| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602 |
- """Shared test fixtures for BamBuddy backend tests."""
- import asyncio
- import atexit
- import json
- import logging
- import os
- import shutil
- import tempfile
- from collections.abc import AsyncGenerator
- from pathlib import Path
- from unittest.mock import AsyncMock, MagicMock, patch
- import pytest
- # IMPORTANT: Set environment variables BEFORE any app imports
- # This must happen before settings/config are loaded
- os.environ["LOG_TO_FILE"] = "false"
- os.environ["DEBUG"] = "false"
- from httpx import ASGITransport, AsyncClient # noqa: E402
- from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine # noqa: E402
- # Ensure settings use our env vars - import and override before database import
- from backend.app.core.config import settings # noqa: E402
- settings.log_to_file = False
- # Use a temp directory for plate calibration to avoid deleting real calibration files
- _test_plate_cal_dir = Path(tempfile.mkdtemp(prefix="bambuddy_test_plate_cal_"))
- settings.plate_calibration_dir = _test_plate_cal_dir
- # Clean up temp directory when tests finish
- def _cleanup_test_plate_cal_dir():
- if _test_plate_cal_dir.exists():
- shutil.rmtree(_test_plate_cal_dir, ignore_errors=True)
- atexit.register(_cleanup_test_plate_cal_dir)
- from backend.app.core.database import Base # noqa: E402
- # Use in-memory SQLite for tests
- TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
- @pytest.fixture(scope="session")
- def event_loop():
- """Create an instance of the default event loop for each test session."""
- loop = asyncio.get_event_loop_policy().new_event_loop()
- yield loop
- # Dispose the module-level engine so aiosqlite worker threads finish
- # before the event loop closes, preventing "Event loop is closed" errors.
- from backend.app.core.database import engine
- loop.run_until_complete(engine.dispose())
- loop.run_until_complete(asyncio.sleep(0.05))
- loop.close()
- @pytest.fixture
- async def test_engine():
- """Create a test database engine."""
- engine = create_async_engine(TEST_DATABASE_URL, echo=False)
- # Import all models to register them
- from backend.app.models import (
- ams_history,
- ams_label,
- api_key,
- archive,
- external_link,
- filament,
- group,
- kprofile_note,
- maintenance,
- notification,
- notification_template,
- print_queue,
- printer,
- project,
- settings,
- smart_plug,
- spool,
- spool_assignment,
- spool_usage_history,
- user,
- )
- async with engine.begin() as conn:
- await conn.run_sync(Base.metadata.create_all)
- yield engine
- async with engine.begin() as conn:
- await conn.run_sync(Base.metadata.drop_all)
- await engine.dispose()
- # Allow aiosqlite's background thread to finish processing the close
- # response before the per-function event loop shuts down, preventing
- # "RuntimeError: Event loop is closed" in call_soon_threadsafe.
- await asyncio.sleep(0.1)
- @pytest.fixture
- async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
- """Create a test database session."""
- async_session_maker = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
- async with async_session_maker() as session:
- yield session
- @pytest.fixture
- async def async_client(test_engine, db_session) -> AsyncGenerator[AsyncClient, None]:
- """Create an async test client."""
- from backend.app.core.database import async_session, get_db
- from backend.app.main import app
- # Create a new session maker for the test engine
- test_async_session = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
- async def override_get_db():
- async with test_async_session() as session:
- yield session
- app.dependency_overrides[get_db] = override_get_db
- # Mock init_printer_connections to prevent MQTT connection attempts during tests
- async def mock_init_printer_connections(db):
- pass # No-op - don't connect to real printers
- # Also patch the module-level async_session used by services, auth, and middleware
- with (
- patch("backend.app.core.database.async_session", test_async_session),
- patch("backend.app.core.auth.async_session", test_async_session),
- patch("backend.app.main.async_session", test_async_session),
- patch("backend.app.main.init_printer_connections", mock_init_printer_connections),
- ):
- # Seed default groups for tests that need them
- from backend.app.core.database import seed_default_groups
- await seed_default_groups()
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
- yield client
- # The app lifespan called init_db() which used the module-level engine
- # (not the test engine), creating aiosqlite connections. Dispose those
- # connections so their background threads finish before the event loop closes.
- from backend.app.core.database import engine as real_engine
- await real_engine.dispose()
- app.dependency_overrides.clear()
- # ============================================================================
- # Mock External Services
- # ============================================================================
- @pytest.fixture
- def mock_tasmota_service():
- """Mock the Tasmota service for smart plug tests."""
- # Patch both the module where it's defined and where it's imported
- with (
- patch("backend.app.services.tasmota.tasmota_service") as mock,
- patch("backend.app.api.routes.smart_plugs.tasmota_service") as mock2,
- ):
- mock.turn_on = AsyncMock(return_value=True)
- mock.turn_off = AsyncMock(return_value=True)
- mock.toggle = AsyncMock(return_value=True)
- mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test Plug"})
- mock.get_energy = AsyncMock(
- return_value={
- "power": 150.5,
- "voltage": 120.0,
- "current": 1.25,
- "today": 2.5,
- "total": 100.0,
- "factor": 0.95,
- }
- )
- mock.test_connection = AsyncMock(return_value={"success": True, "state": "ON", "device_name": "Test Plug"})
- # Copy mocks to second patch target
- mock2.turn_on = mock.turn_on
- mock2.turn_off = mock.turn_off
- mock2.toggle = mock.toggle
- mock2.get_status = mock.get_status
- mock2.get_energy = mock.get_energy
- mock2.test_connection = mock.test_connection
- yield mock
- @pytest.fixture
- def mock_homeassistant_service():
- """Mock the Home Assistant service for smart plug tests."""
- # Patch both the module where it's defined and where it's imported
- with (
- patch("backend.app.services.homeassistant.homeassistant_service") as mock,
- patch("backend.app.api.routes.smart_plugs.homeassistant_service") as mock2,
- ):
- mock.turn_on = AsyncMock(return_value=True)
- mock.turn_off = AsyncMock(return_value=True)
- mock.toggle = AsyncMock(return_value=True)
- mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test HA Entity"})
- mock.get_energy = AsyncMock(return_value=None) # Most HA entities don't have power monitoring
- mock.test_connection = AsyncMock(return_value={"success": True, "message": "API running", "error": None})
- mock.list_entities = AsyncMock(
- return_value=[
- {
- "entity_id": "switch.printer_plug",
- "friendly_name": "Printer Plug",
- "state": "on",
- "domain": "switch",
- },
- {"entity_id": "switch.test", "friendly_name": "Test Switch", "state": "off", "domain": "switch"},
- ]
- )
- mock.configure = MagicMock()
- # Copy mocks to second patch target
- mock2.turn_on = mock.turn_on
- mock2.turn_off = mock.turn_off
- mock2.toggle = mock.toggle
- mock2.get_status = mock.get_status
- mock2.get_energy = mock.get_energy
- mock2.test_connection = mock.test_connection
- mock2.list_entities = mock.list_entities
- mock2.configure = mock.configure
- yield mock
- @pytest.fixture
- def mock_mqtt_client():
- """Mock the MQTT client for printer communication tests."""
- with patch("backend.app.services.bambu_mqtt.BambuMQTTClient") as mock:
- instance = MagicMock()
- instance.state = MagicMock(connected=True, state="IDLE", progress=0, temperatures={"nozzle": 25, "bed": 25})
- instance.connect = MagicMock()
- instance.disconnect = MagicMock()
- mock.return_value = instance
- yield mock
- @pytest.fixture
- def mock_mqtt_smart_plug_service():
- """Mock the MQTT smart plug service for MQTT plug tests."""
- with patch("backend.app.api.routes.smart_plugs.mqtt_relay") as mock:
- # Create a mock smart_plug_service
- mock_service = MagicMock()
- mock_service.is_configured = MagicMock(return_value=True)
- mock_service.has_broker_settings = MagicMock(return_value=True)
- mock_service.configure = AsyncMock(return_value=True)
- mock_service.subscribe = MagicMock()
- mock_service.unsubscribe = MagicMock()
- mock_service.get_plug_data = MagicMock(return_value=None)
- mock_service.is_reachable = MagicMock(return_value=False)
- mock.smart_plug_service = mock_service
- yield mock
- @pytest.fixture
- def mock_ftp_client():
- """Mock the FTP client for file transfer tests."""
- with (
- patch("backend.app.services.bambu_ftp.download_file_async") as download_mock,
- patch("backend.app.services.bambu_ftp.list_files_async") as list_mock,
- ):
- download_mock.return_value = True
- list_mock.return_value = []
- yield {"download": download_mock, "list": list_mock}
- @pytest.fixture
- def mock_httpx_client():
- """Mock httpx for webhook/notification HTTP calls."""
- with patch("httpx.AsyncClient") as mock_class:
- mock_instance = AsyncMock()
- mock_response = MagicMock()
- mock_response.status_code = 200
- mock_response.text = "OK"
- mock_response.json.return_value = {}
- mock_instance.get = AsyncMock(return_value=mock_response)
- mock_instance.post = AsyncMock(return_value=mock_response)
- mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
- mock_instance.__aexit__ = AsyncMock()
- mock_class.return_value = mock_instance
- yield mock_instance
- @pytest.fixture
- def mock_printer_manager():
- """Mock the printer manager for status checks."""
- with patch("backend.app.services.printer_manager.printer_manager") as mock:
- mock.get_status = MagicMock(
- return_value=MagicMock(
- connected=True,
- state="IDLE",
- progress=0,
- temperatures={"nozzle": 25, "bed": 25, "chamber": 25},
- raw_data={},
- )
- )
- mock.mark_printer_offline = MagicMock()
- yield mock
- # ============================================================================
- # Factory Fixtures for Test Data
- # ============================================================================
- @pytest.fixture
- def smart_plug_factory(db_session):
- """Factory to create test smart plugs."""
- async def _create_plug(**kwargs):
- from backend.app.models.smart_plug import SmartPlug
- # Determine defaults based on plug_type
- plug_type = kwargs.get("plug_type", "tasmota")
- defaults = {
- "name": "Test Plug",
- "plug_type": plug_type,
- "enabled": True,
- "auto_on": True,
- "auto_off": True,
- "off_delay_mode": "time",
- "off_delay_minutes": 5,
- "off_temp_threshold": 70,
- "schedule_enabled": False,
- "power_alert_enabled": False,
- }
- # Set required fields based on plug_type
- if plug_type == "homeassistant":
- defaults["ha_entity_id"] = "switch.test"
- defaults["ip_address"] = None
- elif plug_type == "mqtt":
- # Legacy fields (for backward compatibility tests)
- defaults["mqtt_topic"] = kwargs.get("mqtt_topic", "test/topic")
- defaults["mqtt_multiplier"] = kwargs.get("mqtt_multiplier", 1.0)
- # New separate topic/path/multiplier fields
- defaults["mqtt_power_topic"] = kwargs.get("mqtt_power_topic")
- defaults["mqtt_power_path"] = kwargs.get("mqtt_power_path", "power")
- defaults["mqtt_power_multiplier"] = kwargs.get("mqtt_power_multiplier", 1.0)
- defaults["mqtt_energy_topic"] = kwargs.get("mqtt_energy_topic")
- defaults["mqtt_energy_path"] = kwargs.get("mqtt_energy_path")
- defaults["mqtt_energy_multiplier"] = kwargs.get("mqtt_energy_multiplier", 1.0)
- defaults["mqtt_state_topic"] = kwargs.get("mqtt_state_topic")
- defaults["mqtt_state_path"] = kwargs.get("mqtt_state_path")
- defaults["mqtt_state_on_value"] = kwargs.get("mqtt_state_on_value")
- defaults["ip_address"] = None
- defaults["ha_entity_id"] = None
- else:
- defaults["ip_address"] = "192.168.1.100"
- defaults["ha_entity_id"] = None
- defaults.update(kwargs)
- plug = SmartPlug(**defaults)
- db_session.add(plug)
- await db_session.commit()
- await db_session.refresh(plug)
- return plug
- return _create_plug
- @pytest.fixture
- def printer_factory(db_session):
- """Factory to create test printers."""
- _counter = [0] # Use list to allow mutation in nested function
- async def _create_printer(**kwargs):
- from backend.app.models.printer import Printer
- _counter[0] += 1
- counter = _counter[0]
- defaults = {
- "name": "Test Printer",
- "serial_number": f"00M09A{counter:09d}", # Unique serial per printer
- "ip_address": f"192.168.1.{100 + counter}", # Unique IP per printer
- "access_code": "12345678",
- "is_active": True,
- "auto_archive": True,
- "model": "X1C",
- }
- defaults.update(kwargs)
- printer = Printer(**defaults)
- db_session.add(printer)
- await db_session.commit()
- await db_session.refresh(printer)
- return printer
- return _create_printer
- @pytest.fixture
- def notification_provider_factory(db_session):
- """Factory to create test notification providers."""
- async def _create_provider(**kwargs):
- from backend.app.models.notification import NotificationProvider
- config = kwargs.pop("config", {"server": "https://ntfy.sh", "topic": "test-topic"})
- if isinstance(config, dict):
- config = json.dumps(config)
- defaults = {
- "name": "Test Provider",
- "provider_type": "ntfy",
- "enabled": True,
- "config": config,
- "on_print_start": True,
- "on_print_complete": True,
- "on_print_failed": True,
- "on_print_stopped": True,
- "on_print_progress": False,
- "on_printer_offline": False,
- "on_printer_error": False,
- "on_filament_low": False,
- "on_maintenance_due": False,
- "on_ams_humidity_high": False,
- "on_ams_temperature_high": False,
- "on_bed_cooled": False,
- "quiet_hours_enabled": False,
- "daily_digest_enabled": False,
- }
- defaults.update(kwargs)
- provider = NotificationProvider(**defaults)
- db_session.add(provider)
- await db_session.commit()
- await db_session.refresh(provider)
- return provider
- return _create_provider
- @pytest.fixture
- def archive_factory(db_session):
- """Factory to create test archives."""
- async def _create_archive(printer_id: int, **kwargs):
- from backend.app.models.archive import PrintArchive
- defaults = {
- "printer_id": printer_id,
- "filename": "test_print.gcode.3mf",
- "print_name": "Test Print",
- "file_path": "archives/test/test_print.gcode.3mf",
- "file_size": 1024000,
- "status": "completed",
- "filament_type": "PLA",
- "filament_used_grams": 50.0,
- "print_time_seconds": 3600,
- }
- defaults.update(kwargs)
- archive = PrintArchive(**defaults)
- db_session.add(archive)
- await db_session.commit()
- await db_session.refresh(archive)
- return archive
- return _create_archive
- # ============================================================================
- # Sample Data Fixtures
- # ============================================================================
- @pytest.fixture
- def sample_mqtt_print_start():
- """Sample MQTT message for print start."""
- return {
- "print": {
- "command": "project_file",
- "param": "/sdcard/test.gcode.3mf",
- "subtask_name": "test_print",
- "gcode_state": "RUNNING",
- "mc_percent": 0,
- }
- }
- @pytest.fixture
- def sample_mqtt_print_complete():
- """Sample MQTT message for print complete."""
- return {
- "print": {
- "gcode_state": "FINISH",
- "mc_percent": 100,
- "subtask_name": "test_print",
- }
- }
- @pytest.fixture
- def sample_printer_status():
- """Sample printer status data."""
- return {
- "connected": True,
- "state": "IDLE",
- "progress": 0,
- "layer_num": 0,
- "total_layers": 0,
- "temperatures": {
- "nozzle": 25.0,
- "bed": 25.0,
- "chamber": 25.0,
- },
- "remaining_time": 0,
- "filename": None,
- }
- # ============================================================================
- # Log Capture Fixtures for Error Detection
- # ============================================================================
- class LogCapture(logging.Handler):
- """Handler that captures log records for testing."""
- def __init__(self):
- super().__init__()
- self.records: list[logging.LogRecord] = []
- def emit(self, record: logging.LogRecord):
- self.records.append(record)
- def clear(self):
- self.records.clear()
- def get_errors(self) -> list[logging.LogRecord]:
- """Get all ERROR and CRITICAL level records."""
- return [r for r in self.records if r.levelno >= logging.ERROR]
- def get_warnings(self) -> list[logging.LogRecord]:
- """Get all WARNING level records."""
- return [r for r in self.records if r.levelno == logging.WARNING]
- def has_errors(self) -> bool:
- """Check if any errors were logged."""
- return len(self.get_errors()) > 0
- def format_errors(self) -> str:
- """Format all errors as a string for assertion messages."""
- errors = self.get_errors()
- if not errors:
- return "No errors"
- formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
- return "\n".join(formatter.format(r) for r in errors)
- @pytest.fixture
- def capture_logs():
- """Fixture that captures log output during a test.
- Usage:
- def test_something(capture_logs):
- # Do something that might log errors
- some_function()
- # Check no errors were logged
- assert not capture_logs.has_errors(), capture_logs.format_errors()
- """
- handler = LogCapture()
- handler.setLevel(logging.DEBUG)
- # Attach to root logger to capture all logs
- root_logger = logging.getLogger()
- root_logger.addHandler(handler)
- yield handler
- root_logger.removeHandler(handler)
- @pytest.fixture
- def assert_no_log_errors(capture_logs):
- """Fixture that automatically asserts no errors were logged.
- Usage:
- def test_something(assert_no_log_errors):
- # If any ERROR logs occur during this test, it will fail
- some_function()
- """
- yield capture_logs
- errors = capture_logs.get_errors()
- if errors:
- pytest.fail(f"Unexpected log errors:\n{capture_logs.format_errors()}")
|