|
@@ -316,7 +316,12 @@ class TestSafeExecutePattern:
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.asyncio
|
|
|
async def test_check_constraint_false_true_on_sqlite(self):
|
|
async def test_check_constraint_false_true_on_sqlite(self):
|
|
|
- """CheckConstraint with FALSE/TRUE literals is enforced on SQLite (3.23+)."""
|
|
|
|
|
|
|
+ """New constraint formula is enforced on SQLite (3.23+).
|
|
|
|
|
+
|
|
|
|
|
+ New: auto_link = FALSE OR email_claim != 'email' OR require_ev = TRUE
|
|
|
|
|
+ Blocks Fall B (auto_link=1 + email_claim='email' + require_ev=0).
|
|
|
|
|
+ Allows Fall A (email_claim='email' + require_ev=1) and Fall C (custom claim).
|
|
|
|
|
+ """
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy import text
|
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
@@ -330,29 +335,32 @@ class TestSafeExecutePattern:
|
|
|
auto_link BOOLEAN,
|
|
auto_link BOOLEAN,
|
|
|
require_ev BOOLEAN,
|
|
require_ev BOOLEAN,
|
|
|
email_claim TEXT,
|
|
email_claim TEXT,
|
|
|
- CHECK (auto_link = FALSE OR (require_ev = TRUE AND email_claim = 'email'))
|
|
|
|
|
|
|
+ CHECK (auto_link = FALSE OR email_claim != 'email' OR require_ev = TRUE)
|
|
|
)
|
|
)
|
|
|
""")
|
|
""")
|
|
|
)
|
|
)
|
|
|
- # Valid: auto_link=0 (FALSE)
|
|
|
|
|
|
|
+ # Valid: auto_link=0 (FALSE) — any combo allowed
|
|
|
await conn.execute(text("INSERT INTO ck_test VALUES (1, 0, 0, 'upn')"))
|
|
await conn.execute(text("INSERT INTO ck_test VALUES (1, 0, 0, 'upn')"))
|
|
|
- # Valid: auto_link=1, require_ev=1, email_claim='email'
|
|
|
|
|
|
|
+ # Valid: Fall A — auto_link=1, require_ev=1, email_claim='email'
|
|
|
await conn.execute(text("INSERT INTO ck_test VALUES (2, 1, 1, 'email')"))
|
|
await conn.execute(text("INSERT INTO ck_test VALUES (2, 1, 1, 'email')"))
|
|
|
|
|
+ # Valid: Fall C — auto_link=1, email_claim='upn' (require_ev irrelevant)
|
|
|
|
|
+ await conn.execute(text("INSERT INTO ck_test VALUES (3, 1, 0, 'upn')"))
|
|
|
|
|
+ await conn.execute(text("INSERT INTO ck_test VALUES (4, 1, 1, 'upn')"))
|
|
|
|
|
|
|
|
async with engine.begin() as conn:
|
|
async with engine.begin() as conn:
|
|
|
- # Invalid: auto_link=1 but conditions not met
|
|
|
|
|
|
|
+ # Invalid: Fall B — auto_link=1 + email_claim='email' + require_ev=0
|
|
|
with pytest.raises(IntegrityError):
|
|
with pytest.raises(IntegrityError):
|
|
|
- await conn.execute(text("INSERT INTO ck_test VALUES (3, 1, 0, 'email')"))
|
|
|
|
|
|
|
+ await conn.execute(text("INSERT INTO ck_test VALUES (5, 1, 0, 'email')"))
|
|
|
await engine.dispose()
|
|
await engine.dispose()
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.asyncio
|
|
|
async def test_auto_link_sec1_backfill_resets_unsafe_rows(self):
|
|
async def test_auto_link_sec1_backfill_resets_unsafe_rows(self):
|
|
|
- """SEC-1 backfill resets auto_link=TRUE on rows with unsafe combined state.
|
|
|
|
|
|
|
+ """SEC-1 backfill resets auto_link=TRUE only for Fall B (email_claim='email' + require_ev=FALSE).
|
|
|
|
|
|
|
|
Three cases:
|
|
Three cases:
|
|
|
- 1. auto_link=TRUE + require_ev=FALSE → reset to FALSE (unsafe: permissive mode)
|
|
|
|
|
- 2. auto_link=TRUE + custom claim → reset to FALSE (unsafe: no email_verified gate)
|
|
|
|
|
- 3. auto_link=TRUE + require_ev=TRUE + standard claim → unchanged (safe)
|
|
|
|
|
|
|
+ 1. auto_link=TRUE + email_claim='email' + require_ev=FALSE → reset to FALSE (Fall B, unsafe)
|
|
|
|
|
+ 2. auto_link=TRUE + custom claim + require_ev=TRUE → unchanged (Fall C, now allowed)
|
|
|
|
|
+ 3. auto_link=TRUE + email_claim='email' + require_ev=TRUE → unchanged (Fall A, safe)
|
|
|
"""
|
|
"""
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy import text
|
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
@@ -369,11 +377,11 @@ class TestSafeExecutePattern:
|
|
|
")"
|
|
")"
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
- # Row 1: unsafe — require_ev=FALSE
|
|
|
|
|
|
|
+ # Row 1: Fall B — email_claim='email' + require_ev=FALSE → must be reset
|
|
|
await conn.execute(text("INSERT INTO oidc_providers VALUES (1, 1, 0, 'email')"))
|
|
await conn.execute(text("INSERT INTO oidc_providers VALUES (1, 1, 0, 'email')"))
|
|
|
- # Row 2: unsafe — custom claim
|
|
|
|
|
|
|
+ # Row 2: Fall C — custom claim → must NOT be reset (now allowed)
|
|
|
await conn.execute(text("INSERT INTO oidc_providers VALUES (2, 1, 1, 'preferred_username')"))
|
|
await conn.execute(text("INSERT INTO oidc_providers VALUES (2, 1, 1, 'preferred_username')"))
|
|
|
- # Row 3: safe — require_ev=TRUE + standard claim
|
|
|
|
|
|
|
+ # Row 3: Fall A — email_claim='email' + require_ev=TRUE → must NOT be reset (always safe)
|
|
|
await conn.execute(text("INSERT INTO oidc_providers VALUES (3, 1, 1, 'email')"))
|
|
await conn.execute(text("INSERT INTO oidc_providers VALUES (3, 1, 1, 'email')"))
|
|
|
|
|
|
|
|
async with conn.begin_nested():
|
|
async with conn.begin_nested():
|
|
@@ -381,7 +389,7 @@ class TestSafeExecutePattern:
|
|
|
text(
|
|
text(
|
|
|
"UPDATE oidc_providers SET auto_link_existing_accounts = FALSE "
|
|
"UPDATE oidc_providers SET auto_link_existing_accounts = FALSE "
|
|
|
"WHERE auto_link_existing_accounts = TRUE "
|
|
"WHERE auto_link_existing_accounts = TRUE "
|
|
|
- "AND (require_email_verified = FALSE OR email_claim != 'email')"
|
|
|
|
|
|
|
+ "AND email_claim = 'email' AND require_email_verified = FALSE"
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -389,9 +397,9 @@ class TestSafeExecutePattern:
|
|
|
rows = {r[0]: r[1] for r in result.fetchall()}
|
|
rows = {r[0]: r[1] for r in result.fetchall()}
|
|
|
await engine.dispose()
|
|
await engine.dispose()
|
|
|
|
|
|
|
|
- assert rows[1] == 0, "unsafe (require_ev=FALSE) row must be reset to FALSE"
|
|
|
|
|
- assert rows[2] == 0, "unsafe (custom claim) row must be reset to FALSE"
|
|
|
|
|
- assert rows[3] == 1, "safe row must remain TRUE"
|
|
|
|
|
|
|
+ assert rows[1] == 0, "Fall B (require_ev=FALSE) must be reset to FALSE"
|
|
|
|
|
+ assert rows[2] == 1, "Fall C (custom claim) must remain TRUE"
|
|
|
|
|
+ assert rows[3] == 1, "Fall A (require_ev=TRUE) must remain TRUE"
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.asyncio
|
|
|
async def test_safe_execute_reraises_does_not_exist_without_column(self):
|
|
async def test_safe_execute_reraises_does_not_exist_without_column(self):
|
|
@@ -534,3 +542,213 @@ class TestSafeExecutePattern:
|
|
|
sql = mock_conn.execute.call_args[0][0].text
|
|
sql = mock_conn.execute.call_args[0][0].text
|
|
|
assert "::text = '[]'" in sql, f"Expected ::text cast in SQL, got: {sql}"
|
|
assert "::text = '[]'" in sql, f"Expected ::text cast in SQL, got: {sql}"
|
|
|
assert "printer_ids" in sql
|
|
assert "printer_ids" in sql
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class TestAutoLinkConstraintMigration:
|
|
|
|
|
+ """Tests for _migrate_update_auto_link_constraint (Fall C / Azure support)."""
|
|
|
|
|
+
|
|
|
|
|
+ @pytest.mark.asyncio
|
|
|
|
|
+ async def test_new_constraint_allows_fall_c_sqlite(self):
|
|
|
|
|
+ """New formula allows auto_link=TRUE with a custom claim (Fall C)."""
|
|
|
|
|
+ from sqlalchemy import text
|
|
|
|
|
+ from sqlalchemy.exc import IntegrityError
|
|
|
|
|
+ from sqlalchemy.ext.asyncio import create_async_engine
|
|
|
|
|
+
|
|
|
|
|
+ engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
|
|
|
|
+ async with engine.begin() as conn:
|
|
|
|
|
+ await conn.execute(
|
|
|
|
|
+ text(
|
|
|
|
|
+ "CREATE TABLE oidc_providers_ck ("
|
|
|
|
|
+ "id INTEGER PRIMARY KEY, "
|
|
|
|
|
+ "auto_link BOOLEAN, "
|
|
|
|
|
+ "require_ev BOOLEAN, "
|
|
|
|
|
+ "email_claim TEXT, "
|
|
|
|
|
+ "CHECK (auto_link = FALSE OR email_claim != 'email' OR require_ev = TRUE)"
|
|
|
|
|
+ ")"
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ # Fall C: custom claim + auto_link + require_ev=FALSE must pass
|
|
|
|
|
+ await conn.execute(text("INSERT INTO oidc_providers_ck VALUES (1, 1, 0, 'upn')"))
|
|
|
|
|
+ # Fall C: custom claim + auto_link + require_ev=TRUE must pass
|
|
|
|
|
+ await conn.execute(text("INSERT INTO oidc_providers_ck VALUES (2, 1, 1, 'preferred_username')"))
|
|
|
|
|
+ await engine.dispose()
|
|
|
|
|
+
|
|
|
|
|
+ @pytest.mark.asyncio
|
|
|
|
|
+ async def test_new_constraint_blocks_fall_b_sqlite(self):
|
|
|
|
|
+ """New formula still blocks Fall B (email_claim='email' + require_ev=FALSE + auto_link=TRUE)."""
|
|
|
|
|
+ from sqlalchemy import text
|
|
|
|
|
+ from sqlalchemy.exc import IntegrityError
|
|
|
|
|
+ from sqlalchemy.ext.asyncio import create_async_engine
|
|
|
|
|
+
|
|
|
|
|
+ engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
|
|
|
|
+ async with engine.begin() as conn:
|
|
|
|
|
+ await conn.execute(
|
|
|
|
|
+ text(
|
|
|
|
|
+ "CREATE TABLE oidc_providers_ck ("
|
|
|
|
|
+ "id INTEGER PRIMARY KEY, "
|
|
|
|
|
+ "auto_link BOOLEAN, "
|
|
|
|
|
+ "require_ev BOOLEAN, "
|
|
|
|
|
+ "email_claim TEXT, "
|
|
|
|
|
+ "CHECK (auto_link = FALSE OR email_claim != 'email' OR require_ev = TRUE)"
|
|
|
|
|
+ ")"
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ async with engine.begin() as conn:
|
|
|
|
|
+ with pytest.raises(IntegrityError):
|
|
|
|
|
+ await conn.execute(text("INSERT INTO oidc_providers_ck VALUES (1, 1, 0, 'email')"))
|
|
|
|
|
+ await engine.dispose()
|
|
|
|
|
+
|
|
|
|
|
+ @pytest.mark.asyncio
|
|
|
|
|
+ async def test_constraint_migration_sqlite_recreates_table(self):
|
|
|
|
|
+ """SQLite path recreates oidc_providers with new constraint when old formula is present."""
|
|
|
|
|
+ from sqlalchemy import text
|
|
|
|
|
+ from sqlalchemy.ext.asyncio import create_async_engine
|
|
|
|
|
+
|
|
|
|
|
+ from backend.app.core.database import _migrate_update_auto_link_constraint
|
|
|
|
|
+
|
|
|
|
|
+ # Create table with old constraint formula
|
|
|
|
|
+ engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
|
|
|
|
+ async with engine.begin() as conn:
|
|
|
|
|
+ await conn.execute(
|
|
|
|
|
+ text(
|
|
|
|
|
+ "CREATE TABLE oidc_providers ("
|
|
|
|
|
+ "id INTEGER NOT NULL PRIMARY KEY, "
|
|
|
|
|
+ "name VARCHAR(100) NOT NULL UNIQUE, "
|
|
|
|
|
+ "issuer_url VARCHAR(500) NOT NULL, "
|
|
|
|
|
+ "client_id VARCHAR(255) NOT NULL, "
|
|
|
|
|
+ "client_secret VARCHAR(512) NOT NULL, "
|
|
|
|
|
+ "scopes VARCHAR(500), "
|
|
|
|
|
+ "is_enabled BOOLEAN, "
|
|
|
|
|
+ "auto_create_users BOOLEAN, "
|
|
|
|
|
+ "auto_link_existing_accounts BOOLEAN DEFAULT 0, "
|
|
|
|
|
+ "email_claim VARCHAR(64) DEFAULT 'email', "
|
|
|
|
|
+ "require_email_verified BOOLEAN DEFAULT 1, "
|
|
|
|
|
+ "icon_url TEXT, "
|
|
|
|
|
+ "created_at DATETIME DEFAULT CURRENT_TIMESTAMP, "
|
|
|
|
|
+ "updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, "
|
|
|
|
|
+ "CONSTRAINT ck_auto_link_requires_verified_email_claim "
|
|
|
|
|
+ "CHECK (auto_link_existing_accounts = FALSE OR "
|
|
|
|
|
+ "(require_email_verified = TRUE AND email_claim = 'email'))"
|
|
|
|
|
+ ")"
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ await conn.execute(
|
|
|
|
|
+ text(
|
|
|
|
|
+ "INSERT INTO oidc_providers (id, name, issuer_url, client_id, client_secret, "
|
|
|
|
|
+ "scopes, is_enabled, auto_create_users, auto_link_existing_accounts, "
|
|
|
|
|
+ "email_claim, require_email_verified, icon_url, created_at, updated_at) "
|
|
|
|
|
+ "VALUES (1, 'TestIdP', 'https://idp.test', 'cid', 'secret', "
|
|
|
|
|
+ "'openid email', 1, 0, 0, 'email', 1, NULL, "
|
|
|
|
|
+ "CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)"
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ async with engine.begin() as conn:
|
|
|
|
|
+ with patch("backend.app.core.database.is_sqlite", return_value=True):
|
|
|
|
|
+ await _migrate_update_auto_link_constraint(conn)
|
|
|
|
|
+
|
|
|
|
|
+ # Verify data survived
|
|
|
|
|
+ result = await conn.execute(text("SELECT id, name FROM oidc_providers"))
|
|
|
|
|
+ rows = result.fetchall()
|
|
|
|
|
+ assert len(rows) == 1
|
|
|
|
|
+ assert rows[0][0] == 1
|
|
|
|
|
+
|
|
|
|
|
+ # Verify new constraint: Fall C (auto_link=TRUE + custom claim) must now be insertable
|
|
|
|
|
+ await conn.execute(
|
|
|
|
|
+ text(
|
|
|
|
|
+ "INSERT INTO oidc_providers (id, name, issuer_url, client_id, client_secret, "
|
|
|
|
|
+ "scopes, is_enabled, auto_create_users, auto_link_existing_accounts, "
|
|
|
|
|
+ "email_claim, require_email_verified, icon_url, created_at, updated_at) "
|
|
|
|
|
+ "VALUES (2, 'AzureIdP', 'https://azure.test', 'cid2', 'secret', "
|
|
|
|
|
+ "'openid', 1, 0, 1, 'upn', 1, NULL, "
|
|
|
|
|
+ "CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)"
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Verify schema has new formula
|
|
|
|
|
+ schema = (
|
|
|
|
|
+ await conn.execute(text("SELECT sql FROM sqlite_master WHERE type='table' AND name='oidc_providers'"))
|
|
|
|
|
+ ).fetchone()[0]
|
|
|
|
|
+ assert "require_email_verified = TRUE AND email_claim = 'email'" not in schema
|
|
|
|
|
+ assert "email_claim != 'email'" in schema
|
|
|
|
|
+
|
|
|
|
|
+ await engine.dispose()
|
|
|
|
|
+
|
|
|
|
|
+ @pytest.mark.asyncio
|
|
|
|
|
+ async def test_constraint_migration_postgres_drops_and_recreates(self):
|
|
|
|
|
+ """PostgreSQL path calls DROP CONSTRAINT IF EXISTS then ADD CONSTRAINT with new formula."""
|
|
|
|
|
+ from unittest.mock import AsyncMock, MagicMock, call
|
|
|
|
|
+
|
|
|
|
|
+ from backend.app.core.database import _migrate_update_auto_link_constraint
|
|
|
|
|
+
|
|
|
|
|
+ # Track all SQL statements passed to _safe_execute by capturing conn.execute calls
|
|
|
|
|
+ executed_sqls: list[str] = []
|
|
|
|
|
+
|
|
|
|
|
+ async def fake_safe_execute(conn, sql):
|
|
|
|
|
+ executed_sqls.append(sql)
|
|
|
|
|
+
|
|
|
|
|
+ nested_cm = MagicMock()
|
|
|
|
|
+ nested_cm.__aenter__ = AsyncMock(return_value=nested_cm)
|
|
|
|
|
+ nested_cm.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
|
+ nested_cm.execute = AsyncMock()
|
|
|
|
|
+
|
|
|
|
|
+ mock_conn = MagicMock()
|
|
|
|
|
+ mock_conn.begin_nested.return_value = nested_cm
|
|
|
|
|
+ mock_conn.execute = AsyncMock()
|
|
|
|
|
+
|
|
|
|
|
+ with (
|
|
|
|
|
+ patch("backend.app.core.database.is_sqlite", return_value=False),
|
|
|
|
|
+ patch("backend.app.core.database._safe_execute", side_effect=fake_safe_execute),
|
|
|
|
|
+ ):
|
|
|
|
|
+ await _migrate_update_auto_link_constraint(mock_conn)
|
|
|
|
|
+
|
|
|
|
|
+ assert len(executed_sqls) == 2
|
|
|
|
|
+ drop_sql, add_sql = executed_sqls
|
|
|
|
|
+ assert "DROP CONSTRAINT IF EXISTS" in drop_sql.upper()
|
|
|
|
|
+ assert "ck_auto_link_requires_verified_email_claim" in drop_sql
|
|
|
|
|
+ assert "ADD CONSTRAINT" in add_sql.upper()
|
|
|
|
|
+ assert "email_claim != 'email'" in add_sql
|
|
|
|
|
+ assert "require_email_verified = TRUE AND email_claim = 'email'" not in add_sql
|
|
|
|
|
+
|
|
|
|
|
+ @pytest.mark.asyncio
|
|
|
|
|
+ async def test_constraint_migration_sqlite_count_guard_raises_on_mismatch(self):
|
|
|
|
|
+ """RuntimeError is raised when the copied row count doesn't match the source."""
|
|
|
|
|
+ from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
|
+
|
|
|
|
|
+ import pytest
|
|
|
|
|
+
|
|
|
|
|
+ from backend.app.core.database import _migrate_update_auto_link_constraint
|
|
|
|
|
+
|
|
|
|
|
+ _OLD_SQL = (
|
|
|
|
|
+ "CREATE TABLE oidc_providers (id INTEGER NOT NULL, "
|
|
|
|
|
+ "CONSTRAINT ck_auto_link_requires_verified_email_claim "
|
|
|
|
|
+ "CHECK (auto_link_existing_accounts = FALSE OR "
|
|
|
|
|
+ "(require_email_verified = TRUE AND email_claim = 'email')))"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ async def fake_execute(stmt):
|
|
|
|
|
+ sql = str(stmt)
|
|
|
|
|
+ result = MagicMock()
|
|
|
|
|
+ if "sqlite_master" in sql:
|
|
|
|
|
+ result.fetchone.return_value = (_OLD_SQL,)
|
|
|
|
|
+ elif "count(*)" in sql.lower() and "oidc_providers_v2" not in sql:
|
|
|
|
|
+ result.scalar_one.return_value = 2 # source has 2 rows
|
|
|
|
|
+ elif "count(*)" in sql.lower() and "oidc_providers_v2" in sql:
|
|
|
|
|
+ result.scalar_one.return_value = 1 # copy only has 1 — mismatch
|
|
|
|
|
+ else:
|
|
|
|
|
+ result.fetchone.return_value = None
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+ nested_cm = MagicMock()
|
|
|
|
|
+ nested_cm.__aenter__ = AsyncMock(return_value=None)
|
|
|
|
|
+ nested_cm.__aexit__ = AsyncMock(return_value=False) # don't suppress exceptions
|
|
|
|
|
+
|
|
|
|
|
+ mock_conn = MagicMock()
|
|
|
|
|
+ mock_conn.execute = AsyncMock(side_effect=fake_execute)
|
|
|
|
|
+ mock_conn.begin_nested.return_value = nested_cm
|
|
|
|
|
+
|
|
|
|
|
+ with (
|
|
|
|
|
+ patch("backend.app.core.database.is_sqlite", return_value=True),
|
|
|
|
|
+ pytest.raises(RuntimeError, match="mismatch"),
|
|
|
|
|
+ ):
|
|
|
|
|
+ await _migrate_update_auto_link_constraint(mock_conn)
|