websocket.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. """GHSA-r2qv follow-up — WebSocket auth gate.
  2. Previously ``/api/v1/ws`` accepted *any* network client and immediately
  3. streamed every ``printer_status`` / ``print_start`` / ``print_complete``
  4. / ``archive_*`` / ``inventory_changed`` broadcast back to it. That is
  5. the GHSA-gc24 shape on a different protocol — anyone who could reach
  6. the HTTP port could subscribe to every printer event in the system.
  7. This endpoint now validates a short-lived token (minted by
  8. ``POST /api/v1/auth/ws-token`` behind ``Permission.WEBSOCKET_CONNECT``)
  9. *before* ``websocket.accept()``. When auth is disabled, no token is
  10. required (the legacy SPA-friendly path). The token is reused across
  11. reconnects within its 60-minute window so a brief network blip does
  12. not require a round-trip to the auth router.
  13. """
  14. from __future__ import annotations
  15. import logging
  16. from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
  17. from backend.app.core.auth import is_auth_enabled, verify_websocket_token
  18. from backend.app.core.database import async_session
  19. from backend.app.core.websocket import ws_manager
  20. from backend.app.services.background_dispatch import background_dispatch
  21. from backend.app.services.printer_manager import printer_manager, printer_state_to_dict
  22. logger = logging.getLogger(__name__)
  23. router = APIRouter()
  24. # 4401 mirrors the WebSocket "unauthorised" application close code
  25. # convention used by Sec-WebSocket-Protocol authors (private-use range
  26. # is 4000-4999 per RFC 6455). The SPA distinguishes 4401 from network
  27. # drops and refetches a token instead of retrying with the old one.
  28. _WS_CLOSE_UNAUTHORIZED = 4401
  29. @router.websocket("/ws")
  30. async def websocket_endpoint(websocket: WebSocket, token: str | None = Query(default=None)) -> None:
  31. """WebSocket endpoint for real-time updates.
  32. Connection auth (GHSA-r2qv follow-up):
  33. - Auth disabled → connect without a token, identical to the prior
  34. behaviour (single-user / local-network deployments).
  35. - Auth enabled → ``?token=<value>`` query param must hold an
  36. unexpired token minted via ``POST /api/v1/auth/ws-token``.
  37. Missing / invalid / expired token → ``close(code=4401)`` *before*
  38. ``accept()`` so no ``ws_manager.broadcast`` ever reaches the
  39. caller (broadcasts walk ``active_connections`` blindly — letting
  40. an unauthenticated socket into that list is a fan-out leak).
  41. The auth check is fail-closed at every error path: a DB exception
  42. while reading the ``auth_enabled`` setting closes the connection
  43. rather than admitting the caller.
  44. """
  45. # Authenticate before accept() so an unauth caller never lands in
  46. # ws_manager.active_connections (where broadcasts blindly fan out).
  47. try:
  48. async with async_session() as db:
  49. auth_required = await is_auth_enabled(db)
  50. except Exception: # SEC-AUTH-EXC: DB failure on auth probe → fail-closed (refuse connect), matches is_auth_enabled itself which returns True on error
  51. logger.error("WebSocket auth probe failed; refusing connection", exc_info=True)
  52. await websocket.close(code=_WS_CLOSE_UNAUTHORIZED)
  53. return
  54. principal: str | None = None
  55. if auth_required:
  56. if not token:
  57. logger.info("WebSocket connect refused: no token (auth enabled)")
  58. await websocket.close(code=_WS_CLOSE_UNAUTHORIZED)
  59. return
  60. principal = await verify_websocket_token(token)
  61. if principal is None:
  62. logger.info("WebSocket connect refused: invalid or expired token")
  63. await websocket.close(code=_WS_CLOSE_UNAUTHORIZED)
  64. return
  65. # Token verified (or auth disabled); now safe to admit the connection.
  66. logger.info("WebSocket client connecting (principal=%s)", principal if principal else "<anonymous>")
  67. await ws_manager.connect(websocket)
  68. # Stash on connection state for any future per-message permission
  69. # logic; today the message handlers are read-only and only respond
  70. # to the requesting socket, so the stash is informational. The
  71. # explicit attribute (rather than a side dict) means a future
  72. # ``broadcast_to_principal()`` helper can filter on it without
  73. # touching every call site.
  74. websocket.state.bambuddy_principal = principal
  75. logger.info("WebSocket client connected")
  76. try:
  77. # Send initial status of all printers.
  78. statuses = printer_manager.get_all_statuses()
  79. for printer_id, state in statuses.items():
  80. await websocket.send_json(
  81. {
  82. "type": "printer_status",
  83. "printer_id": printer_id,
  84. "data": printer_state_to_dict(state, printer_id, printer_manager.get_model(printer_id)),
  85. }
  86. )
  87. dispatch_state = await background_dispatch.get_state()
  88. if (dispatch_state.get("dispatched", 0) + dispatch_state.get("processing", 0)) > 0:
  89. await websocket.send_json(
  90. {
  91. "type": "background_dispatch",
  92. "data": dispatch_state,
  93. }
  94. )
  95. logger.info("Sent initial status for %s printers", len(statuses))
  96. # Keep connection alive and handle incoming messages.
  97. while True:
  98. data = await websocket.receive_json()
  99. # Handle ping/pong for keepalive
  100. if data.get("type") == "ping":
  101. await websocket.send_json({"type": "pong"})
  102. # Handle status request
  103. elif data.get("type") == "get_status":
  104. printer_id = data.get("printer_id")
  105. if printer_id:
  106. state = printer_manager.get_status(printer_id)
  107. if state:
  108. await websocket.send_json(
  109. {
  110. "type": "printer_status",
  111. "printer_id": printer_id,
  112. "data": printer_state_to_dict(state, printer_id, printer_manager.get_model(printer_id)),
  113. }
  114. )
  115. except WebSocketDisconnect:
  116. logger.info("WebSocket client disconnected normally")
  117. await ws_manager.disconnect(websocket)
  118. except Exception as e:
  119. logger.error("WebSocket error: %s", e, exc_info=True)
  120. await ws_manager.disconnect(websocket)