Просмотр исходного кода

Merge pull request #295 from bambuman/bug/spoolman-creates-dublicate-spools

bug/spoolman-creates-dublicate-spools
MartinNYHC 3 месяцев назад
Родитель
Сommit
00d60d4bb8

+ 54 - 4
backend/app/api/routes/spoolman.py

@@ -217,6 +217,19 @@ async def sync_printer_ams(
             detail=f"AMS data format not supported. Keys: {list(ams_data.keys()) if isinstance(ams_data, dict) else type(ams_data).__name__}",
         )
 
+    # OPTIMIZATION: Fetch all spools once before processing trays
+    # This eliminates redundant API calls (one per tray) when syncing multiple trays
+    logger.debug("[Printer %s] Fetching spools cache for sync...", printer.name)
+    try:
+        cached_spools = await client.get_spools()
+        logger.debug("[Printer %s] Cached %d spools for batch sync", printer.name, len(cached_spools))
+    except Exception as e:
+        logger.error("[Printer %s] Failed to fetch spools cache after retries: %s", printer.name, e)
+        raise HTTPException(
+            status_code=503,
+            detail=f"Failed to connect to Spoolman after multiple retries: {str(e)}",
+        )
+
     for ams_unit in ams_units:
         if not isinstance(ams_unit, dict):
             continue
@@ -257,9 +270,20 @@ async def sync_printer_ams(
                 current_tray_uuids.add(spool_tag.upper())
 
             try:
-                sync_result = await client.sync_ams_tray(tray, printer.name, disable_weight_sync=disable_weight_sync)
+                sync_result = await client.sync_ams_tray(
+                    tray,
+                    printer.name,
+                    disable_weight_sync=disable_weight_sync,
+                    cached_spools=cached_spools,
+                )
                 if sync_result:
                     synced += 1
+                    # Add newly created spool to cache
+                    if sync_result.get("id"):
+                        spool_exists = any(s.get("id") == sync_result["id"] for s in cached_spools)
+                        if not spool_exists:
+                            cached_spools.append(sync_result)
+                            logger.debug("Added newly created spool %s to cache", sync_result["id"])
                     logger.info(
                         "Synced %s from %s AMS %s tray %s", tray.tray_sub_brands, printer.name, ams_id, tray.tray_id
                     )
@@ -273,7 +297,9 @@ async def sync_printer_ams(
 
     # Clear location for spools that were removed from this printer's AMS
     try:
-        cleared = await client.clear_location_for_removed_spools(printer.name, current_tray_uuids)
+        cleared = await client.clear_location_for_removed_spools(
+            printer.name, current_tray_uuids, cached_spools=cached_spools
+        )
         if cleared > 0:
             logger.info("Cleared location for %s spools removed from %s", cleared, printer.name)
     except Exception as e:
@@ -320,6 +346,19 @@ async def sync_all_printers(
     # Track tray UUIDs per printer (for clearing removed spools)
     printer_tray_uuids: dict[str, set[str]] = {}
 
+    # OPTIMIZATION: Fetch all spools once before processing ALL printers/trays
+    # This eliminates redundant API calls across all printers
+    logger.debug("Fetching spools cache for sync-all operation...")
+    try:
+        cached_spools = await client.get_spools()
+        logger.debug("Cached %d spools for batch sync across %d printers", len(cached_spools), len(printers))
+    except Exception as e:
+        logger.error("Failed to fetch spools cache after retries: %s", e)
+        raise HTTPException(
+            status_code=503,
+            detail=f"Failed to connect to Spoolman after multiple retries: {str(e)}",
+        )
+
     for printer in printers:
         state = printer_manager.get_status(printer.id)
         if not state or not state.raw_data:
@@ -394,17 +433,28 @@ async def sync_all_printers(
 
                 try:
                     sync_result = await client.sync_ams_tray(
-                        tray, printer.name, disable_weight_sync=disable_weight_sync
+                        tray,
+                        printer.name,
+                        disable_weight_sync=disable_weight_sync,
+                        cached_spools=cached_spools,
                     )
                     if sync_result:
                         total_synced += 1
+                        # Add newly created spool to cache
+                        if sync_result.get("id"):
+                            spool_exists = any(s.get("id") == sync_result["id"] for s in cached_spools)
+                            if not spool_exists:
+                                cached_spools.append(sync_result)
+                                logger.debug("Added newly created spool %s to cache", sync_result["id"])
                 except Exception as e:
                     all_errors.append(f"{printer.name} AMS {ams_id}:{tray.tray_id}: {e}")
 
     # Clear location for spools that were removed from each printer's AMS
     for printer_name, current_tray_uuids in printer_tray_uuids.items():
         try:
-            cleared = await client.clear_location_for_removed_spools(printer_name, current_tray_uuids)
+            cleared = await client.clear_location_for_removed_spools(
+                printer_name, current_tray_uuids, cached_spools=cached_spools
+            )
             if cleared > 0:
                 logger.info("Cleared location for %s spools removed from %s", cleared, printer_name)
         except Exception as e:

+ 32 - 1
backend/app/main.py

@@ -556,6 +556,20 @@ async def on_ams_change(printer_id: int, ams_data: list):
             printer = result.scalar_one_or_none()
             printer_name = printer.name if printer else f"Printer {printer_id}"
 
+            # OPTIMIZATION: Fetch all spools once before processing trays
+            # This eliminates redundant API calls (one per tray) when syncing multiple trays
+            logger.debug("[Printer %s] Fetching spools cache for AMS sync...", printer_id)
+            try:
+                cached_spools = await client.get_spools()
+                logger.debug("[Printer %s] Cached %d spools for batch sync", printer_id, len(cached_spools))
+            except Exception as e:
+                logger.error(
+                    "[Printer %s] Failed to fetch spools cache after retries, aborting AMS sync: %s",
+                    printer_id,
+                    e,
+                )
+                return
+
             # Sync each AMS tray
             synced = 0
             for ams_unit in ams_data:
@@ -568,9 +582,26 @@ async def on_ams_change(printer_id: int, ams_data: list):
                         continue  # Empty tray
 
                     try:
-                        result = await client.sync_ams_tray(tray, printer_name, disable_weight_sync=disable_weight_sync)
+                        result = await client.sync_ams_tray(
+                            tray,
+                            printer_name,
+                            disable_weight_sync=disable_weight_sync,
+                            cached_spools=cached_spools,
+                        )
                         if result:
                             synced += 1
+                            # If a new spool was created, add it to the cache
+                            # so subsequent trays can find it if they reference the same tag
+                            if result.get("id"):
+                                # Check if this spool already exists in cache
+                                spool_exists = any(s.get("id") == result["id"] for s in cached_spools)
+                                if not spool_exists:
+                                    cached_spools.append(result)
+                                    logger.debug(
+                                        "[Printer %s] Added newly created spool %s to cache",
+                                        printer_id,
+                                        result["id"],
+                                    )
                     except Exception as e:
                         logger.error("Error syncing AMS %s tray %s: %s", ams_id, tray.tray_id, e)
 

+ 89 - 18
backend/app/services/spoolman.py

@@ -1,5 +1,6 @@
 """Spoolman integration service for syncing AMS filament data."""
 
+import asyncio
 import logging
 from dataclasses import dataclass
 from datetime import datetime, timezone
@@ -68,9 +69,22 @@ class SpoolmanClient:
         self._connected = False
 
     async def _get_client(self) -> httpx.AsyncClient:
-        """Get or create the HTTP client."""
+        """Get or create the HTTP client with connection pooling limits.
+
+        Configures the client to prevent idle connection issues:
+        - max_keepalive_connections=5: Limit number of persistent connections
+        - keepalive_expiry=30: Close idle connections after 30 seconds
+        - max_connections=10: Limit total connections to prevent resource exhaustion
+        """
         if self._client is None:
-            self._client = httpx.AsyncClient(timeout=10.0)
+            self._client = httpx.AsyncClient(
+                timeout=10.0,
+                limits=httpx.Limits(
+                    max_keepalive_connections=5,
+                    max_connections=10,
+                    keepalive_expiry=30.0,
+                ),
+            )
         return self._client
 
     async def close(self):
@@ -101,19 +115,59 @@ class SpoolmanClient:
         return self._connected
 
     async def get_spools(self) -> list[dict]:
-        """Get all spools from Spoolman.
+        """Get all spools from Spoolman with retry logic.
+
+        Attempts to fetch spools up to 3 times with 500ms delay between attempts.
+        This handles transient network errors like closed connections.
 
         Returns:
             List of spool dictionaries.
+
+        Raises:
+            Exception: If all 3 retry attempts fail.
         """
-        try:
-            client = await self._get_client()
-            response = await client.get(f"{self.api_url}/spool")
-            response.raise_for_status()
-            return response.json()
-        except Exception as e:
-            logger.error("Failed to get spools from Spoolman: %s", e)
-            return []
+        max_attempts = 3
+        retry_delay = 0.5  # 500ms
+
+        for attempt in range(1, max_attempts + 1):
+            try:
+                client = await self._get_client()
+                response = await client.get(f"{self.api_url}/spool")
+                response.raise_for_status()
+                spools = response.json()
+                if attempt > 1:
+                    logger.info("Successfully fetched %d spools on attempt %d", len(spools), attempt)
+                return spools
+            except (httpx.ReadError, httpx.RemoteProtocolError, httpx.ConnectError) as e:
+                # Connection-related errors - close and recreate client for next attempt
+                if attempt < max_attempts:
+                    logger.warning(
+                        "Connection error getting spools (attempt %d/%d): %s. Recreating client and retrying in %dms...",
+                        attempt,
+                        max_attempts,
+                        e,
+                        int(retry_delay * 1000),
+                    )
+                    # Close the stale client and recreate it
+                    await self.close()
+                    await asyncio.sleep(retry_delay)
+                else:
+                    logger.error("Failed to get spools from Spoolman after %d attempts: %s", max_attempts, e)
+                    raise
+            except Exception as e:
+                # Other errors (HTTP errors, JSON decode errors, etc.)
+                if attempt < max_attempts:
+                    logger.warning(
+                        "Failed to get spools from Spoolman (attempt %d/%d): %s. Retrying in %dms...",
+                        attempt,
+                        max_attempts,
+                        e,
+                        int(retry_delay * 1000),
+                    )
+                    await asyncio.sleep(retry_delay)
+                else:
+                    logger.error("Failed to get spools from Spoolman after %d attempts: %s", max_attempts, e)
+                    raise
 
     async def get_filaments(self) -> list[dict]:
         """Get all internal filaments from Spoolman.
@@ -387,16 +441,18 @@ class SpoolmanClient:
             logger.error("Failed to record spool usage in Spoolman: %s", e)
             return None
 
-    async def find_spool_by_tag(self, tag_uid: str) -> dict | None:
+    async def find_spool_by_tag(self, tag_uid: str, cached_spools: list[dict] | None = None) -> dict | None:
         """Find a spool by its RFID tag UID.
 
         Args:
             tag_uid: The RFID tag UID to search for
+            cached_spools: Optional pre-fetched list of spools to search (avoids API call)
 
         Returns:
             Spool dictionary or None if not found.
         """
-        spools = await self.get_spools()
+        # Use cached spools if provided, otherwise fetch from API
+        spools = cached_spools if cached_spools is not None else await self.get_spools()
         # Normalize tag_uid for comparison (uppercase, strip quotes)
         search_tag = tag_uid.strip('"').upper()
 
@@ -412,16 +468,20 @@ class SpoolmanClient:
                         return spool
         return None
 
-    async def find_spools_by_location_prefix(self, location_prefix: str) -> list[dict]:
+    async def find_spools_by_location_prefix(
+        self, location_prefix: str, cached_spools: list[dict] | None = None
+    ) -> list[dict]:
         """Find all spools with locations starting with a given prefix.
 
         Args:
             location_prefix: The location prefix to search for (e.g., "PrinterName - ")
+            cached_spools: Optional pre-fetched list of spools to search (avoids API call)
 
         Returns:
             List of spool dictionaries with matching locations.
         """
-        spools = await self.get_spools()
+        # Use cached spools if provided, otherwise fetch from API
+        spools = cached_spools if cached_spools is not None else await self.get_spools()
         matching = []
         for spool in spools:
             location = spool.get("location", "")
@@ -433,6 +493,7 @@ class SpoolmanClient:
         self,
         printer_name: str,
         current_tray_uuids: set[str],
+        cached_spools: list[dict] | None = None,
     ) -> int:
         """Clear location for spools that are no longer in the AMS.
 
@@ -443,12 +504,13 @@ class SpoolmanClient:
         Args:
             printer_name: The printer name used as location prefix
             current_tray_uuids: Set of tray_uuids currently in the AMS
+            cached_spools: Optional pre-fetched list of spools to search (avoids API call)
 
         Returns:
             Number of spools whose location was cleared.
         """
         location_prefix = f"{printer_name} - "
-        spools_at_printer = await self.find_spools_by_location_prefix(location_prefix)
+        spools_at_printer = await self.find_spools_by_location_prefix(location_prefix, cached_spools=cached_spools)
         cleared_count = 0
 
         for spool in spools_at_printer:
@@ -662,7 +724,13 @@ class SpoolmanClient:
         """
         return (remain_percent / 100.0) * spool_weight
 
-    async def sync_ams_tray(self, tray: AMSTray, printer_name: str, disable_weight_sync: bool = False) -> dict | None:
+    async def sync_ams_tray(
+        self,
+        tray: AMSTray,
+        printer_name: str,
+        disable_weight_sync: bool = False,
+        cached_spools: list[dict] | None = None,
+    ) -> dict | None:
         """Sync a single AMS tray to Spoolman.
 
         Only syncs trays with valid Bambu Lab tray_uuid (32 hex characters).
@@ -676,6 +744,9 @@ class SpoolmanClient:
             printer_name: Name of the printer for location
             disable_weight_sync: If True, skip updating remaining_weight for existing spools.
                 This allows Spoolman's granular usage tracking to maintain accurate weights.
+            cached_spools: Optional pre-fetched list of spools to search (avoids API calls).
+                When provided, this cache is passed to find_spool_by_tag to avoid redundant
+                API calls during batch sync operations.
 
         Returns:
             Synced spool dictionary or None if skipped or failed.
@@ -716,7 +787,7 @@ class SpoolmanClient:
         location = f"{printer_name} - {self.convert_ams_slot_to_location(tray.ams_id, tray.tray_id)}"
 
         # Find existing spool by tag (tray_uuid or tag_uid, stored as "tag" in Spoolman)
-        existing = await self.find_spool_by_tag(spool_tag)
+        existing = await self.find_spool_by_tag(spool_tag, cached_spools=cached_spools)
         if existing:
             # Update existing spool
             logger.info("Updating existing spool %s for tag %s...", existing["id"], spool_tag[:16])

+ 203 - 1
backend/tests/unit/services/test_spoolman_service.py

@@ -4,7 +4,7 @@ These tests specifically target the sync_ams_tray method's disable_weight_sync
 functionality that controls whether remaining_weight is updated.
 """
 
-from unittest.mock import AsyncMock, patch
+from unittest.mock import AsyncMock, Mock, patch
 
 import pytest
 
@@ -172,3 +172,205 @@ class TestSpoolmanClient:
                 assert call_kwargs["remaining_weight"] == expected, (
                     f"Expected {expected}g for {remain}% of {weight}g, got {call_kwargs['remaining_weight']}"
                 )
+
+    # ========================================================================
+    # Tests for caching functionality
+    # ========================================================================
+
+    @pytest.mark.asyncio
+    async def test_find_spool_by_tag_with_cached_spools(self, client):
+        """Verify find_spool_by_tag uses cached spools when provided (no API call)."""
+        cached = [
+            {"id": 1, "extra": {"tag": '"ABC123"'}},
+            {"id": 2, "extra": {"tag": '"XYZ789"'}},
+        ]
+
+        with patch.object(client, "get_spools", AsyncMock()) as mock_get:
+            result = await client.find_spool_by_tag("ABC123", cached_spools=cached)
+            assert result["id"] == 1
+            mock_get.assert_not_called()  # Should NOT call get_spools
+
+    @pytest.mark.asyncio
+    async def test_find_spool_by_tag_without_cached_spools(self, client):
+        """Verify find_spool_by_tag fetches spools when cache not provided."""
+        mock_spools = [{"id": 1, "extra": {"tag": '"ABC123"'}}]
+
+        with patch.object(client, "get_spools", AsyncMock(return_value=mock_spools)) as mock_get:
+            result = await client.find_spool_by_tag("ABC123")
+            assert result["id"] == 1
+            mock_get.assert_called_once()  # Should call get_spools
+
+    @pytest.mark.asyncio
+    async def test_find_spools_by_location_prefix_with_cached_spools(self, client):
+        """Verify find_spools_by_location_prefix uses cached spools when provided."""
+        cached = [
+            {"id": 1, "location": "Printer1 - AMS A1"},
+            {"id": 2, "location": "Printer2 - AMS A1"},
+            {"id": 3, "location": "Printer1 - AMS A2"},
+        ]
+
+        with patch.object(client, "get_spools", AsyncMock()) as mock_get:
+            result = await client.find_spools_by_location_prefix("Printer1 - ", cached_spools=cached)
+            assert len(result) == 2
+            assert result[0]["id"] == 1
+            assert result[1]["id"] == 3
+            mock_get.assert_not_called()  # Should NOT call get_spools
+
+    @pytest.mark.asyncio
+    async def test_sync_ams_tray_with_cached_spools(self, client, sample_tray, existing_spool):
+        """Verify sync_ams_tray passes cached_spools to find_spool_by_tag."""
+        cached = [existing_spool]
+
+        with (
+            patch.object(client, "get_spools", AsyncMock()) as mock_get,
+            patch.object(client, "update_spool", AsyncMock(return_value={"id": 42})),
+        ):
+            await client.sync_ams_tray(sample_tray, "TestPrinter", cached_spools=cached)
+            mock_get.assert_not_called()  # Should NOT call get_spools
+
+    @pytest.mark.asyncio
+    async def test_clear_location_for_removed_spools_with_cached_spools(self, client):
+        """Verify clear_location_for_removed_spools uses cached spools."""
+        cached = [
+            {"id": 1, "location": "Printer1 - AMS A1", "extra": {"tag": '"TAG1"'}},
+            {"id": 2, "location": "Printer1 - AMS A2", "extra": {"tag": '"TAG2"'}},
+            {"id": 3, "location": "Printer1 - AMS A3", "extra": {"tag": '"TAG3"'}},
+        ]
+        current_tags = {"TAG1", "TAG2"}  # TAG3 was removed
+
+        with (
+            patch.object(client, "get_spools", AsyncMock()) as mock_get,
+            patch.object(client, "update_spool", AsyncMock(return_value={"id": 3})) as mock_update,
+        ):
+            cleared = await client.clear_location_for_removed_spools("Printer1", current_tags, cached_spools=cached)
+            assert cleared == 1
+            mock_get.assert_not_called()  # Should NOT call get_spools
+            mock_update.assert_called_once()
+            # Verify it cleared TAG3 (not in current_tags)
+            call_kwargs = mock_update.call_args.kwargs
+            assert call_kwargs["spool_id"] == 3
+            assert call_kwargs.get("clear_location") is True
+
+    # ========================================================================
+    # Tests for retry logic in get_spools
+    # ========================================================================
+
+    @pytest.mark.asyncio
+    async def test_get_spools_succeeds_on_first_attempt(self, client):
+        """Verify get_spools succeeds immediately when no errors occur."""
+        mock_spools = [{"id": 1}, {"id": 2}]
+
+        with patch.object(client, "_get_client") as mock_get_client:
+            mock_http_client = AsyncMock()
+            mock_response = Mock()
+            mock_response.raise_for_status = Mock()
+            mock_response.json = Mock(return_value=mock_spools)
+            mock_http_client.get = AsyncMock(return_value=mock_response)
+            mock_get_client.return_value = mock_http_client
+
+            result = await client.get_spools()
+
+            assert result == mock_spools
+            mock_get_client.assert_called_once()
+            mock_http_client.get.assert_called_once()
+
+    @pytest.mark.asyncio
+    async def test_get_spools_retries_on_connection_error(self, client):
+        """Verify get_spools retries up to 3 times on connection errors."""
+        import httpx
+
+        mock_spools = [{"id": 1}]
+
+        with (
+            patch.object(client, "_get_client") as mock_get_client,
+            patch.object(client, "close", AsyncMock()) as mock_close,
+            patch("asyncio.sleep", AsyncMock()) as mock_sleep,
+        ):
+            mock_http_client = AsyncMock()
+            mock_get_client.return_value = mock_http_client
+
+            # First 2 attempts fail with ReadError, 3rd succeeds
+            mock_response = Mock()
+            mock_response.raise_for_status = Mock()
+            mock_response.json = Mock(return_value=mock_spools)
+
+            mock_http_client.get = AsyncMock(
+                side_effect=[
+                    httpx.ReadError("Connection closed"),
+                    httpx.ReadError("Connection closed"),
+                    mock_response,
+                ]
+            )
+
+            result = await client.get_spools()
+
+            assert result == mock_spools
+            assert mock_get_client.call_count == 3
+            assert mock_http_client.get.call_count == 3
+            # Should close client twice (after each failed attempt)
+            assert mock_close.call_count == 2
+            # Should sleep twice (after first 2 attempts)
+            assert mock_sleep.call_count == 2
+            mock_sleep.assert_called_with(0.5)
+
+    @pytest.mark.asyncio
+    async def test_get_spools_raises_after_3_failed_attempts(self, client):
+        """Verify get_spools raises exception after 3 failed attempts."""
+        import httpx
+
+        with (
+            patch.object(client, "_get_client", AsyncMock()) as mock_get_client,
+            patch.object(client, "close", AsyncMock()) as mock_close,
+            patch("asyncio.sleep", AsyncMock()) as mock_sleep,
+        ):
+            mock_http_client = AsyncMock()
+            mock_get_client.return_value = mock_http_client
+
+            # All 3 attempts fail
+            mock_http_client.get.side_effect = httpx.ReadError("Connection closed")
+
+            with pytest.raises(httpx.ReadError):
+                await client.get_spools()
+
+            assert mock_get_client.call_count == 3
+            assert mock_http_client.get.call_count == 3
+            # Should close client twice (after first 2 failed attempts, not after 3rd)
+            assert mock_close.call_count == 2
+            # Should sleep twice (after first 2 attempts, not after 3rd)
+            assert mock_sleep.call_count == 2
+
+    @pytest.mark.asyncio
+    async def test_get_spools_handles_non_connection_errors(self, client):
+        """Verify get_spools retries on non-connection errors without recreating client."""
+        import httpx
+
+        mock_spools = [{"id": 1}]
+
+        with (
+            patch.object(client, "_get_client") as mock_get_client,
+            patch.object(client, "close", AsyncMock()) as mock_close,
+            patch("asyncio.sleep", AsyncMock()) as mock_sleep,
+        ):
+            mock_http_client = AsyncMock()
+            mock_get_client.return_value = mock_http_client
+
+            # First attempt fails with HTTP error, 2nd succeeds
+            mock_response_error = Mock()
+            mock_response_error.raise_for_status = Mock(
+                side_effect=httpx.HTTPStatusError("500 Server Error", request=Mock(), response=Mock())
+            )
+
+            mock_response_success = Mock()
+            mock_response_success.raise_for_status = Mock()
+            mock_response_success.json = Mock(return_value=mock_spools)
+
+            mock_http_client.get = AsyncMock(side_effect=[mock_response_error, mock_response_success])
+
+            result = await client.get_spools()
+
+            assert result == mock_spools
+            assert mock_get_client.call_count == 2
+            # Should NOT close client for HTTP errors (only connection errors)
+            mock_close.assert_not_called()
+            # Should sleep once (after first failed attempt)
+            assert mock_sleep.call_count == 1