test_db_dialect.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906
  1. """Unit tests for database dialect helpers and PostgreSQL compatibility."""
  2. from unittest.mock import AsyncMock, patch
  3. import pytest
  4. class TestDialectDetection:
  5. """Test is_sqlite() and is_postgres() detection."""
  6. def test_sqlite_detected(self):
  7. with patch("backend.app.core.config.settings") as mock_settings:
  8. mock_settings.database_url = "sqlite+aiosqlite:///path/to/db.sqlite"
  9. from backend.app.core.db_dialect import is_postgres, is_sqlite
  10. assert is_sqlite() is True
  11. assert is_postgres() is False
  12. def test_postgres_detected(self):
  13. with patch("backend.app.core.config.settings") as mock_settings:
  14. mock_settings.database_url = "postgresql+asyncpg://user:pass@host:5432/db"
  15. from backend.app.core.db_dialect import is_postgres, is_sqlite
  16. assert is_postgres() is True
  17. assert is_sqlite() is False
  18. class TestRunPragma:
  19. """Test that PRAGMAs only run on SQLite."""
  20. @pytest.mark.asyncio
  21. async def test_pragma_runs_on_sqlite(self):
  22. with patch("backend.app.core.db_dialect.is_sqlite", return_value=True):
  23. from backend.app.core.db_dialect import run_pragma
  24. mock_conn = AsyncMock()
  25. await run_pragma(mock_conn, "PRAGMA journal_mode = WAL")
  26. mock_conn.execute.assert_called_once()
  27. @pytest.mark.asyncio
  28. async def test_pragma_skipped_on_postgres(self):
  29. with patch("backend.app.core.db_dialect.is_sqlite", return_value=False):
  30. from backend.app.core.db_dialect import run_pragma
  31. mock_conn = AsyncMock()
  32. await run_pragma(mock_conn, "PRAGMA journal_mode = WAL")
  33. mock_conn.execute.assert_not_called()
  34. class TestTimezoneStripping:
  35. """Test that the before_cursor_execute event strips timezone info."""
  36. def test_strip_aware_datetime(self):
  37. """Verify the timezone stripping logic works correctly."""
  38. import datetime
  39. aware = datetime.datetime(2026, 4, 3, 10, 0, 0, tzinfo=datetime.timezone.utc)
  40. naive = aware.replace(tzinfo=None)
  41. def _strip(val):
  42. if isinstance(val, datetime.datetime) and val.tzinfo is not None:
  43. return val.replace(tzinfo=None)
  44. return val
  45. assert _strip(aware) == naive
  46. assert _strip(aware).tzinfo is None
  47. assert _strip(naive) == naive
  48. assert _strip("not a datetime") == "not a datetime"
  49. assert _strip(None) is None
  50. def test_strip_in_dict_params(self):
  51. """Verify timezone stripping works on dict parameters."""
  52. import datetime
  53. aware = datetime.datetime(2026, 4, 3, 10, 0, 0, tzinfo=datetime.timezone.utc)
  54. def _strip(val):
  55. if isinstance(val, datetime.datetime) and val.tzinfo is not None:
  56. return val.replace(tzinfo=None)
  57. return val
  58. params = {"name": "test", "created_at": aware, "count": 5}
  59. result = {k: _strip(v) for k, v in params.items()}
  60. assert result["created_at"].tzinfo is None
  61. assert result["name"] == "test"
  62. assert result["count"] == 5
  63. def test_strip_in_tuple_params(self):
  64. """Verify timezone stripping works on tuple parameters."""
  65. import datetime
  66. aware = datetime.datetime(2026, 4, 3, 10, 0, 0, tzinfo=datetime.timezone.utc)
  67. def _strip(val):
  68. if isinstance(val, datetime.datetime) and val.tzinfo is not None:
  69. return val.replace(tzinfo=None)
  70. return val
  71. params = ("test", aware, 5)
  72. result = tuple(_strip(v) for v in params)
  73. assert result[1].tzinfo is None
  74. assert result[0] == "test"
  75. def test_naive_datetime_unchanged(self):
  76. """Naive datetimes should pass through untouched."""
  77. import datetime
  78. naive = datetime.datetime(2026, 4, 3, 10, 0, 0)
  79. def _strip(val):
  80. if isinstance(val, datetime.datetime) and val.tzinfo is not None:
  81. return val.replace(tzinfo=None)
  82. return val
  83. result = _strip(naive)
  84. assert result == naive
  85. assert result.tzinfo is None
  86. class TestCrossDatabaseConversion:
  87. """Test SQLite→Postgres type conversion logic used in cross-database import."""
  88. def test_boolean_conversion(self):
  89. """SQLite stores booleans as 0/1, Postgres needs Python bool."""
  90. assert bool(0) is False
  91. assert bool(1) is True
  92. def test_datetime_string_conversion(self):
  93. """SQLite stores datetimes as strings, Postgres needs datetime objects."""
  94. from datetime import datetime
  95. val = "2026-04-02 11:01:52.105147"
  96. result = datetime.fromisoformat(val)
  97. assert result.year == 2026
  98. assert result.month == 4
  99. assert result.microsecond == 105147
  100. def test_datetime_with_timezone_string(self):
  101. """SQLite may store timezone-aware strings."""
  102. from datetime import datetime
  103. val = "2026-04-02T11:01:52+00:00"
  104. result = datetime.fromisoformat(val)
  105. assert result.year == 2026
  106. def test_json_serialization_for_backup(self):
  107. """JSON/list/dict values must be serialized for SQLite backup."""
  108. import json
  109. values = [{"key": "val"}, [1, 2, 3], "plain string", 42, None]
  110. for val in values:
  111. if isinstance(val, (list, dict)):
  112. serialized = json.dumps(val)
  113. assert isinstance(serialized, str)
  114. else:
  115. assert val == val # noqa: PLR0124 — no conversion needed
  116. class TestSafeExecutePattern:
  117. """Test _safe_execute error handling logic."""
  118. def test_safe_execute_catches_expected_exceptions(self):
  119. """Verify _safe_execute catches both OperationalError and ProgrammingError."""
  120. from sqlalchemy.exc import OperationalError, ProgrammingError
  121. for exc_type in (OperationalError, ProgrammingError):
  122. try:
  123. raise exc_type("test", [], Exception("column already exists"))
  124. except (OperationalError, ProgrammingError):
  125. pass
  126. def test_safe_execute_would_not_catch_integrity_error(self):
  127. """IntegrityError should NOT be caught by _safe_execute."""
  128. from sqlalchemy.exc import IntegrityError, OperationalError, ProgrammingError
  129. with pytest.raises(IntegrityError):
  130. try:
  131. raise IntegrityError("test", [], Exception("unique violation"))
  132. except (OperationalError, ProgrammingError):
  133. pass
  134. @pytest.mark.asyncio
  135. async def test_safe_execute_reraises_non_idempotency_errors(self):
  136. """Non-idempotency errors must propagate so startup fails loudly."""
  137. from sqlalchemy.exc import OperationalError
  138. from sqlalchemy.ext.asyncio import create_async_engine
  139. from backend.app.core.database import _safe_execute
  140. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  141. async with engine.begin() as conn:
  142. with pytest.raises(OperationalError):
  143. await _safe_execute(conn, "SELECT * FROM nonexistent_table_xyz")
  144. await engine.dispose()
  145. @pytest.mark.asyncio
  146. async def test_safe_execute_swallows_already_exists(self):
  147. """Idempotency errors (already exists) must be silently ignored."""
  148. from sqlalchemy import text
  149. from sqlalchemy.ext.asyncio import create_async_engine
  150. from backend.app.core.database import _safe_execute
  151. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  152. async with engine.begin() as conn:
  153. await conn.execute(text("CREATE TABLE t (id INTEGER)"))
  154. # Second CREATE must not raise
  155. await _safe_execute(conn, "CREATE TABLE t (id INTEGER)")
  156. await engine.dispose()
  157. @pytest.mark.asyncio
  158. async def test_provider_email_lowercasing_migration(self):
  159. """SEC-3: provider_email normalisation lowers mixed-case values, leaves NULL intact.
  160. The production migration runs this UPDATE directly (not via _safe_execute)
  161. so any failure is always fatal and visible at startup.
  162. """
  163. from sqlalchemy import text
  164. from sqlalchemy.ext.asyncio import create_async_engine
  165. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  166. async with engine.begin() as conn:
  167. await conn.execute(text("CREATE TABLE user_oidc_links (id INTEGER PRIMARY KEY, provider_email TEXT)"))
  168. await conn.execute(text("INSERT INTO user_oidc_links VALUES (1, 'User@Example.COM')"))
  169. await conn.execute(text("INSERT INTO user_oidc_links VALUES (2, 'already@lower.com')"))
  170. await conn.execute(text("INSERT INTO user_oidc_links VALUES (3, NULL)"))
  171. async with conn.begin_nested():
  172. await conn.execute(
  173. text(
  174. "UPDATE user_oidc_links SET provider_email = LOWER(provider_email) "
  175. "WHERE provider_email IS NOT NULL AND provider_email != LOWER(provider_email)"
  176. )
  177. )
  178. result = await conn.execute(text("SELECT provider_email FROM user_oidc_links ORDER BY id"))
  179. rows = [r[0] for r in result.fetchall()]
  180. await engine.dispose()
  181. assert rows[0] == "user@example.com"
  182. assert rows[1] == "already@lower.com"
  183. assert rows[2] is None
  184. @pytest.mark.asyncio
  185. async def test_safe_execute_swallows_no_such_column_for_rename(self):
  186. """'no such column' is swallowed for RENAME COLUMN idempotency.
  187. When a column has already been renamed, re-running the RENAME COLUMN
  188. migration raises 'no such column' — that must be silently swallowed.
  189. DML safety is guaranteed by never passing DML through _safe_execute.
  190. """
  191. from sqlalchemy.ext.asyncio import create_async_engine
  192. from backend.app.core.database import _safe_execute
  193. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  194. async with engine.begin() as conn:
  195. await conn.execute(__import__("sqlalchemy").text("CREATE TABLE t (id INTEGER, new_col INTEGER)"))
  196. # Column 'old_col' does not exist — simulates re-running a RENAME COLUMN migration
  197. # Must NOT raise.
  198. await _safe_execute(conn, "ALTER TABLE t RENAME COLUMN old_col TO new_col")
  199. await engine.dispose()
  200. @pytest.mark.asyncio
  201. async def test_safe_execute_swallows_does_not_exist_for_rename_postgres(self):
  202. """'does not exist' (PostgreSQL UndefinedColumnError) is swallowed for RENAME COLUMN idempotency."""
  203. from unittest.mock import AsyncMock, MagicMock
  204. from sqlalchemy.exc import ProgrammingError
  205. from backend.app.core.database import _safe_execute
  206. fake_exc = ProgrammingError('column "quantity_printed" does not exist', [], Exception())
  207. nested_cm = MagicMock()
  208. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  209. nested_cm.execute = AsyncMock(side_effect=fake_exc)
  210. nested_cm.__aexit__ = AsyncMock(return_value=False)
  211. mock_conn = MagicMock()
  212. mock_conn.begin_nested.return_value = nested_cm
  213. mock_conn.execute = AsyncMock(side_effect=fake_exc)
  214. # Must NOT raise — "does not exist" is in the swallow-list
  215. await _safe_execute(
  216. mock_conn, "ALTER TABLE project_bom_items RENAME COLUMN quantity_printed TO quantity_acquired"
  217. )
  218. @pytest.mark.asyncio
  219. async def test_safe_execute_swallows_duplicate_key(self):
  220. """'duplicate key' errors (PostgreSQL unique-constraint violations on re-run)
  221. must be silently swallowed for idempotent DDL migrations."""
  222. from unittest.mock import AsyncMock, MagicMock
  223. from sqlalchemy.exc import OperationalError
  224. from backend.app.core.database import _safe_execute
  225. fake_exc = OperationalError("duplicate key value violates unique constraint", [], Exception())
  226. # begin_nested() is called synchronously (not awaited) and returns an
  227. # async context manager. Use MagicMock so the call returns a regular
  228. # object, then attach __aenter__/__aexit__ for the async with protocol.
  229. nested_cm = MagicMock()
  230. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  231. # Raise on execute inside the context, simulating PG duplicate key
  232. nested_cm.execute = AsyncMock(side_effect=fake_exc)
  233. nested_cm.__aexit__ = AsyncMock(return_value=False)
  234. mock_conn = MagicMock()
  235. mock_conn.begin_nested.return_value = nested_cm
  236. mock_conn.execute = AsyncMock(side_effect=fake_exc)
  237. # Must NOT raise — "duplicate key" is in the swallow-list
  238. await _safe_execute(mock_conn, "CREATE UNIQUE INDEX ...")
  239. @pytest.mark.asyncio
  240. async def test_check_constraint_false_true_on_sqlite(self):
  241. """New constraint formula is enforced on SQLite (3.23+).
  242. New: auto_link = FALSE OR email_claim != 'email' OR require_ev = TRUE
  243. Blocks Fall B (auto_link=1 + email_claim='email' + require_ev=0).
  244. Allows Fall A (email_claim='email' + require_ev=1) and Fall C (custom claim).
  245. """
  246. from sqlalchemy import text
  247. from sqlalchemy.exc import IntegrityError
  248. from sqlalchemy.ext.asyncio import create_async_engine
  249. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  250. async with engine.begin() as conn:
  251. await conn.execute(
  252. text("""
  253. CREATE TABLE ck_test (
  254. id INTEGER PRIMARY KEY,
  255. auto_link BOOLEAN,
  256. require_ev BOOLEAN,
  257. email_claim TEXT,
  258. CHECK (auto_link = FALSE OR email_claim != 'email' OR require_ev = TRUE)
  259. )
  260. """)
  261. )
  262. # Valid: auto_link=0 (FALSE) — any combo allowed
  263. await conn.execute(text("INSERT INTO ck_test VALUES (1, 0, 0, 'upn')"))
  264. # Valid: Fall A — auto_link=1, require_ev=1, email_claim='email'
  265. await conn.execute(text("INSERT INTO ck_test VALUES (2, 1, 1, 'email')"))
  266. # Valid: Fall C — auto_link=1, email_claim='upn' (require_ev irrelevant)
  267. await conn.execute(text("INSERT INTO ck_test VALUES (3, 1, 0, 'upn')"))
  268. await conn.execute(text("INSERT INTO ck_test VALUES (4, 1, 1, 'upn')"))
  269. async with engine.begin() as conn:
  270. # Invalid: Fall B — auto_link=1 + email_claim='email' + require_ev=0
  271. with pytest.raises(IntegrityError):
  272. await conn.execute(text("INSERT INTO ck_test VALUES (5, 1, 0, 'email')"))
  273. await engine.dispose()
  274. @pytest.mark.asyncio
  275. async def test_auto_link_sec1_backfill_resets_unsafe_rows(self):
  276. """SEC-1 backfill resets auto_link=TRUE only for Fall B (email_claim='email' + require_ev=FALSE).
  277. Three cases:
  278. 1. auto_link=TRUE + email_claim='email' + require_ev=FALSE → reset to FALSE (Fall B, unsafe)
  279. 2. auto_link=TRUE + custom claim + require_ev=TRUE → unchanged (Fall C, now allowed)
  280. 3. auto_link=TRUE + email_claim='email' + require_ev=TRUE → unchanged (Fall A, safe)
  281. """
  282. from sqlalchemy import text
  283. from sqlalchemy.ext.asyncio import create_async_engine
  284. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  285. async with engine.begin() as conn:
  286. await conn.execute(
  287. text(
  288. "CREATE TABLE oidc_providers ("
  289. "id INTEGER PRIMARY KEY, "
  290. "auto_link_existing_accounts BOOLEAN, "
  291. "require_email_verified BOOLEAN, "
  292. "email_claim TEXT"
  293. ")"
  294. )
  295. )
  296. # Row 1: Fall B — email_claim='email' + require_ev=FALSE → must be reset
  297. await conn.execute(text("INSERT INTO oidc_providers VALUES (1, 1, 0, 'email')"))
  298. # Row 2: Fall C — custom claim → must NOT be reset (now allowed)
  299. await conn.execute(text("INSERT INTO oidc_providers VALUES (2, 1, 1, 'preferred_username')"))
  300. # Row 3: Fall A — email_claim='email' + require_ev=TRUE → must NOT be reset (always safe)
  301. await conn.execute(text("INSERT INTO oidc_providers VALUES (3, 1, 1, 'email')"))
  302. async with conn.begin_nested():
  303. await conn.execute(
  304. text(
  305. "UPDATE oidc_providers SET auto_link_existing_accounts = FALSE "
  306. "WHERE auto_link_existing_accounts = TRUE "
  307. "AND email_claim = 'email' AND require_email_verified = FALSE"
  308. )
  309. )
  310. result = await conn.execute(text("SELECT id, auto_link_existing_accounts FROM oidc_providers ORDER BY id"))
  311. rows = {r[0]: r[1] for r in result.fetchall()}
  312. await engine.dispose()
  313. assert rows[1] == 0, "Fall B (require_ev=FALSE) must be reset to FALSE"
  314. assert rows[2] == 1, "Fall C (custom claim) must remain TRUE"
  315. assert rows[3] == 1, "Fall A (require_ev=TRUE) must remain TRUE"
  316. @pytest.mark.asyncio
  317. async def test_safe_execute_reraises_does_not_exist_without_column(self):
  318. """'does not exist' without 'column' in the message must NOT be swallowed.
  319. This verifies that the narrowing from the broad 'does not exist' substring
  320. to the compound RENAME-COLUMN-only guard works correctly. A missing-relation
  321. error must propagate so the operator sees a startup failure rather than a
  322. silent schema gap.
  323. """
  324. from unittest.mock import AsyncMock, MagicMock
  325. from sqlalchemy.exc import ProgrammingError
  326. from backend.app.core.database import _safe_execute
  327. # PostgreSQL error for a missing relation — contains "does not exist" but NOT "column"
  328. fake_exc = ProgrammingError('relation "oidc_providers" does not exist', [], Exception())
  329. nested_cm = MagicMock()
  330. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  331. nested_cm.execute = AsyncMock(side_effect=fake_exc)
  332. nested_cm.__aexit__ = AsyncMock(return_value=False)
  333. mock_conn = MagicMock()
  334. mock_conn.begin_nested.return_value = nested_cm
  335. mock_conn.execute = AsyncMock(side_effect=fake_exc)
  336. # Must RAISE — "column" is absent so this is not RENAME COLUMN idempotency
  337. with pytest.raises(ProgrammingError):
  338. await _safe_execute(
  339. mock_conn, "ALTER TABLE oidc_providers ADD COLUMN auto_link_existing_accounts BOOLEAN DEFAULT false"
  340. )
  341. @pytest.mark.asyncio
  342. async def test_oidc_boolean_default_migrations_sqlite_defaults(self):
  343. """auto_link defaults to 0 (FALSE) and require_email_verified defaults to 1 (TRUE) on SQLite.
  344. Verifies that the SQLite branch of the BOOLEAN DEFAULT dialect-branch uses
  345. the correct integer literals so new rows get safe defaults without explicit values.
  346. """
  347. from sqlalchemy import text
  348. from sqlalchemy.ext.asyncio import create_async_engine
  349. from backend.app.core.database import _safe_execute
  350. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  351. async with engine.begin() as conn:
  352. await conn.execute(text("CREATE TABLE oidc_providers (id INTEGER PRIMARY KEY, name TEXT)"))
  353. await _safe_execute(
  354. conn, "ALTER TABLE oidc_providers ADD COLUMN auto_link_existing_accounts BOOLEAN DEFAULT 0"
  355. )
  356. await _safe_execute(conn, "ALTER TABLE oidc_providers ADD COLUMN require_email_verified BOOLEAN DEFAULT 1")
  357. await conn.execute(text("INSERT INTO oidc_providers (id, name) VALUES (1, 'test')"))
  358. result = await conn.execute(
  359. text("SELECT auto_link_existing_accounts, require_email_verified FROM oidc_providers WHERE id = 1")
  360. )
  361. row = result.fetchone()
  362. await engine.dispose()
  363. assert row[0] == 0, "auto_link_existing_accounts must default to 0 (FALSE) on SQLite"
  364. assert row[1] == 1, "require_email_verified must default to 1 (TRUE) on SQLite"
  365. @pytest.mark.asyncio
  366. async def test_safe_execute_column_not_exists_only_swallowed_for_rename(self):
  367. """'column … does not exist' is swallowed only when the SQL is RENAME COLUMN.
  368. The compound guard must NOT swallow the same error pattern when the SQL is
  369. an ADD COLUMN statement — that would indicate schema corruption, not idempotency.
  370. """
  371. from unittest.mock import AsyncMock, MagicMock
  372. from sqlalchemy.exc import ProgrammingError
  373. from backend.app.core.database import _safe_execute
  374. fake_exc = ProgrammingError('column "auto_link_existing_accounts" does not exist', [], Exception())
  375. nested_cm = MagicMock()
  376. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  377. nested_cm.execute = AsyncMock(side_effect=fake_exc)
  378. nested_cm.__aexit__ = AsyncMock(return_value=False)
  379. mock_conn = MagicMock()
  380. mock_conn.begin_nested.return_value = nested_cm
  381. mock_conn.execute = AsyncMock(side_effect=fake_exc)
  382. # ADD COLUMN statement — must RAISE even though message contains "column" + "does not exist"
  383. with pytest.raises(ProgrammingError):
  384. await _safe_execute(
  385. mock_conn, "ALTER TABLE oidc_providers ADD COLUMN auto_link_existing_accounts BOOLEAN DEFAULT false"
  386. )
  387. # RENAME COLUMN statement — must NOT raise (idempotency)
  388. await _safe_execute(
  389. mock_conn, "ALTER TABLE oidc_providers RENAME COLUMN auto_link_existing_accounts TO auto_link"
  390. )
  391. @pytest.mark.asyncio
  392. async def test_normalize_printer_ids_sqlite_uses_plain_comparison(self):
  393. """SQLite path executes plain string comparison (no cast)."""
  394. from sqlalchemy import text
  395. from sqlalchemy.ext.asyncio import create_async_engine
  396. from backend.app.core.database import _migrate_normalize_printer_ids
  397. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  398. async with engine.begin() as conn:
  399. await conn.execute(text("CREATE TABLE api_keys (id INTEGER PRIMARY KEY, printer_ids TEXT)"))
  400. await conn.execute(text("INSERT INTO api_keys VALUES (1, '[]')"))
  401. await conn.execute(text("INSERT INTO api_keys VALUES (2, '[1,2]')"))
  402. with patch("backend.app.core.database.is_sqlite", return_value=True):
  403. await _migrate_normalize_printer_ids(conn)
  404. result = await conn.execute(text("SELECT id, printer_ids FROM api_keys ORDER BY id"))
  405. rows = {r[0]: r[1] for r in result.fetchall()}
  406. await engine.dispose()
  407. assert rows[1] is None, "printer_ids='[]' must be normalised to NULL"
  408. assert rows[2] == "[1,2]", "non-empty printer_ids must be unchanged"
  409. @pytest.mark.asyncio
  410. async def test_normalize_printer_ids_postgres_uses_text_cast(self):
  411. """PostgreSQL path casts printer_ids to text for comparison (works for json and jsonb)."""
  412. from unittest.mock import AsyncMock, MagicMock
  413. from backend.app.core.database import _migrate_normalize_printer_ids
  414. nested_cm = MagicMock()
  415. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  416. nested_cm.__aexit__ = AsyncMock(return_value=False)
  417. mock_conn = MagicMock()
  418. mock_conn.begin_nested.return_value = nested_cm
  419. mock_conn.execute = AsyncMock()
  420. with patch("backend.app.core.database.is_sqlite", return_value=False):
  421. await _migrate_normalize_printer_ids(mock_conn)
  422. sql = mock_conn.execute.call_args[0][0].text
  423. assert "::text = '[]'" in sql, f"Expected ::text cast in SQL, got: {sql}"
  424. assert "printer_ids" in sql
  425. class TestSpoolmanTableDialect:
  426. """Phase 1: active_print_spoolman and spool_usage_history use dialect-correct DDL.
  427. These tables were created with raw 'INTEGER PRIMARY KEY AUTOINCREMENT' (SQLite-only
  428. syntax) before the fix. Now they branch on is_sqlite() exactly like
  429. smart_plug_energy_snapshots.
  430. """
  431. @pytest.mark.asyncio
  432. async def test_active_print_spoolman_sqlite_creates_table(self):
  433. """SQLite: active_print_spoolman is created with valid SQLite DDL."""
  434. from sqlalchemy import text
  435. from sqlalchemy.ext.asyncio import create_async_engine
  436. from backend.app.core.database import _safe_execute
  437. sql = """
  438. CREATE TABLE IF NOT EXISTS active_print_spoolman (
  439. id INTEGER PRIMARY KEY AUTOINCREMENT,
  440. printer_id INTEGER NOT NULL,
  441. archive_id INTEGER NOT NULL,
  442. filament_usage TEXT NOT NULL,
  443. ams_trays TEXT NOT NULL,
  444. slot_to_tray TEXT,
  445. layer_usage TEXT,
  446. filament_properties TEXT,
  447. UNIQUE(printer_id, archive_id)
  448. )
  449. """
  450. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  451. async with engine.begin() as conn:
  452. await _safe_execute(conn, sql)
  453. result = await conn.execute(
  454. text("SELECT name FROM sqlite_master WHERE type='table' AND name='active_print_spoolman'")
  455. )
  456. assert result.fetchone() is not None, "Table must be created on SQLite"
  457. await engine.dispose()
  458. @pytest.mark.asyncio
  459. async def test_active_print_spoolman_postgres_sql_uses_serial(self):
  460. """PostgreSQL: active_print_spoolman SQL uses SERIAL PRIMARY KEY, not AUTOINCREMENT."""
  461. from unittest.mock import AsyncMock, MagicMock
  462. from backend.app.core.database import _safe_execute
  463. captured_sql: list[str] = []
  464. nested_cm = MagicMock()
  465. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  466. nested_cm.__aexit__ = AsyncMock(return_value=False)
  467. async def capturing_execute(sql_or_text, *args, **kwargs):
  468. captured_sql.append(str(sql_or_text))
  469. nested_cm.execute = AsyncMock(side_effect=capturing_execute)
  470. mock_conn = MagicMock()
  471. mock_conn.begin_nested.return_value = nested_cm
  472. mock_conn.execute = AsyncMock(side_effect=capturing_execute)
  473. # PG path SQL — same string as in run_migrations() when is_sqlite() is False
  474. pg_sql = """
  475. CREATE TABLE IF NOT EXISTS active_print_spoolman (
  476. id SERIAL PRIMARY KEY,
  477. printer_id INTEGER NOT NULL REFERENCES printers(id) ON DELETE CASCADE,
  478. archive_id INTEGER NOT NULL REFERENCES print_archives(id) ON DELETE CASCADE,
  479. filament_usage TEXT NOT NULL,
  480. ams_trays TEXT NOT NULL,
  481. slot_to_tray TEXT,
  482. layer_usage TEXT,
  483. filament_properties TEXT,
  484. UNIQUE(printer_id, archive_id)
  485. )
  486. """
  487. await _safe_execute(mock_conn, pg_sql)
  488. assert captured_sql, "execute must have been called"
  489. combined = " ".join(captured_sql)
  490. assert "SERIAL PRIMARY KEY" in combined
  491. assert "AUTOINCREMENT" not in combined
  492. @pytest.mark.asyncio
  493. async def test_spool_usage_history_sqlite_creates_table(self):
  494. """SQLite: spool_usage_history is created with valid SQLite DDL."""
  495. from sqlalchemy import text
  496. from sqlalchemy.ext.asyncio import create_async_engine
  497. from backend.app.core.database import _safe_execute
  498. sql = """
  499. CREATE TABLE IF NOT EXISTS spool_usage_history (
  500. id INTEGER PRIMARY KEY AUTOINCREMENT,
  501. spool_id INTEGER NOT NULL,
  502. printer_id INTEGER,
  503. print_name VARCHAR(500),
  504. weight_used REAL NOT NULL DEFAULT 0,
  505. percent_used INTEGER NOT NULL DEFAULT 0,
  506. status VARCHAR(20) NOT NULL DEFAULT 'completed',
  507. created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
  508. )
  509. """
  510. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  511. async with engine.begin() as conn:
  512. await _safe_execute(conn, sql)
  513. result = await conn.execute(
  514. text("SELECT name FROM sqlite_master WHERE type='table' AND name='spool_usage_history'")
  515. )
  516. assert result.fetchone() is not None, "Table must be created on SQLite"
  517. await engine.dispose()
  518. @pytest.mark.asyncio
  519. async def test_spool_usage_history_postgres_sql_uses_serial_and_timestamp(self):
  520. """PostgreSQL: spool_usage_history SQL uses SERIAL and TIMESTAMP, not AUTOINCREMENT/DATETIME."""
  521. from unittest.mock import AsyncMock, MagicMock
  522. from backend.app.core.database import _safe_execute
  523. captured_sql: list[str] = []
  524. nested_cm = MagicMock()
  525. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  526. nested_cm.__aexit__ = AsyncMock(return_value=False)
  527. async def capturing_execute(sql_or_text, *args, **kwargs):
  528. captured_sql.append(str(sql_or_text))
  529. nested_cm.execute = AsyncMock(side_effect=capturing_execute)
  530. mock_conn = MagicMock()
  531. mock_conn.begin_nested.return_value = nested_cm
  532. mock_conn.execute = AsyncMock(side_effect=capturing_execute)
  533. pg_sql = """
  534. CREATE TABLE IF NOT EXISTS spool_usage_history (
  535. id SERIAL PRIMARY KEY,
  536. spool_id INTEGER NOT NULL REFERENCES spool(id) ON DELETE CASCADE,
  537. printer_id INTEGER REFERENCES printers(id) ON DELETE SET NULL,
  538. print_name VARCHAR(500),
  539. weight_used REAL NOT NULL DEFAULT 0,
  540. percent_used INTEGER NOT NULL DEFAULT 0,
  541. status VARCHAR(20) NOT NULL DEFAULT 'completed',
  542. created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
  543. )
  544. """
  545. await _safe_execute(mock_conn, pg_sql)
  546. assert captured_sql, "execute must have been called"
  547. combined = " ".join(captured_sql)
  548. assert "SERIAL PRIMARY KEY" in combined
  549. assert "TIMESTAMP" in combined
  550. assert "AUTOINCREMENT" not in combined
  551. assert "DATETIME" not in combined
  552. class TestAutoLinkConstraintMigration:
  553. """Tests for _migrate_update_auto_link_constraint (Fall C / Azure support)."""
  554. @pytest.mark.asyncio
  555. async def test_new_constraint_allows_fall_c_sqlite(self):
  556. """New formula allows auto_link=TRUE with a custom claim (Fall C)."""
  557. from sqlalchemy import text
  558. from sqlalchemy.exc import IntegrityError
  559. from sqlalchemy.ext.asyncio import create_async_engine
  560. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  561. async with engine.begin() as conn:
  562. await conn.execute(
  563. text(
  564. "CREATE TABLE oidc_providers_ck ("
  565. "id INTEGER PRIMARY KEY, "
  566. "auto_link BOOLEAN, "
  567. "require_ev BOOLEAN, "
  568. "email_claim TEXT, "
  569. "CHECK (auto_link = FALSE OR email_claim != 'email' OR require_ev = TRUE)"
  570. ")"
  571. )
  572. )
  573. # Fall C: custom claim + auto_link + require_ev=FALSE must pass
  574. await conn.execute(text("INSERT INTO oidc_providers_ck VALUES (1, 1, 0, 'upn')"))
  575. # Fall C: custom claim + auto_link + require_ev=TRUE must pass
  576. await conn.execute(text("INSERT INTO oidc_providers_ck VALUES (2, 1, 1, 'preferred_username')"))
  577. await engine.dispose()
  578. @pytest.mark.asyncio
  579. async def test_new_constraint_blocks_fall_b_sqlite(self):
  580. """New formula still blocks Fall B (email_claim='email' + require_ev=FALSE + auto_link=TRUE)."""
  581. from sqlalchemy import text
  582. from sqlalchemy.exc import IntegrityError
  583. from sqlalchemy.ext.asyncio import create_async_engine
  584. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  585. async with engine.begin() as conn:
  586. await conn.execute(
  587. text(
  588. "CREATE TABLE oidc_providers_ck ("
  589. "id INTEGER PRIMARY KEY, "
  590. "auto_link BOOLEAN, "
  591. "require_ev BOOLEAN, "
  592. "email_claim TEXT, "
  593. "CHECK (auto_link = FALSE OR email_claim != 'email' OR require_ev = TRUE)"
  594. ")"
  595. )
  596. )
  597. async with engine.begin() as conn:
  598. with pytest.raises(IntegrityError):
  599. await conn.execute(text("INSERT INTO oidc_providers_ck VALUES (1, 1, 0, 'email')"))
  600. await engine.dispose()
  601. @pytest.mark.asyncio
  602. async def test_constraint_migration_sqlite_recreates_table(self):
  603. """SQLite path recreates oidc_providers with new constraint when old formula is present."""
  604. from sqlalchemy import text
  605. from sqlalchemy.ext.asyncio import create_async_engine
  606. from backend.app.core.database import _migrate_update_auto_link_constraint
  607. # Create table with old constraint formula
  608. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  609. async with engine.begin() as conn:
  610. await conn.execute(
  611. text(
  612. "CREATE TABLE oidc_providers ("
  613. "id INTEGER NOT NULL PRIMARY KEY, "
  614. "name VARCHAR(100) NOT NULL UNIQUE, "
  615. "issuer_url VARCHAR(500) NOT NULL, "
  616. "client_id VARCHAR(255) NOT NULL, "
  617. "client_secret VARCHAR(512) NOT NULL, "
  618. "scopes VARCHAR(500), "
  619. "is_enabled BOOLEAN, "
  620. "auto_create_users BOOLEAN, "
  621. "auto_link_existing_accounts BOOLEAN DEFAULT 0, "
  622. "email_claim VARCHAR(64) DEFAULT 'email', "
  623. "require_email_verified BOOLEAN DEFAULT 1, "
  624. "icon_url TEXT, "
  625. "created_at DATETIME DEFAULT CURRENT_TIMESTAMP, "
  626. "updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, "
  627. "CONSTRAINT ck_auto_link_requires_verified_email_claim "
  628. "CHECK (auto_link_existing_accounts = FALSE OR "
  629. "(require_email_verified = TRUE AND email_claim = 'email'))"
  630. ")"
  631. )
  632. )
  633. await conn.execute(
  634. text(
  635. "INSERT INTO oidc_providers (id, name, issuer_url, client_id, client_secret, "
  636. "scopes, is_enabled, auto_create_users, auto_link_existing_accounts, "
  637. "email_claim, require_email_verified, icon_url, created_at, updated_at) "
  638. "VALUES (1, 'TestIdP', 'https://idp.test', 'cid', 'secret', "
  639. "'openid email', 1, 0, 0, 'email', 1, NULL, "
  640. "CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)"
  641. )
  642. )
  643. async with engine.begin() as conn:
  644. with patch("backend.app.core.database.is_sqlite", return_value=True):
  645. await _migrate_update_auto_link_constraint(conn)
  646. # Verify data survived
  647. result = await conn.execute(text("SELECT id, name FROM oidc_providers"))
  648. rows = result.fetchall()
  649. assert len(rows) == 1
  650. assert rows[0][0] == 1
  651. # Verify new constraint: Fall C (auto_link=TRUE + custom claim) must now be insertable
  652. await conn.execute(
  653. text(
  654. "INSERT INTO oidc_providers (id, name, issuer_url, client_id, client_secret, "
  655. "scopes, is_enabled, auto_create_users, auto_link_existing_accounts, "
  656. "email_claim, require_email_verified, icon_url, created_at, updated_at) "
  657. "VALUES (2, 'AzureIdP', 'https://azure.test', 'cid2', 'secret', "
  658. "'openid', 1, 0, 1, 'upn', 1, NULL, "
  659. "CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)"
  660. )
  661. )
  662. # Verify schema has new formula
  663. schema = (
  664. await conn.execute(text("SELECT sql FROM sqlite_master WHERE type='table' AND name='oidc_providers'"))
  665. ).fetchone()[0]
  666. assert "require_email_verified = TRUE AND email_claim = 'email'" not in schema
  667. assert "email_claim != 'email'" in schema
  668. await engine.dispose()
  669. @pytest.mark.asyncio
  670. async def test_constraint_migration_postgres_drops_and_recreates(self):
  671. """PostgreSQL path calls DROP CONSTRAINT IF EXISTS then ADD CONSTRAINT with new formula."""
  672. from unittest.mock import AsyncMock, MagicMock, call
  673. from backend.app.core.database import _migrate_update_auto_link_constraint
  674. # Track all SQL statements passed to _safe_execute by capturing conn.execute calls
  675. executed_sqls: list[str] = []
  676. async def fake_safe_execute(conn, sql):
  677. executed_sqls.append(sql)
  678. nested_cm = MagicMock()
  679. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  680. nested_cm.__aexit__ = AsyncMock(return_value=False)
  681. nested_cm.execute = AsyncMock()
  682. mock_conn = MagicMock()
  683. mock_conn.begin_nested.return_value = nested_cm
  684. mock_conn.execute = AsyncMock()
  685. with (
  686. patch("backend.app.core.database.is_sqlite", return_value=False),
  687. patch("backend.app.core.database._safe_execute", side_effect=fake_safe_execute),
  688. ):
  689. await _migrate_update_auto_link_constraint(mock_conn)
  690. assert len(executed_sqls) == 2
  691. drop_sql, add_sql = executed_sqls
  692. assert "DROP CONSTRAINT IF EXISTS" in drop_sql.upper()
  693. assert "ck_auto_link_requires_verified_email_claim" in drop_sql
  694. assert "ADD CONSTRAINT" in add_sql.upper()
  695. assert "email_claim != 'email'" in add_sql
  696. assert "require_email_verified = TRUE AND email_claim = 'email'" not in add_sql
  697. @pytest.mark.asyncio
  698. async def test_constraint_migration_sqlite_count_guard_raises_on_mismatch(self):
  699. """RuntimeError is raised when the copied row count doesn't match the source."""
  700. from unittest.mock import AsyncMock, MagicMock, patch
  701. import pytest
  702. from backend.app.core.database import _migrate_update_auto_link_constraint
  703. _OLD_SQL = (
  704. "CREATE TABLE oidc_providers (id INTEGER NOT NULL, "
  705. "CONSTRAINT ck_auto_link_requires_verified_email_claim "
  706. "CHECK (auto_link_existing_accounts = FALSE OR "
  707. "(require_email_verified = TRUE AND email_claim = 'email')))"
  708. )
  709. async def fake_execute(stmt):
  710. sql = str(stmt)
  711. result = MagicMock()
  712. if "sqlite_master" in sql:
  713. result.fetchone.return_value = (_OLD_SQL,)
  714. elif "count(*)" in sql.lower() and "oidc_providers_v2" not in sql:
  715. result.scalar_one.return_value = 2 # source has 2 rows
  716. elif "count(*)" in sql.lower() and "oidc_providers_v2" in sql:
  717. result.scalar_one.return_value = 1 # copy only has 1 — mismatch
  718. else:
  719. result.fetchone.return_value = None
  720. return result
  721. nested_cm = MagicMock()
  722. nested_cm.__aenter__ = AsyncMock(return_value=None)
  723. nested_cm.__aexit__ = AsyncMock(return_value=False) # don't suppress exceptions
  724. mock_conn = MagicMock()
  725. mock_conn.execute = AsyncMock(side_effect=fake_execute)
  726. mock_conn.begin_nested.return_value = nested_cm
  727. with (
  728. patch("backend.app.core.database.is_sqlite", return_value=True),
  729. pytest.raises(RuntimeError, match="mismatch"),
  730. ):
  731. await _migrate_update_auto_link_constraint(mock_conn)