websocket.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import asyncio
  2. import json
  3. from typing import Any
  4. from fastapi import WebSocket
  5. class ConnectionManager:
  6. """Manages WebSocket connections and broadcasts."""
  7. def __init__(self):
  8. self.active_connections: list[WebSocket] = []
  9. self._lock = asyncio.Lock()
  10. async def connect(self, websocket: WebSocket):
  11. """Accept a new WebSocket connection."""
  12. await websocket.accept()
  13. async with self._lock:
  14. self.active_connections.append(websocket)
  15. async def disconnect(self, websocket: WebSocket):
  16. """Remove a WebSocket connection."""
  17. async with self._lock:
  18. if websocket in self.active_connections:
  19. self.active_connections.remove(websocket)
  20. async def broadcast(self, message: dict[str, Any]):
  21. """Broadcast a message to all connected clients."""
  22. if not self.active_connections:
  23. return
  24. data = json.dumps(message)
  25. async with self._lock:
  26. disconnected = []
  27. for connection in self.active_connections:
  28. try:
  29. await connection.send_text(data)
  30. except Exception:
  31. disconnected.append(connection)
  32. # Clean up disconnected clients
  33. for conn in disconnected:
  34. if conn in self.active_connections:
  35. self.active_connections.remove(conn)
  36. async def send_printer_status(self, printer_id: int, status: dict):
  37. """Send printer status update to all clients."""
  38. await self.broadcast(
  39. {
  40. "type": "printer_status",
  41. "printer_id": printer_id,
  42. "data": status,
  43. }
  44. )
  45. async def send_print_start(self, printer_id: int, data: dict):
  46. """Notify clients that a print has started."""
  47. await self.broadcast(
  48. {
  49. "type": "print_start",
  50. "printer_id": printer_id,
  51. "data": data,
  52. }
  53. )
  54. async def send_print_complete(self, printer_id: int, data: dict):
  55. """Notify clients that a print has completed."""
  56. await self.broadcast(
  57. {
  58. "type": "print_complete",
  59. "printer_id": printer_id,
  60. "data": data,
  61. }
  62. )
  63. async def send_archive_created(self, archive: dict):
  64. """Notify clients that a new archive was created."""
  65. await self.broadcast(
  66. {
  67. "type": "archive_created",
  68. "data": archive,
  69. }
  70. )
  71. async def send_archive_updated(self, archive: dict):
  72. """Notify clients that an archive was updated."""
  73. await self.broadcast(
  74. {
  75. "type": "archive_updated",
  76. "data": archive,
  77. }
  78. )
  79. # Global connection manager
  80. ws_manager = ConnectionManager()