conftest.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. """Shared test fixtures for BamBuddy backend tests."""
  2. import asyncio
  3. import json
  4. import logging
  5. import os
  6. import sys
  7. import pytest
  8. from typing import AsyncGenerator
  9. from datetime import datetime
  10. from unittest.mock import AsyncMock, MagicMock, patch
  11. # IMPORTANT: Set environment variables BEFORE any app imports
  12. # This must happen before settings/config are loaded
  13. os.environ["LOG_TO_FILE"] = "false"
  14. os.environ["DEBUG"] = "false"
  15. from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
  16. from httpx import AsyncClient, ASGITransport
  17. # Ensure settings use our env vars - import and override before database import
  18. from backend.app.core.config import settings
  19. settings.log_to_file = False
  20. from backend.app.core.database import Base
  21. # Use in-memory SQLite for tests
  22. TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
  23. @pytest.fixture(scope="session")
  24. def event_loop():
  25. """Create an instance of the default event loop for each test session."""
  26. loop = asyncio.get_event_loop_policy().new_event_loop()
  27. yield loop
  28. loop.close()
  29. @pytest.fixture
  30. async def test_engine():
  31. """Create a test database engine."""
  32. engine = create_async_engine(TEST_DATABASE_URL, echo=False)
  33. # Import all models to register them
  34. from backend.app.models import (
  35. printer, archive, filament, settings, smart_plug,
  36. print_queue, notification, maintenance, kprofile_note,
  37. notification_template, external_link, project, api_key,
  38. ams_history
  39. )
  40. async with engine.begin() as conn:
  41. await conn.run_sync(Base.metadata.create_all)
  42. yield engine
  43. async with engine.begin() as conn:
  44. await conn.run_sync(Base.metadata.drop_all)
  45. await engine.dispose()
  46. @pytest.fixture
  47. async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
  48. """Create a test database session."""
  49. async_session_maker = async_sessionmaker(
  50. test_engine, class_=AsyncSession, expire_on_commit=False
  51. )
  52. async with async_session_maker() as session:
  53. yield session
  54. @pytest.fixture
  55. async def async_client(test_engine, db_session) -> AsyncGenerator[AsyncClient, None]:
  56. """Create an async test client."""
  57. from backend.app.main import app
  58. from backend.app.core.database import get_db, async_session
  59. # Create a new session maker for the test engine
  60. test_async_session = async_sessionmaker(
  61. test_engine, class_=AsyncSession, expire_on_commit=False
  62. )
  63. async def override_get_db():
  64. async with test_async_session() as session:
  65. yield session
  66. app.dependency_overrides[get_db] = override_get_db
  67. # Also patch the module-level async_session used by services
  68. with patch('backend.app.core.database.async_session', test_async_session):
  69. async with AsyncClient(
  70. transport=ASGITransport(app=app),
  71. base_url="http://test"
  72. ) as client:
  73. yield client
  74. app.dependency_overrides.clear()
  75. # ============================================================================
  76. # Mock External Services
  77. # ============================================================================
  78. @pytest.fixture
  79. def mock_tasmota_service():
  80. """Mock the Tasmota service for smart plug tests."""
  81. # Patch both the module where it's defined and where it's imported
  82. with patch('backend.app.services.tasmota.tasmota_service') as mock, \
  83. patch('backend.app.api.routes.smart_plugs.tasmota_service') as mock2:
  84. mock.turn_on = AsyncMock(return_value=True)
  85. mock.turn_off = AsyncMock(return_value=True)
  86. mock.toggle = AsyncMock(return_value=True)
  87. mock.get_status = AsyncMock(return_value={
  88. "state": "ON",
  89. "reachable": True,
  90. "device_name": "Test Plug"
  91. })
  92. mock.get_energy = AsyncMock(return_value={
  93. "power": 150.5,
  94. "voltage": 120.0,
  95. "current": 1.25,
  96. "today": 2.5,
  97. "total": 100.0,
  98. "factor": 0.95,
  99. })
  100. mock.test_connection = AsyncMock(return_value={
  101. "success": True,
  102. "state": "ON",
  103. "device_name": "Test Plug"
  104. })
  105. # Copy mocks to second patch target
  106. mock2.turn_on = mock.turn_on
  107. mock2.turn_off = mock.turn_off
  108. mock2.toggle = mock.toggle
  109. mock2.get_status = mock.get_status
  110. mock2.get_energy = mock.get_energy
  111. mock2.test_connection = mock.test_connection
  112. yield mock
  113. @pytest.fixture
  114. def mock_mqtt_client():
  115. """Mock the MQTT client for printer communication tests."""
  116. with patch('backend.app.services.bambu_mqtt.BambuMQTTClient') as mock:
  117. instance = MagicMock()
  118. instance.state = MagicMock(
  119. connected=True,
  120. state="IDLE",
  121. progress=0,
  122. temperatures={"nozzle": 25, "bed": 25}
  123. )
  124. instance.connect = MagicMock()
  125. instance.disconnect = MagicMock()
  126. mock.return_value = instance
  127. yield mock
  128. @pytest.fixture
  129. def mock_ftp_client():
  130. """Mock the FTP client for file transfer tests."""
  131. with patch('backend.app.services.bambu_ftp.download_file_async') as download_mock, \
  132. patch('backend.app.services.bambu_ftp.list_files_async') as list_mock:
  133. download_mock.return_value = True
  134. list_mock.return_value = []
  135. yield {"download": download_mock, "list": list_mock}
  136. @pytest.fixture
  137. def mock_httpx_client():
  138. """Mock httpx for webhook/notification HTTP calls."""
  139. with patch('httpx.AsyncClient') as mock_class:
  140. mock_instance = AsyncMock()
  141. mock_response = MagicMock()
  142. mock_response.status_code = 200
  143. mock_response.text = "OK"
  144. mock_response.json.return_value = {}
  145. mock_instance.get = AsyncMock(return_value=mock_response)
  146. mock_instance.post = AsyncMock(return_value=mock_response)
  147. mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
  148. mock_instance.__aexit__ = AsyncMock()
  149. mock_class.return_value = mock_instance
  150. yield mock_instance
  151. @pytest.fixture
  152. def mock_printer_manager():
  153. """Mock the printer manager for status checks."""
  154. with patch('backend.app.services.printer_manager.printer_manager') as mock:
  155. mock.get_status = MagicMock(return_value=MagicMock(
  156. connected=True,
  157. state="IDLE",
  158. progress=0,
  159. temperatures={"nozzle": 25, "bed": 25, "chamber": 25},
  160. raw_data={}
  161. ))
  162. mock.mark_printer_offline = MagicMock()
  163. yield mock
  164. # ============================================================================
  165. # Factory Fixtures for Test Data
  166. # ============================================================================
  167. @pytest.fixture
  168. def smart_plug_factory(db_session):
  169. """Factory to create test smart plugs."""
  170. async def _create_plug(**kwargs):
  171. from backend.app.models.smart_plug import SmartPlug
  172. defaults = {
  173. "name": "Test Plug",
  174. "ip_address": "192.168.1.100",
  175. "enabled": True,
  176. "auto_on": True,
  177. "auto_off": True,
  178. "off_delay_mode": "time",
  179. "off_delay_minutes": 5,
  180. "off_temp_threshold": 70,
  181. "schedule_enabled": False,
  182. "power_alert_enabled": False,
  183. }
  184. defaults.update(kwargs)
  185. plug = SmartPlug(**defaults)
  186. db_session.add(plug)
  187. await db_session.commit()
  188. await db_session.refresh(plug)
  189. return plug
  190. return _create_plug
  191. @pytest.fixture
  192. def printer_factory(db_session):
  193. """Factory to create test printers."""
  194. _counter = [0] # Use list to allow mutation in nested function
  195. async def _create_printer(**kwargs):
  196. from backend.app.models.printer import Printer
  197. _counter[0] += 1
  198. counter = _counter[0]
  199. defaults = {
  200. "name": "Test Printer",
  201. "serial_number": f"00M09A{counter:09d}", # Unique serial per printer
  202. "ip_address": f"192.168.1.{100 + counter}", # Unique IP per printer
  203. "access_code": "12345678",
  204. "is_active": True,
  205. "auto_archive": True,
  206. "model": "X1C",
  207. }
  208. defaults.update(kwargs)
  209. printer = Printer(**defaults)
  210. db_session.add(printer)
  211. await db_session.commit()
  212. await db_session.refresh(printer)
  213. return printer
  214. return _create_printer
  215. @pytest.fixture
  216. def notification_provider_factory(db_session):
  217. """Factory to create test notification providers."""
  218. async def _create_provider(**kwargs):
  219. from backend.app.models.notification import NotificationProvider
  220. config = kwargs.pop("config", {"server": "https://ntfy.sh", "topic": "test-topic"})
  221. if isinstance(config, dict):
  222. config = json.dumps(config)
  223. defaults = {
  224. "name": "Test Provider",
  225. "provider_type": "ntfy",
  226. "enabled": True,
  227. "config": config,
  228. "on_print_start": True,
  229. "on_print_complete": True,
  230. "on_print_failed": True,
  231. "on_print_stopped": True,
  232. "on_print_progress": False,
  233. "on_printer_offline": False,
  234. "on_printer_error": False,
  235. "on_filament_low": False,
  236. "on_maintenance_due": False,
  237. "on_ams_humidity_high": False,
  238. "on_ams_temperature_high": False,
  239. "quiet_hours_enabled": False,
  240. "daily_digest_enabled": False,
  241. }
  242. defaults.update(kwargs)
  243. provider = NotificationProvider(**defaults)
  244. db_session.add(provider)
  245. await db_session.commit()
  246. await db_session.refresh(provider)
  247. return provider
  248. return _create_provider
  249. @pytest.fixture
  250. def archive_factory(db_session):
  251. """Factory to create test archives."""
  252. async def _create_archive(printer_id: int, **kwargs):
  253. from backend.app.models.archive import PrintArchive
  254. defaults = {
  255. "printer_id": printer_id,
  256. "filename": "test_print.gcode.3mf",
  257. "print_name": "Test Print",
  258. "file_path": "archives/test/test_print.gcode.3mf",
  259. "file_size": 1024000,
  260. "status": "completed",
  261. "filament_type": "PLA",
  262. "filament_used_grams": 50.0,
  263. "print_time_seconds": 3600,
  264. }
  265. defaults.update(kwargs)
  266. archive = PrintArchive(**defaults)
  267. db_session.add(archive)
  268. await db_session.commit()
  269. await db_session.refresh(archive)
  270. return archive
  271. return _create_archive
  272. # ============================================================================
  273. # Sample Data Fixtures
  274. # ============================================================================
  275. @pytest.fixture
  276. def sample_mqtt_print_start():
  277. """Sample MQTT message for print start."""
  278. return {
  279. "print": {
  280. "command": "project_file",
  281. "param": "/sdcard/test.gcode.3mf",
  282. "subtask_name": "test_print",
  283. "gcode_state": "RUNNING",
  284. "mc_percent": 0,
  285. }
  286. }
  287. @pytest.fixture
  288. def sample_mqtt_print_complete():
  289. """Sample MQTT message for print complete."""
  290. return {
  291. "print": {
  292. "gcode_state": "FINISH",
  293. "mc_percent": 100,
  294. "subtask_name": "test_print",
  295. }
  296. }
  297. @pytest.fixture
  298. def sample_printer_status():
  299. """Sample printer status data."""
  300. return {
  301. "connected": True,
  302. "state": "IDLE",
  303. "progress": 0,
  304. "layer_num": 0,
  305. "total_layers": 0,
  306. "temperatures": {
  307. "nozzle": 25.0,
  308. "bed": 25.0,
  309. "chamber": 25.0,
  310. },
  311. "remaining_time": 0,
  312. "filename": None,
  313. }
  314. # ============================================================================
  315. # Log Capture Fixtures for Error Detection
  316. # ============================================================================
  317. class LogCapture(logging.Handler):
  318. """Handler that captures log records for testing."""
  319. def __init__(self):
  320. super().__init__()
  321. self.records: list[logging.LogRecord] = []
  322. def emit(self, record: logging.LogRecord):
  323. self.records.append(record)
  324. def clear(self):
  325. self.records.clear()
  326. def get_errors(self) -> list[logging.LogRecord]:
  327. """Get all ERROR and CRITICAL level records."""
  328. return [r for r in self.records if r.levelno >= logging.ERROR]
  329. def get_warnings(self) -> list[logging.LogRecord]:
  330. """Get all WARNING level records."""
  331. return [r for r in self.records if r.levelno == logging.WARNING]
  332. def has_errors(self) -> bool:
  333. """Check if any errors were logged."""
  334. return len(self.get_errors()) > 0
  335. def format_errors(self) -> str:
  336. """Format all errors as a string for assertion messages."""
  337. errors = self.get_errors()
  338. if not errors:
  339. return "No errors"
  340. formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
  341. return "\n".join(formatter.format(r) for r in errors)
  342. @pytest.fixture
  343. def capture_logs():
  344. """Fixture that captures log output during a test.
  345. Usage:
  346. def test_something(capture_logs):
  347. # Do something that might log errors
  348. some_function()
  349. # Check no errors were logged
  350. assert not capture_logs.has_errors(), capture_logs.format_errors()
  351. """
  352. handler = LogCapture()
  353. handler.setLevel(logging.DEBUG)
  354. # Attach to root logger to capture all logs
  355. root_logger = logging.getLogger()
  356. root_logger.addHandler(handler)
  357. yield handler
  358. root_logger.removeHandler(handler)
  359. @pytest.fixture
  360. def assert_no_log_errors(capture_logs):
  361. """Fixture that automatically asserts no errors were logged.
  362. Usage:
  363. def test_something(assert_no_log_errors):
  364. # If any ERROR logs occur during this test, it will fail
  365. some_function()
  366. """
  367. yield capture_logs
  368. errors = capture_logs.get_errors()
  369. if errors:
  370. pytest.fail(f"Unexpected log errors:\n{capture_logs.format_errors()}")