conftest.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. """Shared test fixtures for BamBuddy backend tests."""
  2. import asyncio
  3. import atexit
  4. import json
  5. import logging
  6. import os
  7. import shutil
  8. import sys
  9. import tempfile
  10. from collections.abc import AsyncGenerator
  11. from datetime import datetime
  12. from pathlib import Path
  13. from unittest.mock import AsyncMock, MagicMock, patch
  14. import pytest
  15. # IMPORTANT: Set environment variables BEFORE any app imports
  16. # This must happen before settings/config are loaded
  17. os.environ["LOG_TO_FILE"] = "false"
  18. os.environ["DEBUG"] = "false"
  19. from httpx import ASGITransport, AsyncClient # noqa: E402
  20. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine # noqa: E402
  21. # Ensure settings use our env vars - import and override before database import
  22. from backend.app.core.config import settings # noqa: E402
  23. settings.log_to_file = False
  24. # Use a temp directory for plate calibration to avoid deleting real calibration files
  25. _test_plate_cal_dir = Path(tempfile.mkdtemp(prefix="bambuddy_test_plate_cal_"))
  26. settings.plate_calibration_dir = _test_plate_cal_dir
  27. # Clean up temp directory when tests finish
  28. def _cleanup_test_plate_cal_dir():
  29. if _test_plate_cal_dir.exists():
  30. shutil.rmtree(_test_plate_cal_dir, ignore_errors=True)
  31. atexit.register(_cleanup_test_plate_cal_dir)
  32. from backend.app.core.database import Base # noqa: E402
  33. # Use in-memory SQLite for tests
  34. TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
  35. @pytest.fixture(scope="session")
  36. def event_loop():
  37. """Create an instance of the default event loop for each test session."""
  38. loop = asyncio.get_event_loop_policy().new_event_loop()
  39. yield loop
  40. loop.close()
  41. @pytest.fixture
  42. async def test_engine():
  43. """Create a test database engine."""
  44. engine = create_async_engine(TEST_DATABASE_URL, echo=False)
  45. # Import all models to register them
  46. from backend.app.models import (
  47. ams_history,
  48. api_key,
  49. archive,
  50. external_link,
  51. filament,
  52. kprofile_note,
  53. maintenance,
  54. notification,
  55. notification_template,
  56. print_queue,
  57. printer,
  58. project,
  59. settings,
  60. smart_plug,
  61. user,
  62. )
  63. async with engine.begin() as conn:
  64. await conn.run_sync(Base.metadata.create_all)
  65. yield engine
  66. async with engine.begin() as conn:
  67. await conn.run_sync(Base.metadata.drop_all)
  68. await engine.dispose()
  69. @pytest.fixture
  70. async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
  71. """Create a test database session."""
  72. async_session_maker = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
  73. async with async_session_maker() as session:
  74. yield session
  75. @pytest.fixture
  76. async def async_client(test_engine, db_session) -> AsyncGenerator[AsyncClient, None]:
  77. """Create an async test client."""
  78. from backend.app.core.database import async_session, get_db
  79. from backend.app.main import app
  80. # Create a new session maker for the test engine
  81. test_async_session = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
  82. async def override_get_db():
  83. async with test_async_session() as session:
  84. yield session
  85. app.dependency_overrides[get_db] = override_get_db
  86. # Mock init_printer_connections to prevent MQTT connection attempts during tests
  87. async def mock_init_printer_connections(db):
  88. pass # No-op - don't connect to real printers
  89. # Also patch the module-level async_session used by services and auth
  90. with (
  91. patch("backend.app.core.database.async_session", test_async_session),
  92. patch("backend.app.core.auth.async_session", test_async_session),
  93. patch("backend.app.main.init_printer_connections", mock_init_printer_connections),
  94. ):
  95. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  96. yield client
  97. app.dependency_overrides.clear()
  98. # ============================================================================
  99. # Mock External Services
  100. # ============================================================================
  101. @pytest.fixture
  102. def mock_tasmota_service():
  103. """Mock the Tasmota service for smart plug tests."""
  104. # Patch both the module where it's defined and where it's imported
  105. with (
  106. patch("backend.app.services.tasmota.tasmota_service") as mock,
  107. patch("backend.app.api.routes.smart_plugs.tasmota_service") as mock2,
  108. ):
  109. mock.turn_on = AsyncMock(return_value=True)
  110. mock.turn_off = AsyncMock(return_value=True)
  111. mock.toggle = AsyncMock(return_value=True)
  112. mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test Plug"})
  113. mock.get_energy = AsyncMock(
  114. return_value={
  115. "power": 150.5,
  116. "voltage": 120.0,
  117. "current": 1.25,
  118. "today": 2.5,
  119. "total": 100.0,
  120. "factor": 0.95,
  121. }
  122. )
  123. mock.test_connection = AsyncMock(return_value={"success": True, "state": "ON", "device_name": "Test Plug"})
  124. # Copy mocks to second patch target
  125. mock2.turn_on = mock.turn_on
  126. mock2.turn_off = mock.turn_off
  127. mock2.toggle = mock.toggle
  128. mock2.get_status = mock.get_status
  129. mock2.get_energy = mock.get_energy
  130. mock2.test_connection = mock.test_connection
  131. yield mock
  132. @pytest.fixture
  133. def mock_homeassistant_service():
  134. """Mock the Home Assistant service for smart plug tests."""
  135. # Patch both the module where it's defined and where it's imported
  136. with (
  137. patch("backend.app.services.homeassistant.homeassistant_service") as mock,
  138. patch("backend.app.api.routes.smart_plugs.homeassistant_service") as mock2,
  139. ):
  140. mock.turn_on = AsyncMock(return_value=True)
  141. mock.turn_off = AsyncMock(return_value=True)
  142. mock.toggle = AsyncMock(return_value=True)
  143. mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test HA Entity"})
  144. mock.get_energy = AsyncMock(return_value=None) # Most HA entities don't have power monitoring
  145. mock.test_connection = AsyncMock(return_value={"success": True, "message": "API running", "error": None})
  146. mock.list_entities = AsyncMock(
  147. return_value=[
  148. {
  149. "entity_id": "switch.printer_plug",
  150. "friendly_name": "Printer Plug",
  151. "state": "on",
  152. "domain": "switch",
  153. },
  154. {"entity_id": "switch.test", "friendly_name": "Test Switch", "state": "off", "domain": "switch"},
  155. ]
  156. )
  157. mock.configure = MagicMock()
  158. # Copy mocks to second patch target
  159. mock2.turn_on = mock.turn_on
  160. mock2.turn_off = mock.turn_off
  161. mock2.toggle = mock.toggle
  162. mock2.get_status = mock.get_status
  163. mock2.get_energy = mock.get_energy
  164. mock2.test_connection = mock.test_connection
  165. mock2.list_entities = mock.list_entities
  166. mock2.configure = mock.configure
  167. yield mock
  168. @pytest.fixture
  169. def mock_mqtt_client():
  170. """Mock the MQTT client for printer communication tests."""
  171. with patch("backend.app.services.bambu_mqtt.BambuMQTTClient") as mock:
  172. instance = MagicMock()
  173. instance.state = MagicMock(connected=True, state="IDLE", progress=0, temperatures={"nozzle": 25, "bed": 25})
  174. instance.connect = MagicMock()
  175. instance.disconnect = MagicMock()
  176. mock.return_value = instance
  177. yield mock
  178. @pytest.fixture
  179. def mock_mqtt_smart_plug_service():
  180. """Mock the MQTT smart plug service for MQTT plug tests."""
  181. with patch("backend.app.api.routes.smart_plugs.mqtt_relay") as mock:
  182. # Create a mock smart_plug_service
  183. mock_service = MagicMock()
  184. mock_service.is_configured = MagicMock(return_value=True)
  185. mock_service.has_broker_settings = MagicMock(return_value=True)
  186. mock_service.configure = AsyncMock(return_value=True)
  187. mock_service.subscribe = MagicMock()
  188. mock_service.unsubscribe = MagicMock()
  189. mock_service.get_plug_data = MagicMock(return_value=None)
  190. mock_service.is_reachable = MagicMock(return_value=False)
  191. mock.smart_plug_service = mock_service
  192. yield mock
  193. @pytest.fixture
  194. def mock_ftp_client():
  195. """Mock the FTP client for file transfer tests."""
  196. with (
  197. patch("backend.app.services.bambu_ftp.download_file_async") as download_mock,
  198. patch("backend.app.services.bambu_ftp.list_files_async") as list_mock,
  199. ):
  200. download_mock.return_value = True
  201. list_mock.return_value = []
  202. yield {"download": download_mock, "list": list_mock}
  203. @pytest.fixture
  204. def mock_httpx_client():
  205. """Mock httpx for webhook/notification HTTP calls."""
  206. with patch("httpx.AsyncClient") as mock_class:
  207. mock_instance = AsyncMock()
  208. mock_response = MagicMock()
  209. mock_response.status_code = 200
  210. mock_response.text = "OK"
  211. mock_response.json.return_value = {}
  212. mock_instance.get = AsyncMock(return_value=mock_response)
  213. mock_instance.post = AsyncMock(return_value=mock_response)
  214. mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
  215. mock_instance.__aexit__ = AsyncMock()
  216. mock_class.return_value = mock_instance
  217. yield mock_instance
  218. @pytest.fixture
  219. def mock_printer_manager():
  220. """Mock the printer manager for status checks."""
  221. with patch("backend.app.services.printer_manager.printer_manager") as mock:
  222. mock.get_status = MagicMock(
  223. return_value=MagicMock(
  224. connected=True,
  225. state="IDLE",
  226. progress=0,
  227. temperatures={"nozzle": 25, "bed": 25, "chamber": 25},
  228. raw_data={},
  229. )
  230. )
  231. mock.mark_printer_offline = MagicMock()
  232. yield mock
  233. # ============================================================================
  234. # Factory Fixtures for Test Data
  235. # ============================================================================
  236. @pytest.fixture
  237. def smart_plug_factory(db_session):
  238. """Factory to create test smart plugs."""
  239. async def _create_plug(**kwargs):
  240. from backend.app.models.smart_plug import SmartPlug
  241. # Determine defaults based on plug_type
  242. plug_type = kwargs.get("plug_type", "tasmota")
  243. defaults = {
  244. "name": "Test Plug",
  245. "plug_type": plug_type,
  246. "enabled": True,
  247. "auto_on": True,
  248. "auto_off": True,
  249. "off_delay_mode": "time",
  250. "off_delay_minutes": 5,
  251. "off_temp_threshold": 70,
  252. "schedule_enabled": False,
  253. "power_alert_enabled": False,
  254. }
  255. # Set required fields based on plug_type
  256. if plug_type == "homeassistant":
  257. defaults["ha_entity_id"] = "switch.test"
  258. defaults["ip_address"] = None
  259. elif plug_type == "mqtt":
  260. # Legacy fields (for backward compatibility tests)
  261. defaults["mqtt_topic"] = kwargs.get("mqtt_topic", "test/topic")
  262. defaults["mqtt_multiplier"] = kwargs.get("mqtt_multiplier", 1.0)
  263. # New separate topic/path/multiplier fields
  264. defaults["mqtt_power_topic"] = kwargs.get("mqtt_power_topic")
  265. defaults["mqtt_power_path"] = kwargs.get("mqtt_power_path", "power")
  266. defaults["mqtt_power_multiplier"] = kwargs.get("mqtt_power_multiplier", 1.0)
  267. defaults["mqtt_energy_topic"] = kwargs.get("mqtt_energy_topic")
  268. defaults["mqtt_energy_path"] = kwargs.get("mqtt_energy_path")
  269. defaults["mqtt_energy_multiplier"] = kwargs.get("mqtt_energy_multiplier", 1.0)
  270. defaults["mqtt_state_topic"] = kwargs.get("mqtt_state_topic")
  271. defaults["mqtt_state_path"] = kwargs.get("mqtt_state_path")
  272. defaults["mqtt_state_on_value"] = kwargs.get("mqtt_state_on_value")
  273. defaults["ip_address"] = None
  274. defaults["ha_entity_id"] = None
  275. else:
  276. defaults["ip_address"] = "192.168.1.100"
  277. defaults["ha_entity_id"] = None
  278. defaults.update(kwargs)
  279. plug = SmartPlug(**defaults)
  280. db_session.add(plug)
  281. await db_session.commit()
  282. await db_session.refresh(plug)
  283. return plug
  284. return _create_plug
  285. @pytest.fixture
  286. def printer_factory(db_session):
  287. """Factory to create test printers."""
  288. _counter = [0] # Use list to allow mutation in nested function
  289. async def _create_printer(**kwargs):
  290. from backend.app.models.printer import Printer
  291. _counter[0] += 1
  292. counter = _counter[0]
  293. defaults = {
  294. "name": "Test Printer",
  295. "serial_number": f"00M09A{counter:09d}", # Unique serial per printer
  296. "ip_address": f"192.168.1.{100 + counter}", # Unique IP per printer
  297. "access_code": "12345678",
  298. "is_active": True,
  299. "auto_archive": True,
  300. "model": "X1C",
  301. }
  302. defaults.update(kwargs)
  303. printer = Printer(**defaults)
  304. db_session.add(printer)
  305. await db_session.commit()
  306. await db_session.refresh(printer)
  307. return printer
  308. return _create_printer
  309. @pytest.fixture
  310. def notification_provider_factory(db_session):
  311. """Factory to create test notification providers."""
  312. async def _create_provider(**kwargs):
  313. from backend.app.models.notification import NotificationProvider
  314. config = kwargs.pop("config", {"server": "https://ntfy.sh", "topic": "test-topic"})
  315. if isinstance(config, dict):
  316. config = json.dumps(config)
  317. defaults = {
  318. "name": "Test Provider",
  319. "provider_type": "ntfy",
  320. "enabled": True,
  321. "config": config,
  322. "on_print_start": True,
  323. "on_print_complete": True,
  324. "on_print_failed": True,
  325. "on_print_stopped": True,
  326. "on_print_progress": False,
  327. "on_printer_offline": False,
  328. "on_printer_error": False,
  329. "on_filament_low": False,
  330. "on_maintenance_due": False,
  331. "on_ams_humidity_high": False,
  332. "on_ams_temperature_high": False,
  333. "quiet_hours_enabled": False,
  334. "daily_digest_enabled": False,
  335. }
  336. defaults.update(kwargs)
  337. provider = NotificationProvider(**defaults)
  338. db_session.add(provider)
  339. await db_session.commit()
  340. await db_session.refresh(provider)
  341. return provider
  342. return _create_provider
  343. @pytest.fixture
  344. def archive_factory(db_session):
  345. """Factory to create test archives."""
  346. async def _create_archive(printer_id: int, **kwargs):
  347. from backend.app.models.archive import PrintArchive
  348. defaults = {
  349. "printer_id": printer_id,
  350. "filename": "test_print.gcode.3mf",
  351. "print_name": "Test Print",
  352. "file_path": "archives/test/test_print.gcode.3mf",
  353. "file_size": 1024000,
  354. "status": "completed",
  355. "filament_type": "PLA",
  356. "filament_used_grams": 50.0,
  357. "print_time_seconds": 3600,
  358. }
  359. defaults.update(kwargs)
  360. archive = PrintArchive(**defaults)
  361. db_session.add(archive)
  362. await db_session.commit()
  363. await db_session.refresh(archive)
  364. return archive
  365. return _create_archive
  366. # ============================================================================
  367. # Sample Data Fixtures
  368. # ============================================================================
  369. @pytest.fixture
  370. def sample_mqtt_print_start():
  371. """Sample MQTT message for print start."""
  372. return {
  373. "print": {
  374. "command": "project_file",
  375. "param": "/sdcard/test.gcode.3mf",
  376. "subtask_name": "test_print",
  377. "gcode_state": "RUNNING",
  378. "mc_percent": 0,
  379. }
  380. }
  381. @pytest.fixture
  382. def sample_mqtt_print_complete():
  383. """Sample MQTT message for print complete."""
  384. return {
  385. "print": {
  386. "gcode_state": "FINISH",
  387. "mc_percent": 100,
  388. "subtask_name": "test_print",
  389. }
  390. }
  391. @pytest.fixture
  392. def sample_printer_status():
  393. """Sample printer status data."""
  394. return {
  395. "connected": True,
  396. "state": "IDLE",
  397. "progress": 0,
  398. "layer_num": 0,
  399. "total_layers": 0,
  400. "temperatures": {
  401. "nozzle": 25.0,
  402. "bed": 25.0,
  403. "chamber": 25.0,
  404. },
  405. "remaining_time": 0,
  406. "filename": None,
  407. }
  408. # ============================================================================
  409. # Log Capture Fixtures for Error Detection
  410. # ============================================================================
  411. class LogCapture(logging.Handler):
  412. """Handler that captures log records for testing."""
  413. def __init__(self):
  414. super().__init__()
  415. self.records: list[logging.LogRecord] = []
  416. def emit(self, record: logging.LogRecord):
  417. self.records.append(record)
  418. def clear(self):
  419. self.records.clear()
  420. def get_errors(self) -> list[logging.LogRecord]:
  421. """Get all ERROR and CRITICAL level records."""
  422. return [r for r in self.records if r.levelno >= logging.ERROR]
  423. def get_warnings(self) -> list[logging.LogRecord]:
  424. """Get all WARNING level records."""
  425. return [r for r in self.records if r.levelno == logging.WARNING]
  426. def has_errors(self) -> bool:
  427. """Check if any errors were logged."""
  428. return len(self.get_errors()) > 0
  429. def format_errors(self) -> str:
  430. """Format all errors as a string for assertion messages."""
  431. errors = self.get_errors()
  432. if not errors:
  433. return "No errors"
  434. formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
  435. return "\n".join(formatter.format(r) for r in errors)
  436. @pytest.fixture
  437. def capture_logs():
  438. """Fixture that captures log output during a test.
  439. Usage:
  440. def test_something(capture_logs):
  441. # Do something that might log errors
  442. some_function()
  443. # Check no errors were logged
  444. assert not capture_logs.has_errors(), capture_logs.format_errors()
  445. """
  446. handler = LogCapture()
  447. handler.setLevel(logging.DEBUG)
  448. # Attach to root logger to capture all logs
  449. root_logger = logging.getLogger()
  450. root_logger.addHandler(handler)
  451. yield handler
  452. root_logger.removeHandler(handler)
  453. @pytest.fixture
  454. def assert_no_log_errors(capture_logs):
  455. """Fixture that automatically asserts no errors were logged.
  456. Usage:
  457. def test_something(assert_no_log_errors):
  458. # If any ERROR logs occur during this test, it will fail
  459. some_function()
  460. """
  461. yield capture_logs
  462. errors = capture_logs.get_errors()
  463. if errors:
  464. pytest.fail(f"Unexpected log errors:\n{capture_logs.format_errors()}")