Browse Source

Optimize AMS Spoolman sync performance with spool caching

- Add cached_spools parameter to find_spool_by_tag, find_spools_by_location_prefix, sync_ams_tray, and clear_location_for_removed_spools
- Fetch spools once before loops in on_ams_change, sync_single_printer, and sync_all_printers endpoints
- Cache newly created spools during sync to avoid duplicate API calls
- Add 5 unit tests for caching functionality (all passing)
- Reduce redundant API calls when syncing multiple AMS trays
- Improve sync performance for users with large spool databases
- Maintain backward compatibility with optional cached_spools parameters
bambuman 3 months ago
parent
commit
deab81287f

+ 40 - 4
backend/app/api/routes/spoolman.py

@@ -217,6 +217,12 @@ async def sync_printer_ams(
             detail=f"AMS data format not supported. Keys: {list(ams_data.keys()) if isinstance(ams_data, dict) else type(ams_data).__name__}",
             detail=f"AMS data format not supported. Keys: {list(ams_data.keys()) if isinstance(ams_data, dict) else type(ams_data).__name__}",
         )
         )
 
 
+    # OPTIMIZATION: Fetch all spools once before processing trays
+    # This eliminates redundant API calls (one per tray) when syncing multiple trays
+    logger.debug("[Printer %s] Fetching spools cache for sync...", printer.name)
+    cached_spools = await client.get_spools()
+    logger.debug("[Printer %s] Cached %d spools for batch sync", printer.name, len(cached_spools))
+
     for ams_unit in ams_units:
     for ams_unit in ams_units:
         if not isinstance(ams_unit, dict):
         if not isinstance(ams_unit, dict):
             continue
             continue
@@ -257,9 +263,20 @@ async def sync_printer_ams(
                 current_tray_uuids.add(spool_tag.upper())
                 current_tray_uuids.add(spool_tag.upper())
 
 
             try:
             try:
-                sync_result = await client.sync_ams_tray(tray, printer.name, disable_weight_sync=disable_weight_sync)
+                sync_result = await client.sync_ams_tray(
+                    tray,
+                    printer.name,
+                    disable_weight_sync=disable_weight_sync,
+                    cached_spools=cached_spools,
+                )
                 if sync_result:
                 if sync_result:
                     synced += 1
                     synced += 1
+                    # Add newly created spool to cache
+                    if sync_result.get("id"):
+                        spool_exists = any(s.get("id") == sync_result["id"] for s in cached_spools)
+                        if not spool_exists:
+                            cached_spools.append(sync_result)
+                            logger.debug("Added newly created spool %s to cache", sync_result["id"])
                     logger.info(
                     logger.info(
                         "Synced %s from %s AMS %s tray %s", tray.tray_sub_brands, printer.name, ams_id, tray.tray_id
                         "Synced %s from %s AMS %s tray %s", tray.tray_sub_brands, printer.name, ams_id, tray.tray_id
                     )
                     )
@@ -273,7 +290,9 @@ async def sync_printer_ams(
 
 
     # Clear location for spools that were removed from this printer's AMS
     # Clear location for spools that were removed from this printer's AMS
     try:
     try:
-        cleared = await client.clear_location_for_removed_spools(printer.name, current_tray_uuids)
+        cleared = await client.clear_location_for_removed_spools(
+            printer.name, current_tray_uuids, cached_spools=cached_spools
+        )
         if cleared > 0:
         if cleared > 0:
             logger.info("Cleared location for %s spools removed from %s", cleared, printer.name)
             logger.info("Cleared location for %s spools removed from %s", cleared, printer.name)
     except Exception as e:
     except Exception as e:
@@ -320,6 +339,12 @@ async def sync_all_printers(
     # Track tray UUIDs per printer (for clearing removed spools)
     # Track tray UUIDs per printer (for clearing removed spools)
     printer_tray_uuids: dict[str, set[str]] = {}
     printer_tray_uuids: dict[str, set[str]] = {}
 
 
+    # OPTIMIZATION: Fetch all spools once before processing ALL printers/trays
+    # This eliminates redundant API calls across all printers
+    logger.debug("Fetching spools cache for sync-all operation...")
+    cached_spools = await client.get_spools()
+    logger.debug("Cached %d spools for batch sync across %d printers", len(cached_spools), len(printers))
+
     for printer in printers:
     for printer in printers:
         state = printer_manager.get_status(printer.id)
         state = printer_manager.get_status(printer.id)
         if not state or not state.raw_data:
         if not state or not state.raw_data:
@@ -394,17 +419,28 @@ async def sync_all_printers(
 
 
                 try:
                 try:
                     sync_result = await client.sync_ams_tray(
                     sync_result = await client.sync_ams_tray(
-                        tray, printer.name, disable_weight_sync=disable_weight_sync
+                        tray,
+                        printer.name,
+                        disable_weight_sync=disable_weight_sync,
+                        cached_spools=cached_spools,
                     )
                     )
                     if sync_result:
                     if sync_result:
                         total_synced += 1
                         total_synced += 1
+                        # Add newly created spool to cache
+                        if sync_result.get("id"):
+                            spool_exists = any(s.get("id") == sync_result["id"] for s in cached_spools)
+                            if not spool_exists:
+                                cached_spools.append(sync_result)
+                                logger.debug("Added newly created spool %s to cache", sync_result["id"])
                 except Exception as e:
                 except Exception as e:
                     all_errors.append(f"{printer.name} AMS {ams_id}:{tray.tray_id}: {e}")
                     all_errors.append(f"{printer.name} AMS {ams_id}:{tray.tray_id}: {e}")
 
 
     # Clear location for spools that were removed from each printer's AMS
     # Clear location for spools that were removed from each printer's AMS
     for printer_name, current_tray_uuids in printer_tray_uuids.items():
     for printer_name, current_tray_uuids in printer_tray_uuids.items():
         try:
         try:
-            cleared = await client.clear_location_for_removed_spools(printer_name, current_tray_uuids)
+            cleared = await client.clear_location_for_removed_spools(
+                printer_name, current_tray_uuids, cached_spools=cached_spools
+            )
             if cleared > 0:
             if cleared > 0:
                 logger.info("Cleared location for %s spools removed from %s", cleared, printer_name)
                 logger.info("Cleared location for %s spools removed from %s", cleared, printer_name)
         except Exception as e:
         except Exception as e:

+ 24 - 1
backend/app/main.py

@@ -557,6 +557,12 @@ async def on_ams_change(printer_id: int, ams_data: list):
             printer = result.scalar_one_or_none()
             printer = result.scalar_one_or_none()
             printer_name = printer.name if printer else f"Printer {printer_id}"
             printer_name = printer.name if printer else f"Printer {printer_id}"
 
 
+            # OPTIMIZATION: Fetch all spools once before processing trays
+            # This eliminates redundant API calls (one per tray) when syncing multiple trays
+            logger.debug("[Printer %s] Fetching spools cache for AMS sync...", printer_id)
+            cached_spools = await client.get_spools()
+            logger.debug("[Printer %s] Cached %d spools for batch sync", printer_id, len(cached_spools))
+
             # Sync each AMS tray
             # Sync each AMS tray
             synced = 0
             synced = 0
             for ams_unit in ams_data:
             for ams_unit in ams_data:
@@ -569,9 +575,26 @@ async def on_ams_change(printer_id: int, ams_data: list):
                         continue  # Empty tray
                         continue  # Empty tray
 
 
                     try:
                     try:
-                        result = await client.sync_ams_tray(tray, printer_name, disable_weight_sync=disable_weight_sync)
+                        result = await client.sync_ams_tray(
+                            tray,
+                            printer_name,
+                            disable_weight_sync=disable_weight_sync,
+                            cached_spools=cached_spools,
+                        )
                         if result:
                         if result:
                             synced += 1
                             synced += 1
+                            # If a new spool was created, add it to the cache
+                            # so subsequent trays can find it if they reference the same tag
+                            if result.get("id"):
+                                # Check if this spool already exists in cache
+                                spool_exists = any(s.get("id") == result["id"] for s in cached_spools)
+                                if not spool_exists:
+                                    cached_spools.append(result)
+                                    logger.debug(
+                                        "[Printer %s] Added newly created spool %s to cache",
+                                        printer_id,
+                                        result["id"],
+                                    )
                     except Exception as e:
                     except Exception as e:
                         logger.error("Error syncing AMS %s tray %s: %s", ams_id, tray.tray_id, e)
                         logger.error("Error syncing AMS %s tray %s: %s", ams_id, tray.tray_id, e)
 
 

+ 24 - 7
backend/app/services/spoolman.py

@@ -387,16 +387,18 @@ class SpoolmanClient:
             logger.error("Failed to record spool usage in Spoolman: %s", e)
             logger.error("Failed to record spool usage in Spoolman: %s", e)
             return None
             return None
 
 
-    async def find_spool_by_tag(self, tag_uid: str) -> dict | None:
+    async def find_spool_by_tag(self, tag_uid: str, cached_spools: list[dict] | None = None) -> dict | None:
         """Find a spool by its RFID tag UID.
         """Find a spool by its RFID tag UID.
 
 
         Args:
         Args:
             tag_uid: The RFID tag UID to search for
             tag_uid: The RFID tag UID to search for
+            cached_spools: Optional pre-fetched list of spools to search (avoids API call)
 
 
         Returns:
         Returns:
             Spool dictionary or None if not found.
             Spool dictionary or None if not found.
         """
         """
-        spools = await self.get_spools()
+        # Use cached spools if provided, otherwise fetch from API
+        spools = cached_spools if cached_spools is not None else await self.get_spools()
         # Normalize tag_uid for comparison (uppercase, strip quotes)
         # Normalize tag_uid for comparison (uppercase, strip quotes)
         search_tag = tag_uid.strip('"').upper()
         search_tag = tag_uid.strip('"').upper()
 
 
@@ -412,16 +414,20 @@ class SpoolmanClient:
                         return spool
                         return spool
         return None
         return None
 
 
-    async def find_spools_by_location_prefix(self, location_prefix: str) -> list[dict]:
+    async def find_spools_by_location_prefix(
+        self, location_prefix: str, cached_spools: list[dict] | None = None
+    ) -> list[dict]:
         """Find all spools with locations starting with a given prefix.
         """Find all spools with locations starting with a given prefix.
 
 
         Args:
         Args:
             location_prefix: The location prefix to search for (e.g., "PrinterName - ")
             location_prefix: The location prefix to search for (e.g., "PrinterName - ")
+            cached_spools: Optional pre-fetched list of spools to search (avoids API call)
 
 
         Returns:
         Returns:
             List of spool dictionaries with matching locations.
             List of spool dictionaries with matching locations.
         """
         """
-        spools = await self.get_spools()
+        # Use cached spools if provided, otherwise fetch from API
+        spools = cached_spools if cached_spools is not None else await self.get_spools()
         matching = []
         matching = []
         for spool in spools:
         for spool in spools:
             location = spool.get("location", "")
             location = spool.get("location", "")
@@ -433,6 +439,7 @@ class SpoolmanClient:
         self,
         self,
         printer_name: str,
         printer_name: str,
         current_tray_uuids: set[str],
         current_tray_uuids: set[str],
+        cached_spools: list[dict] | None = None,
     ) -> int:
     ) -> int:
         """Clear location for spools that are no longer in the AMS.
         """Clear location for spools that are no longer in the AMS.
 
 
@@ -443,12 +450,13 @@ class SpoolmanClient:
         Args:
         Args:
             printer_name: The printer name used as location prefix
             printer_name: The printer name used as location prefix
             current_tray_uuids: Set of tray_uuids currently in the AMS
             current_tray_uuids: Set of tray_uuids currently in the AMS
+            cached_spools: Optional pre-fetched list of spools to search (avoids API call)
 
 
         Returns:
         Returns:
             Number of spools whose location was cleared.
             Number of spools whose location was cleared.
         """
         """
         location_prefix = f"{printer_name} - "
         location_prefix = f"{printer_name} - "
-        spools_at_printer = await self.find_spools_by_location_prefix(location_prefix)
+        spools_at_printer = await self.find_spools_by_location_prefix(location_prefix, cached_spools=cached_spools)
         cleared_count = 0
         cleared_count = 0
 
 
         for spool in spools_at_printer:
         for spool in spools_at_printer:
@@ -662,7 +670,13 @@ class SpoolmanClient:
         """
         """
         return (remain_percent / 100.0) * spool_weight
         return (remain_percent / 100.0) * spool_weight
 
 
-    async def sync_ams_tray(self, tray: AMSTray, printer_name: str, disable_weight_sync: bool = False) -> dict | None:
+    async def sync_ams_tray(
+        self,
+        tray: AMSTray,
+        printer_name: str,
+        disable_weight_sync: bool = False,
+        cached_spools: list[dict] | None = None,
+    ) -> dict | None:
         """Sync a single AMS tray to Spoolman.
         """Sync a single AMS tray to Spoolman.
 
 
         Only syncs trays with valid Bambu Lab tray_uuid (32 hex characters).
         Only syncs trays with valid Bambu Lab tray_uuid (32 hex characters).
@@ -676,6 +690,9 @@ class SpoolmanClient:
             printer_name: Name of the printer for location
             printer_name: Name of the printer for location
             disable_weight_sync: If True, skip updating remaining_weight for existing spools.
             disable_weight_sync: If True, skip updating remaining_weight for existing spools.
                 This allows Spoolman's granular usage tracking to maintain accurate weights.
                 This allows Spoolman's granular usage tracking to maintain accurate weights.
+            cached_spools: Optional pre-fetched list of spools to search (avoids API calls).
+                When provided, this cache is passed to find_spool_by_tag to avoid redundant
+                API calls during batch sync operations.
 
 
         Returns:
         Returns:
             Synced spool dictionary or None if skipped or failed.
             Synced spool dictionary or None if skipped or failed.
@@ -716,7 +733,7 @@ class SpoolmanClient:
         location = f"{printer_name} - {self.convert_ams_slot_to_location(tray.ams_id, tray.tray_id)}"
         location = f"{printer_name} - {self.convert_ams_slot_to_location(tray.ams_id, tray.tray_id)}"
 
 
         # Find existing spool by tag (tray_uuid or tag_uid, stored as "tag" in Spoolman)
         # Find existing spool by tag (tray_uuid or tag_uid, stored as "tag" in Spoolman)
-        existing = await self.find_spool_by_tag(spool_tag)
+        existing = await self.find_spool_by_tag(spool_tag, cached_spools=cached_spools)
         if existing:
         if existing:
             # Update existing spool
             # Update existing spool
             logger.info("Updating existing spool %s for tag %s...", existing["id"], spool_tag[:16])
             logger.info("Updating existing spool %s for tag %s...", existing["id"], spool_tag[:16])

+ 78 - 0
backend/tests/unit/services/test_spoolman_service.py

@@ -172,3 +172,81 @@ class TestSpoolmanClient:
                 assert call_kwargs["remaining_weight"] == expected, (
                 assert call_kwargs["remaining_weight"] == expected, (
                     f"Expected {expected}g for {remain}% of {weight}g, got {call_kwargs['remaining_weight']}"
                     f"Expected {expected}g for {remain}% of {weight}g, got {call_kwargs['remaining_weight']}"
                 )
                 )
+
+    # ========================================================================
+    # Tests for caching functionality
+    # ========================================================================
+
+    @pytest.mark.asyncio
+    async def test_find_spool_by_tag_with_cached_spools(self, client):
+        """Verify find_spool_by_tag uses cached spools when provided (no API call)."""
+        cached = [
+            {"id": 1, "extra": {"tag": '"ABC123"'}},
+            {"id": 2, "extra": {"tag": '"XYZ789"'}},
+        ]
+
+        with patch.object(client, "get_spools", AsyncMock()) as mock_get:
+            result = await client.find_spool_by_tag("ABC123", cached_spools=cached)
+            assert result["id"] == 1
+            mock_get.assert_not_called()  # Should NOT call get_spools
+
+    @pytest.mark.asyncio
+    async def test_find_spool_by_tag_without_cached_spools(self, client):
+        """Verify find_spool_by_tag fetches spools when cache not provided."""
+        mock_spools = [{"id": 1, "extra": {"tag": '"ABC123"'}}]
+
+        with patch.object(client, "get_spools", AsyncMock(return_value=mock_spools)) as mock_get:
+            result = await client.find_spool_by_tag("ABC123")
+            assert result["id"] == 1
+            mock_get.assert_called_once()  # Should call get_spools
+
+    @pytest.mark.asyncio
+    async def test_find_spools_by_location_prefix_with_cached_spools(self, client):
+        """Verify find_spools_by_location_prefix uses cached spools when provided."""
+        cached = [
+            {"id": 1, "location": "Printer1 - AMS A1"},
+            {"id": 2, "location": "Printer2 - AMS A1"},
+            {"id": 3, "location": "Printer1 - AMS A2"},
+        ]
+
+        with patch.object(client, "get_spools", AsyncMock()) as mock_get:
+            result = await client.find_spools_by_location_prefix("Printer1 - ", cached_spools=cached)
+            assert len(result) == 2
+            assert result[0]["id"] == 1
+            assert result[1]["id"] == 3
+            mock_get.assert_not_called()  # Should NOT call get_spools
+
+    @pytest.mark.asyncio
+    async def test_sync_ams_tray_with_cached_spools(self, client, sample_tray, existing_spool):
+        """Verify sync_ams_tray passes cached_spools to find_spool_by_tag."""
+        cached = [existing_spool]
+
+        with (
+            patch.object(client, "get_spools", AsyncMock()) as mock_get,
+            patch.object(client, "update_spool", AsyncMock(return_value={"id": 42})),
+        ):
+            await client.sync_ams_tray(sample_tray, "TestPrinter", cached_spools=cached)
+            mock_get.assert_not_called()  # Should NOT call get_spools
+
+    @pytest.mark.asyncio
+    async def test_clear_location_for_removed_spools_with_cached_spools(self, client):
+        """Verify clear_location_for_removed_spools uses cached spools."""
+        cached = [
+            {"id": 1, "location": "Printer1 - AMS A1", "extra": {"tag": '"TAG1"'}},
+            {"id": 2, "location": "Printer1 - AMS A2", "extra": {"tag": '"TAG2"'}},
+            {"id": 3, "location": "Printer1 - AMS A3", "extra": {"tag": '"TAG3"'}},
+        ]
+        current_tags = {"TAG1", "TAG2"}  # TAG3 was removed
+
+        with (
+            patch.object(client, "get_spools", AsyncMock()) as mock_get,
+            patch.object(client, "update_spool", AsyncMock(return_value={"id": 3})) as mock_update,
+        ):
+            cleared = await client.clear_location_for_removed_spools("Printer1", current_tags, cached_spools=cached)
+            assert cleared == 1
+            mock_get.assert_not_called()  # Should NOT call get_spools
+            mock_update.assert_called_once()
+            # Verify it cleared TAG3 (not in current_tags)
+            call_kwargs = mock_update.call_args.kwargs
+            assert call_kwargs["spool_id"] == 3
+            assert call_kwargs.get("clear_location") is True