conftest.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. """Shared test fixtures for BamBuddy backend tests."""
  2. import asyncio
  3. import atexit
  4. import json
  5. import logging
  6. import os
  7. import shutil
  8. import tempfile
  9. from collections.abc import AsyncGenerator
  10. from pathlib import Path
  11. from unittest.mock import AsyncMock, MagicMock, patch
  12. import pytest
  13. # IMPORTANT: Set environment variables BEFORE any app imports
  14. # This must happen before settings/config are loaded
  15. os.environ["LOG_TO_FILE"] = "false"
  16. os.environ["DEBUG"] = "false"
  17. from httpx import ASGITransport, AsyncClient # noqa: E402
  18. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine # noqa: E402
  19. # Ensure settings use our env vars - import and override before database import
  20. from backend.app.core.config import settings # noqa: E402
  21. settings.log_to_file = False
  22. # Use a temp directory for plate calibration to avoid deleting real calibration files
  23. _test_plate_cal_dir = Path(tempfile.mkdtemp(prefix="bambuddy_test_plate_cal_"))
  24. settings.plate_calibration_dir = _test_plate_cal_dir
  25. # Clean up temp directory when tests finish
  26. def _cleanup_test_plate_cal_dir():
  27. if _test_plate_cal_dir.exists():
  28. shutil.rmtree(_test_plate_cal_dir, ignore_errors=True)
  29. atexit.register(_cleanup_test_plate_cal_dir)
  30. from backend.app.core.database import Base # noqa: E402
  31. # Use in-memory SQLite for tests
  32. TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
  33. @pytest.fixture(scope="session")
  34. def event_loop():
  35. """Create an instance of the default event loop for each test session."""
  36. loop = asyncio.get_event_loop_policy().new_event_loop()
  37. yield loop
  38. loop.close()
  39. @pytest.fixture
  40. async def test_engine():
  41. """Create a test database engine."""
  42. engine = create_async_engine(TEST_DATABASE_URL, echo=False)
  43. # Import all models to register them
  44. from backend.app.models import (
  45. ams_history,
  46. api_key,
  47. archive,
  48. external_link,
  49. filament,
  50. group,
  51. kprofile_note,
  52. maintenance,
  53. notification,
  54. notification_template,
  55. print_queue,
  56. printer,
  57. project,
  58. settings,
  59. smart_plug,
  60. spool,
  61. spool_assignment,
  62. spool_usage_history,
  63. user,
  64. )
  65. async with engine.begin() as conn:
  66. await conn.run_sync(Base.metadata.create_all)
  67. yield engine
  68. async with engine.begin() as conn:
  69. await conn.run_sync(Base.metadata.drop_all)
  70. await engine.dispose()
  71. @pytest.fixture
  72. async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
  73. """Create a test database session."""
  74. async_session_maker = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
  75. async with async_session_maker() as session:
  76. yield session
  77. @pytest.fixture
  78. async def async_client(test_engine, db_session) -> AsyncGenerator[AsyncClient, None]:
  79. """Create an async test client."""
  80. from backend.app.core.database import async_session, get_db
  81. from backend.app.main import app
  82. # Create a new session maker for the test engine
  83. test_async_session = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
  84. async def override_get_db():
  85. async with test_async_session() as session:
  86. yield session
  87. app.dependency_overrides[get_db] = override_get_db
  88. # Mock init_printer_connections to prevent MQTT connection attempts during tests
  89. async def mock_init_printer_connections(db):
  90. pass # No-op - don't connect to real printers
  91. # Also patch the module-level async_session used by services, auth, and middleware
  92. with (
  93. patch("backend.app.core.database.async_session", test_async_session),
  94. patch("backend.app.core.auth.async_session", test_async_session),
  95. patch("backend.app.main.async_session", test_async_session),
  96. patch("backend.app.main.init_printer_connections", mock_init_printer_connections),
  97. ):
  98. # Seed default groups for tests that need them
  99. from backend.app.core.database import seed_default_groups
  100. await seed_default_groups()
  101. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  102. yield client
  103. app.dependency_overrides.clear()
  104. # ============================================================================
  105. # Mock External Services
  106. # ============================================================================
  107. @pytest.fixture
  108. def mock_tasmota_service():
  109. """Mock the Tasmota service for smart plug tests."""
  110. # Patch both the module where it's defined and where it's imported
  111. with (
  112. patch("backend.app.services.tasmota.tasmota_service") as mock,
  113. patch("backend.app.api.routes.smart_plugs.tasmota_service") as mock2,
  114. ):
  115. mock.turn_on = AsyncMock(return_value=True)
  116. mock.turn_off = AsyncMock(return_value=True)
  117. mock.toggle = AsyncMock(return_value=True)
  118. mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test Plug"})
  119. mock.get_energy = AsyncMock(
  120. return_value={
  121. "power": 150.5,
  122. "voltage": 120.0,
  123. "current": 1.25,
  124. "today": 2.5,
  125. "total": 100.0,
  126. "factor": 0.95,
  127. }
  128. )
  129. mock.test_connection = AsyncMock(return_value={"success": True, "state": "ON", "device_name": "Test Plug"})
  130. # Copy mocks to second patch target
  131. mock2.turn_on = mock.turn_on
  132. mock2.turn_off = mock.turn_off
  133. mock2.toggle = mock.toggle
  134. mock2.get_status = mock.get_status
  135. mock2.get_energy = mock.get_energy
  136. mock2.test_connection = mock.test_connection
  137. yield mock
  138. @pytest.fixture
  139. def mock_homeassistant_service():
  140. """Mock the Home Assistant service for smart plug tests."""
  141. # Patch both the module where it's defined and where it's imported
  142. with (
  143. patch("backend.app.services.homeassistant.homeassistant_service") as mock,
  144. patch("backend.app.api.routes.smart_plugs.homeassistant_service") as mock2,
  145. ):
  146. mock.turn_on = AsyncMock(return_value=True)
  147. mock.turn_off = AsyncMock(return_value=True)
  148. mock.toggle = AsyncMock(return_value=True)
  149. mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test HA Entity"})
  150. mock.get_energy = AsyncMock(return_value=None) # Most HA entities don't have power monitoring
  151. mock.test_connection = AsyncMock(return_value={"success": True, "message": "API running", "error": None})
  152. mock.list_entities = AsyncMock(
  153. return_value=[
  154. {
  155. "entity_id": "switch.printer_plug",
  156. "friendly_name": "Printer Plug",
  157. "state": "on",
  158. "domain": "switch",
  159. },
  160. {"entity_id": "switch.test", "friendly_name": "Test Switch", "state": "off", "domain": "switch"},
  161. ]
  162. )
  163. mock.configure = MagicMock()
  164. # Copy mocks to second patch target
  165. mock2.turn_on = mock.turn_on
  166. mock2.turn_off = mock.turn_off
  167. mock2.toggle = mock.toggle
  168. mock2.get_status = mock.get_status
  169. mock2.get_energy = mock.get_energy
  170. mock2.test_connection = mock.test_connection
  171. mock2.list_entities = mock.list_entities
  172. mock2.configure = mock.configure
  173. yield mock
  174. @pytest.fixture
  175. def mock_mqtt_client():
  176. """Mock the MQTT client for printer communication tests."""
  177. with patch("backend.app.services.bambu_mqtt.BambuMQTTClient") as mock:
  178. instance = MagicMock()
  179. instance.state = MagicMock(connected=True, state="IDLE", progress=0, temperatures={"nozzle": 25, "bed": 25})
  180. instance.connect = MagicMock()
  181. instance.disconnect = MagicMock()
  182. mock.return_value = instance
  183. yield mock
  184. @pytest.fixture
  185. def mock_mqtt_smart_plug_service():
  186. """Mock the MQTT smart plug service for MQTT plug tests."""
  187. with patch("backend.app.api.routes.smart_plugs.mqtt_relay") as mock:
  188. # Create a mock smart_plug_service
  189. mock_service = MagicMock()
  190. mock_service.is_configured = MagicMock(return_value=True)
  191. mock_service.has_broker_settings = MagicMock(return_value=True)
  192. mock_service.configure = AsyncMock(return_value=True)
  193. mock_service.subscribe = MagicMock()
  194. mock_service.unsubscribe = MagicMock()
  195. mock_service.get_plug_data = MagicMock(return_value=None)
  196. mock_service.is_reachable = MagicMock(return_value=False)
  197. mock.smart_plug_service = mock_service
  198. yield mock
  199. @pytest.fixture
  200. def mock_ftp_client():
  201. """Mock the FTP client for file transfer tests."""
  202. with (
  203. patch("backend.app.services.bambu_ftp.download_file_async") as download_mock,
  204. patch("backend.app.services.bambu_ftp.list_files_async") as list_mock,
  205. ):
  206. download_mock.return_value = True
  207. list_mock.return_value = []
  208. yield {"download": download_mock, "list": list_mock}
  209. @pytest.fixture
  210. def mock_httpx_client():
  211. """Mock httpx for webhook/notification HTTP calls."""
  212. with patch("httpx.AsyncClient") as mock_class:
  213. mock_instance = AsyncMock()
  214. mock_response = MagicMock()
  215. mock_response.status_code = 200
  216. mock_response.text = "OK"
  217. mock_response.json.return_value = {}
  218. mock_instance.get = AsyncMock(return_value=mock_response)
  219. mock_instance.post = AsyncMock(return_value=mock_response)
  220. mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
  221. mock_instance.__aexit__ = AsyncMock()
  222. mock_class.return_value = mock_instance
  223. yield mock_instance
  224. @pytest.fixture
  225. def mock_printer_manager():
  226. """Mock the printer manager for status checks."""
  227. with patch("backend.app.services.printer_manager.printer_manager") as mock:
  228. mock.get_status = MagicMock(
  229. return_value=MagicMock(
  230. connected=True,
  231. state="IDLE",
  232. progress=0,
  233. temperatures={"nozzle": 25, "bed": 25, "chamber": 25},
  234. raw_data={},
  235. )
  236. )
  237. mock.mark_printer_offline = MagicMock()
  238. yield mock
  239. # ============================================================================
  240. # Factory Fixtures for Test Data
  241. # ============================================================================
  242. @pytest.fixture
  243. def smart_plug_factory(db_session):
  244. """Factory to create test smart plugs."""
  245. async def _create_plug(**kwargs):
  246. from backend.app.models.smart_plug import SmartPlug
  247. # Determine defaults based on plug_type
  248. plug_type = kwargs.get("plug_type", "tasmota")
  249. defaults = {
  250. "name": "Test Plug",
  251. "plug_type": plug_type,
  252. "enabled": True,
  253. "auto_on": True,
  254. "auto_off": True,
  255. "off_delay_mode": "time",
  256. "off_delay_minutes": 5,
  257. "off_temp_threshold": 70,
  258. "schedule_enabled": False,
  259. "power_alert_enabled": False,
  260. }
  261. # Set required fields based on plug_type
  262. if plug_type == "homeassistant":
  263. defaults["ha_entity_id"] = "switch.test"
  264. defaults["ip_address"] = None
  265. elif plug_type == "mqtt":
  266. # Legacy fields (for backward compatibility tests)
  267. defaults["mqtt_topic"] = kwargs.get("mqtt_topic", "test/topic")
  268. defaults["mqtt_multiplier"] = kwargs.get("mqtt_multiplier", 1.0)
  269. # New separate topic/path/multiplier fields
  270. defaults["mqtt_power_topic"] = kwargs.get("mqtt_power_topic")
  271. defaults["mqtt_power_path"] = kwargs.get("mqtt_power_path", "power")
  272. defaults["mqtt_power_multiplier"] = kwargs.get("mqtt_power_multiplier", 1.0)
  273. defaults["mqtt_energy_topic"] = kwargs.get("mqtt_energy_topic")
  274. defaults["mqtt_energy_path"] = kwargs.get("mqtt_energy_path")
  275. defaults["mqtt_energy_multiplier"] = kwargs.get("mqtt_energy_multiplier", 1.0)
  276. defaults["mqtt_state_topic"] = kwargs.get("mqtt_state_topic")
  277. defaults["mqtt_state_path"] = kwargs.get("mqtt_state_path")
  278. defaults["mqtt_state_on_value"] = kwargs.get("mqtt_state_on_value")
  279. defaults["ip_address"] = None
  280. defaults["ha_entity_id"] = None
  281. else:
  282. defaults["ip_address"] = "192.168.1.100"
  283. defaults["ha_entity_id"] = None
  284. defaults.update(kwargs)
  285. plug = SmartPlug(**defaults)
  286. db_session.add(plug)
  287. await db_session.commit()
  288. await db_session.refresh(plug)
  289. return plug
  290. return _create_plug
  291. @pytest.fixture
  292. def printer_factory(db_session):
  293. """Factory to create test printers."""
  294. _counter = [0] # Use list to allow mutation in nested function
  295. async def _create_printer(**kwargs):
  296. from backend.app.models.printer import Printer
  297. _counter[0] += 1
  298. counter = _counter[0]
  299. defaults = {
  300. "name": "Test Printer",
  301. "serial_number": f"00M09A{counter:09d}", # Unique serial per printer
  302. "ip_address": f"192.168.1.{100 + counter}", # Unique IP per printer
  303. "access_code": "12345678",
  304. "is_active": True,
  305. "auto_archive": True,
  306. "model": "X1C",
  307. }
  308. defaults.update(kwargs)
  309. printer = Printer(**defaults)
  310. db_session.add(printer)
  311. await db_session.commit()
  312. await db_session.refresh(printer)
  313. return printer
  314. return _create_printer
  315. @pytest.fixture
  316. def notification_provider_factory(db_session):
  317. """Factory to create test notification providers."""
  318. async def _create_provider(**kwargs):
  319. from backend.app.models.notification import NotificationProvider
  320. config = kwargs.pop("config", {"server": "https://ntfy.sh", "topic": "test-topic"})
  321. if isinstance(config, dict):
  322. config = json.dumps(config)
  323. defaults = {
  324. "name": "Test Provider",
  325. "provider_type": "ntfy",
  326. "enabled": True,
  327. "config": config,
  328. "on_print_start": True,
  329. "on_print_complete": True,
  330. "on_print_failed": True,
  331. "on_print_stopped": True,
  332. "on_print_progress": False,
  333. "on_printer_offline": False,
  334. "on_printer_error": False,
  335. "on_filament_low": False,
  336. "on_maintenance_due": False,
  337. "on_ams_humidity_high": False,
  338. "on_ams_temperature_high": False,
  339. "on_bed_cooled": False,
  340. "quiet_hours_enabled": False,
  341. "daily_digest_enabled": False,
  342. }
  343. defaults.update(kwargs)
  344. provider = NotificationProvider(**defaults)
  345. db_session.add(provider)
  346. await db_session.commit()
  347. await db_session.refresh(provider)
  348. return provider
  349. return _create_provider
  350. @pytest.fixture
  351. def archive_factory(db_session):
  352. """Factory to create test archives."""
  353. async def _create_archive(printer_id: int, **kwargs):
  354. from backend.app.models.archive import PrintArchive
  355. defaults = {
  356. "printer_id": printer_id,
  357. "filename": "test_print.gcode.3mf",
  358. "print_name": "Test Print",
  359. "file_path": "archives/test/test_print.gcode.3mf",
  360. "file_size": 1024000,
  361. "status": "completed",
  362. "filament_type": "PLA",
  363. "filament_used_grams": 50.0,
  364. "print_time_seconds": 3600,
  365. }
  366. defaults.update(kwargs)
  367. archive = PrintArchive(**defaults)
  368. db_session.add(archive)
  369. await db_session.commit()
  370. await db_session.refresh(archive)
  371. return archive
  372. return _create_archive
  373. # ============================================================================
  374. # Sample Data Fixtures
  375. # ============================================================================
  376. @pytest.fixture
  377. def sample_mqtt_print_start():
  378. """Sample MQTT message for print start."""
  379. return {
  380. "print": {
  381. "command": "project_file",
  382. "param": "/sdcard/test.gcode.3mf",
  383. "subtask_name": "test_print",
  384. "gcode_state": "RUNNING",
  385. "mc_percent": 0,
  386. }
  387. }
  388. @pytest.fixture
  389. def sample_mqtt_print_complete():
  390. """Sample MQTT message for print complete."""
  391. return {
  392. "print": {
  393. "gcode_state": "FINISH",
  394. "mc_percent": 100,
  395. "subtask_name": "test_print",
  396. }
  397. }
  398. @pytest.fixture
  399. def sample_printer_status():
  400. """Sample printer status data."""
  401. return {
  402. "connected": True,
  403. "state": "IDLE",
  404. "progress": 0,
  405. "layer_num": 0,
  406. "total_layers": 0,
  407. "temperatures": {
  408. "nozzle": 25.0,
  409. "bed": 25.0,
  410. "chamber": 25.0,
  411. },
  412. "remaining_time": 0,
  413. "filename": None,
  414. }
  415. # ============================================================================
  416. # Log Capture Fixtures for Error Detection
  417. # ============================================================================
  418. class LogCapture(logging.Handler):
  419. """Handler that captures log records for testing."""
  420. def __init__(self):
  421. super().__init__()
  422. self.records: list[logging.LogRecord] = []
  423. def emit(self, record: logging.LogRecord):
  424. self.records.append(record)
  425. def clear(self):
  426. self.records.clear()
  427. def get_errors(self) -> list[logging.LogRecord]:
  428. """Get all ERROR and CRITICAL level records."""
  429. return [r for r in self.records if r.levelno >= logging.ERROR]
  430. def get_warnings(self) -> list[logging.LogRecord]:
  431. """Get all WARNING level records."""
  432. return [r for r in self.records if r.levelno == logging.WARNING]
  433. def has_errors(self) -> bool:
  434. """Check if any errors were logged."""
  435. return len(self.get_errors()) > 0
  436. def format_errors(self) -> str:
  437. """Format all errors as a string for assertion messages."""
  438. errors = self.get_errors()
  439. if not errors:
  440. return "No errors"
  441. formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
  442. return "\n".join(formatter.format(r) for r in errors)
  443. @pytest.fixture
  444. def capture_logs():
  445. """Fixture that captures log output during a test.
  446. Usage:
  447. def test_something(capture_logs):
  448. # Do something that might log errors
  449. some_function()
  450. # Check no errors were logged
  451. assert not capture_logs.has_errors(), capture_logs.format_errors()
  452. """
  453. handler = LogCapture()
  454. handler.setLevel(logging.DEBUG)
  455. # Attach to root logger to capture all logs
  456. root_logger = logging.getLogger()
  457. root_logger.addHandler(handler)
  458. yield handler
  459. root_logger.removeHandler(handler)
  460. @pytest.fixture
  461. def assert_no_log_errors(capture_logs):
  462. """Fixture that automatically asserts no errors were logged.
  463. Usage:
  464. def test_something(assert_no_log_errors):
  465. # If any ERROR logs occur during this test, it will fail
  466. some_function()
  467. """
  468. yield capture_logs
  469. errors = capture_logs.get_errors()
  470. if errors:
  471. pytest.fail(f"Unexpected log errors:\n{capture_logs.format_errors()}")