test_trace_middleware.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """Integration test for the trace-ID middleware contract.
  2. Tests focus on observable surface — what headers go out, what
  3. ContextVar value the route handler sees — rather than re-testing the
  4. ContextVar / filter primitives (those are covered in
  5. ``tests/unit/test_trace.py``).
  6. A minimal FastAPI app is used instead of the production ``backend.app.main``
  7. app: importing main.py would pull in the entire startup graph (DB
  8. migrations, MQTT subscribers, scheduler, etc.) just to assert "the
  9. middleware sets a header", and that overhead would dwarf the test value.
  10. The middleware function is copied inline so the test pins the exact
  11. contract expected of it.
  12. """
  13. from __future__ import annotations
  14. import re
  15. import pytest
  16. from fastapi import FastAPI
  17. from fastapi.testclient import TestClient
  18. from backend.app.core.trace import (
  19. generate_trace_id,
  20. get_trace_id,
  21. normalise_inbound_trace_id,
  22. trace_id_var,
  23. )
  24. def _build_app_with_trace_middleware() -> FastAPI:
  25. """Construct a minimal FastAPI app with the trace middleware wired up
  26. the same way main.py does it."""
  27. app = FastAPI()
  28. @app.middleware("http")
  29. async def trace_id_middleware(request, call_next):
  30. inbound = normalise_inbound_trace_id(request.headers.get("X-Trace-Id"))
  31. trace_id = inbound if inbound is not None else generate_trace_id()
  32. token = trace_id_var.set(trace_id)
  33. try:
  34. response = await call_next(request)
  35. finally:
  36. trace_id_var.reset(token)
  37. response.headers["X-Trace-Id"] = trace_id
  38. return response
  39. @app.get("/echo-trace")
  40. async def echo_trace():
  41. # Read the ContextVar from inside the request handler so the
  42. # test can assert that what's in the header matches what
  43. # downstream code sees. If these ever diverge, application
  44. # logs would be stamped with a different ID than the one the
  45. # client gets back — useless for correlation.
  46. return {"trace_id": get_trace_id()}
  47. return app
  48. @pytest.fixture
  49. def client() -> TestClient:
  50. return TestClient(_build_app_with_trace_middleware())
  51. class TestGeneratedTraceId:
  52. def test_response_carries_x_trace_id_header(self, client):
  53. """Every response must echo X-Trace-Id so a client can paste it
  54. into a server-side log search later — without it, the trace ID
  55. column in bambuddy.log is one-way only."""
  56. response = client.get("/echo-trace")
  57. assert response.status_code == 200
  58. assert "X-Trace-Id" in response.headers
  59. assert response.headers["X-Trace-Id"]
  60. def test_generated_id_matches_handler_view(self, client):
  61. """The X-Trace-Id header value must equal what the route handler
  62. saw in its ContextVar — otherwise client-side and server-side
  63. log searches use different keys and never join up."""
  64. response = client.get("/echo-trace")
  65. body_id = response.json()["trace_id"]
  66. header_id = response.headers["X-Trace-Id"]
  67. assert body_id == header_id
  68. def test_each_request_gets_a_unique_id(self, client):
  69. """Two consecutive requests should produce two different IDs —
  70. otherwise the column in the log file is useless for telling
  71. requests apart."""
  72. first = client.get("/echo-trace").headers["X-Trace-Id"]
  73. second = client.get("/echo-trace").headers["X-Trace-Id"]
  74. assert first != second
  75. def test_generated_id_format_is_short_hex(self, client):
  76. """Bound the visible width and shape of the column. If the
  77. generator ever switches format (e.g. UUID-with-dashes) the
  78. format-string column width changes and grep patterns that
  79. downstream tooling might rely on break — make the change
  80. deliberate by failing this test instead."""
  81. tid = client.get("/echo-trace").headers["X-Trace-Id"]
  82. assert re.fullmatch(r"[0-9a-f]+", tid), tid
  83. assert 4 <= len(tid) <= 32
  84. class TestInboundTraceIdRespected:
  85. def test_safe_inbound_id_is_echoed(self, client):
  86. """When the caller sends a sane X-Trace-Id, we honour it — this
  87. is the cross-system correlation case (caller's tracing system
  88. wants its span ID propagated)."""
  89. response = client.get("/echo-trace", headers={"X-Trace-Id": "client-sent-abc123"})
  90. assert response.headers["X-Trace-Id"] == "client-sent-abc123"
  91. assert response.json()["trace_id"] == "client-sent-abc123"
  92. def test_hostile_inbound_id_is_replaced(self, client):
  93. """A header that fails the validator (control chars,
  94. log-injection-shaped chars, etc.) must NOT reach the response
  95. header or the log column — silently mint fresh and carry on,
  96. so a hostile/buggy caller can't break our log file but also
  97. can't break their own request by sending a bad header."""
  98. response = client.get("/echo-trace", headers={"X-Trace-Id": "abc\ndef rm -rf /"})
  99. echoed = response.headers["X-Trace-Id"]
  100. assert echoed != "abc\ndef rm -rf /"
  101. assert "\n" not in echoed
  102. assert " " not in echoed
  103. def test_overlong_inbound_id_is_replaced(self, client):
  104. """The cap protects bambuddy.log from a 1KB-per-line blowup if
  105. a caller sends a huge X-Trace-Id."""
  106. too_long = "a" * 100
  107. response = client.get("/echo-trace", headers={"X-Trace-Id": too_long})
  108. assert response.headers["X-Trace-Id"] != too_long
  109. class TestContextResetAfterRequest:
  110. def test_trace_id_var_resets_after_request_completes(self, client):
  111. """The middleware must reset the ContextVar in its ``finally``
  112. block. Without this, a record emitted in a totally unrelated
  113. background task that happens to inherit the test client's
  114. context would keep referencing a long-gone request's ID."""
  115. from backend.app.core.trace import TRACE_ID_PLACEHOLDER
  116. client.get("/echo-trace")
  117. # After the request returns, the test fixture's context should
  118. # no longer hold the request's ID.
  119. assert get_trace_id() == TRACE_ID_PLACEHOLDER