printer_manager.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. import asyncio
  2. from collections.abc import Callable
  3. from sqlalchemy import select
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from backend.app.models.printer import Printer
  6. from backend.app.services.bambu_mqtt import BambuMQTTClient, MQTTLogEntry, PrinterState, get_stage_name
  7. class PrinterManager:
  8. """Manager for multiple printer connections."""
  9. def __init__(self):
  10. self._clients: dict[int, BambuMQTTClient] = {}
  11. self._on_print_start: Callable[[int, dict], None] | None = None
  12. self._on_print_complete: Callable[[int, dict], None] | None = None
  13. self._on_status_change: Callable[[int, PrinterState], None] | None = None
  14. self._on_ams_change: Callable[[int, list], None] | None = None
  15. self._loop: asyncio.AbstractEventLoop | None = None
  16. def set_event_loop(self, loop: asyncio.AbstractEventLoop):
  17. """Set the event loop for async callbacks."""
  18. self._loop = loop
  19. def set_print_start_callback(self, callback: Callable[[int, dict], None]):
  20. """Set callback for print start events."""
  21. self._on_print_start = callback
  22. def set_print_complete_callback(self, callback: Callable[[int, dict], None]):
  23. """Set callback for print completion events."""
  24. self._on_print_complete = callback
  25. def set_status_change_callback(self, callback: Callable[[int, PrinterState], None]):
  26. """Set callback for status change events."""
  27. self._on_status_change = callback
  28. def set_ams_change_callback(self, callback: Callable[[int, list], None]):
  29. """Set callback for AMS data change events."""
  30. self._on_ams_change = callback
  31. def _schedule_async(self, coro):
  32. """Schedule an async coroutine from a sync context.
  33. Captures exceptions from the coroutine and logs them to prevent
  34. silent failures in callbacks.
  35. """
  36. if self._loop and self._loop.is_running():
  37. future = asyncio.run_coroutine_threadsafe(coro, self._loop)
  38. def handle_exception(f):
  39. try:
  40. # This will re-raise any exception from the coroutine
  41. f.result()
  42. except Exception as e:
  43. import logging
  44. logging.getLogger(__name__).error(f"Exception in scheduled callback: {e}", exc_info=True)
  45. future.add_done_callback(handle_exception)
  46. async def connect_printer(self, printer: Printer) -> bool:
  47. """Connect to a printer."""
  48. if printer.id in self._clients:
  49. self.disconnect_printer(printer.id)
  50. printer_id = printer.id
  51. def on_state_change(state: PrinterState):
  52. if self._on_status_change:
  53. self._schedule_async(self._on_status_change(printer_id, state))
  54. def on_print_start(data: dict):
  55. if self._on_print_start:
  56. self._schedule_async(self._on_print_start(printer_id, data))
  57. def on_print_complete(data: dict):
  58. if self._on_print_complete:
  59. self._schedule_async(self._on_print_complete(printer_id, data))
  60. def on_ams_change(ams_data: list):
  61. if self._on_ams_change:
  62. self._schedule_async(self._on_ams_change(printer_id, ams_data))
  63. client = BambuMQTTClient(
  64. ip_address=printer.ip_address,
  65. serial_number=printer.serial_number,
  66. access_code=printer.access_code,
  67. on_state_change=on_state_change,
  68. on_print_start=on_print_start,
  69. on_print_complete=on_print_complete,
  70. on_ams_change=on_ams_change,
  71. )
  72. client.connect()
  73. self._clients[printer_id] = client
  74. # Wait a moment for connection
  75. await asyncio.sleep(1)
  76. return client.state.connected
  77. def disconnect_printer(self, printer_id: int):
  78. """Disconnect from a printer."""
  79. if printer_id in self._clients:
  80. self._clients[printer_id].disconnect()
  81. del self._clients[printer_id]
  82. def disconnect_all(self):
  83. """Disconnect from all printers."""
  84. for printer_id in list(self._clients.keys()):
  85. self.disconnect_printer(printer_id)
  86. def get_status(self, printer_id: int) -> PrinterState | None:
  87. """Get the current status of a printer (checks for stale connections)."""
  88. if printer_id in self._clients:
  89. client = self._clients[printer_id]
  90. # Check staleness and update connected state if needed
  91. client.check_staleness()
  92. return client.state
  93. return None
  94. def get_all_statuses(self) -> dict[int, PrinterState]:
  95. """Get status of all connected printers (checks for stale connections)."""
  96. result = {}
  97. for printer_id, client in self._clients.items():
  98. # Check staleness and update connected state if needed
  99. client.check_staleness()
  100. result[printer_id] = client.state
  101. return result
  102. def is_connected(self, printer_id: int) -> bool:
  103. """Check if a printer is connected (checks for stale connections)."""
  104. if printer_id in self._clients:
  105. client = self._clients[printer_id]
  106. # Check staleness and update connected state if needed
  107. return client.check_staleness()
  108. return False
  109. def get_client(self, printer_id: int) -> BambuMQTTClient | None:
  110. """Get the MQTT client for a printer."""
  111. return self._clients.get(printer_id)
  112. def mark_printer_offline(self, printer_id: int):
  113. """Mark a printer as offline and trigger status callback.
  114. This is used when we know the printer power was cut (e.g., smart plug turned off)
  115. to immediately update the UI without waiting for MQTT timeout.
  116. """
  117. import logging
  118. logger = logging.getLogger(__name__)
  119. if printer_id in self._clients:
  120. client = self._clients[printer_id]
  121. if client.state.connected:
  122. logger.info(f"Marking printer {printer_id} as offline (smart plug power off)")
  123. client.state.connected = False
  124. client.state.state = "unknown"
  125. # Trigger the status change callback to broadcast via WebSocket
  126. if self._on_status_change:
  127. self._schedule_async(self._on_status_change(printer_id, client.state))
  128. def start_print(self, printer_id: int, filename: str, plate_id: int = 1) -> bool:
  129. """Start a print on a connected printer."""
  130. if printer_id in self._clients:
  131. return self._clients[printer_id].start_print(filename, plate_id)
  132. return False
  133. def stop_print(self, printer_id: int) -> bool:
  134. """Stop the current print on a connected printer."""
  135. if printer_id in self._clients:
  136. return self._clients[printer_id].stop_print()
  137. return False
  138. async def wait_for_cooldown(
  139. self,
  140. printer_id: int,
  141. target_temp: float = 50.0,
  142. timeout: int = 600,
  143. check_interval: int = 10,
  144. ) -> bool:
  145. """Wait for the nozzle to cool down to a safe temperature.
  146. Args:
  147. printer_id: The printer to monitor
  148. target_temp: Target temperature to wait for (default 50°C)
  149. timeout: Maximum seconds to wait (default 600s = 10 min)
  150. check_interval: Seconds between temperature checks (default 10s)
  151. Returns:
  152. True if cooled down, False if timeout or not connected
  153. """
  154. import logging
  155. logger = logging.getLogger(__name__)
  156. elapsed = 0
  157. while elapsed < timeout:
  158. state = self.get_status(printer_id)
  159. if not state or not state.connected:
  160. logger.warning(f"Printer {printer_id} disconnected during cooldown wait")
  161. return False
  162. # Check nozzle temperature (and nozzle_2 for dual extruders)
  163. nozzle_temp = state.temperatures.get("nozzle", 0)
  164. nozzle_2_temp = state.temperatures.get("nozzle_2", 0)
  165. max_temp = max(nozzle_temp, nozzle_2_temp)
  166. if max_temp <= target_temp:
  167. logger.info(f"Printer {printer_id} cooled down to {max_temp}°C")
  168. return True
  169. logger.debug(f"Printer {printer_id} nozzle at {max_temp}°C, waiting for {target_temp}°C...")
  170. await asyncio.sleep(check_interval)
  171. elapsed += check_interval
  172. logger.warning(f"Printer {printer_id} cooldown timeout after {timeout}s")
  173. return False
  174. def enable_logging(self, printer_id: int, enabled: bool = True) -> bool:
  175. """Enable or disable MQTT logging for a printer."""
  176. if printer_id in self._clients:
  177. self._clients[printer_id].enable_logging(enabled)
  178. return True
  179. return False
  180. def get_logs(self, printer_id: int) -> list[MQTTLogEntry]:
  181. """Get MQTT logs for a printer."""
  182. if printer_id in self._clients:
  183. return self._clients[printer_id].get_logs()
  184. return []
  185. def clear_logs(self, printer_id: int) -> bool:
  186. """Clear MQTT logs for a printer."""
  187. if printer_id in self._clients:
  188. self._clients[printer_id].clear_logs()
  189. return True
  190. return False
  191. def is_logging_enabled(self, printer_id: int) -> bool:
  192. """Check if logging is enabled for a printer."""
  193. if printer_id in self._clients:
  194. return self._clients[printer_id].logging_enabled
  195. return False
  196. def request_status_update(self, printer_id: int) -> bool:
  197. """Request a full status update from the printer.
  198. This sends a 'pushall' command to get the latest data including nozzle info.
  199. """
  200. if printer_id in self._clients:
  201. return self._clients[printer_id].request_status_update()
  202. return False
  203. async def test_connection(
  204. self,
  205. ip_address: str,
  206. serial_number: str,
  207. access_code: str,
  208. ) -> dict:
  209. """Test connection to a printer without persisting."""
  210. client = BambuMQTTClient(
  211. ip_address=ip_address,
  212. serial_number=serial_number,
  213. access_code=access_code,
  214. )
  215. try:
  216. client.connect()
  217. await asyncio.sleep(2)
  218. result = {
  219. "success": client.state.connected,
  220. "state": client.state.state if client.state.connected else None,
  221. "model": client.state.raw_data.get("device_model"),
  222. }
  223. finally:
  224. client.disconnect()
  225. return result
  226. def get_derived_status_name(state: PrinterState) -> str | None:
  227. """
  228. Compute a human-readable status name based on printer state.
  229. Uses stg_cur when available, otherwise derives status from temperature data
  230. when the printer is heating before a print starts.
  231. """
  232. # If we have a valid calibration stage, use it
  233. # X1 models use -1 for idle, A1/P1 models use 255 for idle
  234. # Valid stage numbers are 0-254
  235. if 0 <= state.stg_cur < 255:
  236. return get_stage_name(state.stg_cur)
  237. # If not in RUNNING state, no derived status needed
  238. if state.state != "RUNNING":
  239. return None
  240. # Check if we're in an early phase where temperatures are heating
  241. temps = state.temperatures or {}
  242. progress = state.progress or 0
  243. # Only derive heating status when progress is very low (< 2%)
  244. # This indicates we're in the preparation phase, not actually printing
  245. if progress >= 2:
  246. return None
  247. # Check bed temperature - if target is set and current is significantly below
  248. bed_temp = temps.get("bed", 0)
  249. bed_target = temps.get("bed_target", 0)
  250. # Check nozzle temperature
  251. nozzle_temp = temps.get("nozzle", 0)
  252. nozzle_target = temps.get("nozzle_target", 0)
  253. # Temperature thresholds: consider "heating" if more than 10°C below target
  254. TEMP_THRESHOLD = 10
  255. # Determine what's heating (prioritize bed since it takes longer)
  256. if bed_target > 30 and (bed_target - bed_temp) > TEMP_THRESHOLD:
  257. return "Heating heatbed"
  258. elif nozzle_target > 30 and (nozzle_target - nozzle_temp) > TEMP_THRESHOLD:
  259. return "Heating nozzle"
  260. # If targets are set but we're close to them, we might be in final prep
  261. if bed_target > 30 or nozzle_target > 30:
  262. if progress == 0 and state.layer_num == 0:
  263. return "Preparing"
  264. return None
  265. def printer_state_to_dict(state: PrinterState, printer_id: int | None = None) -> dict:
  266. """Convert PrinterState to a JSON-serializable dict."""
  267. # Parse AMS data from raw_data
  268. ams_units = []
  269. vt_tray = None
  270. raw_data = state.raw_data or {}
  271. # Build K-profile lookup map: cali_idx -> k_value
  272. kprofile_map: dict[int, float] = {}
  273. for kp in state.kprofiles or []:
  274. if kp.slot_id is not None and kp.k_value:
  275. try:
  276. kprofile_map[kp.slot_id] = float(kp.k_value)
  277. except (ValueError, TypeError):
  278. pass
  279. if "ams" in raw_data and isinstance(raw_data["ams"], list):
  280. for ams_data in raw_data["ams"]:
  281. trays = []
  282. for tray in ams_data.get("tray", []):
  283. tag_uid = tray.get("tag_uid")
  284. if tag_uid in ("", "0000000000000000"):
  285. tag_uid = None
  286. tray_uuid = tray.get("tray_uuid")
  287. if tray_uuid in ("", "00000000000000000000000000000000"):
  288. tray_uuid = None
  289. # Get K value: first try tray's k field, then lookup from K-profiles
  290. k_value = tray.get("k")
  291. cali_idx = tray.get("cali_idx")
  292. if k_value is None and cali_idx is not None and cali_idx in kprofile_map:
  293. k_value = kprofile_map[cali_idx]
  294. trays.append(
  295. {
  296. "id": tray.get("id", 0),
  297. "tray_color": tray.get("tray_color"),
  298. "tray_type": tray.get("tray_type"),
  299. "tray_sub_brands": tray.get("tray_sub_brands"),
  300. "tray_id_name": tray.get("tray_id_name"),
  301. "tray_info_idx": tray.get("tray_info_idx"),
  302. "remain": tray.get("remain", 0),
  303. "k": k_value,
  304. "cali_idx": cali_idx,
  305. "tag_uid": tag_uid,
  306. "tray_uuid": tray_uuid,
  307. "nozzle_temp_min": tray.get("nozzle_temp_min"),
  308. "nozzle_temp_max": tray.get("nozzle_temp_max"),
  309. }
  310. )
  311. # Prefer humidity_raw (actual percentage) over humidity (index 1-5)
  312. humidity_raw = ams_data.get("humidity_raw")
  313. humidity_idx = ams_data.get("humidity")
  314. humidity_value = None
  315. if humidity_raw is not None:
  316. try:
  317. humidity_value = int(humidity_raw)
  318. except (ValueError, TypeError):
  319. pass
  320. # Fall back to index if no raw value (index is 1-5, not percentage)
  321. if humidity_value is None and humidity_idx is not None:
  322. try:
  323. humidity_value = int(humidity_idx)
  324. except (ValueError, TypeError):
  325. pass
  326. # AMS-HT has 1 tray, regular AMS has 4 trays
  327. is_ams_ht = len(trays) == 1
  328. ams_units.append(
  329. {
  330. "id": ams_data.get("id", 0),
  331. "humidity": humidity_value,
  332. "temp": ams_data.get("temp"),
  333. "is_ams_ht": is_ams_ht,
  334. "tray": trays,
  335. }
  336. )
  337. # Parse virtual tray (external spool)
  338. if "vt_tray" in raw_data:
  339. vt_data = raw_data["vt_tray"]
  340. vt_tag_uid = vt_data.get("tag_uid")
  341. if vt_tag_uid in ("", "0000000000000000"):
  342. vt_tag_uid = None
  343. vt_tray_uuid = vt_data.get("tray_uuid")
  344. if vt_tray_uuid in ("", "00000000000000000000000000000000"):
  345. vt_tray_uuid = None
  346. # Get K value for vt_tray
  347. vt_k_value = vt_data.get("k")
  348. vt_cali_idx = vt_data.get("cali_idx")
  349. if vt_k_value is None and vt_cali_idx is not None and vt_cali_idx in kprofile_map:
  350. vt_k_value = kprofile_map[vt_cali_idx]
  351. vt_tray = {
  352. "id": 254,
  353. "tray_color": vt_data.get("tray_color"),
  354. "tray_type": vt_data.get("tray_type"),
  355. "tray_sub_brands": vt_data.get("tray_sub_brands"),
  356. "tray_id_name": vt_data.get("tray_id_name"),
  357. "tray_info_idx": vt_data.get("tray_info_idx"),
  358. "remain": vt_data.get("remain", 0),
  359. "k": vt_k_value,
  360. "cali_idx": vt_cali_idx,
  361. "tag_uid": vt_tag_uid,
  362. "tray_uuid": vt_tray_uuid,
  363. "nozzle_temp_min": vt_data.get("nozzle_temp_min"),
  364. "nozzle_temp_max": vt_data.get("nozzle_temp_max"),
  365. }
  366. # Get ams_extruder_map from raw_data (populated by MQTT handler from AMS info field)
  367. ams_extruder_map = raw_data.get("ams_extruder_map", {})
  368. result = {
  369. "connected": state.connected,
  370. "state": state.state,
  371. "current_print": state.current_print,
  372. "subtask_name": state.subtask_name,
  373. "gcode_file": state.gcode_file,
  374. "progress": state.progress,
  375. "remaining_time": state.remaining_time,
  376. "layer_num": state.layer_num,
  377. "total_layers": state.total_layers,
  378. "temperatures": state.temperatures,
  379. "hms_errors": [
  380. {"code": e.code, "attr": e.attr, "module": e.module, "severity": e.severity}
  381. for e in (state.hms_errors or [])
  382. ],
  383. # AMS data for filament colors
  384. "ams": ams_units if ams_units else None,
  385. "vt_tray": vt_tray,
  386. # AMS status for filament change tracking
  387. "ams_status_main": state.ams_status_main,
  388. "ams_status_sub": state.ams_status_sub,
  389. "tray_now": state.tray_now,
  390. # Per-AMS extruder map: {ams_id: extruder_id} where 0=right, 1=left
  391. "ams_extruder_map": ams_extruder_map,
  392. # WiFi signal strength
  393. "wifi_signal": state.wifi_signal,
  394. # Calibration stage tracking
  395. "stg_cur": state.stg_cur,
  396. "stg_cur_name": get_derived_status_name(state),
  397. # Printable objects count for skip objects feature
  398. "printable_objects_count": len(state.printable_objects),
  399. }
  400. # Add cover URL if there's an active print and printer_id is provided
  401. # Include PAUSE/PAUSED states so skip objects modal can show cover
  402. if printer_id and state.state in ("RUNNING", "PAUSE", "PAUSED") and state.gcode_file:
  403. result["cover_url"] = f"/api/v1/printers/{printer_id}/cover"
  404. else:
  405. result["cover_url"] = None
  406. return result
  407. # Global printer manager instance
  408. printer_manager = PrinterManager()
  409. async def init_printer_connections(db: AsyncSession):
  410. """Initialize connections to all active printers."""
  411. result = await db.execute(select(Printer).where(Printer.is_active.is_(True)))
  412. printers = result.scalars().all()
  413. for printer in printers:
  414. await printer_manager.connect_printer(printer)