| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- import asyncio
- from typing import Callable
- from dataclasses import asdict
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select
- from backend.app.models.printer import Printer
- from backend.app.services.bambu_mqtt import BambuMQTTClient, PrinterState, MQTTLogEntry
- from backend.app.services.bambu_ftp import BambuFTPClient
- class PrinterManager:
- """Manager for multiple printer connections."""
- def __init__(self):
- self._clients: dict[int, BambuMQTTClient] = {}
- self._on_print_start: Callable[[int, dict], None] | None = None
- self._on_print_complete: Callable[[int, dict], None] | None = None
- self._on_status_change: Callable[[int, PrinterState], None] | None = None
- self._on_ams_change: Callable[[int, list], None] | None = None
- self._loop: asyncio.AbstractEventLoop | None = None
- def set_event_loop(self, loop: asyncio.AbstractEventLoop):
- """Set the event loop for async callbacks."""
- self._loop = loop
- def set_print_start_callback(self, callback: Callable[[int, dict], None]):
- """Set callback for print start events."""
- self._on_print_start = callback
- def set_print_complete_callback(self, callback: Callable[[int, dict], None]):
- """Set callback for print completion events."""
- self._on_print_complete = callback
- def set_status_change_callback(self, callback: Callable[[int, PrinterState], None]):
- """Set callback for status change events."""
- self._on_status_change = callback
- def set_ams_change_callback(self, callback: Callable[[int, list], None]):
- """Set callback for AMS data change events."""
- self._on_ams_change = callback
- def _schedule_async(self, coro):
- """Schedule an async coroutine from a sync context."""
- if self._loop and self._loop.is_running():
- asyncio.run_coroutine_threadsafe(coro, self._loop)
- async def connect_printer(self, printer: Printer) -> bool:
- """Connect to a printer."""
- if printer.id in self._clients:
- self.disconnect_printer(printer.id)
- printer_id = printer.id
- def on_state_change(state: PrinterState):
- if self._on_status_change:
- self._schedule_async(
- self._on_status_change(printer_id, state)
- )
- def on_print_start(data: dict):
- if self._on_print_start:
- self._schedule_async(
- self._on_print_start(printer_id, data)
- )
- def on_print_complete(data: dict):
- if self._on_print_complete:
- self._schedule_async(
- self._on_print_complete(printer_id, data)
- )
- def on_ams_change(ams_data: list):
- if self._on_ams_change:
- self._schedule_async(
- self._on_ams_change(printer_id, ams_data)
- )
- client = BambuMQTTClient(
- ip_address=printer.ip_address,
- serial_number=printer.serial_number,
- access_code=printer.access_code,
- on_state_change=on_state_change,
- on_print_start=on_print_start,
- on_print_complete=on_print_complete,
- on_ams_change=on_ams_change,
- )
- client.connect()
- self._clients[printer_id] = client
- # Wait a moment for connection
- await asyncio.sleep(1)
- return client.state.connected
- def disconnect_printer(self, printer_id: int):
- """Disconnect from a printer."""
- if printer_id in self._clients:
- self._clients[printer_id].disconnect()
- del self._clients[printer_id]
- def disconnect_all(self):
- """Disconnect from all printers."""
- for printer_id in list(self._clients.keys()):
- self.disconnect_printer(printer_id)
- def get_status(self, printer_id: int) -> PrinterState | None:
- """Get the current status of a printer."""
- if printer_id in self._clients:
- return self._clients[printer_id].state
- return None
- def get_all_statuses(self) -> dict[int, PrinterState]:
- """Get status of all connected printers."""
- return {
- printer_id: client.state
- for printer_id, client in self._clients.items()
- }
- def is_connected(self, printer_id: int) -> bool:
- """Check if a printer is connected."""
- if printer_id in self._clients:
- return self._clients[printer_id].state.connected
- return False
- def get_client(self, printer_id: int) -> BambuMQTTClient | None:
- """Get the MQTT client for a printer."""
- return self._clients.get(printer_id)
- def mark_printer_offline(self, printer_id: int):
- """Mark a printer as offline and trigger status callback.
- This is used when we know the printer power was cut (e.g., smart plug turned off)
- to immediately update the UI without waiting for MQTT timeout.
- """
- import logging
- logger = logging.getLogger(__name__)
- if printer_id in self._clients:
- client = self._clients[printer_id]
- if client.state.connected:
- logger.info(f"Marking printer {printer_id} as offline (smart plug power off)")
- client.state.connected = False
- client.state.state = "unknown"
- # Trigger the status change callback to broadcast via WebSocket
- if self._on_status_change:
- self._schedule_async(self._on_status_change(printer_id, client.state))
- def start_print(self, printer_id: int, filename: str) -> bool:
- """Start a print on a connected printer."""
- if printer_id in self._clients:
- return self._clients[printer_id].start_print(filename)
- return False
- def stop_print(self, printer_id: int) -> bool:
- """Stop the current print on a connected printer."""
- if printer_id in self._clients:
- return self._clients[printer_id].stop_print()
- return False
- async def wait_for_cooldown(
- self,
- printer_id: int,
- target_temp: float = 50.0,
- timeout: int = 600,
- check_interval: int = 10,
- ) -> bool:
- """Wait for the nozzle to cool down to a safe temperature.
- Args:
- printer_id: The printer to monitor
- target_temp: Target temperature to wait for (default 50°C)
- timeout: Maximum seconds to wait (default 600s = 10 min)
- check_interval: Seconds between temperature checks (default 10s)
- Returns:
- True if cooled down, False if timeout or not connected
- """
- import logging
- logger = logging.getLogger(__name__)
- elapsed = 0
- while elapsed < timeout:
- state = self.get_status(printer_id)
- if not state or not state.connected:
- logger.warning(f"Printer {printer_id} disconnected during cooldown wait")
- return False
- # Check nozzle temperature (and nozzle_2 for dual extruders)
- nozzle_temp = state.temperatures.get("nozzle", 0)
- nozzle_2_temp = state.temperatures.get("nozzle_2", 0)
- max_temp = max(nozzle_temp, nozzle_2_temp)
- if max_temp <= target_temp:
- logger.info(f"Printer {printer_id} cooled down to {max_temp}°C")
- return True
- logger.debug(f"Printer {printer_id} nozzle at {max_temp}°C, waiting for {target_temp}°C...")
- await asyncio.sleep(check_interval)
- elapsed += check_interval
- logger.warning(f"Printer {printer_id} cooldown timeout after {timeout}s")
- return False
- def enable_logging(self, printer_id: int, enabled: bool = True) -> bool:
- """Enable or disable MQTT logging for a printer."""
- if printer_id in self._clients:
- self._clients[printer_id].enable_logging(enabled)
- return True
- return False
- def get_logs(self, printer_id: int) -> list[MQTTLogEntry]:
- """Get MQTT logs for a printer."""
- if printer_id in self._clients:
- return self._clients[printer_id].get_logs()
- return []
- def clear_logs(self, printer_id: int) -> bool:
- """Clear MQTT logs for a printer."""
- if printer_id in self._clients:
- self._clients[printer_id].clear_logs()
- return True
- return False
- def is_logging_enabled(self, printer_id: int) -> bool:
- """Check if logging is enabled for a printer."""
- if printer_id in self._clients:
- return self._clients[printer_id].logging_enabled
- return False
- def request_status_update(self, printer_id: int) -> bool:
- """Request a full status update from the printer.
- This sends a 'pushall' command to get the latest data including nozzle info.
- """
- if printer_id in self._clients:
- return self._clients[printer_id].request_status_update()
- return False
- async def test_connection(
- self,
- ip_address: str,
- serial_number: str,
- access_code: str,
- ) -> dict:
- """Test connection to a printer without persisting."""
- client = BambuMQTTClient(
- ip_address=ip_address,
- serial_number=serial_number,
- access_code=access_code,
- )
- try:
- client.connect()
- await asyncio.sleep(2)
- result = {
- "success": client.state.connected,
- "state": client.state.state if client.state.connected else None,
- "model": client.state.raw_data.get("device_model"),
- }
- finally:
- client.disconnect()
- return result
- def printer_state_to_dict(state: PrinterState, printer_id: int | None = None) -> dict:
- """Convert PrinterState to a JSON-serializable dict."""
- result = {
- "connected": state.connected,
- "state": state.state,
- "current_print": state.current_print,
- "subtask_name": state.subtask_name,
- "gcode_file": state.gcode_file,
- "progress": state.progress,
- "remaining_time": state.remaining_time,
- "layer_num": state.layer_num,
- "total_layers": state.total_layers,
- "temperatures": state.temperatures,
- "hms_errors": [
- {"code": e.code, "module": e.module, "severity": e.severity}
- for e in (state.hms_errors or [])
- ],
- }
- # Add cover URL if there's an active print and printer_id is provided
- if printer_id and state.state == "RUNNING" and state.gcode_file:
- result["cover_url"] = f"/api/v1/printers/{printer_id}/cover"
- else:
- result["cover_url"] = None
- return result
- # Global printer manager instance
- printer_manager = PrinterManager()
- async def init_printer_connections(db: AsyncSession):
- """Initialize connections to all active printers."""
- result = await db.execute(
- select(Printer).where(Printer.is_active == True)
- )
- printers = result.scalars().all()
- for printer in printers:
- await printer_manager.connect_printer(printer)
|