conftest.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. """Shared test fixtures for BamBuddy backend tests."""
  2. import asyncio
  3. import atexit
  4. import json
  5. import logging
  6. import os
  7. import shutil
  8. import tempfile
  9. from collections.abc import AsyncGenerator
  10. from pathlib import Path
  11. from unittest.mock import AsyncMock, MagicMock, patch
  12. import pytest
  13. # IMPORTANT: Set environment variables BEFORE any app imports
  14. # This must happen before settings/config are loaded
  15. os.environ["LOG_TO_FILE"] = "false"
  16. os.environ["DEBUG"] = "false"
  17. from httpx import ASGITransport, AsyncClient # noqa: E402
  18. from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine # noqa: E402
  19. # Ensure settings use our env vars - import and override before database import
  20. from backend.app.core.config import settings # noqa: E402
  21. settings.log_to_file = False
  22. # Use a temp directory for plate calibration to avoid deleting real calibration files
  23. _test_plate_cal_dir = Path(tempfile.mkdtemp(prefix="bambuddy_test_plate_cal_"))
  24. settings.plate_calibration_dir = _test_plate_cal_dir
  25. # Clean up temp directory when tests finish
  26. def _cleanup_test_plate_cal_dir():
  27. if _test_plate_cal_dir.exists():
  28. shutil.rmtree(_test_plate_cal_dir, ignore_errors=True)
  29. atexit.register(_cleanup_test_plate_cal_dir)
  30. from backend.app.core.database import Base # noqa: E402
  31. # Use in-memory SQLite for tests
  32. TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
  33. @pytest.fixture(autouse=True)
  34. def mfa_encryption_isolation(monkeypatch, tmp_path):
  35. """Per-test isolation for MFA encryption state.
  36. - Sets ``DATA_DIR`` to an isolated tmp path so the auto-bootstrap can
  37. never write ``.mfa_encryption_key`` into the repo or share state
  38. across tests / xdist workers.
  39. - Removes any inherited ``MFA_ENCRYPTION_KEY`` env var.
  40. - With ``DATA_DIR`` pointing at a writable ``tmp_path``, the default
  41. bootstrap path on first ``_get_fernet()`` call is **auto-generation**
  42. (key_source='generated'), NOT plaintext fallback. Tests that need the
  43. plaintext fallback path must monkeypatch ``_load_or_generate_key`` to
  44. return ``(None, 'none')`` (or 'none_write_failed' / 'none_corrupted')
  45. explicitly — see ``test_plaintext_passthrough_without_key`` for an
  46. example.
  47. - Resets the ``encryption`` module-level singletons before AND after the
  48. test so reorder doesn't leak cached Fernet instances.
  49. Tests that want to exercise an active key should call
  50. ``monkeypatch.setenv("MFA_ENCRYPTION_KEY", valid_key)`` and
  51. ``enc_mod._fernet_instance = None`` inside the test body — the autouse
  52. fixture only sets defaults, it doesn't lock them in.
  53. """
  54. from backend.app.core import encryption as enc_mod
  55. monkeypatch.setenv("DATA_DIR", str(tmp_path))
  56. monkeypatch.delenv("MFA_ENCRYPTION_KEY", raising=False)
  57. enc_mod._fernet_instance = None
  58. enc_mod._warn_shown = False
  59. enc_mod._key_source = None
  60. yield
  61. enc_mod._fernet_instance = None
  62. enc_mod._warn_shown = False
  63. enc_mod._key_source = None
  64. @pytest.fixture(scope="session")
  65. def event_loop():
  66. """Create an instance of the default event loop for each test session."""
  67. loop = asyncio.get_event_loop_policy().new_event_loop()
  68. yield loop
  69. # Dispose the module-level engine so aiosqlite worker threads finish
  70. # before the event loop closes, preventing "Event loop is closed" errors.
  71. from backend.app.core.database import engine
  72. loop.run_until_complete(engine.dispose())
  73. loop.run_until_complete(asyncio.sleep(0.05))
  74. loop.close()
  75. @pytest.fixture
  76. async def test_engine():
  77. """Create a test database engine."""
  78. engine = create_async_engine(TEST_DATABASE_URL, echo=False)
  79. # Import all models to register them
  80. from backend.app.models import (
  81. ams_history,
  82. ams_label,
  83. api_key,
  84. archive,
  85. auth_ephemeral,
  86. color_catalog,
  87. external_link,
  88. filament,
  89. group,
  90. kprofile_note,
  91. maintenance,
  92. notification,
  93. notification_template,
  94. oidc_provider,
  95. print_log,
  96. print_queue,
  97. printer,
  98. project,
  99. project_bom,
  100. settings,
  101. slot_preset,
  102. smart_plug,
  103. smart_plug_energy_snapshot, # noqa: F401
  104. spool,
  105. spool_assignment,
  106. spool_catalog,
  107. spool_k_profile,
  108. spool_usage_history,
  109. spoolbuddy_device,
  110. spoolman_k_profile,
  111. spoolman_slot_assignment,
  112. user,
  113. user_email_pref,
  114. user_otp_code,
  115. user_totp,
  116. virtual_printer,
  117. )
  118. async with engine.begin() as conn:
  119. await conn.run_sync(Base.metadata.create_all)
  120. yield engine
  121. async with engine.begin() as conn:
  122. await conn.run_sync(Base.metadata.drop_all)
  123. await engine.dispose()
  124. # Allow aiosqlite's background thread to finish processing the close
  125. # response before the per-function event loop shuts down, preventing
  126. # "RuntimeError: Event loop is closed" in call_soon_threadsafe.
  127. await asyncio.sleep(0.1)
  128. @pytest.fixture
  129. async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
  130. """Create a test database session."""
  131. async_session_maker = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
  132. async with async_session_maker() as session:
  133. yield session
  134. @pytest.fixture
  135. async def async_client(test_engine, db_session) -> AsyncGenerator[AsyncClient, None]:
  136. """Create an async test client."""
  137. from backend.app.core.database import async_session, get_db
  138. from backend.app.main import app
  139. # Create a new session maker for the test engine
  140. test_async_session = async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
  141. async def override_get_db():
  142. async with test_async_session() as session:
  143. yield session
  144. app.dependency_overrides[get_db] = override_get_db
  145. # Mock init_printer_connections to prevent MQTT connection attempts during tests
  146. async def mock_init_printer_connections(db):
  147. pass # No-op - don't connect to real printers
  148. # Also patch the module-level async_session used by services, auth, and middleware
  149. with (
  150. patch("backend.app.core.database.async_session", test_async_session),
  151. patch("backend.app.core.auth.async_session", test_async_session),
  152. patch("backend.app.main.async_session", test_async_session),
  153. patch("backend.app.main.init_printer_connections", mock_init_printer_connections),
  154. ):
  155. # Seed default groups for tests that need them
  156. from backend.app.core.database import seed_default_groups
  157. await seed_default_groups()
  158. async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
  159. yield client
  160. # The app lifespan called init_db() which used the module-level engine
  161. # (not the test engine), creating aiosqlite connections. Dispose those
  162. # connections so their background threads finish before the event loop closes.
  163. from backend.app.core.database import engine as real_engine
  164. await real_engine.dispose()
  165. app.dependency_overrides.clear()
  166. # ============================================================================
  167. # Mock External Services
  168. # ============================================================================
  169. @pytest.fixture
  170. def mock_tasmota_service():
  171. """Mock the Tasmota service for smart plug tests."""
  172. # Patch both the module where it's defined and where it's imported
  173. with (
  174. patch("backend.app.services.tasmota.tasmota_service") as mock,
  175. patch("backend.app.api.routes.smart_plugs.tasmota_service") as mock2,
  176. ):
  177. mock.turn_on = AsyncMock(return_value=True)
  178. mock.turn_off = AsyncMock(return_value=True)
  179. mock.toggle = AsyncMock(return_value=True)
  180. mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test Plug"})
  181. mock.get_energy = AsyncMock(
  182. return_value={
  183. "power": 150.5,
  184. "voltage": 120.0,
  185. "current": 1.25,
  186. "today": 2.5,
  187. "total": 100.0,
  188. "factor": 0.95,
  189. }
  190. )
  191. mock.test_connection = AsyncMock(return_value={"success": True, "state": "ON", "device_name": "Test Plug"})
  192. # Copy mocks to second patch target
  193. mock2.turn_on = mock.turn_on
  194. mock2.turn_off = mock.turn_off
  195. mock2.toggle = mock.toggle
  196. mock2.get_status = mock.get_status
  197. mock2.get_energy = mock.get_energy
  198. mock2.test_connection = mock.test_connection
  199. yield mock
  200. @pytest.fixture
  201. def mock_homeassistant_service():
  202. """Mock the Home Assistant service for smart plug tests."""
  203. # Patch both the module where it's defined and where it's imported
  204. with (
  205. patch("backend.app.services.homeassistant.homeassistant_service") as mock,
  206. patch("backend.app.api.routes.smart_plugs.homeassistant_service") as mock2,
  207. ):
  208. mock.turn_on = AsyncMock(return_value=True)
  209. mock.turn_off = AsyncMock(return_value=True)
  210. mock.toggle = AsyncMock(return_value=True)
  211. mock.get_status = AsyncMock(return_value={"state": "ON", "reachable": True, "device_name": "Test HA Entity"})
  212. mock.get_energy = AsyncMock(return_value=None) # Most HA entities don't have power monitoring
  213. mock.test_connection = AsyncMock(return_value={"success": True, "message": "API running", "error": None})
  214. mock.list_entities = AsyncMock(
  215. return_value=[
  216. {
  217. "entity_id": "switch.printer_plug",
  218. "friendly_name": "Printer Plug",
  219. "state": "on",
  220. "domain": "switch",
  221. },
  222. {"entity_id": "switch.test", "friendly_name": "Test Switch", "state": "off", "domain": "switch"},
  223. ]
  224. )
  225. mock.configure = MagicMock()
  226. # Copy mocks to second patch target
  227. mock2.turn_on = mock.turn_on
  228. mock2.turn_off = mock.turn_off
  229. mock2.toggle = mock.toggle
  230. mock2.get_status = mock.get_status
  231. mock2.get_energy = mock.get_energy
  232. mock2.test_connection = mock.test_connection
  233. mock2.list_entities = mock.list_entities
  234. mock2.configure = mock.configure
  235. yield mock
  236. @pytest.fixture
  237. def mock_mqtt_client():
  238. """Mock the MQTT client for printer communication tests."""
  239. with patch("backend.app.services.bambu_mqtt.BambuMQTTClient") as mock:
  240. instance = MagicMock()
  241. instance.state = MagicMock(connected=True, state="IDLE", progress=0, temperatures={"nozzle": 25, "bed": 25})
  242. instance.connect = MagicMock()
  243. instance.disconnect = MagicMock()
  244. mock.return_value = instance
  245. yield mock
  246. @pytest.fixture
  247. def mock_mqtt_smart_plug_service():
  248. """Mock the MQTT smart plug service for MQTT plug tests."""
  249. with patch("backend.app.api.routes.smart_plugs.mqtt_relay") as mock:
  250. # Create a mock smart_plug_service
  251. mock_service = MagicMock()
  252. mock_service.is_configured = MagicMock(return_value=True)
  253. mock_service.has_broker_settings = MagicMock(return_value=True)
  254. mock_service.configure = AsyncMock(return_value=True)
  255. mock_service.subscribe = MagicMock()
  256. mock_service.unsubscribe = MagicMock()
  257. mock_service.get_plug_data = MagicMock(return_value=None)
  258. mock_service.is_reachable = MagicMock(return_value=False)
  259. mock.smart_plug_service = mock_service
  260. yield mock
  261. @pytest.fixture
  262. def mock_ftp_client():
  263. """Mock the FTP client for file transfer tests."""
  264. with (
  265. patch("backend.app.services.bambu_ftp.download_file_async") as download_mock,
  266. patch("backend.app.services.bambu_ftp.list_files_async") as list_mock,
  267. ):
  268. download_mock.return_value = True
  269. list_mock.return_value = []
  270. yield {"download": download_mock, "list": list_mock}
  271. @pytest.fixture
  272. def mock_httpx_client():
  273. """Mock httpx for webhook/notification HTTP calls."""
  274. with patch("httpx.AsyncClient") as mock_class:
  275. mock_instance = AsyncMock()
  276. mock_response = MagicMock()
  277. mock_response.status_code = 200
  278. mock_response.text = "OK"
  279. mock_response.json.return_value = {}
  280. mock_instance.get = AsyncMock(return_value=mock_response)
  281. mock_instance.post = AsyncMock(return_value=mock_response)
  282. mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
  283. mock_instance.__aexit__ = AsyncMock()
  284. mock_class.return_value = mock_instance
  285. yield mock_instance
  286. @pytest.fixture
  287. def mock_printer_manager():
  288. """Mock the printer manager for status checks."""
  289. with patch("backend.app.services.printer_manager.printer_manager") as mock:
  290. mock.get_status = MagicMock(
  291. return_value=MagicMock(
  292. connected=True,
  293. state="IDLE",
  294. progress=0,
  295. temperatures={"nozzle": 25, "bed": 25, "chamber": 25},
  296. raw_data={},
  297. )
  298. )
  299. mock.mark_printer_offline = MagicMock()
  300. yield mock
  301. # ============================================================================
  302. # Factory Fixtures for Test Data
  303. # ============================================================================
  304. @pytest.fixture
  305. def smart_plug_factory(db_session):
  306. """Factory to create test smart plugs."""
  307. async def _create_plug(**kwargs):
  308. from backend.app.models.smart_plug import SmartPlug
  309. # Determine defaults based on plug_type
  310. plug_type = kwargs.get("plug_type", "tasmota")
  311. defaults = {
  312. "name": "Test Plug",
  313. "plug_type": plug_type,
  314. "enabled": True,
  315. "auto_on": True,
  316. "auto_off": True,
  317. "off_delay_mode": "time",
  318. "off_delay_minutes": 5,
  319. "off_temp_threshold": 70,
  320. "schedule_enabled": False,
  321. "power_alert_enabled": False,
  322. }
  323. # Set required fields based on plug_type
  324. if plug_type == "homeassistant":
  325. defaults["ha_entity_id"] = "switch.test"
  326. defaults["ip_address"] = None
  327. elif plug_type == "mqtt":
  328. # Legacy fields (for backward compatibility tests)
  329. defaults["mqtt_topic"] = kwargs.get("mqtt_topic", "test/topic")
  330. defaults["mqtt_multiplier"] = kwargs.get("mqtt_multiplier", 1.0)
  331. # New separate topic/path/multiplier fields
  332. defaults["mqtt_power_topic"] = kwargs.get("mqtt_power_topic")
  333. defaults["mqtt_power_path"] = kwargs.get("mqtt_power_path", "power")
  334. defaults["mqtt_power_multiplier"] = kwargs.get("mqtt_power_multiplier", 1.0)
  335. defaults["mqtt_energy_topic"] = kwargs.get("mqtt_energy_topic")
  336. defaults["mqtt_energy_path"] = kwargs.get("mqtt_energy_path")
  337. defaults["mqtt_energy_multiplier"] = kwargs.get("mqtt_energy_multiplier", 1.0)
  338. defaults["mqtt_state_topic"] = kwargs.get("mqtt_state_topic")
  339. defaults["mqtt_state_path"] = kwargs.get("mqtt_state_path")
  340. defaults["mqtt_state_on_value"] = kwargs.get("mqtt_state_on_value")
  341. defaults["ip_address"] = None
  342. defaults["ha_entity_id"] = None
  343. elif plug_type == "rest":
  344. defaults["rest_on_url"] = kwargs.get("rest_on_url", "http://192.168.1.100/api/plug/on")
  345. defaults["rest_off_url"] = kwargs.get("rest_off_url", "http://192.168.1.100/api/plug/off")
  346. defaults["rest_method"] = kwargs.get("rest_method", "POST")
  347. defaults["ip_address"] = None
  348. defaults["ha_entity_id"] = None
  349. else:
  350. defaults["ip_address"] = "192.168.1.100"
  351. defaults["ha_entity_id"] = None
  352. defaults.update(kwargs)
  353. plug = SmartPlug(**defaults)
  354. db_session.add(plug)
  355. await db_session.commit()
  356. await db_session.refresh(plug)
  357. return plug
  358. return _create_plug
  359. @pytest.fixture
  360. def printer_factory(db_session):
  361. """Factory to create test printers."""
  362. _counter = [0] # Use list to allow mutation in nested function
  363. async def _create_printer(**kwargs):
  364. from backend.app.models.printer import Printer
  365. _counter[0] += 1
  366. counter = _counter[0]
  367. defaults = {
  368. "name": "Test Printer",
  369. "serial_number": f"00M09A{counter:09d}", # Unique serial per printer
  370. "ip_address": f"192.168.1.{100 + counter}", # Unique IP per printer
  371. "access_code": "12345678",
  372. "is_active": True,
  373. "auto_archive": True,
  374. "model": "X1C",
  375. }
  376. defaults.update(kwargs)
  377. printer = Printer(**defaults)
  378. db_session.add(printer)
  379. await db_session.commit()
  380. await db_session.refresh(printer)
  381. return printer
  382. return _create_printer
  383. @pytest.fixture
  384. def notification_provider_factory(db_session):
  385. """Factory to create test notification providers."""
  386. async def _create_provider(**kwargs):
  387. from backend.app.models.notification import NotificationProvider
  388. config = kwargs.pop("config", {"server": "https://ntfy.sh", "topic": "test-topic"})
  389. if isinstance(config, dict):
  390. config = json.dumps(config)
  391. defaults = {
  392. "name": "Test Provider",
  393. "provider_type": "ntfy",
  394. "enabled": True,
  395. "config": config,
  396. "on_print_start": True,
  397. "on_print_complete": True,
  398. "on_print_failed": True,
  399. "on_print_stopped": True,
  400. "on_print_progress": False,
  401. "on_print_missing_spool_assignment": False,
  402. "on_printer_offline": False,
  403. "on_printer_error": False,
  404. "on_filament_low": False,
  405. "on_maintenance_due": False,
  406. "on_ams_humidity_high": False,
  407. "on_ams_temperature_high": False,
  408. "on_bed_cooled": False,
  409. "quiet_hours_enabled": False,
  410. "daily_digest_enabled": False,
  411. }
  412. defaults.update(kwargs)
  413. provider = NotificationProvider(**defaults)
  414. db_session.add(provider)
  415. await db_session.commit()
  416. await db_session.refresh(provider)
  417. return provider
  418. return _create_provider
  419. @pytest.fixture
  420. def archive_factory(db_session):
  421. """Factory to create test archives.
  422. Also synthesizes one PrintLogEntry per archive (matching the production
  423. flow where statistics are aggregated from PrintLogEntry, not PrintArchive,
  424. per #1378). Pass ``with_run=False`` to skip — useful for testing the
  425. "archived but never printed" state. Pass ``run_status=...`` to override
  426. the run's status independently of the archive's status field.
  427. """
  428. async def _create_archive(printer_id: int, **kwargs):
  429. from backend.app.models.archive import PrintArchive
  430. from backend.app.models.print_log import PrintLogEntry
  431. with_run = kwargs.pop("with_run", True)
  432. run_status = kwargs.pop("run_status", None)
  433. defaults = {
  434. "printer_id": printer_id,
  435. "filename": "test_print.gcode.3mf",
  436. "print_name": "Test Print",
  437. "file_path": "archives/test/test_print.gcode.3mf",
  438. "file_size": 1024000,
  439. "status": "completed",
  440. "filament_type": "PLA",
  441. "filament_used_grams": 50.0,
  442. "print_time_seconds": 3600,
  443. }
  444. defaults.update(kwargs)
  445. archive = PrintArchive(**defaults)
  446. db_session.add(archive)
  447. await db_session.commit()
  448. await db_session.refresh(archive)
  449. if with_run:
  450. duration = None
  451. if archive.started_at and archive.completed_at:
  452. duration = int((archive.completed_at - archive.started_at).total_seconds()) or None
  453. run = PrintLogEntry(
  454. archive_id=archive.id,
  455. printer_id=archive.printer_id,
  456. status=run_status or archive.status,
  457. started_at=archive.started_at,
  458. completed_at=archive.completed_at,
  459. duration_seconds=duration,
  460. filament_type=archive.filament_type,
  461. filament_color=archive.filament_color,
  462. filament_used_grams=archive.filament_used_grams,
  463. cost=archive.cost,
  464. energy_kwh=archive.energy_kwh,
  465. energy_cost=archive.energy_cost,
  466. failure_reason=archive.failure_reason,
  467. print_name=archive.print_name,
  468. created_by_id=archive.created_by_id,
  469. )
  470. db_session.add(run)
  471. await db_session.commit()
  472. return archive
  473. return _create_archive
  474. # ============================================================================
  475. # Sample Data Fixtures
  476. # ============================================================================
  477. @pytest.fixture
  478. def sample_mqtt_print_start():
  479. """Sample MQTT message for print start."""
  480. return {
  481. "print": {
  482. "command": "project_file",
  483. "param": "/sdcard/test.gcode.3mf",
  484. "subtask_name": "test_print",
  485. "gcode_state": "RUNNING",
  486. "mc_percent": 0,
  487. }
  488. }
  489. @pytest.fixture
  490. def sample_mqtt_print_complete():
  491. """Sample MQTT message for print complete."""
  492. return {
  493. "print": {
  494. "gcode_state": "FINISH",
  495. "mc_percent": 100,
  496. "subtask_name": "test_print",
  497. }
  498. }
  499. @pytest.fixture
  500. def sample_printer_status():
  501. """Sample printer status data."""
  502. return {
  503. "connected": True,
  504. "state": "IDLE",
  505. "progress": 0,
  506. "layer_num": 0,
  507. "total_layers": 0,
  508. "temperatures": {
  509. "nozzle": 25.0,
  510. "bed": 25.0,
  511. "chamber": 25.0,
  512. },
  513. "remaining_time": 0,
  514. "filename": None,
  515. }
  516. # ============================================================================
  517. # Log Capture Fixtures for Error Detection
  518. # ============================================================================
  519. class LogCapture(logging.Handler):
  520. """Handler that captures log records for testing."""
  521. def __init__(self):
  522. super().__init__()
  523. self.records: list[logging.LogRecord] = []
  524. def emit(self, record: logging.LogRecord):
  525. self.records.append(record)
  526. def clear(self):
  527. self.records.clear()
  528. def get_errors(self) -> list[logging.LogRecord]:
  529. """Get all ERROR and CRITICAL level records."""
  530. return [r for r in self.records if r.levelno >= logging.ERROR]
  531. def get_warnings(self) -> list[logging.LogRecord]:
  532. """Get all WARNING level records."""
  533. return [r for r in self.records if r.levelno == logging.WARNING]
  534. def has_errors(self) -> bool:
  535. """Check if any errors were logged."""
  536. return len(self.get_errors()) > 0
  537. def format_errors(self) -> str:
  538. """Format all errors as a string for assertion messages."""
  539. errors = self.get_errors()
  540. if not errors:
  541. return "No errors"
  542. formatter = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
  543. return "\n".join(formatter.format(r) for r in errors)
  544. @pytest.fixture
  545. def capture_logs():
  546. """Fixture that captures log output during a test.
  547. Usage:
  548. def test_something(capture_logs):
  549. # Do something that might log errors
  550. some_function()
  551. # Check no errors were logged
  552. assert not capture_logs.has_errors(), capture_logs.format_errors()
  553. """
  554. handler = LogCapture()
  555. handler.setLevel(logging.DEBUG)
  556. # Attach to root logger to capture all logs
  557. root_logger = logging.getLogger()
  558. root_logger.addHandler(handler)
  559. yield handler
  560. root_logger.removeHandler(handler)
  561. @pytest.fixture
  562. def assert_no_log_errors(capture_logs):
  563. """Fixture that automatically asserts no errors were logged.
  564. Usage:
  565. def test_something(assert_no_log_errors):
  566. # If any ERROR logs occur during this test, it will fail
  567. some_function()
  568. """
  569. yield capture_logs
  570. errors = capture_logs.get_errors()
  571. if errors:
  572. pytest.fail(f"Unexpected log errors:\n{capture_logs.format_errors()}")