Parcourir la source

Refactor functions in cost statistics integration tests

Matteo Parenti il y a 3 mois
Parent
commit
1dd266b66e
2 fichiers modifiés avec 25 ajouts et 11 suppressions
  1. 3 0
      backend/tests/conftest.py
  2. 22 11
      backend/tests/integration/test_cost_statistics.py

+ 3 - 0
backend/tests/conftest.py

@@ -75,6 +75,9 @@ async def test_engine():
         project,
         project,
         settings,
         settings,
         smart_plug,
         smart_plug,
+        spool,
+        spool_assignment,
+        spool_usage_history,
         user,
         user,
     )
     )
 
 

+ 22 - 11
backend/tests/integration/test_cost_statistics.py

@@ -23,7 +23,7 @@ class TestArchiveCostTracking:
     async def test_archive_has_cost_field(
     async def test_archive_has_cost_field(
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
     ):
     ):
-        """Verify PrintArchive includes cost field in response."""
+        # Verify PrintArchive includes cost field in response.
         printer = await printer_factory()
         printer = await printer_factory()
         archive = await archive_factory(
         archive = await archive_factory(
             printer.id,
             printer.id,
@@ -38,13 +38,14 @@ class TestArchiveCostTracking:
         result = response.json()
         result = response.json()
         assert "cost" in result
         assert "cost" in result
         assert result["cost"] == 5.50
         assert result["cost"] == 5.50
+        await db_session.rollback()
 
 
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_archive_cost_null_when_not_set(
     async def test_archive_cost_null_when_not_set(
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
     ):
     ):
-        """Verify cost is null when not set."""
+        # Verify cost is null when not set.
         printer = await printer_factory()
         printer = await printer_factory()
         archive = await archive_factory(
         archive = await archive_factory(
             printer.id,
             printer.id,
@@ -58,6 +59,7 @@ class TestArchiveCostTracking:
         assert response.status_code == 200
         assert response.status_code == 200
         result = response.json()
         result = response.json()
         assert result["cost"] is None or result["cost"] == 0
         assert result["cost"] is None or result["cost"] == 0
+        await db_session.rollback()
 
 
 
 
 class TestStatisticsCostAggregation:
 class TestStatisticsCostAggregation:
@@ -68,7 +70,7 @@ class TestStatisticsCostAggregation:
     async def test_statistics_includes_total_cost(
     async def test_statistics_includes_total_cost(
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
     ):
     ):
-        """Verify statistics endpoint includes total_cost field."""
+        # Verify statistics endpoint includes total_cost field.
         printer = await printer_factory()
         printer = await printer_factory()
 
 
         # Create archives with costs
         # Create archives with costs
@@ -91,13 +93,14 @@ class TestStatisticsCostAggregation:
         result = response.json()
         result = response.json()
         assert "total_cost" in result
         assert "total_cost" in result
         assert result["total_cost"] == 6.25
         assert result["total_cost"] == 6.25
+        await db_session.rollback()
 
 
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_statistics_aggregates_costs_correctly(
     async def test_statistics_aggregates_costs_correctly(
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
     ):
     ):
-        """Verify statistics correctly sums costs from all archives."""
+        # Verify statistics correctly sums costs from all archives.
         printer = await printer_factory()
         printer = await printer_factory()
 
 
         # Create multiple archives with different costs
         # Create multiple archives with different costs
@@ -116,13 +119,14 @@ class TestStatisticsCostAggregation:
         result = response.json()
         result = response.json()
         expected_total = sum(costs)
         expected_total = sum(costs)
         assert result["total_cost"] == expected_total
         assert result["total_cost"] == expected_total
+        await db_session.rollback()
 
 
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_statistics_handles_null_costs(
     async def test_statistics_handles_null_costs(
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
     ):
     ):
-        """Verify statistics handles archives with null costs gracefully."""
+        # Verify statistics handles archives with null costs gracefully.
         printer = await printer_factory()
         printer = await printer_factory()
 
 
         # Mix of archives with and without costs
         # Mix of archives with and without costs
@@ -137,13 +141,14 @@ class TestStatisticsCostAggregation:
         result = response.json()
         result = response.json()
         # Should sum only non-null costs
         # Should sum only non-null costs
         assert result["total_cost"] == 4.25
         assert result["total_cost"] == 4.25
+        await db_session.rollback()
 
 
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_statistics_includes_failed_print_costs(
     async def test_statistics_includes_failed_print_costs(
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
         self, async_client: AsyncClient, archive_factory, printer_factory, db_session
     ):
     ):
-        """Verify failed prints with costs are included in statistics."""
+        # Verify failed prints with costs are included in statistics.
         printer = await printer_factory()
         printer = await printer_factory()
 
 
         await archive_factory(printer.id, status="completed", cost=5.00)
         await archive_factory(printer.id, status="completed", cost=5.00)
@@ -156,6 +161,7 @@ class TestStatisticsCostAggregation:
         result = response.json()
         result = response.json()
         # All prints should contribute to total cost
         # All prints should contribute to total cost
         assert result["total_cost"] == 8.50
         assert result["total_cost"] == 8.50
+        await db_session.rollback()
 
 
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
@@ -174,7 +180,7 @@ class TestSpoolCostPersistence:
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_spool_cost_fields_persist(self, async_client: AsyncClient, db_session):
     async def test_spool_cost_fields_persist(self, async_client: AsyncClient, db_session):
-        """Verify cost_per_kg is saved and retrieved."""
+        # Verify cost_per_kg is saved and retrieved.
         # Create a spool with cost
         # Create a spool with cost
         spool_data = {
         spool_data = {
             "material": "PLA",
             "material": "PLA",
@@ -194,11 +200,12 @@ class TestSpoolCostPersistence:
         result = get_response.json()
         result = get_response.json()
 
 
         assert result["cost_per_kg"] == 25.50
         assert result["cost_per_kg"] == 25.50
+        await db_session.rollback()
 
 
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_spool_update_cost_fields(self, async_client: AsyncClient, db_session):
     async def test_spool_update_cost_fields(self, async_client: AsyncClient, db_session):
-        """Verify cost fields can be updated."""
+        # Verify cost fields can be updated.
         # Create spool without cost
         # Create spool without cost
         spool_data = {
         spool_data = {
             "material": "PETG",
             "material": "PETG",
@@ -221,11 +228,12 @@ class TestSpoolCostPersistence:
 
 
         result = update_response.json()
         result = update_response.json()
         assert result["cost_per_kg"] == 30.00
         assert result["cost_per_kg"] == 30.00
+        await db_session.rollback()
 
 
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_spool_cost_null_by_default(self, async_client: AsyncClient, db_session):
     async def test_spool_cost_null_by_default(self, async_client: AsyncClient, db_session):
-        """Verify cost_per_kg defaults to null when not provided."""
+        # Verify cost_per_kg defaults to null when not provided.
         spool_data = {
         spool_data = {
             "material": "ABS",
             "material": "ABS",
             "label_weight": 1000,
             "label_weight": 1000,
@@ -237,6 +245,7 @@ class TestSpoolCostPersistence:
 
 
         result = create_response.json()
         result = create_response.json()
         assert result["cost_per_kg"] is None
         assert result["cost_per_kg"] is None
+        await db_session.rollback()
 
 
 
 
 class TestCostCalculationScenarios:
 class TestCostCalculationScenarios:
@@ -245,7 +254,7 @@ class TestCostCalculationScenarios:
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_cost_with_multiple_colors(self, async_client: AsyncClient, printer_factory, db_session):
     async def test_cost_with_multiple_colors(self, async_client: AsyncClient, printer_factory, db_session):
-        """Verify cost tracking works for multi-color prints."""
+        # Verify cost tracking works for multi-color prints.
 
 
         # Create two spools with different costs
         # Create two spools with different costs
         spool1_data = {
         spool1_data = {
@@ -271,11 +280,12 @@ class TestCostCalculationScenarios:
         # Verify spools created with correct costs
         # Verify spools created with correct costs
         assert spool1_response.json()["cost_per_kg"] == 20.00
         assert spool1_response.json()["cost_per_kg"] == 20.00
         assert spool2_response.json()["cost_per_kg"] == 25.00
         assert spool2_response.json()["cost_per_kg"] == 25.00
+        await db_session.rollback()
 
 
     @pytest.mark.asyncio
     @pytest.mark.asyncio
     @pytest.mark.integration
     @pytest.mark.integration
     async def test_cost_precision(self, async_client: AsyncClient, db_session):
     async def test_cost_precision(self, async_client: AsyncClient, db_session):
-        """Verify cost calculations maintain proper precision."""
+        # Verify cost calculations maintain proper precision.
         # Create spool with specific cost
         # Create spool with specific cost
         spool_data = {
         spool_data = {
             "material": "PLA",
             "material": "PLA",
@@ -291,3 +301,4 @@ class TestCostCalculationScenarios:
         result = response.json()
         result = response.json()
         # Verify precision is maintained
         # Verify precision is maintained
         assert result["cost_per_kg"] == 19.99
         assert result["cost_per_kg"] == 19.99
+        await db_session.rollback()