"""2FA (TOTP + Email OTP) and OIDC authentication routes. Security model -------------- * Pre-auth tokens : secrets.token_urlsafe(32) stored in-memory with a 5-minute TTL. They are single-use and do NOT grant access to any protected resource. * TOTP codes : verified with pyotp (30-second window, ±1 step tolerance). * Email OTP codes : 6-digit numeric, hashed with pbkdf2_sha256, 10-minute TTL, max 5 failed attempts per code before invalidation. * Backup codes : 10 × 8-char alphanumeric codes, each stored as pbkdf2_sha256 hash, single-use. * OIDC state : secrets.token_urlsafe(32) bound to provider_id + nonce, 10-minute TTL. * OIDC exchange : secrets.token_urlsafe(32), 2-minute TTL, single-use. * Rate limiting : max 5 failed 2FA verification attempts per user within 15 minutes. """ from __future__ import annotations import base64 import hashlib import io import logging import os import re import secrets import string import urllib.parse from datetime import datetime, timedelta, timezone import httpx import jwt import pyotp from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Request, Response, status from fastapi.responses import RedirectResponse from jwt import PyJWKClient from passlib.context import CryptContext from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload, undefer from backend.app.api.routes._oidc_helpers import assert_safe_public_https_url from backend.app.api.routes.settings import get_setting, set_setting from backend.app.core.auth import ( ACCESS_TOKEN_EXPIRE_MINUTES, RequirePermissionIfAuthEnabled, create_access_token, get_current_active_user, get_user_by_email, get_user_by_username, is_auth_enabled, verify_password, ) from backend.app.core.database import get_db from backend.app.core.permissions import Permission from backend.app.models.auth_ephemeral import AuthEphemeralToken, AuthRateLimitEvent, EventType, TokenType from backend.app.models.group import Group from backend.app.models.oidc_provider import OIDCProvider, UserOIDCLink from backend.app.models.user import User from backend.app.models.user_otp_code import UserOTPCode from backend.app.models.user_totp import UserTOTP from backend.app.schemas.auth import ( AUTO_LINK_REQUIREMENTS_ERROR, AdminDisable2FARequest, BackupCodesResponse, EmailOTPDisableRequest, EmailOTPEnableConfirmRequest, EmailOTPSendRequest, GroupBrief, LoginResponse, OIDCAuthorizeResponse, OIDCExchangeRequest, OIDCLinkResponse, OIDCProviderCreate, OIDCProviderResponse, OIDCProviderUpdate, TOTPDisableRequest, TOTPEnableRequest, TOTPEnableResponse, TOTPSetupRequest, TOTPSetupResponse, TwoFAStatusResponse, TwoFAVerifyRequest, TwoFAVerifyResponse, UserResponse, ) from backend.app.services.email_service import get_smtp_settings, send_email from backend.app.services.oidc_icon import OIDCIconError, fetch_icon logger = logging.getLogger(__name__) def _redact_url_for_log(url: str) -> str: """Return ``scheme://host/path`` with query string and fragment stripped. Admin-supplied icon URLs are usually CDN paths, but nothing stops an admin from pasting a presigned URL whose query string carries an ``X-Amz-Signature`` / OAuth token / etc. Operators need a forensic trail without those secrets ending up in log files. """ try: parsed = urllib.parse.urlparse(url) except ValueError: return "" netloc = parsed.netloc or "" return f"{parsed.scheme}://{netloc}{parsed.path}" async def _fetch_icon_or_400(icon_url: str) -> tuple[bytes, str, str]: """Validate URL + fetch icon, mapping any failure to HTTPException(400). Centralises the SSRF guard + fetcher invocation so create/update/refresh all behave identically — admin always gets a 400 with a precise reason, never a 500 / opaque server error. Both failure paths log at WARNING so operators have a forensic trail later — without these log lines the admin's UI toast was the only record of the failure (#1333 review). """ try: assert_safe_public_https_url(icon_url) except ValueError as exc: logger.warning("OIDC icon URL rejected by SSRF guard: url=%s reason=%s", _redact_url_for_log(icon_url), exc) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc try: return await fetch_icon(icon_url) except OIDCIconError as exc: logger.warning("OIDC icon fetch failed: url=%s reason=%s", _redact_url_for_log(icon_url), exc) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc def _build_provider_response(provider: OIDCProvider) -> OIDCProviderResponse: """Build OIDCProviderResponse via ``from_attributes``. The required ``has_icon`` field is supplied by ``OIDCProvider.has_icon`` (a property reading the non-deferred ``icon_content_type`` column).""" return OIDCProviderResponse.model_validate(provider) def _etag_matches(if_none_match: str | None, etag_raw: str | None) -> bool: """RFC 7232 §3.2 If-None-Match comparison. Supports: * ``*`` wildcard — matches any current representation when the resource exists (and it does here; we wouldn't have an etag otherwise). * Multiple comma-separated tokens. * Weak-validator prefix ``W/`` (RFC 7232 §2.3) — accepted on GET since cached representations of a static byte-blob are byte-identical. Returns False on missing header or missing stored etag. """ if not if_none_match or not etag_raw: return False quoted = f'"{etag_raw}"' tokens = [t.strip() for t in if_none_match.split(",")] if "*" in tokens: return True return any(tok.removeprefix("W/") == quoted for tok in tokens) def _as_utc(dt: datetime) -> datetime: """Return *dt* with UTC timezone attached. SQLite/aiosqlite strips timezone info when reading DateTime(timezone=True) columns back – the stored value is always UTC, so we just re-attach the info when doing Python-level comparisons. """ return dt if dt.tzinfo is not None else dt.replace(tzinfo=timezone.utc) # --------------------------------------------------------------------------- # Passlib context (same scheme as auth.py) # --------------------------------------------------------------------------- pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto") # --------------------------------------------------------------------------- # TTL / rate-limit constants # --------------------------------------------------------------------------- MAX_2FA_ATTEMPTS = 5 MAX_LOGIN_ATTEMPTS = 10 LOCKOUT_WINDOW = timedelta(minutes=15) MAX_EMAIL_OTP_SENDS = 3 EMAIL_OTP_SEND_WINDOW = timedelta(minutes=10) PRE_AUTH_TOKEN_TTL = timedelta(minutes=5) OIDC_STATE_TTL = timedelta(minutes=10) OIDC_EXCHANGE_TTL = timedelta(minutes=2) # --------------------------------------------------------------------------- # Router # --------------------------------------------------------------------------- router = APIRouter(prefix="/auth", tags=["2fa", "oidc"]) # --------------------------------------------------------------------------- # Helper: user response # --------------------------------------------------------------------------- def _user_to_response(user: User) -> UserResponse: return UserResponse( id=user.id, username=user.username, email=user.email, role=user.role, is_active=user.is_active, is_admin=user.is_admin, groups=[GroupBrief(id=g.id, name=g.name) for g in user.groups], permissions=sorted(user.get_permissions()), created_at=user.created_at.isoformat(), ) # --------------------------------------------------------------------------- # Helper: QR code generation # --------------------------------------------------------------------------- def _generate_totp_qr_b64(provisioning_uri: str) -> str: """Generate a base64-encoded PNG QR code for the given TOTP provisioning URI.""" import qrcode # type: ignore qr = qrcode.QRCode(box_size=6, border=2) qr.add_data(provisioning_uri) qr.make(fit=True) img = qr.make_image(fill_color="black", back_color="white") buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode() # --------------------------------------------------------------------------- # Helper: backup code generation # --------------------------------------------------------------------------- def _generate_backup_codes() -> tuple[list[str], list[str]]: """Return (plain_codes, hashed_codes) — 10 codes of 8 alphanumeric chars each.""" alphabet = string.ascii_uppercase + string.digits plain = ["".join(secrets.choice(alphabet) for _ in range(8)) for _ in range(10)] hashed = [pwd_context.hash(c) for c in plain] return plain, hashed # --------------------------------------------------------------------------- # DB-backed pre-auth token helpers # --------------------------------------------------------------------------- async def create_pre_auth_token(db: AsyncSession, username: str, challenge_id: str | None = None) -> str: """Create a single-use pre-auth token stored in the DB. Pass ``challenge_id`` (from the HttpOnly 2fa_challenge cookie) to bind the token to the originating browser session. The same value must be present as a cookie on every subsequent call that consumes this token. """ now = datetime.now(timezone.utc) # Prune expired tokens opportunistically (keep table small) await db.execute( delete(AuthEphemeralToken).where( AuthEphemeralToken.token_type == TokenType.PRE_AUTH, AuthEphemeralToken.expires_at < now, ) ) token = secrets.token_urlsafe(32) db.add( AuthEphemeralToken( token=token, token_type=TokenType.PRE_AUTH, username=username, challenge_id=challenge_id, expires_at=now + PRE_AUTH_TOKEN_TTL, ) ) await db.commit() return token async def consume_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None: """Atomically validate and consume a pre-auth token. Returns username or None. Uses DELETE...RETURNING so two concurrent requests with the same token cannot both succeed — only the first DELETE finds the row. M5: When challenge_id is provided, also enforces the cookie-binding constraint so a stolen token cannot be replayed from a different browser session. """ now = datetime.now(timezone.utc) result = await db.execute( delete(AuthEphemeralToken) .where( AuthEphemeralToken.token == token, AuthEphemeralToken.token_type == TokenType.PRE_AUTH, AuthEphemeralToken.expires_at > now, ) .returning(AuthEphemeralToken.username, AuthEphemeralToken.challenge_id) ) row = result.one_or_none() if row is None: return None username, stored_challenge_id = row # Enforce client binding: if the token was issued with a challenge_id, # the caller must supply the matching value. if stored_challenge_id is not None and stored_challenge_id != challenge_id: await db.rollback() return None await db.commit() return username async def peek_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None: """Validate a pre-auth token and return the username WITHOUT consuming it. When the stored token has a ``challenge_id`` (client-binding cookie), the caller must supply the matching value. A mismatch is treated as an invalid token — no information leakage about whether the token itself exists. """ now = datetime.now(timezone.utc) result = await db.execute( select(AuthEphemeralToken).where( AuthEphemeralToken.token == token, AuthEphemeralToken.token_type == TokenType.PRE_AUTH, AuthEphemeralToken.expires_at > now, ) ) eph = result.scalar_one_or_none() if eph is None: return None # Enforce client binding: if the token was issued with a challenge_id the # cookie must match. Treat a mismatch as if the token doesn't exist. if eph.challenge_id is not None and eph.challenge_id != challenge_id: return None return eph.username # --------------------------------------------------------------------------- # DB-backed rate-limiting helpers # --------------------------------------------------------------------------- async def check_rate_limit( db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT, max_attempts: int = MAX_2FA_ATTEMPTS, ) -> None: """Raise HTTP 429 if the user has exceeded the failed attempt limit. The username is normalised to lower-case so case-variant attempts (which all resolve to the same user) share the same rate-limit bucket. L-2: Known TOCTOU — the SELECT (count) and the subsequent INSERT (record_failed_attempt) are not atomic. Two concurrent requests can both read a count below the threshold and both proceed. This is an inherent trade-off of the event-log rate-limit pattern: fixing it would require a serialising lock (SELECT FOR UPDATE on a dedicated counter row), which adds contention and is not worth it for a soft rate-limit whose window is already measured in minutes. In practice the race window is microseconds and the limit can be slightly exceeded only under precise concurrent timing. """ username_key = username.lower() now = datetime.now(timezone.utc) cutoff = now - LOCKOUT_WINDOW result = await db.execute( select(AuthRateLimitEvent).where( AuthRateLimitEvent.username == username_key, AuthRateLimitEvent.event_type == event_type, AuthRateLimitEvent.occurred_at > cutoff, ) ) recent_count = len(result.scalars().all()) if recent_count >= max_attempts: raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many failed attempts. Please try again later.", ) async def record_failed_attempt(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None: """Record a failed attempt for rate-limiting purposes.""" db.add(AuthRateLimitEvent(username=username.lower(), event_type=event_type)) await db.commit() async def clear_failed_attempts(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None: """Delete all recorded failed attempts for a user on successful verification.""" await db.execute( delete(AuthRateLimitEvent).where( AuthRateLimitEvent.username == username.lower(), AuthRateLimitEvent.event_type == event_type, ) ) await db.commit() async def check_email_otp_send_rate(db: AsyncSession, username: str) -> None: """Raise HTTP 429 if the user has requested too many OTP emails recently. I1: This function only *checks* the limit. The caller is responsible for recording the slot via ``record_email_otp_send`` **after** the email has been sent successfully. This prevents failed sends from consuming a slot (wasting the user's quota) and makes it impossible to farm rate-limit events without actually triggering a send. """ username_key = username.lower() now = datetime.now(timezone.utc) cutoff = now - EMAIL_OTP_SEND_WINDOW result = await db.execute( select(AuthRateLimitEvent).where( AuthRateLimitEvent.username == username_key, AuthRateLimitEvent.event_type == EventType.EMAIL_SEND, AuthRateLimitEvent.occurred_at > cutoff, ) ) recent_count = len(result.scalars().all()) if recent_count >= MAX_EMAIL_OTP_SENDS: raise HTTPException( status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=f"Too many OTP email requests. Please wait {EMAIL_OTP_SEND_WINDOW.seconds // 60} minutes.", ) async def record_email_otp_send(db: AsyncSession, username: str) -> None: """Record a successful OTP email send for rate-limiting purposes (I1). Must be called *after* the email has been sent successfully so that failed sends do not consume a slot from the user's quota. """ db.add(AuthRateLimitEvent(username=username.lower(), event_type=EventType.EMAIL_SEND)) await db.commit() # --------------------------------------------------------------------------- # TOTP replay-protection helper # --------------------------------------------------------------------------- def _assert_totp_not_replayed(totp_obj: pyotp.TOTP, totp_record: UserTOTP, code: str) -> None: """Raise HTTP 400 if this TOTP code was already accepted in its time window. M3 fix: store the counter of the *accepted* code rather than the current wall-clock counter. With valid_window=1, pyotp accepts codes from the previous 30-second step. Using timecode(now) would store the wrong counter when the previous-window code is accepted, allowing immediate replay. """ # Determine which time-step the accepted code belongs to. now = datetime.now(timezone.utc) accepted_counter: int | None = None for offset in (0, -1): # current window first, then previous candidate_time = now.timestamp() + offset * totp_obj.interval candidate_counter = totp_obj.timecode(datetime.fromtimestamp(candidate_time, tz=timezone.utc)) if totp_obj.at(candidate_counter) == code: accepted_counter = candidate_counter break if accepted_counter is None: accepted_counter = totp_obj.timecode(now) # fallback (should not happen after verify()) totp_record.accept_counter(accepted_counter) # --------------------------------------------------------------------------- # OIDC helpers # --------------------------------------------------------------------------- _EMAIL_SHAPE_RE = re.compile(r"[^\s@]+@[^\s@]+\.[^\s@]+") def _is_valid_email_shaped(value: str | None) -> bool: # SEC-2: shape check for non-standard claims (upn, preferred_username). # Requires local@domain.tld — rejects "@", "x@", "@domain", "x@nodot". if not value or len(value) > 255: return False return _EMAIL_SHAPE_RE.fullmatch(value) is not None def _enforce_auto_link_safety(provider: OIDCProvider) -> None: """Raise HTTP 422 if auto_link_existing_accounts is on with an unsafe combined state. SEC-1: only Fall B (email_claim='email' + require_email_verified=False) is unsafe — an attacker-controlled IdP could present an unverified email that matches a local account. Fall C (custom claim) never performs an email_verified check, so auto_link is safe there. Called after ORM construction (create) and after the setattr loop (update). """ if provider.auto_link_existing_accounts and provider.email_claim == "email" and not provider.require_email_verified: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=AUTO_LINK_REQUIREMENTS_ERROR, ) def _resolve_provider_email(provider: OIDCProvider, claims: dict, provider_sub: str) -> str | None: """Extract and normalise the email address from OIDC ID-token claims. Implements three resolution paths (Fall A/B/C): Fall C — custom email_claim (!= "email"): shape-check only, no email_verified gate. Recommended for Azure Entra ID (preferred_username or upn). Fall A — email_claim="email" + require_email_verified=True: strict, email_verified must be True. Fall B — email_claim="email" + require_email_verified=False: permissive, explicit False drops email. Returns a lowercase-stripped email string, or None when the claim is absent/invalid. """ provider_id = provider.id raw_claim_value = claims.get(provider.email_claim) if raw_claim_value is not None and not isinstance(raw_claim_value, str): # TYPE-GUARD: non-string claim (e.g. list, int) would raise AttributeError on .lower(). logger.warning( "OIDC provider %d: email_claim %r has unexpected type %s for sub=%r, ignoring", provider_id, provider.email_claim, type(raw_claim_value).__name__, provider_sub, ) raw_claim_value = None raw_email: str | None = raw_claim_value.lower().strip() if raw_claim_value else None if provider.email_claim != "email": # Fall C: custom claim (preferred_username, upn, …) — no email_verified check. # SEC-2: _is_valid_email_shaped instead of bare '"@" in value'. # Recommended for Azure Entra ID: set email_claim="preferred_username" or "upn". if raw_email and _is_valid_email_shaped(raw_email): return raw_email if raw_email: logger.warning( "OIDC provider %d: email_claim %r value failed shape check for sub=%r, ignoring", provider_id, provider.email_claim, provider_sub, ) return None email_verified = claims.get("email_verified") if provider.require_email_verified: # Fall A: standard C1-Guard — fail closed unless email_verified is True. # SEC-2: apply shape check to standard email claim — providers may set # email_verified=True on non-email values (e.g. numeric user IDs). # SEC-3 normalisation applies; existing mixed-case provider_email records # were normalised to lowercase by run_migrations at startup. if raw_email and not _is_valid_email_shaped(raw_email): logger.warning( "OIDC provider %d: email claim failed shape check for sub=%r, ignoring", provider_id, provider_sub, ) return None if email_verified is True: return raw_email if raw_email: logger.info( "OIDC provider %d: ignoring email for sub=%r because email_verified=%r", provider_id, provider_sub, email_verified, ) return None # Fall B: permissive — explicit False drops email, absent/None keeps it. # Required for Azure Entra ID which never sends email_verified. # SEC-2: apply shape check before the email_verified=False drop so malformed # values are rejected regardless of the email_verified claim. if raw_email and not _is_valid_email_shaped(raw_email): logger.warning( "OIDC provider %d: email claim failed shape check for sub=%r, ignoring", provider_id, provider_sub, ) return None if email_verified is False: return None if email_verified is not True: # SEC-5: log only when the permissive path actually fires (ev absent/None), # not on every successful login. logger.info( "OIDC provider %r (%d): accepting email for sub=%r without email_verified claim (permissive mode)", provider.name, provider.id, provider_sub, ) return raw_email # --------------------------------------------------------------------------- # Settings helpers (email 2FA flag) # --------------------------------------------------------------------------- async def _get_email_2fa_enabled(db: AsyncSession, user_id: int) -> bool: val = await get_setting(db, f"user_{user_id}_email_2fa_enabled") return val == "true" async def _set_email_2fa_enabled(db: AsyncSession, user_id: int, enabled: bool) -> None: await set_setting(db, f"user_{user_id}_email_2fa_enabled", "true" if enabled else "false") # =========================================================================== # 2FA Endpoints # =========================================================================== @router.get("/2fa/status", response_model=TwoFAStatusResponse) async def get_2fa_status( current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> TwoFAStatusResponse: """Return the current 2FA configuration for the authenticated user.""" result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id)) totp_record = result.scalar_one_or_none() totp_enabled = totp_record is not None and totp_record.is_enabled backup_codes_remaining = len(totp_record.backup_code_hashes) if totp_record else 0 email_otp_enabled = await _get_email_2fa_enabled(db, current_user.id) return TwoFAStatusResponse( totp_enabled=totp_enabled, email_otp_enabled=email_otp_enabled, backup_codes_remaining=backup_codes_remaining, ) @router.post("/2fa/totp/setup", response_model=TOTPSetupResponse) async def setup_totp( body: TOTPSetupRequest | None = Body(default=None), current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> TOTPSetupResponse: """Initiate TOTP setup: generates a new secret and QR code. Creates (or replaces) a pending UserTOTP record with is_enabled=False. The caller must confirm with POST /auth/2fa/totp/enable. M-R7-A: If an *active* TOTP is already configured, the caller must supply the current TOTP code in the request body to confirm intent before the secret is overwritten (prevents silently locking out the real user). """ if not await is_auth_enabled(db): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Authentication is not enabled") # Upsert a pending TOTP record (is_enabled=False) existing = (await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))).scalar_one_or_none() # M-R7-A: Guard against silent TOTP replacement when one is already active. if existing and existing.is_enabled: await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) supplied_code = (body.code if body else None) or "" # S4: narrow the RuntimeError catch to ONLY the property access — that # is the single line that raises on key-loss. The previous wide try # block also covered record_failed_attempt, clear_failed_attempts, # and _assert_totp_not_replayed, so a future RuntimeError from any # of those would have been misreported as "TOTP secret unavailable". try: secret_plain = existing.secret except RuntimeError: logger.exception("TOTP decryption failed for user_id=%s", current_user.id) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="TOTP secret unavailable", ) if not pyotp.TOTP(secret_plain).verify(supplied_code, valid_window=1): await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Current TOTP code required to replace an active authenticator", ) await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) _assert_totp_not_replayed(pyotp.TOTP(secret_plain), existing, supplied_code) await db.flush() # L-3: persist last_totp_counter immediately to block replay secret = pyotp.random_base32() totp = pyotp.TOTP(secret) provisioning_uri = totp.provisioning_uri(name=current_user.username, issuer_name="Bambuddy") qr_b64 = _generate_totp_qr_b64(provisioning_uri) if existing: existing.secret = secret existing.is_enabled = False existing.backup_code_hashes = [] else: db.add(UserTOTP(user_id=current_user.id, secret=secret, is_enabled=False)) await db.commit() return TOTPSetupResponse(secret=secret, qr_code_b64=qr_b64, issuer="Bambuddy") @router.post("/2fa/totp/enable", response_model=TOTPEnableResponse) async def enable_totp( body: TOTPEnableRequest, current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> TOTPEnableResponse: """Confirm TOTP setup by verifying a code from the authenticator app. On success, enables TOTP and returns 10 single-use backup codes (shown once). L-R7-A: Rate-limited to prevent brute-forcing the 6-digit confirmation code. """ # L-R7-A: Rate-limit the enable step to prevent brute-forcing the 6-digit code. await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id)) totp_record = result.scalar_one_or_none() if not totp_record: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP setup not initiated. Call /auth/2fa/totp/setup first." ) try: totp_verify = pyotp.TOTP(totp_record.secret).verify(body.code, valid_window=1) except RuntimeError: logger.exception("TOTP decryption failed for user_id=%s", totp_record.user_id) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="TOTP secret unavailable") if not totp_verify: await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP code") await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) plain_codes, hashed_codes = _generate_backup_codes() totp_record.is_enabled = True totp_record.backup_code_hashes = hashed_codes await db.commit() return TOTPEnableResponse( message="TOTP enabled successfully. Store your backup codes in a safe place.", backup_codes=plain_codes, ) @router.post("/2fa/totp/disable") async def disable_totp( body: TOTPDisableRequest, current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> dict: """Disable TOTP by verifying a valid TOTP code or a backup code. I10: Rate-limited to prevent backup-code brute-forcing from a hijacked session. """ await check_rate_limit(db, current_user.username) result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id)) totp_record = result.scalar_one_or_none() if not totp_record or not totp_record.is_enabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled") # Accept either a valid TOTP code or a valid backup code. When the secret # cannot be decrypted (encryption key lost), fall through to the backup- # code path so the user can still disable 2FA with their printed codes. totp_obj: pyotp.TOTP | None = None code_valid = False decryption_failed = False try: totp_obj = pyotp.TOTP(totp_record.secret) code_valid = totp_obj.verify(body.code, valid_window=1) except RuntimeError: # S3: track that the failure was server-side so we don't penalise # the user with a fail-counter increment for a problem they can't fix. decryption_failed = True logger.exception( "TOTP decryption failed for user_id=%s — falling through to backup-code check", totp_record.user_id, ) if code_valid and totp_obj is not None: _assert_totp_not_replayed(totp_obj, totp_record, body.code) await db.flush() # L-3: persist last_totp_counter immediately to block replay else: # Check backup codes — always iterate all entries (L-R9-A: no early break # to avoid timing oracle based on code position in the list). for hashed in totp_record.backup_code_hashes: if pwd_context.verify(body.code, hashed): code_valid = True if not code_valid: # S3: skip the fail-counter debit when the cause was a server-side # decryption failure (key loss / rotation). The user submitted a # wrong backup code on top of a broken TOTP, but locking them out # of the recovery path for an admin's mistake is not the right move. if not decryption_failed: await record_failed_attempt(db, current_user.username) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid code") await db.execute(delete(UserTOTP).where(UserTOTP.user_id == current_user.id)) await db.commit() return {"message": "TOTP disabled"} @router.post("/2fa/totp/regenerate-backup-codes", response_model=BackupCodesResponse) async def regenerate_backup_codes( body: TOTPDisableRequest, current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> BackupCodesResponse: """Generate 10 new backup codes. Requires a valid TOTP code OR a backup code. M10: Accepts backup codes for consistency with disable_totp — users who have lost their authenticator app but still have backup codes can regenerate. Rate-limited to prevent brute-forcing from a hijacked session. """ await check_rate_limit(db, current_user.username) result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id)) totp_record = result.scalar_one_or_none() if not totp_record or not totp_record.is_enabled: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled") # Same recovery contract as disable_totp: when the TOTP secret cannot be # decrypted, fall through to the backup-code branch so the user can # rotate their codes with a printed backup code. totp_obj: pyotp.TOTP | None = None code_valid = False decryption_failed = False try: totp_obj = pyotp.TOTP(totp_record.secret) code_valid = totp_obj.verify(body.code, valid_window=1) except RuntimeError: # S3: track server-side failure so we skip the fail-counter debit. decryption_failed = True logger.exception( "TOTP decryption failed for user_id=%s — falling through to backup-code check", totp_record.user_id, ) if code_valid and totp_obj is not None: _assert_totp_not_replayed(totp_obj, totp_record, body.code) await db.flush() # L-3: persist last_totp_counter immediately to block replay else: # Accept a backup code as an alternative (M10) matched_index: int | None = None for idx, hashed in enumerate(totp_record.backup_code_hashes): if pwd_context.verify(body.code, hashed) and matched_index is None: matched_index = idx if matched_index is None: # S3: skip fail-counter debit when the cause was a server-side # decryption failure (key loss / rotation). if not decryption_failed: await record_failed_attempt(db, current_user.username) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP or backup code") # Remove the used backup code totp_record.backup_code_hashes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index] plain_codes, hashed_codes = _generate_backup_codes() totp_record.backup_code_hashes = hashed_codes await db.commit() return BackupCodesResponse( backup_codes=plain_codes, message="Backup codes regenerated. Store them safely — they will not be shown again.", ) @router.post("/2fa/email/enable") async def enable_email_otp( current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> dict: """Step 1 of email OTP enable: send a verification code to the user's email. C5: Proof of possession — the user must prove they control the registered email address before email 2FA is activated. Returns a ``setup_token`` that must be passed to POST /auth/2fa/email/enable/confirm together with the received code. H-3: Rate-limited to prevent email flooding via repeated calls to this endpoint. """ await check_email_otp_send_rate(db, current_user.username) if not current_user.email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="You must have an email address configured to enable email OTP 2FA", ) smtp_settings = await get_smtp_settings(db) if not smtp_settings: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured") # Generate and store the setup token (reuse AuthEphemeralToken with type "email_otp_setup") now = datetime.now(timezone.utc) # Prune any existing pending setup tokens for this user await db.execute( delete(AuthEphemeralToken).where( AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP, AuthEphemeralToken.username == current_user.username, ) ) code = str(secrets.randbelow(1_000_000)).zfill(6) code_hash = pwd_context.hash(code) setup_token = secrets.token_urlsafe(32) db.add( AuthEphemeralToken( token=setup_token, token_type=TokenType.EMAIL_OTP_SETUP, username=current_user.username, # Reuse the nonce field to store the code hash nonce=code_hash, expires_at=now + timedelta(minutes=10), ) ) await db.commit() try: send_email( smtp_settings=smtp_settings, to_email=current_user.email, subject="Verify your Bambuddy email address for 2FA", body_text=( f"Your Bambuddy email 2FA setup code is: {code}\n\n" "Enter this code to confirm email-based two-factor authentication.\n" "The code expires in 10 minutes." ), body_html=( "

To enable email-based two-factor authentication on your Bambuddy account, " "enter the code below:

" f"

{code}

" "

The code expires in 10 minutes. " "If you did not request this, you can safely ignore this email.

" ), ) await record_email_otp_send(db, current_user.username) except Exception as exc: logger.error("Failed to send email OTP setup code to user_id=%d: %s", current_user.id, exc) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send verification email" ) return {"message": "Verification code sent to your email address", "setup_token": setup_token} @router.post("/2fa/email/enable/confirm") async def confirm_enable_email_otp( body: EmailOTPEnableConfirmRequest, current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> dict: """Step 2 of email OTP enable: verify the code and activate email 2FA. H-2 fix: Uses peek-then-consume so a wrong code does NOT burn the setup token. The token is only deleted after successful code verification, allowing retries up to the rate limit (5 attempts / 15 min). M4: Rate-limited to prevent brute-forcing the 6-digit setup code. """ await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) now = datetime.now(timezone.utc) # --- Peek: validate token without consuming --- peek_result = await db.execute( select(AuthEphemeralToken).where( AuthEphemeralToken.token == body.setup_token, AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP, AuthEphemeralToken.username == current_user.username, AuthEphemeralToken.expires_at > now, ) ) eph = peek_result.scalar_one_or_none() if eph is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token") code_hash = eph.nonce # code hash stored in the nonce field # --- Verify code before consuming the token --- if not pwd_context.verify(body.code, code_hash): await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification code") # --- Atomically consume the token now that the code is correct --- # DELETE...RETURNING prevents a concurrent request from using the same token. del_result = await db.execute( delete(AuthEphemeralToken) .where( AuthEphemeralToken.token == body.setup_token, AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP, AuthEphemeralToken.username == current_user.username, ) .returning(AuthEphemeralToken.id) ) if del_result.one_or_none() is None: # Concurrent request consumed it between peek and delete — treat as invalid. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token") await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT) await _set_email_2fa_enabled(db, current_user.id, True) await db.commit() return {"message": "Email OTP 2FA enabled"} @router.post("/2fa/email/disable") async def disable_email_otp( body: EmailOTPDisableRequest, current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> dict: """Disable email-based OTP 2FA for the current user. C6: Re-authentication required — the caller must supply their account password to prevent a hijacked session from silently removing a second factor. LDAP/OIDC-only users (no local password) are exempt from this check. H-2: Rate-limited to prevent brute-forcing the password via this endpoint. """ await check_rate_limit(db, current_user.username) if current_user.password_hash: if not verify_password(body.password, current_user.password_hash): await record_failed_attempt(db, current_user.username) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password") await _set_email_2fa_enabled(db, current_user.id, False) await db.commit() return {"message": "Email OTP 2FA disabled"} @router.post("/2fa/email/send") async def send_email_otp( request: Request, body: EmailOTPSendRequest, db: AsyncSession = Depends(get_db), ) -> dict: """Send a 6-digit OTP code to the user's email address. Requires a valid pre_auth_token obtained during the login flow. """ # Peek (validate without consuming) first so a rate-limit rejection does not # permanently burn the caller's pre-auth token. challenge_id = request.cookies.get("2fa_challenge") username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id) if not username: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token") # Enforce rate limit BEFORE consuming the token to prevent OTP email flooding. await check_email_otp_send_rate(db, username) user = await get_user_by_username(db, username) if not user or not user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") if not user.email: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no email address configured") smtp_settings = await get_smtp_settings(db) if not smtp_settings: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured") # Invalidate all existing unused OTP codes for this user (staged, not yet committed) await db.execute( UserOTPCode.__table__.update() # type: ignore[attr-defined] .where(UserOTPCode.user_id == user.id) .where(UserOTPCode.used.is_(False)) .values(used=True) ) # Generate a 6-digit code and stage the record (not committed yet) code = str(secrets.randbelow(1_000_000)).zfill(6) code_hash = pwd_context.hash(code) expires_at = datetime.now(timezone.utc) + timedelta(minutes=UserOTPCode.OTP_TTL_MINUTES) otp_record = UserOTPCode( user_id=user.id, code_hash=code_hash, attempts=0, used=False, expires_at=expires_at, ) db.add(otp_record) # M2: Send the email BEFORE consuming the pre-auth token. # If the send fails we raise an exception here; the session is uncommitted so # the OTP record is discarded and the original token remains valid for retry. try: send_email( smtp_settings=smtp_settings, to_email=user.email, subject="Your Bambuddy verification code", body_text=f"Your Bambuddy login code is: {code}\n\nThis code expires in {UserOTPCode.OTP_TTL_MINUTES} minutes and can only be used once.", body_html=( f"

Your Bambuddy login verification code is:

" f"

{code}

" f"

This code expires in {UserOTPCode.OTP_TTL_MINUTES} minutes and can only be used once.

" f"

If you did not request this code, you can safely ignore this email.

" ), ) await record_email_otp_send(db, username) except Exception as exc: logger.error("Failed to send OTP email to user_id=%d: %s", user.id, exc) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send OTP email") # Email sent — now atomically consume the old token (this also commits the # staged OTP record) and issue a fresh token for the verify step. consumed = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id) if not consumed: # Raced with another request or token just expired — treat as invalid. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token") # Re-issue a fresh pre-auth token bound to the same cookie so the binding # carries forward through the email → verify step. fresh_token = await create_pre_auth_token(db, username, challenge_id=challenge_id) # Return the fresh pre-auth token so the frontend can proceed to verify return {"message": "Code sent to your email address", "pre_auth_token": fresh_token} @router.post("/2fa/verify", response_model=TwoFAVerifyResponse) async def verify_2fa( request: Request, body: TwoFAVerifyRequest, db: AsyncSession = Depends(get_db), ) -> TwoFAVerifyResponse: """Verify a 2FA code and exchange the pre_auth_token for a full JWT. Accepted methods: ``totp``, ``email``, ``backup``. The pre_auth_token is NOT consumed on failed verification attempts so the user can retry without restarting the login flow. It is only consumed once verification succeeds, preventing token replay after success. """ # Peek without consuming — bad codes must not burn the session token. # Pass the HttpOnly challenge cookie so the binding check is enforced. challenge_id = request.cookies.get("2fa_challenge") username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id) if not username: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token") await check_rate_limit(db, username) user = await get_user_by_username(db, username) if not user or not user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") method = body.method if method == "totp": result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id)) totp_record = result.scalar_one_or_none() if not totp_record or not totp_record.is_enabled: await record_failed_attempt(db, username) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user") try: totp_obj = pyotp.TOTP(totp_record.secret) except RuntimeError: logger.exception("TOTP decryption failed for user_id=%s", totp_record.user_id) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="TOTP secret unavailable") if not totp_obj.verify(body.code, valid_window=1): await record_failed_attempt(db, username) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid TOTP code") _assert_totp_not_replayed(totp_obj, totp_record, body.code) await db.flush() # L-3: persist last_totp_counter immediately to block replay elif method == "email": now = datetime.now(timezone.utc) result = await db.execute( select(UserOTPCode) .where(UserOTPCode.user_id == user.id) .where(UserOTPCode.used.is_(False)) .where(UserOTPCode.expires_at > now) .order_by(UserOTPCode.created_at.desc()) ) otp_record = result.scalar_one_or_none() if not otp_record: await record_failed_attempt(db, username) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="No valid OTP code found. Request a new one." ) if otp_record.attempts >= UserOTPCode.MAX_ATTEMPTS: otp_record.consume() await db.commit() await record_failed_attempt(db, username) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="OTP code has been invalidated after too many attempts" ) if not pwd_context.verify(body.code, otp_record.code_hash): otp_record.attempts += 1 await db.commit() await record_failed_attempt(db, username) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid OTP code") otp_record.consume() await db.commit() else: # method == "backup" result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id)) totp_record = result.scalar_one_or_none() if not totp_record or not totp_record.is_enabled: await record_failed_attempt(db, username) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user") # Always iterate all codes — no early break (L-R9-A: constant iteration # count prevents timing oracle based on used-code position in the list). matched_index: int | None = None for idx, hashed in enumerate(totp_record.backup_code_hashes): if pwd_context.verify(body.code, hashed) and matched_index is None: matched_index = idx if matched_index is None: await record_failed_attempt(db, username) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid backup code") # M1: Consume the pre-auth token FIRST (atomic single-use enforcement). # Only if that succeeds do we remove the backup code — this prevents a race # where two concurrent requests both pass code verification but only one # should be granted a session. consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id) if not consumed_username: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token") # Remove the used backup code now that the token is atomically consumed. updated_codes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index] totp_record.backup_code_hashes = updated_codes await db.commit() await clear_failed_attempts(db, username) access_token = create_access_token( data={"sub": user.username}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), ) result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups))) user = result.scalar_one() return TwoFAVerifyResponse(access_token=access_token, token_type="bearer", user=_user_to_response(user)) # Verification succeeded (TOTP or email) — consume the pre-auth token. # C-1: Check the return value; if None the token was already consumed by a # concurrent request (race condition) — reject to prevent double-use. consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id) if not consumed_username: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token") await clear_failed_attempts(db, username) access_token = create_access_token( data={"sub": user.username}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), ) # Reload with groups for permission calculation result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups))) user = result.scalar_one() return TwoFAVerifyResponse( access_token=access_token, token_type="bearer", user=_user_to_response(user), ) @router.delete("/2fa/admin/{user_id}") async def admin_disable_2fa( user_id: int, body: AdminDisable2FARequest = Body(default_factory=AdminDisable2FARequest), current_user: User | None = RequirePermissionIfAuthEnabled(Permission.USERS_UPDATE), db: AsyncSession = Depends(get_db), ) -> dict: """Admin endpoint: disable all 2FA for a given user. Nit 3: Requires the admin's own password as a re-auth step (matching how disable_email_otp protects a user's own 2FA removal). OIDC/LDAP-only admins (no local password_hash) are exempt. """ # Nit 3: Re-auth — admin must supply their own password. if current_user and current_user.password_hash: if not body.admin_password or not verify_password(body.admin_password, current_user.password_hash): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin password required") # Delete TOTP record await db.execute(delete(UserTOTP).where(UserTOTP.user_id == user_id)) # Disable email 2FA setting await _set_email_2fa_enabled(db, user_id, False) # Invalidate all OTP codes await db.execute( UserOTPCode.__table__.update() # type: ignore[attr-defined] .where(UserOTPCode.user_id == user_id) .values(used=True) ) # I2: Invalidate existing JWTs for the target user by bumping password_changed_at. # Without this, a stolen token remains valid after 2FA removal. target_user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none() if target_user: target_user.password_changed_at = datetime.now(timezone.utc) await db.commit() actor = current_user.username if current_user else "anonymous" logger.info("Admin %s disabled all 2FA for user_id=%d", actor, user_id) return {"message": "2FA disabled for user"} # =========================================================================== # OIDC Endpoints # =========================================================================== @router.get("/oidc/providers", response_model=list[OIDCProviderResponse]) async def list_oidc_providers( db: AsyncSession = Depends(get_db), ) -> list[OIDCProviderResponse]: """List all enabled OIDC providers (public). The login page renders icons via /oidc/providers/{id}/icon — `icon_data` stays deferred so this list query never pulls the BLOB. """ result = await db.execute(select(OIDCProvider).where(OIDCProvider.is_enabled.is_(True))) providers = result.scalars().all() return [_build_provider_response(p) for p in providers] @router.get("/oidc/providers/all", response_model=list[OIDCProviderResponse]) async def list_all_oidc_providers( _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_READ), db: AsyncSession = Depends(get_db), ) -> list[OIDCProviderResponse]: """List ALL OIDC providers including disabled ones (admin only).""" result2 = await db.execute(select(OIDCProvider)) providers = result2.scalars().all() return [_build_provider_response(p) for p in providers] @router.post("/oidc/providers", response_model=OIDCProviderResponse, status_code=status.HTTP_201_CREATED) async def create_oidc_provider( body: OIDCProviderCreate, _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE), db: AsyncSession = Depends(get_db), ) -> OIDCProviderResponse: """Create a new OIDC provider (admin only). If `icon_url` is supplied, the icon is fetched server-side and cached in the BLOB columns (#1333). A fetch failure aborts the create with 400 — no half-configured provider is left in the DB. """ if body.default_group_id is not None: grp_chk = await db.execute(select(Group).where(Group.id == body.default_group_id)) if not grp_chk.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="default_group_id references a non-existent group", ) # Fetch the icon BEFORE creating the row so a failure leaves the DB clean. icon_data: bytes | None = None icon_content_type: str | None = None icon_etag: str | None = None if body.icon_url: icon_data, icon_content_type, icon_etag = await _fetch_icon_or_400(body.icon_url) provider = OIDCProvider( name=body.name, issuer_url=body.issuer_url.rstrip("/"), client_id=body.client_id, client_secret=body.client_secret, scopes=body.scopes, is_enabled=body.is_enabled, auto_create_users=body.auto_create_users, auto_link_existing_accounts=body.auto_link_existing_accounts, email_claim=body.email_claim, require_email_verified=body.require_email_verified, icon_url=body.icon_url, icon_data=icon_data, icon_content_type=icon_content_type, icon_etag=icon_etag, default_group_id=body.default_group_id, ) # SEC-1 + SEC-6: runtime guard mirrors the OIDCProviderCreate model_validator in schemas/auth.py. # Catches any future path that bypasses Pydantic validation (direct ORM, scripts). _enforce_auto_link_safety(provider) db.add(provider) await db.commit() await db.refresh(provider) return _build_provider_response(provider) @router.put("/oidc/providers/{provider_id}", response_model=OIDCProviderResponse) async def update_oidc_provider( provider_id: int, body: OIDCProviderUpdate, _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE), db: AsyncSession = Depends(get_db), ) -> OIDCProviderResponse: """Update an existing OIDC provider (admin only). Icon refetch fires when: 1. The submitted `icon_url` differs from the stored one (URL changed), OR 2. The submitted `icon_url` equals the stored one AND `icon_content_type` is NULL — this is the upgrade-path edge case: old providers carry `icon_url` but no cached bytes until the admin first saves them. On fetch failure the request aborts with 400 *before* commit, so the existing cached bytes (if any) remain untouched. """ result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id)) provider = result2.scalar_one_or_none() if not provider: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found") if body.default_group_id is not None: grp_chk = await db.execute(select(Group).where(Group.id == body.default_group_id)) if not grp_chk.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="default_group_id references a non-existent group", ) dumped = body.model_dump(exclude_none=True) # Decide whether an icon refetch is needed BEFORE mutating the ORM object, # so the comparison sees provider.icon_url / icon_content_type as they are # in the database. new_icon_url = dumped.get("icon_url") needs_icon_refetch = new_icon_url is not None and ( new_icon_url != provider.icon_url or provider.icon_content_type is None ) # Fetch FIRST. If the upstream is unreachable or SSRF-blocked, _fetch_icon_or_400 # raises HTTPException(400) here — provider attributes are still untouched, so # the in-memory ORM object stays consistent on the way out (and the DB row is # safe regardless via get_db()'s rollback). fetched_icon: tuple[bytes, str, str] | None = None if needs_icon_refetch: fetched_icon = await _fetch_icon_or_400(new_icon_url) # Explicit `icon_url: null` in the PUT body means "clear the icon". # The exclude_none=True dump above drops None values, which would # otherwise silently ignore this request. Check model_fields_set on # the unfiltered body to distinguish "client cleared it" from "client # didn't include this field at all". if "icon_url" in body.model_fields_set and body.icon_url is None: provider.icon_url = None provider.icon_data = None provider.icon_content_type = None provider.icon_etag = None for field, value in dumped.items(): if field == "issuer_url" and value: value = value.rstrip("/") setattr(provider, field, value) if fetched_icon is not None: provider.icon_data, provider.icon_content_type, provider.icon_etag = fetched_icon # SEC-1 + SEC-6: Combined-State-Guard after setattr loop. # Checks the final in-memory state (DB values + newly set values combined) to catch # partial updates that each pass schema validation individually but are unsafe together. _enforce_auto_link_safety(provider) await db.commit() await db.refresh(provider) return _build_provider_response(provider) @router.delete("/oidc/providers/{provider_id}") async def delete_oidc_provider( provider_id: int, _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE), db: AsyncSession = Depends(get_db), ) -> dict: """Delete an OIDC provider and all its user links (admin only).""" result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id)) provider = result2.scalar_one_or_none() if not provider: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found") await db.delete(provider) await db.commit() return {"message": "Provider deleted"} # --------------------------------------------------------------------------- # OIDC provider icon proxy (#1333) # --------------------------------------------------------------------------- @router.get("/oidc/providers/{provider_id}/icon") async def get_oidc_provider_icon( provider_id: int, if_none_match: str | None = Header(default=None, alias="If-None-Match"), db: AsyncSession = Depends(get_db), ) -> Response: """Serve the cached icon for an enabled OIDC provider (public, no auth). Unauthenticated because ```` tags cannot send Authorization headers and the login page renders these icons before the user is signed in — the same justification as ``/api/v1/makerworld/thumbnail``. The SSRF guard runs at admin-config time (create/update/refresh), not here. Disabled providers respond 404 to avoid leaking their existence to anonymous callers (mirrors ``GET /oidc/providers`` which filters on ``is_enabled``). """ result = await db.execute( select(OIDCProvider) .options(undefer(OIDCProvider.icon_data)) .where(OIDCProvider.id == provider_id, OIDCProvider.is_enabled.is_(True)) ) provider = result.scalar_one_or_none() if provider is None or provider.icon_content_type is None or provider.icon_data is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Icon not found") etag_value = f'"{provider.icon_etag}"' cache_headers = {"ETag": etag_value, "Cache-Control": "public, max-age=3600"} if _etag_matches(if_none_match, provider.icon_etag): return Response(status_code=status.HTTP_304_NOT_MODIFIED, headers=cache_headers) return Response( content=provider.icon_data, media_type=provider.icon_content_type, headers=cache_headers, ) @router.delete("/oidc/providers/{provider_id}/icon", status_code=status.HTTP_204_NO_CONTENT) async def delete_oidc_provider_icon( provider_id: int, _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE), db: AsyncSession = Depends(get_db), ) -> Response: """Remove the icon entirely for a provider (admin only). Clears all four icon columns — ``icon_url`` plus the three cached-bytes columns. "Remove icon" means the whole record is gone, not just the cache; without this the admin form would still show the URL while the login page rendered a blank fallback (confusing half-state). To re-add an icon the admin re-types the URL in the edit form. """ result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id)) provider = result.scalar_one_or_none() if provider is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found") # Setting deferred columns is safe — no read happens, just a write. provider.icon_url = None provider.icon_data = None provider.icon_content_type = None provider.icon_etag = None await db.commit() return Response(status_code=status.HTTP_204_NO_CONTENT) @router.post("/oidc/providers/{provider_id}/icon/refresh", response_model=OIDCProviderResponse) async def refresh_oidc_provider_icon( provider_id: int, _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE), db: AsyncSession = Depends(get_db), ) -> OIDCProviderResponse: """Refetch the icon from the stored `icon_url` (admin only). Used when: - The IdP changed its icon and the admin wants Bambuddy to pick up the new bytes. - An upgrade left the provider with an `icon_url` but no cached bytes (covered automatically by `update_oidc_provider` too, but this gives the UI an explicit "Refresh" button). Failure to refetch returns 400 *before* commit, so the previously cached bytes survive intact. """ result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id)) provider = result.scalar_one_or_none() if provider is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found") if not provider.icon_url: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Provider has no icon_url to refresh", ) icon_data, icon_content_type, icon_etag = await _fetch_icon_or_400(provider.icon_url) provider.icon_data = icon_data provider.icon_content_type = icon_content_type provider.icon_etag = icon_etag await db.commit() await db.refresh(provider) return _build_provider_response(provider) @router.get("/oidc/authorize/{provider_id}", response_model=OIDCAuthorizeResponse) async def oidc_authorize( provider_id: int, db: AsyncSession = Depends(get_db), ) -> OIDCAuthorizeResponse: """Return the OIDC authorization URL for the given provider.""" result = await db.execute( select(OIDCProvider).where(OIDCProvider.id == provider_id).where(OIDCProvider.is_enabled.is_(True)) ) provider = result.scalar_one_or_none() if not provider: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found or not enabled") # Fetch discovery document discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration" try: async with httpx.AsyncClient(timeout=10) as client: resp = await client.get(discovery_url) resp.raise_for_status() discovery = resp.json() except Exception as exc: logger.error("Failed to fetch OIDC discovery for provider %d: %s", provider_id, exc) raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch OIDC discovery document") authorization_endpoint = discovery.get("authorization_endpoint") if not authorization_endpoint: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="OIDC discovery document missing authorization_endpoint" ) # B2: SSRF guard — reject non-HTTP(S) schemes in the authorization endpoint if not authorization_endpoint.startswith(("https://", "http://")): logger.warning("OIDC discovery authorization_endpoint has invalid scheme: %s", authorization_endpoint) raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="OIDC discovery document contains invalid authorization_endpoint", ) external_url = await _get_base_external_url(db) redirect_uri = f"{external_url}/api/v1/auth/oidc/callback" now = datetime.now(timezone.utc) # Prune expired OIDC states from the DB await db.execute( delete(AuthEphemeralToken).where( AuthEphemeralToken.token_type == TokenType.OIDC_STATE, AuthEphemeralToken.expires_at < now, ) ) state = secrets.token_urlsafe(32) nonce = secrets.token_urlsafe(32) # PKCE (S256) – required by PocketID and recommended for all OIDC flows code_verifier = secrets.token_urlsafe(48) # 64-char URL-safe string code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode() db.add( AuthEphemeralToken( token=state, token_type=TokenType.OIDC_STATE, provider_id=provider_id, nonce=nonce, code_verifier=code_verifier, expires_at=now + OIDC_STATE_TTL, ) ) await db.commit() params = urllib.parse.urlencode( { "response_type": "code", "client_id": provider.client_id, "redirect_uri": redirect_uri, "scope": provider.scopes, "state": state, "nonce": nonce, "code_challenge": code_challenge, "code_challenge_method": "S256", } ) auth_url = f"{authorization_endpoint}?{params}" return OIDCAuthorizeResponse(auth_url=auth_url) @router.get("/oidc/callback") async def oidc_callback( code: str | None = Query(default=None, max_length=2048), state: str | None = Query(default=None, max_length=2048), error: str | None = Query(default=None, max_length=256), db: AsyncSession = Depends(get_db), ) -> RedirectResponse: """Handle the OIDC authorization code callback from the identity provider.""" external_url = await _get_base_external_url(db) frontend_error_url = f"{external_url}/?oidc_error=" try: if error: logger.warning("OIDC callback received error: %s", error) return RedirectResponse(url=f"{frontend_error_url}oidc_provider_error", status_code=302) if not code or not state: return RedirectResponse(url=f"{frontend_error_url}missing_parameters", status_code=302) # Atomically validate and consume OIDC state from DB (I6: single-use enforcement). # DELETE...RETURNING ensures concurrent callbacks with the same state token # cannot both succeed — only the first DELETE finds the row. now = datetime.now(timezone.utc) state_del = await db.execute( delete(AuthEphemeralToken) .where( AuthEphemeralToken.token == state, AuthEphemeralToken.token_type == TokenType.OIDC_STATE, AuthEphemeralToken.expires_at > now, # reject expired tokens atomically ) .returning( AuthEphemeralToken.provider_id, AuthEphemeralToken.nonce, AuthEphemeralToken.code_verifier, ) ) state_row = state_del.one_or_none() if state_row is None: await db.rollback() return RedirectResponse(url=f"{frontend_error_url}invalid_state", status_code=302) provider_id, nonce, code_verifier = state_row await db.commit() # Load provider result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id)) provider = result.scalar_one_or_none() if not provider: return RedirectResponse(url=f"{frontend_error_url}provider_not_found", status_code=302) redirect_uri = f"{external_url}/api/v1/auth/oidc/callback" # ── Step 1: Fetch discovery document ──────────────────────────────── discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration" try: async with httpx.AsyncClient(timeout=10) as client: disc_resp = await client.get(discovery_url) disc_resp.raise_for_status() discovery = disc_resp.json() except Exception as exc: logger.error("OIDC discovery fetch failed for provider %d: %s", provider_id, exc) return RedirectResponse(url=f"{frontend_error_url}discovery_failed", status_code=302) token_endpoint = discovery.get("token_endpoint") jwks_uri = discovery.get("jwks_uri") if not token_endpoint or not jwks_uri: return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302) # L-R7-C: Reject non-HTTP(S) URLs in the discovery document to prevent # SSRF via crafted responses (e.g. file://, gopher://, internal schemes). if not token_endpoint.startswith(("https://", "http://")) or not jwks_uri.startswith(("https://", "http://")): logger.warning( "OIDC discovery document contains non-HTTP URL(s): token=%s jwks=%s", token_endpoint, jwks_uri ) return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302) # ── Step 2: Exchange authorization code for tokens ─────────────────── token_form: dict[str, str] = { "grant_type": "authorization_code", "code": code, "redirect_uri": redirect_uri, "client_id": provider.client_id, } if provider.client_secret: token_form["client_secret"] = provider.client_secret if code_verifier: token_form["code_verifier"] = code_verifier try: async with httpx.AsyncClient(timeout=15) as client: token_resp = await client.post( token_endpoint, data=token_form, headers={"Accept": "application/json"}, ) except Exception as exc: logger.error("OIDC token exchange request failed for provider %d: %s", provider_id, exc) return RedirectResponse(url=f"{frontend_error_url}token_exchange_network_error", status_code=302) if not token_resp.is_success: try: err_body = token_resp.json() oidc_err = err_body.get("error", "") oidc_desc = err_body.get("error_description", "") except Exception: oidc_err = "" oidc_desc = token_resp.text[:200] logger.error( "OIDC token exchange HTTP %d for provider %d. redirect_uri=%r error=%r desc=%r", token_resp.status_code, provider_id, redirect_uri, oidc_err, oidc_desc, ) # Encode the OIDC error code into the redirect so the user sees it in the toast. # URL-encode the value to prevent query-parameter injection from provider responses. raw_err = oidc_err[:40] if oidc_err else str(token_resp.status_code) safe_err = urllib.parse.quote(raw_err, safe="") return RedirectResponse( url=f"{frontend_error_url}token_exchange_{safe_err}", status_code=302, ) try: token_data = token_resp.json() except Exception as exc: logger.error("OIDC token exchange non-JSON response for provider %d: %s", provider_id, exc) return RedirectResponse(url=f"{frontend_error_url}token_exchange_bad_response", status_code=302) id_token = token_data.get("id_token") if not id_token: # Only log the keys present — values may contain secrets (access_token, etc.) logger.error( "OIDC token response missing id_token for provider %d; keys present: %s", provider_id, list(token_data.keys()), ) return RedirectResponse(url=f"{frontend_error_url}no_id_token", status_code=302) # ── Step 3: Fetch JWKS and validate ID token ───────────────────────── # Use the issuer from the discovery document as the canonical value (OIDC Core # §3.1.3.7 requires iss == discovery issuer exactly). We strip trailing slashes # from both sides because some providers (e.g. Authentik, older PocketID versions) # are inconsistent between the discovery issuer and the JWT iss claim. discovery_issuer: str = discovery.get("issuer", provider.issuer_url).rstrip("/") try: async with httpx.AsyncClient(timeout=10) as jwks_http: jwks_resp = await jwks_http.get(jwks_uri) jwks_resp.raise_for_status() jwks_data = jwks_resp.json() jwks_client = PyJWKClient(jwks_uri) jwks_client.fetch_data = lambda: jwks_data # type: ignore[method-assign] signing_key = jwks_client.get_signing_key_from_jwt(id_token) # M-3: Decode without built-in issuer check, then compare normalised # (both sides rstrip("/")) to handle providers like Authentik that include # a trailing slash in iss but not in the discovery issuer, or vice-versa. claims = jwt.decode( id_token, signing_key.key, algorithms=["RS256", "ES256", "RS384", "ES384", "RS512"], audience=provider.client_id, options={"verify_iss": False}, ) token_iss = claims.get("iss", "").rstrip("/") if token_iss != discovery_issuer: raise jwt.exceptions.InvalidIssuerError("Invalid issuer") except Exception as exc: logger.error("OIDC JWT validation failed for provider %d: %s", provider_id, exc, exc_info=True) return RedirectResponse(url=f"{frontend_error_url}token_validation_failed", status_code=302) # Verify nonce — fail closed: we always send a nonce, so the provider must echo it. # Skipping the check when nonce is absent would allow CSRF on non-nonce providers. token_nonce = claims.get("nonce") if token_nonce is None or token_nonce != nonce: logger.warning("OIDC nonce mismatch for provider %d (present=%r)", provider_id, token_nonce is not None) return RedirectResponse(url=f"{frontend_error_url}nonce_mismatch", status_code=302) provider_sub: str = claims.get("sub", "") if not provider_sub: return RedirectResponse(url=f"{frontend_error_url}missing_sub_claim", status_code=302) # SEC-3: resolve email via Fall A/B/C logic (see _resolve_provider_email). provider_email = _resolve_provider_email(provider, claims, provider_sub) # ── Step 4: Resolve / create user ──────────────────────────────────── try: # 1. Look up existing OIDC link link_result = await db.execute( select(UserOIDCLink) .where(UserOIDCLink.provider_id == provider_id) .where(UserOIDCLink.provider_user_id == provider_sub) ) link = link_result.scalar_one_or_none() user: User | None = None if link: # Existing link → load the linked user user_result = await db.execute( select(User).where(User.id == link.user_id).options(selectinload(User.groups)) ) user = user_result.scalar_one_or_none() else: # 2. No OIDC link yet — check for an existing user with the same email. # Use case-insensitive matching (func.lower) so that "User@Example.com" # and "user@example.com" are treated as the same identity, preventing # an attacker-controlled IdP from bypassing the auto-link guard by # registering the target email with different casing. email_user: User | None = None if provider_email: email_user = await get_user_by_email(db, provider_email) if email_user and provider.auto_link_existing_accounts: # M-4: Only auto-link when the provider has auto_link_existing_accounts # enabled. Operators can disable this to require explicit account linking, # preventing an attacker-controlled IdP from hijacking local accounts. # # M-NEW-6: Refuse auto-link if the target user already has any OIDC # link (to any provider). Without this guard an attacker who controls # a second OIDC provider with auto_link enabled could add themselves as # a second IdP for a user that already authenticates via a legitimate # provider, effectively taking over the account. existing_links_result = await db.execute( select(UserOIDCLink).where(UserOIDCLink.user_id == email_user.id) ) has_existing_oidc_link = existing_links_result.scalar_one_or_none() is not None if has_existing_oidc_link: logger.warning( "Auto-link rejected for user '%s': already linked to another OIDC provider", email_user.username, ) return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302) db.add( UserOIDCLink( user_id=email_user.id, provider_id=provider_id, provider_user_id=provider_sub, provider_email=provider_email, ) ) await db.commit() user = email_user logger.info( "Auto-linked existing user '%s' to OIDC provider %d via email match", email_user.username, provider_id, ) elif provider.auto_create_users: # 3. No existing user — create one if provider_email: raw = provider_email.split("@")[0] else: # Prefer a human-readable IdP claim over the opaque sub. # isinstance guards are required: claims may carry non-string # values (e.g. a list) that would break .strip(). # Sanitization is applied per-candidate so that a value that # strips to empty (e.g. "!!!") correctly falls through to the # next candidate rather than silently becoming "oidcuser". _pref = claims.get("preferred_username") _name = claims.get("name") raw = "" if isinstance(_pref, str): raw = re.sub(r"[^a-zA-Z0-9._-]", "", _pref.strip())[:30] if not raw and isinstance(_name, str): raw = re.sub(r"[^a-zA-Z0-9._-]", "", _name.strip())[:30] if not raw: raw = provider_sub[:30] candidate = re.sub(r"[^a-zA-Z0-9._-]", "", raw)[:30] or "oidcuser" username = candidate counter = 1 while True: existing = await get_user_by_username(db, username) if not existing: break username = f"{candidate}{counter}" counter += 1 # I9: Assign new OIDC users to a group before flush — accessing # new_user.groups after a flush triggers a lazy-load which fails # in async context. Resolution order: # 1. provider.default_group_id (operator-configured) # 2. "Viewers" (system fallback for read-only access) # 3. no group (last resort if Viewers was deleted) # SQLite does not enforce ON DELETE SET NULL, so a dangling # default_group_id returns None here and falls through to Viewers. default_group: Group | None = None if provider.default_group_id is not None: dg_result = await db.execute(select(Group).where(Group.id == provider.default_group_id)) default_group = dg_result.scalar_one_or_none() if default_group is None: viewers_result = await db.execute(select(Group).where(Group.name == "Viewers")) default_group = viewers_result.scalar_one_or_none() new_user = User( username=username, email=provider_email, # M-1: auth_source="oidc" prevents local password-reset flow # for users who should only authenticate via OIDC. auth_source="oidc", password_hash=None, # OIDC users never use password auth role="user", is_active=True, groups=[default_group] if default_group else [], ) db.add(new_user) await db.flush() db.add( UserOIDCLink( user_id=new_user.id, provider_id=provider_id, provider_user_id=provider_sub, provider_email=provider_email, ) ) await db.commit() user_result = await db.execute( select(User).where(User.id == new_user.id).options(selectinload(User.groups)) ) user = user_result.scalar_one() logger.info("Auto-created user '%s' via OIDC provider %d", username, provider_id) else: return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302) if not user or not user.is_active: return RedirectResponse(url=f"{frontend_error_url}account_inactive", status_code=302) # Issue an OIDC exchange token (short-lived, single-use) stored in DB. # I7: Opportunistically prune expired exchange tokens to keep the table small. now2 = datetime.now(timezone.utc) await db.execute( delete(AuthEphemeralToken).where( AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE, AuthEphemeralToken.expires_at < now2, ) ) exchange_token = secrets.token_urlsafe(32) db.add( AuthEphemeralToken( token=exchange_token, token_type=TokenType.OIDC_EXCHANGE, username=user.username, expires_at=now2 + OIDC_EXCHANGE_TTL, ) ) await db.commit() # H-4: Use a URL fragment (#) instead of a query parameter so the exchange # token is never sent to the server in the Referer header or server logs. return RedirectResponse(url=f"{external_url}/login#oidc_token={exchange_token}", status_code=302) except Exception as exc: logger.error("OIDC user resolution failed for provider %d: %s", provider_id, exc, exc_info=True) try: await db.rollback() except Exception as rb_exc: logger.error("DB rollback failed after OIDC user-resolution error: %s", rb_exc, exc_info=True) return RedirectResponse(url=f"{frontend_error_url}user_resolution_failed", status_code=302) except Exception as exc: # L-1: Log the exception class name internally but never expose it in the # redirect URL — leaking exception names aids attacker reconnaissance. logger.error("Unexpected error in OIDC callback (%s): %s", type(exc).__name__, exc, exc_info=True) try: return RedirectResponse(url=f"{frontend_error_url}internal_error", status_code=302) except Exception as redirect_exc: logger.error("Failed to construct error redirect in OIDC callback: %s", redirect_exc, exc_info=True) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="OIDC callback failed") @router.post("/oidc/exchange", response_model=LoginResponse) async def oidc_exchange( body: OIDCExchangeRequest, raw_request: Request, response: Response, db: AsyncSession = Depends(get_db), ) -> LoginResponse: """Exchange an OIDC exchange token (from the callback redirect) for a full JWT. C4: If the resolved user has 2FA enabled the exchange returns a pre_auth_token (requires_2fa=True) instead of a full JWT. The frontend must then complete the 2FA step exactly as it would after a password-based login. """ now = datetime.now(timezone.utc) # Atomically consume the exchange token (DELETE...RETURNING prevents replay). consume_result = await db.execute( delete(AuthEphemeralToken) .where( AuthEphemeralToken.token == body.oidc_token, AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE, AuthEphemeralToken.expires_at > now, # reject expired tokens atomically ) .returning(AuthEphemeralToken.username) ) row = consume_result.one_or_none() if row is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired OIDC exchange token") (username,) = row await db.commit() user = await get_user_by_username(db, username) if not user or not user.is_active: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") # Reload with groups result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups))) user = result.scalar_one() # C4: Check whether the user has any 2FA method enabled. totp_result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id)) totp_record = totp_result.scalar_one_or_none() totp_enabled = totp_record is not None and totp_record.is_enabled email_2fa_enabled = await _get_email_2fa_enabled(db, user.id) if totp_enabled or email_2fa_enabled: # User has 2FA — issue a pre_auth_token bound to this browser session via # an HttpOnly cookie (H-A: mirrors the cookie-binding done in auth.py:login). two_fa_methods: list[str] = [] if totp_enabled: two_fa_methods.append("totp") if email_2fa_enabled: two_fa_methods.append("email") if totp_enabled: two_fa_methods.append("backup") challenge_id = secrets.token_urlsafe(32) pre_auth_token = await create_pre_auth_token(db, user.username, challenge_id=challenge_id) response.set_cookie( key="2fa_challenge", value=challenge_id, httponly=True, secure=raw_request.url.scheme == "https", samesite="lax", max_age=300, path="/api/v1/auth/2fa", ) return LoginResponse( requires_2fa=True, pre_auth_token=pre_auth_token, two_fa_methods=two_fa_methods, user=_user_to_response(user), ) access_token = create_access_token( data={"sub": user.username}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), ) return LoginResponse( access_token=access_token, token_type="bearer", user=_user_to_response(user), requires_2fa=False, ) @router.get("/oidc/links", response_model=list[OIDCLinkResponse]) async def list_oidc_links( current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> list[OIDCLinkResponse]: """List all OIDC provider links for the current user.""" result = await db.execute( select(UserOIDCLink).where(UserOIDCLink.user_id == current_user.id).options(selectinload(UserOIDCLink.provider)) ) links = result.scalars().all() # Defensive null-check on link.provider: on PostgreSQL the FK cascade # ensures provider exists, but SQLite ships with FK enforcement off, so # a deleted provider could in theory leave the link briefly orphan until # the next init_db() cleanup runs. Returning "" instead of # crashing keeps the endpoint usable in that edge case (#1285 follow-up). return [ OIDCLinkResponse( id=link.id, provider_id=link.provider_id, provider_name=link.provider.name if link.provider else "", provider_email=link.provider_email, created_at=link.created_at.isoformat(), ) for link in links ] @router.delete("/oidc/links/{provider_id}") async def remove_oidc_link( provider_id: int, current_user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_db), ) -> dict: """Remove the OIDC link between the current user and a provider.""" result = await db.execute( select(UserOIDCLink) .where(UserOIDCLink.user_id == current_user.id) .where(UserOIDCLink.provider_id == provider_id) ) link = result.scalar_one_or_none() if not link: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="OIDC link not found") await db.delete(link) await db.commit() return {"message": "OIDC link removed"} # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- async def _get_base_external_url(db: AsyncSession) -> str: """Return the base external URL (no trailing slash, no /login suffix).""" external_url = await get_setting(db, "external_url") if external_url: return external_url.rstrip("/") return os.environ.get("APP_URL", "http://localhost:5173").rstrip("/")