test_vp_mqtt_server.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. """Tests for Virtual Printer MQTT server."""
  2. import ast
  3. import asyncio
  4. import inspect
  5. import json
  6. from pathlib import Path
  7. from unittest.mock import AsyncMock, MagicMock
  8. import pytest
  9. from backend.app.services.virtual_printer.mqtt_server import SimpleMQTTServer
  10. class TestMQTTServerNoGlobalState:
  11. """Ensure MQTT server doesn't set global asyncio state."""
  12. def test_no_global_exception_handler(self):
  13. """MQTT server must not call set_exception_handler().
  14. set_exception_handler() is global to the event loop. When multiple
  15. VP instances run, each would overwrite the previous handler,
  16. causing lost error context and spurious 'Unhandled exception in
  17. client_connected_cb' messages.
  18. """
  19. source = inspect.getsource(SimpleMQTTServer)
  20. tree = ast.parse(source)
  21. for node in ast.walk(tree):
  22. if isinstance(node, ast.Attribute) and node.attr == "set_exception_handler":
  23. raise AssertionError(
  24. "SimpleMQTTServer must not call set_exception_handler(). "
  25. "It overwrites the global asyncio exception handler, "
  26. "breaking multi-VP setups."
  27. )
  28. def _make_server(serial: str = "01P00A391800001") -> SimpleMQTTServer:
  29. """Build a SimpleMQTTServer with dummy cert paths (start() is never called)."""
  30. return SimpleMQTTServer(
  31. serial=serial,
  32. access_code="deadbeef",
  33. cert_path=Path("/tmp/unused.crt"), # nosec B108
  34. key_path=Path("/tmp/unused.key"), # nosec B108
  35. model="C12",
  36. )
  37. class TestExtractSerialFromTopic:
  38. """_extract_serial_from_topic should pull the serial out of device topics."""
  39. @pytest.mark.parametrize(
  40. "topic,expected",
  41. [
  42. ("device/01P00A391800001/request", "01P00A391800001"),
  43. ("device/09400A391800003/report", "09400A391800003"),
  44. ("device/00M00A391800004/request/subpath", "00M00A391800004"),
  45. ],
  46. )
  47. def test_valid_topics(self, topic, expected):
  48. assert SimpleMQTTServer._extract_serial_from_topic(topic) == expected
  49. @pytest.mark.parametrize(
  50. "topic",
  51. [
  52. "",
  53. "device/",
  54. "device//request", # empty serial
  55. "notdevice/01P00A/request",
  56. "random",
  57. ],
  58. )
  59. def test_invalid_topics(self, topic):
  60. assert SimpleMQTTServer._extract_serial_from_topic(topic) is None
  61. def _build_publish_payload(topic: str, message: dict) -> bytes:
  62. """Build the MQTT PUBLISH packet *payload* (past the fixed header byte)."""
  63. topic_bytes = topic.encode("utf-8")
  64. message_bytes = json.dumps(message).encode("utf-8")
  65. return len(topic_bytes).to_bytes(2, "big") + topic_bytes + message_bytes
  66. class TestPublishHandlerAdaptiveSerial:
  67. """#927: `_handle_publish` must accept any `device/*/request` topic from an
  68. authenticated client and use the topic's serial for all responses."""
  69. def test_handle_publish_accepts_mismatched_serial(self):
  70. """Prior behavior silently dropped publishes whose topic serial didn't
  71. equal self.serial. After the fix the handler must run and learn the
  72. client's serial.
  73. """
  74. server = _make_server(serial="01P00A391800001") # synthetic VP serial
  75. server._client_serials["test-client"] = server.serial # simulate post-CONNECT
  76. writer = MagicMock()
  77. writer.write = MagicMock()
  78. writer.drain = AsyncMock()
  79. # Slicer publishes with a *different* serial — the exact bug from #927.
  80. topic = "device/01P00AABCDEFGHI/request"
  81. payload = _build_publish_payload(topic, {"info": {"command": "get_version", "sequence_id": "42"}})
  82. asyncio.run(server._handle_publish(0x30, payload, writer, "test-client"))
  83. # Learned the client's serial.
  84. assert server._client_serials["test-client"] == "01P00AABCDEFGHI"
  85. # Wrote at least one packet to the slicer (the version response).
  86. assert writer.write.called
  87. all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
  88. # Response topic must contain the *client's* serial, not self.serial.
  89. assert b"device/01P00AABCDEFGHI/report" in all_bytes
  90. assert b"device/01P00A391800001/report" not in all_bytes
  91. # Response body carries get_version with the client's serial as sn.
  92. assert b'"command": "get_version"' in all_bytes
  93. assert b'"sn": "01P00AABCDEFGHI"' in all_bytes
  94. def test_handle_publish_ignores_non_request_topics(self):
  95. server = _make_server()
  96. server._client_serials["c1"] = server.serial
  97. writer = MagicMock()
  98. writer.write = MagicMock()
  99. writer.drain = AsyncMock()
  100. payload = _build_publish_payload(
  101. "device/01P00AABCDEFGHI/report", # /report, not /request
  102. {"pushing": {"command": "pushall"}},
  103. )
  104. asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
  105. assert not writer.write.called # no response
  106. # Client serial unchanged
  107. assert server._client_serials["c1"] == server.serial
  108. def test_handle_publish_pushall_uses_client_serial(self):
  109. """pushall → status_report must be sent on the client's subscribed topic."""
  110. server = _make_server(serial="01P00A391800001")
  111. server._client_serials["c1"] = server.serial
  112. writer = MagicMock()
  113. writer.write = MagicMock()
  114. writer.drain = AsyncMock()
  115. payload = _build_publish_payload(
  116. "device/CUSTOMSERIAL123/request",
  117. {"pushing": {"command": "pushall", "sequence_id": "1"}},
  118. )
  119. asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
  120. all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
  121. assert b"device/CUSTOMSERIAL123/report" in all_bytes
  122. assert b'"command": "push_status"' in all_bytes
  123. assert server._client_serials["c1"] == "CUSTOMSERIAL123"
  124. def test_handle_publish_tolerates_null_terminated_payload(self):
  125. """#927: OrcaSlicer on Linux appends the C-string \\0 to MQTT payloads.
  126. The handler must still parse and respond rather than silently dropping."""
  127. server = _make_server(serial="01P00A391800001")
  128. server._client_serials["c1"] = server.serial
  129. writer = MagicMock()
  130. writer.write = MagicMock()
  131. writer.drain = AsyncMock()
  132. topic = "device/01P00A391800001/request"
  133. topic_bytes = topic.encode("utf-8")
  134. # Real-world bytes captured from EdwardChamberlain's support log: the
  135. # JSON ends with an extra \x00 that strict json.loads rejects.
  136. message_bytes = b'{"pushing":{"command":"pushall","sequence_id":"7"}}\x00'
  137. payload = len(topic_bytes).to_bytes(2, "big") + topic_bytes + message_bytes
  138. asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
  139. all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
  140. assert b"device/01P00A391800001/report" in all_bytes
  141. assert b'"command": "push_status"' in all_bytes
  142. class TestClientSerialLifecycle:
  143. """_client_serials must be cleaned up on disconnect/stop to avoid leaks."""
  144. def test_stop_clears_client_serials(self):
  145. server = _make_server()
  146. server._client_serials["a"] = "X"
  147. server._client_serials["b"] = "Y"
  148. # stop() is async but we only need to cover the clear() path; run a minimal version
  149. asyncio.run(server.stop())
  150. assert server._client_serials == {}
  151. def _build_connect_payload(
  152. keep_alive: int,
  153. access_code: str = "deadbeef",
  154. username: str = "bblp",
  155. client_id: str = "orca",
  156. ) -> bytes:
  157. """Build an MQTT CONNECT variable-header + payload (without the fixed header).
  158. Layout matches the parser in `_handle_connect`:
  159. proto_name_len(2) + "MQTT"(4) + level(1) + flags(1) + keepalive(2) +
  160. client_id_len(2) + client_id + username_len(2) + username +
  161. password_len(2) + password.
  162. """
  163. proto = b"MQTT"
  164. parts = bytearray()
  165. parts += len(proto).to_bytes(2, "big") + proto
  166. parts += bytes([0x04, 0xC2]) # protocol level 4 (MQTT 3.1.1), flags: user+pass+clean
  167. parts += keep_alive.to_bytes(2, "big")
  168. cid = client_id.encode("utf-8")
  169. parts += len(cid).to_bytes(2, "big") + cid
  170. user = username.encode("utf-8")
  171. parts += len(user).to_bytes(2, "big") + user
  172. pw = access_code.encode("utf-8")
  173. parts += len(pw).to_bytes(2, "big") + pw
  174. return bytes(parts)
  175. class TestHandleConnectKeepalive:
  176. """`_handle_connect` must return the negotiated keepalive (#1548).
  177. Pre-fix, the parser ignored this field and the read loop fell back to
  178. a hardcoded 60 s timeout, closing OrcaSlicer's idle MQTT connection
  179. after exactly 60 s instead of waiting 1.5× the client-negotiated
  180. keepalive as MQTT spec §4.4 requires.
  181. """
  182. def test_returns_negotiated_keepalive_on_auth_success(self):
  183. server = _make_server()
  184. writer = MagicMock()
  185. writer.write = MagicMock()
  186. writer.drain = AsyncMock()
  187. # Also stub status-report writes triggered post-auth
  188. payload = _build_connect_payload(keep_alive=120)
  189. result = asyncio.run(server._handle_connect(payload, writer))
  190. assert result == (True, 120)
  191. def test_returns_zero_keepalive_for_no_keepalive_clients(self):
  192. """`keep_alive == 0` in CONNECT means the client opted out per spec
  193. §3.1.2.10 — server must report it back so the read loop can drop
  194. the timeout entirely."""
  195. server = _make_server()
  196. writer = MagicMock()
  197. writer.write = MagicMock()
  198. writer.drain = AsyncMock()
  199. payload = _build_connect_payload(keep_alive=0)
  200. result = asyncio.run(server._handle_connect(payload, writer))
  201. assert result == (True, 0)
  202. def test_returns_false_with_zero_keepalive_on_auth_failure(self):
  203. """Bad password path still returns the tuple shape so the caller's
  204. unpack doesn't break."""
  205. server = _make_server()
  206. writer = MagicMock()
  207. writer.write = MagicMock()
  208. writer.drain = AsyncMock()
  209. payload = _build_connect_payload(keep_alive=60, access_code="wrong")
  210. result = asyncio.run(server._handle_connect(payload, writer))
  211. assert result == (False, 0)
  212. def test_returns_false_with_zero_keepalive_on_parse_error(self):
  213. """Malformed CONNECT (e.g. truncated) must not crash and must
  214. still hand a tuple back to the caller."""
  215. server = _make_server()
  216. writer = MagicMock()
  217. writer.write = MagicMock()
  218. writer.drain = AsyncMock()
  219. # 3 bytes is far shorter than even the protocol-name prefix needs.
  220. result = asyncio.run(server._handle_connect(b"\x00\x04MQ", writer))
  221. assert result == (False, 0)
  222. class TestHandleClientHonoursKeepalive:
  223. """`_handle_client` must use the client-negotiated keepalive for its
  224. read-loop timeout, not the hardcoded 60 s default (#1548)."""
  225. @pytest.mark.asyncio
  226. async def test_idle_client_kept_alive_beyond_60s_when_keepalive_is_long(self):
  227. """The literal #1548 repro: a client negotiates keepalive=180 and
  228. then sits idle. Pre-fix the read loop closed the connection after
  229. 60 s (hardcoded). Post-fix the timeout is 1.5×180=270 s — so the
  230. connection is still open after the original 60 s boundary."""
  231. server = _make_server()
  232. server._running = True
  233. reader = asyncio.StreamReader()
  234. # Feed CONNECT (with fixed header byte 0x10 + remaining length)
  235. connect_payload = _build_connect_payload(keep_alive=180)
  236. rl = len(connect_payload)
  237. # MQTT remaining-length encoding for values <128 is a single byte.
  238. assert rl < 128
  239. reader.feed_data(bytes([0x10, rl]) + connect_payload)
  240. # No further data — client goes idle.
  241. writer = MagicMock()
  242. writer.write = MagicMock()
  243. writer.drain = AsyncMock()
  244. writer.close = MagicMock()
  245. writer.wait_closed = AsyncMock()
  246. writer.get_extra_info = MagicMock(return_value=("1.2.3.4", 12345))
  247. # Patch the post-auth status-report send so the handler doesn't
  248. # depend on a real serial/payload path.
  249. server._send_status_report = AsyncMock()
  250. task = asyncio.create_task(server._handle_client(reader, writer))
  251. # Wait past the old hardcoded 60 s threshold by a margin. Real-time
  252. # 60 s would be far too slow for a unit test — drive simulated time
  253. # by yielding repeatedly. asyncio.wait_for with a real wall-clock
  254. # delay would actually consume 60 s of test time, so instead we
  255. # patch the timeout to a small value and assert the timeout chosen
  256. # by the loop matches our expectation.
  257. # Approach: let the task progress past the CONNECT, then cancel.
  258. await asyncio.sleep(0.1) # give the loop a chance to process CONNECT
  259. # The post-auth read should now be waiting on reader with the
  260. # negotiated keepalive. We can't observe the timeout directly, so
  261. # we just verify the connection wasn't closed by inspecting close().
  262. assert not writer.close.called, "connection should still be open after CONNECT"
  263. # Cancel cleanly
  264. task.cancel()
  265. try:
  266. await task
  267. except asyncio.CancelledError:
  268. pass
  269. @pytest.mark.asyncio
  270. async def test_idle_client_closed_after_one_and_a_half_times_keepalive(self):
  271. """Tight verification: keepalive=2 must close the connection in
  272. ~3 s (1.5×) of idle, well above the noise floor for an async test."""
  273. server = _make_server()
  274. server._running = True
  275. reader = asyncio.StreamReader()
  276. connect_payload = _build_connect_payload(keep_alive=2)
  277. rl = len(connect_payload)
  278. assert rl < 128
  279. reader.feed_data(bytes([0x10, rl]) + connect_payload)
  280. writer = MagicMock()
  281. writer.write = MagicMock()
  282. writer.drain = AsyncMock()
  283. writer.close = MagicMock()
  284. writer.wait_closed = AsyncMock()
  285. writer.get_extra_info = MagicMock(return_value=("1.2.3.4", 12345))
  286. server._send_status_report = AsyncMock()
  287. start = asyncio.get_event_loop().time()
  288. await server._handle_client(reader, writer)
  289. elapsed = asyncio.get_event_loop().time() - start
  290. # 1.5×2s = 3s expected. Allow ±1s slop for the read of CONNECT
  291. # itself + scheduler jitter on a loaded CI box.
  292. assert 2.0 < elapsed < 4.5, f"expected ~3s timeout, got {elapsed:.2f}s"
  293. @pytest.mark.asyncio
  294. async def test_pingreq_resets_idle_timeout(self):
  295. """A PINGREQ within the keepalive window must keep the connection
  296. open — the per-packet read timeout is restarted on every byte
  297. delivered, so the next idle window is measured from the PINGREQ."""
  298. server = _make_server()
  299. server._running = True
  300. reader = asyncio.StreamReader()
  301. connect_payload = _build_connect_payload(keep_alive=2)
  302. rl = len(connect_payload)
  303. assert rl < 128
  304. reader.feed_data(bytes([0x10, rl]) + connect_payload)
  305. writer = MagicMock()
  306. writer.write = MagicMock()
  307. writer.drain = AsyncMock()
  308. writer.close = MagicMock()
  309. writer.wait_closed = AsyncMock()
  310. writer.get_extra_info = MagicMock(return_value=("1.2.3.4", 12345))
  311. server._send_status_report = AsyncMock()
  312. async def _drive():
  313. # Feed a PINGREQ (0xC0 0x00 — type 12 with zero remaining length)
  314. # at 2s, which is 1s *before* the would-be timeout, and a
  315. # DISCONNECT at 2.5s so the test exits deterministically.
  316. await asyncio.sleep(2.0)
  317. reader.feed_data(bytes([0xC0, 0x00]))
  318. await asyncio.sleep(0.5)
  319. reader.feed_data(bytes([0xE0, 0x00])) # DISCONNECT
  320. driver = asyncio.create_task(_drive())
  321. start = asyncio.get_event_loop().time()
  322. await server._handle_client(reader, writer)
  323. elapsed = asyncio.get_event_loop().time() - start
  324. await driver # ensure no orphan task
  325. # Exit was via DISCONNECT at ~2.5s, NOT a 3s keepalive timeout.
  326. # Allow generous slop.
  327. assert 2.0 < elapsed < 3.0, f"expected exit on DISCONNECT near 2.5s, got {elapsed:.2f}s"