obico_smoothing.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """Temporal smoothing for Obico ML detection scores.
  2. Ports Obico's failure-detection math:
  3. - per-frame `current_p` = sum of detection confidences
  4. - `ewm_mean` = exponentially weighted mean (alpha = 2 / (span + 1), span = 12)
  5. - `rolling_mean_short` = ~310 frames of recent activity (≈52 min at 10s/frame)
  6. - `rolling_mean_long` = ~7200 frames of long-term baseline noise
  7. - First `WARMUP_FRAMES` frames always report "safe" while the state settles
  8. - Final score = max(ewm_mean, rolling_mean_short - rolling_mean_long)
  9. - Thresholds: LOW < score < HIGH is "warning", >= HIGH is "failure"
  10. """
  11. import math
  12. from collections import deque
  13. from dataclasses import dataclass, field
  14. EWM_SPAN = 12
  15. EWM_ALPHA = 2.0 / (EWM_SPAN + 1)
  16. ROLLING_SHORT = 310
  17. ROLLING_LONG = 7200
  18. WARMUP_FRAMES = 30
  19. # Base thresholds; sensitivity multipliers adjust them
  20. BASE_LOW = 0.38
  21. BASE_HIGH = 0.78
  22. SENSITIVITY_MULT = {
  23. "low": 1.25, # harder to trigger — higher thresholds
  24. "medium": 1.0,
  25. "high": 0.75, # easier to trigger — lower thresholds
  26. }
  27. def thresholds(sensitivity: str) -> tuple[float, float]:
  28. mult = SENSITIVITY_MULT.get(sensitivity, 1.0)
  29. return BASE_LOW * mult, BASE_HIGH * mult
  30. @dataclass
  31. class PrintState:
  32. """Per-print smoothing state. Reset when a new print starts."""
  33. frame_count: int = 0
  34. ewm_mean: float = 0.0
  35. short_sum: float = 0.0
  36. long_sum: float = 0.0
  37. short_buf: deque = field(default_factory=lambda: deque(maxlen=ROLLING_SHORT))
  38. long_buf: deque = field(default_factory=lambda: deque(maxlen=ROLLING_LONG))
  39. def update(self, current_p: float) -> float:
  40. """Feed a new per-frame score and return the smoothed score.
  41. Returns 0.0 during warmup so early noise doesn't trigger actions.
  42. """
  43. self.frame_count += 1
  44. if self.frame_count == 1:
  45. self.ewm_mean = current_p
  46. else:
  47. self.ewm_mean = EWM_ALPHA * current_p + (1 - EWM_ALPHA) * self.ewm_mean
  48. if len(self.short_buf) == self.short_buf.maxlen:
  49. self.short_sum -= self.short_buf[0]
  50. self.short_buf.append(current_p)
  51. self.short_sum += current_p
  52. if len(self.long_buf) == self.long_buf.maxlen:
  53. self.long_sum -= self.long_buf[0]
  54. self.long_buf.append(current_p)
  55. self.long_sum += current_p
  56. if self.frame_count <= WARMUP_FRAMES:
  57. return 0.0
  58. short_mean = self.short_sum / len(self.short_buf)
  59. long_mean = self.long_sum / len(self.long_buf)
  60. return max(self.ewm_mean, short_mean - long_mean)
  61. def classify(score: float, sensitivity: str) -> str:
  62. """Return 'safe', 'warning', or 'failure' for a smoothed score."""
  63. low, high = thresholds(sensitivity)
  64. if score >= high:
  65. return "failure"
  66. if score >= low:
  67. return "warning"
  68. return "safe"
  69. def score_from_detections(detections: list) -> float:
  70. """Sum confidences from the ML API `detections` array.
  71. Each detection is `[label, confidence, [x, y, w, h]]`. We only care about
  72. the confidence column — label is always "failure" for the single-class model.
  73. """
  74. total = 0.0
  75. for det in detections or []:
  76. try:
  77. value = float(det[1])
  78. except (IndexError, TypeError, ValueError):
  79. continue
  80. if math.isnan(value) or math.isinf(value):
  81. continue
  82. total += value
  83. return total