mfa.py 74 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693
  1. """2FA (TOTP + Email OTP) and OIDC authentication routes.
  2. Security model
  3. --------------
  4. * Pre-auth tokens : secrets.token_urlsafe(32) stored in-memory with a 5-minute TTL.
  5. They are single-use and do NOT grant access to any protected resource.
  6. * TOTP codes : verified with pyotp (30-second window, ±1 step tolerance).
  7. * Email OTP codes : 6-digit numeric, hashed with pbkdf2_sha256, 10-minute TTL,
  8. max 5 failed attempts per code before invalidation.
  9. * Backup codes : 10 × 8-char alphanumeric codes, each stored as pbkdf2_sha256 hash,
  10. single-use.
  11. * OIDC state : secrets.token_urlsafe(32) bound to provider_id + nonce, 10-minute TTL.
  12. * OIDC exchange : secrets.token_urlsafe(32), 2-minute TTL, single-use.
  13. * Rate limiting : max 5 failed 2FA verification attempts per user within 15 minutes.
  14. """
  15. from __future__ import annotations
  16. import base64
  17. import hashlib
  18. import io
  19. import logging
  20. import os
  21. import re
  22. import secrets
  23. import string
  24. import urllib.parse
  25. from datetime import datetime, timedelta, timezone
  26. import httpx
  27. import jwt
  28. import pyotp
  29. from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Response, status
  30. from fastapi.responses import RedirectResponse
  31. from jwt import PyJWKClient
  32. from passlib.context import CryptContext
  33. from sqlalchemy import delete, select
  34. from sqlalchemy.ext.asyncio import AsyncSession
  35. from sqlalchemy.orm import selectinload
  36. from backend.app.api.routes.settings import get_setting, set_setting
  37. from backend.app.core.auth import (
  38. ACCESS_TOKEN_EXPIRE_MINUTES,
  39. RequirePermissionIfAuthEnabled,
  40. create_access_token,
  41. get_current_active_user,
  42. get_user_by_email,
  43. get_user_by_username,
  44. is_auth_enabled,
  45. verify_password,
  46. )
  47. from backend.app.core.database import get_db
  48. from backend.app.core.permissions import Permission
  49. from backend.app.models.auth_ephemeral import AuthEphemeralToken, AuthRateLimitEvent, EventType, TokenType
  50. from backend.app.models.group import Group
  51. from backend.app.models.oidc_provider import OIDCProvider, UserOIDCLink
  52. from backend.app.models.user import User
  53. from backend.app.models.user_otp_code import UserOTPCode
  54. from backend.app.models.user_totp import UserTOTP
  55. from backend.app.schemas.auth import (
  56. AdminDisable2FARequest,
  57. BackupCodesResponse,
  58. EmailOTPDisableRequest,
  59. EmailOTPEnableConfirmRequest,
  60. EmailOTPSendRequest,
  61. GroupBrief,
  62. LoginResponse,
  63. OIDCAuthorizeResponse,
  64. OIDCExchangeRequest,
  65. OIDCLinkResponse,
  66. OIDCProviderCreate,
  67. OIDCProviderResponse,
  68. OIDCProviderUpdate,
  69. TOTPDisableRequest,
  70. TOTPEnableRequest,
  71. TOTPEnableResponse,
  72. TOTPSetupRequest,
  73. TOTPSetupResponse,
  74. TwoFAStatusResponse,
  75. TwoFAVerifyRequest,
  76. TwoFAVerifyResponse,
  77. UserResponse,
  78. )
  79. from backend.app.services.email_service import get_smtp_settings, send_email
  80. logger = logging.getLogger(__name__)
  81. def _as_utc(dt: datetime) -> datetime:
  82. """Return *dt* with UTC timezone attached.
  83. SQLite/aiosqlite strips timezone info when reading DateTime(timezone=True)
  84. columns back – the stored value is always UTC, so we just re-attach the
  85. info when doing Python-level comparisons.
  86. """
  87. return dt if dt.tzinfo is not None else dt.replace(tzinfo=timezone.utc)
  88. # ---------------------------------------------------------------------------
  89. # Passlib context (same scheme as auth.py)
  90. # ---------------------------------------------------------------------------
  91. pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
  92. # ---------------------------------------------------------------------------
  93. # TTL / rate-limit constants
  94. # ---------------------------------------------------------------------------
  95. MAX_2FA_ATTEMPTS = 5
  96. MAX_LOGIN_ATTEMPTS = 10
  97. LOCKOUT_WINDOW = timedelta(minutes=15)
  98. MAX_EMAIL_OTP_SENDS = 3
  99. EMAIL_OTP_SEND_WINDOW = timedelta(minutes=10)
  100. PRE_AUTH_TOKEN_TTL = timedelta(minutes=5)
  101. OIDC_STATE_TTL = timedelta(minutes=10)
  102. OIDC_EXCHANGE_TTL = timedelta(minutes=2)
  103. # ---------------------------------------------------------------------------
  104. # Router
  105. # ---------------------------------------------------------------------------
  106. router = APIRouter(prefix="/auth", tags=["2fa", "oidc"])
  107. # ---------------------------------------------------------------------------
  108. # Helper: user response
  109. # ---------------------------------------------------------------------------
  110. def _user_to_response(user: User) -> UserResponse:
  111. return UserResponse(
  112. id=user.id,
  113. username=user.username,
  114. email=user.email,
  115. role=user.role,
  116. is_active=user.is_active,
  117. is_admin=user.is_admin,
  118. groups=[GroupBrief(id=g.id, name=g.name) for g in user.groups],
  119. permissions=sorted(user.get_permissions()),
  120. created_at=user.created_at.isoformat(),
  121. )
  122. # ---------------------------------------------------------------------------
  123. # Helper: QR code generation
  124. # ---------------------------------------------------------------------------
  125. def _generate_totp_qr_b64(provisioning_uri: str) -> str:
  126. """Generate a base64-encoded PNG QR code for the given TOTP provisioning URI."""
  127. import qrcode # type: ignore
  128. qr = qrcode.QRCode(box_size=6, border=2)
  129. qr.add_data(provisioning_uri)
  130. qr.make(fit=True)
  131. img = qr.make_image(fill_color="black", back_color="white")
  132. buf = io.BytesIO()
  133. img.save(buf, format="PNG")
  134. return base64.b64encode(buf.getvalue()).decode()
  135. # ---------------------------------------------------------------------------
  136. # Helper: backup code generation
  137. # ---------------------------------------------------------------------------
  138. def _generate_backup_codes() -> tuple[list[str], list[str]]:
  139. """Return (plain_codes, hashed_codes) — 10 codes of 8 alphanumeric chars each."""
  140. alphabet = string.ascii_uppercase + string.digits
  141. plain = ["".join(secrets.choice(alphabet) for _ in range(8)) for _ in range(10)]
  142. hashed = [pwd_context.hash(c) for c in plain]
  143. return plain, hashed
  144. # ---------------------------------------------------------------------------
  145. # DB-backed pre-auth token helpers
  146. # ---------------------------------------------------------------------------
  147. async def create_pre_auth_token(db: AsyncSession, username: str, challenge_id: str | None = None) -> str:
  148. """Create a single-use pre-auth token stored in the DB.
  149. Pass ``challenge_id`` (from the HttpOnly 2fa_challenge cookie) to bind the
  150. token to the originating browser session. The same value must be present as
  151. a cookie on every subsequent call that consumes this token.
  152. """
  153. now = datetime.now(timezone.utc)
  154. # Prune expired tokens opportunistically (keep table small)
  155. await db.execute(
  156. delete(AuthEphemeralToken).where(
  157. AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
  158. AuthEphemeralToken.expires_at < now,
  159. )
  160. )
  161. token = secrets.token_urlsafe(32)
  162. db.add(
  163. AuthEphemeralToken(
  164. token=token,
  165. token_type=TokenType.PRE_AUTH,
  166. username=username,
  167. challenge_id=challenge_id,
  168. expires_at=now + PRE_AUTH_TOKEN_TTL,
  169. )
  170. )
  171. await db.commit()
  172. return token
  173. async def consume_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None:
  174. """Atomically validate and consume a pre-auth token. Returns username or None.
  175. Uses DELETE...RETURNING so two concurrent requests with the same token cannot
  176. both succeed — only the first DELETE finds the row.
  177. M5: When challenge_id is provided, also enforces the cookie-binding constraint
  178. so a stolen token cannot be replayed from a different browser session.
  179. """
  180. now = datetime.now(timezone.utc)
  181. result = await db.execute(
  182. delete(AuthEphemeralToken)
  183. .where(
  184. AuthEphemeralToken.token == token,
  185. AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
  186. AuthEphemeralToken.expires_at > now,
  187. )
  188. .returning(AuthEphemeralToken.username, AuthEphemeralToken.challenge_id)
  189. )
  190. row = result.one_or_none()
  191. if row is None:
  192. return None
  193. username, stored_challenge_id = row
  194. # Enforce client binding: if the token was issued with a challenge_id,
  195. # the caller must supply the matching value.
  196. if stored_challenge_id is not None and stored_challenge_id != challenge_id:
  197. await db.rollback()
  198. return None
  199. await db.commit()
  200. return username
  201. async def peek_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None:
  202. """Validate a pre-auth token and return the username WITHOUT consuming it.
  203. When the stored token has a ``challenge_id`` (client-binding cookie), the
  204. caller must supply the matching value. A mismatch is treated as an invalid
  205. token — no information leakage about whether the token itself exists.
  206. """
  207. now = datetime.now(timezone.utc)
  208. result = await db.execute(
  209. select(AuthEphemeralToken).where(
  210. AuthEphemeralToken.token == token,
  211. AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
  212. AuthEphemeralToken.expires_at > now,
  213. )
  214. )
  215. eph = result.scalar_one_or_none()
  216. if eph is None:
  217. return None
  218. # Enforce client binding: if the token was issued with a challenge_id the
  219. # cookie must match. Treat a mismatch as if the token doesn't exist.
  220. if eph.challenge_id is not None and eph.challenge_id != challenge_id:
  221. return None
  222. return eph.username
  223. # ---------------------------------------------------------------------------
  224. # DB-backed rate-limiting helpers
  225. # ---------------------------------------------------------------------------
  226. async def check_rate_limit(
  227. db: AsyncSession,
  228. username: str,
  229. event_type: str = EventType.TWO_FA_ATTEMPT,
  230. max_attempts: int = MAX_2FA_ATTEMPTS,
  231. ) -> None:
  232. """Raise HTTP 429 if the user has exceeded the failed attempt limit.
  233. The username is normalised to lower-case so case-variant attempts
  234. (which all resolve to the same user) share the same rate-limit bucket.
  235. L-2: Known TOCTOU — the SELECT (count) and the subsequent INSERT
  236. (record_failed_attempt) are not atomic. Two concurrent requests can both
  237. read a count below the threshold and both proceed. This is an inherent
  238. trade-off of the event-log rate-limit pattern: fixing it would require
  239. a serialising lock (SELECT FOR UPDATE on a dedicated counter row), which
  240. adds contention and is not worth it for a soft rate-limit whose window is
  241. already measured in minutes. In practice the race window is microseconds
  242. and the limit can be slightly exceeded only under precise concurrent timing.
  243. """
  244. username_key = username.lower()
  245. now = datetime.now(timezone.utc)
  246. cutoff = now - LOCKOUT_WINDOW
  247. result = await db.execute(
  248. select(AuthRateLimitEvent).where(
  249. AuthRateLimitEvent.username == username_key,
  250. AuthRateLimitEvent.event_type == event_type,
  251. AuthRateLimitEvent.occurred_at > cutoff,
  252. )
  253. )
  254. recent_count = len(result.scalars().all())
  255. if recent_count >= max_attempts:
  256. raise HTTPException(
  257. status_code=status.HTTP_429_TOO_MANY_REQUESTS,
  258. detail="Too many failed attempts. Please try again later.",
  259. )
  260. async def record_failed_attempt(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None:
  261. """Record a failed attempt for rate-limiting purposes."""
  262. db.add(AuthRateLimitEvent(username=username.lower(), event_type=event_type))
  263. await db.commit()
  264. async def clear_failed_attempts(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None:
  265. """Delete all recorded failed attempts for a user on successful verification."""
  266. await db.execute(
  267. delete(AuthRateLimitEvent).where(
  268. AuthRateLimitEvent.username == username.lower(),
  269. AuthRateLimitEvent.event_type == event_type,
  270. )
  271. )
  272. await db.commit()
  273. async def check_email_otp_send_rate(db: AsyncSession, username: str) -> None:
  274. """Raise HTTP 429 if the user has requested too many OTP emails recently.
  275. I1: This function only *checks* the limit. The caller is responsible for
  276. recording the slot via ``record_email_otp_send`` **after** the email has
  277. been sent successfully. This prevents failed sends from consuming a slot
  278. (wasting the user's quota) and makes it impossible to farm rate-limit events
  279. without actually triggering a send.
  280. """
  281. username_key = username.lower()
  282. now = datetime.now(timezone.utc)
  283. cutoff = now - EMAIL_OTP_SEND_WINDOW
  284. result = await db.execute(
  285. select(AuthRateLimitEvent).where(
  286. AuthRateLimitEvent.username == username_key,
  287. AuthRateLimitEvent.event_type == EventType.EMAIL_SEND,
  288. AuthRateLimitEvent.occurred_at > cutoff,
  289. )
  290. )
  291. recent_count = len(result.scalars().all())
  292. if recent_count >= MAX_EMAIL_OTP_SENDS:
  293. raise HTTPException(
  294. status_code=status.HTTP_429_TOO_MANY_REQUESTS,
  295. detail=f"Too many OTP email requests. Please wait {EMAIL_OTP_SEND_WINDOW.seconds // 60} minutes.",
  296. )
  297. async def record_email_otp_send(db: AsyncSession, username: str) -> None:
  298. """Record a successful OTP email send for rate-limiting purposes (I1).
  299. Must be called *after* the email has been sent successfully so that failed
  300. sends do not consume a slot from the user's quota.
  301. """
  302. db.add(AuthRateLimitEvent(username=username.lower(), event_type=EventType.EMAIL_SEND))
  303. await db.commit()
  304. # ---------------------------------------------------------------------------
  305. # TOTP replay-protection helper
  306. # ---------------------------------------------------------------------------
  307. def _assert_totp_not_replayed(totp_obj: pyotp.TOTP, totp_record: UserTOTP, code: str) -> None:
  308. """Raise HTTP 400 if this TOTP code was already accepted in its time window.
  309. M3 fix: store the counter of the *accepted* code rather than the current
  310. wall-clock counter. With valid_window=1, pyotp accepts codes from the
  311. previous 30-second step. Using timecode(now) would store the wrong counter
  312. when the previous-window code is accepted, allowing immediate replay.
  313. """
  314. # Determine which time-step the accepted code belongs to.
  315. now = datetime.now(timezone.utc)
  316. accepted_counter: int | None = None
  317. for offset in (0, -1): # current window first, then previous
  318. candidate_time = now.timestamp() + offset * totp_obj.interval
  319. candidate_counter = totp_obj.timecode(datetime.fromtimestamp(candidate_time, tz=timezone.utc))
  320. if totp_obj.at(candidate_counter) == code:
  321. accepted_counter = candidate_counter
  322. break
  323. if accepted_counter is None:
  324. accepted_counter = totp_obj.timecode(now) # fallback (should not happen after verify())
  325. totp_record.accept_counter(accepted_counter)
  326. # ---------------------------------------------------------------------------
  327. # Settings helpers (email 2FA flag)
  328. # ---------------------------------------------------------------------------
  329. async def _get_email_2fa_enabled(db: AsyncSession, user_id: int) -> bool:
  330. val = await get_setting(db, f"user_{user_id}_email_2fa_enabled")
  331. return val == "true"
  332. async def _set_email_2fa_enabled(db: AsyncSession, user_id: int, enabled: bool) -> None:
  333. await set_setting(db, f"user_{user_id}_email_2fa_enabled", "true" if enabled else "false")
  334. # ===========================================================================
  335. # 2FA Endpoints
  336. # ===========================================================================
  337. @router.get("/2fa/status", response_model=TwoFAStatusResponse)
  338. async def get_2fa_status(
  339. current_user: User = Depends(get_current_active_user),
  340. db: AsyncSession = Depends(get_db),
  341. ) -> TwoFAStatusResponse:
  342. """Return the current 2FA configuration for the authenticated user."""
  343. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
  344. totp_record = result.scalar_one_or_none()
  345. totp_enabled = totp_record is not None and totp_record.is_enabled
  346. backup_codes_remaining = len(totp_record.backup_code_hashes) if totp_record else 0
  347. email_otp_enabled = await _get_email_2fa_enabled(db, current_user.id)
  348. return TwoFAStatusResponse(
  349. totp_enabled=totp_enabled,
  350. email_otp_enabled=email_otp_enabled,
  351. backup_codes_remaining=backup_codes_remaining,
  352. )
  353. @router.post("/2fa/totp/setup", response_model=TOTPSetupResponse)
  354. async def setup_totp(
  355. body: TOTPSetupRequest | None = Body(default=None),
  356. current_user: User = Depends(get_current_active_user),
  357. db: AsyncSession = Depends(get_db),
  358. ) -> TOTPSetupResponse:
  359. """Initiate TOTP setup: generates a new secret and QR code.
  360. Creates (or replaces) a pending UserTOTP record with is_enabled=False.
  361. The caller must confirm with POST /auth/2fa/totp/enable.
  362. M-R7-A: If an *active* TOTP is already configured, the caller must supply
  363. the current TOTP code in the request body to confirm intent before the
  364. secret is overwritten (prevents silently locking out the real user).
  365. """
  366. if not await is_auth_enabled(db):
  367. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Authentication is not enabled")
  368. # Upsert a pending TOTP record (is_enabled=False)
  369. existing = (await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))).scalar_one_or_none()
  370. # M-R7-A: Guard against silent TOTP replacement when one is already active.
  371. if existing and existing.is_enabled:
  372. await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  373. supplied_code = (body.code if body else None) or ""
  374. if not pyotp.TOTP(existing.secret).verify(supplied_code, valid_window=1):
  375. await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  376. raise HTTPException(
  377. status_code=status.HTTP_400_BAD_REQUEST,
  378. detail="Current TOTP code required to replace an active authenticator",
  379. )
  380. await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  381. _assert_totp_not_replayed(pyotp.TOTP(existing.secret), existing, supplied_code)
  382. await db.flush() # L-3: persist last_totp_counter immediately to block replay
  383. secret = pyotp.random_base32()
  384. totp = pyotp.TOTP(secret)
  385. provisioning_uri = totp.provisioning_uri(name=current_user.username, issuer_name="Bambuddy")
  386. qr_b64 = _generate_totp_qr_b64(provisioning_uri)
  387. if existing:
  388. existing.secret = secret
  389. existing.is_enabled = False
  390. existing.backup_code_hashes = []
  391. else:
  392. db.add(UserTOTP(user_id=current_user.id, secret=secret, is_enabled=False))
  393. await db.commit()
  394. return TOTPSetupResponse(secret=secret, qr_code_b64=qr_b64, issuer="Bambuddy")
  395. @router.post("/2fa/totp/enable", response_model=TOTPEnableResponse)
  396. async def enable_totp(
  397. body: TOTPEnableRequest,
  398. current_user: User = Depends(get_current_active_user),
  399. db: AsyncSession = Depends(get_db),
  400. ) -> TOTPEnableResponse:
  401. """Confirm TOTP setup by verifying a code from the authenticator app.
  402. On success, enables TOTP and returns 10 single-use backup codes (shown once).
  403. L-R7-A: Rate-limited to prevent brute-forcing the 6-digit confirmation code.
  404. """
  405. # L-R7-A: Rate-limit the enable step to prevent brute-forcing the 6-digit code.
  406. await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  407. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
  408. totp_record = result.scalar_one_or_none()
  409. if not totp_record:
  410. raise HTTPException(
  411. status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP setup not initiated. Call /auth/2fa/totp/setup first."
  412. )
  413. if not pyotp.TOTP(totp_record.secret).verify(body.code, valid_window=1):
  414. await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  415. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP code")
  416. await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  417. plain_codes, hashed_codes = _generate_backup_codes()
  418. totp_record.is_enabled = True
  419. totp_record.backup_code_hashes = hashed_codes
  420. await db.commit()
  421. return TOTPEnableResponse(
  422. message="TOTP enabled successfully. Store your backup codes in a safe place.",
  423. backup_codes=plain_codes,
  424. )
  425. @router.post("/2fa/totp/disable")
  426. async def disable_totp(
  427. body: TOTPDisableRequest,
  428. current_user: User = Depends(get_current_active_user),
  429. db: AsyncSession = Depends(get_db),
  430. ) -> dict:
  431. """Disable TOTP by verifying a valid TOTP code or a backup code.
  432. I10: Rate-limited to prevent backup-code brute-forcing from a hijacked session.
  433. """
  434. await check_rate_limit(db, current_user.username)
  435. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
  436. totp_record = result.scalar_one_or_none()
  437. if not totp_record or not totp_record.is_enabled:
  438. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled")
  439. # Accept either a valid TOTP code or a valid backup code
  440. totp_obj = pyotp.TOTP(totp_record.secret)
  441. code_valid = totp_obj.verify(body.code, valid_window=1)
  442. if code_valid:
  443. _assert_totp_not_replayed(totp_obj, totp_record, body.code)
  444. await db.flush() # L-3: persist last_totp_counter immediately to block replay
  445. else:
  446. # Check backup codes — always iterate all entries (L-R9-A: no early break
  447. # to avoid timing oracle based on code position in the list).
  448. for hashed in totp_record.backup_code_hashes:
  449. if pwd_context.verify(body.code, hashed):
  450. code_valid = True
  451. if not code_valid:
  452. await record_failed_attempt(db, current_user.username)
  453. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid code")
  454. await db.execute(delete(UserTOTP).where(UserTOTP.user_id == current_user.id))
  455. await db.commit()
  456. return {"message": "TOTP disabled"}
  457. @router.post("/2fa/totp/regenerate-backup-codes", response_model=BackupCodesResponse)
  458. async def regenerate_backup_codes(
  459. body: TOTPDisableRequest,
  460. current_user: User = Depends(get_current_active_user),
  461. db: AsyncSession = Depends(get_db),
  462. ) -> BackupCodesResponse:
  463. """Generate 10 new backup codes. Requires a valid TOTP code OR a backup code.
  464. M10: Accepts backup codes for consistency with disable_totp — users who have
  465. lost their authenticator app but still have backup codes can regenerate.
  466. Rate-limited to prevent brute-forcing from a hijacked session.
  467. """
  468. await check_rate_limit(db, current_user.username)
  469. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
  470. totp_record = result.scalar_one_or_none()
  471. if not totp_record or not totp_record.is_enabled:
  472. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled")
  473. totp_obj = pyotp.TOTP(totp_record.secret)
  474. code_valid = totp_obj.verify(body.code, valid_window=1)
  475. if code_valid:
  476. _assert_totp_not_replayed(totp_obj, totp_record, body.code)
  477. await db.flush() # L-3: persist last_totp_counter immediately to block replay
  478. else:
  479. # Accept a backup code as an alternative (M10)
  480. matched_index: int | None = None
  481. for idx, hashed in enumerate(totp_record.backup_code_hashes):
  482. if pwd_context.verify(body.code, hashed) and matched_index is None:
  483. matched_index = idx
  484. if matched_index is None:
  485. await record_failed_attempt(db, current_user.username)
  486. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP or backup code")
  487. # Remove the used backup code
  488. totp_record.backup_code_hashes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index]
  489. plain_codes, hashed_codes = _generate_backup_codes()
  490. totp_record.backup_code_hashes = hashed_codes
  491. await db.commit()
  492. return BackupCodesResponse(
  493. backup_codes=plain_codes,
  494. message="Backup codes regenerated. Store them safely — they will not be shown again.",
  495. )
  496. @router.post("/2fa/email/enable")
  497. async def enable_email_otp(
  498. current_user: User = Depends(get_current_active_user),
  499. db: AsyncSession = Depends(get_db),
  500. ) -> dict:
  501. """Step 1 of email OTP enable: send a verification code to the user's email.
  502. C5: Proof of possession — the user must prove they control the registered email
  503. address before email 2FA is activated. Returns a ``setup_token`` that must be
  504. passed to POST /auth/2fa/email/enable/confirm together with the received code.
  505. H-3: Rate-limited to prevent email flooding via repeated calls to this endpoint.
  506. """
  507. await check_email_otp_send_rate(db, current_user.username)
  508. if not current_user.email:
  509. raise HTTPException(
  510. status_code=status.HTTP_400_BAD_REQUEST,
  511. detail="You must have an email address configured to enable email OTP 2FA",
  512. )
  513. smtp_settings = await get_smtp_settings(db)
  514. if not smtp_settings:
  515. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured")
  516. # Generate and store the setup token (reuse AuthEphemeralToken with type "email_otp_setup")
  517. now = datetime.now(timezone.utc)
  518. # Prune any existing pending setup tokens for this user
  519. await db.execute(
  520. delete(AuthEphemeralToken).where(
  521. AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
  522. AuthEphemeralToken.username == current_user.username,
  523. )
  524. )
  525. code = str(secrets.randbelow(1_000_000)).zfill(6)
  526. code_hash = pwd_context.hash(code)
  527. setup_token = secrets.token_urlsafe(32)
  528. db.add(
  529. AuthEphemeralToken(
  530. token=setup_token,
  531. token_type=TokenType.EMAIL_OTP_SETUP,
  532. username=current_user.username,
  533. # Reuse the nonce field to store the code hash
  534. nonce=code_hash,
  535. expires_at=now + timedelta(minutes=10),
  536. )
  537. )
  538. await db.commit()
  539. try:
  540. send_email(
  541. smtp_settings=smtp_settings,
  542. to_email=current_user.email,
  543. subject="Verify your Bambuddy email address for 2FA",
  544. body_text=(
  545. f"Your Bambuddy email 2FA setup code is: {code}\n\n"
  546. "Enter this code to confirm email-based two-factor authentication.\n"
  547. "The code expires in 10 minutes."
  548. ),
  549. body_html=(
  550. "<p>To enable <strong>email-based two-factor authentication</strong> on your Bambuddy account, "
  551. "enter the code below:</p>"
  552. f"<h2 style='letter-spacing:4px'>{code}</h2>"
  553. "<p>The code expires in <strong>10 minutes</strong>. "
  554. "If you did not request this, you can safely ignore this email.</p>"
  555. ),
  556. )
  557. await record_email_otp_send(db, current_user.username)
  558. except Exception as exc:
  559. logger.error("Failed to send email OTP setup code to user_id=%d: %s", current_user.id, exc)
  560. raise HTTPException(
  561. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send verification email"
  562. )
  563. return {"message": "Verification code sent to your email address", "setup_token": setup_token}
  564. @router.post("/2fa/email/enable/confirm")
  565. async def confirm_enable_email_otp(
  566. body: EmailOTPEnableConfirmRequest,
  567. current_user: User = Depends(get_current_active_user),
  568. db: AsyncSession = Depends(get_db),
  569. ) -> dict:
  570. """Step 2 of email OTP enable: verify the code and activate email 2FA.
  571. H-2 fix: Uses peek-then-consume so a wrong code does NOT burn the setup token.
  572. The token is only deleted after successful code verification, allowing retries
  573. up to the rate limit (5 attempts / 15 min).
  574. M4: Rate-limited to prevent brute-forcing the 6-digit setup code.
  575. """
  576. await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  577. now = datetime.now(timezone.utc)
  578. # --- Peek: validate token without consuming ---
  579. peek_result = await db.execute(
  580. select(AuthEphemeralToken).where(
  581. AuthEphemeralToken.token == body.setup_token,
  582. AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
  583. AuthEphemeralToken.username == current_user.username,
  584. AuthEphemeralToken.expires_at > now,
  585. )
  586. )
  587. eph = peek_result.scalar_one_or_none()
  588. if eph is None:
  589. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token")
  590. code_hash = eph.nonce # code hash stored in the nonce field
  591. # --- Verify code before consuming the token ---
  592. if not pwd_context.verify(body.code, code_hash):
  593. await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  594. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification code")
  595. # --- Atomically consume the token now that the code is correct ---
  596. # DELETE...RETURNING prevents a concurrent request from using the same token.
  597. del_result = await db.execute(
  598. delete(AuthEphemeralToken)
  599. .where(
  600. AuthEphemeralToken.token == body.setup_token,
  601. AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
  602. AuthEphemeralToken.username == current_user.username,
  603. )
  604. .returning(AuthEphemeralToken.id)
  605. )
  606. if del_result.one_or_none() is None:
  607. # Concurrent request consumed it between peek and delete — treat as invalid.
  608. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token")
  609. await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  610. await _set_email_2fa_enabled(db, current_user.id, True)
  611. await db.commit()
  612. return {"message": "Email OTP 2FA enabled"}
  613. @router.post("/2fa/email/disable")
  614. async def disable_email_otp(
  615. body: EmailOTPDisableRequest,
  616. current_user: User = Depends(get_current_active_user),
  617. db: AsyncSession = Depends(get_db),
  618. ) -> dict:
  619. """Disable email-based OTP 2FA for the current user.
  620. C6: Re-authentication required — the caller must supply their account password
  621. to prevent a hijacked session from silently removing a second factor.
  622. LDAP/OIDC-only users (no local password) are exempt from this check.
  623. H-2: Rate-limited to prevent brute-forcing the password via this endpoint.
  624. """
  625. await check_rate_limit(db, current_user.username)
  626. if current_user.password_hash:
  627. if not verify_password(body.password, current_user.password_hash):
  628. await record_failed_attempt(db, current_user.username)
  629. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password")
  630. await _set_email_2fa_enabled(db, current_user.id, False)
  631. await db.commit()
  632. return {"message": "Email OTP 2FA disabled"}
  633. @router.post("/2fa/email/send")
  634. async def send_email_otp(
  635. request: Request,
  636. body: EmailOTPSendRequest,
  637. db: AsyncSession = Depends(get_db),
  638. ) -> dict:
  639. """Send a 6-digit OTP code to the user's email address.
  640. Requires a valid pre_auth_token obtained during the login flow.
  641. """
  642. # Peek (validate without consuming) first so a rate-limit rejection does not
  643. # permanently burn the caller's pre-auth token.
  644. challenge_id = request.cookies.get("2fa_challenge")
  645. username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  646. if not username:
  647. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  648. # Enforce rate limit BEFORE consuming the token to prevent OTP email flooding.
  649. await check_email_otp_send_rate(db, username)
  650. user = await get_user_by_username(db, username)
  651. if not user or not user.is_active:
  652. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
  653. if not user.email:
  654. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no email address configured")
  655. smtp_settings = await get_smtp_settings(db)
  656. if not smtp_settings:
  657. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured")
  658. # Invalidate all existing unused OTP codes for this user (staged, not yet committed)
  659. await db.execute(
  660. UserOTPCode.__table__.update() # type: ignore[attr-defined]
  661. .where(UserOTPCode.user_id == user.id)
  662. .where(UserOTPCode.used.is_(False))
  663. .values(used=True)
  664. )
  665. # Generate a 6-digit code and stage the record (not committed yet)
  666. code = str(secrets.randbelow(1_000_000)).zfill(6)
  667. code_hash = pwd_context.hash(code)
  668. expires_at = datetime.now(timezone.utc) + timedelta(minutes=UserOTPCode.OTP_TTL_MINUTES)
  669. otp_record = UserOTPCode(
  670. user_id=user.id,
  671. code_hash=code_hash,
  672. attempts=0,
  673. used=False,
  674. expires_at=expires_at,
  675. )
  676. db.add(otp_record)
  677. # M2: Send the email BEFORE consuming the pre-auth token.
  678. # If the send fails we raise an exception here; the session is uncommitted so
  679. # the OTP record is discarded and the original token remains valid for retry.
  680. try:
  681. send_email(
  682. smtp_settings=smtp_settings,
  683. to_email=user.email,
  684. subject="Your Bambuddy verification code",
  685. body_text=f"Your Bambuddy login code is: {code}\n\nThis code expires in {UserOTPCode.OTP_TTL_MINUTES} minutes and can only be used once.",
  686. body_html=(
  687. f"<p>Your <strong>Bambuddy</strong> login verification code is:</p>"
  688. f"<h2 style='letter-spacing:4px'>{code}</h2>"
  689. f"<p>This code expires in <strong>{UserOTPCode.OTP_TTL_MINUTES} minutes</strong> and can only be used once.</p>"
  690. f"<p>If you did not request this code, you can safely ignore this email.</p>"
  691. ),
  692. )
  693. await record_email_otp_send(db, username)
  694. except Exception as exc:
  695. logger.error("Failed to send OTP email to user_id=%d: %s", user.id, exc)
  696. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send OTP email")
  697. # Email sent — now atomically consume the old token (this also commits the
  698. # staged OTP record) and issue a fresh token for the verify step.
  699. consumed = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  700. if not consumed:
  701. # Raced with another request or token just expired — treat as invalid.
  702. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  703. # Re-issue a fresh pre-auth token bound to the same cookie so the binding
  704. # carries forward through the email → verify step.
  705. fresh_token = await create_pre_auth_token(db, username, challenge_id=challenge_id)
  706. # Return the fresh pre-auth token so the frontend can proceed to verify
  707. return {"message": "Code sent to your email address", "pre_auth_token": fresh_token}
  708. @router.post("/2fa/verify", response_model=TwoFAVerifyResponse)
  709. async def verify_2fa(
  710. request: Request,
  711. body: TwoFAVerifyRequest,
  712. db: AsyncSession = Depends(get_db),
  713. ) -> TwoFAVerifyResponse:
  714. """Verify a 2FA code and exchange the pre_auth_token for a full JWT.
  715. Accepted methods: ``totp``, ``email``, ``backup``.
  716. The pre_auth_token is NOT consumed on failed verification attempts so the
  717. user can retry without restarting the login flow. It is only consumed once
  718. verification succeeds, preventing token replay after success.
  719. """
  720. # Peek without consuming — bad codes must not burn the session token.
  721. # Pass the HttpOnly challenge cookie so the binding check is enforced.
  722. challenge_id = request.cookies.get("2fa_challenge")
  723. username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  724. if not username:
  725. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  726. await check_rate_limit(db, username)
  727. user = await get_user_by_username(db, username)
  728. if not user or not user.is_active:
  729. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
  730. method = body.method
  731. if method == "totp":
  732. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
  733. totp_record = result.scalar_one_or_none()
  734. if not totp_record or not totp_record.is_enabled:
  735. await record_failed_attempt(db, username)
  736. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user")
  737. totp_obj = pyotp.TOTP(totp_record.secret)
  738. if not totp_obj.verify(body.code, valid_window=1):
  739. await record_failed_attempt(db, username)
  740. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid TOTP code")
  741. _assert_totp_not_replayed(totp_obj, totp_record, body.code)
  742. await db.flush() # L-3: persist last_totp_counter immediately to block replay
  743. elif method == "email":
  744. now = datetime.now(timezone.utc)
  745. result = await db.execute(
  746. select(UserOTPCode)
  747. .where(UserOTPCode.user_id == user.id)
  748. .where(UserOTPCode.used.is_(False))
  749. .where(UserOTPCode.expires_at > now)
  750. .order_by(UserOTPCode.created_at.desc())
  751. )
  752. otp_record = result.scalar_one_or_none()
  753. if not otp_record:
  754. await record_failed_attempt(db, username)
  755. raise HTTPException(
  756. status_code=status.HTTP_401_UNAUTHORIZED, detail="No valid OTP code found. Request a new one."
  757. )
  758. if otp_record.attempts >= UserOTPCode.MAX_ATTEMPTS:
  759. otp_record.consume()
  760. await db.commit()
  761. await record_failed_attempt(db, username)
  762. raise HTTPException(
  763. status_code=status.HTTP_401_UNAUTHORIZED, detail="OTP code has been invalidated after too many attempts"
  764. )
  765. if not pwd_context.verify(body.code, otp_record.code_hash):
  766. otp_record.attempts += 1
  767. await db.commit()
  768. await record_failed_attempt(db, username)
  769. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid OTP code")
  770. otp_record.consume()
  771. await db.commit()
  772. else: # method == "backup"
  773. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
  774. totp_record = result.scalar_one_or_none()
  775. if not totp_record or not totp_record.is_enabled:
  776. await record_failed_attempt(db, username)
  777. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user")
  778. # Always iterate all codes — no early break (L-R9-A: constant iteration
  779. # count prevents timing oracle based on used-code position in the list).
  780. matched_index: int | None = None
  781. for idx, hashed in enumerate(totp_record.backup_code_hashes):
  782. if pwd_context.verify(body.code, hashed) and matched_index is None:
  783. matched_index = idx
  784. if matched_index is None:
  785. await record_failed_attempt(db, username)
  786. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid backup code")
  787. # M1: Consume the pre-auth token FIRST (atomic single-use enforcement).
  788. # Only if that succeeds do we remove the backup code — this prevents a race
  789. # where two concurrent requests both pass code verification but only one
  790. # should be granted a session.
  791. consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  792. if not consumed_username:
  793. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  794. # Remove the used backup code now that the token is atomically consumed.
  795. updated_codes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index]
  796. totp_record.backup_code_hashes = updated_codes
  797. await db.commit()
  798. await clear_failed_attempts(db, username)
  799. access_token = create_access_token(
  800. data={"sub": user.username},
  801. expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
  802. )
  803. result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
  804. user = result.scalar_one()
  805. return TwoFAVerifyResponse(access_token=access_token, token_type="bearer", user=_user_to_response(user))
  806. # Verification succeeded (TOTP or email) — consume the pre-auth token.
  807. # C-1: Check the return value; if None the token was already consumed by a
  808. # concurrent request (race condition) — reject to prevent double-use.
  809. consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  810. if not consumed_username:
  811. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  812. await clear_failed_attempts(db, username)
  813. access_token = create_access_token(
  814. data={"sub": user.username},
  815. expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
  816. )
  817. # Reload with groups for permission calculation
  818. result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
  819. user = result.scalar_one()
  820. return TwoFAVerifyResponse(
  821. access_token=access_token,
  822. token_type="bearer",
  823. user=_user_to_response(user),
  824. )
  825. @router.delete("/2fa/admin/{user_id}")
  826. async def admin_disable_2fa(
  827. user_id: int,
  828. body: AdminDisable2FARequest = Body(default_factory=AdminDisable2FARequest),
  829. current_user: User | None = RequirePermissionIfAuthEnabled(Permission.USERS_UPDATE),
  830. db: AsyncSession = Depends(get_db),
  831. ) -> dict:
  832. """Admin endpoint: disable all 2FA for a given user.
  833. Nit 3: Requires the admin's own password as a re-auth step (matching how
  834. disable_email_otp protects a user's own 2FA removal). OIDC/LDAP-only admins
  835. (no local password_hash) are exempt.
  836. """
  837. # Nit 3: Re-auth — admin must supply their own password.
  838. if current_user and current_user.password_hash:
  839. if not body.admin_password or not verify_password(body.admin_password, current_user.password_hash):
  840. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin password required")
  841. # Delete TOTP record
  842. await db.execute(delete(UserTOTP).where(UserTOTP.user_id == user_id))
  843. # Disable email 2FA setting
  844. await _set_email_2fa_enabled(db, user_id, False)
  845. # Invalidate all OTP codes
  846. await db.execute(
  847. UserOTPCode.__table__.update() # type: ignore[attr-defined]
  848. .where(UserOTPCode.user_id == user_id)
  849. .values(used=True)
  850. )
  851. # I2: Invalidate existing JWTs for the target user by bumping password_changed_at.
  852. # Without this, a stolen token remains valid after 2FA removal.
  853. target_user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none()
  854. if target_user:
  855. target_user.password_changed_at = datetime.now(timezone.utc)
  856. await db.commit()
  857. actor = current_user.username if current_user else "anonymous"
  858. logger.info("Admin %s disabled all 2FA for user_id=%d", actor, user_id)
  859. return {"message": "2FA disabled for user"}
  860. # ===========================================================================
  861. # OIDC Endpoints
  862. # ===========================================================================
  863. @router.get("/oidc/providers", response_model=list[OIDCProviderResponse])
  864. async def list_oidc_providers(
  865. db: AsyncSession = Depends(get_db),
  866. ) -> list[OIDCProviderResponse]:
  867. """List all enabled OIDC providers (public)."""
  868. result = await db.execute(select(OIDCProvider).where(OIDCProvider.is_enabled.is_(True)))
  869. providers = result.scalars().all()
  870. return [OIDCProviderResponse.model_validate(p) for p in providers]
  871. @router.get("/oidc/providers/all", response_model=list[OIDCProviderResponse])
  872. async def list_all_oidc_providers(
  873. _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_READ),
  874. db: AsyncSession = Depends(get_db),
  875. ) -> list[OIDCProviderResponse]:
  876. """List ALL OIDC providers including disabled ones (admin only)."""
  877. result2 = await db.execute(select(OIDCProvider))
  878. providers = result2.scalars().all()
  879. return [OIDCProviderResponse.model_validate(p) for p in providers]
  880. @router.post("/oidc/providers", response_model=OIDCProviderResponse, status_code=status.HTTP_201_CREATED)
  881. async def create_oidc_provider(
  882. body: OIDCProviderCreate,
  883. _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
  884. db: AsyncSession = Depends(get_db),
  885. ) -> OIDCProviderResponse:
  886. """Create a new OIDC provider (admin only)."""
  887. provider = OIDCProvider(
  888. name=body.name,
  889. issuer_url=body.issuer_url.rstrip("/"),
  890. client_id=body.client_id,
  891. client_secret=body.client_secret,
  892. scopes=body.scopes,
  893. is_enabled=body.is_enabled,
  894. auto_create_users=body.auto_create_users,
  895. icon_url=body.icon_url,
  896. )
  897. db.add(provider)
  898. await db.commit()
  899. await db.refresh(provider)
  900. return OIDCProviderResponse.model_validate(provider)
  901. @router.put("/oidc/providers/{provider_id}", response_model=OIDCProviderResponse)
  902. async def update_oidc_provider(
  903. provider_id: int,
  904. body: OIDCProviderUpdate,
  905. _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
  906. db: AsyncSession = Depends(get_db),
  907. ) -> OIDCProviderResponse:
  908. """Update an existing OIDC provider (admin only)."""
  909. result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
  910. provider = result2.scalar_one_or_none()
  911. if not provider:
  912. raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
  913. for field, value in body.model_dump(exclude_none=True).items():
  914. if field == "issuer_url" and value:
  915. value = value.rstrip("/")
  916. setattr(provider, field, value)
  917. await db.commit()
  918. await db.refresh(provider)
  919. return OIDCProviderResponse.model_validate(provider)
  920. @router.delete("/oidc/providers/{provider_id}")
  921. async def delete_oidc_provider(
  922. provider_id: int,
  923. _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
  924. db: AsyncSession = Depends(get_db),
  925. ) -> dict:
  926. """Delete an OIDC provider and all its user links (admin only)."""
  927. result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
  928. provider = result2.scalar_one_or_none()
  929. if not provider:
  930. raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
  931. await db.delete(provider)
  932. await db.commit()
  933. return {"message": "Provider deleted"}
  934. @router.get("/oidc/authorize/{provider_id}", response_model=OIDCAuthorizeResponse)
  935. async def oidc_authorize(
  936. provider_id: int,
  937. db: AsyncSession = Depends(get_db),
  938. ) -> OIDCAuthorizeResponse:
  939. """Return the OIDC authorization URL for the given provider."""
  940. result = await db.execute(
  941. select(OIDCProvider).where(OIDCProvider.id == provider_id).where(OIDCProvider.is_enabled.is_(True))
  942. )
  943. provider = result.scalar_one_or_none()
  944. if not provider:
  945. raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found or not enabled")
  946. # Fetch discovery document
  947. discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration"
  948. try:
  949. async with httpx.AsyncClient(timeout=10) as client:
  950. resp = await client.get(discovery_url)
  951. resp.raise_for_status()
  952. discovery = resp.json()
  953. except Exception as exc:
  954. logger.error("Failed to fetch OIDC discovery for provider %d: %s", provider_id, exc)
  955. raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch OIDC discovery document")
  956. authorization_endpoint = discovery.get("authorization_endpoint")
  957. if not authorization_endpoint:
  958. raise HTTPException(
  959. status_code=status.HTTP_502_BAD_GATEWAY, detail="OIDC discovery document missing authorization_endpoint"
  960. )
  961. # B2: SSRF guard — reject non-HTTP(S) schemes in the authorization endpoint
  962. if not authorization_endpoint.startswith(("https://", "http://")):
  963. logger.warning("OIDC discovery authorization_endpoint has invalid scheme: %s", authorization_endpoint)
  964. raise HTTPException(
  965. status_code=status.HTTP_502_BAD_GATEWAY,
  966. detail="OIDC discovery document contains invalid authorization_endpoint",
  967. )
  968. external_url = await _get_base_external_url(db)
  969. redirect_uri = f"{external_url}/api/v1/auth/oidc/callback"
  970. now = datetime.now(timezone.utc)
  971. # Prune expired OIDC states from the DB
  972. await db.execute(
  973. delete(AuthEphemeralToken).where(
  974. AuthEphemeralToken.token_type == TokenType.OIDC_STATE,
  975. AuthEphemeralToken.expires_at < now,
  976. )
  977. )
  978. state = secrets.token_urlsafe(32)
  979. nonce = secrets.token_urlsafe(32)
  980. # PKCE (S256) – required by PocketID and recommended for all OIDC flows
  981. code_verifier = secrets.token_urlsafe(48) # 64-char URL-safe string
  982. code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
  983. db.add(
  984. AuthEphemeralToken(
  985. token=state,
  986. token_type=TokenType.OIDC_STATE,
  987. provider_id=provider_id,
  988. nonce=nonce,
  989. code_verifier=code_verifier,
  990. expires_at=now + OIDC_STATE_TTL,
  991. )
  992. )
  993. await db.commit()
  994. params = urllib.parse.urlencode(
  995. {
  996. "response_type": "code",
  997. "client_id": provider.client_id,
  998. "redirect_uri": redirect_uri,
  999. "scope": provider.scopes,
  1000. "state": state,
  1001. "nonce": nonce,
  1002. "code_challenge": code_challenge,
  1003. "code_challenge_method": "S256",
  1004. }
  1005. )
  1006. auth_url = f"{authorization_endpoint}?{params}"
  1007. return OIDCAuthorizeResponse(auth_url=auth_url)
  1008. @router.get("/oidc/callback")
  1009. async def oidc_callback(
  1010. code: str | None = Query(default=None, max_length=2048),
  1011. state: str | None = Query(default=None, max_length=2048),
  1012. error: str | None = Query(default=None, max_length=256),
  1013. db: AsyncSession = Depends(get_db),
  1014. ) -> RedirectResponse:
  1015. """Handle the OIDC authorization code callback from the identity provider."""
  1016. external_url = await _get_base_external_url(db)
  1017. frontend_error_url = f"{external_url}/?oidc_error="
  1018. try:
  1019. if error:
  1020. logger.warning("OIDC callback received error: %s", error)
  1021. return RedirectResponse(url=f"{frontend_error_url}oidc_provider_error", status_code=302)
  1022. if not code or not state:
  1023. return RedirectResponse(url=f"{frontend_error_url}missing_parameters", status_code=302)
  1024. # Atomically validate and consume OIDC state from DB (I6: single-use enforcement).
  1025. # DELETE...RETURNING ensures concurrent callbacks with the same state token
  1026. # cannot both succeed — only the first DELETE finds the row.
  1027. now = datetime.now(timezone.utc)
  1028. state_del = await db.execute(
  1029. delete(AuthEphemeralToken)
  1030. .where(
  1031. AuthEphemeralToken.token == state,
  1032. AuthEphemeralToken.token_type == TokenType.OIDC_STATE,
  1033. AuthEphemeralToken.expires_at > now, # reject expired tokens atomically
  1034. )
  1035. .returning(
  1036. AuthEphemeralToken.provider_id,
  1037. AuthEphemeralToken.nonce,
  1038. AuthEphemeralToken.code_verifier,
  1039. )
  1040. )
  1041. state_row = state_del.one_or_none()
  1042. if state_row is None:
  1043. await db.rollback()
  1044. return RedirectResponse(url=f"{frontend_error_url}invalid_state", status_code=302)
  1045. provider_id, nonce, code_verifier = state_row
  1046. await db.commit()
  1047. # Load provider
  1048. result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
  1049. provider = result.scalar_one_or_none()
  1050. if not provider:
  1051. return RedirectResponse(url=f"{frontend_error_url}provider_not_found", status_code=302)
  1052. redirect_uri = f"{external_url}/api/v1/auth/oidc/callback"
  1053. # ── Step 1: Fetch discovery document ────────────────────────────────
  1054. discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration"
  1055. try:
  1056. async with httpx.AsyncClient(timeout=10) as client:
  1057. disc_resp = await client.get(discovery_url)
  1058. disc_resp.raise_for_status()
  1059. discovery = disc_resp.json()
  1060. except Exception as exc:
  1061. logger.error("OIDC discovery fetch failed for provider %d: %s", provider_id, exc)
  1062. return RedirectResponse(url=f"{frontend_error_url}discovery_failed", status_code=302)
  1063. token_endpoint = discovery.get("token_endpoint")
  1064. jwks_uri = discovery.get("jwks_uri")
  1065. if not token_endpoint or not jwks_uri:
  1066. return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302)
  1067. # L-R7-C: Reject non-HTTP(S) URLs in the discovery document to prevent
  1068. # SSRF via crafted responses (e.g. file://, gopher://, internal schemes).
  1069. if not token_endpoint.startswith(("https://", "http://")) or not jwks_uri.startswith(("https://", "http://")):
  1070. logger.warning(
  1071. "OIDC discovery document contains non-HTTP URL(s): token=%s jwks=%s", token_endpoint, jwks_uri
  1072. )
  1073. return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302)
  1074. # ── Step 2: Exchange authorization code for tokens ───────────────────
  1075. token_form: dict[str, str] = {
  1076. "grant_type": "authorization_code",
  1077. "code": code,
  1078. "redirect_uri": redirect_uri,
  1079. "client_id": provider.client_id,
  1080. }
  1081. if provider.client_secret:
  1082. token_form["client_secret"] = provider.client_secret
  1083. if code_verifier:
  1084. token_form["code_verifier"] = code_verifier
  1085. try:
  1086. async with httpx.AsyncClient(timeout=15) as client:
  1087. token_resp = await client.post(
  1088. token_endpoint,
  1089. data=token_form,
  1090. headers={"Accept": "application/json"},
  1091. )
  1092. except Exception as exc:
  1093. logger.error("OIDC token exchange request failed for provider %d: %s", provider_id, exc)
  1094. return RedirectResponse(url=f"{frontend_error_url}token_exchange_network_error", status_code=302)
  1095. if not token_resp.is_success:
  1096. try:
  1097. err_body = token_resp.json()
  1098. oidc_err = err_body.get("error", "")
  1099. oidc_desc = err_body.get("error_description", "")
  1100. except Exception:
  1101. oidc_err = ""
  1102. oidc_desc = token_resp.text[:200]
  1103. logger.error(
  1104. "OIDC token exchange HTTP %d for provider %d. redirect_uri=%r error=%r desc=%r",
  1105. token_resp.status_code,
  1106. provider_id,
  1107. redirect_uri,
  1108. oidc_err,
  1109. oidc_desc,
  1110. )
  1111. # Encode the OIDC error code into the redirect so the user sees it in the toast.
  1112. # URL-encode the value to prevent query-parameter injection from provider responses.
  1113. raw_err = oidc_err[:40] if oidc_err else str(token_resp.status_code)
  1114. safe_err = urllib.parse.quote(raw_err, safe="")
  1115. return RedirectResponse(
  1116. url=f"{frontend_error_url}token_exchange_{safe_err}",
  1117. status_code=302,
  1118. )
  1119. try:
  1120. token_data = token_resp.json()
  1121. except Exception as exc:
  1122. logger.error("OIDC token exchange non-JSON response for provider %d: %s", provider_id, exc)
  1123. return RedirectResponse(url=f"{frontend_error_url}token_exchange_bad_response", status_code=302)
  1124. id_token = token_data.get("id_token")
  1125. if not id_token:
  1126. # Only log the keys present — values may contain secrets (access_token, etc.)
  1127. logger.error(
  1128. "OIDC token response missing id_token for provider %d; keys present: %s",
  1129. provider_id,
  1130. list(token_data.keys()),
  1131. )
  1132. return RedirectResponse(url=f"{frontend_error_url}no_id_token", status_code=302)
  1133. # ── Step 3: Fetch JWKS and validate ID token ─────────────────────────
  1134. # Use the issuer from the discovery document as the canonical value (OIDC Core
  1135. # §3.1.3.7 requires iss == discovery issuer exactly). We strip trailing slashes
  1136. # from both sides because some providers (e.g. Authentik, older PocketID versions)
  1137. # are inconsistent between the discovery issuer and the JWT iss claim.
  1138. discovery_issuer: str = discovery.get("issuer", provider.issuer_url).rstrip("/")
  1139. try:
  1140. async with httpx.AsyncClient(timeout=10) as jwks_http:
  1141. jwks_resp = await jwks_http.get(jwks_uri)
  1142. jwks_resp.raise_for_status()
  1143. jwks_data = jwks_resp.json()
  1144. jwks_client = PyJWKClient(jwks_uri)
  1145. jwks_client.fetch_data = lambda: jwks_data # type: ignore[method-assign]
  1146. signing_key = jwks_client.get_signing_key_from_jwt(id_token)
  1147. # M-3: Decode without built-in issuer check, then compare normalised
  1148. # (both sides rstrip("/")) to handle providers like Authentik that include
  1149. # a trailing slash in iss but not in the discovery issuer, or vice-versa.
  1150. claims = jwt.decode(
  1151. id_token,
  1152. signing_key.key,
  1153. algorithms=["RS256", "ES256", "RS384", "ES384", "RS512"],
  1154. audience=provider.client_id,
  1155. options={"verify_iss": False},
  1156. )
  1157. token_iss = claims.get("iss", "").rstrip("/")
  1158. if token_iss != discovery_issuer:
  1159. raise jwt.exceptions.InvalidIssuerError("Invalid issuer")
  1160. except Exception as exc:
  1161. logger.error("OIDC JWT validation failed for provider %d: %s", provider_id, exc, exc_info=True)
  1162. return RedirectResponse(url=f"{frontend_error_url}token_validation_failed", status_code=302)
  1163. # Verify nonce — fail closed: we always send a nonce, so the provider must echo it.
  1164. # Skipping the check when nonce is absent would allow CSRF on non-nonce providers.
  1165. token_nonce = claims.get("nonce")
  1166. if token_nonce is None or token_nonce != nonce:
  1167. logger.warning("OIDC nonce mismatch for provider %d (present=%r)", provider_id, token_nonce is not None)
  1168. return RedirectResponse(url=f"{frontend_error_url}nonce_mismatch", status_code=302)
  1169. provider_sub: str = claims.get("sub", "")
  1170. if not provider_sub:
  1171. return RedirectResponse(url=f"{frontend_error_url}missing_sub_claim", status_code=302)
  1172. # C1: Only trust the email claim when the provider explicitly marks it verified.
  1173. # Treating absent email_verified as verified enables account-takeover: an attacker
  1174. # could register an unverified email with an IdP and auto-link to an existing account.
  1175. # Fail closed: require email_verified == True; absent/False both drop the email.
  1176. raw_email: str | None = claims.get("email")
  1177. email_verified = claims.get("email_verified")
  1178. if email_verified is not True:
  1179. if raw_email:
  1180. logger.info(
  1181. "OIDC provider %d: ignoring email for sub=%r because email_verified=%r",
  1182. provider_id,
  1183. provider_sub,
  1184. email_verified,
  1185. )
  1186. provider_email: str | None = None
  1187. else:
  1188. provider_email = raw_email
  1189. # ── Step 4: Resolve / create user ────────────────────────────────────
  1190. try:
  1191. # 1. Look up existing OIDC link
  1192. link_result = await db.execute(
  1193. select(UserOIDCLink)
  1194. .where(UserOIDCLink.provider_id == provider_id)
  1195. .where(UserOIDCLink.provider_user_id == provider_sub)
  1196. )
  1197. link = link_result.scalar_one_or_none()
  1198. user: User | None = None
  1199. if link:
  1200. # Existing link → load the linked user
  1201. user_result = await db.execute(
  1202. select(User).where(User.id == link.user_id).options(selectinload(User.groups))
  1203. )
  1204. user = user_result.scalar_one_or_none()
  1205. else:
  1206. # 2. No OIDC link yet — check for an existing user with the same email.
  1207. # Use case-insensitive matching (func.lower) so that "User@Example.com"
  1208. # and "user@example.com" are treated as the same identity, preventing
  1209. # an attacker-controlled IdP from bypassing the auto-link guard by
  1210. # registering the target email with different casing.
  1211. email_user: User | None = None
  1212. if provider_email:
  1213. email_user = await get_user_by_email(db, provider_email)
  1214. if email_user and provider.auto_link_existing_accounts:
  1215. # M-4: Only auto-link when the provider has auto_link_existing_accounts
  1216. # enabled. Operators can disable this to require explicit account linking,
  1217. # preventing an attacker-controlled IdP from hijacking local accounts.
  1218. #
  1219. # M-NEW-6: Refuse auto-link if the target user already has any OIDC
  1220. # link (to any provider). Without this guard an attacker who controls
  1221. # a second OIDC provider with auto_link enabled could add themselves as
  1222. # a second IdP for a user that already authenticates via a legitimate
  1223. # provider, effectively taking over the account.
  1224. existing_links_result = await db.execute(
  1225. select(UserOIDCLink).where(UserOIDCLink.user_id == email_user.id)
  1226. )
  1227. has_existing_oidc_link = existing_links_result.scalar_one_or_none() is not None
  1228. if has_existing_oidc_link:
  1229. logger.warning(
  1230. "Auto-link rejected for user '%s': already linked to another OIDC provider",
  1231. email_user.username,
  1232. )
  1233. return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302)
  1234. db.add(
  1235. UserOIDCLink(
  1236. user_id=email_user.id,
  1237. provider_id=provider_id,
  1238. provider_user_id=provider_sub,
  1239. provider_email=provider_email,
  1240. )
  1241. )
  1242. await db.commit()
  1243. user = email_user
  1244. logger.info(
  1245. "Auto-linked existing user '%s' to OIDC provider %d via email match",
  1246. email_user.username,
  1247. provider_id,
  1248. )
  1249. elif provider.auto_create_users:
  1250. # 3. No existing user — create one
  1251. if provider_email:
  1252. raw = provider_email.split("@")[0]
  1253. else:
  1254. raw = provider_sub[:30]
  1255. candidate = re.sub(r"[^a-zA-Z0-9._-]", "", raw)[:30] or "oidcuser"
  1256. username = candidate
  1257. counter = 1
  1258. while True:
  1259. existing = await get_user_by_username(db, username)
  1260. if not existing:
  1261. break
  1262. username = f"{candidate}{counter}"
  1263. counter += 1
  1264. # I9: Assign new OIDC users to the default "Viewers" group so they
  1265. # have read-only access rather than starting with no permissions.
  1266. # Fetch the group BEFORE creating the user so we can set the
  1267. # relationship before flush — accessing new_user.groups after a
  1268. # flush triggers a lazy-load which fails in async context.
  1269. viewers_result = await db.execute(select(Group).where(Group.name == "Viewers"))
  1270. viewers_group = viewers_result.scalar_one_or_none()
  1271. new_user = User(
  1272. username=username,
  1273. email=provider_email,
  1274. # M-1: auth_source="oidc" prevents local password-reset flow
  1275. # for users who should only authenticate via OIDC.
  1276. auth_source="oidc",
  1277. password_hash=None, # OIDC users never use password auth
  1278. role="user",
  1279. is_active=True,
  1280. groups=[viewers_group] if viewers_group else [],
  1281. )
  1282. db.add(new_user)
  1283. await db.flush()
  1284. db.add(
  1285. UserOIDCLink(
  1286. user_id=new_user.id,
  1287. provider_id=provider_id,
  1288. provider_user_id=provider_sub,
  1289. provider_email=provider_email,
  1290. )
  1291. )
  1292. await db.commit()
  1293. user_result = await db.execute(
  1294. select(User).where(User.id == new_user.id).options(selectinload(User.groups))
  1295. )
  1296. user = user_result.scalar_one()
  1297. logger.info("Auto-created user '%s' via OIDC provider %d", username, provider_id)
  1298. else:
  1299. return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302)
  1300. if not user or not user.is_active:
  1301. return RedirectResponse(url=f"{frontend_error_url}account_inactive", status_code=302)
  1302. # Issue an OIDC exchange token (short-lived, single-use) stored in DB.
  1303. # I7: Opportunistically prune expired exchange tokens to keep the table small.
  1304. now2 = datetime.now(timezone.utc)
  1305. await db.execute(
  1306. delete(AuthEphemeralToken).where(
  1307. AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE,
  1308. AuthEphemeralToken.expires_at < now2,
  1309. )
  1310. )
  1311. exchange_token = secrets.token_urlsafe(32)
  1312. db.add(
  1313. AuthEphemeralToken(
  1314. token=exchange_token,
  1315. token_type=TokenType.OIDC_EXCHANGE,
  1316. username=user.username,
  1317. expires_at=now2 + OIDC_EXCHANGE_TTL,
  1318. )
  1319. )
  1320. await db.commit()
  1321. # H-4: Use a URL fragment (#) instead of a query parameter so the exchange
  1322. # token is never sent to the server in the Referer header or server logs.
  1323. return RedirectResponse(url=f"{external_url}/login#oidc_token={exchange_token}", status_code=302)
  1324. except Exception as exc:
  1325. logger.error("OIDC user resolution failed for provider %d: %s", provider_id, exc, exc_info=True)
  1326. try:
  1327. await db.rollback()
  1328. except Exception as rb_exc:
  1329. logger.error("DB rollback failed after OIDC user-resolution error: %s", rb_exc, exc_info=True)
  1330. return RedirectResponse(url=f"{frontend_error_url}user_resolution_failed", status_code=302)
  1331. except Exception as exc:
  1332. # L-1: Log the exception class name internally but never expose it in the
  1333. # redirect URL — leaking exception names aids attacker reconnaissance.
  1334. logger.error("Unexpected error in OIDC callback (%s): %s", type(exc).__name__, exc, exc_info=True)
  1335. try:
  1336. return RedirectResponse(url=f"{frontend_error_url}internal_error", status_code=302)
  1337. except Exception:
  1338. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="OIDC callback failed")
  1339. @router.post("/oidc/exchange", response_model=LoginResponse)
  1340. async def oidc_exchange(
  1341. body: OIDCExchangeRequest,
  1342. raw_request: Request,
  1343. response: Response,
  1344. db: AsyncSession = Depends(get_db),
  1345. ) -> LoginResponse:
  1346. """Exchange an OIDC exchange token (from the callback redirect) for a full JWT.
  1347. C4: If the resolved user has 2FA enabled the exchange returns a pre_auth_token
  1348. (requires_2fa=True) instead of a full JWT. The frontend must then complete the
  1349. 2FA step exactly as it would after a password-based login.
  1350. """
  1351. now = datetime.now(timezone.utc)
  1352. # Atomically consume the exchange token (DELETE...RETURNING prevents replay).
  1353. consume_result = await db.execute(
  1354. delete(AuthEphemeralToken)
  1355. .where(
  1356. AuthEphemeralToken.token == body.oidc_token,
  1357. AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE,
  1358. AuthEphemeralToken.expires_at > now, # reject expired tokens atomically
  1359. )
  1360. .returning(AuthEphemeralToken.username)
  1361. )
  1362. row = consume_result.one_or_none()
  1363. if row is None:
  1364. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired OIDC exchange token")
  1365. (username,) = row
  1366. await db.commit()
  1367. user = await get_user_by_username(db, username)
  1368. if not user or not user.is_active:
  1369. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
  1370. # Reload with groups
  1371. result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
  1372. user = result.scalar_one()
  1373. # C4: Check whether the user has any 2FA method enabled.
  1374. totp_result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
  1375. totp_record = totp_result.scalar_one_or_none()
  1376. totp_enabled = totp_record is not None and totp_record.is_enabled
  1377. email_2fa_enabled = await _get_email_2fa_enabled(db, user.id)
  1378. if totp_enabled or email_2fa_enabled:
  1379. # User has 2FA — issue a pre_auth_token bound to this browser session via
  1380. # an HttpOnly cookie (H-A: mirrors the cookie-binding done in auth.py:login).
  1381. two_fa_methods: list[str] = []
  1382. if totp_enabled:
  1383. two_fa_methods.append("totp")
  1384. if email_2fa_enabled:
  1385. two_fa_methods.append("email")
  1386. if totp_enabled:
  1387. two_fa_methods.append("backup")
  1388. challenge_id = secrets.token_urlsafe(32)
  1389. pre_auth_token = await create_pre_auth_token(db, user.username, challenge_id=challenge_id)
  1390. response.set_cookie(
  1391. key="2fa_challenge",
  1392. value=challenge_id,
  1393. httponly=True,
  1394. secure=raw_request.url.scheme == "https",
  1395. samesite="lax",
  1396. max_age=300,
  1397. path="/api/v1/auth/2fa",
  1398. )
  1399. return LoginResponse(
  1400. requires_2fa=True,
  1401. pre_auth_token=pre_auth_token,
  1402. two_fa_methods=two_fa_methods,
  1403. user=_user_to_response(user),
  1404. )
  1405. access_token = create_access_token(
  1406. data={"sub": user.username},
  1407. expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
  1408. )
  1409. return LoginResponse(
  1410. access_token=access_token,
  1411. token_type="bearer",
  1412. user=_user_to_response(user),
  1413. requires_2fa=False,
  1414. )
  1415. @router.get("/oidc/links", response_model=list[OIDCLinkResponse])
  1416. async def list_oidc_links(
  1417. current_user: User = Depends(get_current_active_user),
  1418. db: AsyncSession = Depends(get_db),
  1419. ) -> list[OIDCLinkResponse]:
  1420. """List all OIDC provider links for the current user."""
  1421. result = await db.execute(
  1422. select(UserOIDCLink).where(UserOIDCLink.user_id == current_user.id).options(selectinload(UserOIDCLink.provider))
  1423. )
  1424. links = result.scalars().all()
  1425. return [
  1426. OIDCLinkResponse(
  1427. id=link.id,
  1428. provider_id=link.provider_id,
  1429. provider_name=link.provider.name,
  1430. provider_email=link.provider_email,
  1431. created_at=link.created_at.isoformat(),
  1432. )
  1433. for link in links
  1434. ]
  1435. @router.delete("/oidc/links/{provider_id}")
  1436. async def remove_oidc_link(
  1437. provider_id: int,
  1438. current_user: User = Depends(get_current_active_user),
  1439. db: AsyncSession = Depends(get_db),
  1440. ) -> dict:
  1441. """Remove the OIDC link between the current user and a provider."""
  1442. result = await db.execute(
  1443. select(UserOIDCLink)
  1444. .where(UserOIDCLink.user_id == current_user.id)
  1445. .where(UserOIDCLink.provider_id == provider_id)
  1446. )
  1447. link = result.scalar_one_or_none()
  1448. if not link:
  1449. raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="OIDC link not found")
  1450. await db.delete(link)
  1451. await db.commit()
  1452. return {"message": "OIDC link removed"}
  1453. # ---------------------------------------------------------------------------
  1454. # Internal helpers
  1455. # ---------------------------------------------------------------------------
  1456. async def _get_base_external_url(db: AsyncSession) -> str:
  1457. """Return the base external URL (no trailing slash, no /login suffix)."""
  1458. external_url = await get_setting(db, "external_url")
  1459. if external_url:
  1460. return external_url.rstrip("/")
  1461. return os.environ.get("APP_URL", "http://localhost:5173").rstrip("/")