test_spoolman_service.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. """Unit tests for Spoolman service.
  2. These tests specifically target the sync_ams_tray method's disable_weight_sync
  3. functionality that controls whether remaining_weight is updated.
  4. """
  5. from unittest.mock import AsyncMock, Mock, patch
  6. import pytest
  7. from backend.app.services.spoolman import AMSTray, SpoolmanClient
  8. class TestSpoolmanClient:
  9. """Tests for SpoolmanClient class."""
  10. @pytest.fixture
  11. def client(self):
  12. """Create a SpoolmanClient instance."""
  13. return SpoolmanClient("http://localhost:7912")
  14. @pytest.fixture
  15. def sample_tray(self):
  16. """Create a sample AMSTray for testing."""
  17. return AMSTray(
  18. ams_id=0,
  19. tray_id=0,
  20. tray_type="PLA",
  21. tray_sub_brands="PLA Basic",
  22. tray_color="FF0000FF",
  23. remain=50,
  24. tag_uid="",
  25. tray_uuid="A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4",
  26. tray_info_idx="GFA00",
  27. tray_weight=1000,
  28. )
  29. @pytest.fixture
  30. def existing_spool(self):
  31. """Create a mock existing spool response."""
  32. return {
  33. "id": 42,
  34. "remaining_weight": 800,
  35. "extra": {"tag": '"A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4"'},
  36. "filament": {"id": 1, "name": "PLA Red", "material": "PLA"},
  37. }
  38. @pytest.fixture
  39. def mock_filament(self):
  40. """Create a mock filament response."""
  41. return {"id": 1, "name": "PLA Basic", "material": "PLA"}
  42. # ========================================================================
  43. # Tests for sync_ams_tray with disable_weight_sync
  44. # ========================================================================
  45. @pytest.mark.asyncio
  46. async def test_sync_ams_tray_updates_weight_by_default(self, client, sample_tray, existing_spool):
  47. """Verify sync_ams_tray updates remaining_weight by default."""
  48. with (
  49. patch.object(client, "find_spool_by_tag", AsyncMock(return_value=existing_spool)),
  50. patch.object(client, "update_spool", AsyncMock(return_value={"id": 42})) as mock_update,
  51. ):
  52. await client.sync_ams_tray(sample_tray, "TestPrinter")
  53. mock_update.assert_called_once()
  54. call_kwargs = mock_update.call_args.kwargs
  55. assert "remaining_weight" in call_kwargs
  56. assert call_kwargs["remaining_weight"] == 500.0 # 50% of 1000g
  57. assert "location" in call_kwargs
  58. @pytest.mark.asyncio
  59. async def test_sync_ams_tray_skips_weight_when_disabled(self, client, sample_tray, existing_spool):
  60. """Verify sync_ams_tray skips remaining_weight when disable_weight_sync=True."""
  61. with (
  62. patch.object(client, "find_spool_by_tag", AsyncMock(return_value=existing_spool)),
  63. patch.object(client, "update_spool", AsyncMock(return_value={"id": 42})) as mock_update,
  64. ):
  65. await client.sync_ams_tray(sample_tray, "TestPrinter", disable_weight_sync=True)
  66. mock_update.assert_called_once()
  67. call_kwargs = mock_update.call_args.kwargs
  68. # remaining_weight should be None (not updated)
  69. assert call_kwargs.get("remaining_weight") is None
  70. # location should still be updated
  71. assert "location" in call_kwargs
  72. assert "TestPrinter" in call_kwargs["location"]
  73. @pytest.mark.asyncio
  74. async def test_sync_ams_tray_new_spool_always_includes_weight(self, client, sample_tray, mock_filament):
  75. """Verify new spool creation always includes remaining_weight even when disabled."""
  76. with (
  77. patch.object(client, "find_spool_by_tag", AsyncMock(return_value=None)),
  78. patch.object(client, "_find_or_create_filament", AsyncMock(return_value=mock_filament)),
  79. patch.object(client, "create_spool", AsyncMock(return_value={"id": 99})) as mock_create,
  80. ):
  81. await client.sync_ams_tray(sample_tray, "TestPrinter", disable_weight_sync=True)
  82. mock_create.assert_called_once()
  83. call_kwargs = mock_create.call_args.kwargs
  84. # New spools should ALWAYS include remaining_weight
  85. assert "remaining_weight" in call_kwargs
  86. assert call_kwargs["remaining_weight"] == 500.0 # 50% of 1000g
  87. @pytest.mark.asyncio
  88. async def test_sync_ams_tray_location_format(self, client, sample_tray, existing_spool):
  89. """Verify location format is correct when updating spool."""
  90. with (
  91. patch.object(client, "find_spool_by_tag", AsyncMock(return_value=existing_spool)),
  92. patch.object(client, "update_spool", AsyncMock(return_value={"id": 42})) as mock_update,
  93. ):
  94. await client.sync_ams_tray(sample_tray, "My Printer", disable_weight_sync=True)
  95. call_kwargs = mock_update.call_args.kwargs
  96. # Location should follow pattern: "PrinterName - AMS A1"
  97. assert "location" in call_kwargs
  98. assert "My Printer" in call_kwargs["location"]
  99. assert "AMS" in call_kwargs["location"]
  100. @pytest.mark.asyncio
  101. async def test_sync_ams_tray_skips_non_bambu_spool(self, client):
  102. """Verify non-Bambu Lab spools are skipped."""
  103. # Third-party spool without proper identifiers
  104. tray = AMSTray(
  105. ams_id=0,
  106. tray_id=0,
  107. tray_type="PLA",
  108. tray_sub_brands="Third Party PLA",
  109. tray_color="FF0000FF",
  110. remain=50,
  111. tag_uid="",
  112. tray_uuid="",
  113. tray_info_idx="", # No Bambu Lab preset ID
  114. tray_weight=1000,
  115. )
  116. result = await client.sync_ams_tray(tray, "TestPrinter")
  117. assert result is None
  118. @pytest.mark.asyncio
  119. async def test_sync_ams_tray_weight_calculation(self, client, existing_spool):
  120. """Verify remaining weight is calculated correctly for various percentages."""
  121. test_cases = [
  122. (100, 1000, 1000.0), # Full spool
  123. (50, 1000, 500.0), # Half spool
  124. (25, 1000, 250.0), # Quarter spool
  125. (0, 1000, 0.0), # Empty spool
  126. (75, 500, 375.0), # Different spool weight
  127. ]
  128. for remain, weight, expected in test_cases:
  129. tray = AMSTray(
  130. ams_id=0,
  131. tray_id=0,
  132. tray_type="PLA",
  133. tray_sub_brands="PLA Basic",
  134. tray_color="FF0000FF",
  135. remain=remain,
  136. tag_uid="",
  137. tray_uuid="A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4",
  138. tray_info_idx="GFA00",
  139. tray_weight=weight,
  140. )
  141. with (
  142. patch.object(client, "find_spool_by_tag", AsyncMock(return_value=existing_spool)),
  143. patch.object(client, "update_spool", AsyncMock(return_value={"id": 42})) as mock_update,
  144. ):
  145. await client.sync_ams_tray(tray, "TestPrinter", disable_weight_sync=False)
  146. call_kwargs = mock_update.call_args.kwargs
  147. assert call_kwargs["remaining_weight"] == expected, (
  148. f"Expected {expected}g for {remain}% of {weight}g, got {call_kwargs['remaining_weight']}"
  149. )
  150. # ========================================================================
  151. # Tests for caching functionality
  152. # ========================================================================
  153. @pytest.mark.asyncio
  154. async def test_find_spool_by_tag_with_cached_spools(self, client):
  155. """Verify find_spool_by_tag uses cached spools when provided (no API call)."""
  156. cached = [
  157. {"id": 1, "extra": {"tag": '"ABC123"'}},
  158. {"id": 2, "extra": {"tag": '"XYZ789"'}},
  159. ]
  160. with patch.object(client, "get_spools", AsyncMock()) as mock_get:
  161. result = await client.find_spool_by_tag("ABC123", cached_spools=cached)
  162. assert result["id"] == 1
  163. mock_get.assert_not_called() # Should NOT call get_spools
  164. @pytest.mark.asyncio
  165. async def test_find_spool_by_tag_without_cached_spools(self, client):
  166. """Verify find_spool_by_tag fetches spools when cache not provided."""
  167. mock_spools = [{"id": 1, "extra": {"tag": '"ABC123"'}}]
  168. with patch.object(client, "get_spools", AsyncMock(return_value=mock_spools)) as mock_get:
  169. result = await client.find_spool_by_tag("ABC123")
  170. assert result["id"] == 1
  171. mock_get.assert_called_once() # Should call get_spools
  172. @pytest.mark.asyncio
  173. async def test_find_spools_by_location_prefix_with_cached_spools(self, client):
  174. """Verify find_spools_by_location_prefix uses cached spools when provided."""
  175. cached = [
  176. {"id": 1, "location": "Printer1 - AMS A1"},
  177. {"id": 2, "location": "Printer2 - AMS A1"},
  178. {"id": 3, "location": "Printer1 - AMS A2"},
  179. ]
  180. with patch.object(client, "get_spools", AsyncMock()) as mock_get:
  181. result = await client.find_spools_by_location_prefix("Printer1 - ", cached_spools=cached)
  182. assert len(result) == 2
  183. assert result[0]["id"] == 1
  184. assert result[1]["id"] == 3
  185. mock_get.assert_not_called() # Should NOT call get_spools
  186. @pytest.mark.asyncio
  187. async def test_sync_ams_tray_with_cached_spools(self, client, sample_tray, existing_spool):
  188. """Verify sync_ams_tray passes cached_spools to find_spool_by_tag."""
  189. cached = [existing_spool]
  190. with (
  191. patch.object(client, "get_spools", AsyncMock()) as mock_get,
  192. patch.object(client, "update_spool", AsyncMock(return_value={"id": 42})),
  193. ):
  194. await client.sync_ams_tray(sample_tray, "TestPrinter", cached_spools=cached)
  195. mock_get.assert_not_called() # Should NOT call get_spools
  196. @pytest.mark.asyncio
  197. async def test_clear_location_for_removed_spools_with_cached_spools(self, client):
  198. """Verify clear_location_for_removed_spools uses cached spools."""
  199. cached = [
  200. {"id": 1, "location": "Printer1 - AMS A1", "extra": {"tag": '"TAG1"'}},
  201. {"id": 2, "location": "Printer1 - AMS A2", "extra": {"tag": '"TAG2"'}},
  202. {"id": 3, "location": "Printer1 - AMS A3", "extra": {"tag": '"TAG3"'}},
  203. ]
  204. current_tags = {"TAG1", "TAG2"} # TAG3 was removed
  205. with (
  206. patch.object(client, "get_spools", AsyncMock()) as mock_get,
  207. patch.object(client, "update_spool", AsyncMock(return_value={"id": 3})) as mock_update,
  208. ):
  209. cleared = await client.clear_location_for_removed_spools("Printer1", current_tags, cached_spools=cached)
  210. assert cleared == 1
  211. mock_get.assert_not_called() # Should NOT call get_spools
  212. mock_update.assert_called_once()
  213. # Verify it cleared TAG3 (not in current_tags)
  214. call_kwargs = mock_update.call_args.kwargs
  215. assert call_kwargs["spool_id"] == 3
  216. assert call_kwargs.get("clear_location") is True
  217. # ========================================================================
  218. # Tests for retry logic in get_spools
  219. # ========================================================================
  220. @pytest.mark.asyncio
  221. async def test_get_spools_succeeds_on_first_attempt(self, client):
  222. """Verify get_spools succeeds immediately when no errors occur."""
  223. mock_spools = [{"id": 1}, {"id": 2}]
  224. with patch.object(client, "_get_client") as mock_get_client:
  225. mock_http_client = AsyncMock()
  226. mock_response = Mock()
  227. mock_response.raise_for_status = Mock()
  228. mock_response.json = Mock(return_value=mock_spools)
  229. mock_http_client.get = AsyncMock(return_value=mock_response)
  230. mock_get_client.return_value = mock_http_client
  231. result = await client.get_spools()
  232. assert result == mock_spools
  233. mock_get_client.assert_called_once()
  234. mock_http_client.get.assert_called_once()
  235. @pytest.mark.asyncio
  236. async def test_get_spools_retries_on_connection_error(self, client):
  237. """Verify get_spools retries up to 3 times on connection errors."""
  238. import httpx
  239. mock_spools = [{"id": 1}]
  240. with (
  241. patch.object(client, "_get_client") as mock_get_client,
  242. patch.object(client, "close", AsyncMock()) as mock_close,
  243. patch("asyncio.sleep", AsyncMock()) as mock_sleep,
  244. ):
  245. mock_http_client = AsyncMock()
  246. mock_get_client.return_value = mock_http_client
  247. # First 2 attempts fail with ReadError, 3rd succeeds
  248. mock_response = Mock()
  249. mock_response.raise_for_status = Mock()
  250. mock_response.json = Mock(return_value=mock_spools)
  251. mock_http_client.get = AsyncMock(
  252. side_effect=[
  253. httpx.ReadError("Connection closed"),
  254. httpx.ReadError("Connection closed"),
  255. mock_response,
  256. ]
  257. )
  258. result = await client.get_spools()
  259. assert result == mock_spools
  260. assert mock_get_client.call_count == 3
  261. assert mock_http_client.get.call_count == 3
  262. # Should close client twice (after each failed attempt)
  263. assert mock_close.call_count == 2
  264. # Should sleep twice (after first 2 attempts)
  265. assert mock_sleep.call_count == 2
  266. mock_sleep.assert_called_with(0.5)
  267. @pytest.mark.asyncio
  268. async def test_get_spools_raises_after_3_failed_attempts(self, client):
  269. """Verify get_spools raises exception after 3 failed attempts."""
  270. import httpx
  271. with (
  272. patch.object(client, "_get_client", AsyncMock()) as mock_get_client,
  273. patch.object(client, "close", AsyncMock()) as mock_close,
  274. patch("asyncio.sleep", AsyncMock()) as mock_sleep,
  275. ):
  276. mock_http_client = AsyncMock()
  277. mock_get_client.return_value = mock_http_client
  278. # All 3 attempts fail
  279. mock_http_client.get.side_effect = httpx.ReadError("Connection closed")
  280. with pytest.raises(httpx.ReadError):
  281. await client.get_spools()
  282. assert mock_get_client.call_count == 3
  283. assert mock_http_client.get.call_count == 3
  284. # Should close client twice (after first 2 failed attempts, not after 3rd)
  285. assert mock_close.call_count == 2
  286. # Should sleep twice (after first 2 attempts, not after 3rd)
  287. assert mock_sleep.call_count == 2
  288. @pytest.mark.asyncio
  289. async def test_get_spools_handles_non_connection_errors(self, client):
  290. """Verify get_spools retries on non-connection errors without recreating client."""
  291. import httpx
  292. mock_spools = [{"id": 1}]
  293. with (
  294. patch.object(client, "_get_client") as mock_get_client,
  295. patch.object(client, "close", AsyncMock()) as mock_close,
  296. patch("asyncio.sleep", AsyncMock()) as mock_sleep,
  297. ):
  298. mock_http_client = AsyncMock()
  299. mock_get_client.return_value = mock_http_client
  300. # First attempt fails with HTTP error, 2nd succeeds
  301. mock_response_error = Mock()
  302. mock_response_error.raise_for_status = Mock(
  303. side_effect=httpx.HTTPStatusError("500 Server Error", request=Mock(), response=Mock())
  304. )
  305. mock_response_success = Mock()
  306. mock_response_success.raise_for_status = Mock()
  307. mock_response_success.json = Mock(return_value=mock_spools)
  308. mock_http_client.get = AsyncMock(side_effect=[mock_response_error, mock_response_success])
  309. result = await client.get_spools()
  310. assert result == mock_spools
  311. assert mock_get_client.call_count == 2
  312. # Should NOT close client for HTTP errors (only connection errors)
  313. mock_close.assert_not_called()
  314. # Should sleep once (after first failed attempt)
  315. assert mock_sleep.call_count == 1