| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- """Tests for Virtual Printer MQTT server."""
- import ast
- import asyncio
- import inspect
- import json
- from pathlib import Path
- from unittest.mock import AsyncMock, MagicMock
- import pytest
- from backend.app.services.virtual_printer.mqtt_server import SimpleMQTTServer
- class TestMQTTServerNoGlobalState:
- """Ensure MQTT server doesn't set global asyncio state."""
- def test_no_global_exception_handler(self):
- """MQTT server must not call set_exception_handler().
- set_exception_handler() is global to the event loop. When multiple
- VP instances run, each would overwrite the previous handler,
- causing lost error context and spurious 'Unhandled exception in
- client_connected_cb' messages.
- """
- source = inspect.getsource(SimpleMQTTServer)
- tree = ast.parse(source)
- for node in ast.walk(tree):
- if isinstance(node, ast.Attribute) and node.attr == "set_exception_handler":
- raise AssertionError(
- "SimpleMQTTServer must not call set_exception_handler(). "
- "It overwrites the global asyncio exception handler, "
- "breaking multi-VP setups."
- )
- def _make_server(serial: str = "01P00A391800001") -> SimpleMQTTServer:
- """Build a SimpleMQTTServer with dummy cert paths (start() is never called)."""
- return SimpleMQTTServer(
- serial=serial,
- access_code="deadbeef",
- cert_path=Path("/tmp/unused.crt"),
- key_path=Path("/tmp/unused.key"),
- model="C12",
- )
- class TestExtractSerialFromTopic:
- """_extract_serial_from_topic should pull the serial out of device topics."""
- @pytest.mark.parametrize(
- "topic,expected",
- [
- ("device/01P00A391800001/request", "01P00A391800001"),
- ("device/09400A391800003/report", "09400A391800003"),
- ("device/00M00A391800004/request/subpath", "00M00A391800004"),
- ],
- )
- def test_valid_topics(self, topic, expected):
- assert SimpleMQTTServer._extract_serial_from_topic(topic) == expected
- @pytest.mark.parametrize(
- "topic",
- [
- "",
- "device/",
- "device//request", # empty serial
- "notdevice/01P00A/request",
- "random",
- ],
- )
- def test_invalid_topics(self, topic):
- assert SimpleMQTTServer._extract_serial_from_topic(topic) is None
- def _build_publish_payload(topic: str, message: dict) -> bytes:
- """Build the MQTT PUBLISH packet *payload* (past the fixed header byte)."""
- topic_bytes = topic.encode("utf-8")
- message_bytes = json.dumps(message).encode("utf-8")
- return len(topic_bytes).to_bytes(2, "big") + topic_bytes + message_bytes
- class TestPublishHandlerAdaptiveSerial:
- """#927: `_handle_publish` must accept any `device/*/request` topic from an
- authenticated client and use the topic's serial for all responses."""
- def test_handle_publish_accepts_mismatched_serial(self):
- """Prior behavior silently dropped publishes whose topic serial didn't
- equal self.serial. After the fix the handler must run and learn the
- client's serial.
- """
- server = _make_server(serial="01P00A391800001") # synthetic VP serial
- server._client_serials["test-client"] = server.serial # simulate post-CONNECT
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- # Slicer publishes with a *different* serial — the exact bug from #927.
- topic = "device/01P00AABCDEFGHI/request"
- payload = _build_publish_payload(topic, {"info": {"command": "get_version", "sequence_id": "42"}})
- asyncio.run(server._handle_publish(0x30, payload, writer, "test-client"))
- # Learned the client's serial.
- assert server._client_serials["test-client"] == "01P00AABCDEFGHI"
- # Wrote at least one packet to the slicer (the version response).
- assert writer.write.called
- all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
- # Response topic must contain the *client's* serial, not self.serial.
- assert b"device/01P00AABCDEFGHI/report" in all_bytes
- assert b"device/01P00A391800001/report" not in all_bytes
- # Response body carries get_version with the client's serial as sn.
- assert b'"command": "get_version"' in all_bytes
- assert b'"sn": "01P00AABCDEFGHI"' in all_bytes
- def test_handle_publish_ignores_non_request_topics(self):
- server = _make_server()
- server._client_serials["c1"] = server.serial
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- payload = _build_publish_payload(
- "device/01P00AABCDEFGHI/report", # /report, not /request
- {"pushing": {"command": "pushall"}},
- )
- asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
- assert not writer.write.called # no response
- # Client serial unchanged
- assert server._client_serials["c1"] == server.serial
- def test_handle_publish_pushall_uses_client_serial(self):
- """pushall → status_report must be sent on the client's subscribed topic."""
- server = _make_server(serial="01P00A391800001")
- server._client_serials["c1"] = server.serial
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- payload = _build_publish_payload(
- "device/CUSTOMSERIAL123/request",
- {"pushing": {"command": "pushall", "sequence_id": "1"}},
- )
- asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
- all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
- assert b"device/CUSTOMSERIAL123/report" in all_bytes
- assert b'"command": "push_status"' in all_bytes
- assert server._client_serials["c1"] == "CUSTOMSERIAL123"
- class TestClientSerialLifecycle:
- """_client_serials must be cleaned up on disconnect/stop to avoid leaks."""
- def test_stop_clears_client_serials(self):
- server = _make_server()
- server._client_serials["a"] = "X"
- server._client_serials["b"] = "Y"
- # stop() is async but we only need to cover the clear() path; run a minimal version
- asyncio.run(server.stop())
- assert server._client_serials == {}
|