test_db_dialect.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. """Unit tests for database dialect helpers and PostgreSQL compatibility."""
  2. from unittest.mock import AsyncMock, patch
  3. import pytest
  4. class TestDialectDetection:
  5. """Test is_sqlite() and is_postgres() detection."""
  6. def test_sqlite_detected(self):
  7. with patch("backend.app.core.config.settings") as mock_settings:
  8. mock_settings.database_url = "sqlite+aiosqlite:///path/to/db.sqlite"
  9. from backend.app.core.db_dialect import is_postgres, is_sqlite
  10. assert is_sqlite() is True
  11. assert is_postgres() is False
  12. def test_postgres_detected(self):
  13. with patch("backend.app.core.config.settings") as mock_settings:
  14. mock_settings.database_url = "postgresql+asyncpg://user:pass@host:5432/db"
  15. from backend.app.core.db_dialect import is_postgres, is_sqlite
  16. assert is_postgres() is True
  17. assert is_sqlite() is False
  18. class TestRunPragma:
  19. """Test that PRAGMAs only run on SQLite."""
  20. @pytest.mark.asyncio
  21. async def test_pragma_runs_on_sqlite(self):
  22. with patch("backend.app.core.db_dialect.is_sqlite", return_value=True):
  23. from backend.app.core.db_dialect import run_pragma
  24. mock_conn = AsyncMock()
  25. await run_pragma(mock_conn, "PRAGMA journal_mode = WAL")
  26. mock_conn.execute.assert_called_once()
  27. @pytest.mark.asyncio
  28. async def test_pragma_skipped_on_postgres(self):
  29. with patch("backend.app.core.db_dialect.is_sqlite", return_value=False):
  30. from backend.app.core.db_dialect import run_pragma
  31. mock_conn = AsyncMock()
  32. await run_pragma(mock_conn, "PRAGMA journal_mode = WAL")
  33. mock_conn.execute.assert_not_called()
  34. class TestTimezoneStripping:
  35. """Test that the before_cursor_execute event strips timezone info."""
  36. def test_strip_aware_datetime(self):
  37. """Verify the timezone stripping logic works correctly."""
  38. import datetime
  39. aware = datetime.datetime(2026, 4, 3, 10, 0, 0, tzinfo=datetime.timezone.utc)
  40. naive = aware.replace(tzinfo=None)
  41. def _strip(val):
  42. if isinstance(val, datetime.datetime) and val.tzinfo is not None:
  43. return val.replace(tzinfo=None)
  44. return val
  45. assert _strip(aware) == naive
  46. assert _strip(aware).tzinfo is None
  47. assert _strip(naive) == naive
  48. assert _strip("not a datetime") == "not a datetime"
  49. assert _strip(None) is None
  50. def test_strip_in_dict_params(self):
  51. """Verify timezone stripping works on dict parameters."""
  52. import datetime
  53. aware = datetime.datetime(2026, 4, 3, 10, 0, 0, tzinfo=datetime.timezone.utc)
  54. def _strip(val):
  55. if isinstance(val, datetime.datetime) and val.tzinfo is not None:
  56. return val.replace(tzinfo=None)
  57. return val
  58. params = {"name": "test", "created_at": aware, "count": 5}
  59. result = {k: _strip(v) for k, v in params.items()}
  60. assert result["created_at"].tzinfo is None
  61. assert result["name"] == "test"
  62. assert result["count"] == 5
  63. def test_strip_in_tuple_params(self):
  64. """Verify timezone stripping works on tuple parameters."""
  65. import datetime
  66. aware = datetime.datetime(2026, 4, 3, 10, 0, 0, tzinfo=datetime.timezone.utc)
  67. def _strip(val):
  68. if isinstance(val, datetime.datetime) and val.tzinfo is not None:
  69. return val.replace(tzinfo=None)
  70. return val
  71. params = ("test", aware, 5)
  72. result = tuple(_strip(v) for v in params)
  73. assert result[1].tzinfo is None
  74. assert result[0] == "test"
  75. def test_naive_datetime_unchanged(self):
  76. """Naive datetimes should pass through untouched."""
  77. import datetime
  78. naive = datetime.datetime(2026, 4, 3, 10, 0, 0)
  79. def _strip(val):
  80. if isinstance(val, datetime.datetime) and val.tzinfo is not None:
  81. return val.replace(tzinfo=None)
  82. return val
  83. result = _strip(naive)
  84. assert result == naive
  85. assert result.tzinfo is None
  86. class TestCrossDatabaseConversion:
  87. """Test SQLite→Postgres type conversion logic used in cross-database import."""
  88. def test_boolean_conversion(self):
  89. """SQLite stores booleans as 0/1, Postgres needs Python bool."""
  90. assert bool(0) is False
  91. assert bool(1) is True
  92. def test_datetime_string_conversion(self):
  93. """SQLite stores datetimes as strings, Postgres needs datetime objects."""
  94. from datetime import datetime
  95. val = "2026-04-02 11:01:52.105147"
  96. result = datetime.fromisoformat(val)
  97. assert result.year == 2026
  98. assert result.month == 4
  99. assert result.microsecond == 105147
  100. def test_datetime_with_timezone_string(self):
  101. """SQLite may store timezone-aware strings."""
  102. from datetime import datetime
  103. val = "2026-04-02T11:01:52+00:00"
  104. result = datetime.fromisoformat(val)
  105. assert result.year == 2026
  106. def test_json_serialization_for_backup(self):
  107. """JSON/list/dict values must be serialized for SQLite backup."""
  108. import json
  109. values = [{"key": "val"}, [1, 2, 3], "plain string", 42, None]
  110. for val in values:
  111. if isinstance(val, (list, dict)):
  112. serialized = json.dumps(val)
  113. assert isinstance(serialized, str)
  114. else:
  115. assert val == val # noqa: PLR0124 — no conversion needed
  116. class TestSafeExecutePattern:
  117. """Test _safe_execute error handling logic."""
  118. def test_safe_execute_catches_expected_exceptions(self):
  119. """Verify _safe_execute catches both OperationalError and ProgrammingError."""
  120. from sqlalchemy.exc import OperationalError, ProgrammingError
  121. for exc_type in (OperationalError, ProgrammingError):
  122. try:
  123. raise exc_type("test", [], Exception("column already exists"))
  124. except (OperationalError, ProgrammingError):
  125. pass
  126. def test_safe_execute_would_not_catch_integrity_error(self):
  127. """IntegrityError should NOT be caught by _safe_execute."""
  128. from sqlalchemy.exc import IntegrityError, OperationalError, ProgrammingError
  129. with pytest.raises(IntegrityError):
  130. try:
  131. raise IntegrityError("test", [], Exception("unique violation"))
  132. except (OperationalError, ProgrammingError):
  133. pass
  134. @pytest.mark.asyncio
  135. async def test_safe_execute_reraises_non_idempotency_errors(self):
  136. """Non-idempotency errors must propagate so startup fails loudly."""
  137. from sqlalchemy.exc import OperationalError
  138. from sqlalchemy.ext.asyncio import create_async_engine
  139. from backend.app.core.database import _safe_execute
  140. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  141. async with engine.begin() as conn:
  142. with pytest.raises(OperationalError):
  143. await _safe_execute(conn, "SELECT * FROM nonexistent_table_xyz")
  144. await engine.dispose()
  145. @pytest.mark.asyncio
  146. async def test_safe_execute_swallows_already_exists(self):
  147. """Idempotency errors (already exists) must be silently ignored."""
  148. from sqlalchemy import text
  149. from sqlalchemy.ext.asyncio import create_async_engine
  150. from backend.app.core.database import _safe_execute
  151. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  152. async with engine.begin() as conn:
  153. await conn.execute(text("CREATE TABLE t (id INTEGER)"))
  154. # Second CREATE must not raise
  155. await _safe_execute(conn, "CREATE TABLE t (id INTEGER)")
  156. await engine.dispose()
  157. @pytest.mark.asyncio
  158. async def test_provider_email_lowercasing_migration(self):
  159. """SEC-3: provider_email normalisation lowers mixed-case values, leaves NULL intact.
  160. The production migration runs this UPDATE directly (not via _safe_execute)
  161. so any failure is always fatal and visible at startup.
  162. """
  163. from sqlalchemy import text
  164. from sqlalchemy.ext.asyncio import create_async_engine
  165. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  166. async with engine.begin() as conn:
  167. await conn.execute(text("CREATE TABLE user_oidc_links (id INTEGER PRIMARY KEY, provider_email TEXT)"))
  168. await conn.execute(text("INSERT INTO user_oidc_links VALUES (1, 'User@Example.COM')"))
  169. await conn.execute(text("INSERT INTO user_oidc_links VALUES (2, 'already@lower.com')"))
  170. await conn.execute(text("INSERT INTO user_oidc_links VALUES (3, NULL)"))
  171. async with conn.begin_nested():
  172. await conn.execute(
  173. text(
  174. "UPDATE user_oidc_links SET provider_email = LOWER(provider_email) "
  175. "WHERE provider_email IS NOT NULL AND provider_email != LOWER(provider_email)"
  176. )
  177. )
  178. result = await conn.execute(text("SELECT provider_email FROM user_oidc_links ORDER BY id"))
  179. rows = [r[0] for r in result.fetchall()]
  180. await engine.dispose()
  181. assert rows[0] == "user@example.com"
  182. assert rows[1] == "already@lower.com"
  183. assert rows[2] is None
  184. @pytest.mark.asyncio
  185. async def test_safe_execute_swallows_no_such_column_for_rename(self):
  186. """'no such column' is swallowed for RENAME COLUMN idempotency.
  187. When a column has already been renamed, re-running the RENAME COLUMN
  188. migration raises 'no such column' — that must be silently swallowed.
  189. DML safety is guaranteed by never passing DML through _safe_execute.
  190. """
  191. from sqlalchemy.ext.asyncio import create_async_engine
  192. from backend.app.core.database import _safe_execute
  193. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  194. async with engine.begin() as conn:
  195. await conn.execute(__import__("sqlalchemy").text("CREATE TABLE t (id INTEGER, new_col INTEGER)"))
  196. # Column 'old_col' does not exist — simulates re-running a RENAME COLUMN migration
  197. # Must NOT raise.
  198. await _safe_execute(conn, "ALTER TABLE t RENAME COLUMN old_col TO new_col")
  199. await engine.dispose()
  200. @pytest.mark.asyncio
  201. async def test_safe_execute_swallows_duplicate_key(self):
  202. """'duplicate key' errors (PostgreSQL unique-constraint violations on re-run)
  203. must be silently swallowed for idempotent DDL migrations."""
  204. from unittest.mock import AsyncMock, MagicMock
  205. from sqlalchemy.exc import OperationalError
  206. from backend.app.core.database import _safe_execute
  207. fake_exc = OperationalError("duplicate key value violates unique constraint", [], Exception())
  208. # begin_nested() is called synchronously (not awaited) and returns an
  209. # async context manager. Use MagicMock so the call returns a regular
  210. # object, then attach __aenter__/__aexit__ for the async with protocol.
  211. nested_cm = MagicMock()
  212. nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
  213. # Raise on execute inside the context, simulating PG duplicate key
  214. nested_cm.execute = AsyncMock(side_effect=fake_exc)
  215. nested_cm.__aexit__ = AsyncMock(return_value=False)
  216. mock_conn = MagicMock()
  217. mock_conn.begin_nested.return_value = nested_cm
  218. mock_conn.execute = AsyncMock(side_effect=fake_exc)
  219. # Must NOT raise — "duplicate key" is in the swallow-list
  220. await _safe_execute(mock_conn, "CREATE UNIQUE INDEX ...")
  221. @pytest.mark.asyncio
  222. async def test_check_constraint_false_true_on_sqlite(self):
  223. """CheckConstraint with FALSE/TRUE literals is enforced on SQLite (3.23+)."""
  224. from sqlalchemy import text
  225. from sqlalchemy.exc import IntegrityError
  226. from sqlalchemy.ext.asyncio import create_async_engine
  227. engine = create_async_engine("sqlite+aiosqlite:///:memory:")
  228. async with engine.begin() as conn:
  229. await conn.execute(
  230. text("""
  231. CREATE TABLE ck_test (
  232. id INTEGER PRIMARY KEY,
  233. auto_link BOOLEAN,
  234. require_ev BOOLEAN,
  235. email_claim TEXT,
  236. CHECK (auto_link = FALSE OR (require_ev = TRUE AND email_claim = 'email'))
  237. )
  238. """)
  239. )
  240. # Valid: auto_link=0 (FALSE)
  241. await conn.execute(text("INSERT INTO ck_test VALUES (1, 0, 0, 'upn')"))
  242. # Valid: auto_link=1, require_ev=1, email_claim='email'
  243. await conn.execute(text("INSERT INTO ck_test VALUES (2, 1, 1, 'email')"))
  244. async with engine.begin() as conn:
  245. # Invalid: auto_link=1 but conditions not met
  246. with pytest.raises(IntegrityError):
  247. await conn.execute(text("INSERT INTO ck_test VALUES (3, 1, 0, 'email')"))
  248. await engine.dispose()