conftest.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. """Shared test fixtures for BamBuddy backend tests."""
  2. import asyncio
  3. import json
  4. import logging
  5. import os
  6. import sys
  7. from collections.abc import AsyncGenerator
  8. from datetime import datetime
  9. from unittest.mock import AsyncMock, MagicMock, patch
  10. import pytest
  11. # IMPORTANT: Set environment variables BEFORE any app imports
  12. # This must happen before settings/config are loaded
  13. os.environ["LOG_TO_FILE"] = "false"
  14. os.environ["DEBUG"] = "false"
  15. from httpx import ASGITransport, AsyncClient # noqa: E402
  16. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine # noqa: E402
  17. # Ensure settings use our env vars - import and override before database import
  18. from backend.app.core.config import settings # noqa: E402
  19. settings.log_to_file = False
  20. from backend.app.core.database import Base # noqa: E402
  21. # Use in-memory SQLite for tests
  22. TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
  23. @pytest.fixture(scope="session")
  24. def event_loop():
  25. """Create an instance of the default event loop for each test session."""
  26. loop = asyncio.get_event_loop_policy().new_event_loop()
  27. yield loop
  28. loop.close()
  29. @pytest.fixture
  30. async def test_engine():
  31. """Create a test database engine."""
  32. engine = create_async_engine(TEST_DATABASE_URL, echo=False)
  33. # Import all models to register them
  34. from backend.app.models import (
  35. ams_history,
  36. api_key,
  37. archive,
  38. external_link,
  39. filament,
  40. kprofile_note,
  41. maintenance,
  42. notification,
  43. notification_template,
  44. print_queue,
  45. printer,
  46. project,
  47. settings,
  48. smart_plug,
  49. user,
  50. )
  51. async with engine.begin() as conn:
  52. await conn.run_sync(Base.metadata.create_all)
  53. yield engine
  54. async with engine.begin() as conn:
  55. await conn.run_sync(Base.metadata.drop_all)
  56. await engine.dispose()
  57. @pytest.fixture
  58. async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
  59. """Create a test database session."""
  60. async_session_maker = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
  61. async with async_session_maker() as session:
  62. yield session
  63. @pytest.fixture
  64. async def async_client(test_engine, db_session) -> AsyncGenerator[AsyncClient, None]:
  65. """Create an async test client."""
  66. from backend.app.core.database import async_session, get_db
  67. from backend.app.main import app
  68. # Create a new session maker for the test engine
  69. test_async_session = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
  70. async def override_get_db():
  71. async with test_async_session() as session:
  72. yield session
  73. app.dependency_overrides[get_db] = override_get_db
  74. # Mock init_printer_connections to prevent MQTT connection attempts during tests
  75. async def mock_init_printer_connections(db):
  76. pass # No-op - don't connect to real printers
  77. # Also patch the module-level async_session used by services
  78. with (
  79. patch("backend.app.core.database.async_session", test_async_session),
  80. patch("backend.app.main.init_printer_connections", mock_init_printer_connections),
  81. ):
  82. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  83. yield client
  84. app.dependency_overrides.clear()
  85. # ============================================================================
  86. # Mock External Services
  87. # ============================================================================
  88. @pytest.fixture
  89. def mock_tasmota_service():
  90. """Mock the Tasmota service for smart plug tests."""
  91. # Patch both the module where it's defined and where it's imported
  92. with (
  93. patch("backend.app.services.tasmota.tasmota_service") as mock,
  94. patch("backend.app.api.routes.smart_plugs.tasmota_service") as mock2,
  95. ):
  96. mock.turn_on = AsyncMock(return_value=True)
  97. mock.turn_off = AsyncMock(return_value=True)
  98. mock.toggle = AsyncMock(return_value=True)
  99. mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test Plug"})
  100. mock.get_energy = AsyncMock(
  101. return_value={
  102. "power": 150.5,
  103. "voltage": 120.0,
  104. "current": 1.25,
  105. "today": 2.5,
  106. "total": 100.0,
  107. "factor": 0.95,
  108. }
  109. )
  110. mock.test_connection = AsyncMock(return_value={"success": True, "state": "ON", "device_name": "Test Plug"})
  111. # Copy mocks to second patch target
  112. mock2.turn_on = mock.turn_on
  113. mock2.turn_off = mock.turn_off
  114. mock2.toggle = mock.toggle
  115. mock2.get_status = mock.get_status
  116. mock2.get_energy = mock.get_energy
  117. mock2.test_connection = mock.test_connection
  118. yield mock
  119. @pytest.fixture
  120. def mock_homeassistant_service():
  121. """Mock the Home Assistant service for smart plug tests."""
  122. # Patch both the module where it's defined and where it's imported
  123. with (
  124. patch("backend.app.services.homeassistant.homeassistant_service") as mock,
  125. patch("backend.app.api.routes.smart_plugs.homeassistant_service") as mock2,
  126. ):
  127. mock.turn_on = AsyncMock(return_value=True)
  128. mock.turn_off = AsyncMock(return_value=True)
  129. mock.toggle = AsyncMock(return_value=True)
  130. mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test HA Entity"})
  131. mock.get_energy = AsyncMock(return_value=None) # Most HA entities don't have power monitoring
  132. mock.test_connection = AsyncMock(return_value={"success": True, "message": "API running", "error": None})
  133. mock.list_entities = AsyncMock(
  134. return_value=[
  135. {
  136. "entity_id": "switch.printer_plug",
  137. "friendly_name": "Printer Plug",
  138. "state": "on",
  139. "domain": "switch",
  140. },
  141. {"entity_id": "switch.test", "friendly_name": "Test Switch", "state": "off", "domain": "switch"},
  142. ]
  143. )
  144. mock.configure = MagicMock()
  145. # Copy mocks to second patch target
  146. mock2.turn_on = mock.turn_on
  147. mock2.turn_off = mock.turn_off
  148. mock2.toggle = mock.toggle
  149. mock2.get_status = mock.get_status
  150. mock2.get_energy = mock.get_energy
  151. mock2.test_connection = mock.test_connection
  152. mock2.list_entities = mock.list_entities
  153. mock2.configure = mock.configure
  154. yield mock
  155. @pytest.fixture
  156. def mock_mqtt_client():
  157. """Mock the MQTT client for printer communication tests."""
  158. with patch("backend.app.services.bambu_mqtt.BambuMQTTClient") as mock:
  159. instance = MagicMock()
  160. instance.state = MagicMock(connected=True, state="IDLE", progress=0, temperatures={"nozzle": 25, "bed": 25})
  161. instance.connect = MagicMock()
  162. instance.disconnect = MagicMock()
  163. mock.return_value = instance
  164. yield mock
  165. @pytest.fixture
  166. def mock_ftp_client():
  167. """Mock the FTP client for file transfer tests."""
  168. with (
  169. patch("backend.app.services.bambu_ftp.download_file_async") as download_mock,
  170. patch("backend.app.services.bambu_ftp.list_files_async") as list_mock,
  171. ):
  172. download_mock.return_value = True
  173. list_mock.return_value = []
  174. yield {"download": download_mock, "list": list_mock}
  175. @pytest.fixture
  176. def mock_httpx_client():
  177. """Mock httpx for webhook/notification HTTP calls."""
  178. with patch("httpx.AsyncClient") as mock_class:
  179. mock_instance = AsyncMock()
  180. mock_response = MagicMock()
  181. mock_response.status_code = 200
  182. mock_response.text = "OK"
  183. mock_response.json.return_value = {}
  184. mock_instance.get = AsyncMock(return_value=mock_response)
  185. mock_instance.post = AsyncMock(return_value=mock_response)
  186. mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
  187. mock_instance.__aexit__ = AsyncMock()
  188. mock_class.return_value = mock_instance
  189. yield mock_instance
  190. @pytest.fixture
  191. def mock_printer_manager():
  192. """Mock the printer manager for status checks."""
  193. with patch("backend.app.services.printer_manager.printer_manager") as mock:
  194. mock.get_status = MagicMock(
  195. return_value=MagicMock(
  196. connected=True,
  197. state="IDLE",
  198. progress=0,
  199. temperatures={"nozzle": 25, "bed": 25, "chamber": 25},
  200. raw_data={},
  201. )
  202. )
  203. mock.mark_printer_offline = MagicMock()
  204. yield mock
  205. # ============================================================================
  206. # Factory Fixtures for Test Data
  207. # ============================================================================
  208. @pytest.fixture
  209. def smart_plug_factory(db_session):
  210. """Factory to create test smart plugs."""
  211. async def _create_plug(**kwargs):
  212. from backend.app.models.smart_plug import SmartPlug
  213. # Determine defaults based on plug_type
  214. plug_type = kwargs.get("plug_type", "tasmota")
  215. defaults = {
  216. "name": "Test Plug",
  217. "plug_type": plug_type,
  218. "enabled": True,
  219. "auto_on": True,
  220. "auto_off": True,
  221. "off_delay_mode": "time",
  222. "off_delay_minutes": 5,
  223. "off_temp_threshold": 70,
  224. "schedule_enabled": False,
  225. "power_alert_enabled": False,
  226. }
  227. # Set required fields based on plug_type
  228. if plug_type == "homeassistant":
  229. defaults["ha_entity_id"] = "switch.test"
  230. defaults["ip_address"] = None
  231. else:
  232. defaults["ip_address"] = "192.168.1.100"
  233. defaults["ha_entity_id"] = None
  234. defaults.update(kwargs)
  235. plug = SmartPlug(**defaults)
  236. db_session.add(plug)
  237. await db_session.commit()
  238. await db_session.refresh(plug)
  239. return plug
  240. return _create_plug
  241. @pytest.fixture
  242. def printer_factory(db_session):
  243. """Factory to create test printers."""
  244. _counter = [0] # Use list to allow mutation in nested function
  245. async def _create_printer(**kwargs):
  246. from backend.app.models.printer import Printer
  247. _counter[0] += 1
  248. counter = _counter[0]
  249. defaults = {
  250. "name": "Test Printer",
  251. "serial_number": f"00M09A{counter:09d}", # Unique serial per printer
  252. "ip_address": f"192.168.1.{100 + counter}", # Unique IP per printer
  253. "access_code": "12345678",
  254. "is_active": True,
  255. "auto_archive": True,
  256. "model": "X1C",
  257. }
  258. defaults.update(kwargs)
  259. printer = Printer(**defaults)
  260. db_session.add(printer)
  261. await db_session.commit()
  262. await db_session.refresh(printer)
  263. return printer
  264. return _create_printer
  265. @pytest.fixture
  266. def notification_provider_factory(db_session):
  267. """Factory to create test notification providers."""
  268. async def _create_provider(**kwargs):
  269. from backend.app.models.notification import NotificationProvider
  270. config = kwargs.pop("config", {"server": "https://ntfy.sh", "topic": "test-topic"})
  271. if isinstance(config, dict):
  272. config = json.dumps(config)
  273. defaults = {
  274. "name": "Test Provider",
  275. "provider_type": "ntfy",
  276. "enabled": True,
  277. "config": config,
  278. "on_print_start": True,
  279. "on_print_complete": True,
  280. "on_print_failed": True,
  281. "on_print_stopped": True,
  282. "on_print_progress": False,
  283. "on_printer_offline": False,
  284. "on_printer_error": False,
  285. "on_filament_low": False,
  286. "on_maintenance_due": False,
  287. "on_ams_humidity_high": False,
  288. "on_ams_temperature_high": False,
  289. "quiet_hours_enabled": False,
  290. "daily_digest_enabled": False,
  291. }
  292. defaults.update(kwargs)
  293. provider = NotificationProvider(**defaults)
  294. db_session.add(provider)
  295. await db_session.commit()
  296. await db_session.refresh(provider)
  297. return provider
  298. return _create_provider
  299. @pytest.fixture
  300. def archive_factory(db_session):
  301. """Factory to create test archives."""
  302. async def _create_archive(printer_id: int, **kwargs):
  303. from backend.app.models.archive import PrintArchive
  304. defaults = {
  305. "printer_id": printer_id,
  306. "filename": "test_print.gcode.3mf",
  307. "print_name": "Test Print",
  308. "file_path": "archives/test/test_print.gcode.3mf",
  309. "file_size": 1024000,
  310. "status": "completed",
  311. "filament_type": "PLA",
  312. "filament_used_grams": 50.0,
  313. "print_time_seconds": 3600,
  314. }
  315. defaults.update(kwargs)
  316. archive = PrintArchive(**defaults)
  317. db_session.add(archive)
  318. await db_session.commit()
  319. await db_session.refresh(archive)
  320. return archive
  321. return _create_archive
  322. # ============================================================================
  323. # Sample Data Fixtures
  324. # ============================================================================
  325. @pytest.fixture
  326. def sample_mqtt_print_start():
  327. """Sample MQTT message for print start."""
  328. return {
  329. "print": {
  330. "command": "project_file",
  331. "param": "/sdcard/test.gcode.3mf",
  332. "subtask_name": "test_print",
  333. "gcode_state": "RUNNING",
  334. "mc_percent": 0,
  335. }
  336. }
  337. @pytest.fixture
  338. def sample_mqtt_print_complete():
  339. """Sample MQTT message for print complete."""
  340. return {
  341. "print": {
  342. "gcode_state": "FINISH",
  343. "mc_percent": 100,
  344. "subtask_name": "test_print",
  345. }
  346. }
  347. @pytest.fixture
  348. def sample_printer_status():
  349. """Sample printer status data."""
  350. return {
  351. "connected": True,
  352. "state": "IDLE",
  353. "progress": 0,
  354. "layer_num": 0,
  355. "total_layers": 0,
  356. "temperatures": {
  357. "nozzle": 25.0,
  358. "bed": 25.0,
  359. "chamber": 25.0,
  360. },
  361. "remaining_time": 0,
  362. "filename": None,
  363. }
  364. # ============================================================================
  365. # Log Capture Fixtures for Error Detection
  366. # ============================================================================
  367. class LogCapture(logging.Handler):
  368. """Handler that captures log records for testing."""
  369. def __init__(self):
  370. super().__init__()
  371. self.records: list[logging.LogRecord] = []
  372. def emit(self, record: logging.LogRecord):
  373. self.records.append(record)
  374. def clear(self):
  375. self.records.clear()
  376. def get_errors(self) -> list[logging.LogRecord]:
  377. """Get all ERROR and CRITICAL level records."""
  378. return [r for r in self.records if r.levelno >= logging.ERROR]
  379. def get_warnings(self) -> list[logging.LogRecord]:
  380. """Get all WARNING level records."""
  381. return [r for r in self.records if r.levelno == logging.WARNING]
  382. def has_errors(self) -> bool:
  383. """Check if any errors were logged."""
  384. return len(self.get_errors()) > 0
  385. def format_errors(self) -> str:
  386. """Format all errors as a string for assertion messages."""
  387. errors = self.get_errors()
  388. if not errors:
  389. return "No errors"
  390. formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
  391. return "\n".join(formatter.format(r) for r in errors)
  392. @pytest.fixture
  393. def capture_logs():
  394. """Fixture that captures log output during a test.
  395. Usage:
  396. def test_something(capture_logs):
  397. # Do something that might log errors
  398. some_function()
  399. # Check no errors were logged
  400. assert not capture_logs.has_errors(), capture_logs.format_errors()
  401. """
  402. handler = LogCapture()
  403. handler.setLevel(logging.DEBUG)
  404. # Attach to root logger to capture all logs
  405. root_logger = logging.getLogger()
  406. root_logger.addHandler(handler)
  407. yield handler
  408. root_logger.removeHandler(handler)
  409. @pytest.fixture
  410. def assert_no_log_errors(capture_logs):
  411. """Fixture that automatically asserts no errors were logged.
  412. Usage:
  413. def test_something(assert_no_log_errors):
  414. # If any ERROR logs occur during this test, it will fail
  415. some_function()
  416. """
  417. yield capture_logs
  418. errors = capture_logs.get_errors()
  419. if errors:
  420. pytest.fail(f"Unexpected log errors:\n{capture_logs.format_errors()}")