Browse Source

1. ArchiveService import shadowing bug (main.py:978) - Local import was shadowing module-level import, causing "cannot access local variable" error
2. Timelapse race condition (bambu_mqtt.py) - xcam data was parsed before _was_running was set, so timelapse wasn't detected at print start
3. Added exception handling (printer_manager.py) - _schedule_async now logs exceptions instead of swallowing them silently
4. Added debug logging (main.py) - [CALLBACK] logs at key decision points to help debug future issues
5. Added code quality tests (test_code_quality.py) - Static analysis to catch import shadowing bugs automatically

maziggy 5 months ago
parent
commit
ae55884858
3 changed files with 355 additions and 3 deletions
  1. 9 3
      backend/app/main.py
  2. 79 0
      backend/tests/conftest.py
  3. 267 0
      backend/tests/unit/test_code_quality.py

+ 9 - 3
backend/app/main.py

@@ -329,6 +329,8 @@ async def on_print_start(printer_id: int, data: dict):
     import logging
     import logging
     logger = logging.getLogger(__name__)
     logger = logging.getLogger(__name__)
 
 
+    logger.info(f"[CALLBACK] on_print_start called for printer {printer_id}, data keys: {list(data.keys())}")
+
     await ws_manager.send_print_start(printer_id, data)
     await ws_manager.send_print_start(printer_id, data)
 
 
     # Track if notification was sent (to avoid sending twice)
     # Track if notification was sent (to avoid sending twice)
@@ -352,6 +354,7 @@ async def on_print_start(printer_id: int, data: dict):
 
 
         if not printer or not printer.auto_archive:
         if not printer or not printer.auto_archive:
             # Send notification without archive data (auto-archive disabled)
             # Send notification without archive data (auto-archive disabled)
+            logger.info(f"[CALLBACK] Skipping archive - printer: {printer is not None}, auto_archive: {printer.auto_archive if printer else 'N/A'}")
             if not notification_sent:
             if not notification_sent:
                 await _send_print_start_notification(printer_id, data, logger=logger)
                 await _send_print_start_notification(printer_id, data, logger=logger)
             return
             return
@@ -360,10 +363,11 @@ async def on_print_start(printer_id: int, data: dict):
         filename = data.get("filename", "")
         filename = data.get("filename", "")
         subtask_name = data.get("subtask_name", "")
         subtask_name = data.get("subtask_name", "")
 
 
-        logger.info(f"Print start detected - filename: {filename}, subtask: {subtask_name}")
+        logger.info(f"[CALLBACK] Print start detected - filename: {filename}, subtask: {subtask_name}")
 
 
         if not filename and not subtask_name:
         if not filename and not subtask_name:
             # Send notification without archive data (no filename)
             # Send notification without archive data (no filename)
+            logger.info(f"[CALLBACK] Skipping archive - no filename or subtask_name")
             if not notification_sent:
             if not notification_sent:
                 await _send_print_start_notification(printer_id, data, logger=logger)
                 await _send_print_start_notification(printer_id, data, logger=logger)
             return
             return
@@ -975,13 +979,15 @@ async def on_print_complete(printer_id: int, data: dict):
             async with async_session() as db:
             async with async_session() as db:
                 from backend.app.models.printer import Printer
                 from backend.app.models.printer import Printer
                 from backend.app.models.archive import PrintArchive
                 from backend.app.models.archive import PrintArchive
-                from backend.app.services.archive import ArchiveService
+                # NOTE: ArchiveService is imported at module level (line 67)
+                # Do NOT import it here - it causes a Python scoping issue that breaks
+                # the earlier usage of ArchiveService in this function
                 from backend.app.services.bambu_ftp import list_files_async, download_file_bytes_async
                 from backend.app.services.bambu_ftp import list_files_async, download_file_bytes_async
                 from pathlib import Path
                 from pathlib import Path
                 import re
                 import re
                 from datetime import timedelta
                 from datetime import timedelta
 
 
-                # Get archive
+                # Get archive (ArchiveService from module-level import)
                 service = ArchiveService(db)
                 service = ArchiveService(db)
                 archive = await service.get_archive(archive_id)
                 archive = await service.get_archive(archive_id)
                 if not archive:
                 if not archive:

+ 79 - 0
backend/tests/conftest.py

@@ -2,6 +2,7 @@
 
 
 import asyncio
 import asyncio
 import json
 import json
+import logging
 import os
 import os
 import sys
 import sys
 import pytest
 import pytest
@@ -378,3 +379,81 @@ def sample_printer_status():
         "remaining_time": 0,
         "remaining_time": 0,
         "filename": None,
         "filename": None,
     }
     }
+
+
+# ============================================================================
+# Log Capture Fixtures for Error Detection
+# ============================================================================
+
+class LogCapture(logging.Handler):
+    """Handler that captures log records for testing."""
+
+    def __init__(self):
+        super().__init__()
+        self.records: list[logging.LogRecord] = []
+
+    def emit(self, record: logging.LogRecord):
+        self.records.append(record)
+
+    def clear(self):
+        self.records.clear()
+
+    def get_errors(self) -> list[logging.LogRecord]:
+        """Get all ERROR and CRITICAL level records."""
+        return [r for r in self.records if r.levelno >= logging.ERROR]
+
+    def get_warnings(self) -> list[logging.LogRecord]:
+        """Get all WARNING level records."""
+        return [r for r in self.records if r.levelno == logging.WARNING]
+
+    def has_errors(self) -> bool:
+        """Check if any errors were logged."""
+        return len(self.get_errors()) > 0
+
+    def format_errors(self) -> str:
+        """Format all errors as a string for assertion messages."""
+        errors = self.get_errors()
+        if not errors:
+            return "No errors"
+        formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
+        return "\n".join(formatter.format(r) for r in errors)
+
+
+@pytest.fixture
+def capture_logs():
+    """Fixture that captures log output during a test.
+
+    Usage:
+        def test_something(capture_logs):
+            # Do something that might log errors
+            some_function()
+
+            # Check no errors were logged
+            assert not capture_logs.has_errors(), capture_logs.format_errors()
+    """
+    handler = LogCapture()
+    handler.setLevel(logging.DEBUG)
+
+    # Attach to root logger to capture all logs
+    root_logger = logging.getLogger()
+    root_logger.addHandler(handler)
+
+    yield handler
+
+    root_logger.removeHandler(handler)
+
+
+@pytest.fixture
+def assert_no_log_errors(capture_logs):
+    """Fixture that automatically asserts no errors were logged.
+
+    Usage:
+        def test_something(assert_no_log_errors):
+            # If any ERROR logs occur during this test, it will fail
+            some_function()
+    """
+    yield capture_logs
+
+    errors = capture_logs.get_errors()
+    if errors:
+        pytest.fail(f"Unexpected log errors:\n{capture_logs.format_errors()}")

+ 267 - 0
backend/tests/unit/test_code_quality.py

@@ -0,0 +1,267 @@
+"""
+Code quality tests for BamBuddy backend.
+
+These tests check for common anti-patterns and code quality issues
+that could cause runtime errors but aren't caught by normal tests.
+"""
+
+import ast
+import os
+import pytest
+from pathlib import Path
+
+
+# Get the backend source directory
+BACKEND_DIR = Path(__file__).parent.parent.parent / "app"
+
+
+# Safe imports that are commonly re-imported in functions without issues
+# These are typically imported at the START of a function, not midway through
+SAFE_REIMPORT_NAMES = {
+    'logging', 're', 'os', 'sys', 'json', 'Path', 'datetime', 'timedelta',
+    'asyncio', 'time', 'typing', 'Optional', 'List', 'Dict', 'Any', 'Union',
+}
+
+
+class DangerousImportVisitor(ast.NodeVisitor):
+    """AST visitor that detects dangerous import patterns.
+
+    Specifically looks for cases where:
+    1. A name is imported at module level
+    2. The same name is imported locally in a function
+    3. The name is USED before the local import in that function
+
+    This pattern causes 'cannot access local variable' errors.
+    """
+
+    def __init__(self):
+        self.module_imports: set[str] = set()
+        self.dangerous_imports: list[tuple[str, int, str, int]] = []  # (name, import_line, function, first_use_line)
+        self.current_function: str | None = None
+        self.function_start_line: int = 0
+        self.in_function = False
+
+    def visit_Import(self, node: ast.Import):
+        for alias in node.names:
+            name = alias.asname or alias.name
+            if not self.in_function:
+                self.module_imports.add(name)
+        self.generic_visit(node)
+
+    def visit_ImportFrom(self, node: ast.ImportFrom):
+        for alias in node.names:
+            name = alias.asname or alias.name
+            if not self.in_function:
+                self.module_imports.add(name)
+        self.generic_visit(node)
+
+    def _check_function(self, node):
+        """Check a function for dangerous import patterns."""
+        if not self.in_function:
+            return
+
+        # Skip safe reimports
+        # Collect all local imports in this function
+        local_imports: dict[str, int] = {}  # name -> line number
+        name_uses: dict[str, int] = {}  # name -> first use line number
+
+        for child in ast.walk(node):
+            # Find local imports
+            if isinstance(child, (ast.Import, ast.ImportFrom)):
+                if isinstance(child, ast.Import):
+                    for alias in child.names:
+                        name = alias.asname or alias.name
+                        if name in self.module_imports and name not in SAFE_REIMPORT_NAMES:
+                            local_imports[name] = child.lineno
+                elif isinstance(child, ast.ImportFrom):
+                    for alias in child.names:
+                        name = alias.asname or alias.name
+                        if name in self.module_imports and name not in SAFE_REIMPORT_NAMES:
+                            local_imports[name] = child.lineno
+
+            # Find name uses
+            if isinstance(child, ast.Name):
+                if child.id not in name_uses:
+                    name_uses[child.id] = child.lineno
+
+        # Check for dangerous pattern: use before import
+        for name, import_line in local_imports.items():
+            if name in name_uses:
+                first_use = name_uses[name]
+                if first_use < import_line:
+                    self.dangerous_imports.append((name, import_line, self.current_function, first_use))
+
+    def visit_FunctionDef(self, node: ast.FunctionDef):
+        old_function = self.current_function
+        old_in_function = self.in_function
+        old_start_line = self.function_start_line
+
+        self.current_function = node.name
+        self.in_function = True
+        self.function_start_line = node.lineno
+
+        self._check_function(node)
+        self.generic_visit(node)
+
+        self.current_function = old_function
+        self.in_function = old_in_function
+        self.function_start_line = old_start_line
+
+    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
+        old_function = self.current_function
+        old_in_function = self.in_function
+        old_start_line = self.function_start_line
+
+        self.current_function = node.name
+        self.in_function = True
+        self.function_start_line = node.lineno
+
+        self._check_function(node)
+        self.generic_visit(node)
+
+        self.current_function = old_function
+        self.in_function = old_in_function
+        self.function_start_line = old_start_line
+
+
+def find_import_shadowing(file_path: Path) -> list[tuple[str, int, str]]:
+    """Find cases where local imports shadow module-level imports AND are used before import.
+
+    Returns list of (name, line_number, function_name) tuples.
+    """
+    try:
+        with open(file_path, 'r') as f:
+            source = f.read()
+        tree = ast.parse(source)
+        visitor = DangerousImportVisitor()
+        visitor.visit(tree)
+        # Convert (name, import_line, function, first_use_line) to (name, import_line, function)
+        return [(name, import_line, func) for name, import_line, func, _ in visitor.dangerous_imports]
+    except SyntaxError:
+        return []  # Skip files with syntax errors
+
+
+def get_python_files(directory: Path) -> list[Path]:
+    """Get all Python files in a directory recursively."""
+    return list(directory.rglob("*.py"))
+
+
+class TestImportShadowing:
+    """Tests for import shadowing anti-pattern."""
+
+    def test_no_import_shadowing_in_main(self):
+        """Check main.py has no import shadowing issues.
+
+        This test would have caught the ArchiveService scoping bug.
+        """
+        main_file = BACKEND_DIR / "main.py"
+        if not main_file.exists():
+            pytest.skip("main.py not found")
+
+        shadows = find_import_shadowing(main_file)
+
+        if shadows:
+            error_msg = "Import shadowing detected in main.py:\n"
+            for name, line, func in shadows:
+                error_msg += f"  - '{name}' at line {line} in function '{func}' shadows module-level import\n"
+            error_msg += "\nThis can cause 'cannot access local variable' errors."
+            pytest.fail(error_msg)
+
+    def test_no_import_shadowing_in_services(self):
+        """Check service files have no import shadowing issues."""
+        services_dir = BACKEND_DIR / "services"
+        if not services_dir.exists():
+            pytest.skip("services directory not found")
+
+        all_shadows = []
+        for py_file in get_python_files(services_dir):
+            shadows = find_import_shadowing(py_file)
+            for name, line, func in shadows:
+                all_shadows.append((py_file.name, name, line, func))
+
+        if all_shadows:
+            error_msg = "Import shadowing detected in services:\n"
+            for filename, name, line, func in all_shadows:
+                error_msg += f"  - {filename}: '{name}' at line {line} in function '{func}'\n"
+            pytest.fail(error_msg)
+
+    def test_no_import_shadowing_in_routes(self):
+        """Check route files have no import shadowing issues."""
+        routes_dir = BACKEND_DIR / "api" / "routes"
+        if not routes_dir.exists():
+            pytest.skip("routes directory not found")
+
+        all_shadows = []
+        for py_file in get_python_files(routes_dir):
+            shadows = find_import_shadowing(py_file)
+            for name, line, func in shadows:
+                all_shadows.append((py_file.name, name, line, func))
+
+        if all_shadows:
+            error_msg = "Import shadowing detected in routes:\n"
+            for filename, name, line, func in all_shadows:
+                error_msg += f"  - {filename}: '{name}' at line {line} in function '{func}'\n"
+            pytest.fail(error_msg)
+
+
+class TestModuleImports:
+    """Tests for module import health."""
+
+    def test_all_modules_importable(self):
+        """Verify all Python modules can be imported without errors.
+
+        This catches syntax errors and missing dependencies.
+        """
+        import importlib
+        import sys
+
+        # Modules to test importing
+        modules = [
+            "backend.app.main",
+            "backend.app.services.bambu_mqtt",
+            "backend.app.services.printer_manager",
+            "backend.app.services.archive",
+            "backend.app.services.notification_service",
+            "backend.app.services.smart_plug_manager",
+        ]
+
+        errors = []
+        for module_name in modules:
+            try:
+                # Remove from cache first to ensure fresh import
+                if module_name in sys.modules:
+                    del sys.modules[module_name]
+                importlib.import_module(module_name)
+            except Exception as e:
+                errors.append(f"{module_name}: {type(e).__name__}: {e}")
+
+        if errors:
+            pytest.fail("Failed to import modules:\n" + "\n".join(errors))
+
+
+class TestLogErrorPatterns:
+    """Tests that use log capture to detect runtime errors."""
+
+    def test_mqtt_message_processing_no_errors(self, capture_logs):
+        """Test that MQTT message processing doesn't log errors."""
+        from backend.app.services.bambu_mqtt import BambuMQTTClient
+
+        client = BambuMQTTClient(
+            ip_address="192.168.1.100",
+            serial_number="TEST123",
+            access_code="12345678",
+        )
+        client.on_print_start = lambda data: None
+        client.on_print_complete = lambda data: None
+
+        # Process a realistic print lifecycle
+        messages = [
+            {"print": {"gcode_state": "RUNNING", "gcode_file": "/test.gcode", "subtask_name": "Test"}},
+            {"print": {"gcode_state": "RUNNING", "gcode_file": "/test.gcode", "mc_percent": 50}},
+            {"print": {"gcode_state": "FINISH", "gcode_file": "/test.gcode", "subtask_name": "Test"}},
+        ]
+
+        for msg in messages:
+            client._process_message(msg)
+
+        assert not capture_logs.has_errors(), f"Errors during MQTT processing:\n{capture_logs.format_errors()}"