| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848 |
- """2FA (TOTP + Email OTP) and OIDC authentication routes.
- Security model
- --------------
- * Pre-auth tokens : secrets.token_urlsafe(32) stored in-memory with a 5-minute TTL.
- They are single-use and do NOT grant access to any protected resource.
- * TOTP codes : verified with pyotp (30-second window, ±1 step tolerance).
- * Email OTP codes : 6-digit numeric, hashed with pbkdf2_sha256, 10-minute TTL,
- max 5 failed attempts per code before invalidation.
- * Backup codes : 10 × 8-char alphanumeric codes, each stored as pbkdf2_sha256 hash,
- single-use.
- * OIDC state : secrets.token_urlsafe(32) bound to provider_id + nonce, 10-minute TTL.
- * OIDC exchange : secrets.token_urlsafe(32), 2-minute TTL, single-use.
- * Rate limiting : max 5 failed 2FA verification attempts per user within 15 minutes.
- """
- from __future__ import annotations
- import base64
- import hashlib
- import io
- import logging
- import os
- import re
- import secrets
- import string
- import urllib.parse
- from datetime import datetime, timedelta, timezone
- import httpx
- import jwt
- import pyotp
- from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Response, status
- from fastapi.responses import RedirectResponse
- from jwt import PyJWKClient
- from passlib.context import CryptContext
- from sqlalchemy import delete, select
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import selectinload
- from backend.app.api.routes.settings import get_setting, set_setting
- from backend.app.core.auth import (
- ACCESS_TOKEN_EXPIRE_MINUTES,
- RequirePermissionIfAuthEnabled,
- create_access_token,
- get_current_active_user,
- get_user_by_email,
- get_user_by_username,
- is_auth_enabled,
- verify_password,
- )
- from backend.app.core.database import get_db
- from backend.app.core.permissions import Permission
- from backend.app.models.auth_ephemeral import AuthEphemeralToken, AuthRateLimitEvent, EventType, TokenType
- from backend.app.models.group import Group
- from backend.app.models.oidc_provider import OIDCProvider, UserOIDCLink
- from backend.app.models.user import User
- from backend.app.models.user_otp_code import UserOTPCode
- from backend.app.models.user_totp import UserTOTP
- from backend.app.schemas.auth import (
- AUTO_LINK_REQUIREMENTS_ERROR,
- AdminDisable2FARequest,
- BackupCodesResponse,
- EmailOTPDisableRequest,
- EmailOTPEnableConfirmRequest,
- EmailOTPSendRequest,
- GroupBrief,
- LoginResponse,
- OIDCAuthorizeResponse,
- OIDCExchangeRequest,
- OIDCLinkResponse,
- OIDCProviderCreate,
- OIDCProviderResponse,
- OIDCProviderUpdate,
- TOTPDisableRequest,
- TOTPEnableRequest,
- TOTPEnableResponse,
- TOTPSetupRequest,
- TOTPSetupResponse,
- TwoFAStatusResponse,
- TwoFAVerifyRequest,
- TwoFAVerifyResponse,
- UserResponse,
- )
- from backend.app.services.email_service import get_smtp_settings, send_email
- logger = logging.getLogger(__name__)
- def _as_utc(dt: datetime) -> datetime:
- """Return *dt* with UTC timezone attached.
- SQLite/aiosqlite strips timezone info when reading DateTime(timezone=True)
- columns back – the stored value is always UTC, so we just re-attach the
- info when doing Python-level comparisons.
- """
- return dt if dt.tzinfo is not None else dt.replace(tzinfo=timezone.utc)
- # ---------------------------------------------------------------------------
- # Passlib context (same scheme as auth.py)
- # ---------------------------------------------------------------------------
- pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
- # ---------------------------------------------------------------------------
- # TTL / rate-limit constants
- # ---------------------------------------------------------------------------
- MAX_2FA_ATTEMPTS = 5
- MAX_LOGIN_ATTEMPTS = 10
- LOCKOUT_WINDOW = timedelta(minutes=15)
- MAX_EMAIL_OTP_SENDS = 3
- EMAIL_OTP_SEND_WINDOW = timedelta(minutes=10)
- PRE_AUTH_TOKEN_TTL = timedelta(minutes=5)
- OIDC_STATE_TTL = timedelta(minutes=10)
- OIDC_EXCHANGE_TTL = timedelta(minutes=2)
- # ---------------------------------------------------------------------------
- # Router
- # ---------------------------------------------------------------------------
- router = APIRouter(prefix="/auth", tags=["2fa", "oidc"])
- # ---------------------------------------------------------------------------
- # Helper: user response
- # ---------------------------------------------------------------------------
- def _user_to_response(user: User) -> UserResponse:
- return UserResponse(
- id=user.id,
- username=user.username,
- email=user.email,
- role=user.role,
- is_active=user.is_active,
- is_admin=user.is_admin,
- groups=[GroupBrief(id=g.id, name=g.name) for g in user.groups],
- permissions=sorted(user.get_permissions()),
- created_at=user.created_at.isoformat(),
- )
- # ---------------------------------------------------------------------------
- # Helper: QR code generation
- # ---------------------------------------------------------------------------
- def _generate_totp_qr_b64(provisioning_uri: str) -> str:
- """Generate a base64-encoded PNG QR code for the given TOTP provisioning URI."""
- import qrcode # type: ignore
- qr = qrcode.QRCode(box_size=6, border=2)
- qr.add_data(provisioning_uri)
- qr.make(fit=True)
- img = qr.make_image(fill_color="black", back_color="white")
- buf = io.BytesIO()
- img.save(buf, format="PNG")
- return base64.b64encode(buf.getvalue()).decode()
- # ---------------------------------------------------------------------------
- # Helper: backup code generation
- # ---------------------------------------------------------------------------
- def _generate_backup_codes() -> tuple[list[str], list[str]]:
- """Return (plain_codes, hashed_codes) — 10 codes of 8 alphanumeric chars each."""
- alphabet = string.ascii_uppercase + string.digits
- plain = ["".join(secrets.choice(alphabet) for _ in range(8)) for _ in range(10)]
- hashed = [pwd_context.hash(c) for c in plain]
- return plain, hashed
- # ---------------------------------------------------------------------------
- # DB-backed pre-auth token helpers
- # ---------------------------------------------------------------------------
- async def create_pre_auth_token(db: AsyncSession, username: str, challenge_id: str | None = None) -> str:
- """Create a single-use pre-auth token stored in the DB.
- Pass ``challenge_id`` (from the HttpOnly 2fa_challenge cookie) to bind the
- token to the originating browser session. The same value must be present as
- a cookie on every subsequent call that consumes this token.
- """
- now = datetime.now(timezone.utc)
- # Prune expired tokens opportunistically (keep table small)
- await db.execute(
- delete(AuthEphemeralToken).where(
- AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
- AuthEphemeralToken.expires_at < now,
- )
- )
- token = secrets.token_urlsafe(32)
- db.add(
- AuthEphemeralToken(
- token=token,
- token_type=TokenType.PRE_AUTH,
- username=username,
- challenge_id=challenge_id,
- expires_at=now + PRE_AUTH_TOKEN_TTL,
- )
- )
- await db.commit()
- return token
- async def consume_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None:
- """Atomically validate and consume a pre-auth token. Returns username or None.
- Uses DELETE...RETURNING so two concurrent requests with the same token cannot
- both succeed — only the first DELETE finds the row.
- M5: When challenge_id is provided, also enforces the cookie-binding constraint
- so a stolen token cannot be replayed from a different browser session.
- """
- now = datetime.now(timezone.utc)
- result = await db.execute(
- delete(AuthEphemeralToken)
- .where(
- AuthEphemeralToken.token == token,
- AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
- AuthEphemeralToken.expires_at > now,
- )
- .returning(AuthEphemeralToken.username, AuthEphemeralToken.challenge_id)
- )
- row = result.one_or_none()
- if row is None:
- return None
- username, stored_challenge_id = row
- # Enforce client binding: if the token was issued with a challenge_id,
- # the caller must supply the matching value.
- if stored_challenge_id is not None and stored_challenge_id != challenge_id:
- await db.rollback()
- return None
- await db.commit()
- return username
- async def peek_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None:
- """Validate a pre-auth token and return the username WITHOUT consuming it.
- When the stored token has a ``challenge_id`` (client-binding cookie), the
- caller must supply the matching value. A mismatch is treated as an invalid
- token — no information leakage about whether the token itself exists.
- """
- now = datetime.now(timezone.utc)
- result = await db.execute(
- select(AuthEphemeralToken).where(
- AuthEphemeralToken.token == token,
- AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
- AuthEphemeralToken.expires_at > now,
- )
- )
- eph = result.scalar_one_or_none()
- if eph is None:
- return None
- # Enforce client binding: if the token was issued with a challenge_id the
- # cookie must match. Treat a mismatch as if the token doesn't exist.
- if eph.challenge_id is not None and eph.challenge_id != challenge_id:
- return None
- return eph.username
- # ---------------------------------------------------------------------------
- # DB-backed rate-limiting helpers
- # ---------------------------------------------------------------------------
- async def check_rate_limit(
- db: AsyncSession,
- username: str,
- event_type: str = EventType.TWO_FA_ATTEMPT,
- max_attempts: int = MAX_2FA_ATTEMPTS,
- ) -> None:
- """Raise HTTP 429 if the user has exceeded the failed attempt limit.
- The username is normalised to lower-case so case-variant attempts
- (which all resolve to the same user) share the same rate-limit bucket.
- L-2: Known TOCTOU — the SELECT (count) and the subsequent INSERT
- (record_failed_attempt) are not atomic. Two concurrent requests can both
- read a count below the threshold and both proceed. This is an inherent
- trade-off of the event-log rate-limit pattern: fixing it would require
- a serialising lock (SELECT FOR UPDATE on a dedicated counter row), which
- adds contention and is not worth it for a soft rate-limit whose window is
- already measured in minutes. In practice the race window is microseconds
- and the limit can be slightly exceeded only under precise concurrent timing.
- """
- username_key = username.lower()
- now = datetime.now(timezone.utc)
- cutoff = now - LOCKOUT_WINDOW
- result = await db.execute(
- select(AuthRateLimitEvent).where(
- AuthRateLimitEvent.username == username_key,
- AuthRateLimitEvent.event_type == event_type,
- AuthRateLimitEvent.occurred_at > cutoff,
- )
- )
- recent_count = len(result.scalars().all())
- if recent_count >= max_attempts:
- raise HTTPException(
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
- detail="Too many failed attempts. Please try again later.",
- )
- async def record_failed_attempt(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None:
- """Record a failed attempt for rate-limiting purposes."""
- db.add(AuthRateLimitEvent(username=username.lower(), event_type=event_type))
- await db.commit()
- async def clear_failed_attempts(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None:
- """Delete all recorded failed attempts for a user on successful verification."""
- await db.execute(
- delete(AuthRateLimitEvent).where(
- AuthRateLimitEvent.username == username.lower(),
- AuthRateLimitEvent.event_type == event_type,
- )
- )
- await db.commit()
- async def check_email_otp_send_rate(db: AsyncSession, username: str) -> None:
- """Raise HTTP 429 if the user has requested too many OTP emails recently.
- I1: This function only *checks* the limit. The caller is responsible for
- recording the slot via ``record_email_otp_send`` **after** the email has
- been sent successfully. This prevents failed sends from consuming a slot
- (wasting the user's quota) and makes it impossible to farm rate-limit events
- without actually triggering a send.
- """
- username_key = username.lower()
- now = datetime.now(timezone.utc)
- cutoff = now - EMAIL_OTP_SEND_WINDOW
- result = await db.execute(
- select(AuthRateLimitEvent).where(
- AuthRateLimitEvent.username == username_key,
- AuthRateLimitEvent.event_type == EventType.EMAIL_SEND,
- AuthRateLimitEvent.occurred_at > cutoff,
- )
- )
- recent_count = len(result.scalars().all())
- if recent_count >= MAX_EMAIL_OTP_SENDS:
- raise HTTPException(
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
- detail=f"Too many OTP email requests. Please wait {EMAIL_OTP_SEND_WINDOW.seconds // 60} minutes.",
- )
- async def record_email_otp_send(db: AsyncSession, username: str) -> None:
- """Record a successful OTP email send for rate-limiting purposes (I1).
- Must be called *after* the email has been sent successfully so that failed
- sends do not consume a slot from the user's quota.
- """
- db.add(AuthRateLimitEvent(username=username.lower(), event_type=EventType.EMAIL_SEND))
- await db.commit()
- # ---------------------------------------------------------------------------
- # TOTP replay-protection helper
- # ---------------------------------------------------------------------------
- def _assert_totp_not_replayed(totp_obj: pyotp.TOTP, totp_record: UserTOTP, code: str) -> None:
- """Raise HTTP 400 if this TOTP code was already accepted in its time window.
- M3 fix: store the counter of the *accepted* code rather than the current
- wall-clock counter. With valid_window=1, pyotp accepts codes from the
- previous 30-second step. Using timecode(now) would store the wrong counter
- when the previous-window code is accepted, allowing immediate replay.
- """
- # Determine which time-step the accepted code belongs to.
- now = datetime.now(timezone.utc)
- accepted_counter: int | None = None
- for offset in (0, -1): # current window first, then previous
- candidate_time = now.timestamp() + offset * totp_obj.interval
- candidate_counter = totp_obj.timecode(datetime.fromtimestamp(candidate_time, tz=timezone.utc))
- if totp_obj.at(candidate_counter) == code:
- accepted_counter = candidate_counter
- break
- if accepted_counter is None:
- accepted_counter = totp_obj.timecode(now) # fallback (should not happen after verify())
- totp_record.accept_counter(accepted_counter)
- # ---------------------------------------------------------------------------
- # OIDC helpers
- # ---------------------------------------------------------------------------
- _EMAIL_SHAPE_RE = re.compile(r"[^\s@]+@[^\s@]+\.[^\s@]+")
- def _is_valid_email_shaped(value: str | None) -> bool:
- # SEC-2: shape check for non-standard claims (upn, preferred_username).
- # Requires local@domain.tld — rejects "@", "x@", "@domain", "x@nodot".
- if not value or len(value) > 255:
- return False
- return _EMAIL_SHAPE_RE.fullmatch(value) is not None
- def _enforce_auto_link_safety(provider: OIDCProvider) -> None:
- """Raise HTTP 422 if auto_link_existing_accounts is on with an unsafe combined state.
- SEC-1: only Fall B (email_claim='email' + require_email_verified=False) is unsafe —
- an attacker-controlled IdP could present an unverified email that matches a local account.
- Fall C (custom claim) never performs an email_verified check, so auto_link is safe there.
- Called after ORM construction (create) and after the setattr loop (update).
- """
- if provider.auto_link_existing_accounts and provider.email_claim == "email" and not provider.require_email_verified:
- raise HTTPException(
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- detail=AUTO_LINK_REQUIREMENTS_ERROR,
- )
- def _resolve_provider_email(provider: OIDCProvider, claims: dict, provider_sub: str) -> str | None:
- """Extract and normalise the email address from OIDC ID-token claims.
- Implements three resolution paths (Fall A/B/C):
- Fall C — custom email_claim (!= "email"): shape-check only, no email_verified gate.
- Recommended for Azure Entra ID (preferred_username or upn).
- Fall A — email_claim="email" + require_email_verified=True: strict, email_verified must be True.
- Fall B — email_claim="email" + require_email_verified=False: permissive, explicit False drops email.
- Returns a lowercase-stripped email string, or None when the claim is absent/invalid.
- """
- provider_id = provider.id
- raw_claim_value = claims.get(provider.email_claim)
- if raw_claim_value is not None and not isinstance(raw_claim_value, str):
- # TYPE-GUARD: non-string claim (e.g. list, int) would raise AttributeError on .lower().
- logger.warning(
- "OIDC provider %d: email_claim %r has unexpected type %s for sub=%r, ignoring",
- provider_id,
- provider.email_claim,
- type(raw_claim_value).__name__,
- provider_sub,
- )
- raw_claim_value = None
- raw_email: str | None = raw_claim_value.lower().strip() if raw_claim_value else None
- if provider.email_claim != "email":
- # Fall C: custom claim (preferred_username, upn, …) — no email_verified check.
- # SEC-2: _is_valid_email_shaped instead of bare '"@" in value'.
- # Recommended for Azure Entra ID: set email_claim="preferred_username" or "upn".
- if raw_email and _is_valid_email_shaped(raw_email):
- return raw_email
- if raw_email:
- logger.warning(
- "OIDC provider %d: email_claim %r value failed shape check for sub=%r, ignoring",
- provider_id,
- provider.email_claim,
- provider_sub,
- )
- return None
- email_verified = claims.get("email_verified")
- if provider.require_email_verified:
- # Fall A: standard C1-Guard — fail closed unless email_verified is True.
- # SEC-2: apply shape check to standard email claim — providers may set
- # email_verified=True on non-email values (e.g. numeric user IDs).
- # SEC-3 normalisation applies; existing mixed-case provider_email records
- # were normalised to lowercase by run_migrations at startup.
- if raw_email and not _is_valid_email_shaped(raw_email):
- logger.warning(
- "OIDC provider %d: email claim failed shape check for sub=%r, ignoring",
- provider_id,
- provider_sub,
- )
- return None
- if email_verified is True:
- return raw_email
- if raw_email:
- logger.info(
- "OIDC provider %d: ignoring email for sub=%r because email_verified=%r",
- provider_id,
- provider_sub,
- email_verified,
- )
- return None
- # Fall B: permissive — explicit False drops email, absent/None keeps it.
- # Required for Azure Entra ID which never sends email_verified.
- # SEC-2: apply shape check before the email_verified=False drop so malformed
- # values are rejected regardless of the email_verified claim.
- if raw_email and not _is_valid_email_shaped(raw_email):
- logger.warning(
- "OIDC provider %d: email claim failed shape check for sub=%r, ignoring",
- provider_id,
- provider_sub,
- )
- return None
- if email_verified is False:
- return None
- if email_verified is not True:
- # SEC-5: log only when the permissive path actually fires (ev absent/None),
- # not on every successful login.
- logger.info(
- "OIDC provider %r (%d): accepting email for sub=%r without email_verified claim (permissive mode)",
- provider.name,
- provider.id,
- provider_sub,
- )
- return raw_email
- # ---------------------------------------------------------------------------
- # Settings helpers (email 2FA flag)
- # ---------------------------------------------------------------------------
- async def _get_email_2fa_enabled(db: AsyncSession, user_id: int) -> bool:
- val = await get_setting(db, f"user_{user_id}_email_2fa_enabled")
- return val == "true"
- async def _set_email_2fa_enabled(db: AsyncSession, user_id: int, enabled: bool) -> None:
- await set_setting(db, f"user_{user_id}_email_2fa_enabled", "true" if enabled else "false")
- # ===========================================================================
- # 2FA Endpoints
- # ===========================================================================
- @router.get("/2fa/status", response_model=TwoFAStatusResponse)
- async def get_2fa_status(
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> TwoFAStatusResponse:
- """Return the current 2FA configuration for the authenticated user."""
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
- totp_record = result.scalar_one_or_none()
- totp_enabled = totp_record is not None and totp_record.is_enabled
- backup_codes_remaining = len(totp_record.backup_code_hashes) if totp_record else 0
- email_otp_enabled = await _get_email_2fa_enabled(db, current_user.id)
- return TwoFAStatusResponse(
- totp_enabled=totp_enabled,
- email_otp_enabled=email_otp_enabled,
- backup_codes_remaining=backup_codes_remaining,
- )
- @router.post("/2fa/totp/setup", response_model=TOTPSetupResponse)
- async def setup_totp(
- body: TOTPSetupRequest | None = Body(default=None),
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> TOTPSetupResponse:
- """Initiate TOTP setup: generates a new secret and QR code.
- Creates (or replaces) a pending UserTOTP record with is_enabled=False.
- The caller must confirm with POST /auth/2fa/totp/enable.
- M-R7-A: If an *active* TOTP is already configured, the caller must supply
- the current TOTP code in the request body to confirm intent before the
- secret is overwritten (prevents silently locking out the real user).
- """
- if not await is_auth_enabled(db):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Authentication is not enabled")
- # Upsert a pending TOTP record (is_enabled=False)
- existing = (await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))).scalar_one_or_none()
- # M-R7-A: Guard against silent TOTP replacement when one is already active.
- if existing and existing.is_enabled:
- await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- supplied_code = (body.code if body else None) or ""
- if not pyotp.TOTP(existing.secret).verify(supplied_code, valid_window=1):
- await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Current TOTP code required to replace an active authenticator",
- )
- await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- _assert_totp_not_replayed(pyotp.TOTP(existing.secret), existing, supplied_code)
- await db.flush() # L-3: persist last_totp_counter immediately to block replay
- secret = pyotp.random_base32()
- totp = pyotp.TOTP(secret)
- provisioning_uri = totp.provisioning_uri(name=current_user.username, issuer_name="Bambuddy")
- qr_b64 = _generate_totp_qr_b64(provisioning_uri)
- if existing:
- existing.secret = secret
- existing.is_enabled = False
- existing.backup_code_hashes = []
- else:
- db.add(UserTOTP(user_id=current_user.id, secret=secret, is_enabled=False))
- await db.commit()
- return TOTPSetupResponse(secret=secret, qr_code_b64=qr_b64, issuer="Bambuddy")
- @router.post("/2fa/totp/enable", response_model=TOTPEnableResponse)
- async def enable_totp(
- body: TOTPEnableRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> TOTPEnableResponse:
- """Confirm TOTP setup by verifying a code from the authenticator app.
- On success, enables TOTP and returns 10 single-use backup codes (shown once).
- L-R7-A: Rate-limited to prevent brute-forcing the 6-digit confirmation code.
- """
- # L-R7-A: Rate-limit the enable step to prevent brute-forcing the 6-digit code.
- await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP setup not initiated. Call /auth/2fa/totp/setup first."
- )
- if not pyotp.TOTP(totp_record.secret).verify(body.code, valid_window=1):
- await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP code")
- await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- plain_codes, hashed_codes = _generate_backup_codes()
- totp_record.is_enabled = True
- totp_record.backup_code_hashes = hashed_codes
- await db.commit()
- return TOTPEnableResponse(
- message="TOTP enabled successfully. Store your backup codes in a safe place.",
- backup_codes=plain_codes,
- )
- @router.post("/2fa/totp/disable")
- async def disable_totp(
- body: TOTPDisableRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Disable TOTP by verifying a valid TOTP code or a backup code.
- I10: Rate-limited to prevent backup-code brute-forcing from a hijacked session.
- """
- await check_rate_limit(db, current_user.username)
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record or not totp_record.is_enabled:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled")
- # Accept either a valid TOTP code or a valid backup code
- totp_obj = pyotp.TOTP(totp_record.secret)
- code_valid = totp_obj.verify(body.code, valid_window=1)
- if code_valid:
- _assert_totp_not_replayed(totp_obj, totp_record, body.code)
- await db.flush() # L-3: persist last_totp_counter immediately to block replay
- else:
- # Check backup codes — always iterate all entries (L-R9-A: no early break
- # to avoid timing oracle based on code position in the list).
- for hashed in totp_record.backup_code_hashes:
- if pwd_context.verify(body.code, hashed):
- code_valid = True
- if not code_valid:
- await record_failed_attempt(db, current_user.username)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid code")
- await db.execute(delete(UserTOTP).where(UserTOTP.user_id == current_user.id))
- await db.commit()
- return {"message": "TOTP disabled"}
- @router.post("/2fa/totp/regenerate-backup-codes", response_model=BackupCodesResponse)
- async def regenerate_backup_codes(
- body: TOTPDisableRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> BackupCodesResponse:
- """Generate 10 new backup codes. Requires a valid TOTP code OR a backup code.
- M10: Accepts backup codes for consistency with disable_totp — users who have
- lost their authenticator app but still have backup codes can regenerate.
- Rate-limited to prevent brute-forcing from a hijacked session.
- """
- await check_rate_limit(db, current_user.username)
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record or not totp_record.is_enabled:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled")
- totp_obj = pyotp.TOTP(totp_record.secret)
- code_valid = totp_obj.verify(body.code, valid_window=1)
- if code_valid:
- _assert_totp_not_replayed(totp_obj, totp_record, body.code)
- await db.flush() # L-3: persist last_totp_counter immediately to block replay
- else:
- # Accept a backup code as an alternative (M10)
- matched_index: int | None = None
- for idx, hashed in enumerate(totp_record.backup_code_hashes):
- if pwd_context.verify(body.code, hashed) and matched_index is None:
- matched_index = idx
- if matched_index is None:
- await record_failed_attempt(db, current_user.username)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP or backup code")
- # Remove the used backup code
- totp_record.backup_code_hashes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index]
- plain_codes, hashed_codes = _generate_backup_codes()
- totp_record.backup_code_hashes = hashed_codes
- await db.commit()
- return BackupCodesResponse(
- backup_codes=plain_codes,
- message="Backup codes regenerated. Store them safely — they will not be shown again.",
- )
- @router.post("/2fa/email/enable")
- async def enable_email_otp(
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Step 1 of email OTP enable: send a verification code to the user's email.
- C5: Proof of possession — the user must prove they control the registered email
- address before email 2FA is activated. Returns a ``setup_token`` that must be
- passed to POST /auth/2fa/email/enable/confirm together with the received code.
- H-3: Rate-limited to prevent email flooding via repeated calls to this endpoint.
- """
- await check_email_otp_send_rate(db, current_user.username)
- if not current_user.email:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="You must have an email address configured to enable email OTP 2FA",
- )
- smtp_settings = await get_smtp_settings(db)
- if not smtp_settings:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured")
- # Generate and store the setup token (reuse AuthEphemeralToken with type "email_otp_setup")
- now = datetime.now(timezone.utc)
- # Prune any existing pending setup tokens for this user
- await db.execute(
- delete(AuthEphemeralToken).where(
- AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
- AuthEphemeralToken.username == current_user.username,
- )
- )
- code = str(secrets.randbelow(1_000_000)).zfill(6)
- code_hash = pwd_context.hash(code)
- setup_token = secrets.token_urlsafe(32)
- db.add(
- AuthEphemeralToken(
- token=setup_token,
- token_type=TokenType.EMAIL_OTP_SETUP,
- username=current_user.username,
- # Reuse the nonce field to store the code hash
- nonce=code_hash,
- expires_at=now + timedelta(minutes=10),
- )
- )
- await db.commit()
- try:
- send_email(
- smtp_settings=smtp_settings,
- to_email=current_user.email,
- subject="Verify your Bambuddy email address for 2FA",
- body_text=(
- f"Your Bambuddy email 2FA setup code is: {code}\n\n"
- "Enter this code to confirm email-based two-factor authentication.\n"
- "The code expires in 10 minutes."
- ),
- body_html=(
- "<p>To enable <strong>email-based two-factor authentication</strong> on your Bambuddy account, "
- "enter the code below:</p>"
- f"<h2 style='letter-spacing:4px'>{code}</h2>"
- "<p>The code expires in <strong>10 minutes</strong>. "
- "If you did not request this, you can safely ignore this email.</p>"
- ),
- )
- await record_email_otp_send(db, current_user.username)
- except Exception as exc:
- logger.error("Failed to send email OTP setup code to user_id=%d: %s", current_user.id, exc)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send verification email"
- )
- return {"message": "Verification code sent to your email address", "setup_token": setup_token}
- @router.post("/2fa/email/enable/confirm")
- async def confirm_enable_email_otp(
- body: EmailOTPEnableConfirmRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Step 2 of email OTP enable: verify the code and activate email 2FA.
- H-2 fix: Uses peek-then-consume so a wrong code does NOT burn the setup token.
- The token is only deleted after successful code verification, allowing retries
- up to the rate limit (5 attempts / 15 min).
- M4: Rate-limited to prevent brute-forcing the 6-digit setup code.
- """
- await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- now = datetime.now(timezone.utc)
- # --- Peek: validate token without consuming ---
- peek_result = await db.execute(
- select(AuthEphemeralToken).where(
- AuthEphemeralToken.token == body.setup_token,
- AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
- AuthEphemeralToken.username == current_user.username,
- AuthEphemeralToken.expires_at > now,
- )
- )
- eph = peek_result.scalar_one_or_none()
- if eph is None:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token")
- code_hash = eph.nonce # code hash stored in the nonce field
- # --- Verify code before consuming the token ---
- if not pwd_context.verify(body.code, code_hash):
- await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification code")
- # --- Atomically consume the token now that the code is correct ---
- # DELETE...RETURNING prevents a concurrent request from using the same token.
- del_result = await db.execute(
- delete(AuthEphemeralToken)
- .where(
- AuthEphemeralToken.token == body.setup_token,
- AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
- AuthEphemeralToken.username == current_user.username,
- )
- .returning(AuthEphemeralToken.id)
- )
- if del_result.one_or_none() is None:
- # Concurrent request consumed it between peek and delete — treat as invalid.
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token")
- await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- await _set_email_2fa_enabled(db, current_user.id, True)
- await db.commit()
- return {"message": "Email OTP 2FA enabled"}
- @router.post("/2fa/email/disable")
- async def disable_email_otp(
- body: EmailOTPDisableRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Disable email-based OTP 2FA for the current user.
- C6: Re-authentication required — the caller must supply their account password
- to prevent a hijacked session from silently removing a second factor.
- LDAP/OIDC-only users (no local password) are exempt from this check.
- H-2: Rate-limited to prevent brute-forcing the password via this endpoint.
- """
- await check_rate_limit(db, current_user.username)
- if current_user.password_hash:
- if not verify_password(body.password, current_user.password_hash):
- await record_failed_attempt(db, current_user.username)
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password")
- await _set_email_2fa_enabled(db, current_user.id, False)
- await db.commit()
- return {"message": "Email OTP 2FA disabled"}
- @router.post("/2fa/email/send")
- async def send_email_otp(
- request: Request,
- body: EmailOTPSendRequest,
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Send a 6-digit OTP code to the user's email address.
- Requires a valid pre_auth_token obtained during the login flow.
- """
- # Peek (validate without consuming) first so a rate-limit rejection does not
- # permanently burn the caller's pre-auth token.
- challenge_id = request.cookies.get("2fa_challenge")
- username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not username:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- # Enforce rate limit BEFORE consuming the token to prevent OTP email flooding.
- await check_email_otp_send_rate(db, username)
- user = await get_user_by_username(db, username)
- if not user or not user.is_active:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
- if not user.email:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no email address configured")
- smtp_settings = await get_smtp_settings(db)
- if not smtp_settings:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured")
- # Invalidate all existing unused OTP codes for this user (staged, not yet committed)
- await db.execute(
- UserOTPCode.__table__.update() # type: ignore[attr-defined]
- .where(UserOTPCode.user_id == user.id)
- .where(UserOTPCode.used.is_(False))
- .values(used=True)
- )
- # Generate a 6-digit code and stage the record (not committed yet)
- code = str(secrets.randbelow(1_000_000)).zfill(6)
- code_hash = pwd_context.hash(code)
- expires_at = datetime.now(timezone.utc) + timedelta(minutes=UserOTPCode.OTP_TTL_MINUTES)
- otp_record = UserOTPCode(
- user_id=user.id,
- code_hash=code_hash,
- attempts=0,
- used=False,
- expires_at=expires_at,
- )
- db.add(otp_record)
- # M2: Send the email BEFORE consuming the pre-auth token.
- # If the send fails we raise an exception here; the session is uncommitted so
- # the OTP record is discarded and the original token remains valid for retry.
- try:
- send_email(
- smtp_settings=smtp_settings,
- to_email=user.email,
- subject="Your Bambuddy verification code",
- body_text=f"Your Bambuddy login code is: {code}\n\nThis code expires in {UserOTPCode.OTP_TTL_MINUTES} minutes and can only be used once.",
- body_html=(
- f"<p>Your <strong>Bambuddy</strong> login verification code is:</p>"
- f"<h2 style='letter-spacing:4px'>{code}</h2>"
- f"<p>This code expires in <strong>{UserOTPCode.OTP_TTL_MINUTES} minutes</strong> and can only be used once.</p>"
- f"<p>If you did not request this code, you can safely ignore this email.</p>"
- ),
- )
- await record_email_otp_send(db, username)
- except Exception as exc:
- logger.error("Failed to send OTP email to user_id=%d: %s", user.id, exc)
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send OTP email")
- # Email sent — now atomically consume the old token (this also commits the
- # staged OTP record) and issue a fresh token for the verify step.
- consumed = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not consumed:
- # Raced with another request or token just expired — treat as invalid.
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- # Re-issue a fresh pre-auth token bound to the same cookie so the binding
- # carries forward through the email → verify step.
- fresh_token = await create_pre_auth_token(db, username, challenge_id=challenge_id)
- # Return the fresh pre-auth token so the frontend can proceed to verify
- return {"message": "Code sent to your email address", "pre_auth_token": fresh_token}
- @router.post("/2fa/verify", response_model=TwoFAVerifyResponse)
- async def verify_2fa(
- request: Request,
- body: TwoFAVerifyRequest,
- db: AsyncSession = Depends(get_db),
- ) -> TwoFAVerifyResponse:
- """Verify a 2FA code and exchange the pre_auth_token for a full JWT.
- Accepted methods: ``totp``, ``email``, ``backup``.
- The pre_auth_token is NOT consumed on failed verification attempts so the
- user can retry without restarting the login flow. It is only consumed once
- verification succeeds, preventing token replay after success.
- """
- # Peek without consuming — bad codes must not burn the session token.
- # Pass the HttpOnly challenge cookie so the binding check is enforced.
- challenge_id = request.cookies.get("2fa_challenge")
- username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not username:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- await check_rate_limit(db, username)
- user = await get_user_by_username(db, username)
- if not user or not user.is_active:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
- method = body.method
- if method == "totp":
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record or not totp_record.is_enabled:
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user")
- totp_obj = pyotp.TOTP(totp_record.secret)
- if not totp_obj.verify(body.code, valid_window=1):
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid TOTP code")
- _assert_totp_not_replayed(totp_obj, totp_record, body.code)
- await db.flush() # L-3: persist last_totp_counter immediately to block replay
- elif method == "email":
- now = datetime.now(timezone.utc)
- result = await db.execute(
- select(UserOTPCode)
- .where(UserOTPCode.user_id == user.id)
- .where(UserOTPCode.used.is_(False))
- .where(UserOTPCode.expires_at > now)
- .order_by(UserOTPCode.created_at.desc())
- )
- otp_record = result.scalar_one_or_none()
- if not otp_record:
- await record_failed_attempt(db, username)
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="No valid OTP code found. Request a new one."
- )
- if otp_record.attempts >= UserOTPCode.MAX_ATTEMPTS:
- otp_record.consume()
- await db.commit()
- await record_failed_attempt(db, username)
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="OTP code has been invalidated after too many attempts"
- )
- if not pwd_context.verify(body.code, otp_record.code_hash):
- otp_record.attempts += 1
- await db.commit()
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid OTP code")
- otp_record.consume()
- await db.commit()
- else: # method == "backup"
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record or not totp_record.is_enabled:
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user")
- # Always iterate all codes — no early break (L-R9-A: constant iteration
- # count prevents timing oracle based on used-code position in the list).
- matched_index: int | None = None
- for idx, hashed in enumerate(totp_record.backup_code_hashes):
- if pwd_context.verify(body.code, hashed) and matched_index is None:
- matched_index = idx
- if matched_index is None:
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid backup code")
- # M1: Consume the pre-auth token FIRST (atomic single-use enforcement).
- # Only if that succeeds do we remove the backup code — this prevents a race
- # where two concurrent requests both pass code verification but only one
- # should be granted a session.
- consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not consumed_username:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- # Remove the used backup code now that the token is atomically consumed.
- updated_codes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index]
- totp_record.backup_code_hashes = updated_codes
- await db.commit()
- await clear_failed_attempts(db, username)
- access_token = create_access_token(
- data={"sub": user.username},
- expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
- )
- result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
- user = result.scalar_one()
- return TwoFAVerifyResponse(access_token=access_token, token_type="bearer", user=_user_to_response(user))
- # Verification succeeded (TOTP or email) — consume the pre-auth token.
- # C-1: Check the return value; if None the token was already consumed by a
- # concurrent request (race condition) — reject to prevent double-use.
- consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not consumed_username:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- await clear_failed_attempts(db, username)
- access_token = create_access_token(
- data={"sub": user.username},
- expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
- )
- # Reload with groups for permission calculation
- result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
- user = result.scalar_one()
- return TwoFAVerifyResponse(
- access_token=access_token,
- token_type="bearer",
- user=_user_to_response(user),
- )
- @router.delete("/2fa/admin/{user_id}")
- async def admin_disable_2fa(
- user_id: int,
- body: AdminDisable2FARequest = Body(default_factory=AdminDisable2FARequest),
- current_user: User | None = RequirePermissionIfAuthEnabled(Permission.USERS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Admin endpoint: disable all 2FA for a given user.
- Nit 3: Requires the admin's own password as a re-auth step (matching how
- disable_email_otp protects a user's own 2FA removal). OIDC/LDAP-only admins
- (no local password_hash) are exempt.
- """
- # Nit 3: Re-auth — admin must supply their own password.
- if current_user and current_user.password_hash:
- if not body.admin_password or not verify_password(body.admin_password, current_user.password_hash):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin password required")
- # Delete TOTP record
- await db.execute(delete(UserTOTP).where(UserTOTP.user_id == user_id))
- # Disable email 2FA setting
- await _set_email_2fa_enabled(db, user_id, False)
- # Invalidate all OTP codes
- await db.execute(
- UserOTPCode.__table__.update() # type: ignore[attr-defined]
- .where(UserOTPCode.user_id == user_id)
- .values(used=True)
- )
- # I2: Invalidate existing JWTs for the target user by bumping password_changed_at.
- # Without this, a stolen token remains valid after 2FA removal.
- target_user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none()
- if target_user:
- target_user.password_changed_at = datetime.now(timezone.utc)
- await db.commit()
- actor = current_user.username if current_user else "anonymous"
- logger.info("Admin %s disabled all 2FA for user_id=%d", actor, user_id)
- return {"message": "2FA disabled for user"}
- # ===========================================================================
- # OIDC Endpoints
- # ===========================================================================
- @router.get("/oidc/providers", response_model=list[OIDCProviderResponse])
- async def list_oidc_providers(
- db: AsyncSession = Depends(get_db),
- ) -> list[OIDCProviderResponse]:
- """List all enabled OIDC providers (public)."""
- result = await db.execute(select(OIDCProvider).where(OIDCProvider.is_enabled.is_(True)))
- providers = result.scalars().all()
- return [OIDCProviderResponse.model_validate(p) for p in providers]
- @router.get("/oidc/providers/all", response_model=list[OIDCProviderResponse])
- async def list_all_oidc_providers(
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_READ),
- db: AsyncSession = Depends(get_db),
- ) -> list[OIDCProviderResponse]:
- """List ALL OIDC providers including disabled ones (admin only)."""
- result2 = await db.execute(select(OIDCProvider))
- providers = result2.scalars().all()
- return [OIDCProviderResponse.model_validate(p) for p in providers]
- @router.post("/oidc/providers", response_model=OIDCProviderResponse, status_code=status.HTTP_201_CREATED)
- async def create_oidc_provider(
- body: OIDCProviderCreate,
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> OIDCProviderResponse:
- """Create a new OIDC provider (admin only)."""
- if body.default_group_id is not None:
- grp_chk = await db.execute(select(Group).where(Group.id == body.default_group_id))
- if not grp_chk.scalar_one_or_none():
- raise HTTPException(
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- detail="default_group_id references a non-existent group",
- )
- provider = OIDCProvider(
- name=body.name,
- issuer_url=body.issuer_url.rstrip("/"),
- client_id=body.client_id,
- client_secret=body.client_secret,
- scopes=body.scopes,
- is_enabled=body.is_enabled,
- auto_create_users=body.auto_create_users,
- auto_link_existing_accounts=body.auto_link_existing_accounts,
- email_claim=body.email_claim,
- require_email_verified=body.require_email_verified,
- icon_url=body.icon_url,
- default_group_id=body.default_group_id,
- )
- # SEC-1 + SEC-6: runtime guard mirrors the OIDCProviderCreate model_validator in schemas/auth.py.
- # Catches any future path that bypasses Pydantic validation (direct ORM, scripts).
- _enforce_auto_link_safety(provider)
- db.add(provider)
- await db.commit()
- await db.refresh(provider)
- return OIDCProviderResponse.model_validate(provider)
- @router.put("/oidc/providers/{provider_id}", response_model=OIDCProviderResponse)
- async def update_oidc_provider(
- provider_id: int,
- body: OIDCProviderUpdate,
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> OIDCProviderResponse:
- """Update an existing OIDC provider (admin only)."""
- result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
- provider = result2.scalar_one_or_none()
- if not provider:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
- if body.default_group_id is not None:
- grp_chk = await db.execute(select(Group).where(Group.id == body.default_group_id))
- if not grp_chk.scalar_one_or_none():
- raise HTTPException(
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- detail="default_group_id references a non-existent group",
- )
- for field, value in body.model_dump(exclude_none=True).items():
- if field == "issuer_url" and value:
- value = value.rstrip("/")
- setattr(provider, field, value)
- # SEC-1 + SEC-6: Combined-State-Guard after setattr loop.
- # Checks the final in-memory state (DB values + newly set values combined) to catch
- # partial updates that each pass schema validation individually but are unsafe together.
- _enforce_auto_link_safety(provider)
- await db.commit()
- await db.refresh(provider)
- return OIDCProviderResponse.model_validate(provider)
- @router.delete("/oidc/providers/{provider_id}")
- async def delete_oidc_provider(
- provider_id: int,
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Delete an OIDC provider and all its user links (admin only)."""
- result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
- provider = result2.scalar_one_or_none()
- if not provider:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
- await db.delete(provider)
- await db.commit()
- return {"message": "Provider deleted"}
- @router.get("/oidc/authorize/{provider_id}", response_model=OIDCAuthorizeResponse)
- async def oidc_authorize(
- provider_id: int,
- db: AsyncSession = Depends(get_db),
- ) -> OIDCAuthorizeResponse:
- """Return the OIDC authorization URL for the given provider."""
- result = await db.execute(
- select(OIDCProvider).where(OIDCProvider.id == provider_id).where(OIDCProvider.is_enabled.is_(True))
- )
- provider = result.scalar_one_or_none()
- if not provider:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found or not enabled")
- # Fetch discovery document
- discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration"
- try:
- async with httpx.AsyncClient(timeout=10) as client:
- resp = await client.get(discovery_url)
- resp.raise_for_status()
- discovery = resp.json()
- except Exception as exc:
- logger.error("Failed to fetch OIDC discovery for provider %d: %s", provider_id, exc)
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch OIDC discovery document")
- authorization_endpoint = discovery.get("authorization_endpoint")
- if not authorization_endpoint:
- raise HTTPException(
- status_code=status.HTTP_502_BAD_GATEWAY, detail="OIDC discovery document missing authorization_endpoint"
- )
- # B2: SSRF guard — reject non-HTTP(S) schemes in the authorization endpoint
- if not authorization_endpoint.startswith(("https://", "http://")):
- logger.warning("OIDC discovery authorization_endpoint has invalid scheme: %s", authorization_endpoint)
- raise HTTPException(
- status_code=status.HTTP_502_BAD_GATEWAY,
- detail="OIDC discovery document contains invalid authorization_endpoint",
- )
- external_url = await _get_base_external_url(db)
- redirect_uri = f"{external_url}/api/v1/auth/oidc/callback"
- now = datetime.now(timezone.utc)
- # Prune expired OIDC states from the DB
- await db.execute(
- delete(AuthEphemeralToken).where(
- AuthEphemeralToken.token_type == TokenType.OIDC_STATE,
- AuthEphemeralToken.expires_at < now,
- )
- )
- state = secrets.token_urlsafe(32)
- nonce = secrets.token_urlsafe(32)
- # PKCE (S256) – required by PocketID and recommended for all OIDC flows
- code_verifier = secrets.token_urlsafe(48) # 64-char URL-safe string
- code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
- db.add(
- AuthEphemeralToken(
- token=state,
- token_type=TokenType.OIDC_STATE,
- provider_id=provider_id,
- nonce=nonce,
- code_verifier=code_verifier,
- expires_at=now + OIDC_STATE_TTL,
- )
- )
- await db.commit()
- params = urllib.parse.urlencode(
- {
- "response_type": "code",
- "client_id": provider.client_id,
- "redirect_uri": redirect_uri,
- "scope": provider.scopes,
- "state": state,
- "nonce": nonce,
- "code_challenge": code_challenge,
- "code_challenge_method": "S256",
- }
- )
- auth_url = f"{authorization_endpoint}?{params}"
- return OIDCAuthorizeResponse(auth_url=auth_url)
- @router.get("/oidc/callback")
- async def oidc_callback(
- code: str | None = Query(default=None, max_length=2048),
- state: str | None = Query(default=None, max_length=2048),
- error: str | None = Query(default=None, max_length=256),
- db: AsyncSession = Depends(get_db),
- ) -> RedirectResponse:
- """Handle the OIDC authorization code callback from the identity provider."""
- external_url = await _get_base_external_url(db)
- frontend_error_url = f"{external_url}/?oidc_error="
- try:
- if error:
- logger.warning("OIDC callback received error: %s", error)
- return RedirectResponse(url=f"{frontend_error_url}oidc_provider_error", status_code=302)
- if not code or not state:
- return RedirectResponse(url=f"{frontend_error_url}missing_parameters", status_code=302)
- # Atomically validate and consume OIDC state from DB (I6: single-use enforcement).
- # DELETE...RETURNING ensures concurrent callbacks with the same state token
- # cannot both succeed — only the first DELETE finds the row.
- now = datetime.now(timezone.utc)
- state_del = await db.execute(
- delete(AuthEphemeralToken)
- .where(
- AuthEphemeralToken.token == state,
- AuthEphemeralToken.token_type == TokenType.OIDC_STATE,
- AuthEphemeralToken.expires_at > now, # reject expired tokens atomically
- )
- .returning(
- AuthEphemeralToken.provider_id,
- AuthEphemeralToken.nonce,
- AuthEphemeralToken.code_verifier,
- )
- )
- state_row = state_del.one_or_none()
- if state_row is None:
- await db.rollback()
- return RedirectResponse(url=f"{frontend_error_url}invalid_state", status_code=302)
- provider_id, nonce, code_verifier = state_row
- await db.commit()
- # Load provider
- result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
- provider = result.scalar_one_or_none()
- if not provider:
- return RedirectResponse(url=f"{frontend_error_url}provider_not_found", status_code=302)
- redirect_uri = f"{external_url}/api/v1/auth/oidc/callback"
- # ── Step 1: Fetch discovery document ────────────────────────────────
- discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration"
- try:
- async with httpx.AsyncClient(timeout=10) as client:
- disc_resp = await client.get(discovery_url)
- disc_resp.raise_for_status()
- discovery = disc_resp.json()
- except Exception as exc:
- logger.error("OIDC discovery fetch failed for provider %d: %s", provider_id, exc)
- return RedirectResponse(url=f"{frontend_error_url}discovery_failed", status_code=302)
- token_endpoint = discovery.get("token_endpoint")
- jwks_uri = discovery.get("jwks_uri")
- if not token_endpoint or not jwks_uri:
- return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302)
- # L-R7-C: Reject non-HTTP(S) URLs in the discovery document to prevent
- # SSRF via crafted responses (e.g. file://, gopher://, internal schemes).
- if not token_endpoint.startswith(("https://", "http://")) or not jwks_uri.startswith(("https://", "http://")):
- logger.warning(
- "OIDC discovery document contains non-HTTP URL(s): token=%s jwks=%s", token_endpoint, jwks_uri
- )
- return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302)
- # ── Step 2: Exchange authorization code for tokens ───────────────────
- token_form: dict[str, str] = {
- "grant_type": "authorization_code",
- "code": code,
- "redirect_uri": redirect_uri,
- "client_id": provider.client_id,
- }
- if provider.client_secret:
- token_form["client_secret"] = provider.client_secret
- if code_verifier:
- token_form["code_verifier"] = code_verifier
- try:
- async with httpx.AsyncClient(timeout=15) as client:
- token_resp = await client.post(
- token_endpoint,
- data=token_form,
- headers={"Accept": "application/json"},
- )
- except Exception as exc:
- logger.error("OIDC token exchange request failed for provider %d: %s", provider_id, exc)
- return RedirectResponse(url=f"{frontend_error_url}token_exchange_network_error", status_code=302)
- if not token_resp.is_success:
- try:
- err_body = token_resp.json()
- oidc_err = err_body.get("error", "")
- oidc_desc = err_body.get("error_description", "")
- except Exception:
- oidc_err = ""
- oidc_desc = token_resp.text[:200]
- logger.error(
- "OIDC token exchange HTTP %d for provider %d. redirect_uri=%r error=%r desc=%r",
- token_resp.status_code,
- provider_id,
- redirect_uri,
- oidc_err,
- oidc_desc,
- )
- # Encode the OIDC error code into the redirect so the user sees it in the toast.
- # URL-encode the value to prevent query-parameter injection from provider responses.
- raw_err = oidc_err[:40] if oidc_err else str(token_resp.status_code)
- safe_err = urllib.parse.quote(raw_err, safe="")
- return RedirectResponse(
- url=f"{frontend_error_url}token_exchange_{safe_err}",
- status_code=302,
- )
- try:
- token_data = token_resp.json()
- except Exception as exc:
- logger.error("OIDC token exchange non-JSON response for provider %d: %s", provider_id, exc)
- return RedirectResponse(url=f"{frontend_error_url}token_exchange_bad_response", status_code=302)
- id_token = token_data.get("id_token")
- if not id_token:
- # Only log the keys present — values may contain secrets (access_token, etc.)
- logger.error(
- "OIDC token response missing id_token for provider %d; keys present: %s",
- provider_id,
- list(token_data.keys()),
- )
- return RedirectResponse(url=f"{frontend_error_url}no_id_token", status_code=302)
- # ── Step 3: Fetch JWKS and validate ID token ─────────────────────────
- # Use the issuer from the discovery document as the canonical value (OIDC Core
- # §3.1.3.7 requires iss == discovery issuer exactly). We strip trailing slashes
- # from both sides because some providers (e.g. Authentik, older PocketID versions)
- # are inconsistent between the discovery issuer and the JWT iss claim.
- discovery_issuer: str = discovery.get("issuer", provider.issuer_url).rstrip("/")
- try:
- async with httpx.AsyncClient(timeout=10) as jwks_http:
- jwks_resp = await jwks_http.get(jwks_uri)
- jwks_resp.raise_for_status()
- jwks_data = jwks_resp.json()
- jwks_client = PyJWKClient(jwks_uri)
- jwks_client.fetch_data = lambda: jwks_data # type: ignore[method-assign]
- signing_key = jwks_client.get_signing_key_from_jwt(id_token)
- # M-3: Decode without built-in issuer check, then compare normalised
- # (both sides rstrip("/")) to handle providers like Authentik that include
- # a trailing slash in iss but not in the discovery issuer, or vice-versa.
- claims = jwt.decode(
- id_token,
- signing_key.key,
- algorithms=["RS256", "ES256", "RS384", "ES384", "RS512"],
- audience=provider.client_id,
- options={"verify_iss": False},
- )
- token_iss = claims.get("iss", "").rstrip("/")
- if token_iss != discovery_issuer:
- raise jwt.exceptions.InvalidIssuerError("Invalid issuer")
- except Exception as exc:
- logger.error("OIDC JWT validation failed for provider %d: %s", provider_id, exc, exc_info=True)
- return RedirectResponse(url=f"{frontend_error_url}token_validation_failed", status_code=302)
- # Verify nonce — fail closed: we always send a nonce, so the provider must echo it.
- # Skipping the check when nonce is absent would allow CSRF on non-nonce providers.
- token_nonce = claims.get("nonce")
- if token_nonce is None or token_nonce != nonce:
- logger.warning("OIDC nonce mismatch for provider %d (present=%r)", provider_id, token_nonce is not None)
- return RedirectResponse(url=f"{frontend_error_url}nonce_mismatch", status_code=302)
- provider_sub: str = claims.get("sub", "")
- if not provider_sub:
- return RedirectResponse(url=f"{frontend_error_url}missing_sub_claim", status_code=302)
- # SEC-3: resolve email via Fall A/B/C logic (see _resolve_provider_email).
- provider_email = _resolve_provider_email(provider, claims, provider_sub)
- # ── Step 4: Resolve / create user ────────────────────────────────────
- try:
- # 1. Look up existing OIDC link
- link_result = await db.execute(
- select(UserOIDCLink)
- .where(UserOIDCLink.provider_id == provider_id)
- .where(UserOIDCLink.provider_user_id == provider_sub)
- )
- link = link_result.scalar_one_or_none()
- user: User | None = None
- if link:
- # Existing link → load the linked user
- user_result = await db.execute(
- select(User).where(User.id == link.user_id).options(selectinload(User.groups))
- )
- user = user_result.scalar_one_or_none()
- else:
- # 2. No OIDC link yet — check for an existing user with the same email.
- # Use case-insensitive matching (func.lower) so that "User@Example.com"
- # and "user@example.com" are treated as the same identity, preventing
- # an attacker-controlled IdP from bypassing the auto-link guard by
- # registering the target email with different casing.
- email_user: User | None = None
- if provider_email:
- email_user = await get_user_by_email(db, provider_email)
- if email_user and provider.auto_link_existing_accounts:
- # M-4: Only auto-link when the provider has auto_link_existing_accounts
- # enabled. Operators can disable this to require explicit account linking,
- # preventing an attacker-controlled IdP from hijacking local accounts.
- #
- # M-NEW-6: Refuse auto-link if the target user already has any OIDC
- # link (to any provider). Without this guard an attacker who controls
- # a second OIDC provider with auto_link enabled could add themselves as
- # a second IdP for a user that already authenticates via a legitimate
- # provider, effectively taking over the account.
- existing_links_result = await db.execute(
- select(UserOIDCLink).where(UserOIDCLink.user_id == email_user.id)
- )
- has_existing_oidc_link = existing_links_result.scalar_one_or_none() is not None
- if has_existing_oidc_link:
- logger.warning(
- "Auto-link rejected for user '%s': already linked to another OIDC provider",
- email_user.username,
- )
- return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302)
- db.add(
- UserOIDCLink(
- user_id=email_user.id,
- provider_id=provider_id,
- provider_user_id=provider_sub,
- provider_email=provider_email,
- )
- )
- await db.commit()
- user = email_user
- logger.info(
- "Auto-linked existing user '%s' to OIDC provider %d via email match",
- email_user.username,
- provider_id,
- )
- elif provider.auto_create_users:
- # 3. No existing user — create one
- if provider_email:
- raw = provider_email.split("@")[0]
- else:
- # Prefer a human-readable IdP claim over the opaque sub.
- # isinstance guards are required: claims may carry non-string
- # values (e.g. a list) that would break .strip().
- # Sanitization is applied per-candidate so that a value that
- # strips to empty (e.g. "!!!") correctly falls through to the
- # next candidate rather than silently becoming "oidcuser".
- _pref = claims.get("preferred_username")
- _name = claims.get("name")
- raw = ""
- if isinstance(_pref, str):
- raw = re.sub(r"[^a-zA-Z0-9._-]", "", _pref.strip())[:30]
- if not raw and isinstance(_name, str):
- raw = re.sub(r"[^a-zA-Z0-9._-]", "", _name.strip())[:30]
- if not raw:
- raw = provider_sub[:30]
- candidate = re.sub(r"[^a-zA-Z0-9._-]", "", raw)[:30] or "oidcuser"
- username = candidate
- counter = 1
- while True:
- existing = await get_user_by_username(db, username)
- if not existing:
- break
- username = f"{candidate}{counter}"
- counter += 1
- # I9: Assign new OIDC users to a group before flush — accessing
- # new_user.groups after a flush triggers a lazy-load which fails
- # in async context. Resolution order:
- # 1. provider.default_group_id (operator-configured)
- # 2. "Viewers" (system fallback for read-only access)
- # 3. no group (last resort if Viewers was deleted)
- # SQLite does not enforce ON DELETE SET NULL, so a dangling
- # default_group_id returns None here and falls through to Viewers.
- default_group: Group | None = None
- if provider.default_group_id is not None:
- dg_result = await db.execute(select(Group).where(Group.id == provider.default_group_id))
- default_group = dg_result.scalar_one_or_none()
- if default_group is None:
- viewers_result = await db.execute(select(Group).where(Group.name == "Viewers"))
- default_group = viewers_result.scalar_one_or_none()
- new_user = User(
- username=username,
- email=provider_email,
- # M-1: auth_source="oidc" prevents local password-reset flow
- # for users who should only authenticate via OIDC.
- auth_source="oidc",
- password_hash=None, # OIDC users never use password auth
- role="user",
- is_active=True,
- groups=[default_group] if default_group else [],
- )
- db.add(new_user)
- await db.flush()
- db.add(
- UserOIDCLink(
- user_id=new_user.id,
- provider_id=provider_id,
- provider_user_id=provider_sub,
- provider_email=provider_email,
- )
- )
- await db.commit()
- user_result = await db.execute(
- select(User).where(User.id == new_user.id).options(selectinload(User.groups))
- )
- user = user_result.scalar_one()
- logger.info("Auto-created user '%s' via OIDC provider %d", username, provider_id)
- else:
- return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302)
- if not user or not user.is_active:
- return RedirectResponse(url=f"{frontend_error_url}account_inactive", status_code=302)
- # Issue an OIDC exchange token (short-lived, single-use) stored in DB.
- # I7: Opportunistically prune expired exchange tokens to keep the table small.
- now2 = datetime.now(timezone.utc)
- await db.execute(
- delete(AuthEphemeralToken).where(
- AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE,
- AuthEphemeralToken.expires_at < now2,
- )
- )
- exchange_token = secrets.token_urlsafe(32)
- db.add(
- AuthEphemeralToken(
- token=exchange_token,
- token_type=TokenType.OIDC_EXCHANGE,
- username=user.username,
- expires_at=now2 + OIDC_EXCHANGE_TTL,
- )
- )
- await db.commit()
- # H-4: Use a URL fragment (#) instead of a query parameter so the exchange
- # token is never sent to the server in the Referer header or server logs.
- return RedirectResponse(url=f"{external_url}/login#oidc_token={exchange_token}", status_code=302)
- except Exception as exc:
- logger.error("OIDC user resolution failed for provider %d: %s", provider_id, exc, exc_info=True)
- try:
- await db.rollback()
- except Exception as rb_exc:
- logger.error("DB rollback failed after OIDC user-resolution error: %s", rb_exc, exc_info=True)
- return RedirectResponse(url=f"{frontend_error_url}user_resolution_failed", status_code=302)
- except Exception as exc:
- # L-1: Log the exception class name internally but never expose it in the
- # redirect URL — leaking exception names aids attacker reconnaissance.
- logger.error("Unexpected error in OIDC callback (%s): %s", type(exc).__name__, exc, exc_info=True)
- try:
- return RedirectResponse(url=f"{frontend_error_url}internal_error", status_code=302)
- except Exception as redirect_exc:
- logger.error("Failed to construct error redirect in OIDC callback: %s", redirect_exc, exc_info=True)
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="OIDC callback failed")
- @router.post("/oidc/exchange", response_model=LoginResponse)
- async def oidc_exchange(
- body: OIDCExchangeRequest,
- raw_request: Request,
- response: Response,
- db: AsyncSession = Depends(get_db),
- ) -> LoginResponse:
- """Exchange an OIDC exchange token (from the callback redirect) for a full JWT.
- C4: If the resolved user has 2FA enabled the exchange returns a pre_auth_token
- (requires_2fa=True) instead of a full JWT. The frontend must then complete the
- 2FA step exactly as it would after a password-based login.
- """
- now = datetime.now(timezone.utc)
- # Atomically consume the exchange token (DELETE...RETURNING prevents replay).
- consume_result = await db.execute(
- delete(AuthEphemeralToken)
- .where(
- AuthEphemeralToken.token == body.oidc_token,
- AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE,
- AuthEphemeralToken.expires_at > now, # reject expired tokens atomically
- )
- .returning(AuthEphemeralToken.username)
- )
- row = consume_result.one_or_none()
- if row is None:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired OIDC exchange token")
- (username,) = row
- await db.commit()
- user = await get_user_by_username(db, username)
- if not user or not user.is_active:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
- # Reload with groups
- result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
- user = result.scalar_one()
- # C4: Check whether the user has any 2FA method enabled.
- totp_result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
- totp_record = totp_result.scalar_one_or_none()
- totp_enabled = totp_record is not None and totp_record.is_enabled
- email_2fa_enabled = await _get_email_2fa_enabled(db, user.id)
- if totp_enabled or email_2fa_enabled:
- # User has 2FA — issue a pre_auth_token bound to this browser session via
- # an HttpOnly cookie (H-A: mirrors the cookie-binding done in auth.py:login).
- two_fa_methods: list[str] = []
- if totp_enabled:
- two_fa_methods.append("totp")
- if email_2fa_enabled:
- two_fa_methods.append("email")
- if totp_enabled:
- two_fa_methods.append("backup")
- challenge_id = secrets.token_urlsafe(32)
- pre_auth_token = await create_pre_auth_token(db, user.username, challenge_id=challenge_id)
- response.set_cookie(
- key="2fa_challenge",
- value=challenge_id,
- httponly=True,
- secure=raw_request.url.scheme == "https",
- samesite="lax",
- max_age=300,
- path="/api/v1/auth/2fa",
- )
- return LoginResponse(
- requires_2fa=True,
- pre_auth_token=pre_auth_token,
- two_fa_methods=two_fa_methods,
- user=_user_to_response(user),
- )
- access_token = create_access_token(
- data={"sub": user.username},
- expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
- )
- return LoginResponse(
- access_token=access_token,
- token_type="bearer",
- user=_user_to_response(user),
- requires_2fa=False,
- )
- @router.get("/oidc/links", response_model=list[OIDCLinkResponse])
- async def list_oidc_links(
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> list[OIDCLinkResponse]:
- """List all OIDC provider links for the current user."""
- result = await db.execute(
- select(UserOIDCLink).where(UserOIDCLink.user_id == current_user.id).options(selectinload(UserOIDCLink.provider))
- )
- links = result.scalars().all()
- return [
- OIDCLinkResponse(
- id=link.id,
- provider_id=link.provider_id,
- provider_name=link.provider.name,
- provider_email=link.provider_email,
- created_at=link.created_at.isoformat(),
- )
- for link in links
- ]
- @router.delete("/oidc/links/{provider_id}")
- async def remove_oidc_link(
- provider_id: int,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Remove the OIDC link between the current user and a provider."""
- result = await db.execute(
- select(UserOIDCLink)
- .where(UserOIDCLink.user_id == current_user.id)
- .where(UserOIDCLink.provider_id == provider_id)
- )
- link = result.scalar_one_or_none()
- if not link:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="OIDC link not found")
- await db.delete(link)
- await db.commit()
- return {"message": "OIDC link removed"}
- # ---------------------------------------------------------------------------
- # Internal helpers
- # ---------------------------------------------------------------------------
- async def _get_base_external_url(db: AsyncSession) -> str:
- """Return the base external URL (no trailing slash, no /login suffix)."""
- external_url = await get_setting(db, "external_url")
- if external_url:
- return external_url.rstrip("/")
- return os.environ.get("APP_URL", "http://localhost:5173").rstrip("/")
|