Browse Source

Merge pull request #328 from jeffsf/fixes/327

Shut down MQTT relay and smart plug services explicitly
MartinNYHC 3 months ago
parent
commit
179a9de521

+ 5 - 0
backend/app/main.py

@@ -217,6 +217,7 @@ from backend.app.services.bambu_mqtt import PrinterState
 from backend.app.services.github_backup import github_backup_service
 from backend.app.services.github_backup import github_backup_service
 from backend.app.services.homeassistant import homeassistant_service
 from backend.app.services.homeassistant import homeassistant_service
 from backend.app.services.mqtt_relay import mqtt_relay
 from backend.app.services.mqtt_relay import mqtt_relay
+from backend.app.services.mqtt_smart_plug import mqtt_smart_plug_service
 from backend.app.services.notification_service import notification_service
 from backend.app.services.notification_service import notification_service
 from backend.app.services.print_scheduler import scheduler as print_scheduler
 from backend.app.services.print_scheduler import scheduler as print_scheduler
 from backend.app.services.printer_manager import (
 from backend.app.services.printer_manager import (
@@ -3033,6 +3034,10 @@ async def lifespan(app: FastAPI):
     if virtual_printer_manager.is_enabled:
     if virtual_printer_manager.is_enabled:
         await virtual_printer_manager.configure(enabled=False)
         await virtual_printer_manager.configure(enabled=False)
 
 
+    await mqtt_smart_plug_service.disconnect(timeout=2)
+
+    await mqtt_relay.disconnect(timeout=2)
+
 
 
 app = FastAPI(
 app = FastAPI(
     title=app_settings.app_name,
     title=app_settings.app_name,

+ 8 - 2
backend/app/services/bambu_mqtt.py

@@ -11,6 +11,7 @@ import asyncio
 import json
 import json
 import logging
 import logging
 import ssl
 import ssl
+import threading
 import time
 import time
 from collections import deque
 from collections import deque
 from collections.abc import Callable
 from collections.abc import Callable
@@ -279,6 +280,7 @@ class BambuMQTTClient:
         self._message_log: deque[MQTTLogEntry] = deque(maxlen=100)
         self._message_log: deque[MQTTLogEntry] = deque(maxlen=100)
         self._logging_enabled: bool = False
         self._logging_enabled: bool = False
         self._last_message_time: float = 0.0  # Track when we last received a message
         self._last_message_time: float = 0.0  # Track when we last received a message
+        self._disconnection_event: threading.Event | None = None
         self._previous_ams_hash: str | None = None  # Track AMS changes
         self._previous_ams_hash: str | None = None  # Track AMS changes
 
 
         # K-profile command tracking
         # K-profile command tracking
@@ -357,6 +359,8 @@ class BambuMQTTClient:
         self.state.connected = False
         self.state.connected = False
         if self.on_state_change:
         if self.on_state_change:
             self.on_state_change(self.state)
             self.on_state_change(self.state)
+        if self._disconnection_event:
+            self._disconnection_event.set()
 
 
     def _on_message(self, client, userdata, msg):
     def _on_message(self, client, userdata, msg):
         try:
         try:
@@ -2374,11 +2378,13 @@ class BambuMQTTClient:
 
 
         return True
         return True
 
 
-    def disconnect(self):
+    def disconnect(self, timeout: float = 0):
         """Disconnect from the printer."""
         """Disconnect from the printer."""
         if self._client:
         if self._client:
-            self._client.loop_stop()
+            self._disconnection_event = threading.Event()
             self._client.disconnect()
             self._client.disconnect()
+            self._disconnection_event.wait(timeout=timeout)
+            self._client.loop_stop()
             self._client = None
             self._client = None
             self.state.connected = False
             self.state.connected = False
 
 

+ 7 - 2
backend/app/services/mqtt_relay.py

@@ -36,6 +36,7 @@ class MQTTRelayService:
         self._last_printer_status: dict[int, float] = {}  # printer_id -> last publish timestamp
         self._last_printer_status: dict[int, float] = {}  # printer_id -> last publish timestamp
         self._smart_plug_service = None  # Lazy import to avoid circular dependency
         self._smart_plug_service = None  # Lazy import to avoid circular dependency
         self._settings: dict = {}  # Store settings for smart plug service
         self._settings: dict = {}  # Store settings for smart plug service
+        self._disconnection_event: threading.Event | None = None
 
 
     async def configure(self, settings: dict) -> bool:
     async def configure(self, settings: dict) -> bool:
         """Configure MQTT connection from settings.
         """Configure MQTT connection from settings.
@@ -187,15 +188,19 @@ class MQTTRelayService:
             logger.warning("MQTT relay disconnected: %s", rc)
             logger.warning("MQTT relay disconnected: %s", rc)
         else:
         else:
             logger.info("MQTT relay disconnected cleanly")
             logger.info("MQTT relay disconnected cleanly")
+        if self._disconnection_event:
+            self._disconnection_event.set()
 
 
-    async def disconnect(self):
+    async def disconnect(self, timeout: float = 0):
         """Disconnect from MQTT broker."""
         """Disconnect from MQTT broker."""
         if self.client:
         if self.client:
             try:
             try:
                 # Publish offline status before disconnecting
                 # Publish offline status before disconnecting
                 self._publish_status("offline")
                 self._publish_status("offline")
-                self.client.loop_stop()
+                self._disconnection_event = threading.Event()
                 self.client.disconnect()
                 self.client.disconnect()
+                await asyncio.to_thread(self._disconnection_event.wait, timeout=timeout)
+                self.client.loop_stop()
             except Exception as e:
             except Exception as e:
                 logger.debug("MQTT disconnect error (ignored): %s", e)
                 logger.debug("MQTT disconnect error (ignored): %s", e)
             finally:
             finally:

+ 8 - 2
backend/app/services/mqtt_smart_plug.py

@@ -3,6 +3,7 @@
 This service enables integration with Shelly, Zigbee2MQTT, and other MQTT-based energy monitoring devices.
 This service enables integration with Shelly, Zigbee2MQTT, and other MQTT-based energy monitoring devices.
 """
 """
 
 
+import asyncio
 import json
 import json
 import logging
 import logging
 import threading
 import threading
@@ -52,6 +53,7 @@ class MQTTSmartPlugService:
         self.plug_configs: dict[int, dict[str, MQTTDataSourceConfig]] = {}
         self.plug_configs: dict[int, dict[str, MQTTDataSourceConfig]] = {}
         # plug_id -> latest data
         # plug_id -> latest data
         self.plug_data: dict[int, SmartPlugMQTTData] = {}
         self.plug_data: dict[int, SmartPlugMQTTData] = {}
+        self._disconnection_event: threading.Event | None = None
         self._configured = False
         self._configured = False
         self._broker = ""
         self._broker = ""
         self._port = 1883
         self._port = 1883
@@ -209,6 +211,8 @@ class MQTTSmartPlugService:
             logger.warning("MQTT smart plug service disconnected: %s", rc)
             logger.warning("MQTT smart plug service disconnected: %s", rc)
         else:
         else:
             logger.info("MQTT smart plug service disconnected cleanly")
             logger.info("MQTT smart plug service disconnected cleanly")
+        if self._disconnection_event:
+            self._disconnection_event.set()
 
 
     def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage):
     def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage):
         """Handle incoming MQTT message, extract data using JSON path."""
         """Handle incoming MQTT message, extract data using JSON path."""
@@ -471,12 +475,14 @@ class MQTTSmartPlugService:
         timeout = timedelta(minutes=self.REACHABLE_TIMEOUT_MINUTES)
         timeout = timedelta(minutes=self.REACHABLE_TIMEOUT_MINUTES)
         return datetime.utcnow() - data.last_seen < timeout
         return datetime.utcnow() - data.last_seen < timeout
 
 
-    async def disconnect(self):
+    async def disconnect(self, timeout: float = 0):
         """Disconnect from MQTT broker."""
         """Disconnect from MQTT broker."""
         if self.client:
         if self.client:
             try:
             try:
-                self.client.loop_stop()
+                self._disconnection_event = threading.Event()
                 self.client.disconnect()
                 self.client.disconnect()
+                await asyncio.to_thread(self._disconnection_event.wait, timeout=timeout)
+                self.client.loop_stop()
             except Exception as e:
             except Exception as e:
                 logger.debug("MQTT smart plug disconnect error (ignored): %s", e)
                 logger.debug("MQTT smart plug disconnect error (ignored): %s", e)
             finally:
             finally:

+ 4 - 4
backend/app/services/printer_manager.py

@@ -223,18 +223,18 @@ class PrinterManager:
         await asyncio.sleep(1)
         await asyncio.sleep(1)
         return client.state.connected
         return client.state.connected
 
 
-    def disconnect_printer(self, printer_id: int):
+    def disconnect_printer(self, printer_id: int, timeout: float = 0):
         """Disconnect from a printer."""
         """Disconnect from a printer."""
         if printer_id in self._clients:
         if printer_id in self._clients:
-            self._clients[printer_id].disconnect()
+            self._clients[printer_id].disconnect(timeout=timeout)
             del self._clients[printer_id]
             del self._clients[printer_id]
         self._models.pop(printer_id, None)  # Clean up model cache
         self._models.pop(printer_id, None)  # Clean up model cache
         self._printer_info.pop(printer_id, None)  # Clean up printer info cache
         self._printer_info.pop(printer_id, None)  # Clean up printer info cache
 
 
-    def disconnect_all(self):
+    def disconnect_all(self, timeout: float = 0):
         """Disconnect from all printers."""
         """Disconnect from all printers."""
         for printer_id in list(self._clients.keys()):
         for printer_id in list(self._clients.keys()):
-            self.disconnect_printer(printer_id)
+            self.disconnect_printer(printer_id, timeout=timeout)
 
 
     def get_status(self, printer_id: int) -> PrinterState | None:
     def get_status(self, printer_id: int) -> PrinterState | None:
         """Get the current status of a printer (checks for stale connections)."""
         """Get the current status of a printer (checks for stale connections)."""