| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823 |
- """Security tests for the 8 coverage gaps identified in the maintainer review.
- Gap 1: encryption.py has zero tests
- Gap 2: JWT revocation (revoke_jti, is_jti_revoked, _is_token_fresh) untested
- Gap 3: OIDC exchange token replay untested
- Gap 4: OIDC email_verified claim handling untested
- Gap 5: Email OTP max-attempts invalidation untested
- Gap 6: OIDC callback error redirects (SSRF protection) undertested
- Gap 7: Login rate limiting untested
- Gap 8: challenge_id cookie binding untested
- """
- from __future__ import annotations
- import base64
- import secrets
- import time
- from datetime import datetime, timedelta, timezone
- from unittest.mock import AsyncMock, MagicMock, patch
- import jwt as pyjwt
- import pytest
- from cryptography.hazmat.primitives import serialization
- from cryptography.hazmat.primitives.asymmetric import rsa
- from httpx import AsyncClient
- from sqlalchemy.ext.asyncio import AsyncSession
- from backend.app.models.auth_ephemeral import AuthEphemeralToken
- from backend.app.models.user import User
- AUTH_SETUP_URL = "/api/v1/auth/setup"
- LOGIN_URL = "/api/v1/auth/login"
- LOGOUT_URL = "/api/v1/auth/logout"
- ME_URL = "/api/v1/auth/me"
- def _auth_header(token: str) -> dict[str, str]:
- return {"Authorization": f"Bearer {token}"}
- def _norm_pw(password: str) -> str:
- """Ensure password meets complexity requirements (I4: SetupRequest now validates)."""
- if not any(c.isupper() for c in password):
- password = password[0].upper() + password[1:]
- if not any(c not in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" for c in password):
- password = password + "!"
- return password
- async def _setup_and_login(client: AsyncClient, username: str, password: str) -> str:
- password = _norm_pw(password)
- await client.post(
- AUTH_SETUP_URL,
- json={"auth_enabled": True, "admin_username": username, "admin_password": password},
- )
- resp = await client.post(LOGIN_URL, json={"username": username, "password": password})
- assert resp.status_code == 200
- return resp.json()["access_token"]
- def _make_test_rsa_key():
- def _b64url(n: int, length: int) -> str:
- return base64.urlsafe_b64encode(n.to_bytes(length, "big")).rstrip(b"=").decode()
- private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
- private_pem = private_key.private_bytes(
- serialization.Encoding.PEM,
- serialization.PrivateFormat.TraditionalOpenSSL,
- serialization.NoEncryption(),
- )
- pub_numbers = private_key.public_key().public_numbers()
- jwks = {
- "keys": [
- {
- "kty": "RSA",
- "use": "sig",
- "alg": "RS256",
- "kid": "test-kid-1",
- "n": _b64url(pub_numbers.n, 256),
- "e": _b64url(pub_numbers.e, 3),
- }
- ]
- }
- return private_pem, jwks
- # ===========================================================================
- # Gap 1: encryption.py unit tests
- # ===========================================================================
- class TestEncryption:
- """encrypt/decrypt round-trips, plaintext passthrough, RuntimeError on missing key.
- The ``mfa_encryption_isolation`` autouse fixture (conftest.py) resets the
- ``encryption`` module's globals before/after each test and points
- ``DATA_DIR`` at a tmp path, so individual tests only need to set
- ``MFA_ENCRYPTION_KEY`` when they want a specific key in scope.
- """
- def test_encrypt_decrypt_roundtrip_with_key(self, monkeypatch):
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- test_key = Fernet.generate_key().decode()
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", test_key)
- # Force re-initialisation now that the env var is set.
- enc_mod._fernet_instance = None
- ciphertext = enc_mod.mfa_encrypt("my-totp-secret")
- assert ciphertext.startswith("fernet:")
- assert enc_mod.mfa_decrypt(ciphertext) == "my-totp-secret"
- def test_plaintext_passthrough_without_key(self, monkeypatch):
- # Force the auto-bootstrap into the legacy "no key available" branch
- # by patching _load_or_generate_key directly. This is more robust than
- # chmod tricks (which root bypasses) when verifying the plaintext path.
- import backend.app.core.encryption as enc_mod
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- result = enc_mod.mfa_encrypt("plaintext-secret")
- assert result == "plaintext-secret"
- assert enc_mod.mfa_decrypt("plaintext-secret") == "plaintext-secret"
- def test_decrypt_raises_runtime_error_without_key_for_encrypted_value(self, monkeypatch):
- import backend.app.core.encryption as enc_mod
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- with pytest.raises(RuntimeError, match="MFA_ENCRYPTION_KEY must be set"):
- enc_mod.mfa_decrypt("fernet:gAAAAA-fake-ciphertext")
- # ------------------------------------------------------------------
- # Auto-bootstrap tests for _load_or_generate_key
- # ------------------------------------------------------------------
- def test_load_or_generate_key_uses_env_when_set(self, monkeypatch, tmp_path):
- """Valid env var → key_source == 'env', no file written."""
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- valid_key = Fernet.generate_key().decode()
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", valid_key)
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- key, source = enc_mod._load_or_generate_key()
- assert key == valid_key
- assert source == "env"
- assert not (tmp_path / ".mfa_encryption_key").exists()
- def test_invalid_env_key_falls_through_to_file(self, monkeypatch, tmp_path, caplog):
- """Invalid env var → logger.error + file fallback (auto-generated)."""
- import logging
- import backend.app.core.encryption as enc_mod
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", "not-a-valid-fernet-key")
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- with caplog.at_level(logging.ERROR, logger="backend.app.core.encryption"):
- key, source = enc_mod._load_or_generate_key()
- assert source == "generated"
- assert key is not None
- assert (tmp_path / ".mfa_encryption_key").exists()
- assert any("not a valid Fernet key" in rec.message for rec in caplog.records)
- def test_load_or_generate_key_reads_existing_file(self, monkeypatch, tmp_path):
- """File present in DATA_DIR + no env var → key_source == 'file'."""
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- existing_key = Fernet.generate_key().decode()
- key_file = tmp_path / ".mfa_encryption_key"
- key_file.write_text(existing_key)
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- key, source = enc_mod._load_or_generate_key()
- assert key == existing_key
- assert source == "file"
- def test_load_or_generate_key_creates_file_with_0600(self, monkeypatch, tmp_path):
- """Neither env nor file → new key generated, file mode is 0o600."""
- import backend.app.core.encryption as enc_mod
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- key, source = enc_mod._load_or_generate_key()
- assert source == "generated"
- assert enc_mod._validate_fernet_key(key)
- key_file = tmp_path / ".mfa_encryption_key"
- assert key_file.exists()
- # Mode bits LSB are 0o600 — owner read+write only.
- assert (key_file.stat().st_mode & 0o777) == 0o600
- def test_load_or_generate_key_returns_none_on_write_oserror(self, monkeypatch, tmp_path, caplog):
- """When DATA_DIR can't be written to (auto-generate path), return (None, 'none_write_failed').
- S1: write now uses os.open(O_EXCL|O_CREAT, 0o600) instead of write_text — patch
- os.write to simulate the OS-level failure. S8: source distinguishes write-failed
- from corrupted to drive accurate operator messaging.
- """
- import logging
- import os
- import backend.app.core.encryption as enc_mod
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- original_write = os.write
- def _raising_write(fd, data):
- # Best-effort: trigger OSError specifically for the key write.
- raise OSError("simulated read-only filesystem")
- monkeypatch.setattr(os, "write", _raising_write)
- with caplog.at_level(logging.ERROR, logger="backend.app.core.encryption"):
- key, source = enc_mod._load_or_generate_key()
- # Restore os.write so the rest of the test suite is unaffected.
- monkeypatch.setattr(os, "write", original_write)
- assert key is None
- assert source == "none_write_failed"
- assert any("Could not save MFA encryption key" in rec.message for rec in caplog.records)
- def test_load_or_generate_key_returns_none_on_read_oserror(self, monkeypatch, tmp_path, caplog):
- """B4: existing key file but read fails (e.g. permission denied) → (None, 'none_corrupted').
- Critical: must NOT regenerate a new key, which would destroy access to
- every row already encrypted under the existing key. S8: 'none_corrupted'
- marks the cause so operators see the right diagnostic.
- """
- import logging
- from pathlib import Path
- import backend.app.core.encryption as enc_mod
- # Pre-create a key file so we hit the existing-file branch.
- key_file = tmp_path / ".mfa_encryption_key"
- key_file.write_text("placeholder-content")
- original_size = key_file.stat().st_size
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- original_read_text = Path.read_text
- def _raising_read_text(self, *args, **kwargs):
- if self.name == ".mfa_encryption_key":
- raise OSError("simulated permission denied")
- return original_read_text(self, *args, **kwargs)
- monkeypatch.setattr(Path, "read_text", _raising_read_text)
- with caplog.at_level(logging.ERROR, logger="backend.app.core.encryption"):
- key, source = enc_mod._load_or_generate_key()
- assert key is None
- assert source == "none_corrupted"
- # Critical: file must not have been overwritten with a new key.
- assert key_file.exists()
- assert key_file.stat().st_size == original_size
- assert any("Failed to read existing MFA key file" in rec.message for rec in caplog.records)
- assert any("Refusing to regenerate" in rec.message for rec in caplog.records)
- def test_get_key_source_reflects_active_source(self, monkeypatch, tmp_path):
- """get_key_source() returns the source detected on the most recent _get_fernet() call."""
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- # Trigger lazy initialisation
- enc_mod.mfa_encrypt("anything")
- assert enc_mod.get_key_source() == "env"
- def test_corrupted_key_file_returns_none_without_overwrite(self, monkeypatch, tmp_path, caplog):
- """A1: invalid key file content → (None, 'none_corrupted'), file not overwritten.
- S8: 'none_corrupted' (vs 'none_write_failed') so operators get the right
- diagnostic and don't see a misleading 'DATA_DIR not writable' warning.
- """
- import logging
- import backend.app.core.encryption as enc_mod
- key_file = tmp_path / ".mfa_encryption_key"
- key_file.write_text("invalid_content")
- original_mtime = key_file.stat().st_mtime
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- with caplog.at_level(logging.ERROR, logger="backend.app.core.encryption"):
- key, source = enc_mod._load_or_generate_key()
- assert key is None
- assert source == "none_corrupted"
- assert key_file.exists(), "file must not be deleted"
- assert key_file.stat().st_mtime == original_mtime, "file must not be overwritten"
- assert any("not a valid Fernet key" in rec.message for rec in caplog.records)
- assert any("Refusing to overwrite" in rec.message for rec in caplog.records)
- def test_auto_generate_fileexistserror_returns_none_corrupted(self, monkeypatch, tmp_path, caplog):
- """S1: O_EXCL race — file appears between exists() check and open() →
- return (None, 'none_corrupted') without overwriting."""
- import logging
- import os
- import backend.app.core.encryption as enc_mod
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- enc_mod._fernet_instance = None
- original_open = os.open
- def _excl_raise(path, flags, mode=0o777):
- if str(path).endswith(".mfa_encryption_key") and (flags & os.O_EXCL):
- raise FileExistsError(17, "File exists", str(path))
- return original_open(path, flags, mode)
- monkeypatch.setattr(os, "open", _excl_raise)
- with caplog.at_level(logging.ERROR, logger="backend.app.core.encryption"):
- key, source = enc_mod._load_or_generate_key()
- assert key is None
- assert source == "none_corrupted"
- assert any("Race detected" in rec.message for rec in caplog.records)
- # ===========================================================================
- # Gap 2: JWT revocation — revoke_jti, is_jti_revoked, _is_token_fresh, /me
- # ===========================================================================
- class TestJWTRevocation:
- """JWT revocation and token freshness checks."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_revoke_jti_and_is_jti_revoked(self, async_client: AsyncClient, db_session: AsyncSession):
- """revoke_jti stores the JTI; is_jti_revoked returns True afterwards."""
- from backend.app.core.auth import is_jti_revoked, revoke_jti
- test_jti = secrets.token_urlsafe(16)
- expires = datetime.now(timezone.utc) + timedelta(hours=1)
- assert not await is_jti_revoked(test_jti)
- await revoke_jti(test_jti, expires, username="testuser")
- assert await is_jti_revoked(test_jti)
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_revoke_jti_idempotent(self, async_client: AsyncClient):
- """Double-revocation of the same JTI should not raise."""
- from backend.app.core.auth import is_jti_revoked, revoke_jti
- jti = secrets.token_urlsafe(16)
- expires = datetime.now(timezone.utc) + timedelta(hours=1)
- await revoke_jti(jti, expires)
- await revoke_jti(jti, expires) # must not raise
- assert await is_jti_revoked(jti)
- def test_is_token_fresh_rejects_none_iat(self):
- """_is_token_fresh returns False when iat is None (I1 hard cutoff)."""
- from backend.app.core.auth import _is_token_fresh
- user = MagicMock()
- user.password_changed_at = None
- assert _is_token_fresh(None, user) is False
- def test_is_token_fresh_rejects_token_before_password_change(self):
- """_is_token_fresh returns False when iat predates password_changed_at."""
- from backend.app.core.auth import _is_token_fresh
- now = datetime.now(timezone.utc)
- user = MagicMock()
- user.password_changed_at = now
- old_iat = (now - timedelta(hours=1)).timestamp()
- assert _is_token_fresh(old_iat, user) is False
- def test_is_token_fresh_accepts_token_after_password_change(self):
- """_is_token_fresh returns True when iat is after password_changed_at."""
- from backend.app.core.auth import _is_token_fresh
- now = datetime.now(timezone.utc)
- user = MagicMock()
- user.password_changed_at = now - timedelta(hours=1)
- recent_iat = now.timestamp()
- assert _is_token_fresh(recent_iat, user) is True
- def test_is_token_fresh_returns_true_when_no_password_change(self):
- """_is_token_fresh returns True when password_changed_at is None (I2 migration not yet run)."""
- from backend.app.core.auth import _is_token_fresh
- user = MagicMock()
- user.password_changed_at = None
- assert _is_token_fresh(time.time(), user) is True
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_me_endpoint_rejects_token_after_logout(self, async_client: AsyncClient):
- """After logout, the bearer token must be rejected by /me (B1 + revocation)."""
- token = await _setup_and_login(async_client, "sec_logout_me", "sec_logout_me1")
- # Token works before logout
- me_resp = await async_client.get(ME_URL, headers=_auth_header(token))
- assert me_resp.status_code == 200
- # Logout
- logout_resp = await async_client.post(LOGOUT_URL, headers=_auth_header(token))
- assert logout_resp.status_code == 200
- # Token must now be rejected
- me_after = await async_client.get(ME_URL, headers=_auth_header(token))
- assert me_after.status_code == 401
- # ===========================================================================
- # Gap 3: OIDC exchange token replay
- # ===========================================================================
- class TestOIDCExchangeReplay:
- """A single-use OIDC exchange token cannot be redeemed twice."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_exchange_token_is_single_use(self, async_client: AsyncClient, db_session: AsyncSession):
- """The second call to /oidc/exchange with the same token returns 401."""
- exchange_token = secrets.token_urlsafe(32)
- db_session.add(
- AuthEphemeralToken(
- token=exchange_token,
- token_type="oidc_exchange",
- username="oidc_replay_user",
- expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
- )
- )
- await db_session.commit()
- # Seed the user so the exchange can resolve it
- from backend.app.core.auth import get_password_hash
- from backend.app.core.database import async_session, seed_default_groups
- async with async_session() as db:
- result = await db.execute(__import__("sqlalchemy").select(User).where(User.username == "oidc_replay_user"))
- if result.scalar_one_or_none() is None:
- db.add(
- User(
- username="oidc_replay_user",
- password_hash=get_password_hash("pw"),
- is_active=True,
- )
- )
- await db.commit()
- first = await async_client.post("/api/v1/auth/oidc/exchange", json={"oidc_token": exchange_token})
- assert first.status_code == 200
- second = await async_client.post("/api/v1/auth/oidc/exchange", json={"oidc_token": exchange_token})
- assert second.status_code == 401
- # ===========================================================================
- # Gap 4: OIDC email_verified claim handling
- # ===========================================================================
- class TestOIDCEmailVerified:
- """email_verified: False/absent must not link OIDC identity to an existing email."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_unverified_email_does_not_link_to_existing_user(
- self, async_client: AsyncClient, db_session: AsyncSession
- ):
- """If email_verified is False, the OIDC callback must not auto-link by email."""
- private_pem, jwks_data = _make_test_rsa_key()
- issuer = "https://idp.evtest.example.com"
- client_id = "ev-client"
- nonce = secrets.token_urlsafe(16)
- now = int(time.time())
- id_token = pyjwt.encode(
- {
- "sub": "ev-sub-new",
- "iss": issuer,
- "aud": client_id,
- "nonce": nonce,
- "email": "existing@example.com",
- "email_verified": False, # <-- must be ignored
- "iat": now,
- "exp": now + 300,
- },
- private_pem,
- algorithm="RS256",
- headers={"kid": "test-kid-1"},
- )
- admin_token = await _setup_and_login(async_client, "ev_admin", "ev_admin1")
- # Create existing user with the same email (use strong password for validator)
- create_user_resp = await async_client.post(
- "/api/v1/users",
- json={"username": "existing_email_user", "password": "Str0ng!Pass", "email": "existing@example.com"},
- headers=_auth_header(admin_token),
- )
- assert create_user_resp.status_code in (200, 201), create_user_resp.json()
- # Create OIDC provider
- create_resp = await async_client.post(
- "/api/v1/auth/oidc/providers",
- json={
- "name": "EV-IdP",
- "issuer_url": issuer,
- "client_id": client_id,
- "client_secret": "secret",
- "scopes": "openid email",
- "is_enabled": True,
- "auto_create_users": True,
- },
- headers=_auth_header(admin_token),
- )
- assert create_resp.status_code == 201
- provider_id = create_resp.json()["id"]
- state = secrets.token_urlsafe(32)
- code_verifier = secrets.token_urlsafe(48)
- db_session.add(
- AuthEphemeralToken(
- token=state,
- token_type="oidc_state",
- provider_id=provider_id,
- nonce=nonce,
- code_verifier=code_verifier,
- expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
- )
- )
- await db_session.commit()
- discovery_doc = {
- "issuer": issuer,
- "authorization_endpoint": f"{issuer}/auth",
- "token_endpoint": f"{issuer}/token",
- "jwks_uri": f"{issuer}/.well-known/jwks.json",
- }
- class _MockResp:
- def __init__(self, data):
- self._data = data
- self.status_code = 200
- self.is_success = True
- self.text = str(data)
- def json(self):
- return self._data
- def raise_for_status(self):
- pass
- class _MockHttpxClientEV:
- def __init__(self, *args, **kwargs):
- pass
- async def __aenter__(self):
- return self
- async def __aexit__(self, *_):
- pass
- async def get(self, url, **kwargs):
- if "jwks" in url:
- return _MockResp(jwks_data)
- return _MockResp(discovery_doc)
- async def post(self, url, **kwargs):
- return _MockResp({"access_token": "mock", "token_type": "Bearer", "id_token": id_token})
- with patch("backend.app.api.routes.mfa.httpx.AsyncClient", _MockHttpxClientEV):
- await async_client.get(
- f"/api/v1/auth/oidc/callback?code=test-code&state={state}",
- follow_redirects=False,
- )
- # Callback must NOT link to the existing_email_user — a new user is created
- # instead (because the email claim was ignored due to email_verified=False).
- # Either a new user is provisioned (redirect with oidc_token) or the callback
- # fails. In either case, the existing user must not have an OIDC link.
- from sqlalchemy import select as sa_select
- from backend.app.models.oidc_provider import UserOIDCLink
- link_result = await db_session.execute(
- sa_select(UserOIDCLink)
- .join(User, UserOIDCLink.user_id == User.id)
- .where(User.email == "existing@example.com")
- )
- link = link_result.scalar_one_or_none()
- assert link is None, "Existing user must not be auto-linked when email_verified is False"
- # ===========================================================================
- # Gap 5: Email OTP max-attempts invalidation
- # ===========================================================================
- class TestEmailOTPMaxAttempts:
- """After MAX_ATTEMPTS wrong codes, the OTP is permanently invalidated."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_email_otp_invalidated_after_max_attempts(self, async_client: AsyncClient, db_session: AsyncSession):
- from passlib.context import CryptContext
- from sqlalchemy import select as sa_select
- from backend.app.models.user_otp_code import UserOTPCode
- _pwd_ctx = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
- admin_token = await _setup_and_login(async_client, "otp_max_admin", "otp_max_admin1")
- # Enable email OTP for admin user
- result = await db_session.execute(sa_select(User).where(User.username == "otp_max_admin"))
- user = result.scalar_one()
- user.email = "otpmax@example.com"
- await db_session.commit()
- setup_code = "123456"
- from backend.app.models.auth_ephemeral import AuthEphemeralToken as AET
- setup_token = secrets.token_urlsafe(32)
- db_session.add(
- AET(
- token=setup_token,
- token_type="email_otp_setup",
- username="otp_max_admin",
- nonce=_pwd_ctx.hash(setup_code),
- expires_at=datetime.now(timezone.utc) + timedelta(minutes=10),
- )
- )
- await db_session.commit()
- await async_client.post(
- "/api/v1/auth/2fa/email/enable/confirm",
- json={"setup_token": setup_token, "code": setup_code},
- headers=_auth_header(admin_token),
- )
- # Login to get pre_auth_token
- login_resp = await async_client.post(
- LOGIN_URL, json={"username": "otp_max_admin", "password": "Otp_max_admin1"}
- )
- pre_auth_token = login_resp.json()["pre_auth_token"]
- # Insert an OTP record directly (bypassing SMTP)
- real_code = "654321"
- otp = UserOTPCode(
- user_id=user.id,
- code_hash=_pwd_ctx.hash(real_code),
- attempts=0,
- used=False,
- expires_at=datetime.now(timezone.utc) + timedelta(minutes=10),
- )
- db_session.add(otp)
- await db_session.commit()
- # Submit MAX_ATTEMPTS wrong codes
- from backend.app.api.routes.mfa import MAX_2FA_ATTEMPTS
- for _ in range(MAX_2FA_ATTEMPTS):
- r = await async_client.post(
- "/api/v1/auth/2fa/verify",
- json={"pre_auth_token": pre_auth_token, "code": "000000", "method": "email"},
- )
- # Each attempt must fail with 401
- assert r.status_code == 401
- # After max attempts, the correct code is also rejected (either OTP
- # invalidated → 401, or rate limit hit → 429). Either means locked out.
- final = await async_client.post(
- "/api/v1/auth/2fa/verify",
- json={"pre_auth_token": pre_auth_token, "code": real_code, "method": "email"},
- )
- assert final.status_code in (401, 429), f"Expected lockout, got {final.status_code}: {final.json()}"
- # ===========================================================================
- # Gap 6: OIDC callback SSRF protection — invalid authorization_endpoint scheme
- # ===========================================================================
- class TestOIDCSSRFProtection:
- """authorization_endpoint with non-http(s) scheme must be rejected."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_invalid_authorization_endpoint_scheme_rejected(
- self, async_client: AsyncClient, db_session: AsyncSession
- ):
- issuer = "https://idp.ssrf.example.com"
- client_id = "ssrf-client"
- admin_token = await _setup_and_login(async_client, "ssrf_admin", "ssrf_admin1")
- create_resp = await async_client.post(
- "/api/v1/auth/oidc/providers",
- json={
- "name": "SSRF-IdP",
- "issuer_url": issuer,
- "client_id": client_id,
- "client_secret": "secret",
- "scopes": "openid",
- "is_enabled": True,
- "auto_create_users": False,
- },
- headers=_auth_header(admin_token),
- )
- assert create_resp.status_code == 201
- provider_id = create_resp.json()["id"]
- # Discovery doc returns a javascript: authorization_endpoint
- malicious_discovery = {
- "issuer": issuer,
- "authorization_endpoint": "javascript:alert(1)", # <-- malicious
- "token_endpoint": f"{issuer}/token",
- "jwks_uri": f"{issuer}/.well-known/jwks.json",
- }
- class _MockResp:
- def __init__(self, data):
- self._data = data
- self.status_code = 200
- self.is_success = True
- self.text = str(data)
- def json(self):
- return self._data
- def raise_for_status(self):
- pass
- class _MockHttpxClientSSRF:
- def __init__(self, *args, **kwargs):
- pass
- async def __aenter__(self):
- return self
- async def __aexit__(self, *_):
- pass
- async def get(self, url, **kwargs):
- return _MockResp(malicious_discovery)
- async def post(self, url, **kwargs):
- return _MockResp({})
- with patch("backend.app.api.routes.mfa.httpx.AsyncClient", _MockHttpxClientSSRF):
- # oidc_authorize uses a path parameter, not query param
- authorize_resp = await async_client.get(
- f"/api/v1/auth/oidc/authorize/{provider_id}",
- follow_redirects=False,
- )
- # Must be rejected with 502 — B2 guard rejects invalid authorization_endpoint scheme
- assert authorize_resp.status_code == 502, authorize_resp.json()
- detail = authorize_resp.json().get("detail", "").lower()
- assert "authorization_endpoint" in detail or "invalid" in detail
- # ===========================================================================
- # Gap 7: Login rate limiting
- # ===========================================================================
- class TestLoginRateLimiting:
- """10+ failed logins for the same username must return 429."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_excessive_failed_logins_return_429(self, async_client: AsyncClient):
- from backend.app.api.routes.mfa import MAX_LOGIN_ATTEMPTS
- # Setup auth but do NOT log in
- await async_client.post(
- AUTH_SETUP_URL,
- json={"auth_enabled": True, "admin_username": "ratelimit_user", "admin_password": "Ratelimit_pw1"},
- )
- status_codes = []
- for _ in range(MAX_LOGIN_ATTEMPTS + 2):
- resp = await async_client.post(
- LOGIN_URL,
- json={"username": "ratelimit_user", "password": "wrong_password"},
- )
- status_codes.append(resp.status_code)
- # The last attempts must be 429 (Too Many Requests)
- assert status_codes[-1] == 429, f"Expected 429 after {MAX_LOGIN_ATTEMPTS} failures, got: {status_codes}"
- # ===========================================================================
- # Gap 8: challenge_id cookie binding
- # ===========================================================================
- class TestChallengeIdCookieBinding:
- """A pre-auth token stolen from session A cannot be used from session B."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_pre_auth_token_rejected_without_matching_cookie(
- self, async_client: AsyncClient, db_session: AsyncSession
- ):
- import pyotp
- from passlib.context import CryptContext
- _pwd_ctx = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
- # Set up user with TOTP
- await _setup_and_login(async_client, "cookie_bind_user", "cookie_bind_pw1")
- secret = pyotp.random_base32()
- totp_obj = pyotp.TOTP(secret)
- from sqlalchemy import select as sa_select
- from backend.app.models.user_totp import UserTOTP
- result = await db_session.execute(sa_select(User).where(User.username == "cookie_bind_user"))
- user = result.scalar_one()
- db_session.add(UserTOTP(user_id=user.id, secret=secret, is_enabled=True))
- await db_session.commit()
- # Login from "session A" — gets a pre_auth_token and a 2fa_challenge cookie
- login_resp = await async_client.post(
- LOGIN_URL, json={"username": "cookie_bind_user", "password": "Cookie_bind_pw1"}
- )
- assert login_resp.status_code == 200
- assert login_resp.json()["requires_2fa"] is True
- pre_auth_token = login_resp.json()["pre_auth_token"]
- # The async_client jar now holds the 2fa_challenge cookie for session A
- # Simulate session B by creating a new client WITHOUT the cookie
- from httpx import ASGITransport, AsyncClient as FreshClient
- from backend.app.main import app
- async with FreshClient(transport=ASGITransport(app=app), base_url="http://test") as session_b:
- # Attempt to use session A's pre_auth_token from session B (no cookie)
- verify_resp = await session_b.post(
- "/api/v1/auth/2fa/verify",
- json={
- "pre_auth_token": pre_auth_token,
- "code": totp_obj.now(),
- "method": "totp",
- },
- )
- # Must be rejected — pre_auth_token is bound to session A's cookie
- assert verify_resp.status_code == 401, (
- f"Expected 401 for token replay from cookieless session, got {verify_resp.status_code}: "
- f"{verify_resp.json()}"
- )
- # ===========================================================================
- # C2: Security-header middleware
- # ===========================================================================
- class TestSecurityHeaders:
- """Every HTTP response must include standard security headers (C2)."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_security_headers_present(self, async_client: AsyncClient):
- """GET /api/v1/auth/me (unauthenticated → 401) still carries security headers."""
- resp = await async_client.get(ME_URL)
- assert resp.status_code == 401 # sanity — no auth token
- assert resp.headers.get("x-content-type-options") == "nosniff"
- assert resp.headers.get("x-frame-options") == "SAMEORIGIN"
- assert resp.headers.get("referrer-policy") == "strict-origin-when-cross-origin"
- csp = resp.headers.get("content-security-policy", "")
- assert "default-src 'self'" in csp
- assert "script-src 'self'" in csp
- assert "frame-ancestors 'none'" in csp
- assert "object-src 'none'" in csp
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_hsts_absent_for_http(self, async_client: AsyncClient):
- """HSTS must NOT be set over plain HTTP (test transport uses http)."""
- resp = await async_client.get(ME_URL)
- assert "strict-transport-security" not in resp.headers
- # ===========================================================================
- # I3: Rate-limit bucket interaction — IP spray vs. username spray
- # ===========================================================================
- class TestRateLimitBuckets:
- """IP-spray and username-spray must each trip the correct independent bucket."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_ip_spray_trips_ip_bucket(self, async_client: AsyncClient):
- """20 failed logins from one IP across 20 different usernames trips the IP bucket.
- Each per-username bucket only has 1 failure (well below MAX_LOGIN_ATTEMPTS=10),
- so the username bucket is never the reason for the 429.
- """
- from unittest.mock import patch as _patch
- unique_ip = "10.99.1.1"
- # Ensure auth is enabled
- await async_client.post(
- AUTH_SETUP_URL,
- json={"auth_enabled": True, "admin_username": "spray_ip_admin", "admin_password": "SprayIp_admin1"},
- )
- status_codes: list[int] = []
- with _patch("backend.app.api.routes.auth._get_client_ip", return_value=unique_ip):
- for i in range(22):
- resp = await async_client.post(
- LOGIN_URL,
- json={"username": f"spray_ip_victim_{i}", "password": "wrong"},
- )
- status_codes.append(resp.status_code)
- # The first 20 attempts fail with 401; the 21st+ must be 429 (IP bucket full)
- assert status_codes[-1] == 429, f"Expected 429 after 20 IP-spray failures, got: {status_codes}"
- # No single username saw more than one attempt → username buckets not tripped
- non_429 = [c for c in status_codes[:-2] if c == 429]
- assert not non_429, f"Username bucket triggered early: {status_codes}"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_username_spray_trips_username_bucket(self, async_client: AsyncClient):
- """One username targeted from 10+ different IPs trips the username bucket.
- Each per-IP bucket only sees 1 failure, so no IP bucket is tripped.
- The username bucket (max 10) is what fires the 429.
- """
- from unittest.mock import patch as _patch
- from backend.app.api.routes.mfa import MAX_LOGIN_ATTEMPTS
- # Ensure auth is enabled
- await async_client.post(
- AUTH_SETUP_URL,
- json={
- "auth_enabled": True,
- "admin_username": "spray_uname_admin",
- "admin_password": "SprayUname_admin1",
- },
- )
- target_username = "spray_uname_victim"
- status_codes: list[int] = []
- for i in range(MAX_LOGIN_ATTEMPTS + 2):
- rotating_ip = f"10.99.2.{i + 1}"
- with _patch("backend.app.api.routes.auth._get_client_ip", return_value=rotating_ip):
- resp = await async_client.post(
- LOGIN_URL,
- json={"username": target_username, "password": "wrong"},
- )
- status_codes.append(resp.status_code)
- # After MAX_LOGIN_ATTEMPTS failures for same username the bucket fires
- assert status_codes[-1] == 429, (
- f"Expected 429 after {MAX_LOGIN_ATTEMPTS} username-spray failures, got: {status_codes}"
- )
- # ============================================================================
- # TestEncryptLegacyMigration
- # ============================================================================
- class TestEncryptLegacyMigration:
- """Re-encryption migration of legacy plaintext OIDC + TOTP rows.
- The migration runs against its own ``async_session`` factory (not the
- ``db_session`` fixture) so each test patches the module-level factory to
- point at the test-engine before invoking the helper. ``db_session`` is
- used to seed and to verify state via the same engine.
- """
- @staticmethod
- def _patch_module_session(monkeypatch, db_session):
- """Bind ``database.async_session`` to the test engine for one test."""
- from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
- from backend.app.core import database as db_mod
- test_factory = async_sessionmaker(db_session.bind, class_=AsyncSession, expire_on_commit=False)
- monkeypatch.setattr(db_mod, "async_session", test_factory)
- @staticmethod
- def _set_active_key(monkeypatch):
- """Configure a valid Fernet key for the migration to use."""
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_migration_encrypts_plaintext_oidc_secret(self, db_session, monkeypatch):
- from sqlalchemy import select
- from backend.app.core.database import _migrate_encrypt_legacy_secrets
- from backend.app.models.oidc_provider import OIDCProvider
- self._patch_module_session(monkeypatch, db_session)
- self._set_active_key(monkeypatch)
- provider = OIDCProvider(
- name="LegacyProv",
- issuer_url="https://legacy.example.com",
- client_id="cid",
- _client_secret_enc="legacy-plaintext",
- scopes="openid email profile",
- is_enabled=True,
- )
- db_session.add(provider)
- await db_session.commit()
- await _migrate_encrypt_legacy_secrets()
- # Re-fetch on a fresh row state
- await db_session.refresh(provider)
- assert provider._client_secret_enc.startswith("fernet:")
- # Decrypted value matches the original plaintext
- assert provider.client_secret == "legacy-plaintext"
- # Sanity: a SELECT also sees the encrypted value
- result = await db_session.execute(select(OIDCProvider).where(OIDCProvider.id == provider.id))
- fetched = result.scalar_one()
- assert fetched._client_secret_enc.startswith("fernet:")
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_migration_skips_already_encrypted_rows(self, db_session, monkeypatch):
- from backend.app.core.database import _migrate_encrypt_legacy_secrets
- from backend.app.models.oidc_provider import OIDCProvider
- self._patch_module_session(monkeypatch, db_session)
- self._set_active_key(monkeypatch)
- # Use the property setter so the value is encrypted up front.
- provider = OIDCProvider(
- name="EncProv",
- issuer_url="https://enc.example.com",
- client_id="cid",
- client_secret="already-encrypted",
- scopes="openid email profile",
- is_enabled=True,
- )
- db_session.add(provider)
- await db_session.commit()
- original_enc = provider._client_secret_enc
- await _migrate_encrypt_legacy_secrets()
- await _migrate_encrypt_legacy_secrets() # idempotent
- await db_session.refresh(provider)
- # Value unchanged across two migration runs (still the same ciphertext).
- assert provider._client_secret_enc == original_enc
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_migration_no_op_when_key_unset(self, db_session, monkeypatch):
- import backend.app.core.encryption as enc_mod
- from backend.app.core.database import _migrate_encrypt_legacy_secrets
- from backend.app.models.oidc_provider import OIDCProvider
- self._patch_module_session(monkeypatch, db_session)
- # Force "no key" branch
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- provider = OIDCProvider(
- name="NoKeyProv",
- issuer_url="https://nokey.example.com",
- client_id="cid",
- _client_secret_enc="still-plaintext",
- scopes="openid email profile",
- is_enabled=True,
- )
- db_session.add(provider)
- await db_session.commit()
- await _migrate_encrypt_legacy_secrets()
- await db_session.refresh(provider)
- # Migration should have early-returned; plaintext untouched.
- assert provider._client_secret_enc == "still-plaintext"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_migration_handles_mixed_state(self, db_session, monkeypatch):
- from backend.app.core.database import _migrate_encrypt_legacy_secrets
- from backend.app.models.oidc_provider import OIDCProvider
- self._patch_module_session(monkeypatch, db_session)
- self._set_active_key(monkeypatch)
- legacy = OIDCProvider(
- name="LegacyMix",
- issuer_url="https://l.example.com",
- client_id="c1",
- _client_secret_enc="plain-mix",
- scopes="openid email profile",
- )
- encrypted = OIDCProvider(
- name="EncMix",
- issuer_url="https://e.example.com",
- client_id="c2",
- client_secret="encrypted-mix", # uses setter
- scopes="openid email profile",
- )
- db_session.add_all([legacy, encrypted])
- await db_session.commit()
- original_encrypted = encrypted._client_secret_enc
- await _migrate_encrypt_legacy_secrets()
- await db_session.refresh(legacy)
- await db_session.refresh(encrypted)
- assert legacy._client_secret_enc.startswith("fernet:")
- assert legacy.client_secret == "plain-mix"
- assert encrypted._client_secret_enc == original_encrypted
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_migration_encrypts_plaintext_totp_secret(self, db_session, monkeypatch):
- from backend.app.core.database import _migrate_encrypt_legacy_secrets
- from backend.app.models.user import User
- from backend.app.models.user_totp import UserTOTP
- self._patch_module_session(monkeypatch, db_session)
- self._set_active_key(monkeypatch)
- user = User(username="totpuser1219", email="t@example.com", password_hash="x")
- db_session.add(user)
- await db_session.flush()
- totp = UserTOTP(user_id=user.id, _secret_enc="JBSWY3DPEHPK3PXP", is_enabled=True)
- db_session.add(totp)
- await db_session.commit()
- await _migrate_encrypt_legacy_secrets()
- await db_session.refresh(totp)
- assert totp._secret_enc.startswith("fernet:")
- assert totp.secret == "JBSWY3DPEHPK3PXP"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_migration_logs_count_of_rows_re_encrypted(self, db_session, monkeypatch, caplog):
- import logging
- from backend.app.core.database import _migrate_encrypt_legacy_secrets
- from backend.app.models.oidc_provider import OIDCProvider
- from backend.app.models.user import User
- from backend.app.models.user_totp import UserTOTP
- self._patch_module_session(monkeypatch, db_session)
- self._set_active_key(monkeypatch)
- provider = OIDCProvider(
- name="LegacyLog",
- issuer_url="https://log.example.com",
- client_id="c",
- _client_secret_enc="p",
- scopes="openid email profile",
- )
- user = User(username="logger1219", email="l@example.com", password_hash="x")
- db_session.add_all([provider, user])
- await db_session.flush()
- totp = UserTOTP(user_id=user.id, _secret_enc="JBSWY3DPEHPK3PXP", is_enabled=True)
- db_session.add(totp)
- await db_session.commit()
- with caplog.at_level(logging.INFO, logger="backend.app.core.database"):
- await _migrate_encrypt_legacy_secrets()
- # The migration logs once with both counts.
- assert any(
- "Re-encrypted legacy plaintext secrets" in rec.message
- and "1 OIDC client_secret(s)" in rec.message
- and "1 TOTP secret(s)" in rec.message
- for rec in caplog.records
- )
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_migration_continues_on_row_error(self, db_session, monkeypatch, caplog):
- """B2: per-row commit semantics — when one row fails to re-encrypt,
- OTHER successfully-encrypted rows must remain committed and the
- failure surfaces via get_migration_error_count.
- Replaces the previous "rollback all" behaviour: a single poison row
- used to block every successful re-encryption on every startup forever.
- """
- import logging
- import backend.app.core.encryption as enc_mod # noqa: F401
- from backend.app.core.database import (
- _migrate_encrypt_legacy_secrets,
- get_migration_error_count,
- )
- from backend.app.models.oidc_provider import OIDCProvider
- self._patch_module_session(monkeypatch, db_session)
- self._set_active_key(monkeypatch)
- good = OIDCProvider(
- name="GoodRow",
- issuer_url="https://good.example.com",
- client_id="c1",
- _client_secret_enc="plaintext-good",
- scopes="openid email profile",
- )
- bad = OIDCProvider(
- name="BadRow",
- issuer_url="https://bad.example.com",
- client_id="c2",
- _client_secret_enc="plaintext-bad",
- scopes="openid email profile",
- )
- db_session.add_all([good, bad])
- await db_session.commit()
- original_bad = bad._client_secret_enc
- # Force the setter on the SECOND row to raise — patch at the model's
- # import location so the property setter picks up the patched function.
- import backend.app.models.oidc_provider as oidc_mod
- real_encrypt = oidc_mod.mfa_encrypt
- call_count = [0]
- def _sometimes_raise(value):
- call_count[0] += 1
- if call_count[0] == 2:
- raise RuntimeError("simulated encrypt failure")
- return real_encrypt(value)
- monkeypatch.setattr(oidc_mod, "mfa_encrypt", _sometimes_raise)
- with caplog.at_level(logging.ERROR, logger="backend.app.core.database"):
- await _migrate_encrypt_legacy_secrets()
- # B2: per-row commit — good IS encrypted, bad is unchanged.
- await db_session.refresh(good)
- await db_session.refresh(bad)
- assert good._client_secret_enc.startswith("fernet:"), (
- "good row must be successfully re-encrypted (per-row commit)"
- )
- assert bad._client_secret_enc == original_bad, "bad row must remain unchanged (savepoint-style isolation)"
- assert get_migration_error_count() == 1, "the skipped row must be exposed via get_migration_error_count"
- assert any("skipping" in rec.message.lower() for rec in caplog.records)
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_migration_logs_no_op_when_all_encrypted(self, db_session, monkeypatch, caplog):
- """A2: when all rows are already encrypted, migration logs a debug no-op."""
- import logging
- from backend.app.core.database import _migrate_encrypt_legacy_secrets
- from backend.app.models.oidc_provider import OIDCProvider
- self._patch_module_session(monkeypatch, db_session)
- self._set_active_key(monkeypatch)
- provider = OIDCProvider(
- name="AlreadyEnc",
- issuer_url="https://ae.example.com",
- client_id="cae",
- client_secret="already-encrypted",
- scopes="openid email profile",
- )
- db_session.add(provider)
- await db_session.commit()
- with caplog.at_level(logging.DEBUG, logger="backend.app.core.database"):
- await _migrate_encrypt_legacy_secrets()
- assert any("no rows needed re-encryption" in rec.message for rec in caplog.records)
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_init_db_propagates_unexpected_migration_error(self, monkeypatch, tmp_path):
- """B3: an unexpected error from _migrate_encrypt_legacy_secrets must
- surface (re-raise) instead of being silently swallowed.
- Pins the contract introduced for B3: a startup-fatal error like a
- session-creation failure must fail the lifespan / CLI / restore
- handler explicitly, never run the app with half-migrated rows.
- Implementation note: we patch _migrate_encrypt_legacy_secrets itself
- rather than poking the inner read phase, because that is the contract
- boundary the rest of the codebase relies on (init_db -> migration).
- """
- from sqlalchemy import event
- from sqlalchemy.ext.asyncio import create_async_engine
- import backend.app.core.database as db_mod
- from backend.app.core.config import settings
- # init_db() uses the module-level `engine`, which was bound at import
- # time to settings.database_url — that resolves to the real shared
- # bambuddy.db at the project root (or, when DATABASE_URL is set, the
- # configured Postgres). The autouse DATA_DIR fixture runs too late to
- # influence either. Letting this test write to that real DB makes it
- # (a) non-hermetic and (b) flake under `-n 30` with "database is
- # locked" when two workers race on the file. Substitute an isolated
- # per-test SQLite engine — and override settings.database_url for
- # this test so the is_sqlite() / is_postgres() dialect guards inside
- # run_migrations pick the SQLite path against this engine.
- test_db_url = f"sqlite+aiosqlite:///{tmp_path / 'init_db_test.db'}"
- test_engine = create_async_engine(test_db_url, echo=False)
- event.listen(test_engine.sync_engine, "connect", db_mod._set_sqlite_pragmas)
- monkeypatch.setattr(db_mod, "engine", test_engine)
- monkeypatch.setattr(settings, "database_url", test_db_url)
- async def boom():
- raise RuntimeError("simulated startup-fatal failure")
- # Stub out the rest of init_db so we exercise only the migration step.
- # init_db opens the engine.begin() block, runs metadata.create_all,
- # run_migrations, then awaits _migrate_encrypt_legacy_secrets — the
- # only call we want to fail.
- monkeypatch.setattr(db_mod, "_migrate_encrypt_legacy_secrets", boom)
- monkeypatch.setattr(db_mod, "seed_notification_templates", lambda: _noop_async())
- monkeypatch.setattr(db_mod, "seed_default_groups", lambda: _noop_async())
- monkeypatch.setattr(db_mod, "seed_spool_catalog", lambda: _noop_async())
- monkeypatch.setattr(db_mod, "seed_color_catalog", lambda: _noop_async())
- try:
- with pytest.raises(RuntimeError, match="simulated startup-fatal failure"):
- await db_mod.init_db()
- finally:
- await test_engine.dispose()
- async def _noop_async():
- """Helper for tests that need to stub out `seed_*` async coroutines."""
- return None
- # ============================================================================
- # TestEncryptionStatusEndpoint
- # ============================================================================
- class TestEncryptionStatusEndpoint:
- """GET /api/v1/auth/encryption-status: key source, counts, decryption_broken."""
- STATUS_URL = "/api/v1/auth/encryption-status"
- async def _create_admin_and_login(self, async_client: AsyncClient) -> str:
- """Bootstrap auth + return a Bearer token for an admin."""
- await async_client.post(
- "/api/v1/auth/setup",
- json={
- "auth_enabled": True,
- "admin_username": "admin1219",
- "admin_password": "Admin1219!Pass",
- },
- )
- login = await async_client.post(
- "/api/v1/auth/login",
- json={"username": "admin1219", "password": "Admin1219!Pass"},
- )
- assert login.status_code == 200, login.text
- return login.json()["access_token"]
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_reports_env_source(self, async_client, monkeypatch):
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- token = await self._create_admin_and_login(async_client)
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200
- data = resp.json()
- assert data["key_configured"] is True
- assert data["key_source"] == "env"
- assert data["decryption_broken"] is False
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_reports_file_source(self, async_client, monkeypatch, tmp_path):
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- token = await self._create_admin_and_login(async_client)
- # Pre-place a valid key file in DATA_DIR.
- key_file = tmp_path / ".mfa_encryption_key"
- key_file.write_text(Fernet.generate_key().decode())
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.delenv("MFA_ENCRYPTION_KEY", raising=False)
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200
- data = resp.json()
- assert data["key_source"] == "file"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_reports_generated_source(self, async_client, monkeypatch, tmp_path):
- import backend.app.core.encryption as enc_mod
- token = await self._create_admin_and_login(async_client)
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.delenv("MFA_ENCRYPTION_KEY", raising=False)
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200
- data = resp.json()
- assert data["key_source"] == "generated"
- assert (tmp_path / ".mfa_encryption_key").exists()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_reports_none_source(self, async_client, monkeypatch):
- import backend.app.core.encryption as enc_mod
- token = await self._create_admin_and_login(async_client)
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200
- data = resp.json()
- assert data["key_configured"] is False
- assert data["key_source"] == "none"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_counts_legacy_rows(self, async_client, db_session, monkeypatch):
- from backend.app.models.oidc_provider import OIDCProvider
- token = await self._create_admin_and_login(async_client)
- provider = OIDCProvider(
- name="LegacyStatus",
- issuer_url="https://ls.example.com",
- client_id="c",
- _client_secret_enc="plaintext-no-prefix",
- scopes="openid email profile",
- )
- db_session.add(provider)
- await db_session.commit()
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200
- data = resp.json()
- assert data["legacy_plaintext_rows"]["oidc_providers"] >= 1
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_counts_encrypted_rows(self, async_client, db_session, monkeypatch):
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- from backend.app.models.oidc_provider import OIDCProvider
- token = await self._create_admin_and_login(async_client)
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- provider = OIDCProvider(
- name="EncStatus",
- issuer_url="https://es.example.com",
- client_id="c",
- client_secret="real-secret", # via setter → encrypted
- scopes="openid email profile",
- )
- db_session.add(provider)
- await db_session.commit()
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200
- data = resp.json()
- assert data["encrypted_rows"]["oidc_providers"] >= 1
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_warns_on_encrypted_rows_without_key(self, async_client, db_session, monkeypatch):
- """Gap 2: encrypted rows present but no key loadable → decryption_broken=true."""
- import backend.app.core.encryption as enc_mod
- from backend.app.models.oidc_provider import OIDCProvider
- token = await self._create_admin_and_login(async_client)
- # Insert a row whose value is already prefixed (simulates a previously-encrypted row).
- provider = OIDCProvider(
- name="BrokenEnc",
- issuer_url="https://be.example.com",
- client_id="c",
- _client_secret_enc="fernet:gAAAAA-fake-but-prefixed",
- scopes="openid email profile",
- )
- db_session.add(provider)
- await db_session.commit()
- # Now disable key loading so decryption is impossible.
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200
- data = resp.json()
- assert data["key_configured"] is False
- assert data["encrypted_rows"]["oidc_providers"] >= 1
- assert data["decryption_broken"] is True
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_requires_settings_read_permission(self, async_client, db_session):
- """Non-admin without settings:read permission gets 403."""
- from backend.app.models.user import User
- await self._create_admin_and_login(async_client)
- # Create a low-privilege user (no group → no permissions in default seed).
- from backend.app.core.auth import get_password_hash
- viewer = User(
- username="viewer1219",
- email="viewer1219@example.com",
- password_hash=get_password_hash("Viewer1219!Pass"),
- role="user",
- is_active=True,
- )
- db_session.add(viewer)
- await db_session.commit()
- login = await async_client.post(
- "/api/v1/auth/login",
- json={"username": "viewer1219", "password": "Viewer1219!Pass"},
- )
- assert login.status_code == 200, login.text
- token = login.json().get("access_token")
- assert token is not None, f"Expected access_token in login response, got: {login.json()}"
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 403
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_returns_500_on_db_error(self, async_client, monkeypatch):
- """A8: SQLAlchemyError during count queries → 500 with static message."""
- from unittest.mock import AsyncMock
- from sqlalchemy.exc import SQLAlchemyError
- token = await self._create_admin_and_login(async_client)
- async def _raise(*args, **kwargs):
- raise SQLAlchemyError("simulated DB failure")
- monkeypatch.setattr("sqlalchemy.ext.asyncio.AsyncSession.execute", AsyncMock(side_effect=_raise))
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 500
- assert "encryption status" in resp.json().get("detail", "").lower()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_returns_403_for_viewer_in_viewers_group(self, async_client, db_session):
- """S2: a user in the Viewers group (has SETTINGS_READ but NOT SETTINGS_UPDATE)
- must get 403 — encryption-status is admin/operator only.
- """
- from sqlalchemy import insert, select
- from backend.app.core.auth import get_password_hash
- from backend.app.models.group import Group, user_groups
- from backend.app.models.user import User
- # Bootstrap auth (creates default groups via setup endpoint).
- await self._create_admin_and_login(async_client)
- # Create a user explicitly in the Viewers group — it has SETTINGS_READ
- # but not SETTINGS_UPDATE, which is the discriminator for S2.
- viewer = User(
- username="viewer_s2",
- email="viewer_s2@example.com",
- password_hash=get_password_hash("ViewerS2!Pass1"),
- role="user",
- is_active=True,
- )
- db_session.add(viewer)
- await db_session.flush()
- viewers_group = (await db_session.execute(select(Group).where(Group.name == "Viewers"))).scalar_one_or_none()
- assert viewers_group is not None, "Viewers group must be seeded by setup"
- # Insert the association row directly to avoid touching the lazy
- # `viewer.groups` relationship (which would trigger an implicit
- # IO inside an active async transaction and fail with MissingGreenlet).
- await db_session.execute(insert(user_groups).values(user_id=viewer.id, group_id=viewers_group.id))
- await db_session.commit()
- login = await async_client.post(
- "/api/v1/auth/login",
- json={"username": "viewer_s2", "password": "ViewerS2!Pass1"},
- )
- assert login.status_code == 200, login.text
- token = login.json()["access_token"]
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 403, "S2: Viewers (SETTINGS_READ only) must NOT be able to read encryption-status"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_decryption_broken_when_wrong_key_active(self, async_client, db_session, monkeypatch):
- """B4: key is configured but cannot decrypt existing rows → decryption_broken=True.
- This is the "wrong key" state that the legacy computed_field check
- missed — operator pasted a different valid Fernet key (rotation,
- cross-deployment restore, env override). Status used to show GREEN
- while every encrypted row was unrecoverable.
- """
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- from backend.app.models.oidc_provider import OIDCProvider
- token = await self._create_admin_and_login(async_client)
- # Insert a row whose value is fernet-prefixed but encrypted under a
- # DIFFERENT key (the prefix matches, but decrypt will throw).
- provider = OIDCProvider(
- name="WrongKeyEnc",
- issuer_url="https://wk.example.com",
- client_id="c",
- _client_secret_enc=("fernet:" + Fernet(Fernet.generate_key()).encrypt(b"original").decode()),
- scopes="openid email profile",
- )
- db_session.add(provider)
- await db_session.commit()
- # Now activate a DIFFERENT key — sample-decrypt must fail.
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200, resp.text
- data = resp.json()
- assert data["key_configured"] is True, "different key is still 'configured'"
- assert data["encrypted_rows"]["oidc_providers"] >= 1
- assert data["decryption_broken"] is True, "B4: sample-decrypt must detect wrong-key state"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_decryption_broken_with_only_totp_rows(self, async_client, db_session, monkeypatch):
- """B4: the sample-decrypt fallback to UserTOTP fires when there are no
- encrypted OIDC rows but TOTP rows exist. The OIDC-only test above
- proves the primary path; this pins the second branch in the same
- try-block so a future refactor of the row-source switch can't silently
- regress wrong-key detection for TOTP-only deployments.
- """
- from cryptography.fernet import Fernet
- from sqlalchemy import select
- import backend.app.core.encryption as enc_mod
- from backend.app.models.user import User
- from backend.app.models.user_totp import UserTOTP
- token = await self._create_admin_and_login(async_client)
- # Look up the admin user created by login so we can attach a TOTP row.
- admin_row = await db_session.execute(select(User).where(User.username == "admin1219"))
- admin = admin_row.scalar_one()
- # Seed a UserTOTP row encrypted under key A. No OIDC rows exist, so
- # the endpoint's first branch (oidc_providers > 0) misses and the
- # sample falls through to UserTOTP.
- key_a_ciphertext = Fernet(Fernet.generate_key()).encrypt(b"original-totp-secret").decode()
- db_session.add(UserTOTP(user_id=admin.id, _secret_enc=f"fernet:{key_a_ciphertext}", is_enabled=True))
- await db_session.commit()
- # Activate a DIFFERENT key — the TOTP-fallback sample-decrypt must fail.
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200, resp.text
- data = resp.json()
- assert data["key_configured"] is True
- assert data["encrypted_rows"]["oidc_providers"] == 0, "test premise: no OIDC rows so TOTP branch fires"
- assert data["encrypted_rows"]["user_totp"] >= 1
- assert data["decryption_broken"] is True, "B4: TOTP-fallback sample-decrypt must detect wrong-key state"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_status_surfaces_real_migration_error_count(self, async_client, db_session, monkeypatch, caplog):
- """B2: a real migration with a poison row produces an error_count that
- flows through to the endpoint's `migration_error_count` field.
- Replaces an earlier tautology that patched the module-level counter
- directly. The chained version verifies the full path: poison row →
- per-row migration skip → ``get_migration_error_count()`` →
- ``GET /encryption-status``.
- """
- import logging
- from backend.app.core.database import _migrate_encrypt_legacy_secrets, get_migration_error_count
- from backend.app.models.oidc_provider import OIDCProvider
- token = await self._create_admin_and_login(async_client)
- # Bind the migration's session factory to the test engine and activate a key.
- from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
- from backend.app.core import database as db_mod
- test_factory = async_sessionmaker(db_session.bind, class_=AsyncSession, expire_on_commit=False)
- monkeypatch.setattr(db_mod, "async_session", test_factory)
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- # Two legacy plaintext rows; force the SECOND row's encrypt call to raise.
- db_session.add_all(
- [
- OIDCProvider(
- name="GoodRow",
- issuer_url="https://good.example.com",
- client_id="c1",
- _client_secret_enc="plaintext-good",
- scopes="openid email profile",
- ),
- OIDCProvider(
- name="BadRow",
- issuer_url="https://bad.example.com",
- client_id="c2",
- _client_secret_enc="plaintext-bad",
- scopes="openid email profile",
- ),
- ]
- )
- await db_session.commit()
- import backend.app.models.oidc_provider as oidc_mod
- real_encrypt = oidc_mod.mfa_encrypt
- call_count = [0]
- def _sometimes_raise(value):
- call_count[0] += 1
- if call_count[0] == 2:
- raise RuntimeError("simulated encrypt failure")
- return real_encrypt(value)
- monkeypatch.setattr(oidc_mod, "mfa_encrypt", _sometimes_raise)
- with caplog.at_level(logging.ERROR, logger="backend.app.core.database"):
- await _migrate_encrypt_legacy_secrets()
- # Sanity: the migration's own counter saw the failure.
- assert get_migration_error_count() == 1
- # The endpoint must surface the same number — full path pinned, not just the getter.
- resp = await async_client.get(self.STATUS_URL, headers={"Authorization": f"Bearer {token}"})
- assert resp.status_code == 200, resp.text
- data = resp.json()
- assert data["migration_error_count"] == 1, (
- "endpoint must report the actual migration outcome, not just read a stub global"
- )
- # ============================================================================
- # TestEncryptionRoundtrip (E2E)
- # ============================================================================
- class TestEncryptionRoundtrip:
- """End-to-end: writes via the property setter store ciphertext at the column
- level; reads via the property getter return the original plaintext."""
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_oidc_provider_secret_encrypted_at_rest_e2e(self, db_session, monkeypatch):
- from cryptography.fernet import Fernet
- from sqlalchemy import select
- import backend.app.core.encryption as enc_mod
- from backend.app.models.oidc_provider import OIDCProvider
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- provider = OIDCProvider(
- name="E2E_OIDC",
- issuer_url="https://e2e.example.com",
- client_id="cid",
- client_secret="my-real-client-secret", # via setter → encrypted
- scopes="openid email profile",
- is_enabled=True,
- )
- db_session.add(provider)
- await db_session.commit()
- # Raw column read: must be ciphertext, not the plaintext.
- result = await db_session.execute(select(OIDCProvider).where(OIDCProvider.id == provider.id))
- fetched = result.scalar_one()
- assert fetched._client_secret_enc.startswith("fernet:")
- assert fetched._client_secret_enc != "my-real-client-secret"
- # Property read: returns original plaintext.
- assert fetched.client_secret == "my-real-client-secret"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_totp_secret_encrypted_at_rest_e2e(self, db_session, monkeypatch):
- from cryptography.fernet import Fernet
- from sqlalchemy import select
- import backend.app.core.encryption as enc_mod
- from backend.app.models.user import User
- from backend.app.models.user_totp import UserTOTP
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- user = User(username="e2etotp1219", email="e@example.com", password_hash="x")
- db_session.add(user)
- await db_session.flush()
- totp = UserTOTP(user_id=user.id, secret="JBSWY3DPEHPK3PXP", is_enabled=True)
- db_session.add(totp)
- await db_session.commit()
- result = await db_session.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
- fetched = result.scalar_one()
- assert fetched._secret_enc.startswith("fernet:")
- assert fetched._secret_enc != "JBSWY3DPEHPK3PXP"
- assert fetched.secret == "JBSWY3DPEHPK3PXP"
- # ============================================================================
- # TestBackupKeyFiles
- # Verifies that .mfa_encryption_key is included in backup ZIPs (so backups
- # are self-contained) and restored with chmod 0600 — and that path-traversal
- # payloads in a malicious ZIP are rejected.
- # ============================================================================
- class TestBackupKeyFiles:
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_backup_includes_mfa_encryption_key_when_present(self, async_client, monkeypatch, tmp_path):
- import zipfile
- from backend.app.api.routes.settings import create_backup_zip
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- # Ensure `app_settings.base_dir` follows DATA_DIR for this test by
- # patching the module attribute (config caches it at import time).
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- key_path = tmp_path / ".mfa_encryption_key"
- key_path.write_text("test-key-content")
- zip_path, _filename = await create_backup_zip(output_path=tmp_path)
- try:
- with zipfile.ZipFile(zip_path) as zf:
- names = zf.namelist()
- assert ".mfa_encryption_key" in names
- assert zf.read(".mfa_encryption_key").decode() == "test-key-content"
- finally:
- zip_path.unlink(missing_ok=True)
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_backup_skips_mfa_encryption_key_when_absent(self, async_client, monkeypatch, tmp_path):
- import zipfile
- from backend.app.api.routes.settings import create_backup_zip
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- # No .mfa_encryption_key written — must not crash.
- zip_path, _filename = await create_backup_zip(output_path=tmp_path)
- try:
- with zipfile.ZipFile(zip_path) as zf:
- names = zf.namelist()
- assert ".mfa_encryption_key" not in names
- finally:
- zip_path.unlink(missing_ok=True)
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_restore_writes_key_files_with_chmod_0600(self, async_client, monkeypatch, tmp_path):
- """T1: restore endpoint writes key file with mode 0o600.
- Bypasses the SQLite-copy step via patches so execution reaches the
- key-write code unconditionally — the previous version used a stub
- ``b"SQLite format 3"`` which made ``sqlite3.backup()`` fail and the
- key-write code never ran.
- """
- import io
- import zipfile
- from unittest.mock import AsyncMock, patch
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- # Build a minimal ZIP with a stub DB and the key file.
- buf = io.BytesIO()
- with zipfile.ZipFile(buf, "w") as zf:
- zf.writestr("bambuddy.db", b"SQLite format 3")
- zf.writestr(".mfa_encryption_key", "test-restored-key")
- buf.seek(0)
- with (
- patch("backend.app.core.db_dialect.is_sqlite", return_value=False),
- patch(
- "backend.app.api.routes.settings._import_sqlite_to_postgres",
- new_callable=AsyncMock,
- ),
- patch("backend.app.core.database.close_all_connections", new_callable=AsyncMock),
- patch("backend.app.core.database.reinitialize_database", new_callable=AsyncMock),
- patch("backend.app.core.database.init_db", new_callable=AsyncMock),
- ):
- resp = await async_client.post(
- "/api/v1/settings/restore",
- files={"file": ("backup.zip", buf, "application/zip")},
- )
- assert resp.status_code == 200
- restored_key = tmp_path / ".mfa_encryption_key"
- assert restored_key.exists()
- assert restored_key.read_text() == "test-restored-key"
- assert (restored_key.stat().st_mode & 0o777) == 0o600
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_restore_handles_missing_key_files(self, async_client, monkeypatch, tmp_path):
- """T2: ZIP without key file → restore succeeds, no key written to DATA_DIR."""
- import io
- import zipfile
- from unittest.mock import AsyncMock, patch
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- buf = io.BytesIO()
- with zipfile.ZipFile(buf, "w") as zf:
- zf.writestr("bambuddy.db", b"SQLite format 3")
- # Intentionally no .mfa_encryption_key entry.
- buf.seek(0)
- with (
- patch("backend.app.core.db_dialect.is_sqlite", return_value=False),
- patch(
- "backend.app.api.routes.settings._import_sqlite_to_postgres",
- new_callable=AsyncMock,
- ),
- patch("backend.app.core.database.close_all_connections", new_callable=AsyncMock),
- patch("backend.app.core.database.reinitialize_database", new_callable=AsyncMock),
- patch("backend.app.core.database.init_db", new_callable=AsyncMock),
- ):
- resp = await async_client.post(
- "/api/v1/settings/restore",
- files={"file": ("backup.zip", buf, "application/zip")},
- )
- assert resp.status_code == 200
- assert not (tmp_path / ".mfa_encryption_key").exists()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_restore_aborts_db_swap_when_key_write_fails(self, async_client, monkeypatch, tmp_path):
- """B1: when MFA key write fails, restore must abort BEFORE the database
- swap so the live DB is not left with rows encrypted under a key that
- no longer exists on disk."""
- import io
- import os
- import zipfile
- from unittest.mock import AsyncMock, patch
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- # Build ZIP with a key file that we will fail to write to DATA_DIR.
- buf = io.BytesIO()
- with zipfile.ZipFile(buf, "w") as zf:
- zf.writestr("bambuddy.db", b"SQLite format 3 backup data")
- zf.writestr(".mfa_encryption_key", "backup-key-content")
- buf.seek(0)
- # Track whether the database swap functions were called.
- # If B1 is correct, key-write failure aborts BEFORE these run.
- import_pg_mock = AsyncMock()
- reinit_mock = AsyncMock()
- init_mock = AsyncMock()
- original_open = os.open
- def _key_write_fails(path, flags, mode=0o777, **kwargs):
- # `shutil.rmtree` calls os.open(... dir_fd=...) during temp-dir
- # cleanup — accept and forward any extra kwargs so the mock
- # doesn't break the cleanup path.
- if str(path).endswith(".mfa_encryption_key.restore-tmp"):
- raise OSError(28, "No space left on device", str(path))
- return original_open(path, flags, mode, **kwargs)
- with (
- patch("backend.app.core.db_dialect.is_sqlite", return_value=False),
- patch(
- "backend.app.api.routes.settings._import_sqlite_to_postgres",
- import_pg_mock,
- ),
- patch("backend.app.core.database.close_all_connections", new_callable=AsyncMock),
- patch("backend.app.core.database.reinitialize_database", reinit_mock),
- patch("backend.app.core.database.init_db", init_mock),
- ):
- monkeypatch.setattr(os, "open", _key_write_fails)
- resp = await async_client.post(
- "/api/v1/settings/restore",
- files={"file": ("backup.zip", buf, "application/zip")},
- )
- assert resp.status_code == 500
- assert "Database is unchanged" in resp.json().get("detail", "")
- # Database swap functions must NOT have been called — the abort
- # happens before that step.
- import_pg_mock.assert_not_awaited()
- reinit_mock.assert_not_awaited()
- init_mock.assert_not_awaited()
- # No partial key file should be left behind.
- assert not (tmp_path / ".mfa_encryption_key").exists()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_restore_resets_encryption_singleton_after_key_replace(self, async_client, monkeypatch, tmp_path):
- """B1: after a successful key replace, the encryption singleton must be
- cleared so init_db's re-encryption migration picks up the restored key
- instead of the cached Fernet from the previous key.
- """
- import io
- import zipfile
- from unittest.mock import AsyncMock, patch
- from cryptography.fernet import Fernet
- import backend.app.core.encryption as enc_mod
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- # Pre-warm the singleton with an "old" key so we can detect the reset.
- old_key = Fernet.generate_key().decode()
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", old_key)
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- # Trigger lazy load → singleton holds the old Fernet.
- assert enc_mod.is_encryption_active() is True
- assert enc_mod._fernet_instance is not None
- old_fernet_obj = enc_mod._fernet_instance
- # Build ZIP that delivers a DIFFERENT key file.
- new_key = Fernet.generate_key().decode()
- assert new_key != old_key
- buf = io.BytesIO()
- with zipfile.ZipFile(buf, "w") as zf:
- zf.writestr("bambuddy.db", b"SQLite format 3 backup data")
- zf.writestr(".mfa_encryption_key", new_key)
- buf.seek(0)
- with (
- patch("backend.app.core.db_dialect.is_sqlite", return_value=False),
- patch(
- "backend.app.api.routes.settings._import_sqlite_to_postgres",
- new_callable=AsyncMock,
- ),
- patch("backend.app.core.database.close_all_connections", new_callable=AsyncMock),
- patch("backend.app.core.database.reinitialize_database", new_callable=AsyncMock),
- patch("backend.app.core.database.init_db", new_callable=AsyncMock),
- ):
- resp = await async_client.post(
- "/api/v1/settings/restore",
- files={"file": ("backup.zip", buf, "application/zip")},
- )
- assert resp.status_code == 200, resp.text
- # The singleton must have been invalidated. The exact post-state depends
- # on whether init_db (mocked) re-loaded the singleton, but the cached
- # _fernet_instance reference from before the restore must not be the
- # active one any more.
- assert enc_mod._fernet_instance is None or enc_mod._fernet_instance is not old_fernet_obj, (
- "B1: encryption singleton must be reset after key replace so init_db's migration picks up the restored key"
- )
- # The key file must be on disk with the new content.
- restored = (tmp_path / ".mfa_encryption_key").read_text()
- assert restored == new_key
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_restore_rejects_path_traversal_in_zip(self, async_client, monkeypatch, tmp_path):
- """A4: ZIP with path-traversal entry → HTTP 400, no file written outside temp dir."""
- import io
- import zipfile
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- # Build ZIP with a relative path-traversal entry.
- buf = io.BytesIO()
- with zipfile.ZipFile(buf, "w") as zf:
- zf.writestr("../etc/passwd", "root:x:0:0")
- zf.writestr("bambuddy.db", b"SQLite format 3")
- buf.seek(0)
- resp = await async_client.post(
- "/api/v1/settings/restore",
- files={"file": ("backup.zip", buf, "application/zip")},
- )
- assert resp.status_code == 400
- assert "unsafe path" in resp.json().get("detail", "").lower()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_restore_rejects_prefix_collision_zipslip(self, async_client, monkeypatch, tmp_path):
- """T1: ZIP entry with prefix-collision path must be rejected.
- A startswith() check would accept '/tmp/abc_evil/file' when the
- extraction root was '/tmp/abc' — is_relative_to correctly rejects it.
- The restore handler creates a tempfile.TemporaryDirectory inside the
- system temp dir; we craft an entry that resolves to a sibling path
- whose name starts with the temp dir's basename.
- """
- import io
- import zipfile
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- # Use a path with traversal — the resolved path will share the parent
- # temp directory's basename as a prefix but NOT be inside the
- # extraction root. We don't know the random extraction-root name at
- # ZIP-build time, so we pick a literal "../poc-evil-prefix-collision/"
- # which traverses up one level from the extraction root and lands in
- # a sibling directory. is_relative_to() must reject this; a naive
- # startswith() against the parent's parent would accept it.
- evil_name = "../escaped-prefix-collision/poc.txt"
- buf = io.BytesIO()
- with zipfile.ZipFile(buf, "w") as zf:
- zf.writestr(evil_name, "pwned")
- zf.writestr("bambuddy.db", b"SQLite format 3\x00")
- buf.seek(0)
- resp = await async_client.post(
- "/api/v1/settings/restore",
- files={"file": ("backup.zip", buf, "application/zip")},
- )
- assert resp.status_code == 400
- assert "unsafe path" in resp.json().get("detail", "").lower()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_restore_rejects_absolute_path_in_zip(self, async_client, monkeypatch, tmp_path):
- """B1: ZIP with an absolute path entry must be rejected by is_relative_to check."""
- import io
- import zipfile
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- buf = io.BytesIO()
- with zipfile.ZipFile(buf, "w") as zf:
- # Absolute path in the archive — extracts outside temp_path on
- # systems where (temp_path / "/etc/passwd") resolves to /etc/passwd.
- zf.writestr("/etc/passwd", "root:x:0:0")
- zf.writestr("bambuddy.db", b"SQLite format 3")
- buf.seek(0)
- resp = await async_client.post(
- "/api/v1/settings/restore",
- files={"file": ("backup.zip", buf, "application/zip")},
- )
- assert resp.status_code == 400
- assert "unsafe path" in resp.json().get("detail", "").lower()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_backup_fails_when_key_file_unreadable(self, async_client, monkeypatch, tmp_path):
- """A5: OSError while copying key file propagates out of create_backup_zip."""
- import shutil
- from backend.app.api.routes.settings import create_backup_zip
- from backend.app.core.config import settings as app_settings
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- (tmp_path / ".mfa_encryption_key").write_text("key")
- original_copy2 = shutil.copy2
- def _raise_on_key(src, dst):
- if ".mfa_encryption_key" in str(src):
- raise OSError("simulated unreadable key file")
- return original_copy2(src, dst)
- monkeypatch.setattr(shutil, "copy2", _raise_on_key)
- import pytest as _pytest
- with _pytest.raises(OSError, match="simulated unreadable"):
- await create_backup_zip(output_path=tmp_path)
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_backup_restore_roundtrip_preserves_encrypted_oidc_secret(
- self, async_client, db_session, monkeypatch, tmp_path
- ):
- """T3: encrypt → backup → simulate key loss → restore → decrypt.
- Verifies the user-facing promise that local backup ZIPs are
- self-contained: an OIDC client_secret encrypted under one key still
- decrypts after restore even when the running install no longer has
- the key on disk or in the env. Exercises the B1 key-first restore
- path and the B4 sample-decrypt status check together.
- """
- import zipfile
- from pathlib import Path
- from unittest.mock import AsyncMock, patch
- from cryptography.fernet import Fernet
- from sqlalchemy import select
- import backend.app.core.encryption as enc_mod
- from backend.app.api.routes.settings import create_backup_zip
- from backend.app.core.config import settings as app_settings
- from backend.app.models.oidc_provider import OIDCProvider
- # 1. Pin a key, encrypt an OIDC secret via the property setter.
- key = Fernet.generate_key().decode()
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", key)
- monkeypatch.setenv("DATA_DIR", str(tmp_path))
- monkeypatch.setattr(app_settings, "base_dir", tmp_path)
- # Persist the key file too, so create_backup_zip picks it up.
- (tmp_path / ".mfa_encryption_key").write_text(key)
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- provider = OIDCProvider(
- name="RoundtripProv",
- issuer_url="https://rt.example.com",
- client_id="cid",
- client_secret="my-original-secret", # via setter -> encrypted
- scopes="openid email profile",
- is_enabled=True,
- )
- db_session.add(provider)
- await db_session.commit()
- original_id = provider.id
- assert provider._client_secret_enc.startswith("fernet:")
- # 2. Create a backup ZIP (must include .mfa_encryption_key).
- zip_path, _ = await create_backup_zip(output_path=tmp_path)
- try:
- with zipfile.ZipFile(zip_path) as zf:
- names = zf.namelist()
- assert ".mfa_encryption_key" in names, "T3: backup ZIP must include the key file"
- # 3. Simulate key loss: delete the key file from DATA_DIR, drop
- # the env var, reset the cached fernet singleton.
- (tmp_path / ".mfa_encryption_key").unlink()
- monkeypatch.delenv("MFA_ENCRYPTION_KEY", raising=False)
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- # 4. Restore the ZIP via the endpoint. Mock out the DB-swap
- # (we keep the live in-memory test DB) and init_db side effects
- # so this test focuses on the key-restore path.
- with (
- patch("backend.app.core.db_dialect.is_sqlite", return_value=False),
- patch(
- "backend.app.api.routes.settings._import_sqlite_to_postgres",
- new_callable=AsyncMock,
- ),
- patch("backend.app.core.database.close_all_connections", new_callable=AsyncMock),
- patch("backend.app.core.database.reinitialize_database", new_callable=AsyncMock),
- patch("backend.app.core.database.init_db", new_callable=AsyncMock),
- open(zip_path, "rb") as f,
- ):
- resp = await async_client.post(
- "/api/v1/settings/restore",
- files={"file": ("backup.zip", f, "application/zip")},
- )
- assert resp.status_code == 200, resp.text
- # 5. Reset the singleton again (B1 already does this in production,
- # but here init_db is mocked so we explicitly invalidate).
- enc_mod._fernet_instance = None
- enc_mod._key_source = None
- # 6. The key file must be back on disk with restrictive permissions.
- restored = Path(tmp_path) / ".mfa_encryption_key"
- assert restored.exists(), "T3: key file must be restored to DATA_DIR"
- assert (restored.stat().st_mode & 0o777) == 0o600
- # 7. Decryption works again — the property getter must return the
- # original plaintext, proving the restored key matches the
- # cipher in the (still in-memory) DB row.
- result = await db_session.execute(select(OIDCProvider).where(OIDCProvider.id == original_id))
- restored_provider = result.scalar_one()
- assert restored_provider.client_secret == "my-original-secret"
- finally:
- zip_path.unlink(missing_ok=True)
- # ============================================================================
- # TestTOTPDecryptionBroken (C9)
- # Verifies the decryption-broken state (encrypted TOTP row + no key) for each
- # TOTP endpoint. Behaviour differs between recovery-aware and non-recovery
- # endpoints:
- # - setup_totp / enable_totp / verify_2fa: HTTP 500 (no backup-code path).
- # - disable_totp / regenerate_backup_codes: fall through to the backup-code
- # branch — HTTP 200 with a valid backup code, HTTP 400 without.
- # ============================================================================
- class TestTOTPDecryptionBroken:
- """C9: RuntimeError from mfa_decrypt — 500 for non-recovery endpoints,
- backup-code fall-through for disable_totp / regenerate_backup_codes."""
- async def _setup_admin_and_totp_user(self, async_client, db_session):
- """Create admin (enables auth), log in as admin, add TOTP record with fernet secret."""
- from backend.app.models.user_totp import UserTOTP
- admin_username = f"admin_c9_{secrets.token_hex(4)}"
- setup = await async_client.post(
- "/api/v1/auth/setup",
- json={
- "auth_enabled": True,
- "admin_username": admin_username,
- "admin_password": "Admin_C9_Pass1!",
- },
- )
- assert setup.status_code in (200, 201), setup.text
- login = await async_client.post(
- "/api/v1/auth/login",
- json={"username": admin_username, "password": "Admin_C9_Pass1!"},
- )
- assert login.status_code == 200, login.text
- token = login.json()["access_token"]
- # Get the admin user_id from the /me endpoint
- me = await async_client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
- assert me.status_code == 200
- user_id = me.json()["id"]
- # Insert a TOTP row with a fernet-prefixed secret directly (no key needed for insert).
- totp = UserTOTP(
- user_id=user_id,
- _secret_enc="fernet:gAAAAA-not-really-encrypted",
- is_enabled=True,
- )
- db_session.add(totp)
- await db_session.commit()
- return token, admin_username, user_id
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_enable_totp_returns_500_when_decryption_broken(self, async_client, db_session, monkeypatch):
- """C9: enable endpoint → 500 when TOTP secret is encrypted but key unavailable."""
- import backend.app.core.encryption as enc_mod
- token, _, _ = await self._setup_admin_and_totp_user(async_client, db_session)
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- # enable_totp requires setup-but-not-yet-enabled state; force is_enabled=False
- from sqlalchemy import select as _select
- from backend.app.models.user_totp import UserTOTP
- result = await db_session.execute(_select(UserTOTP))
- for t in result.scalars().all():
- t.is_enabled = False
- await db_session.commit()
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/enable",
- json={"code": "123456"},
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 500
- assert "unavailable" in resp.json().get("detail", "").lower()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_disable_totp_returns_400_when_decryption_broken_and_no_backup_codes(
- self, async_client, db_session, monkeypatch
- ):
- """B2a + S3: disable falls through to backup-code branch when TOTP secret
- cannot be decrypted; with no backup codes seeded, the request is
- rejected as an invalid code (400), not a server error.
- S3: AND the failed-attempt counter must NOT be incremented — the
- cause was a server-side key loss, not a user mistake.
- """
- from sqlalchemy import select as _select
- import backend.app.core.encryption as enc_mod
- from backend.app.models.auth_ephemeral import AuthRateLimitEvent
- token, admin_username, _ = await self._setup_admin_and_totp_user(async_client, db_session)
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/disable",
- json={"code": "123456"},
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 400
- assert "invalid" in resp.json().get("detail", "").lower()
- # S3: no fail-counter debit on server-side key loss.
- events = (
- (
- await db_session.execute(
- _select(AuthRateLimitEvent).where(AuthRateLimitEvent.username == admin_username.lower())
- )
- )
- .scalars()
- .all()
- )
- assert len(events) == 0, "S3: must not debit fail-counter on key-loss"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_regenerate_backup_codes_returns_400_when_decryption_broken_and_no_backup_codes(
- self, async_client, db_session, monkeypatch
- ):
- """B2b + S3: regenerate-backup-codes falls through to backup-code branch when
- TOTP secret cannot be decrypted; with no backup codes seeded, the
- request is rejected as an invalid code (400) AND the fail-counter
- is NOT incremented (S3: server-side cause, not user mistake).
- """
- from sqlalchemy import select as _select
- import backend.app.core.encryption as enc_mod
- from backend.app.models.auth_ephemeral import AuthRateLimitEvent
- token, admin_username, _ = await self._setup_admin_and_totp_user(async_client, db_session)
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/regenerate-backup-codes",
- json={"code": "123456"},
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 400
- assert "invalid" in resp.json().get("detail", "").lower()
- events = (
- (
- await db_session.execute(
- _select(AuthRateLimitEvent).where(AuthRateLimitEvent.username == admin_username.lower())
- )
- )
- .scalars()
- .all()
- )
- assert len(events) == 0, "S3: must not debit fail-counter on key-loss"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_disable_totp_succeeds_via_backup_code_when_decryption_broken(
- self, async_client, db_session, monkeypatch
- ):
- """B2a: a valid backup code disables TOTP even when the secret cannot
- be decrypted — recovery path for users who lost the encryption key."""
- from sqlalchemy import select as _select
- import backend.app.core.encryption as enc_mod
- from backend.app.api.routes.mfa import _generate_backup_codes
- from backend.app.models.user_totp import UserTOTP
- token, _, user_id = await self._setup_admin_and_totp_user(async_client, db_session)
- # Seed a real backup-code hash on the existing TOTP row.
- plain_codes, hashed_codes = _generate_backup_codes()
- result = await db_session.execute(_select(UserTOTP).where(UserTOTP.user_id == user_id))
- totp = result.scalar_one()
- totp.backup_code_hashes = hashed_codes
- await db_session.commit()
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/disable",
- json={"code": plain_codes[0]},
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 200, resp.text
- # The TOTP row must have been deleted.
- result_after = await db_session.execute(_select(UserTOTP).where(UserTOTP.user_id == user_id))
- assert result_after.scalar_one_or_none() is None
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_regenerate_backup_codes_succeeds_via_backup_code_when_decryption_broken(
- self, async_client, db_session, monkeypatch
- ):
- """B2b: a valid backup code rotates the codes even when the secret
- cannot be decrypted — recovery path mirrors disable_totp."""
- from sqlalchemy import select as _select
- import backend.app.core.encryption as enc_mod
- from backend.app.api.routes.mfa import _generate_backup_codes
- from backend.app.models.user_totp import UserTOTP
- token, _, user_id = await self._setup_admin_and_totp_user(async_client, db_session)
- plain_codes, hashed_codes = _generate_backup_codes()
- result = await db_session.execute(_select(UserTOTP).where(UserTOTP.user_id == user_id))
- totp = result.scalar_one()
- totp.backup_code_hashes = hashed_codes
- await db_session.commit()
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/regenerate-backup-codes",
- json={"code": plain_codes[0]},
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 200, resp.text
- body = resp.json()
- assert "backup_codes" in body
- assert len(body["backup_codes"]) == 10
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_disable_totp_wrong_code_with_seeded_hashes_returns_400_and_debits_counter(
- self, async_client, db_session, monkeypatch
- ):
- """T2: with backup_code_hashes seeded AND a working encryption key,
- a wrong code is rejected (400) AND the fail-counter IS incremented.
- This pins the behaviour that a future refactor swallowing
- compare_digest mismatches would still let the existing 'no codes
- configured' tests pass — only this assertion exercises the actual
- pwd_context.verify mismatch path.
- """
- from cryptography.fernet import Fernet
- from sqlalchemy import select as _select
- import backend.app.core.encryption as enc_mod
- from backend.app.api.routes.mfa import _generate_backup_codes
- from backend.app.models.auth_ephemeral import AuthRateLimitEvent
- from backend.app.models.user_totp import UserTOTP
- # Active key — secret can be decrypted, this is NOT key-loss.
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- token, admin_username, user_id = await self._setup_admin_and_totp_user(async_client, db_session)
- # Replace stub fernet:-prefixed value with a real encrypted secret so
- # disable_totp's TOTP-decrypt path doesn't throw, AND seed real hashes.
- result = await db_session.execute(_select(UserTOTP).where(UserTOTP.user_id == user_id))
- totp = result.scalar_one()
- totp.secret = "JBSWY3DPEHPK3PXP" # via setter -> mfa_encrypt
- plain_codes, hashed_codes = _generate_backup_codes()
- totp.backup_code_hashes = hashed_codes
- await db_session.commit()
- # Submit a code that matches NEITHER the TOTP nor any backup-code hash.
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/disable",
- json={"code": "WRONGCD1"}, # wrong but well-formed
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 400
- assert "invalid" in resp.json().get("detail", "").lower()
- # T2 + S3: with key intact, the fail-counter MUST increment for a
- # real wrong-code attempt (this is the user-error path, not key-loss).
- events = (
- (
- await db_session.execute(
- _select(AuthRateLimitEvent).where(AuthRateLimitEvent.username == admin_username.lower())
- )
- )
- .scalars()
- .all()
- )
- assert len(events) >= 1, "T2: with key intact, wrong code must debit the fail-counter"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_regenerate_backup_codes_wrong_code_with_seeded_hashes_returns_400_and_debits_counter(
- self, async_client, db_session, monkeypatch
- ):
- """T2: same as the disable_totp variant for /regenerate-backup-codes."""
- from cryptography.fernet import Fernet
- from sqlalchemy import select as _select
- import backend.app.core.encryption as enc_mod
- from backend.app.api.routes.mfa import _generate_backup_codes
- from backend.app.models.auth_ephemeral import AuthRateLimitEvent
- from backend.app.models.user_totp import UserTOTP
- monkeypatch.setenv("MFA_ENCRYPTION_KEY", Fernet.generate_key().decode())
- enc_mod._fernet_instance = None
- token, admin_username, user_id = await self._setup_admin_and_totp_user(async_client, db_session)
- result = await db_session.execute(_select(UserTOTP).where(UserTOTP.user_id == user_id))
- totp = result.scalar_one()
- totp.secret = "JBSWY3DPEHPK3PXP"
- plain_codes, hashed_codes = _generate_backup_codes()
- totp.backup_code_hashes = hashed_codes
- await db_session.commit()
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/regenerate-backup-codes",
- json={"code": "WRONGCD2"},
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 400
- assert "invalid" in resp.json().get("detail", "").lower()
- events = (
- (
- await db_session.execute(
- _select(AuthRateLimitEvent).where(AuthRateLimitEvent.username == admin_username.lower())
- )
- )
- .scalars()
- .all()
- )
- assert len(events) >= 1, "T2: with key intact, wrong code must debit the fail-counter"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_disable_totp_wrong_code_with_seeded_hashes_at_keyloss_no_counter_debit(
- self, async_client, db_session, monkeypatch
- ):
- """T2 + S3 cross-check: with hashes seeded but encryption key gone,
- a wrong code returns 400 BUT the fail-counter MUST NOT increment.
- This is the dual of the test above — same wrong-code 400 outcome,
- but the counter debit is gated on the cause of failure (server-side
- key loss must NOT penalise the user).
- """
- from sqlalchemy import select as _select
- import backend.app.core.encryption as enc_mod
- from backend.app.api.routes.mfa import _generate_backup_codes
- from backend.app.models.auth_ephemeral import AuthRateLimitEvent
- from backend.app.models.user_totp import UserTOTP
- token, admin_username, user_id = await self._setup_admin_and_totp_user(async_client, db_session)
- # Seed real hashes on the existing TOTP row.
- result = await db_session.execute(_select(UserTOTP).where(UserTOTP.user_id == user_id))
- totp = result.scalar_one()
- plain_codes, hashed_codes = _generate_backup_codes()
- totp.backup_code_hashes = hashed_codes
- await db_session.commit()
- # Now simulate key loss.
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/disable",
- json={"code": "WRONGCD3"},
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 400
- # S3: counter MUST be unchanged — this is a server-side problem.
- events = (
- (
- await db_session.execute(
- _select(AuthRateLimitEvent).where(AuthRateLimitEvent.username == admin_username.lower())
- )
- )
- .scalars()
- .all()
- )
- assert len(events) == 0, "S3: must not debit fail-counter when cause is server-side key-loss"
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_setup_totp_returns_500_when_decryption_broken(self, async_client, db_session, monkeypatch):
- """B3: setup endpoint → 500 when an active TOTP secret can't be decrypted.
- Replacing an active authenticator requires verifying the current TOTP
- code; with no recovery (backup-code) path on this endpoint, the only
- safe outcome is a 500 surface to the operator.
- """
- import backend.app.core.encryption as enc_mod
- token, _, _ = await self._setup_admin_and_totp_user(async_client, db_session)
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- resp = await async_client.post(
- "/api/v1/auth/2fa/totp/setup",
- json={"code": "123456"},
- headers={"Authorization": f"Bearer {token}"},
- )
- assert resp.status_code == 500
- assert "unavailable" in resp.json().get("detail", "").lower()
- @pytest.mark.asyncio
- @pytest.mark.integration
- async def test_verify_2fa_returns_500_when_decryption_broken(self, async_client, db_session, monkeypatch):
- """C9: verify endpoint (TOTP method) → 500 when TOTP secret unreadable."""
- from datetime import datetime, timedelta, timezone
- import backend.app.core.encryption as enc_mod
- from backend.app.models.auth_ephemeral import AuthEphemeralToken
- token, admin_username, user_id = await self._setup_admin_and_totp_user(async_client, db_session)
- monkeypatch.setattr(enc_mod, "_load_or_generate_key", lambda: (None, "none"))
- enc_mod._fernet_instance = None
- # Create a pre_auth token to simulate the post-login 2FA challenge step.
- raw_token = secrets.token_urlsafe(32)
- ephemeral = AuthEphemeralToken(
- token=raw_token,
- token_type="pre_auth",
- username=admin_username,
- expires_at=datetime.now(timezone.utc) + timedelta(minutes=5),
- )
- db_session.add(ephemeral)
- await db_session.commit()
- resp = await async_client.post(
- "/api/v1/auth/2fa/verify",
- json={"pre_auth_token": raw_token, "method": "totp", "code": "123456"},
- )
- assert resp.status_code == 500
- assert "unavailable" in resp.json().get("detail", "").lower()
|