test_vp_mqtt_server.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. """Tests for Virtual Printer MQTT server."""
  2. import ast
  3. import asyncio
  4. import inspect
  5. import json
  6. from pathlib import Path
  7. from unittest.mock import AsyncMock, MagicMock
  8. import pytest
  9. from backend.app.services.virtual_printer.mqtt_server import SimpleMQTTServer
  10. class TestMQTTServerNoGlobalState:
  11. """Ensure MQTT server doesn't set global asyncio state."""
  12. def test_no_global_exception_handler(self):
  13. """MQTT server must not call set_exception_handler().
  14. set_exception_handler() is global to the event loop. When multiple
  15. VP instances run, each would overwrite the previous handler,
  16. causing lost error context and spurious 'Unhandled exception in
  17. client_connected_cb' messages.
  18. """
  19. source = inspect.getsource(SimpleMQTTServer)
  20. tree = ast.parse(source)
  21. for node in ast.walk(tree):
  22. if isinstance(node, ast.Attribute) and node.attr == "set_exception_handler":
  23. raise AssertionError(
  24. "SimpleMQTTServer must not call set_exception_handler(). "
  25. "It overwrites the global asyncio exception handler, "
  26. "breaking multi-VP setups."
  27. )
  28. def _make_server(serial: str = "01P00A391800001") -> SimpleMQTTServer:
  29. """Build a SimpleMQTTServer with dummy cert paths (start() is never called)."""
  30. return SimpleMQTTServer(
  31. serial=serial,
  32. access_code="deadbeef",
  33. cert_path=Path("/tmp/unused.crt"),
  34. key_path=Path("/tmp/unused.key"),
  35. model="C12",
  36. )
  37. class TestExtractSerialFromTopic:
  38. """_extract_serial_from_topic should pull the serial out of device topics."""
  39. @pytest.mark.parametrize(
  40. "topic,expected",
  41. [
  42. ("device/01P00A391800001/request", "01P00A391800001"),
  43. ("device/09400A391800003/report", "09400A391800003"),
  44. ("device/00M00A391800004/request/subpath", "00M00A391800004"),
  45. ],
  46. )
  47. def test_valid_topics(self, topic, expected):
  48. assert SimpleMQTTServer._extract_serial_from_topic(topic) == expected
  49. @pytest.mark.parametrize(
  50. "topic",
  51. [
  52. "",
  53. "device/",
  54. "device//request", # empty serial
  55. "notdevice/01P00A/request",
  56. "random",
  57. ],
  58. )
  59. def test_invalid_topics(self, topic):
  60. assert SimpleMQTTServer._extract_serial_from_topic(topic) is None
  61. def _build_publish_payload(topic: str, message: dict) -> bytes:
  62. """Build the MQTT PUBLISH packet *payload* (past the fixed header byte)."""
  63. topic_bytes = topic.encode("utf-8")
  64. message_bytes = json.dumps(message).encode("utf-8")
  65. return len(topic_bytes).to_bytes(2, "big") + topic_bytes + message_bytes
  66. class TestPublishHandlerAdaptiveSerial:
  67. """#927: `_handle_publish` must accept any `device/*/request` topic from an
  68. authenticated client and use the topic's serial for all responses."""
  69. def test_handle_publish_accepts_mismatched_serial(self):
  70. """Prior behavior silently dropped publishes whose topic serial didn't
  71. equal self.serial. After the fix the handler must run and learn the
  72. client's serial.
  73. """
  74. server = _make_server(serial="01P00A391800001") # synthetic VP serial
  75. server._client_serials["test-client"] = server.serial # simulate post-CONNECT
  76. writer = MagicMock()
  77. writer.write = MagicMock()
  78. writer.drain = AsyncMock()
  79. # Slicer publishes with a *different* serial — the exact bug from #927.
  80. topic = "device/01P00AABCDEFGHI/request"
  81. payload = _build_publish_payload(topic, {"info": {"command": "get_version", "sequence_id": "42"}})
  82. asyncio.run(server._handle_publish(0x30, payload, writer, "test-client"))
  83. # Learned the client's serial.
  84. assert server._client_serials["test-client"] == "01P00AABCDEFGHI"
  85. # Wrote at least one packet to the slicer (the version response).
  86. assert writer.write.called
  87. all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
  88. # Response topic must contain the *client's* serial, not self.serial.
  89. assert b"device/01P00AABCDEFGHI/report" in all_bytes
  90. assert b"device/01P00A391800001/report" not in all_bytes
  91. # Response body carries get_version with the client's serial as sn.
  92. assert b'"command": "get_version"' in all_bytes
  93. assert b'"sn": "01P00AABCDEFGHI"' in all_bytes
  94. def test_handle_publish_ignores_non_request_topics(self):
  95. server = _make_server()
  96. server._client_serials["c1"] = server.serial
  97. writer = MagicMock()
  98. writer.write = MagicMock()
  99. writer.drain = AsyncMock()
  100. payload = _build_publish_payload(
  101. "device/01P00AABCDEFGHI/report", # /report, not /request
  102. {"pushing": {"command": "pushall"}},
  103. )
  104. asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
  105. assert not writer.write.called # no response
  106. # Client serial unchanged
  107. assert server._client_serials["c1"] == server.serial
  108. def test_handle_publish_pushall_uses_client_serial(self):
  109. """pushall → status_report must be sent on the client's subscribed topic."""
  110. server = _make_server(serial="01P00A391800001")
  111. server._client_serials["c1"] = server.serial
  112. writer = MagicMock()
  113. writer.write = MagicMock()
  114. writer.drain = AsyncMock()
  115. payload = _build_publish_payload(
  116. "device/CUSTOMSERIAL123/request",
  117. {"pushing": {"command": "pushall", "sequence_id": "1"}},
  118. )
  119. asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
  120. all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
  121. assert b"device/CUSTOMSERIAL123/report" in all_bytes
  122. assert b'"command": "push_status"' in all_bytes
  123. assert server._client_serials["c1"] == "CUSTOMSERIAL123"
  124. class TestClientSerialLifecycle:
  125. """_client_serials must be cleaned up on disconnect/stop to avoid leaks."""
  126. def test_stop_clears_client_serials(self):
  127. server = _make_server()
  128. server._client_serials["a"] = "X"
  129. server._client_serials["b"] = "Y"
  130. # stop() is async but we only need to cover the clear() path; run a minimal version
  131. asyncio.run(server.stop())
  132. assert server._client_serials == {}