bind_server.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. """Bind/detect server for virtual printer discovery (ports 3000 + 3002).
  2. Bambu slicers (BambuStudio, OrcaSlicer) connect to a printer on port 3000
  3. or 3002 to perform the "bind with access code" handshake before using
  4. MQTT/FTP.
  5. Port 3000: plain TCP (legacy / some printer models).
  6. Port 3002: TLS (newer firmware, e.g. A1 Mini 01.07.x).
  7. Protocol (same on both ports, only transport differs):
  8. - Framing: 0xA5A5 + uint16_le(total_msg_size) + JSON payload + 0xA7A7
  9. - Slicer sends: {"login":{"command":"detect","sequence_id":"20000"}}
  10. - Printer replies: {"login":{"bind":"free","command":"detect","connect":"lan",
  11. "dev_cap":1,"id":"<serial>","model":"<model>","name":"<name>",
  12. "sequence_id":<int>,"version":"<firmware>"}}
  13. - Connection closes after one exchange.
  14. """
  15. import asyncio
  16. import json
  17. import logging
  18. import ssl
  19. import struct
  20. from pathlib import Path
  21. logger = logging.getLogger(__name__)
  22. BIND_PORT_PLAIN = 3000
  23. BIND_PORT_TLS = 3002
  24. BIND_PORTS = [BIND_PORT_PLAIN, BIND_PORT_TLS]
  25. FRAME_HEADER = b"\xa5\xa5"
  26. FRAME_TRAILER = b"\xa7\xa7"
  27. HEADER_SIZE = 4 # 2 bytes magic + 2 bytes length
  28. TRAILER_SIZE = 2
  29. class BindServer:
  30. """Responds to slicer bind/detect requests on ports 3000 and 3002.
  31. In server mode, Bambuddy IS the printer — it responds with its own
  32. identity so the slicer can discover and bind to it.
  33. Port 3000 is plain TCP, port 3002 is TLS. BambuStudio chooses which
  34. port to use based on the printer model discovered via SSDP.
  35. """
  36. def __init__(
  37. self,
  38. serial: str,
  39. model: str,
  40. name: str,
  41. version: str = "01.00.00.00",
  42. bind_address: str = "0.0.0.0", # nosec B104
  43. cert_path: Path | None = None,
  44. key_path: Path | None = None,
  45. ):
  46. self.serial = serial
  47. self.model = model
  48. self.name = name
  49. self.version = version
  50. self.bind_address = bind_address
  51. self.cert_path = cert_path
  52. self.key_path = key_path
  53. self._servers: list[asyncio.Server] = []
  54. self._running = False
  55. # Set after at least one bind port is listening — see ftp_server.py
  56. # for rationale. Bind server is best-effort across BIND_PORTS, so
  57. # "ready" means "at least one port bound", matching the existing
  58. # serve_forever path.
  59. self.ready = asyncio.Event()
  60. def _create_tls_context(self) -> ssl.SSLContext | None:
  61. """Create SSL context for the TLS bind port (3002)."""
  62. if not self.cert_path or not self.key_path:
  63. return None
  64. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  65. ctx.load_cert_chain(str(self.cert_path), str(self.key_path))
  66. ctx.minimum_version = ssl.TLSVersion.TLSv1_2
  67. ctx.verify_mode = ssl.CERT_NONE
  68. return ctx
  69. async def start(self) -> None:
  70. """Start the bind server on ports 3000 (plain) and 3002 (TLS)."""
  71. if self._running:
  72. return
  73. self._running = True
  74. tls_ctx = self._create_tls_context()
  75. if not tls_ctx:
  76. logger.warning("Bind server: no TLS cert provided, port %s will be plain TCP", BIND_PORT_TLS)
  77. logger.info(
  78. "Starting bind server on ports %s (serial=%s, model=%s, tls=%s)",
  79. BIND_PORTS,
  80. self.serial,
  81. self.model,
  82. tls_ctx is not None,
  83. )
  84. try:
  85. for port in BIND_PORTS:
  86. use_tls = port == BIND_PORT_TLS and tls_ctx is not None
  87. try:
  88. server = await asyncio.start_server(
  89. self._handle_client,
  90. self.bind_address,
  91. port,
  92. ssl=tls_ctx if use_tls else None,
  93. )
  94. self._servers.append(server)
  95. logger.info(
  96. "Bind server listening on %s:%s (%s)",
  97. self.bind_address,
  98. port,
  99. "TLS" if use_tls else "plain",
  100. )
  101. except OSError as e:
  102. if e.errno == 98:
  103. logger.warning("Bind server port %s already in use, skipping", port)
  104. elif e.errno == 13:
  105. logger.warning("Bind server: cannot bind to port %s (permission denied), skipping", port)
  106. else:
  107. logger.warning("Bind server: failed to bind port %s: %s", port, e)
  108. if not self._servers:
  109. logger.error("Bind server: could not bind to any port")
  110. return
  111. self.ready.set()
  112. # Serve all successfully bound ports
  113. await asyncio.gather(*(s.serve_forever() for s in self._servers))
  114. except asyncio.CancelledError:
  115. logger.debug("Bind server task cancelled")
  116. except Exception as e:
  117. logger.error("Bind server error: %s", e)
  118. finally:
  119. await self.stop()
  120. async def stop(self) -> None:
  121. """Stop the bind server."""
  122. logger.info("Stopping bind server")
  123. self._running = False
  124. self.ready.clear()
  125. for server in self._servers:
  126. try:
  127. server.close()
  128. await server.wait_closed()
  129. except OSError as e:
  130. logger.debug("Error closing bind server: %s", e)
  131. self._servers = []
  132. async def _handle_client(
  133. self,
  134. reader: asyncio.StreamReader,
  135. writer: asyncio.StreamWriter,
  136. ) -> None:
  137. """Handle a single bind/detect request from a slicer."""
  138. peername = writer.get_extra_info("peername")
  139. client_id = f"{peername[0]}:{peername[1]}" if peername else "unknown"
  140. logger.info("Bind server: client connected from %s", client_id)
  141. try:
  142. # Read the framed message (timeout after 10s)
  143. data = await asyncio.wait_for(reader.read(4096), timeout=10.0)
  144. if not data:
  145. return
  146. # Parse the request
  147. request = self._parse_frame(data)
  148. if request is None:
  149. logger.warning("Bind server: invalid frame from %s", client_id)
  150. return
  151. logger.info("Bind server: received from %s: %s", client_id, request)
  152. # Check if this is a detect command
  153. login = request.get("login", {})
  154. if not isinstance(login, dict) or login.get("command") != "detect":
  155. logger.warning("Bind server: unexpected command from %s: %s", client_id, request)
  156. return
  157. # Build response. `sequence_id` is an INTEGER counter chosen by
  158. # the printer side (not an echo of the slicer's string seq_id).
  159. # The protocol docstring at the top of this file documents the
  160. # asymmetry: slicer sends `"20000"` (string), printer replies
  161. # with an int. The hardcoded 3021 mirrors real-firmware-captured
  162. # value; an earlier audit suggesting we echo the slicer's seq_id
  163. # was wrong and would have broken slicers that validate the
  164. # type (int vs string).
  165. response = {
  166. "login": {
  167. "bind": "free",
  168. "command": "detect",
  169. "connect": "lan",
  170. "dev_cap": 1,
  171. "id": self.serial,
  172. "model": self.model,
  173. "name": self.name,
  174. "sequence_id": 3021,
  175. "version": self.version,
  176. }
  177. }
  178. frame = self._build_frame(response)
  179. writer.write(frame)
  180. await writer.drain()
  181. logger.info("Bind server: sent detect response to %s (serial=%s)", client_id, self.serial)
  182. except TimeoutError:
  183. logger.debug("Bind server: timeout waiting for data from %s", client_id)
  184. except Exception as e:
  185. logger.error("Bind server: error handling %s: %s", client_id, e)
  186. finally:
  187. try:
  188. writer.close()
  189. await writer.wait_closed()
  190. except OSError:
  191. pass
  192. logger.debug("Bind server: client %s disconnected", client_id)
  193. def _parse_frame(self, data: bytes) -> dict | None:
  194. """Parse a framed message: 0xA5A5 + len(u16le) + JSON + 0xA7A7."""
  195. if len(data) < HEADER_SIZE + TRAILER_SIZE:
  196. return None
  197. if data[:2] != FRAME_HEADER:
  198. return None
  199. if data[-2:] != FRAME_TRAILER:
  200. return None
  201. # Length field is total message size (header + json + trailer)
  202. total_len = struct.unpack_from("<H", data, 2)[0]
  203. if total_len != len(data):
  204. logger.debug("Bind frame length mismatch: header says %d, got %d", total_len, len(data))
  205. # JSON payload is between header and trailer
  206. json_bytes = data[HEADER_SIZE:-TRAILER_SIZE]
  207. try:
  208. return json.loads(json_bytes)
  209. except (json.JSONDecodeError, UnicodeDecodeError) as e:
  210. logger.warning("Bind server: failed to parse JSON: %s", e)
  211. return None
  212. def _build_frame(self, payload: dict) -> bytes:
  213. """Build a framed message: 0xA5A5 + len(u16le) + JSON + 0xA7A7."""
  214. json_bytes = json.dumps(payload, separators=(",", ":")).encode("utf-8")
  215. total_len = HEADER_SIZE + len(json_bytes) + TRAILER_SIZE
  216. header = FRAME_HEADER + struct.pack("<H", total_len)
  217. return header + json_bytes + FRAME_TRAILER