| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153 |
- """2FA (TOTP + Email OTP) and OIDC authentication routes.
- Security model
- --------------
- * Pre-auth tokens : secrets.token_urlsafe(32) stored in-memory with a 5-minute TTL.
- They are single-use and do NOT grant access to any protected resource.
- * TOTP codes : verified with pyotp (30-second window, ±1 step tolerance).
- * Email OTP codes : 6-digit numeric, hashed with pbkdf2_sha256, 10-minute TTL,
- max 5 failed attempts per code before invalidation.
- * Backup codes : 10 × 8-char alphanumeric codes, each stored as pbkdf2_sha256 hash,
- single-use.
- * OIDC state : secrets.token_urlsafe(32) bound to provider_id + nonce, 10-minute TTL.
- * OIDC exchange : secrets.token_urlsafe(32), 2-minute TTL, single-use.
- * Rate limiting : max 5 failed 2FA verification attempts per user within 15 minutes.
- """
- from __future__ import annotations
- import base64
- import hashlib
- import io
- import logging
- import os
- import re
- import secrets
- import string
- import urllib.parse
- from datetime import datetime, timedelta, timezone
- import httpx
- import jwt
- import pyotp
- from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Request, Response, status
- from fastapi.responses import RedirectResponse
- from jwt import PyJWKClient
- from passlib.context import CryptContext
- from sqlalchemy import delete, select
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy.orm import selectinload, undefer
- from backend.app.api.routes._oidc_helpers import assert_safe_public_https_url
- from backend.app.api.routes.settings import get_setting, set_setting
- from backend.app.core.auth import (
- ACCESS_TOKEN_EXPIRE_MINUTES,
- RequirePermissionIfAuthEnabled,
- create_access_token,
- get_current_active_user,
- get_user_by_email,
- get_user_by_username,
- is_auth_enabled,
- verify_password,
- )
- from backend.app.core.database import get_db
- from backend.app.core.permissions import Permission
- from backend.app.models.auth_ephemeral import AuthEphemeralToken, AuthRateLimitEvent, EventType, TokenType
- from backend.app.models.group import Group
- from backend.app.models.oidc_provider import OIDCProvider, UserOIDCLink
- from backend.app.models.user import User
- from backend.app.models.user_otp_code import UserOTPCode
- from backend.app.models.user_totp import UserTOTP
- from backend.app.schemas.auth import (
- AUTO_LINK_REQUIREMENTS_ERROR,
- AdminDisable2FARequest,
- BackupCodesResponse,
- EmailOTPDisableRequest,
- EmailOTPEnableConfirmRequest,
- EmailOTPSendRequest,
- GroupBrief,
- LoginResponse,
- OIDCAuthorizeResponse,
- OIDCExchangeRequest,
- OIDCLinkResponse,
- OIDCProviderCreate,
- OIDCProviderResponse,
- OIDCProviderUpdate,
- TOTPDisableRequest,
- TOTPEnableRequest,
- TOTPEnableResponse,
- TOTPSetupRequest,
- TOTPSetupResponse,
- TwoFAStatusResponse,
- TwoFAVerifyRequest,
- TwoFAVerifyResponse,
- UserResponse,
- )
- from backend.app.services.email_service import get_smtp_settings, send_email
- from backend.app.services.oidc_icon import OIDCIconError, fetch_icon
- logger = logging.getLogger(__name__)
- def _redact_url_for_log(url: str) -> str:
- """Return ``scheme://host/path`` with query string and fragment stripped.
- Admin-supplied icon URLs are usually CDN paths, but nothing stops an
- admin from pasting a presigned URL whose query string carries an
- ``X-Amz-Signature`` / OAuth token / etc. Operators need a forensic
- trail without those secrets ending up in log files.
- """
- try:
- parsed = urllib.parse.urlparse(url)
- except ValueError:
- return "<unparseable>"
- netloc = parsed.netloc or "<no-host>"
- return f"{parsed.scheme}://{netloc}{parsed.path}"
- async def _fetch_icon_or_400(icon_url: str) -> tuple[bytes, str, str]:
- """Validate URL + fetch icon, mapping any failure to HTTPException(400).
- Centralises the SSRF guard + fetcher invocation so create/update/refresh
- all behave identically — admin always gets a 400 with a precise reason,
- never a 500 / opaque server error.
- Both failure paths log at WARNING so operators have a forensic trail
- later — without these log lines the admin's UI toast was the only
- record of the failure (#1333 review).
- """
- try:
- assert_safe_public_https_url(icon_url)
- except ValueError as exc:
- logger.warning("OIDC icon URL rejected by SSRF guard: url=%s reason=%s", _redact_url_for_log(icon_url), exc)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
- try:
- return await fetch_icon(icon_url)
- except OIDCIconError as exc:
- logger.warning("OIDC icon fetch failed: url=%s reason=%s", _redact_url_for_log(icon_url), exc)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
- def _build_provider_response(provider: OIDCProvider) -> OIDCProviderResponse:
- """Build OIDCProviderResponse via ``from_attributes``. The required
- ``has_icon`` field is supplied by ``OIDCProvider.has_icon`` (a property
- reading the non-deferred ``icon_content_type`` column)."""
- return OIDCProviderResponse.model_validate(provider)
- def _etag_matches(if_none_match: str | None, etag_raw: str | None) -> bool:
- """RFC 7232 §3.2 If-None-Match comparison.
- Supports:
- * ``*`` wildcard — matches any current representation when the resource
- exists (and it does here; we wouldn't have an etag otherwise).
- * Multiple comma-separated tokens.
- * Weak-validator prefix ``W/`` (RFC 7232 §2.3) — accepted on GET since
- cached representations of a static byte-blob are byte-identical.
- Returns False on missing header or missing stored etag.
- """
- if not if_none_match or not etag_raw:
- return False
- quoted = f'"{etag_raw}"'
- tokens = [t.strip() for t in if_none_match.split(",")]
- if "*" in tokens:
- return True
- return any(tok.removeprefix("W/") == quoted for tok in tokens)
- def _as_utc(dt: datetime) -> datetime:
- """Return *dt* with UTC timezone attached.
- SQLite/aiosqlite strips timezone info when reading DateTime(timezone=True)
- columns back – the stored value is always UTC, so we just re-attach the
- info when doing Python-level comparisons.
- """
- return dt if dt.tzinfo is not None else dt.replace(tzinfo=timezone.utc)
- # ---------------------------------------------------------------------------
- # Passlib context (same scheme as auth.py)
- # ---------------------------------------------------------------------------
- pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
- # ---------------------------------------------------------------------------
- # TTL / rate-limit constants
- # ---------------------------------------------------------------------------
- MAX_2FA_ATTEMPTS = 5
- MAX_LOGIN_ATTEMPTS = 10
- LOCKOUT_WINDOW = timedelta(minutes=15)
- MAX_EMAIL_OTP_SENDS = 3
- EMAIL_OTP_SEND_WINDOW = timedelta(minutes=10)
- PRE_AUTH_TOKEN_TTL = timedelta(minutes=5)
- OIDC_STATE_TTL = timedelta(minutes=10)
- OIDC_EXCHANGE_TTL = timedelta(minutes=2)
- # ---------------------------------------------------------------------------
- # Router
- # ---------------------------------------------------------------------------
- router = APIRouter(prefix="/auth", tags=["2fa", "oidc"])
- # ---------------------------------------------------------------------------
- # Helper: user response
- # ---------------------------------------------------------------------------
- def _user_to_response(user: User) -> UserResponse:
- return UserResponse(
- id=user.id,
- username=user.username,
- email=user.email,
- role=user.role,
- is_active=user.is_active,
- is_admin=user.is_admin,
- groups=[GroupBrief(id=g.id, name=g.name) for g in user.groups],
- permissions=sorted(user.get_permissions()),
- created_at=user.created_at.isoformat(),
- )
- # ---------------------------------------------------------------------------
- # Helper: QR code generation
- # ---------------------------------------------------------------------------
- def _generate_totp_qr_b64(provisioning_uri: str) -> str:
- """Generate a base64-encoded PNG QR code for the given TOTP provisioning URI."""
- import qrcode # type: ignore
- qr = qrcode.QRCode(box_size=6, border=2)
- qr.add_data(provisioning_uri)
- qr.make(fit=True)
- img = qr.make_image(fill_color="black", back_color="white")
- buf = io.BytesIO()
- img.save(buf, format="PNG")
- return base64.b64encode(buf.getvalue()).decode()
- # ---------------------------------------------------------------------------
- # Helper: backup code generation
- # ---------------------------------------------------------------------------
- def _generate_backup_codes() -> tuple[list[str], list[str]]:
- """Return (plain_codes, hashed_codes) — 10 codes of 8 alphanumeric chars each."""
- alphabet = string.ascii_uppercase + string.digits
- plain = ["".join(secrets.choice(alphabet) for _ in range(8)) for _ in range(10)]
- hashed = [pwd_context.hash(c) for c in plain]
- return plain, hashed
- # ---------------------------------------------------------------------------
- # DB-backed pre-auth token helpers
- # ---------------------------------------------------------------------------
- async def create_pre_auth_token(db: AsyncSession, username: str, challenge_id: str | None = None) -> str:
- """Create a single-use pre-auth token stored in the DB.
- Pass ``challenge_id`` (from the HttpOnly 2fa_challenge cookie) to bind the
- token to the originating browser session. The same value must be present as
- a cookie on every subsequent call that consumes this token.
- """
- now = datetime.now(timezone.utc)
- # Prune expired tokens opportunistically (keep table small)
- await db.execute(
- delete(AuthEphemeralToken).where(
- AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
- AuthEphemeralToken.expires_at < now,
- )
- )
- token = secrets.token_urlsafe(32)
- db.add(
- AuthEphemeralToken(
- token=token,
- token_type=TokenType.PRE_AUTH,
- username=username,
- challenge_id=challenge_id,
- expires_at=now + PRE_AUTH_TOKEN_TTL,
- )
- )
- await db.commit()
- return token
- async def consume_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None:
- """Atomically validate and consume a pre-auth token. Returns username or None.
- Uses DELETE...RETURNING so two concurrent requests with the same token cannot
- both succeed — only the first DELETE finds the row.
- M5: When challenge_id is provided, also enforces the cookie-binding constraint
- so a stolen token cannot be replayed from a different browser session.
- """
- now = datetime.now(timezone.utc)
- result = await db.execute(
- delete(AuthEphemeralToken)
- .where(
- AuthEphemeralToken.token == token,
- AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
- AuthEphemeralToken.expires_at > now,
- )
- .returning(AuthEphemeralToken.username, AuthEphemeralToken.challenge_id)
- )
- row = result.one_or_none()
- if row is None:
- return None
- username, stored_challenge_id = row
- # Enforce client binding: if the token was issued with a challenge_id,
- # the caller must supply the matching value.
- if stored_challenge_id is not None and stored_challenge_id != challenge_id:
- await db.rollback()
- return None
- await db.commit()
- return username
- async def peek_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None:
- """Validate a pre-auth token and return the username WITHOUT consuming it.
- When the stored token has a ``challenge_id`` (client-binding cookie), the
- caller must supply the matching value. A mismatch is treated as an invalid
- token — no information leakage about whether the token itself exists.
- """
- now = datetime.now(timezone.utc)
- result = await db.execute(
- select(AuthEphemeralToken).where(
- AuthEphemeralToken.token == token,
- AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
- AuthEphemeralToken.expires_at > now,
- )
- )
- eph = result.scalar_one_or_none()
- if eph is None:
- return None
- # Enforce client binding: if the token was issued with a challenge_id the
- # cookie must match. Treat a mismatch as if the token doesn't exist.
- if eph.challenge_id is not None and eph.challenge_id != challenge_id:
- return None
- return eph.username
- # ---------------------------------------------------------------------------
- # DB-backed rate-limiting helpers
- # ---------------------------------------------------------------------------
- async def check_rate_limit(
- db: AsyncSession,
- username: str,
- event_type: str = EventType.TWO_FA_ATTEMPT,
- max_attempts: int = MAX_2FA_ATTEMPTS,
- ) -> None:
- """Raise HTTP 429 if the user has exceeded the failed attempt limit.
- The username is normalised to lower-case so case-variant attempts
- (which all resolve to the same user) share the same rate-limit bucket.
- L-2: Known TOCTOU — the SELECT (count) and the subsequent INSERT
- (record_failed_attempt) are not atomic. Two concurrent requests can both
- read a count below the threshold and both proceed. This is an inherent
- trade-off of the event-log rate-limit pattern: fixing it would require
- a serialising lock (SELECT FOR UPDATE on a dedicated counter row), which
- adds contention and is not worth it for a soft rate-limit whose window is
- already measured in minutes. In practice the race window is microseconds
- and the limit can be slightly exceeded only under precise concurrent timing.
- """
- username_key = username.lower()
- now = datetime.now(timezone.utc)
- cutoff = now - LOCKOUT_WINDOW
- result = await db.execute(
- select(AuthRateLimitEvent).where(
- AuthRateLimitEvent.username == username_key,
- AuthRateLimitEvent.event_type == event_type,
- AuthRateLimitEvent.occurred_at > cutoff,
- )
- )
- recent_count = len(result.scalars().all())
- if recent_count >= max_attempts:
- raise HTTPException(
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
- detail="Too many failed attempts. Please try again later.",
- )
- async def record_failed_attempt(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None:
- """Record a failed attempt for rate-limiting purposes."""
- db.add(AuthRateLimitEvent(username=username.lower(), event_type=event_type))
- await db.commit()
- async def clear_failed_attempts(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None:
- """Delete all recorded failed attempts for a user on successful verification."""
- await db.execute(
- delete(AuthRateLimitEvent).where(
- AuthRateLimitEvent.username == username.lower(),
- AuthRateLimitEvent.event_type == event_type,
- )
- )
- await db.commit()
- async def check_email_otp_send_rate(db: AsyncSession, username: str) -> None:
- """Raise HTTP 429 if the user has requested too many OTP emails recently.
- I1: This function only *checks* the limit. The caller is responsible for
- recording the slot via ``record_email_otp_send`` **after** the email has
- been sent successfully. This prevents failed sends from consuming a slot
- (wasting the user's quota) and makes it impossible to farm rate-limit events
- without actually triggering a send.
- """
- username_key = username.lower()
- now = datetime.now(timezone.utc)
- cutoff = now - EMAIL_OTP_SEND_WINDOW
- result = await db.execute(
- select(AuthRateLimitEvent).where(
- AuthRateLimitEvent.username == username_key,
- AuthRateLimitEvent.event_type == EventType.EMAIL_SEND,
- AuthRateLimitEvent.occurred_at > cutoff,
- )
- )
- recent_count = len(result.scalars().all())
- if recent_count >= MAX_EMAIL_OTP_SENDS:
- raise HTTPException(
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
- detail=f"Too many OTP email requests. Please wait {EMAIL_OTP_SEND_WINDOW.seconds // 60} minutes.",
- )
- async def record_email_otp_send(db: AsyncSession, username: str) -> None:
- """Record a successful OTP email send for rate-limiting purposes (I1).
- Must be called *after* the email has been sent successfully so that failed
- sends do not consume a slot from the user's quota.
- """
- db.add(AuthRateLimitEvent(username=username.lower(), event_type=EventType.EMAIL_SEND))
- await db.commit()
- # ---------------------------------------------------------------------------
- # TOTP replay-protection helper
- # ---------------------------------------------------------------------------
- def _assert_totp_not_replayed(totp_obj: pyotp.TOTP, totp_record: UserTOTP, code: str) -> None:
- """Raise HTTP 400 if this TOTP code was already accepted in its time window.
- M3 fix: store the counter of the *accepted* code rather than the current
- wall-clock counter. With valid_window=1, pyotp accepts codes from the
- previous 30-second step. Using timecode(now) would store the wrong counter
- when the previous-window code is accepted, allowing immediate replay.
- """
- # Determine which time-step the accepted code belongs to.
- now = datetime.now(timezone.utc)
- accepted_counter: int | None = None
- for offset in (0, -1): # current window first, then previous
- candidate_time = now.timestamp() + offset * totp_obj.interval
- candidate_counter = totp_obj.timecode(datetime.fromtimestamp(candidate_time, tz=timezone.utc))
- if totp_obj.at(candidate_counter) == code:
- accepted_counter = candidate_counter
- break
- if accepted_counter is None:
- accepted_counter = totp_obj.timecode(now) # fallback (should not happen after verify())
- totp_record.accept_counter(accepted_counter)
- # ---------------------------------------------------------------------------
- # OIDC helpers
- # ---------------------------------------------------------------------------
- _EMAIL_SHAPE_RE = re.compile(r"[^\s@]+@[^\s@]+\.[^\s@]+")
- def _is_valid_email_shaped(value: str | None) -> bool:
- # SEC-2: shape check for non-standard claims (upn, preferred_username).
- # Requires local@domain.tld — rejects "@", "x@", "@domain", "x@nodot".
- if not value or len(value) > 255:
- return False
- return _EMAIL_SHAPE_RE.fullmatch(value) is not None
- def _enforce_auto_link_safety(provider: OIDCProvider) -> None:
- """Raise HTTP 422 if auto_link_existing_accounts is on with an unsafe combined state.
- SEC-1: only Fall B (email_claim='email' + require_email_verified=False) is unsafe —
- an attacker-controlled IdP could present an unverified email that matches a local account.
- Fall C (custom claim) never performs an email_verified check, so auto_link is safe there.
- Called after ORM construction (create) and after the setattr loop (update).
- """
- if provider.auto_link_existing_accounts and provider.email_claim == "email" and not provider.require_email_verified:
- raise HTTPException(
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- detail=AUTO_LINK_REQUIREMENTS_ERROR,
- )
- def _resolve_provider_email(provider: OIDCProvider, claims: dict, provider_sub: str) -> str | None:
- """Extract and normalise the email address from OIDC ID-token claims.
- Implements three resolution paths (Fall A/B/C):
- Fall C — custom email_claim (!= "email"): shape-check only, no email_verified gate.
- Recommended for Azure Entra ID (preferred_username or upn).
- Fall A — email_claim="email" + require_email_verified=True: strict, email_verified must be True.
- Fall B — email_claim="email" + require_email_verified=False: permissive, explicit False drops email.
- Returns a lowercase-stripped email string, or None when the claim is absent/invalid.
- """
- provider_id = provider.id
- raw_claim_value = claims.get(provider.email_claim)
- if raw_claim_value is not None and not isinstance(raw_claim_value, str):
- # TYPE-GUARD: non-string claim (e.g. list, int) would raise AttributeError on .lower().
- logger.warning(
- "OIDC provider %d: email_claim %r has unexpected type %s for sub=%r, ignoring",
- provider_id,
- provider.email_claim,
- type(raw_claim_value).__name__,
- provider_sub,
- )
- raw_claim_value = None
- raw_email: str | None = raw_claim_value.lower().strip() if raw_claim_value else None
- if provider.email_claim != "email":
- # Fall C: custom claim (preferred_username, upn, …) — no email_verified check.
- # SEC-2: _is_valid_email_shaped instead of bare '"@" in value'.
- # Recommended for Azure Entra ID: set email_claim="preferred_username" or "upn".
- if raw_email and _is_valid_email_shaped(raw_email):
- return raw_email
- if raw_email:
- logger.warning(
- "OIDC provider %d: email_claim %r value failed shape check for sub=%r, ignoring",
- provider_id,
- provider.email_claim,
- provider_sub,
- )
- return None
- email_verified = claims.get("email_verified")
- if provider.require_email_verified:
- # Fall A: standard C1-Guard — fail closed unless email_verified is True.
- # SEC-2: apply shape check to standard email claim — providers may set
- # email_verified=True on non-email values (e.g. numeric user IDs).
- # SEC-3 normalisation applies; existing mixed-case provider_email records
- # were normalised to lowercase by run_migrations at startup.
- if raw_email and not _is_valid_email_shaped(raw_email):
- logger.warning(
- "OIDC provider %d: email claim failed shape check for sub=%r, ignoring",
- provider_id,
- provider_sub,
- )
- return None
- if email_verified is True:
- return raw_email
- if raw_email:
- logger.info(
- "OIDC provider %d: ignoring email for sub=%r because email_verified=%r",
- provider_id,
- provider_sub,
- email_verified,
- )
- return None
- # Fall B: permissive — explicit False drops email, absent/None keeps it.
- # Required for Azure Entra ID which never sends email_verified.
- # SEC-2: apply shape check before the email_verified=False drop so malformed
- # values are rejected regardless of the email_verified claim.
- if raw_email and not _is_valid_email_shaped(raw_email):
- logger.warning(
- "OIDC provider %d: email claim failed shape check for sub=%r, ignoring",
- provider_id,
- provider_sub,
- )
- return None
- if email_verified is False:
- return None
- if email_verified is not True:
- # SEC-5: log only when the permissive path actually fires (ev absent/None),
- # not on every successful login.
- logger.info(
- "OIDC provider %r (%d): accepting email for sub=%r without email_verified claim (permissive mode)",
- provider.name,
- provider.id,
- provider_sub,
- )
- return raw_email
- # ---------------------------------------------------------------------------
- # Settings helpers (email 2FA flag)
- # ---------------------------------------------------------------------------
- async def _get_email_2fa_enabled(db: AsyncSession, user_id: int) -> bool:
- val = await get_setting(db, f"user_{user_id}_email_2fa_enabled")
- return val == "true"
- async def _set_email_2fa_enabled(db: AsyncSession, user_id: int, enabled: bool) -> None:
- await set_setting(db, f"user_{user_id}_email_2fa_enabled", "true" if enabled else "false")
- # ===========================================================================
- # 2FA Endpoints
- # ===========================================================================
- @router.get("/2fa/status", response_model=TwoFAStatusResponse)
- async def get_2fa_status(
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> TwoFAStatusResponse:
- """Return the current 2FA configuration for the authenticated user."""
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
- totp_record = result.scalar_one_or_none()
- totp_enabled = totp_record is not None and totp_record.is_enabled
- backup_codes_remaining = len(totp_record.backup_code_hashes) if totp_record else 0
- email_otp_enabled = await _get_email_2fa_enabled(db, current_user.id)
- return TwoFAStatusResponse(
- totp_enabled=totp_enabled,
- email_otp_enabled=email_otp_enabled,
- backup_codes_remaining=backup_codes_remaining,
- )
- @router.post("/2fa/totp/setup", response_model=TOTPSetupResponse)
- async def setup_totp(
- body: TOTPSetupRequest | None = Body(default=None),
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> TOTPSetupResponse:
- """Initiate TOTP setup: generates a new secret and QR code.
- Creates (or replaces) a pending UserTOTP record with is_enabled=False.
- The caller must confirm with POST /auth/2fa/totp/enable.
- M-R7-A: If an *active* TOTP is already configured, the caller must supply
- the current TOTP code in the request body to confirm intent before the
- secret is overwritten (prevents silently locking out the real user).
- """
- if not await is_auth_enabled(db):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Authentication is not enabled")
- # Upsert a pending TOTP record (is_enabled=False)
- existing = (await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))).scalar_one_or_none()
- # M-R7-A: Guard against silent TOTP replacement when one is already active.
- if existing and existing.is_enabled:
- await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- supplied_code = (body.code if body else None) or ""
- # S4: narrow the RuntimeError catch to ONLY the property access — that
- # is the single line that raises on key-loss. The previous wide try
- # block also covered record_failed_attempt, clear_failed_attempts,
- # and _assert_totp_not_replayed, so a future RuntimeError from any
- # of those would have been misreported as "TOTP secret unavailable".
- try:
- secret_plain = existing.secret
- except RuntimeError:
- logger.exception("TOTP decryption failed for user_id=%s", current_user.id)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="TOTP secret unavailable",
- )
- if not pyotp.TOTP(secret_plain).verify(supplied_code, valid_window=1):
- await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Current TOTP code required to replace an active authenticator",
- )
- await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- _assert_totp_not_replayed(pyotp.TOTP(secret_plain), existing, supplied_code)
- await db.flush() # L-3: persist last_totp_counter immediately to block replay
- secret = pyotp.random_base32()
- totp = pyotp.TOTP(secret)
- provisioning_uri = totp.provisioning_uri(name=current_user.username, issuer_name="Bambuddy")
- qr_b64 = _generate_totp_qr_b64(provisioning_uri)
- if existing:
- existing.secret = secret
- existing.is_enabled = False
- existing.backup_code_hashes = []
- else:
- db.add(UserTOTP(user_id=current_user.id, secret=secret, is_enabled=False))
- await db.commit()
- return TOTPSetupResponse(secret=secret, qr_code_b64=qr_b64, issuer="Bambuddy")
- @router.post("/2fa/totp/enable", response_model=TOTPEnableResponse)
- async def enable_totp(
- body: TOTPEnableRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> TOTPEnableResponse:
- """Confirm TOTP setup by verifying a code from the authenticator app.
- On success, enables TOTP and returns 10 single-use backup codes (shown once).
- L-R7-A: Rate-limited to prevent brute-forcing the 6-digit confirmation code.
- """
- # L-R7-A: Rate-limit the enable step to prevent brute-forcing the 6-digit code.
- await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP setup not initiated. Call /auth/2fa/totp/setup first."
- )
- try:
- totp_verify = pyotp.TOTP(totp_record.secret).verify(body.code, valid_window=1)
- except RuntimeError:
- logger.exception("TOTP decryption failed for user_id=%s", totp_record.user_id)
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="TOTP secret unavailable")
- if not totp_verify:
- await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP code")
- await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- plain_codes, hashed_codes = _generate_backup_codes()
- totp_record.is_enabled = True
- totp_record.backup_code_hashes = hashed_codes
- await db.commit()
- return TOTPEnableResponse(
- message="TOTP enabled successfully. Store your backup codes in a safe place.",
- backup_codes=plain_codes,
- )
- @router.post("/2fa/totp/disable")
- async def disable_totp(
- body: TOTPDisableRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Disable TOTP by verifying a valid TOTP code or a backup code.
- I10: Rate-limited to prevent backup-code brute-forcing from a hijacked session.
- """
- await check_rate_limit(db, current_user.username)
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record or not totp_record.is_enabled:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled")
- # Accept either a valid TOTP code or a valid backup code. When the secret
- # cannot be decrypted (encryption key lost), fall through to the backup-
- # code path so the user can still disable 2FA with their printed codes.
- totp_obj: pyotp.TOTP | None = None
- code_valid = False
- decryption_failed = False
- try:
- totp_obj = pyotp.TOTP(totp_record.secret)
- code_valid = totp_obj.verify(body.code, valid_window=1)
- except RuntimeError:
- # S3: track that the failure was server-side so we don't penalise
- # the user with a fail-counter increment for a problem they can't fix.
- decryption_failed = True
- logger.exception(
- "TOTP decryption failed for user_id=%s — falling through to backup-code check",
- totp_record.user_id,
- )
- if code_valid and totp_obj is not None:
- _assert_totp_not_replayed(totp_obj, totp_record, body.code)
- await db.flush() # L-3: persist last_totp_counter immediately to block replay
- else:
- # Check backup codes — always iterate all entries (L-R9-A: no early break
- # to avoid timing oracle based on code position in the list).
- for hashed in totp_record.backup_code_hashes:
- if pwd_context.verify(body.code, hashed):
- code_valid = True
- if not code_valid:
- # S3: skip the fail-counter debit when the cause was a server-side
- # decryption failure (key loss / rotation). The user submitted a
- # wrong backup code on top of a broken TOTP, but locking them out
- # of the recovery path for an admin's mistake is not the right move.
- if not decryption_failed:
- await record_failed_attempt(db, current_user.username)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid code")
- await db.execute(delete(UserTOTP).where(UserTOTP.user_id == current_user.id))
- await db.commit()
- return {"message": "TOTP disabled"}
- @router.post("/2fa/totp/regenerate-backup-codes", response_model=BackupCodesResponse)
- async def regenerate_backup_codes(
- body: TOTPDisableRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> BackupCodesResponse:
- """Generate 10 new backup codes. Requires a valid TOTP code OR a backup code.
- M10: Accepts backup codes for consistency with disable_totp — users who have
- lost their authenticator app but still have backup codes can regenerate.
- Rate-limited to prevent brute-forcing from a hijacked session.
- """
- await check_rate_limit(db, current_user.username)
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record or not totp_record.is_enabled:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled")
- # Same recovery contract as disable_totp: when the TOTP secret cannot be
- # decrypted, fall through to the backup-code branch so the user can
- # rotate their codes with a printed backup code.
- totp_obj: pyotp.TOTP | None = None
- code_valid = False
- decryption_failed = False
- try:
- totp_obj = pyotp.TOTP(totp_record.secret)
- code_valid = totp_obj.verify(body.code, valid_window=1)
- except RuntimeError:
- # S3: track server-side failure so we skip the fail-counter debit.
- decryption_failed = True
- logger.exception(
- "TOTP decryption failed for user_id=%s — falling through to backup-code check",
- totp_record.user_id,
- )
- if code_valid and totp_obj is not None:
- _assert_totp_not_replayed(totp_obj, totp_record, body.code)
- await db.flush() # L-3: persist last_totp_counter immediately to block replay
- else:
- # Accept a backup code as an alternative (M10)
- matched_index: int | None = None
- for idx, hashed in enumerate(totp_record.backup_code_hashes):
- if pwd_context.verify(body.code, hashed) and matched_index is None:
- matched_index = idx
- if matched_index is None:
- # S3: skip fail-counter debit when the cause was a server-side
- # decryption failure (key loss / rotation).
- if not decryption_failed:
- await record_failed_attempt(db, current_user.username)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP or backup code")
- # Remove the used backup code
- totp_record.backup_code_hashes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index]
- plain_codes, hashed_codes = _generate_backup_codes()
- totp_record.backup_code_hashes = hashed_codes
- await db.commit()
- return BackupCodesResponse(
- backup_codes=plain_codes,
- message="Backup codes regenerated. Store them safely — they will not be shown again.",
- )
- @router.post("/2fa/email/enable")
- async def enable_email_otp(
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Step 1 of email OTP enable: send a verification code to the user's email.
- C5: Proof of possession — the user must prove they control the registered email
- address before email 2FA is activated. Returns a ``setup_token`` that must be
- passed to POST /auth/2fa/email/enable/confirm together with the received code.
- H-3: Rate-limited to prevent email flooding via repeated calls to this endpoint.
- """
- await check_email_otp_send_rate(db, current_user.username)
- if not current_user.email:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="You must have an email address configured to enable email OTP 2FA",
- )
- smtp_settings = await get_smtp_settings(db)
- if not smtp_settings:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured")
- # Generate and store the setup token (reuse AuthEphemeralToken with type "email_otp_setup")
- now = datetime.now(timezone.utc)
- # Prune any existing pending setup tokens for this user
- await db.execute(
- delete(AuthEphemeralToken).where(
- AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
- AuthEphemeralToken.username == current_user.username,
- )
- )
- code = str(secrets.randbelow(1_000_000)).zfill(6)
- code_hash = pwd_context.hash(code)
- setup_token = secrets.token_urlsafe(32)
- db.add(
- AuthEphemeralToken(
- token=setup_token,
- token_type=TokenType.EMAIL_OTP_SETUP,
- username=current_user.username,
- # Reuse the nonce field to store the code hash
- nonce=code_hash,
- expires_at=now + timedelta(minutes=10),
- )
- )
- await db.commit()
- try:
- send_email(
- smtp_settings=smtp_settings,
- to_email=current_user.email,
- subject="Verify your Bambuddy email address for 2FA",
- body_text=(
- f"Your Bambuddy email 2FA setup code is: {code}\n\n"
- "Enter this code to confirm email-based two-factor authentication.\n"
- "The code expires in 10 minutes."
- ),
- body_html=(
- "<p>To enable <strong>email-based two-factor authentication</strong> on your Bambuddy account, "
- "enter the code below:</p>"
- f"<h2 style='letter-spacing:4px'>{code}</h2>"
- "<p>The code expires in <strong>10 minutes</strong>. "
- "If you did not request this, you can safely ignore this email.</p>"
- ),
- )
- await record_email_otp_send(db, current_user.username)
- except Exception as exc:
- logger.error("Failed to send email OTP setup code to user_id=%d: %s", current_user.id, exc)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send verification email"
- )
- return {"message": "Verification code sent to your email address", "setup_token": setup_token}
- @router.post("/2fa/email/enable/confirm")
- async def confirm_enable_email_otp(
- body: EmailOTPEnableConfirmRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Step 2 of email OTP enable: verify the code and activate email 2FA.
- H-2 fix: Uses peek-then-consume so a wrong code does NOT burn the setup token.
- The token is only deleted after successful code verification, allowing retries
- up to the rate limit (5 attempts / 15 min).
- M4: Rate-limited to prevent brute-forcing the 6-digit setup code.
- """
- await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- now = datetime.now(timezone.utc)
- # --- Peek: validate token without consuming ---
- peek_result = await db.execute(
- select(AuthEphemeralToken).where(
- AuthEphemeralToken.token == body.setup_token,
- AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
- AuthEphemeralToken.username == current_user.username,
- AuthEphemeralToken.expires_at > now,
- )
- )
- eph = peek_result.scalar_one_or_none()
- if eph is None:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token")
- code_hash = eph.nonce # code hash stored in the nonce field
- # --- Verify code before consuming the token ---
- if not pwd_context.verify(body.code, code_hash):
- await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification code")
- # --- Atomically consume the token now that the code is correct ---
- # DELETE...RETURNING prevents a concurrent request from using the same token.
- del_result = await db.execute(
- delete(AuthEphemeralToken)
- .where(
- AuthEphemeralToken.token == body.setup_token,
- AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
- AuthEphemeralToken.username == current_user.username,
- )
- .returning(AuthEphemeralToken.id)
- )
- if del_result.one_or_none() is None:
- # Concurrent request consumed it between peek and delete — treat as invalid.
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token")
- await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
- await _set_email_2fa_enabled(db, current_user.id, True)
- await db.commit()
- return {"message": "Email OTP 2FA enabled"}
- @router.post("/2fa/email/disable")
- async def disable_email_otp(
- body: EmailOTPDisableRequest,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Disable email-based OTP 2FA for the current user.
- C6: Re-authentication required — the caller must supply their account password
- to prevent a hijacked session from silently removing a second factor.
- LDAP/OIDC-only users (no local password) are exempt from this check.
- H-2: Rate-limited to prevent brute-forcing the password via this endpoint.
- """
- await check_rate_limit(db, current_user.username)
- if current_user.password_hash:
- if not verify_password(body.password, current_user.password_hash):
- await record_failed_attempt(db, current_user.username)
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password")
- await _set_email_2fa_enabled(db, current_user.id, False)
- await db.commit()
- return {"message": "Email OTP 2FA disabled"}
- @router.post("/2fa/email/send")
- async def send_email_otp(
- request: Request,
- body: EmailOTPSendRequest,
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Send a 6-digit OTP code to the user's email address.
- Requires a valid pre_auth_token obtained during the login flow.
- """
- # Peek (validate without consuming) first so a rate-limit rejection does not
- # permanently burn the caller's pre-auth token.
- challenge_id = request.cookies.get("2fa_challenge")
- username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not username:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- # Enforce rate limit BEFORE consuming the token to prevent OTP email flooding.
- await check_email_otp_send_rate(db, username)
- user = await get_user_by_username(db, username)
- if not user or not user.is_active:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
- if not user.email:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no email address configured")
- smtp_settings = await get_smtp_settings(db)
- if not smtp_settings:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured")
- # Invalidate all existing unused OTP codes for this user (staged, not yet committed)
- await db.execute(
- UserOTPCode.__table__.update() # type: ignore[attr-defined]
- .where(UserOTPCode.user_id == user.id)
- .where(UserOTPCode.used.is_(False))
- .values(used=True)
- )
- # Generate a 6-digit code and stage the record (not committed yet)
- code = str(secrets.randbelow(1_000_000)).zfill(6)
- code_hash = pwd_context.hash(code)
- expires_at = datetime.now(timezone.utc) + timedelta(minutes=UserOTPCode.OTP_TTL_MINUTES)
- otp_record = UserOTPCode(
- user_id=user.id,
- code_hash=code_hash,
- attempts=0,
- used=False,
- expires_at=expires_at,
- )
- db.add(otp_record)
- # M2: Send the email BEFORE consuming the pre-auth token.
- # If the send fails we raise an exception here; the session is uncommitted so
- # the OTP record is discarded and the original token remains valid for retry.
- try:
- send_email(
- smtp_settings=smtp_settings,
- to_email=user.email,
- subject="Your Bambuddy verification code",
- body_text=f"Your Bambuddy login code is: {code}\n\nThis code expires in {UserOTPCode.OTP_TTL_MINUTES} minutes and can only be used once.",
- body_html=(
- f"<p>Your <strong>Bambuddy</strong> login verification code is:</p>"
- f"<h2 style='letter-spacing:4px'>{code}</h2>"
- f"<p>This code expires in <strong>{UserOTPCode.OTP_TTL_MINUTES} minutes</strong> and can only be used once.</p>"
- f"<p>If you did not request this code, you can safely ignore this email.</p>"
- ),
- )
- await record_email_otp_send(db, username)
- except Exception as exc:
- logger.error("Failed to send OTP email to user_id=%d: %s", user.id, exc)
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send OTP email")
- # Email sent — now atomically consume the old token (this also commits the
- # staged OTP record) and issue a fresh token for the verify step.
- consumed = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not consumed:
- # Raced with another request or token just expired — treat as invalid.
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- # Re-issue a fresh pre-auth token bound to the same cookie so the binding
- # carries forward through the email → verify step.
- fresh_token = await create_pre_auth_token(db, username, challenge_id=challenge_id)
- # Return the fresh pre-auth token so the frontend can proceed to verify
- return {"message": "Code sent to your email address", "pre_auth_token": fresh_token}
- @router.post("/2fa/verify", response_model=TwoFAVerifyResponse)
- async def verify_2fa(
- request: Request,
- body: TwoFAVerifyRequest,
- db: AsyncSession = Depends(get_db),
- ) -> TwoFAVerifyResponse:
- """Verify a 2FA code and exchange the pre_auth_token for a full JWT.
- Accepted methods: ``totp``, ``email``, ``backup``.
- The pre_auth_token is NOT consumed on failed verification attempts so the
- user can retry without restarting the login flow. It is only consumed once
- verification succeeds, preventing token replay after success.
- """
- # Peek without consuming — bad codes must not burn the session token.
- # Pass the HttpOnly challenge cookie so the binding check is enforced.
- challenge_id = request.cookies.get("2fa_challenge")
- username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not username:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- await check_rate_limit(db, username)
- user = await get_user_by_username(db, username)
- if not user or not user.is_active:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
- method = body.method
- if method == "totp":
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record or not totp_record.is_enabled:
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user")
- try:
- totp_obj = pyotp.TOTP(totp_record.secret)
- except RuntimeError:
- logger.exception("TOTP decryption failed for user_id=%s", totp_record.user_id)
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="TOTP secret unavailable")
- if not totp_obj.verify(body.code, valid_window=1):
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid TOTP code")
- _assert_totp_not_replayed(totp_obj, totp_record, body.code)
- await db.flush() # L-3: persist last_totp_counter immediately to block replay
- elif method == "email":
- now = datetime.now(timezone.utc)
- result = await db.execute(
- select(UserOTPCode)
- .where(UserOTPCode.user_id == user.id)
- .where(UserOTPCode.used.is_(False))
- .where(UserOTPCode.expires_at > now)
- .order_by(UserOTPCode.created_at.desc())
- )
- otp_record = result.scalar_one_or_none()
- if not otp_record:
- await record_failed_attempt(db, username)
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="No valid OTP code found. Request a new one."
- )
- if otp_record.attempts >= UserOTPCode.MAX_ATTEMPTS:
- otp_record.consume()
- await db.commit()
- await record_failed_attempt(db, username)
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="OTP code has been invalidated after too many attempts"
- )
- if not pwd_context.verify(body.code, otp_record.code_hash):
- otp_record.attempts += 1
- await db.commit()
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid OTP code")
- otp_record.consume()
- await db.commit()
- else: # method == "backup"
- result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
- totp_record = result.scalar_one_or_none()
- if not totp_record or not totp_record.is_enabled:
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user")
- # Always iterate all codes — no early break (L-R9-A: constant iteration
- # count prevents timing oracle based on used-code position in the list).
- matched_index: int | None = None
- for idx, hashed in enumerate(totp_record.backup_code_hashes):
- if pwd_context.verify(body.code, hashed) and matched_index is None:
- matched_index = idx
- if matched_index is None:
- await record_failed_attempt(db, username)
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid backup code")
- # M1: Consume the pre-auth token FIRST (atomic single-use enforcement).
- # Only if that succeeds do we remove the backup code — this prevents a race
- # where two concurrent requests both pass code verification but only one
- # should be granted a session.
- consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not consumed_username:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- # Remove the used backup code now that the token is atomically consumed.
- updated_codes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index]
- totp_record.backup_code_hashes = updated_codes
- await db.commit()
- await clear_failed_attempts(db, username)
- access_token = create_access_token(
- data={"sub": user.username},
- expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
- )
- result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
- user = result.scalar_one()
- return TwoFAVerifyResponse(access_token=access_token, token_type="bearer", user=_user_to_response(user))
- # Verification succeeded (TOTP or email) — consume the pre-auth token.
- # C-1: Check the return value; if None the token was already consumed by a
- # concurrent request (race condition) — reject to prevent double-use.
- consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
- if not consumed_username:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
- await clear_failed_attempts(db, username)
- access_token = create_access_token(
- data={"sub": user.username},
- expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
- )
- # Reload with groups for permission calculation
- result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
- user = result.scalar_one()
- return TwoFAVerifyResponse(
- access_token=access_token,
- token_type="bearer",
- user=_user_to_response(user),
- )
- @router.delete("/2fa/admin/{user_id}")
- async def admin_disable_2fa(
- user_id: int,
- body: AdminDisable2FARequest = Body(default_factory=AdminDisable2FARequest),
- current_user: User | None = RequirePermissionIfAuthEnabled(Permission.USERS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Admin endpoint: disable all 2FA for a given user.
- Nit 3: Requires the admin's own password as a re-auth step (matching how
- disable_email_otp protects a user's own 2FA removal). OIDC/LDAP-only admins
- (no local password_hash) are exempt.
- """
- # Nit 3: Re-auth — admin must supply their own password.
- if current_user and current_user.password_hash:
- if not body.admin_password or not verify_password(body.admin_password, current_user.password_hash):
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin password required")
- # Delete TOTP record
- await db.execute(delete(UserTOTP).where(UserTOTP.user_id == user_id))
- # Disable email 2FA setting
- await _set_email_2fa_enabled(db, user_id, False)
- # Invalidate all OTP codes
- await db.execute(
- UserOTPCode.__table__.update() # type: ignore[attr-defined]
- .where(UserOTPCode.user_id == user_id)
- .values(used=True)
- )
- # I2: Invalidate existing JWTs for the target user by bumping password_changed_at.
- # Without this, a stolen token remains valid after 2FA removal.
- target_user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none()
- if target_user:
- target_user.password_changed_at = datetime.now(timezone.utc)
- await db.commit()
- actor = current_user.username if current_user else "anonymous"
- logger.info("Admin %s disabled all 2FA for user_id=%d", actor, user_id)
- return {"message": "2FA disabled for user"}
- # ===========================================================================
- # OIDC Endpoints
- # ===========================================================================
- @router.get("/oidc/providers", response_model=list[OIDCProviderResponse])
- async def list_oidc_providers(
- db: AsyncSession = Depends(get_db),
- ) -> list[OIDCProviderResponse]:
- """List all enabled OIDC providers (public).
- The login page renders icons via /oidc/providers/{id}/icon — `icon_data`
- stays deferred so this list query never pulls the BLOB.
- """
- result = await db.execute(select(OIDCProvider).where(OIDCProvider.is_enabled.is_(True)))
- providers = result.scalars().all()
- return [_build_provider_response(p) for p in providers]
- @router.get("/oidc/providers/all", response_model=list[OIDCProviderResponse])
- async def list_all_oidc_providers(
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_READ),
- db: AsyncSession = Depends(get_db),
- ) -> list[OIDCProviderResponse]:
- """List ALL OIDC providers including disabled ones (admin only)."""
- result2 = await db.execute(select(OIDCProvider))
- providers = result2.scalars().all()
- return [_build_provider_response(p) for p in providers]
- @router.post("/oidc/providers", response_model=OIDCProviderResponse, status_code=status.HTTP_201_CREATED)
- async def create_oidc_provider(
- body: OIDCProviderCreate,
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> OIDCProviderResponse:
- """Create a new OIDC provider (admin only).
- If `icon_url` is supplied, the icon is fetched server-side and cached in
- the BLOB columns (#1333). A fetch failure aborts the create with 400 —
- no half-configured provider is left in the DB.
- """
- if body.default_group_id is not None:
- grp_chk = await db.execute(select(Group).where(Group.id == body.default_group_id))
- if not grp_chk.scalar_one_or_none():
- raise HTTPException(
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- detail="default_group_id references a non-existent group",
- )
- # Fetch the icon BEFORE creating the row so a failure leaves the DB clean.
- icon_data: bytes | None = None
- icon_content_type: str | None = None
- icon_etag: str | None = None
- if body.icon_url:
- icon_data, icon_content_type, icon_etag = await _fetch_icon_or_400(body.icon_url)
- provider = OIDCProvider(
- name=body.name,
- issuer_url=body.issuer_url.rstrip("/"),
- client_id=body.client_id,
- client_secret=body.client_secret,
- scopes=body.scopes,
- is_enabled=body.is_enabled,
- auto_create_users=body.auto_create_users,
- auto_link_existing_accounts=body.auto_link_existing_accounts,
- email_claim=body.email_claim,
- require_email_verified=body.require_email_verified,
- icon_url=body.icon_url,
- icon_data=icon_data,
- icon_content_type=icon_content_type,
- icon_etag=icon_etag,
- default_group_id=body.default_group_id,
- )
- # SEC-1 + SEC-6: runtime guard mirrors the OIDCProviderCreate model_validator in schemas/auth.py.
- # Catches any future path that bypasses Pydantic validation (direct ORM, scripts).
- _enforce_auto_link_safety(provider)
- db.add(provider)
- await db.commit()
- await db.refresh(provider)
- return _build_provider_response(provider)
- @router.put("/oidc/providers/{provider_id}", response_model=OIDCProviderResponse)
- async def update_oidc_provider(
- provider_id: int,
- body: OIDCProviderUpdate,
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> OIDCProviderResponse:
- """Update an existing OIDC provider (admin only).
- Icon refetch fires when:
- 1. The submitted `icon_url` differs from the stored one (URL changed), OR
- 2. The submitted `icon_url` equals the stored one AND `icon_content_type`
- is NULL — this is the upgrade-path edge case: old providers carry
- `icon_url` but no cached bytes until the admin first saves them.
- On fetch failure the request aborts with 400 *before* commit, so the
- existing cached bytes (if any) remain untouched.
- """
- result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
- provider = result2.scalar_one_or_none()
- if not provider:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
- if body.default_group_id is not None:
- grp_chk = await db.execute(select(Group).where(Group.id == body.default_group_id))
- if not grp_chk.scalar_one_or_none():
- raise HTTPException(
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
- detail="default_group_id references a non-existent group",
- )
- dumped = body.model_dump(exclude_none=True)
- # Decide whether an icon refetch is needed BEFORE mutating the ORM object,
- # so the comparison sees provider.icon_url / icon_content_type as they are
- # in the database.
- new_icon_url = dumped.get("icon_url")
- needs_icon_refetch = new_icon_url is not None and (
- new_icon_url != provider.icon_url or provider.icon_content_type is None
- )
- # Fetch FIRST. If the upstream is unreachable or SSRF-blocked, _fetch_icon_or_400
- # raises HTTPException(400) here — provider attributes are still untouched, so
- # the in-memory ORM object stays consistent on the way out (and the DB row is
- # safe regardless via get_db()'s rollback).
- fetched_icon: tuple[bytes, str, str] | None = None
- if needs_icon_refetch:
- fetched_icon = await _fetch_icon_or_400(new_icon_url)
- # Explicit `icon_url: null` in the PUT body means "clear the icon".
- # The exclude_none=True dump above drops None values, which would
- # otherwise silently ignore this request. Check model_fields_set on
- # the unfiltered body to distinguish "client cleared it" from "client
- # didn't include this field at all".
- if "icon_url" in body.model_fields_set and body.icon_url is None:
- provider.icon_url = None
- provider.icon_data = None
- provider.icon_content_type = None
- provider.icon_etag = None
- for field, value in dumped.items():
- if field == "issuer_url" and value:
- value = value.rstrip("/")
- setattr(provider, field, value)
- if fetched_icon is not None:
- provider.icon_data, provider.icon_content_type, provider.icon_etag = fetched_icon
- # SEC-1 + SEC-6: Combined-State-Guard after setattr loop.
- # Checks the final in-memory state (DB values + newly set values combined) to catch
- # partial updates that each pass schema validation individually but are unsafe together.
- _enforce_auto_link_safety(provider)
- await db.commit()
- await db.refresh(provider)
- return _build_provider_response(provider)
- @router.delete("/oidc/providers/{provider_id}")
- async def delete_oidc_provider(
- provider_id: int,
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Delete an OIDC provider and all its user links (admin only)."""
- result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
- provider = result2.scalar_one_or_none()
- if not provider:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
- await db.delete(provider)
- await db.commit()
- return {"message": "Provider deleted"}
- # ---------------------------------------------------------------------------
- # OIDC provider icon proxy (#1333)
- # ---------------------------------------------------------------------------
- @router.get("/oidc/providers/{provider_id}/icon")
- async def get_oidc_provider_icon(
- provider_id: int,
- if_none_match: str | None = Header(default=None, alias="If-None-Match"),
- db: AsyncSession = Depends(get_db),
- ) -> Response:
- """Serve the cached icon for an enabled OIDC provider (public, no auth).
- Unauthenticated because ``<img>`` tags cannot send Authorization headers
- and the login page renders these icons before the user is signed in — the
- same justification as ``/api/v1/makerworld/thumbnail``. The SSRF guard
- runs at admin-config time (create/update/refresh), not here.
- Disabled providers respond 404 to avoid leaking their existence to
- anonymous callers (mirrors ``GET /oidc/providers`` which filters on
- ``is_enabled``).
- """
- result = await db.execute(
- select(OIDCProvider)
- .options(undefer(OIDCProvider.icon_data))
- .where(OIDCProvider.id == provider_id, OIDCProvider.is_enabled.is_(True))
- )
- provider = result.scalar_one_or_none()
- if provider is None or provider.icon_content_type is None or provider.icon_data is None:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Icon not found")
- etag_value = f'"{provider.icon_etag}"'
- cache_headers = {"ETag": etag_value, "Cache-Control": "public, max-age=3600"}
- if _etag_matches(if_none_match, provider.icon_etag):
- return Response(status_code=status.HTTP_304_NOT_MODIFIED, headers=cache_headers)
- return Response(
- content=provider.icon_data,
- media_type=provider.icon_content_type,
- headers=cache_headers,
- )
- @router.delete("/oidc/providers/{provider_id}/icon", status_code=status.HTTP_204_NO_CONTENT)
- async def delete_oidc_provider_icon(
- provider_id: int,
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> Response:
- """Remove the icon entirely for a provider (admin only).
- Clears all four icon columns — ``icon_url`` plus the three cached-bytes
- columns. "Remove icon" means the whole record is gone, not just the
- cache; without this the admin form would still show the URL while
- the login page rendered a blank fallback (confusing half-state).
- To re-add an icon the admin re-types the URL in the edit form.
- """
- result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
- provider = result.scalar_one_or_none()
- if provider is None:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
- # Setting deferred columns is safe — no read happens, just a write.
- provider.icon_url = None
- provider.icon_data = None
- provider.icon_content_type = None
- provider.icon_etag = None
- await db.commit()
- return Response(status_code=status.HTTP_204_NO_CONTENT)
- @router.post("/oidc/providers/{provider_id}/icon/refresh", response_model=OIDCProviderResponse)
- async def refresh_oidc_provider_icon(
- provider_id: int,
- _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
- db: AsyncSession = Depends(get_db),
- ) -> OIDCProviderResponse:
- """Refetch the icon from the stored `icon_url` (admin only).
- Used when:
- - The IdP changed its icon and the admin wants Bambuddy to pick up the
- new bytes.
- - An upgrade left the provider with an `icon_url` but no cached bytes
- (covered automatically by `update_oidc_provider` too, but this gives
- the UI an explicit "Refresh" button).
- Failure to refetch returns 400 *before* commit, so the previously cached
- bytes survive intact.
- """
- result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
- provider = result.scalar_one_or_none()
- if provider is None:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
- if not provider.icon_url:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Provider has no icon_url to refresh",
- )
- icon_data, icon_content_type, icon_etag = await _fetch_icon_or_400(provider.icon_url)
- provider.icon_data = icon_data
- provider.icon_content_type = icon_content_type
- provider.icon_etag = icon_etag
- await db.commit()
- await db.refresh(provider)
- return _build_provider_response(provider)
- @router.get("/oidc/authorize/{provider_id}", response_model=OIDCAuthorizeResponse)
- async def oidc_authorize(
- provider_id: int,
- db: AsyncSession = Depends(get_db),
- ) -> OIDCAuthorizeResponse:
- """Return the OIDC authorization URL for the given provider."""
- result = await db.execute(
- select(OIDCProvider).where(OIDCProvider.id == provider_id).where(OIDCProvider.is_enabled.is_(True))
- )
- provider = result.scalar_one_or_none()
- if not provider:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found or not enabled")
- # Fetch discovery document
- discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration"
- try:
- async with httpx.AsyncClient(timeout=10) as client:
- resp = await client.get(discovery_url)
- resp.raise_for_status()
- discovery = resp.json()
- except Exception as exc:
- logger.error("Failed to fetch OIDC discovery for provider %d: %s", provider_id, exc)
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch OIDC discovery document")
- authorization_endpoint = discovery.get("authorization_endpoint")
- if not authorization_endpoint:
- raise HTTPException(
- status_code=status.HTTP_502_BAD_GATEWAY, detail="OIDC discovery document missing authorization_endpoint"
- )
- # B2: SSRF guard — reject non-HTTP(S) schemes in the authorization endpoint
- if not authorization_endpoint.startswith(("https://", "http://")):
- logger.warning("OIDC discovery authorization_endpoint has invalid scheme: %s", authorization_endpoint)
- raise HTTPException(
- status_code=status.HTTP_502_BAD_GATEWAY,
- detail="OIDC discovery document contains invalid authorization_endpoint",
- )
- external_url = await _get_base_external_url(db)
- redirect_uri = f"{external_url}/api/v1/auth/oidc/callback"
- now = datetime.now(timezone.utc)
- # Prune expired OIDC states from the DB
- await db.execute(
- delete(AuthEphemeralToken).where(
- AuthEphemeralToken.token_type == TokenType.OIDC_STATE,
- AuthEphemeralToken.expires_at < now,
- )
- )
- state = secrets.token_urlsafe(32)
- nonce = secrets.token_urlsafe(32)
- # PKCE (S256) – required by PocketID and recommended for all OIDC flows
- code_verifier = secrets.token_urlsafe(48) # 64-char URL-safe string
- code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
- db.add(
- AuthEphemeralToken(
- token=state,
- token_type=TokenType.OIDC_STATE,
- provider_id=provider_id,
- nonce=nonce,
- code_verifier=code_verifier,
- expires_at=now + OIDC_STATE_TTL,
- )
- )
- await db.commit()
- params = urllib.parse.urlencode(
- {
- "response_type": "code",
- "client_id": provider.client_id,
- "redirect_uri": redirect_uri,
- "scope": provider.scopes,
- "state": state,
- "nonce": nonce,
- "code_challenge": code_challenge,
- "code_challenge_method": "S256",
- }
- )
- auth_url = f"{authorization_endpoint}?{params}"
- return OIDCAuthorizeResponse(auth_url=auth_url)
- @router.get("/oidc/callback")
- async def oidc_callback(
- code: str | None = Query(default=None, max_length=2048),
- state: str | None = Query(default=None, max_length=2048),
- error: str | None = Query(default=None, max_length=256),
- db: AsyncSession = Depends(get_db),
- ) -> RedirectResponse:
- """Handle the OIDC authorization code callback from the identity provider."""
- external_url = await _get_base_external_url(db)
- frontend_error_url = f"{external_url}/?oidc_error="
- try:
- if error:
- logger.warning("OIDC callback received error: %s", error)
- return RedirectResponse(url=f"{frontend_error_url}oidc_provider_error", status_code=302)
- if not code or not state:
- return RedirectResponse(url=f"{frontend_error_url}missing_parameters", status_code=302)
- # Atomically validate and consume OIDC state from DB (I6: single-use enforcement).
- # DELETE...RETURNING ensures concurrent callbacks with the same state token
- # cannot both succeed — only the first DELETE finds the row.
- now = datetime.now(timezone.utc)
- state_del = await db.execute(
- delete(AuthEphemeralToken)
- .where(
- AuthEphemeralToken.token == state,
- AuthEphemeralToken.token_type == TokenType.OIDC_STATE,
- AuthEphemeralToken.expires_at > now, # reject expired tokens atomically
- )
- .returning(
- AuthEphemeralToken.provider_id,
- AuthEphemeralToken.nonce,
- AuthEphemeralToken.code_verifier,
- )
- )
- state_row = state_del.one_or_none()
- if state_row is None:
- await db.rollback()
- return RedirectResponse(url=f"{frontend_error_url}invalid_state", status_code=302)
- provider_id, nonce, code_verifier = state_row
- await db.commit()
- # Load provider
- result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
- provider = result.scalar_one_or_none()
- if not provider:
- return RedirectResponse(url=f"{frontend_error_url}provider_not_found", status_code=302)
- redirect_uri = f"{external_url}/api/v1/auth/oidc/callback"
- # ── Step 1: Fetch discovery document ────────────────────────────────
- discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration"
- try:
- async with httpx.AsyncClient(timeout=10) as client:
- disc_resp = await client.get(discovery_url)
- disc_resp.raise_for_status()
- discovery = disc_resp.json()
- except Exception as exc:
- logger.error("OIDC discovery fetch failed for provider %d: %s", provider_id, exc)
- return RedirectResponse(url=f"{frontend_error_url}discovery_failed", status_code=302)
- token_endpoint = discovery.get("token_endpoint")
- jwks_uri = discovery.get("jwks_uri")
- if not token_endpoint or not jwks_uri:
- return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302)
- # L-R7-C: Reject non-HTTP(S) URLs in the discovery document to prevent
- # SSRF via crafted responses (e.g. file://, gopher://, internal schemes).
- if not token_endpoint.startswith(("https://", "http://")) or not jwks_uri.startswith(("https://", "http://")):
- logger.warning(
- "OIDC discovery document contains non-HTTP URL(s): token=%s jwks=%s", token_endpoint, jwks_uri
- )
- return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302)
- # ── Step 2: Exchange authorization code for tokens ───────────────────
- token_form: dict[str, str] = {
- "grant_type": "authorization_code",
- "code": code,
- "redirect_uri": redirect_uri,
- "client_id": provider.client_id,
- }
- if provider.client_secret:
- token_form["client_secret"] = provider.client_secret
- if code_verifier:
- token_form["code_verifier"] = code_verifier
- try:
- async with httpx.AsyncClient(timeout=15) as client:
- token_resp = await client.post(
- token_endpoint,
- data=token_form,
- headers={"Accept": "application/json"},
- )
- except Exception as exc:
- logger.error("OIDC token exchange request failed for provider %d: %s", provider_id, exc)
- return RedirectResponse(url=f"{frontend_error_url}token_exchange_network_error", status_code=302)
- if not token_resp.is_success:
- try:
- err_body = token_resp.json()
- oidc_err = err_body.get("error", "")
- oidc_desc = err_body.get("error_description", "")
- except Exception:
- oidc_err = ""
- oidc_desc = token_resp.text[:200]
- logger.error(
- "OIDC token exchange HTTP %d for provider %d. redirect_uri=%r error=%r desc=%r",
- token_resp.status_code,
- provider_id,
- redirect_uri,
- oidc_err,
- oidc_desc,
- )
- # Encode the OIDC error code into the redirect so the user sees it in the toast.
- # URL-encode the value to prevent query-parameter injection from provider responses.
- raw_err = oidc_err[:40] if oidc_err else str(token_resp.status_code)
- safe_err = urllib.parse.quote(raw_err, safe="")
- return RedirectResponse(
- url=f"{frontend_error_url}token_exchange_{safe_err}",
- status_code=302,
- )
- try:
- token_data = token_resp.json()
- except Exception as exc:
- logger.error("OIDC token exchange non-JSON response for provider %d: %s", provider_id, exc)
- return RedirectResponse(url=f"{frontend_error_url}token_exchange_bad_response", status_code=302)
- id_token = token_data.get("id_token")
- if not id_token:
- # Only log the keys present — values may contain secrets (access_token, etc.)
- logger.error(
- "OIDC token response missing id_token for provider %d; keys present: %s",
- provider_id,
- list(token_data.keys()),
- )
- return RedirectResponse(url=f"{frontend_error_url}no_id_token", status_code=302)
- # ── Step 3: Fetch JWKS and validate ID token ─────────────────────────
- # Use the issuer from the discovery document as the canonical value (OIDC Core
- # §3.1.3.7 requires iss == discovery issuer exactly). We strip trailing slashes
- # from both sides because some providers (e.g. Authentik, older PocketID versions)
- # are inconsistent between the discovery issuer and the JWT iss claim.
- discovery_issuer: str = discovery.get("issuer", provider.issuer_url).rstrip("/")
- try:
- async with httpx.AsyncClient(timeout=10) as jwks_http:
- jwks_resp = await jwks_http.get(jwks_uri)
- jwks_resp.raise_for_status()
- jwks_data = jwks_resp.json()
- jwks_client = PyJWKClient(jwks_uri)
- jwks_client.fetch_data = lambda: jwks_data # type: ignore[method-assign]
- signing_key = jwks_client.get_signing_key_from_jwt(id_token)
- # M-3: Decode without built-in issuer check, then compare normalised
- # (both sides rstrip("/")) to handle providers like Authentik that include
- # a trailing slash in iss but not in the discovery issuer, or vice-versa.
- claims = jwt.decode(
- id_token,
- signing_key.key,
- algorithms=["RS256", "ES256", "RS384", "ES384", "RS512"],
- audience=provider.client_id,
- options={"verify_iss": False},
- )
- token_iss = claims.get("iss", "").rstrip("/")
- if token_iss != discovery_issuer:
- raise jwt.exceptions.InvalidIssuerError("Invalid issuer")
- except Exception as exc:
- logger.error("OIDC JWT validation failed for provider %d: %s", provider_id, exc, exc_info=True)
- return RedirectResponse(url=f"{frontend_error_url}token_validation_failed", status_code=302)
- # Verify nonce — fail closed: we always send a nonce, so the provider must echo it.
- # Skipping the check when nonce is absent would allow CSRF on non-nonce providers.
- token_nonce = claims.get("nonce")
- if token_nonce is None or token_nonce != nonce:
- logger.warning("OIDC nonce mismatch for provider %d (present=%r)", provider_id, token_nonce is not None)
- return RedirectResponse(url=f"{frontend_error_url}nonce_mismatch", status_code=302)
- provider_sub: str = claims.get("sub", "")
- if not provider_sub:
- return RedirectResponse(url=f"{frontend_error_url}missing_sub_claim", status_code=302)
- # SEC-3: resolve email via Fall A/B/C logic (see _resolve_provider_email).
- provider_email = _resolve_provider_email(provider, claims, provider_sub)
- # ── Step 4: Resolve / create user ────────────────────────────────────
- try:
- # 1. Look up existing OIDC link
- link_result = await db.execute(
- select(UserOIDCLink)
- .where(UserOIDCLink.provider_id == provider_id)
- .where(UserOIDCLink.provider_user_id == provider_sub)
- )
- link = link_result.scalar_one_or_none()
- user: User | None = None
- if link:
- # Existing link → load the linked user
- user_result = await db.execute(
- select(User).where(User.id == link.user_id).options(selectinload(User.groups))
- )
- user = user_result.scalar_one_or_none()
- else:
- # 2. No OIDC link yet — check for an existing user with the same email.
- # Use case-insensitive matching (func.lower) so that "User@Example.com"
- # and "user@example.com" are treated as the same identity, preventing
- # an attacker-controlled IdP from bypassing the auto-link guard by
- # registering the target email with different casing.
- email_user: User | None = None
- if provider_email:
- email_user = await get_user_by_email(db, provider_email)
- if email_user and provider.auto_link_existing_accounts:
- # M-4: Only auto-link when the provider has auto_link_existing_accounts
- # enabled. Operators can disable this to require explicit account linking,
- # preventing an attacker-controlled IdP from hijacking local accounts.
- #
- # M-NEW-6: Refuse auto-link if the target user already has any OIDC
- # link (to any provider). Without this guard an attacker who controls
- # a second OIDC provider with auto_link enabled could add themselves as
- # a second IdP for a user that already authenticates via a legitimate
- # provider, effectively taking over the account.
- existing_links_result = await db.execute(
- select(UserOIDCLink).where(UserOIDCLink.user_id == email_user.id)
- )
- has_existing_oidc_link = existing_links_result.scalar_one_or_none() is not None
- if has_existing_oidc_link:
- logger.warning(
- "Auto-link rejected for user '%s': already linked to another OIDC provider",
- email_user.username,
- )
- return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302)
- db.add(
- UserOIDCLink(
- user_id=email_user.id,
- provider_id=provider_id,
- provider_user_id=provider_sub,
- provider_email=provider_email,
- )
- )
- await db.commit()
- user = email_user
- logger.info(
- "Auto-linked existing user '%s' to OIDC provider %d via email match",
- email_user.username,
- provider_id,
- )
- elif provider.auto_create_users:
- # 3. No existing user — create one
- if provider_email:
- raw = provider_email.split("@")[0]
- else:
- # Prefer a human-readable IdP claim over the opaque sub.
- # isinstance guards are required: claims may carry non-string
- # values (e.g. a list) that would break .strip().
- # Sanitization is applied per-candidate so that a value that
- # strips to empty (e.g. "!!!") correctly falls through to the
- # next candidate rather than silently becoming "oidcuser".
- _pref = claims.get("preferred_username")
- _name = claims.get("name")
- raw = ""
- if isinstance(_pref, str):
- raw = re.sub(r"[^a-zA-Z0-9._-]", "", _pref.strip())[:30]
- if not raw and isinstance(_name, str):
- raw = re.sub(r"[^a-zA-Z0-9._-]", "", _name.strip())[:30]
- if not raw:
- raw = provider_sub[:30]
- candidate = re.sub(r"[^a-zA-Z0-9._-]", "", raw)[:30] or "oidcuser"
- username = candidate
- counter = 1
- while True:
- existing = await get_user_by_username(db, username)
- if not existing:
- break
- username = f"{candidate}{counter}"
- counter += 1
- # I9: Assign new OIDC users to a group before flush — accessing
- # new_user.groups after a flush triggers a lazy-load which fails
- # in async context. Resolution order:
- # 1. provider.default_group_id (operator-configured)
- # 2. "Viewers" (system fallback for read-only access)
- # 3. no group (last resort if Viewers was deleted)
- # SQLite does not enforce ON DELETE SET NULL, so a dangling
- # default_group_id returns None here and falls through to Viewers.
- default_group: Group | None = None
- if provider.default_group_id is not None:
- dg_result = await db.execute(select(Group).where(Group.id == provider.default_group_id))
- default_group = dg_result.scalar_one_or_none()
- if default_group is None:
- viewers_result = await db.execute(select(Group).where(Group.name == "Viewers"))
- default_group = viewers_result.scalar_one_or_none()
- new_user = User(
- username=username,
- email=provider_email,
- # M-1: auth_source="oidc" prevents local password-reset flow
- # for users who should only authenticate via OIDC.
- auth_source="oidc",
- password_hash=None, # OIDC users never use password auth
- role="user",
- is_active=True,
- groups=[default_group] if default_group else [],
- )
- db.add(new_user)
- await db.flush()
- db.add(
- UserOIDCLink(
- user_id=new_user.id,
- provider_id=provider_id,
- provider_user_id=provider_sub,
- provider_email=provider_email,
- )
- )
- await db.commit()
- user_result = await db.execute(
- select(User).where(User.id == new_user.id).options(selectinload(User.groups))
- )
- user = user_result.scalar_one()
- logger.info("Auto-created user '%s' via OIDC provider %d", username, provider_id)
- else:
- return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302)
- if not user or not user.is_active:
- return RedirectResponse(url=f"{frontend_error_url}account_inactive", status_code=302)
- # Issue an OIDC exchange token (short-lived, single-use) stored in DB.
- # I7: Opportunistically prune expired exchange tokens to keep the table small.
- now2 = datetime.now(timezone.utc)
- await db.execute(
- delete(AuthEphemeralToken).where(
- AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE,
- AuthEphemeralToken.expires_at < now2,
- )
- )
- exchange_token = secrets.token_urlsafe(32)
- db.add(
- AuthEphemeralToken(
- token=exchange_token,
- token_type=TokenType.OIDC_EXCHANGE,
- username=user.username,
- expires_at=now2 + OIDC_EXCHANGE_TTL,
- )
- )
- await db.commit()
- # H-4: Use a URL fragment (#) instead of a query parameter so the exchange
- # token is never sent to the server in the Referer header or server logs.
- return RedirectResponse(url=f"{external_url}/login#oidc_token={exchange_token}", status_code=302)
- except Exception as exc:
- logger.error("OIDC user resolution failed for provider %d: %s", provider_id, exc, exc_info=True)
- try:
- await db.rollback()
- except Exception as rb_exc:
- logger.error("DB rollback failed after OIDC user-resolution error: %s", rb_exc, exc_info=True)
- return RedirectResponse(url=f"{frontend_error_url}user_resolution_failed", status_code=302)
- except Exception as exc:
- # L-1: Log the exception class name internally but never expose it in the
- # redirect URL — leaking exception names aids attacker reconnaissance.
- logger.error("Unexpected error in OIDC callback (%s): %s", type(exc).__name__, exc, exc_info=True)
- try:
- return RedirectResponse(url=f"{frontend_error_url}internal_error", status_code=302)
- except Exception as redirect_exc:
- logger.error("Failed to construct error redirect in OIDC callback: %s", redirect_exc, exc_info=True)
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="OIDC callback failed")
- @router.post("/oidc/exchange", response_model=LoginResponse)
- async def oidc_exchange(
- body: OIDCExchangeRequest,
- raw_request: Request,
- response: Response,
- db: AsyncSession = Depends(get_db),
- ) -> LoginResponse:
- """Exchange an OIDC exchange token (from the callback redirect) for a full JWT.
- C4: If the resolved user has 2FA enabled the exchange returns a pre_auth_token
- (requires_2fa=True) instead of a full JWT. The frontend must then complete the
- 2FA step exactly as it would after a password-based login.
- """
- now = datetime.now(timezone.utc)
- # Atomically consume the exchange token (DELETE...RETURNING prevents replay).
- consume_result = await db.execute(
- delete(AuthEphemeralToken)
- .where(
- AuthEphemeralToken.token == body.oidc_token,
- AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE,
- AuthEphemeralToken.expires_at > now, # reject expired tokens atomically
- )
- .returning(AuthEphemeralToken.username)
- )
- row = consume_result.one_or_none()
- if row is None:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired OIDC exchange token")
- (username,) = row
- await db.commit()
- user = await get_user_by_username(db, username)
- if not user or not user.is_active:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
- # Reload with groups
- result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
- user = result.scalar_one()
- # C4: Check whether the user has any 2FA method enabled.
- totp_result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
- totp_record = totp_result.scalar_one_or_none()
- totp_enabled = totp_record is not None and totp_record.is_enabled
- email_2fa_enabled = await _get_email_2fa_enabled(db, user.id)
- if totp_enabled or email_2fa_enabled:
- # User has 2FA — issue a pre_auth_token bound to this browser session via
- # an HttpOnly cookie (H-A: mirrors the cookie-binding done in auth.py:login).
- two_fa_methods: list[str] = []
- if totp_enabled:
- two_fa_methods.append("totp")
- if email_2fa_enabled:
- two_fa_methods.append("email")
- if totp_enabled:
- two_fa_methods.append("backup")
- challenge_id = secrets.token_urlsafe(32)
- pre_auth_token = await create_pre_auth_token(db, user.username, challenge_id=challenge_id)
- response.set_cookie(
- key="2fa_challenge",
- value=challenge_id,
- httponly=True,
- secure=raw_request.url.scheme == "https",
- samesite="lax",
- max_age=300,
- path="/api/v1/auth/2fa",
- )
- return LoginResponse(
- requires_2fa=True,
- pre_auth_token=pre_auth_token,
- two_fa_methods=two_fa_methods,
- user=_user_to_response(user),
- )
- access_token = create_access_token(
- data={"sub": user.username},
- expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
- )
- return LoginResponse(
- access_token=access_token,
- token_type="bearer",
- user=_user_to_response(user),
- requires_2fa=False,
- )
- @router.get("/oidc/links", response_model=list[OIDCLinkResponse])
- async def list_oidc_links(
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> list[OIDCLinkResponse]:
- """List all OIDC provider links for the current user."""
- result = await db.execute(
- select(UserOIDCLink).where(UserOIDCLink.user_id == current_user.id).options(selectinload(UserOIDCLink.provider))
- )
- links = result.scalars().all()
- # Defensive null-check on link.provider: on PostgreSQL the FK cascade
- # ensures provider exists, but SQLite ships with FK enforcement off, so
- # a deleted provider could in theory leave the link briefly orphan until
- # the next init_db() cleanup runs. Returning "<deleted>" instead of
- # crashing keeps the endpoint usable in that edge case (#1285 follow-up).
- return [
- OIDCLinkResponse(
- id=link.id,
- provider_id=link.provider_id,
- provider_name=link.provider.name if link.provider else "<deleted>",
- provider_email=link.provider_email,
- created_at=link.created_at.isoformat(),
- )
- for link in links
- ]
- @router.delete("/oidc/links/{provider_id}")
- async def remove_oidc_link(
- provider_id: int,
- current_user: User = Depends(get_current_active_user),
- db: AsyncSession = Depends(get_db),
- ) -> dict:
- """Remove the OIDC link between the current user and a provider."""
- result = await db.execute(
- select(UserOIDCLink)
- .where(UserOIDCLink.user_id == current_user.id)
- .where(UserOIDCLink.provider_id == provider_id)
- )
- link = result.scalar_one_or_none()
- if not link:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="OIDC link not found")
- await db.delete(link)
- await db.commit()
- return {"message": "OIDC link removed"}
- # ---------------------------------------------------------------------------
- # Internal helpers
- # ---------------------------------------------------------------------------
- async def _get_base_external_url(db: AsyncSession) -> str:
- """Return the base external URL (no trailing slash, no /login suffix)."""
- external_url = await get_setting(db, "external_url")
- if external_url:
- return external_url.rstrip("/")
- return os.environ.get("APP_URL", "http://localhost:5173").rstrip("/")
|