test_obico_smoothing.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """Unit tests for Obico detection smoothing math."""
  2. import pytest
  3. from backend.app.services.obico_smoothing import (
  4. BASE_HIGH,
  5. BASE_LOW,
  6. WARMUP_FRAMES,
  7. PrintState,
  8. classify,
  9. score_from_detections,
  10. thresholds,
  11. )
  12. class TestThresholds:
  13. def test_medium_matches_base(self):
  14. low, high = thresholds("medium")
  15. assert low == pytest.approx(BASE_LOW)
  16. assert high == pytest.approx(BASE_HIGH)
  17. def test_low_sensitivity_is_stricter(self):
  18. low, high = thresholds("low")
  19. assert low > BASE_LOW
  20. assert high > BASE_HIGH
  21. def test_high_sensitivity_is_looser(self):
  22. low, high = thresholds("high")
  23. assert low < BASE_LOW
  24. assert high < BASE_HIGH
  25. def test_unknown_falls_back_to_medium(self):
  26. assert thresholds("bogus") == thresholds("medium")
  27. class TestScoreFromDetections:
  28. def test_empty(self):
  29. assert score_from_detections([]) == 0.0
  30. assert score_from_detections(None) == 0.0
  31. def test_sums_confidences(self):
  32. dets = [["failure", 0.3, [0, 0, 10, 10]], ["failure", 0.5, [0, 0, 10, 10]]]
  33. assert score_from_detections(dets) == pytest.approx(0.8)
  34. def test_ignores_malformed(self):
  35. dets = [["failure", 0.4, []], ["bad"], ["failure", "nan", []]]
  36. assert score_from_detections(dets) == pytest.approx(0.4)
  37. class TestPrintState:
  38. def test_warmup_returns_zero(self):
  39. state = PrintState()
  40. for _ in range(WARMUP_FRAMES):
  41. assert state.update(0.9) == 0.0
  42. def test_after_warmup_returns_nonzero_for_hits(self):
  43. state = PrintState()
  44. for _ in range(WARMUP_FRAMES):
  45. state.update(0.9)
  46. score = state.update(0.9)
  47. assert score > 0.0
  48. def test_sustained_zero_stays_safe(self):
  49. state = PrintState()
  50. scores = [state.update(0.0) for _ in range(WARMUP_FRAMES + 50)]
  51. assert max(scores) == 0.0
  52. def test_sustained_hits_eventually_cross_high(self):
  53. """A stream of high-confidence frames must escalate to 'failure'."""
  54. state = PrintState()
  55. final = 0.0
  56. for _ in range(WARMUP_FRAMES + 200):
  57. final = state.update(1.0)
  58. _, high = thresholds("medium")
  59. assert final >= high
  60. def test_isolated_spike_does_not_trigger_failure(self):
  61. """A single noisy frame in a clean stream must not cross HIGH."""
  62. state = PrintState()
  63. for _ in range(WARMUP_FRAMES):
  64. state.update(0.0)
  65. score = state.update(1.0)
  66. _, high = thresholds("medium")
  67. assert score < high
  68. class TestClassify:
  69. def test_safe(self):
  70. assert classify(0.0, "medium") == "safe"
  71. assert classify(BASE_LOW - 0.01, "medium") == "safe"
  72. def test_warning(self):
  73. assert classify(BASE_LOW, "medium") == "warning"
  74. assert classify((BASE_LOW + BASE_HIGH) / 2, "medium") == "warning"
  75. def test_failure(self):
  76. assert classify(BASE_HIGH, "medium") == "failure"
  77. assert classify(1.0, "medium") == "failure"