tcp_proxy.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. """TLS proxy for slicer-to-printer communication.
  2. This module provides a TLS terminating proxy that forwards data between
  3. a slicer and a real Bambu printer, enabling remote printing over
  4. any network connection.
  5. Unlike a transparent TCP proxy, this terminates TLS on both ends:
  6. - Slicer connects to Bambuddy using Bambuddy's certificate
  7. - Bambuddy connects to printer using printer's certificate
  8. - Data is decrypted, forwarded, and re-encrypted
  9. """
  10. import asyncio
  11. import logging
  12. import ssl
  13. from collections.abc import Callable
  14. from pathlib import Path
  15. logger = logging.getLogger(__name__)
  16. class TLSProxy:
  17. """TLS terminating proxy that forwards data between client and target.
  18. This proxy terminates TLS on both ends, allowing the slicer to connect
  19. to Bambuddy's certificate while Bambuddy connects to the real printer.
  20. """
  21. def __init__(
  22. self,
  23. name: str,
  24. listen_port: int,
  25. target_host: str,
  26. target_port: int,
  27. server_cert_path: Path,
  28. server_key_path: Path,
  29. on_connect: Callable[[str], None] | None = None,
  30. on_disconnect: Callable[[str], None] | None = None,
  31. ):
  32. """Initialize the TLS proxy.
  33. Args:
  34. name: Friendly name for logging (e.g., "FTP", "MQTT")
  35. listen_port: Port to listen on for incoming connections
  36. target_host: Target printer IP/hostname
  37. target_port: Target printer port
  38. server_cert_path: Path to server certificate (for accepting slicer connections)
  39. server_key_path: Path to server private key
  40. on_connect: Optional callback when client connects (receives client_id)
  41. on_disconnect: Optional callback when client disconnects (receives client_id)
  42. """
  43. self.name = name
  44. self.listen_port = listen_port
  45. self.target_host = target_host
  46. self.target_port = target_port
  47. self.server_cert_path = server_cert_path
  48. self.server_key_path = server_key_path
  49. self.on_connect = on_connect
  50. self.on_disconnect = on_disconnect
  51. self._server: asyncio.Server | None = None
  52. self._running = False
  53. self._active_connections: dict[str, tuple[asyncio.Task, asyncio.Task]] = {}
  54. self._server_ssl_context: ssl.SSLContext | None = None
  55. self._client_ssl_context: ssl.SSLContext | None = None
  56. def _create_server_ssl_context(self) -> ssl.SSLContext:
  57. """Create SSL context for accepting client (slicer) connections."""
  58. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  59. ctx.load_cert_chain(self.server_cert_path, self.server_key_path)
  60. # Allow older TLS versions for compatibility with slicers
  61. ctx.minimum_version = ssl.TLSVersion.TLSv1_2
  62. # Don't require client certificates
  63. ctx.verify_mode = ssl.CERT_NONE
  64. return ctx
  65. def _create_client_ssl_context(self) -> ssl.SSLContext:
  66. """Create SSL context for connecting to printer."""
  67. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  68. # Don't verify printer's certificate (self-signed)
  69. ctx.check_hostname = False
  70. ctx.verify_mode = ssl.CERT_NONE
  71. ctx.minimum_version = ssl.TLSVersion.TLSv1_2
  72. return ctx
  73. async def start(self) -> None:
  74. """Start the TLS proxy server."""
  75. if self._running:
  76. return
  77. logger.info(
  78. f"Starting {self.name} TLS proxy: 0.0.0.0:{self.listen_port} → {self.target_host}:{self.target_port}"
  79. )
  80. try:
  81. self._running = True
  82. # Create SSL contexts
  83. self._server_ssl_context = self._create_server_ssl_context()
  84. self._client_ssl_context = self._create_client_ssl_context()
  85. # Start server with TLS
  86. self._server = await asyncio.start_server(
  87. self._handle_client,
  88. "0.0.0.0", # nosec B104
  89. self.listen_port,
  90. ssl=self._server_ssl_context,
  91. )
  92. logger.info("%s TLS proxy listening on port %s", self.name, self.listen_port)
  93. async with self._server:
  94. await self._server.serve_forever()
  95. except OSError as e:
  96. if e.errno == 98: # Address already in use
  97. logger.error("%s proxy port %s is already in use", self.name, self.listen_port)
  98. else:
  99. logger.error("%s proxy error: %s", self.name, e)
  100. except asyncio.CancelledError:
  101. logger.debug("%s proxy task cancelled", self.name)
  102. except Exception as e:
  103. logger.error("%s proxy error: %s", self.name, e)
  104. finally:
  105. await self.stop()
  106. async def stop(self) -> None:
  107. """Stop the TLS proxy server."""
  108. logger.info("Stopping %s proxy", self.name)
  109. self._running = False
  110. # Cancel all active connection tasks
  111. for client_id, (task1, task2) in list(self._active_connections.items()):
  112. task1.cancel()
  113. task2.cancel()
  114. if self.on_disconnect:
  115. try:
  116. self.on_disconnect(client_id)
  117. except Exception:
  118. pass # Ignore disconnect callback errors during shutdown
  119. self._active_connections.clear()
  120. if self._server:
  121. try:
  122. self._server.close()
  123. await self._server.wait_closed()
  124. except OSError as e:
  125. logger.debug("Error closing %s proxy server: %s", self.name, e)
  126. self._server = None
  127. async def _handle_client(
  128. self,
  129. client_reader: asyncio.StreamReader,
  130. client_writer: asyncio.StreamWriter,
  131. ) -> None:
  132. """Handle a new client connection by proxying to target."""
  133. peername = client_writer.get_extra_info("peername")
  134. client_id = f"{peername[0]}:{peername[1]}" if peername else "unknown"
  135. logger.info("%s proxy: client connected from %s", self.name, client_id)
  136. if self.on_connect:
  137. try:
  138. self.on_connect(client_id)
  139. except Exception:
  140. pass # Ignore connect callback errors; connection proceeds regardless
  141. # Connect to target printer with TLS
  142. try:
  143. printer_reader, printer_writer = await asyncio.wait_for(
  144. asyncio.open_connection(
  145. self.target_host,
  146. self.target_port,
  147. ssl=self._client_ssl_context,
  148. ),
  149. timeout=10.0,
  150. )
  151. logger.info("%s proxy: connected to printer %s:%s", self.name, self.target_host, self.target_port)
  152. except TimeoutError:
  153. logger.error("%s proxy: timeout connecting to %s:%s", self.name, self.target_host, self.target_port)
  154. client_writer.close()
  155. await client_writer.wait_closed()
  156. return
  157. except ssl.SSLError as e:
  158. logger.error(
  159. "%s proxy: SSL error connecting to %s:%s: %s", self.name, self.target_host, self.target_port, e
  160. )
  161. client_writer.close()
  162. await client_writer.wait_closed()
  163. return
  164. except OSError as e:
  165. logger.error("%s proxy: failed to connect to %s:%s: %s", self.name, self.target_host, self.target_port, e)
  166. client_writer.close()
  167. await client_writer.wait_closed()
  168. return
  169. # Create bidirectional forwarding tasks
  170. client_to_printer = asyncio.create_task(
  171. self._forward(client_reader, printer_writer, f"{client_id}→printer"),
  172. name=f"{self.name}_c2p_{client_id}",
  173. )
  174. printer_to_client = asyncio.create_task(
  175. self._forward(printer_reader, client_writer, f"printer→{client_id}"),
  176. name=f"{self.name}_p2c_{client_id}",
  177. )
  178. self._active_connections[client_id] = (client_to_printer, printer_to_client)
  179. try:
  180. # Wait for either direction to complete (connection closed)
  181. done, pending = await asyncio.wait(
  182. [client_to_printer, printer_to_client],
  183. return_when=asyncio.FIRST_COMPLETED,
  184. )
  185. # Cancel the other direction
  186. for task in pending:
  187. task.cancel()
  188. try:
  189. await task
  190. except asyncio.CancelledError:
  191. pass # Expected when cancelling the other forwarding direction
  192. except Exception as e:
  193. logger.debug("%s proxy connection error: %s", self.name, e)
  194. finally:
  195. # Clean up
  196. self._active_connections.pop(client_id, None)
  197. for writer in [client_writer, printer_writer]:
  198. try:
  199. writer.close()
  200. await writer.wait_closed()
  201. except OSError:
  202. pass # Best-effort connection cleanup; peer may have disconnected
  203. logger.info("%s proxy: client %s disconnected", self.name, client_id)
  204. if self.on_disconnect:
  205. try:
  206. self.on_disconnect(client_id)
  207. except Exception:
  208. pass # Ignore disconnect callback errors; cleanup continues
  209. async def _forward(
  210. self,
  211. reader: asyncio.StreamReader,
  212. writer: asyncio.StreamWriter,
  213. direction: str,
  214. ) -> None:
  215. """Forward data from reader to writer.
  216. Args:
  217. reader: Source stream (already TLS-decrypted)
  218. writer: Destination stream (will be TLS-encrypted by the stream)
  219. direction: Description for logging (e.g., "client→printer")
  220. """
  221. total_bytes = 0
  222. try:
  223. while self._running:
  224. # Read chunk - use reasonable buffer size
  225. data = await reader.read(65536)
  226. if not data:
  227. # Connection closed
  228. break
  229. # Forward to destination
  230. writer.write(data)
  231. await writer.drain()
  232. total_bytes += len(data)
  233. logger.debug("%s proxy %s: %s bytes", self.name, direction, len(data))
  234. except asyncio.CancelledError:
  235. pass # Expected when the other forwarding direction closes first
  236. except ConnectionResetError:
  237. logger.debug("%s proxy %s: connection reset", self.name, direction)
  238. except BrokenPipeError:
  239. logger.debug("%s proxy %s: broken pipe", self.name, direction)
  240. except OSError as e:
  241. logger.debug("%s proxy %s error: %s", self.name, direction, e)
  242. logger.debug("%s proxy %s: total %s bytes", self.name, direction, total_bytes)
  243. class SlicerProxyManager:
  244. """Manages FTP and MQTT TLS proxies for a single printer target."""
  245. # Bambu printer ports
  246. PRINTER_FTP_PORT = 990
  247. PRINTER_MQTT_PORT = 8883
  248. # Local listen ports - must match what Bambu Studio expects
  249. # Note: Port 990 requires root or CAP_NET_BIND_SERVICE capability
  250. LOCAL_FTP_PORT = 990
  251. LOCAL_MQTT_PORT = 8883
  252. def __init__(
  253. self,
  254. target_host: str,
  255. cert_path: Path,
  256. key_path: Path,
  257. on_activity: Callable[[str, str], None] | None = None,
  258. ):
  259. """Initialize the slicer proxy manager.
  260. Args:
  261. target_host: Target printer IP address
  262. cert_path: Path to server certificate
  263. key_path: Path to server private key
  264. on_activity: Optional callback for activity logging (name, message)
  265. """
  266. self.target_host = target_host
  267. self.cert_path = cert_path
  268. self.key_path = key_path
  269. self.on_activity = on_activity
  270. self._ftp_proxy: TLSProxy | None = None
  271. self._mqtt_proxy: TLSProxy | None = None
  272. self._tasks: list[asyncio.Task] = []
  273. async def start(self) -> None:
  274. """Start FTP and MQTT TLS proxies."""
  275. logger.info("Starting slicer TLS proxy to %s", self.target_host)
  276. # Create proxies with TLS
  277. self._ftp_proxy = TLSProxy(
  278. name="FTP",
  279. listen_port=self.LOCAL_FTP_PORT,
  280. target_host=self.target_host,
  281. target_port=self.PRINTER_FTP_PORT,
  282. server_cert_path=self.cert_path,
  283. server_key_path=self.key_path,
  284. on_connect=lambda cid: self._log_activity("FTP", f"connected: {cid}"),
  285. on_disconnect=lambda cid: self._log_activity("FTP", f"disconnected: {cid}"),
  286. )
  287. self._mqtt_proxy = TLSProxy(
  288. name="MQTT",
  289. listen_port=self.LOCAL_MQTT_PORT,
  290. target_host=self.target_host,
  291. target_port=self.PRINTER_MQTT_PORT,
  292. server_cert_path=self.cert_path,
  293. server_key_path=self.key_path,
  294. on_connect=lambda cid: self._log_activity("MQTT", f"connected: {cid}"),
  295. on_disconnect=lambda cid: self._log_activity("MQTT", f"disconnected: {cid}"),
  296. )
  297. # Start as background tasks
  298. async def run_with_logging(proxy: TLSProxy) -> None:
  299. try:
  300. await proxy.start()
  301. except Exception as e:
  302. logger.error("Slicer proxy %s failed: %s", proxy.name, e)
  303. self._tasks = [
  304. asyncio.create_task(
  305. run_with_logging(self._ftp_proxy),
  306. name="slicer_proxy_ftp",
  307. ),
  308. asyncio.create_task(
  309. run_with_logging(self._mqtt_proxy),
  310. name="slicer_proxy_mqtt",
  311. ),
  312. ]
  313. logger.info("Slicer TLS proxy started for %s", self.target_host)
  314. # Wait for tasks to complete (they run until cancelled)
  315. # This keeps the start() coroutine alive so the parent task doesn't complete
  316. try:
  317. await asyncio.gather(*self._tasks)
  318. except asyncio.CancelledError:
  319. logger.debug("Slicer proxy start cancelled")
  320. async def stop(self) -> None:
  321. """Stop all proxies."""
  322. logger.info("Stopping slicer proxy")
  323. # Stop proxies
  324. if self._ftp_proxy:
  325. await self._ftp_proxy.stop()
  326. self._ftp_proxy = None
  327. if self._mqtt_proxy:
  328. await self._mqtt_proxy.stop()
  329. self._mqtt_proxy = None
  330. # Cancel tasks
  331. for task in self._tasks:
  332. task.cancel()
  333. if self._tasks:
  334. try:
  335. await asyncio.wait_for(
  336. asyncio.gather(*self._tasks, return_exceptions=True),
  337. timeout=2.0,
  338. )
  339. except TimeoutError:
  340. logger.debug("Some proxy tasks didn't stop in time")
  341. self._tasks = []
  342. logger.info("Slicer proxy stopped")
  343. def _log_activity(self, name: str, message: str) -> None:
  344. """Log activity via callback if configured."""
  345. if self.on_activity:
  346. try:
  347. self.on_activity(name, message)
  348. except Exception:
  349. pass # Ignore activity callback errors; logging is non-critical
  350. @property
  351. def is_running(self) -> bool:
  352. """Check if proxies are running."""
  353. return len(self._tasks) > 0 and all(not t.done() for t in self._tasks)
  354. def get_status(self) -> dict:
  355. """Get proxy status."""
  356. return {
  357. "running": self.is_running,
  358. "target_host": self.target_host,
  359. "ftp_port": self.LOCAL_FTP_PORT,
  360. "mqtt_port": self.LOCAL_MQTT_PORT,
  361. "ftp_connections": (len(self._ftp_proxy._active_connections) if self._ftp_proxy else 0),
  362. "mqtt_connections": (len(self._mqtt_proxy._active_connections) if self._mqtt_proxy else 0),
  363. }