bind_server.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """Bind/detect server for virtual printer discovery (ports 3000 + 3002).
  2. Bambu slicers (BambuStudio, OrcaSlicer) connect to a printer on port 3000
  3. or 3002 to perform the "bind with access code" handshake before using
  4. MQTT/FTP.
  5. Port 3000: plain TCP (legacy / some printer models).
  6. Port 3002: TLS (newer firmware, e.g. A1 Mini 01.07.x).
  7. Protocol (same on both ports, only transport differs):
  8. - Framing: 0xA5A5 + uint16_le(total_msg_size) + JSON payload + 0xA7A7
  9. - Slicer sends: {"login":{"command":"detect","sequence_id":"20000"}}
  10. - Printer replies: {"login":{"bind":"free","command":"detect","connect":"lan",
  11. "dev_cap":1,"id":"<serial>","model":"<model>","name":"<name>",
  12. "sequence_id":<int>,"version":"<firmware>"}}
  13. - Connection closes after one exchange.
  14. """
  15. import asyncio
  16. import json
  17. import logging
  18. import ssl
  19. import struct
  20. from pathlib import Path
  21. logger = logging.getLogger(__name__)
  22. BIND_PORT_PLAIN = 3000
  23. BIND_PORT_TLS = 3002
  24. BIND_PORTS = [BIND_PORT_PLAIN, BIND_PORT_TLS]
  25. FRAME_HEADER = b"\xa5\xa5"
  26. FRAME_TRAILER = b"\xa7\xa7"
  27. HEADER_SIZE = 4 # 2 bytes magic + 2 bytes length
  28. TRAILER_SIZE = 2
  29. class BindServer:
  30. """Responds to slicer bind/detect requests on ports 3000 and 3002.
  31. In server mode, Bambuddy IS the printer — it responds with its own
  32. identity so the slicer can discover and bind to it.
  33. Port 3000 is plain TCP, port 3002 is TLS. BambuStudio chooses which
  34. port to use based on the printer model discovered via SSDP.
  35. """
  36. def __init__(
  37. self,
  38. serial: str,
  39. model: str,
  40. name: str,
  41. version: str = "01.00.00.00",
  42. bind_address: str = "0.0.0.0", # nosec B104
  43. cert_path: Path | None = None,
  44. key_path: Path | None = None,
  45. ):
  46. self.serial = serial
  47. self.model = model
  48. self.name = name
  49. self.version = version
  50. self.bind_address = bind_address
  51. self.cert_path = cert_path
  52. self.key_path = key_path
  53. self._servers: list[asyncio.Server] = []
  54. self._running = False
  55. def _create_tls_context(self) -> ssl.SSLContext | None:
  56. """Create SSL context for the TLS bind port (3002)."""
  57. if not self.cert_path or not self.key_path:
  58. return None
  59. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  60. ctx.load_cert_chain(str(self.cert_path), str(self.key_path))
  61. ctx.minimum_version = ssl.TLSVersion.TLSv1_2
  62. ctx.verify_mode = ssl.CERT_NONE
  63. return ctx
  64. async def start(self) -> None:
  65. """Start the bind server on ports 3000 (plain) and 3002 (TLS)."""
  66. if self._running:
  67. return
  68. self._running = True
  69. tls_ctx = self._create_tls_context()
  70. if not tls_ctx:
  71. logger.warning("Bind server: no TLS cert provided, port %s will be plain TCP", BIND_PORT_TLS)
  72. logger.info(
  73. "Starting bind server on ports %s (serial=%s, model=%s, tls=%s)",
  74. BIND_PORTS,
  75. self.serial,
  76. self.model,
  77. tls_ctx is not None,
  78. )
  79. try:
  80. for port in BIND_PORTS:
  81. use_tls = port == BIND_PORT_TLS and tls_ctx is not None
  82. try:
  83. server = await asyncio.start_server(
  84. self._handle_client,
  85. self.bind_address,
  86. port,
  87. ssl=tls_ctx if use_tls else None,
  88. )
  89. self._servers.append(server)
  90. logger.info(
  91. "Bind server listening on %s:%s (%s)",
  92. self.bind_address,
  93. port,
  94. "TLS" if use_tls else "plain",
  95. )
  96. except OSError as e:
  97. if e.errno == 98:
  98. logger.warning("Bind server port %s already in use, skipping", port)
  99. elif e.errno == 13:
  100. logger.warning("Bind server: cannot bind to port %s (permission denied), skipping", port)
  101. else:
  102. logger.warning("Bind server: failed to bind port %s: %s", port, e)
  103. if not self._servers:
  104. logger.error("Bind server: could not bind to any port")
  105. return
  106. # Serve all successfully bound ports
  107. await asyncio.gather(*(s.serve_forever() for s in self._servers))
  108. except asyncio.CancelledError:
  109. logger.debug("Bind server task cancelled")
  110. except Exception as e:
  111. logger.error("Bind server error: %s", e)
  112. finally:
  113. await self.stop()
  114. async def stop(self) -> None:
  115. """Stop the bind server."""
  116. logger.info("Stopping bind server")
  117. self._running = False
  118. for server in self._servers:
  119. try:
  120. server.close()
  121. await server.wait_closed()
  122. except OSError as e:
  123. logger.debug("Error closing bind server: %s", e)
  124. self._servers = []
  125. async def _handle_client(
  126. self,
  127. reader: asyncio.StreamReader,
  128. writer: asyncio.StreamWriter,
  129. ) -> None:
  130. """Handle a single bind/detect request from a slicer."""
  131. peername = writer.get_extra_info("peername")
  132. client_id = f"{peername[0]}:{peername[1]}" if peername else "unknown"
  133. logger.info("Bind server: client connected from %s", client_id)
  134. try:
  135. # Read the framed message (timeout after 10s)
  136. data = await asyncio.wait_for(reader.read(4096), timeout=10.0)
  137. if not data:
  138. return
  139. # Parse the request
  140. request = self._parse_frame(data)
  141. if request is None:
  142. logger.warning("Bind server: invalid frame from %s", client_id)
  143. return
  144. logger.info("Bind server: received from %s: %s", client_id, request)
  145. # Check if this is a detect command
  146. login = request.get("login", {})
  147. if not isinstance(login, dict) or login.get("command") != "detect":
  148. logger.warning("Bind server: unexpected command from %s: %s", client_id, request)
  149. return
  150. # Build response
  151. response = {
  152. "login": {
  153. "bind": "free",
  154. "command": "detect",
  155. "connect": "lan",
  156. "dev_cap": 1,
  157. "id": self.serial,
  158. "model": self.model,
  159. "name": self.name,
  160. "sequence_id": 3021,
  161. "version": self.version,
  162. }
  163. }
  164. frame = self._build_frame(response)
  165. writer.write(frame)
  166. await writer.drain()
  167. logger.info("Bind server: sent detect response to %s (serial=%s)", client_id, self.serial)
  168. except TimeoutError:
  169. logger.debug("Bind server: timeout waiting for data from %s", client_id)
  170. except Exception as e:
  171. logger.error("Bind server: error handling %s: %s", client_id, e)
  172. finally:
  173. try:
  174. writer.close()
  175. await writer.wait_closed()
  176. except OSError:
  177. pass
  178. logger.debug("Bind server: client %s disconnected", client_id)
  179. def _parse_frame(self, data: bytes) -> dict | None:
  180. """Parse a framed message: 0xA5A5 + len(u16le) + JSON + 0xA7A7."""
  181. if len(data) < HEADER_SIZE + TRAILER_SIZE:
  182. return None
  183. if data[:2] != FRAME_HEADER:
  184. return None
  185. if data[-2:] != FRAME_TRAILER:
  186. return None
  187. # Length field is total message size (header + json + trailer)
  188. total_len = struct.unpack_from("<H", data, 2)[0]
  189. if total_len != len(data):
  190. logger.debug("Bind frame length mismatch: header says %d, got %d", total_len, len(data))
  191. # JSON payload is between header and trailer
  192. json_bytes = data[HEADER_SIZE:-TRAILER_SIZE]
  193. try:
  194. return json.loads(json_bytes)
  195. except (json.JSONDecodeError, UnicodeDecodeError) as e:
  196. logger.warning("Bind server: failed to parse JSON: %s", e)
  197. return None
  198. def _build_frame(self, payload: dict) -> bytes:
  199. """Build a framed message: 0xA5A5 + len(u16le) + JSON + 0xA7A7."""
  200. json_bytes = json.dumps(payload, separators=(",", ":")).encode("utf-8")
  201. total_len = HEADER_SIZE + len(json_bytes) + TRAILER_SIZE
  202. header = FRAME_HEADER + struct.pack("<H", total_len)
  203. return header + json_bytes + FRAME_TRAILER