conftest.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. """Shared test fixtures for BamBuddy backend tests."""
  2. import asyncio
  3. import json
  4. import os
  5. import sys
  6. import pytest
  7. from typing import AsyncGenerator
  8. from datetime import datetime
  9. from unittest.mock import AsyncMock, MagicMock, patch
  10. # IMPORTANT: Set environment variables BEFORE any app imports
  11. # This must happen before settings/config are loaded
  12. os.environ["LOG_TO_FILE"] = "false"
  13. os.environ["DEBUG"] = "false"
  14. from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
  15. from httpx import AsyncClient, ASGITransport
  16. # Ensure settings use our env vars - import and override before database import
  17. from backend.app.core.config import settings
  18. settings.log_to_file = False
  19. from backend.app.core.database import Base
  20. # Use in-memory SQLite for tests
  21. TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
  22. @pytest.fixture(scope="session")
  23. def event_loop():
  24. """Create an instance of the default event loop for each test session."""
  25. loop = asyncio.get_event_loop_policy().new_event_loop()
  26. yield loop
  27. loop.close()
  28. @pytest.fixture
  29. async def test_engine():
  30. """Create a test database engine."""
  31. engine = create_async_engine(TEST_DATABASE_URL, echo=False)
  32. # Import all models to register them
  33. from backend.app.models import (
  34. printer, archive, filament, settings, smart_plug,
  35. print_queue, notification, maintenance, kprofile_note,
  36. notification_template, external_link, project, api_key,
  37. ams_history
  38. )
  39. async with engine.begin() as conn:
  40. await conn.run_sync(Base.metadata.create_all)
  41. yield engine
  42. async with engine.begin() as conn:
  43. await conn.run_sync(Base.metadata.drop_all)
  44. await engine.dispose()
  45. @pytest.fixture
  46. async def db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
  47. """Create a test database session."""
  48. async_session_maker = async_sessionmaker(
  49. test_engine, class_=AsyncSession, expire_on_commit=False
  50. )
  51. async with async_session_maker() as session:
  52. yield session
  53. @pytest.fixture
  54. async def async_client(test_engine, db_session) -> AsyncGenerator[AsyncClient, None]:
  55. """Create an async test client."""
  56. from backend.app.main import app
  57. from backend.app.core.database import get_db, async_session
  58. # Create a new session maker for the test engine
  59. test_async_session = async_sessionmaker(
  60. test_engine, class_=AsyncSession, expire_on_commit=False
  61. )
  62. async def override_get_db():
  63. async with test_async_session() as session:
  64. yield session
  65. app.dependency_overrides[get_db] = override_get_db
  66. # Also patch the module-level async_session used by services
  67. with patch('backend.app.core.database.async_session', test_async_session):
  68. async with AsyncClient(
  69. transport=ASGITransport(app=app),
  70. base_url="http://test"
  71. ) as client:
  72. yield client
  73. app.dependency_overrides.clear()
  74. # ============================================================================
  75. # Mock External Services
  76. # ============================================================================
  77. @pytest.fixture
  78. def mock_tasmota_service():
  79. """Mock the Tasmota service for smart plug tests."""
  80. # Patch both the module where it's defined and where it's imported
  81. with patch('backend.app.services.tasmota.tasmota_service') as mock, \
  82. patch('backend.app.api.routes.smart_plugs.tasmota_service') as mock2:
  83. mock.turn_on = AsyncMock(return_value=True)
  84. mock.turn_off = AsyncMock(return_value=True)
  85. mock.toggle = AsyncMock(return_value=True)
  86. mock.get_status = AsyncMock(return_value={
  87. "state": "ON",
  88. "reachable": True,
  89. "device_name": "Test Plug"
  90. })
  91. mock.get_energy = AsyncMock(return_value={
  92. "power": 150.5,
  93. "voltage": 120.0,
  94. "current": 1.25,
  95. "today": 2.5,
  96. "total": 100.0,
  97. "factor": 0.95,
  98. })
  99. mock.test_connection = AsyncMock(return_value={
  100. "success": True,
  101. "state": "ON",
  102. "device_name": "Test Plug"
  103. })
  104. # Copy mocks to second patch target
  105. mock2.turn_on = mock.turn_on
  106. mock2.turn_off = mock.turn_off
  107. mock2.toggle = mock.toggle
  108. mock2.get_status = mock.get_status
  109. mock2.get_energy = mock.get_energy
  110. mock2.test_connection = mock.test_connection
  111. yield mock
  112. @pytest.fixture
  113. def mock_mqtt_client():
  114. """Mock the MQTT client for printer communication tests."""
  115. with patch('backend.app.services.bambu_mqtt.BambuMQTTClient') as mock:
  116. instance = MagicMock()
  117. instance.state = MagicMock(
  118. connected=True,
  119. state="IDLE",
  120. progress=0,
  121. temperatures={"nozzle": 25, "bed": 25}
  122. )
  123. instance.connect = MagicMock()
  124. instance.disconnect = MagicMock()
  125. mock.return_value = instance
  126. yield mock
  127. @pytest.fixture
  128. def mock_ftp_client():
  129. """Mock the FTP client for file transfer tests."""
  130. with patch('backend.app.services.bambu_ftp.download_file_async') as download_mock, \
  131. patch('backend.app.services.bambu_ftp.list_files_async') as list_mock:
  132. download_mock.return_value = True
  133. list_mock.return_value = []
  134. yield {"download": download_mock, "list": list_mock}
  135. @pytest.fixture
  136. def mock_httpx_client():
  137. """Mock httpx for webhook/notification HTTP calls."""
  138. with patch('httpx.AsyncClient') as mock_class:
  139. mock_instance = AsyncMock()
  140. mock_response = MagicMock()
  141. mock_response.status_code = 200
  142. mock_response.text = "OK"
  143. mock_response.json.return_value = {}
  144. mock_instance.get = AsyncMock(return_value=mock_response)
  145. mock_instance.post = AsyncMock(return_value=mock_response)
  146. mock_instance.__aenter__ = AsyncMock(return_value=mock_instance)
  147. mock_instance.__aexit__ = AsyncMock()
  148. mock_class.return_value = mock_instance
  149. yield mock_instance
  150. @pytest.fixture
  151. def mock_printer_manager():
  152. """Mock the printer manager for status checks."""
  153. with patch('backend.app.services.printer_manager.printer_manager') as mock:
  154. mock.get_status = MagicMock(return_value=MagicMock(
  155. connected=True,
  156. state="IDLE",
  157. progress=0,
  158. temperatures={"nozzle": 25, "bed": 25, "chamber": 25},
  159. raw_data={}
  160. ))
  161. mock.mark_printer_offline = MagicMock()
  162. yield mock
  163. # ============================================================================
  164. # Factory Fixtures for Test Data
  165. # ============================================================================
  166. @pytest.fixture
  167. def smart_plug_factory(db_session):
  168. """Factory to create test smart plugs."""
  169. async def _create_plug(**kwargs):
  170. from backend.app.models.smart_plug import SmartPlug
  171. defaults = {
  172. "name": "Test Plug",
  173. "ip_address": "192.168.1.100",
  174. "enabled": True,
  175. "auto_on": True,
  176. "auto_off": True,
  177. "off_delay_mode": "time",
  178. "off_delay_minutes": 5,
  179. "off_temp_threshold": 70,
  180. "schedule_enabled": False,
  181. "power_alert_enabled": False,
  182. }
  183. defaults.update(kwargs)
  184. plug = SmartPlug(**defaults)
  185. db_session.add(plug)
  186. await db_session.commit()
  187. await db_session.refresh(plug)
  188. return plug
  189. return _create_plug
  190. @pytest.fixture
  191. def printer_factory(db_session):
  192. """Factory to create test printers."""
  193. _counter = [0] # Use list to allow mutation in nested function
  194. async def _create_printer(**kwargs):
  195. from backend.app.models.printer import Printer
  196. _counter[0] += 1
  197. counter = _counter[0]
  198. defaults = {
  199. "name": "Test Printer",
  200. "serial_number": f"00M09A{counter:09d}", # Unique serial per printer
  201. "ip_address": f"192.168.1.{100 + counter}", # Unique IP per printer
  202. "access_code": "12345678",
  203. "is_active": True,
  204. "auto_archive": True,
  205. "model": "X1C",
  206. }
  207. defaults.update(kwargs)
  208. printer = Printer(**defaults)
  209. db_session.add(printer)
  210. await db_session.commit()
  211. await db_session.refresh(printer)
  212. return printer
  213. return _create_printer
  214. @pytest.fixture
  215. def notification_provider_factory(db_session):
  216. """Factory to create test notification providers."""
  217. async def _create_provider(**kwargs):
  218. from backend.app.models.notification import NotificationProvider
  219. config = kwargs.pop("config", {"server": "https://ntfy.sh", "topic": "test-topic"})
  220. if isinstance(config, dict):
  221. config = json.dumps(config)
  222. defaults = {
  223. "name": "Test Provider",
  224. "provider_type": "ntfy",
  225. "enabled": True,
  226. "config": config,
  227. "on_print_start": True,
  228. "on_print_complete": True,
  229. "on_print_failed": True,
  230. "on_print_stopped": True,
  231. "on_print_progress": False,
  232. "on_printer_offline": False,
  233. "on_printer_error": False,
  234. "on_filament_low": False,
  235. "on_maintenance_due": False,
  236. "on_ams_humidity_high": False,
  237. "on_ams_temperature_high": False,
  238. "quiet_hours_enabled": False,
  239. "daily_digest_enabled": False,
  240. }
  241. defaults.update(kwargs)
  242. provider = NotificationProvider(**defaults)
  243. db_session.add(provider)
  244. await db_session.commit()
  245. await db_session.refresh(provider)
  246. return provider
  247. return _create_provider
  248. @pytest.fixture
  249. def archive_factory(db_session):
  250. """Factory to create test archives."""
  251. async def _create_archive(printer_id: int, **kwargs):
  252. from backend.app.models.archive import PrintArchive
  253. defaults = {
  254. "printer_id": printer_id,
  255. "filename": "test_print.gcode.3mf",
  256. "print_name": "Test Print",
  257. "file_path": "archives/test/test_print.gcode.3mf",
  258. "file_size": 1024000,
  259. "status": "completed",
  260. "filament_type": "PLA",
  261. "filament_used_grams": 50.0,
  262. "print_time_seconds": 3600,
  263. }
  264. defaults.update(kwargs)
  265. archive = PrintArchive(**defaults)
  266. db_session.add(archive)
  267. await db_session.commit()
  268. await db_session.refresh(archive)
  269. return archive
  270. return _create_archive
  271. # ============================================================================
  272. # Sample Data Fixtures
  273. # ============================================================================
  274. @pytest.fixture
  275. def sample_mqtt_print_start():
  276. """Sample MQTT message for print start."""
  277. return {
  278. "print": {
  279. "command": "project_file",
  280. "param": "/sdcard/test.gcode.3mf",
  281. "subtask_name": "test_print",
  282. "gcode_state": "RUNNING",
  283. "mc_percent": 0,
  284. }
  285. }
  286. @pytest.fixture
  287. def sample_mqtt_print_complete():
  288. """Sample MQTT message for print complete."""
  289. return {
  290. "print": {
  291. "gcode_state": "FINISH",
  292. "mc_percent": 100,
  293. "subtask_name": "test_print",
  294. }
  295. }
  296. @pytest.fixture
  297. def sample_printer_status():
  298. """Sample printer status data."""
  299. return {
  300. "connected": True,
  301. "state": "IDLE",
  302. "progress": 0,
  303. "layer_num": 0,
  304. "total_layers": 0,
  305. "temperatures": {
  306. "nozzle": 25.0,
  307. "bed": 25.0,
  308. "chamber": 25.0,
  309. },
  310. "remaining_time": 0,
  311. "filename": None,
  312. }