tcp_proxy.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. """TLS proxy for slicer-to-printer communication.
  2. This module provides a TLS terminating proxy that forwards data between
  3. a slicer and a real Bambu printer, enabling remote printing over
  4. any network connection.
  5. Unlike a transparent TCP proxy, this terminates TLS on both ends:
  6. - Slicer connects to Bambuddy using Bambuddy's certificate
  7. - Bambuddy connects to printer using printer's certificate
  8. - Data is decrypted, forwarded, and re-encrypted
  9. """
  10. import asyncio
  11. import logging
  12. import ssl
  13. from collections.abc import Callable
  14. from pathlib import Path
  15. logger = logging.getLogger(__name__)
  16. class TLSProxy:
  17. """TLS terminating proxy that forwards data between client and target.
  18. This proxy terminates TLS on both ends, allowing the slicer to connect
  19. to Bambuddy's certificate while Bambuddy connects to the real printer.
  20. """
  21. def __init__(
  22. self,
  23. name: str,
  24. listen_port: int,
  25. target_host: str,
  26. target_port: int,
  27. server_cert_path: Path,
  28. server_key_path: Path,
  29. on_connect: Callable[[str], None] | None = None,
  30. on_disconnect: Callable[[str], None] | None = None,
  31. ):
  32. """Initialize the TLS proxy.
  33. Args:
  34. name: Friendly name for logging (e.g., "FTP", "MQTT")
  35. listen_port: Port to listen on for incoming connections
  36. target_host: Target printer IP/hostname
  37. target_port: Target printer port
  38. server_cert_path: Path to server certificate (for accepting slicer connections)
  39. server_key_path: Path to server private key
  40. on_connect: Optional callback when client connects (receives client_id)
  41. on_disconnect: Optional callback when client disconnects (receives client_id)
  42. """
  43. self.name = name
  44. self.listen_port = listen_port
  45. self.target_host = target_host
  46. self.target_port = target_port
  47. self.server_cert_path = server_cert_path
  48. self.server_key_path = server_key_path
  49. self.on_connect = on_connect
  50. self.on_disconnect = on_disconnect
  51. self._server: asyncio.Server | None = None
  52. self._running = False
  53. self._active_connections: dict[str, tuple[asyncio.Task, asyncio.Task]] = {}
  54. self._server_ssl_context: ssl.SSLContext | None = None
  55. self._client_ssl_context: ssl.SSLContext | None = None
  56. def _create_server_ssl_context(self) -> ssl.SSLContext:
  57. """Create SSL context for accepting client (slicer) connections."""
  58. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  59. ctx.load_cert_chain(self.server_cert_path, self.server_key_path)
  60. # Allow older TLS versions for compatibility with slicers
  61. ctx.minimum_version = ssl.TLSVersion.TLSv1_2
  62. # Don't require client certificates
  63. ctx.verify_mode = ssl.CERT_NONE
  64. return ctx
  65. def _create_client_ssl_context(self) -> ssl.SSLContext:
  66. """Create SSL context for connecting to printer."""
  67. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
  68. # Don't verify printer's certificate (self-signed)
  69. ctx.check_hostname = False
  70. ctx.verify_mode = ssl.CERT_NONE
  71. ctx.minimum_version = ssl.TLSVersion.TLSv1_2
  72. return ctx
  73. async def start(self) -> None:
  74. """Start the TLS proxy server."""
  75. if self._running:
  76. return
  77. logger.info(
  78. f"Starting {self.name} TLS proxy: 0.0.0.0:{self.listen_port} → {self.target_host}:{self.target_port}"
  79. )
  80. try:
  81. self._running = True
  82. # Create SSL contexts
  83. self._server_ssl_context = self._create_server_ssl_context()
  84. self._client_ssl_context = self._create_client_ssl_context()
  85. # Start server with TLS
  86. self._server = await asyncio.start_server(
  87. self._handle_client,
  88. "0.0.0.0",
  89. self.listen_port,
  90. ssl=self._server_ssl_context,
  91. )
  92. logger.info(f"{self.name} TLS proxy listening on port {self.listen_port}")
  93. async with self._server:
  94. await self._server.serve_forever()
  95. except OSError as e:
  96. if e.errno == 98: # Address already in use
  97. logger.error(f"{self.name} proxy port {self.listen_port} is already in use")
  98. else:
  99. logger.error(f"{self.name} proxy error: {e}")
  100. except asyncio.CancelledError:
  101. logger.debug(f"{self.name} proxy task cancelled")
  102. except Exception as e:
  103. logger.error(f"{self.name} proxy error: {e}")
  104. finally:
  105. await self.stop()
  106. async def stop(self) -> None:
  107. """Stop the TLS proxy server."""
  108. logger.info(f"Stopping {self.name} proxy")
  109. self._running = False
  110. # Cancel all active connection tasks
  111. for client_id, (task1, task2) in list(self._active_connections.items()):
  112. task1.cancel()
  113. task2.cancel()
  114. if self.on_disconnect:
  115. try:
  116. self.on_disconnect(client_id)
  117. except Exception:
  118. pass
  119. self._active_connections.clear()
  120. if self._server:
  121. try:
  122. self._server.close()
  123. await self._server.wait_closed()
  124. except Exception as e:
  125. logger.debug(f"Error closing {self.name} proxy server: {e}")
  126. self._server = None
  127. async def _handle_client(
  128. self,
  129. client_reader: asyncio.StreamReader,
  130. client_writer: asyncio.StreamWriter,
  131. ) -> None:
  132. """Handle a new client connection by proxying to target."""
  133. peername = client_writer.get_extra_info("peername")
  134. client_id = f"{peername[0]}:{peername[1]}" if peername else "unknown"
  135. logger.info(f"{self.name} proxy: client connected from {client_id}")
  136. if self.on_connect:
  137. try:
  138. self.on_connect(client_id)
  139. except Exception:
  140. pass
  141. # Connect to target printer with TLS
  142. try:
  143. printer_reader, printer_writer = await asyncio.wait_for(
  144. asyncio.open_connection(
  145. self.target_host,
  146. self.target_port,
  147. ssl=self._client_ssl_context,
  148. ),
  149. timeout=10.0,
  150. )
  151. logger.info(f"{self.name} proxy: connected to printer {self.target_host}:{self.target_port}")
  152. except TimeoutError:
  153. logger.error(f"{self.name} proxy: timeout connecting to {self.target_host}:{self.target_port}")
  154. client_writer.close()
  155. await client_writer.wait_closed()
  156. return
  157. except ssl.SSLError as e:
  158. logger.error(f"{self.name} proxy: SSL error connecting to {self.target_host}:{self.target_port}: {e}")
  159. client_writer.close()
  160. await client_writer.wait_closed()
  161. return
  162. except Exception as e:
  163. logger.error(f"{self.name} proxy: failed to connect to {self.target_host}:{self.target_port}: {e}")
  164. client_writer.close()
  165. await client_writer.wait_closed()
  166. return
  167. # Create bidirectional forwarding tasks
  168. client_to_printer = asyncio.create_task(
  169. self._forward(client_reader, printer_writer, f"{client_id}→printer"),
  170. name=f"{self.name}_c2p_{client_id}",
  171. )
  172. printer_to_client = asyncio.create_task(
  173. self._forward(printer_reader, client_writer, f"printer→{client_id}"),
  174. name=f"{self.name}_p2c_{client_id}",
  175. )
  176. self._active_connections[client_id] = (client_to_printer, printer_to_client)
  177. try:
  178. # Wait for either direction to complete (connection closed)
  179. done, pending = await asyncio.wait(
  180. [client_to_printer, printer_to_client],
  181. return_when=asyncio.FIRST_COMPLETED,
  182. )
  183. # Cancel the other direction
  184. for task in pending:
  185. task.cancel()
  186. try:
  187. await task
  188. except asyncio.CancelledError:
  189. pass
  190. except Exception as e:
  191. logger.debug(f"{self.name} proxy connection error: {e}")
  192. finally:
  193. # Clean up
  194. self._active_connections.pop(client_id, None)
  195. for writer in [client_writer, printer_writer]:
  196. try:
  197. writer.close()
  198. await writer.wait_closed()
  199. except Exception:
  200. pass
  201. logger.info(f"{self.name} proxy: client {client_id} disconnected")
  202. if self.on_disconnect:
  203. try:
  204. self.on_disconnect(client_id)
  205. except Exception:
  206. pass
  207. async def _forward(
  208. self,
  209. reader: asyncio.StreamReader,
  210. writer: asyncio.StreamWriter,
  211. direction: str,
  212. ) -> None:
  213. """Forward data from reader to writer.
  214. Args:
  215. reader: Source stream (already TLS-decrypted)
  216. writer: Destination stream (will be TLS-encrypted by the stream)
  217. direction: Description for logging (e.g., "client→printer")
  218. """
  219. total_bytes = 0
  220. try:
  221. while self._running:
  222. # Read chunk - use reasonable buffer size
  223. data = await reader.read(65536)
  224. if not data:
  225. # Connection closed
  226. break
  227. # Forward to destination
  228. writer.write(data)
  229. await writer.drain()
  230. total_bytes += len(data)
  231. logger.debug(f"{self.name} proxy {direction}: {len(data)} bytes")
  232. except asyncio.CancelledError:
  233. pass
  234. except ConnectionResetError:
  235. logger.debug(f"{self.name} proxy {direction}: connection reset")
  236. except BrokenPipeError:
  237. logger.debug(f"{self.name} proxy {direction}: broken pipe")
  238. except Exception as e:
  239. logger.debug(f"{self.name} proxy {direction} error: {e}")
  240. logger.debug(f"{self.name} proxy {direction}: total {total_bytes} bytes")
  241. class SlicerProxyManager:
  242. """Manages FTP and MQTT TLS proxies for a single printer target."""
  243. # Bambu printer ports
  244. PRINTER_FTP_PORT = 990
  245. PRINTER_MQTT_PORT = 8883
  246. # Local listen ports (same as virtual printer)
  247. LOCAL_FTP_PORT = 9990
  248. LOCAL_MQTT_PORT = 8883
  249. def __init__(
  250. self,
  251. target_host: str,
  252. cert_path: Path,
  253. key_path: Path,
  254. on_activity: Callable[[str, str], None] | None = None,
  255. ):
  256. """Initialize the slicer proxy manager.
  257. Args:
  258. target_host: Target printer IP address
  259. cert_path: Path to server certificate
  260. key_path: Path to server private key
  261. on_activity: Optional callback for activity logging (name, message)
  262. """
  263. self.target_host = target_host
  264. self.cert_path = cert_path
  265. self.key_path = key_path
  266. self.on_activity = on_activity
  267. self._ftp_proxy: TLSProxy | None = None
  268. self._mqtt_proxy: TLSProxy | None = None
  269. self._tasks: list[asyncio.Task] = []
  270. async def start(self) -> None:
  271. """Start FTP and MQTT TLS proxies."""
  272. logger.info(f"Starting slicer TLS proxy to {self.target_host}")
  273. # Create proxies with TLS
  274. self._ftp_proxy = TLSProxy(
  275. name="FTP",
  276. listen_port=self.LOCAL_FTP_PORT,
  277. target_host=self.target_host,
  278. target_port=self.PRINTER_FTP_PORT,
  279. server_cert_path=self.cert_path,
  280. server_key_path=self.key_path,
  281. on_connect=lambda cid: self._log_activity("FTP", f"connected: {cid}"),
  282. on_disconnect=lambda cid: self._log_activity("FTP", f"disconnected: {cid}"),
  283. )
  284. self._mqtt_proxy = TLSProxy(
  285. name="MQTT",
  286. listen_port=self.LOCAL_MQTT_PORT,
  287. target_host=self.target_host,
  288. target_port=self.PRINTER_MQTT_PORT,
  289. server_cert_path=self.cert_path,
  290. server_key_path=self.key_path,
  291. on_connect=lambda cid: self._log_activity("MQTT", f"connected: {cid}"),
  292. on_disconnect=lambda cid: self._log_activity("MQTT", f"disconnected: {cid}"),
  293. )
  294. # Start as background tasks
  295. async def run_with_logging(proxy: TLSProxy) -> None:
  296. try:
  297. await proxy.start()
  298. except Exception as e:
  299. logger.error(f"Slicer proxy {proxy.name} failed: {e}")
  300. self._tasks = [
  301. asyncio.create_task(
  302. run_with_logging(self._ftp_proxy),
  303. name="slicer_proxy_ftp",
  304. ),
  305. asyncio.create_task(
  306. run_with_logging(self._mqtt_proxy),
  307. name="slicer_proxy_mqtt",
  308. ),
  309. ]
  310. logger.info(f"Slicer TLS proxy started for {self.target_host}")
  311. # Wait for tasks to complete (they run until cancelled)
  312. # This keeps the start() coroutine alive so the parent task doesn't complete
  313. try:
  314. await asyncio.gather(*self._tasks)
  315. except asyncio.CancelledError:
  316. logger.debug("Slicer proxy start cancelled")
  317. async def stop(self) -> None:
  318. """Stop all proxies."""
  319. logger.info("Stopping slicer proxy")
  320. # Stop proxies
  321. if self._ftp_proxy:
  322. await self._ftp_proxy.stop()
  323. self._ftp_proxy = None
  324. if self._mqtt_proxy:
  325. await self._mqtt_proxy.stop()
  326. self._mqtt_proxy = None
  327. # Cancel tasks
  328. for task in self._tasks:
  329. task.cancel()
  330. if self._tasks:
  331. try:
  332. await asyncio.wait_for(
  333. asyncio.gather(*self._tasks, return_exceptions=True),
  334. timeout=2.0,
  335. )
  336. except TimeoutError:
  337. logger.debug("Some proxy tasks didn't stop in time")
  338. self._tasks = []
  339. logger.info("Slicer proxy stopped")
  340. def _log_activity(self, name: str, message: str) -> None:
  341. """Log activity via callback if configured."""
  342. if self.on_activity:
  343. try:
  344. self.on_activity(name, message)
  345. except Exception:
  346. pass
  347. @property
  348. def is_running(self) -> bool:
  349. """Check if proxies are running."""
  350. return len(self._tasks) > 0 and all(not t.done() for t in self._tasks)
  351. def get_status(self) -> dict:
  352. """Get proxy status."""
  353. return {
  354. "running": self.is_running,
  355. "target_host": self.target_host,
  356. "ftp_port": self.LOCAL_FTP_PORT,
  357. "mqtt_port": self.LOCAL_MQTT_PORT,
  358. "ftp_connections": (len(self._ftp_proxy._active_connections) if self._ftp_proxy else 0),
  359. "mqtt_connections": (len(self._mqtt_proxy._active_connections) if self._mqtt_proxy else 0),
  360. }