| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560 |
- """Tests for Virtual Printer MQTT server."""
- import ast
- import asyncio
- import inspect
- import json
- from pathlib import Path
- from unittest.mock import AsyncMock, MagicMock
- import pytest
- from backend.app.services.virtual_printer.mqtt_server import SimpleMQTTServer
- class TestMQTTServerNoGlobalState:
- """Ensure MQTT server doesn't set global asyncio state."""
- def test_no_global_exception_handler(self):
- """MQTT server must not call set_exception_handler().
- set_exception_handler() is global to the event loop. When multiple
- VP instances run, each would overwrite the previous handler,
- causing lost error context and spurious 'Unhandled exception in
- client_connected_cb' messages.
- """
- source = inspect.getsource(SimpleMQTTServer)
- tree = ast.parse(source)
- for node in ast.walk(tree):
- if isinstance(node, ast.Attribute) and node.attr == "set_exception_handler":
- raise AssertionError(
- "SimpleMQTTServer must not call set_exception_handler(). "
- "It overwrites the global asyncio exception handler, "
- "breaking multi-VP setups."
- )
- def _make_server(serial: str = "01P00A391800001") -> SimpleMQTTServer:
- """Build a SimpleMQTTServer with dummy cert paths (start() is never called)."""
- return SimpleMQTTServer(
- serial=serial,
- access_code="deadbeef",
- cert_path=Path("/tmp/unused.crt"), # nosec B108
- key_path=Path("/tmp/unused.key"), # nosec B108
- model="C12",
- )
- class TestExtractSerialFromTopic:
- """_extract_serial_from_topic should pull the serial out of device topics."""
- @pytest.mark.parametrize(
- "topic,expected",
- [
- ("device/01P00A391800001/request", "01P00A391800001"),
- ("device/09400A391800003/report", "09400A391800003"),
- ("device/00M00A391800004/request/subpath", "00M00A391800004"),
- ],
- )
- def test_valid_topics(self, topic, expected):
- assert SimpleMQTTServer._extract_serial_from_topic(topic) == expected
- @pytest.mark.parametrize(
- "topic",
- [
- "",
- "device/",
- "device//request", # empty serial
- "notdevice/01P00A/request",
- "random",
- ],
- )
- def test_invalid_topics(self, topic):
- assert SimpleMQTTServer._extract_serial_from_topic(topic) is None
- def _build_publish_payload(topic: str, message: dict) -> bytes:
- """Build the MQTT PUBLISH packet *payload* (past the fixed header byte)."""
- topic_bytes = topic.encode("utf-8")
- message_bytes = json.dumps(message).encode("utf-8")
- return len(topic_bytes).to_bytes(2, "big") + topic_bytes + message_bytes
- class TestPublishHandlerAdaptiveSerial:
- """#927: `_handle_publish` must accept any `device/*/request` topic from an
- authenticated client and use the topic's serial for all responses."""
- def test_handle_publish_accepts_mismatched_serial(self):
- """Prior behavior silently dropped publishes whose topic serial didn't
- equal self.serial. After the fix the handler must run and learn the
- client's serial.
- """
- server = _make_server(serial="01P00A391800001") # synthetic VP serial
- server._client_serials["test-client"] = server.serial # simulate post-CONNECT
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- # Slicer publishes with a *different* serial — the exact bug from #927.
- topic = "device/01P00AABCDEFGHI/request"
- payload = _build_publish_payload(topic, {"info": {"command": "get_version", "sequence_id": "42"}})
- asyncio.run(server._handle_publish(0x30, payload, writer, "test-client"))
- # Learned the client's serial.
- assert server._client_serials["test-client"] == "01P00AABCDEFGHI"
- # Wrote at least one packet to the slicer (the version response).
- assert writer.write.called
- all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
- # Response topic must contain the *client's* serial, not self.serial.
- assert b"device/01P00AABCDEFGHI/report" in all_bytes
- assert b"device/01P00A391800001/report" not in all_bytes
- # Response body carries get_version with the client's serial as sn.
- assert b'"command": "get_version"' in all_bytes
- assert b'"sn": "01P00AABCDEFGHI"' in all_bytes
- def test_handle_publish_ignores_non_request_topics(self):
- server = _make_server()
- server._client_serials["c1"] = server.serial
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- payload = _build_publish_payload(
- "device/01P00AABCDEFGHI/report", # /report, not /request
- {"pushing": {"command": "pushall"}},
- )
- asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
- assert not writer.write.called # no response
- # Client serial unchanged
- assert server._client_serials["c1"] == server.serial
- def test_handle_publish_pushall_uses_client_serial(self):
- """pushall → status_report must be sent on the client's subscribed topic."""
- server = _make_server(serial="01P00A391800001")
- server._client_serials["c1"] = server.serial
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- payload = _build_publish_payload(
- "device/CUSTOMSERIAL123/request",
- {"pushing": {"command": "pushall", "sequence_id": "1"}},
- )
- asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
- all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
- assert b"device/CUSTOMSERIAL123/report" in all_bytes
- assert b'"command": "push_status"' in all_bytes
- assert server._client_serials["c1"] == "CUSTOMSERIAL123"
- def test_handle_publish_tolerates_null_terminated_payload(self):
- """#927: OrcaSlicer on Linux appends the C-string \\0 to MQTT payloads.
- The handler must still parse and respond rather than silently dropping."""
- server = _make_server(serial="01P00A391800001")
- server._client_serials["c1"] = server.serial
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- topic = "device/01P00A391800001/request"
- topic_bytes = topic.encode("utf-8")
- # Real-world bytes captured from EdwardChamberlain's support log: the
- # JSON ends with an extra \x00 that strict json.loads rejects.
- message_bytes = b'{"pushing":{"command":"pushall","sequence_id":"7"}}\x00'
- payload = len(topic_bytes).to_bytes(2, "big") + topic_bytes + message_bytes
- asyncio.run(server._handle_publish(0x30, payload, writer, "c1"))
- all_bytes = b"".join(call.args[0] for call in writer.write.call_args_list)
- assert b"device/01P00A391800001/report" in all_bytes
- assert b'"command": "push_status"' in all_bytes
- class TestClientSerialLifecycle:
- """_client_serials must be cleaned up on disconnect/stop to avoid leaks."""
- def test_stop_clears_client_serials(self):
- server = _make_server()
- server._client_serials["a"] = "X"
- server._client_serials["b"] = "Y"
- # stop() is async but we only need to cover the clear() path; run a minimal version
- asyncio.run(server.stop())
- assert server._client_serials == {}
- def _build_connect_payload(
- keep_alive: int,
- access_code: str = "deadbeef",
- username: str = "bblp",
- client_id: str = "orca",
- ) -> bytes:
- """Build an MQTT CONNECT variable-header + payload (without the fixed header).
- Layout matches the parser in `_handle_connect`:
- proto_name_len(2) + "MQTT"(4) + level(1) + flags(1) + keepalive(2) +
- client_id_len(2) + client_id + username_len(2) + username +
- password_len(2) + password.
- """
- proto = b"MQTT"
- parts = bytearray()
- parts += len(proto).to_bytes(2, "big") + proto
- parts += bytes([0x04, 0xC2]) # protocol level 4 (MQTT 3.1.1), flags: user+pass+clean
- parts += keep_alive.to_bytes(2, "big")
- cid = client_id.encode("utf-8")
- parts += len(cid).to_bytes(2, "big") + cid
- user = username.encode("utf-8")
- parts += len(user).to_bytes(2, "big") + user
- pw = access_code.encode("utf-8")
- parts += len(pw).to_bytes(2, "big") + pw
- return bytes(parts)
- class TestHandleConnectKeepalive:
- """`_handle_connect` must return the negotiated keepalive (#1548).
- Pre-fix, the parser ignored this field and the read loop fell back to
- a hardcoded 60 s timeout, closing OrcaSlicer's idle MQTT connection
- after exactly 60 s instead of waiting 1.5× the client-negotiated
- keepalive as MQTT spec §4.4 requires.
- """
- def test_returns_negotiated_keepalive_on_auth_success(self):
- server = _make_server()
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- # Also stub status-report writes triggered post-auth
- payload = _build_connect_payload(keep_alive=120)
- result = asyncio.run(server._handle_connect(payload, writer))
- assert result == (True, 120)
- def test_returns_zero_keepalive_for_no_keepalive_clients(self):
- """`keep_alive == 0` in CONNECT means the client opted out per spec
- §3.1.2.10 — server must report it back so the read loop can drop
- the timeout entirely."""
- server = _make_server()
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- payload = _build_connect_payload(keep_alive=0)
- result = asyncio.run(server._handle_connect(payload, writer))
- assert result == (True, 0)
- def test_returns_false_with_zero_keepalive_on_auth_failure(self):
- """Bad password path still returns the tuple shape so the caller's
- unpack doesn't break."""
- server = _make_server()
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- payload = _build_connect_payload(keep_alive=60, access_code="wrong")
- result = asyncio.run(server._handle_connect(payload, writer))
- assert result == (False, 0)
- def test_returns_false_with_zero_keepalive_on_parse_error(self):
- """Malformed CONNECT (e.g. truncated) must not crash and must
- still hand a tuple back to the caller."""
- server = _make_server()
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- # 3 bytes is far shorter than even the protocol-name prefix needs.
- result = asyncio.run(server._handle_connect(b"\x00\x04MQ", writer))
- assert result == (False, 0)
- class TestHandleClientHonoursKeepalive:
- """`_handle_client` must use the client-negotiated keepalive for its
- read-loop timeout, not the hardcoded 60 s default (#1548)."""
- @pytest.mark.asyncio
- async def test_idle_client_kept_alive_beyond_60s_when_keepalive_is_long(self):
- """The literal #1548 repro: a client negotiates keepalive=180 and
- then sits idle. Pre-fix the read loop closed the connection after
- 60 s (hardcoded). Post-fix the timeout is 1.5×180=270 s — so the
- connection is still open after the original 60 s boundary."""
- server = _make_server()
- server._running = True
- reader = asyncio.StreamReader()
- # Feed CONNECT (with fixed header byte 0x10 + remaining length)
- connect_payload = _build_connect_payload(keep_alive=180)
- rl = len(connect_payload)
- # MQTT remaining-length encoding for values <128 is a single byte.
- assert rl < 128
- reader.feed_data(bytes([0x10, rl]) + connect_payload)
- # No further data — client goes idle.
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- writer.close = MagicMock()
- writer.wait_closed = AsyncMock()
- writer.get_extra_info = MagicMock(return_value=("1.2.3.4", 12345))
- # Patch the post-auth status-report send so the handler doesn't
- # depend on a real serial/payload path.
- server._send_status_report = AsyncMock()
- task = asyncio.create_task(server._handle_client(reader, writer))
- # Wait past the old hardcoded 60 s threshold by a margin. Real-time
- # 60 s would be far too slow for a unit test — drive simulated time
- # by yielding repeatedly. asyncio.wait_for with a real wall-clock
- # delay would actually consume 60 s of test time, so instead we
- # patch the timeout to a small value and assert the timeout chosen
- # by the loop matches our expectation.
- # Approach: let the task progress past the CONNECT, then cancel.
- await asyncio.sleep(0.1) # give the loop a chance to process CONNECT
- # The post-auth read should now be waiting on reader with the
- # negotiated keepalive. We can't observe the timeout directly, so
- # we just verify the connection wasn't closed by inspecting close().
- assert not writer.close.called, "connection should still be open after CONNECT"
- # Cancel cleanly
- task.cancel()
- try:
- await task
- except asyncio.CancelledError:
- pass
- @pytest.mark.asyncio
- async def test_idle_client_closed_after_one_and_a_half_times_keepalive(self):
- """Tight verification: keepalive=2 must close the connection in
- ~3 s (1.5×) of idle, well above the noise floor for an async test."""
- server = _make_server()
- server._running = True
- reader = asyncio.StreamReader()
- connect_payload = _build_connect_payload(keep_alive=2)
- rl = len(connect_payload)
- assert rl < 128
- reader.feed_data(bytes([0x10, rl]) + connect_payload)
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- writer.close = MagicMock()
- writer.wait_closed = AsyncMock()
- writer.get_extra_info = MagicMock(return_value=("1.2.3.4", 12345))
- server._send_status_report = AsyncMock()
- start = asyncio.get_event_loop().time()
- await server._handle_client(reader, writer)
- elapsed = asyncio.get_event_loop().time() - start
- # 1.5×2s = 3s expected. Allow ±1s slop for the read of CONNECT
- # itself + scheduler jitter on a loaded CI box.
- assert 2.0 < elapsed < 4.5, f"expected ~3s timeout, got {elapsed:.2f}s"
- @pytest.mark.asyncio
- async def test_pingreq_resets_idle_timeout(self):
- """A PINGREQ within the keepalive window must keep the connection
- open — the per-packet read timeout is restarted on every byte
- delivered, so the next idle window is measured from the PINGREQ."""
- server = _make_server()
- server._running = True
- reader = asyncio.StreamReader()
- connect_payload = _build_connect_payload(keep_alive=2)
- rl = len(connect_payload)
- assert rl < 128
- reader.feed_data(bytes([0x10, rl]) + connect_payload)
- writer = MagicMock()
- writer.write = MagicMock()
- writer.drain = AsyncMock()
- writer.close = MagicMock()
- writer.wait_closed = AsyncMock()
- writer.get_extra_info = MagicMock(return_value=("1.2.3.4", 12345))
- server._send_status_report = AsyncMock()
- async def _drive():
- # Feed a PINGREQ (0xC0 0x00 — type 12 with zero remaining length)
- # at 2s, which is 1s *before* the would-be timeout, and a
- # DISCONNECT at 2.5s so the test exits deterministically.
- await asyncio.sleep(2.0)
- reader.feed_data(bytes([0xC0, 0x00]))
- await asyncio.sleep(0.5)
- reader.feed_data(bytes([0xE0, 0x00])) # DISCONNECT
- driver = asyncio.create_task(_drive())
- start = asyncio.get_event_loop().time()
- await server._handle_client(reader, writer)
- elapsed = asyncio.get_event_loop().time() - start
- await driver # ensure no orphan task
- # Exit was via DISCONNECT at ~2.5s, NOT a 3s keepalive timeout.
- # Allow generous slop.
- assert 2.0 < elapsed < 3.0, f"expected exit on DISCONNECT near 2.5s, got {elapsed:.2f}s"
- class TestAuthRateLimit:
- """Per-IP rate-limiting of MQTT CONNECT auth attempts.
- Bambuddy's VP exposes an 8-char access code via the slicer-facing MQTT
- server. Without a rate-limit the code is brute-forceable by anyone who
- can reach the VP's bind IP (LAN or VPN). The limiter records each
- failed auth attempt per source IP and rejects further CONNECTs from
- that IP once the per-window threshold is crossed, then auto-recovers
- when the window expires. Verified here against the production
- constants imported from the module.
- """
- @pytest.fixture
- def server(self):
- from backend.app.services.virtual_printer.mqtt_server import SimpleMQTTServer
- return _make_server(serial="01P00A391800002")
- def test_under_limit_attempts_are_allowed(self, server):
- from backend.app.services.virtual_printer.mqtt_server import _AUTH_RATE_LIMIT_MAX_ATTEMPTS
- ip = "192.168.1.50"
- # Record (max-1) failures and verify the next attempt is still allowed.
- for _ in range(_AUTH_RATE_LIMIT_MAX_ATTEMPTS - 1):
- server._record_auth_failure(ip)
- assert server._is_auth_rate_limited(ip) is False
- def test_exactly_max_attempts_triggers_rate_limit(self, server):
- from backend.app.services.virtual_printer.mqtt_server import _AUTH_RATE_LIMIT_MAX_ATTEMPTS
- ip = "192.168.1.50"
- for _ in range(_AUTH_RATE_LIMIT_MAX_ATTEMPTS):
- server._record_auth_failure(ip)
- # At exactly the cap, further attempts must be rejected.
- assert server._is_auth_rate_limited(ip) is True
- def test_window_recovery_clears_old_failures(self, server):
- """A burst of failures older than the window must NOT count
- against the IP — the limiter is sliding, not cumulative."""
- import time as _time
- from backend.app.services.virtual_printer.mqtt_server import (
- _AUTH_RATE_LIMIT_MAX_ATTEMPTS,
- _AUTH_RATE_LIMIT_WINDOW_SECONDS,
- )
- ip = "192.168.1.50"
- # Inject stale timestamps directly — older than the window means the
- # limiter should drop them on the next probe.
- stale = _time.monotonic() - _AUTH_RATE_LIMIT_WINDOW_SECONDS - 1.0
- server._auth_failures[ip] = [stale] * _AUTH_RATE_LIMIT_MAX_ATTEMPTS
- # All recorded failures are outside the window — IP is no longer rate-limited.
- assert server._is_auth_rate_limited(ip) is False
- # And the dict entry was pruned (empty) instead of leaking forever.
- assert ip not in server._auth_failures
- def test_multiple_ips_tracked_independently(self, server):
- from backend.app.services.virtual_printer.mqtt_server import _AUTH_RATE_LIMIT_MAX_ATTEMPTS
- # One IP exhausts the budget; another IP must still be allowed.
- for _ in range(_AUTH_RATE_LIMIT_MAX_ATTEMPTS):
- server._record_auth_failure("10.0.0.1")
- assert server._is_auth_rate_limited("10.0.0.1") is True
- assert server._is_auth_rate_limited("10.0.0.2") is False
- def test_successful_auth_clears_failure_history(self, server):
- """A successful auth must wipe the IP's prior-failures stash so the
- user isn't penalised for typos that they ultimately corrected."""
- from backend.app.services.virtual_printer.mqtt_server import _AUTH_RATE_LIMIT_MAX_ATTEMPTS
- ip = "192.168.1.50"
- # Build up failures one short of the cap.
- for _ in range(_AUTH_RATE_LIMIT_MAX_ATTEMPTS - 1):
- server._record_auth_failure(ip)
- # Successful auth must clear them.
- server._clear_auth_failures(ip)
- # Now a subsequent failure starts the count over at 1 (well under cap).
- server._record_auth_failure(ip)
- assert server._is_auth_rate_limited(ip) is False
- class TestPendingRequestRouting:
- """`push_raw_to_clients` routes the printer's response back only to the
- slicer that originated the request, not to every connected slicer.
- The bridge calls `push_raw_to_clients(topic, payload)` for every
- response it sees from the real printer. Before the fix, this fanned
- out to every connected slicer — leaking slicer A's
- `extrusion_cali_get` response into slicer B's command stream. The
- fix records `sequence_id → client_id` on the way out and looks it
- back up on the way in.
- """
- @pytest.fixture
- def server(self):
- return _make_server(serial="01P00A391800003")
- def test_single_slicer_routes_to_that_slicer(self, server):
- """Sanity check: when one slicer is connected, the response goes
- to it regardless of whether the seq_id was recorded."""
- # No recorded request, no slicer seen → returns None (broadcast).
- assert server._lookup_pending_request_client(b'{"print": {"sequence_id": "999"}}') is None
- def test_record_pending_request_walks_nested_blocks(self, server):
- """The slicer wraps its sequence_id under whichever subsystem the
- command targets (`print`, `info`, `system`, …). The helper must
- find it regardless of which key it's nested under."""
- server._record_pending_request(
- {"print": {"command": "extrusion_cali_get", "sequence_id": "42"}},
- "clientA",
- )
- assert server._pending_requests.get("42") == "clientA"
- server._record_pending_request(
- {"info": {"command": "get_version", "sequence_id": "43"}},
- "clientB",
- )
- assert server._pending_requests.get("43") == "clientB"
- def test_lookup_pops_entry_so_each_response_routes_once(self, server):
- """Once a response is matched, the pending entry is consumed so
- a later coincidental sequence_id from a printer-initiated push
- doesn't mis-route to the original client."""
- server._record_pending_request({"print": {"sequence_id": "100"}}, "clientA")
- # First lookup finds it…
- assert server._lookup_pending_request_client(b'{"print": {"sequence_id": "100"}}') == "clientA"
- # …and removes it. Second lookup with the same seq returns None
- # (treated as printer-initiated → broadcast fallback).
- assert server._lookup_pending_request_client(b'{"print": {"sequence_id": "100"}}') is None
- def test_fifo_eviction_when_cache_fills(self, server):
- """If a slicer sends many commands without responses (or the
- responses never arrive), the oldest entries age out so the dict
- can't grow unbounded."""
- from backend.app.services.virtual_printer.mqtt_server import _PENDING_REQUEST_MAX_ENTRIES
- # Fill the dict to one over the cap.
- for i in range(_PENDING_REQUEST_MAX_ENTRIES + 1):
- server._record_pending_request({"print": {"sequence_id": str(i)}}, "clientA")
- # The dict is capped — the oldest entry ("0") is gone, the newest is in.
- assert len(server._pending_requests) <= _PENDING_REQUEST_MAX_ENTRIES
- assert "0" not in server._pending_requests
- assert str(_PENDING_REQUEST_MAX_ENTRIES) in server._pending_requests
- def test_response_without_recorded_seq_returns_none_for_broadcast(self, server):
- """Printer-initiated pushes (push_status etc.) have a sequence_id
- the bridge never saw recorded. ``_lookup_pending_request_client``
- must return None so ``push_raw_to_clients`` falls back to fan-out
- — every slicer expects to receive these unsolicited messages."""
- # No record for this seq id.
- assert server._lookup_pending_request_client(b'{"print": {"sequence_id": "777"}}') is None
- def test_malformed_payload_falls_through_to_broadcast(self, server):
- """A non-JSON / non-dict payload must NOT crash the routing path —
- return None so the response broadcasts."""
- assert server._lookup_pending_request_client(b"not valid json") is None
- assert server._lookup_pending_request_client(b'"a string, not a dict"') is None
|