mfa.py 78 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793
  1. """2FA (TOTP + Email OTP) and OIDC authentication routes.
  2. Security model
  3. --------------
  4. * Pre-auth tokens : secrets.token_urlsafe(32) stored in-memory with a 5-minute TTL.
  5. They are single-use and do NOT grant access to any protected resource.
  6. * TOTP codes : verified with pyotp (30-second window, ±1 step tolerance).
  7. * Email OTP codes : 6-digit numeric, hashed with pbkdf2_sha256, 10-minute TTL,
  8. max 5 failed attempts per code before invalidation.
  9. * Backup codes : 10 × 8-char alphanumeric codes, each stored as pbkdf2_sha256 hash,
  10. single-use.
  11. * OIDC state : secrets.token_urlsafe(32) bound to provider_id + nonce, 10-minute TTL.
  12. * OIDC exchange : secrets.token_urlsafe(32), 2-minute TTL, single-use.
  13. * Rate limiting : max 5 failed 2FA verification attempts per user within 15 minutes.
  14. """
  15. from __future__ import annotations
  16. import base64
  17. import hashlib
  18. import io
  19. import logging
  20. import os
  21. import re
  22. import secrets
  23. import string
  24. import urllib.parse
  25. from datetime import datetime, timedelta, timezone
  26. import httpx
  27. import jwt
  28. import pyotp
  29. from fastapi import APIRouter, Body, Depends, HTTPException, Query, Request, Response, status
  30. from fastapi.responses import RedirectResponse
  31. from jwt import PyJWKClient
  32. from passlib.context import CryptContext
  33. from sqlalchemy import delete, select
  34. from sqlalchemy.ext.asyncio import AsyncSession
  35. from sqlalchemy.orm import selectinload
  36. from backend.app.api.routes.settings import get_setting, set_setting
  37. from backend.app.core.auth import (
  38. ACCESS_TOKEN_EXPIRE_MINUTES,
  39. RequirePermissionIfAuthEnabled,
  40. create_access_token,
  41. get_current_active_user,
  42. get_user_by_email,
  43. get_user_by_username,
  44. is_auth_enabled,
  45. verify_password,
  46. )
  47. from backend.app.core.database import get_db
  48. from backend.app.core.permissions import Permission
  49. from backend.app.models.auth_ephemeral import AuthEphemeralToken, AuthRateLimitEvent, EventType, TokenType
  50. from backend.app.models.group import Group
  51. from backend.app.models.oidc_provider import OIDCProvider, UserOIDCLink
  52. from backend.app.models.user import User
  53. from backend.app.models.user_otp_code import UserOTPCode
  54. from backend.app.models.user_totp import UserTOTP
  55. from backend.app.schemas.auth import (
  56. AUTO_LINK_REQUIREMENTS_ERROR,
  57. AdminDisable2FARequest,
  58. BackupCodesResponse,
  59. EmailOTPDisableRequest,
  60. EmailOTPEnableConfirmRequest,
  61. EmailOTPSendRequest,
  62. GroupBrief,
  63. LoginResponse,
  64. OIDCAuthorizeResponse,
  65. OIDCExchangeRequest,
  66. OIDCLinkResponse,
  67. OIDCProviderCreate,
  68. OIDCProviderResponse,
  69. OIDCProviderUpdate,
  70. TOTPDisableRequest,
  71. TOTPEnableRequest,
  72. TOTPEnableResponse,
  73. TOTPSetupRequest,
  74. TOTPSetupResponse,
  75. TwoFAStatusResponse,
  76. TwoFAVerifyRequest,
  77. TwoFAVerifyResponse,
  78. UserResponse,
  79. )
  80. from backend.app.services.email_service import get_smtp_settings, send_email
  81. logger = logging.getLogger(__name__)
  82. def _as_utc(dt: datetime) -> datetime:
  83. """Return *dt* with UTC timezone attached.
  84. SQLite/aiosqlite strips timezone info when reading DateTime(timezone=True)
  85. columns back – the stored value is always UTC, so we just re-attach the
  86. info when doing Python-level comparisons.
  87. """
  88. return dt if dt.tzinfo is not None else dt.replace(tzinfo=timezone.utc)
  89. # ---------------------------------------------------------------------------
  90. # Passlib context (same scheme as auth.py)
  91. # ---------------------------------------------------------------------------
  92. pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
  93. # ---------------------------------------------------------------------------
  94. # TTL / rate-limit constants
  95. # ---------------------------------------------------------------------------
  96. MAX_2FA_ATTEMPTS = 5
  97. MAX_LOGIN_ATTEMPTS = 10
  98. LOCKOUT_WINDOW = timedelta(minutes=15)
  99. MAX_EMAIL_OTP_SENDS = 3
  100. EMAIL_OTP_SEND_WINDOW = timedelta(minutes=10)
  101. PRE_AUTH_TOKEN_TTL = timedelta(minutes=5)
  102. OIDC_STATE_TTL = timedelta(minutes=10)
  103. OIDC_EXCHANGE_TTL = timedelta(minutes=2)
  104. # ---------------------------------------------------------------------------
  105. # Router
  106. # ---------------------------------------------------------------------------
  107. router = APIRouter(prefix="/auth", tags=["2fa", "oidc"])
  108. # ---------------------------------------------------------------------------
  109. # Helper: user response
  110. # ---------------------------------------------------------------------------
  111. def _user_to_response(user: User) -> UserResponse:
  112. return UserResponse(
  113. id=user.id,
  114. username=user.username,
  115. email=user.email,
  116. role=user.role,
  117. is_active=user.is_active,
  118. is_admin=user.is_admin,
  119. groups=[GroupBrief(id=g.id, name=g.name) for g in user.groups],
  120. permissions=sorted(user.get_permissions()),
  121. created_at=user.created_at.isoformat(),
  122. )
  123. # ---------------------------------------------------------------------------
  124. # Helper: QR code generation
  125. # ---------------------------------------------------------------------------
  126. def _generate_totp_qr_b64(provisioning_uri: str) -> str:
  127. """Generate a base64-encoded PNG QR code for the given TOTP provisioning URI."""
  128. import qrcode # type: ignore
  129. qr = qrcode.QRCode(box_size=6, border=2)
  130. qr.add_data(provisioning_uri)
  131. qr.make(fit=True)
  132. img = qr.make_image(fill_color="black", back_color="white")
  133. buf = io.BytesIO()
  134. img.save(buf, format="PNG")
  135. return base64.b64encode(buf.getvalue()).decode()
  136. # ---------------------------------------------------------------------------
  137. # Helper: backup code generation
  138. # ---------------------------------------------------------------------------
  139. def _generate_backup_codes() -> tuple[list[str], list[str]]:
  140. """Return (plain_codes, hashed_codes) — 10 codes of 8 alphanumeric chars each."""
  141. alphabet = string.ascii_uppercase + string.digits
  142. plain = ["".join(secrets.choice(alphabet) for _ in range(8)) for _ in range(10)]
  143. hashed = [pwd_context.hash(c) for c in plain]
  144. return plain, hashed
  145. # ---------------------------------------------------------------------------
  146. # DB-backed pre-auth token helpers
  147. # ---------------------------------------------------------------------------
  148. async def create_pre_auth_token(db: AsyncSession, username: str, challenge_id: str | None = None) -> str:
  149. """Create a single-use pre-auth token stored in the DB.
  150. Pass ``challenge_id`` (from the HttpOnly 2fa_challenge cookie) to bind the
  151. token to the originating browser session. The same value must be present as
  152. a cookie on every subsequent call that consumes this token.
  153. """
  154. now = datetime.now(timezone.utc)
  155. # Prune expired tokens opportunistically (keep table small)
  156. await db.execute(
  157. delete(AuthEphemeralToken).where(
  158. AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
  159. AuthEphemeralToken.expires_at < now,
  160. )
  161. )
  162. token = secrets.token_urlsafe(32)
  163. db.add(
  164. AuthEphemeralToken(
  165. token=token,
  166. token_type=TokenType.PRE_AUTH,
  167. username=username,
  168. challenge_id=challenge_id,
  169. expires_at=now + PRE_AUTH_TOKEN_TTL,
  170. )
  171. )
  172. await db.commit()
  173. return token
  174. async def consume_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None:
  175. """Atomically validate and consume a pre-auth token. Returns username or None.
  176. Uses DELETE...RETURNING so two concurrent requests with the same token cannot
  177. both succeed — only the first DELETE finds the row.
  178. M5: When challenge_id is provided, also enforces the cookie-binding constraint
  179. so a stolen token cannot be replayed from a different browser session.
  180. """
  181. now = datetime.now(timezone.utc)
  182. result = await db.execute(
  183. delete(AuthEphemeralToken)
  184. .where(
  185. AuthEphemeralToken.token == token,
  186. AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
  187. AuthEphemeralToken.expires_at > now,
  188. )
  189. .returning(AuthEphemeralToken.username, AuthEphemeralToken.challenge_id)
  190. )
  191. row = result.one_or_none()
  192. if row is None:
  193. return None
  194. username, stored_challenge_id = row
  195. # Enforce client binding: if the token was issued with a challenge_id,
  196. # the caller must supply the matching value.
  197. if stored_challenge_id is not None and stored_challenge_id != challenge_id:
  198. await db.rollback()
  199. return None
  200. await db.commit()
  201. return username
  202. async def peek_pre_auth_token(db: AsyncSession, token: str, challenge_id: str | None = None) -> str | None:
  203. """Validate a pre-auth token and return the username WITHOUT consuming it.
  204. When the stored token has a ``challenge_id`` (client-binding cookie), the
  205. caller must supply the matching value. A mismatch is treated as an invalid
  206. token — no information leakage about whether the token itself exists.
  207. """
  208. now = datetime.now(timezone.utc)
  209. result = await db.execute(
  210. select(AuthEphemeralToken).where(
  211. AuthEphemeralToken.token == token,
  212. AuthEphemeralToken.token_type == TokenType.PRE_AUTH,
  213. AuthEphemeralToken.expires_at > now,
  214. )
  215. )
  216. eph = result.scalar_one_or_none()
  217. if eph is None:
  218. return None
  219. # Enforce client binding: if the token was issued with a challenge_id the
  220. # cookie must match. Treat a mismatch as if the token doesn't exist.
  221. if eph.challenge_id is not None and eph.challenge_id != challenge_id:
  222. return None
  223. return eph.username
  224. # ---------------------------------------------------------------------------
  225. # DB-backed rate-limiting helpers
  226. # ---------------------------------------------------------------------------
  227. async def check_rate_limit(
  228. db: AsyncSession,
  229. username: str,
  230. event_type: str = EventType.TWO_FA_ATTEMPT,
  231. max_attempts: int = MAX_2FA_ATTEMPTS,
  232. ) -> None:
  233. """Raise HTTP 429 if the user has exceeded the failed attempt limit.
  234. The username is normalised to lower-case so case-variant attempts
  235. (which all resolve to the same user) share the same rate-limit bucket.
  236. L-2: Known TOCTOU — the SELECT (count) and the subsequent INSERT
  237. (record_failed_attempt) are not atomic. Two concurrent requests can both
  238. read a count below the threshold and both proceed. This is an inherent
  239. trade-off of the event-log rate-limit pattern: fixing it would require
  240. a serialising lock (SELECT FOR UPDATE on a dedicated counter row), which
  241. adds contention and is not worth it for a soft rate-limit whose window is
  242. already measured in minutes. In practice the race window is microseconds
  243. and the limit can be slightly exceeded only under precise concurrent timing.
  244. """
  245. username_key = username.lower()
  246. now = datetime.now(timezone.utc)
  247. cutoff = now - LOCKOUT_WINDOW
  248. result = await db.execute(
  249. select(AuthRateLimitEvent).where(
  250. AuthRateLimitEvent.username == username_key,
  251. AuthRateLimitEvent.event_type == event_type,
  252. AuthRateLimitEvent.occurred_at > cutoff,
  253. )
  254. )
  255. recent_count = len(result.scalars().all())
  256. if recent_count >= max_attempts:
  257. raise HTTPException(
  258. status_code=status.HTTP_429_TOO_MANY_REQUESTS,
  259. detail="Too many failed attempts. Please try again later.",
  260. )
  261. async def record_failed_attempt(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None:
  262. """Record a failed attempt for rate-limiting purposes."""
  263. db.add(AuthRateLimitEvent(username=username.lower(), event_type=event_type))
  264. await db.commit()
  265. async def clear_failed_attempts(db: AsyncSession, username: str, event_type: str = EventType.TWO_FA_ATTEMPT) -> None:
  266. """Delete all recorded failed attempts for a user on successful verification."""
  267. await db.execute(
  268. delete(AuthRateLimitEvent).where(
  269. AuthRateLimitEvent.username == username.lower(),
  270. AuthRateLimitEvent.event_type == event_type,
  271. )
  272. )
  273. await db.commit()
  274. async def check_email_otp_send_rate(db: AsyncSession, username: str) -> None:
  275. """Raise HTTP 429 if the user has requested too many OTP emails recently.
  276. I1: This function only *checks* the limit. The caller is responsible for
  277. recording the slot via ``record_email_otp_send`` **after** the email has
  278. been sent successfully. This prevents failed sends from consuming a slot
  279. (wasting the user's quota) and makes it impossible to farm rate-limit events
  280. without actually triggering a send.
  281. """
  282. username_key = username.lower()
  283. now = datetime.now(timezone.utc)
  284. cutoff = now - EMAIL_OTP_SEND_WINDOW
  285. result = await db.execute(
  286. select(AuthRateLimitEvent).where(
  287. AuthRateLimitEvent.username == username_key,
  288. AuthRateLimitEvent.event_type == EventType.EMAIL_SEND,
  289. AuthRateLimitEvent.occurred_at > cutoff,
  290. )
  291. )
  292. recent_count = len(result.scalars().all())
  293. if recent_count >= MAX_EMAIL_OTP_SENDS:
  294. raise HTTPException(
  295. status_code=status.HTTP_429_TOO_MANY_REQUESTS,
  296. detail=f"Too many OTP email requests. Please wait {EMAIL_OTP_SEND_WINDOW.seconds // 60} minutes.",
  297. )
  298. async def record_email_otp_send(db: AsyncSession, username: str) -> None:
  299. """Record a successful OTP email send for rate-limiting purposes (I1).
  300. Must be called *after* the email has been sent successfully so that failed
  301. sends do not consume a slot from the user's quota.
  302. """
  303. db.add(AuthRateLimitEvent(username=username.lower(), event_type=EventType.EMAIL_SEND))
  304. await db.commit()
  305. # ---------------------------------------------------------------------------
  306. # TOTP replay-protection helper
  307. # ---------------------------------------------------------------------------
  308. def _assert_totp_not_replayed(totp_obj: pyotp.TOTP, totp_record: UserTOTP, code: str) -> None:
  309. """Raise HTTP 400 if this TOTP code was already accepted in its time window.
  310. M3 fix: store the counter of the *accepted* code rather than the current
  311. wall-clock counter. With valid_window=1, pyotp accepts codes from the
  312. previous 30-second step. Using timecode(now) would store the wrong counter
  313. when the previous-window code is accepted, allowing immediate replay.
  314. """
  315. # Determine which time-step the accepted code belongs to.
  316. now = datetime.now(timezone.utc)
  317. accepted_counter: int | None = None
  318. for offset in (0, -1): # current window first, then previous
  319. candidate_time = now.timestamp() + offset * totp_obj.interval
  320. candidate_counter = totp_obj.timecode(datetime.fromtimestamp(candidate_time, tz=timezone.utc))
  321. if totp_obj.at(candidate_counter) == code:
  322. accepted_counter = candidate_counter
  323. break
  324. if accepted_counter is None:
  325. accepted_counter = totp_obj.timecode(now) # fallback (should not happen after verify())
  326. totp_record.accept_counter(accepted_counter)
  327. # ---------------------------------------------------------------------------
  328. # OIDC helpers
  329. # ---------------------------------------------------------------------------
  330. _EMAIL_SHAPE_RE = re.compile(r"[^\s@]+@[^\s@]+\.[^\s@]+")
  331. def _is_valid_email_shaped(value: str | None) -> bool:
  332. # SEC-2: shape check for non-standard claims (upn, preferred_username).
  333. # Requires local@domain.tld — rejects "@", "x@", "@domain", "x@nodot".
  334. if not value or len(value) > 255:
  335. return False
  336. return _EMAIL_SHAPE_RE.fullmatch(value) is not None
  337. def _enforce_auto_link_safety(provider: OIDCProvider) -> None:
  338. """Raise HTTP 422 if auto_link_existing_accounts is on without safe email settings.
  339. SEC-1/SEC-6: auto_link is only safe when require_email_verified=True AND
  340. email_claim='email'. Called after ORM construction (create) and after the
  341. setattr loop (update) so partial-update bypasses are also caught.
  342. """
  343. if provider.auto_link_existing_accounts and (
  344. not provider.require_email_verified or provider.email_claim != "email"
  345. ):
  346. raise HTTPException(
  347. status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
  348. detail=AUTO_LINK_REQUIREMENTS_ERROR,
  349. )
  350. def _resolve_provider_email(provider: OIDCProvider, claims: dict, provider_sub: str) -> str | None:
  351. """Extract and normalise the email address from OIDC ID-token claims.
  352. Implements three resolution paths (Fall A/B/C):
  353. Fall C — custom email_claim (!= "email"): shape-check only, no email_verified gate.
  354. Recommended for Azure Entra ID (preferred_username or upn).
  355. Fall A — email_claim="email" + require_email_verified=True: strict, email_verified must be True.
  356. Fall B — email_claim="email" + require_email_verified=False: permissive, explicit False drops email.
  357. Returns a lowercase-stripped email string, or None when the claim is absent/invalid.
  358. """
  359. provider_id = provider.id
  360. raw_claim_value = claims.get(provider.email_claim)
  361. if raw_claim_value is not None and not isinstance(raw_claim_value, str):
  362. # TYPE-GUARD: non-string claim (e.g. list, int) would raise AttributeError on .lower().
  363. logger.warning(
  364. "OIDC provider %d: email_claim %r has unexpected type %s for sub=%r, ignoring",
  365. provider_id,
  366. provider.email_claim,
  367. type(raw_claim_value).__name__,
  368. provider_sub,
  369. )
  370. raw_claim_value = None
  371. raw_email: str | None = raw_claim_value.lower().strip() if raw_claim_value else None
  372. if provider.email_claim != "email":
  373. # Fall C: custom claim (preferred_username, upn, …) — no email_verified check.
  374. # SEC-2: _is_valid_email_shaped instead of bare '"@" in value'.
  375. # Recommended for Azure Entra ID: set email_claim="preferred_username" or "upn".
  376. if raw_email and _is_valid_email_shaped(raw_email):
  377. return raw_email
  378. if raw_email:
  379. logger.info(
  380. "OIDC provider %d: email_claim %r value failed shape check for sub=%r, ignoring",
  381. provider_id,
  382. provider.email_claim,
  383. provider_sub,
  384. )
  385. return None
  386. email_verified = claims.get("email_verified")
  387. if provider.require_email_verified:
  388. # Fall A: standard C1-Guard — fail closed unless email_verified is True.
  389. # SEC-3 normalisation applies; existing mixed-case provider_email records
  390. # were normalised to lowercase by run_migrations at startup.
  391. if email_verified is True:
  392. return raw_email
  393. if raw_email:
  394. logger.info(
  395. "OIDC provider %d: ignoring email for sub=%r because email_verified=%r",
  396. provider_id,
  397. provider_sub,
  398. email_verified,
  399. )
  400. return None
  401. # Fall B: permissive — explicit False drops email, absent/None keeps it.
  402. # Required for Azure Entra ID which never sends email_verified.
  403. if email_verified is False:
  404. return None
  405. if email_verified is not True:
  406. # SEC-5: log only when the permissive path actually fires (ev absent/None),
  407. # not on every successful login.
  408. logger.info(
  409. "OIDC provider %r (%d): accepting email for sub=%r without email_verified claim (permissive mode)",
  410. provider.name,
  411. provider.id,
  412. provider_sub,
  413. )
  414. return raw_email
  415. # ---------------------------------------------------------------------------
  416. # Settings helpers (email 2FA flag)
  417. # ---------------------------------------------------------------------------
  418. async def _get_email_2fa_enabled(db: AsyncSession, user_id: int) -> bool:
  419. val = await get_setting(db, f"user_{user_id}_email_2fa_enabled")
  420. return val == "true"
  421. async def _set_email_2fa_enabled(db: AsyncSession, user_id: int, enabled: bool) -> None:
  422. await set_setting(db, f"user_{user_id}_email_2fa_enabled", "true" if enabled else "false")
  423. # ===========================================================================
  424. # 2FA Endpoints
  425. # ===========================================================================
  426. @router.get("/2fa/status", response_model=TwoFAStatusResponse)
  427. async def get_2fa_status(
  428. current_user: User = Depends(get_current_active_user),
  429. db: AsyncSession = Depends(get_db),
  430. ) -> TwoFAStatusResponse:
  431. """Return the current 2FA configuration for the authenticated user."""
  432. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
  433. totp_record = result.scalar_one_or_none()
  434. totp_enabled = totp_record is not None and totp_record.is_enabled
  435. backup_codes_remaining = len(totp_record.backup_code_hashes) if totp_record else 0
  436. email_otp_enabled = await _get_email_2fa_enabled(db, current_user.id)
  437. return TwoFAStatusResponse(
  438. totp_enabled=totp_enabled,
  439. email_otp_enabled=email_otp_enabled,
  440. backup_codes_remaining=backup_codes_remaining,
  441. )
  442. @router.post("/2fa/totp/setup", response_model=TOTPSetupResponse)
  443. async def setup_totp(
  444. body: TOTPSetupRequest | None = Body(default=None),
  445. current_user: User = Depends(get_current_active_user),
  446. db: AsyncSession = Depends(get_db),
  447. ) -> TOTPSetupResponse:
  448. """Initiate TOTP setup: generates a new secret and QR code.
  449. Creates (or replaces) a pending UserTOTP record with is_enabled=False.
  450. The caller must confirm with POST /auth/2fa/totp/enable.
  451. M-R7-A: If an *active* TOTP is already configured, the caller must supply
  452. the current TOTP code in the request body to confirm intent before the
  453. secret is overwritten (prevents silently locking out the real user).
  454. """
  455. if not await is_auth_enabled(db):
  456. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Authentication is not enabled")
  457. # Upsert a pending TOTP record (is_enabled=False)
  458. existing = (await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))).scalar_one_or_none()
  459. # M-R7-A: Guard against silent TOTP replacement when one is already active.
  460. if existing and existing.is_enabled:
  461. await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  462. supplied_code = (body.code if body else None) or ""
  463. if not pyotp.TOTP(existing.secret).verify(supplied_code, valid_window=1):
  464. await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  465. raise HTTPException(
  466. status_code=status.HTTP_400_BAD_REQUEST,
  467. detail="Current TOTP code required to replace an active authenticator",
  468. )
  469. await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  470. _assert_totp_not_replayed(pyotp.TOTP(existing.secret), existing, supplied_code)
  471. await db.flush() # L-3: persist last_totp_counter immediately to block replay
  472. secret = pyotp.random_base32()
  473. totp = pyotp.TOTP(secret)
  474. provisioning_uri = totp.provisioning_uri(name=current_user.username, issuer_name="Bambuddy")
  475. qr_b64 = _generate_totp_qr_b64(provisioning_uri)
  476. if existing:
  477. existing.secret = secret
  478. existing.is_enabled = False
  479. existing.backup_code_hashes = []
  480. else:
  481. db.add(UserTOTP(user_id=current_user.id, secret=secret, is_enabled=False))
  482. await db.commit()
  483. return TOTPSetupResponse(secret=secret, qr_code_b64=qr_b64, issuer="Bambuddy")
  484. @router.post("/2fa/totp/enable", response_model=TOTPEnableResponse)
  485. async def enable_totp(
  486. body: TOTPEnableRequest,
  487. current_user: User = Depends(get_current_active_user),
  488. db: AsyncSession = Depends(get_db),
  489. ) -> TOTPEnableResponse:
  490. """Confirm TOTP setup by verifying a code from the authenticator app.
  491. On success, enables TOTP and returns 10 single-use backup codes (shown once).
  492. L-R7-A: Rate-limited to prevent brute-forcing the 6-digit confirmation code.
  493. """
  494. # L-R7-A: Rate-limit the enable step to prevent brute-forcing the 6-digit code.
  495. await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  496. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
  497. totp_record = result.scalar_one_or_none()
  498. if not totp_record:
  499. raise HTTPException(
  500. status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP setup not initiated. Call /auth/2fa/totp/setup first."
  501. )
  502. if not pyotp.TOTP(totp_record.secret).verify(body.code, valid_window=1):
  503. await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  504. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP code")
  505. await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  506. plain_codes, hashed_codes = _generate_backup_codes()
  507. totp_record.is_enabled = True
  508. totp_record.backup_code_hashes = hashed_codes
  509. await db.commit()
  510. return TOTPEnableResponse(
  511. message="TOTP enabled successfully. Store your backup codes in a safe place.",
  512. backup_codes=plain_codes,
  513. )
  514. @router.post("/2fa/totp/disable")
  515. async def disable_totp(
  516. body: TOTPDisableRequest,
  517. current_user: User = Depends(get_current_active_user),
  518. db: AsyncSession = Depends(get_db),
  519. ) -> dict:
  520. """Disable TOTP by verifying a valid TOTP code or a backup code.
  521. I10: Rate-limited to prevent backup-code brute-forcing from a hijacked session.
  522. """
  523. await check_rate_limit(db, current_user.username)
  524. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
  525. totp_record = result.scalar_one_or_none()
  526. if not totp_record or not totp_record.is_enabled:
  527. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled")
  528. # Accept either a valid TOTP code or a valid backup code
  529. totp_obj = pyotp.TOTP(totp_record.secret)
  530. code_valid = totp_obj.verify(body.code, valid_window=1)
  531. if code_valid:
  532. _assert_totp_not_replayed(totp_obj, totp_record, body.code)
  533. await db.flush() # L-3: persist last_totp_counter immediately to block replay
  534. else:
  535. # Check backup codes — always iterate all entries (L-R9-A: no early break
  536. # to avoid timing oracle based on code position in the list).
  537. for hashed in totp_record.backup_code_hashes:
  538. if pwd_context.verify(body.code, hashed):
  539. code_valid = True
  540. if not code_valid:
  541. await record_failed_attempt(db, current_user.username)
  542. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid code")
  543. await db.execute(delete(UserTOTP).where(UserTOTP.user_id == current_user.id))
  544. await db.commit()
  545. return {"message": "TOTP disabled"}
  546. @router.post("/2fa/totp/regenerate-backup-codes", response_model=BackupCodesResponse)
  547. async def regenerate_backup_codes(
  548. body: TOTPDisableRequest,
  549. current_user: User = Depends(get_current_active_user),
  550. db: AsyncSession = Depends(get_db),
  551. ) -> BackupCodesResponse:
  552. """Generate 10 new backup codes. Requires a valid TOTP code OR a backup code.
  553. M10: Accepts backup codes for consistency with disable_totp — users who have
  554. lost their authenticator app but still have backup codes can regenerate.
  555. Rate-limited to prevent brute-forcing from a hijacked session.
  556. """
  557. await check_rate_limit(db, current_user.username)
  558. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == current_user.id))
  559. totp_record = result.scalar_one_or_none()
  560. if not totp_record or not totp_record.is_enabled:
  561. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled")
  562. totp_obj = pyotp.TOTP(totp_record.secret)
  563. code_valid = totp_obj.verify(body.code, valid_window=1)
  564. if code_valid:
  565. _assert_totp_not_replayed(totp_obj, totp_record, body.code)
  566. await db.flush() # L-3: persist last_totp_counter immediately to block replay
  567. else:
  568. # Accept a backup code as an alternative (M10)
  569. matched_index: int | None = None
  570. for idx, hashed in enumerate(totp_record.backup_code_hashes):
  571. if pwd_context.verify(body.code, hashed) and matched_index is None:
  572. matched_index = idx
  573. if matched_index is None:
  574. await record_failed_attempt(db, current_user.username)
  575. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid TOTP or backup code")
  576. # Remove the used backup code
  577. totp_record.backup_code_hashes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index]
  578. plain_codes, hashed_codes = _generate_backup_codes()
  579. totp_record.backup_code_hashes = hashed_codes
  580. await db.commit()
  581. return BackupCodesResponse(
  582. backup_codes=plain_codes,
  583. message="Backup codes regenerated. Store them safely — they will not be shown again.",
  584. )
  585. @router.post("/2fa/email/enable")
  586. async def enable_email_otp(
  587. current_user: User = Depends(get_current_active_user),
  588. db: AsyncSession = Depends(get_db),
  589. ) -> dict:
  590. """Step 1 of email OTP enable: send a verification code to the user's email.
  591. C5: Proof of possession — the user must prove they control the registered email
  592. address before email 2FA is activated. Returns a ``setup_token`` that must be
  593. passed to POST /auth/2fa/email/enable/confirm together with the received code.
  594. H-3: Rate-limited to prevent email flooding via repeated calls to this endpoint.
  595. """
  596. await check_email_otp_send_rate(db, current_user.username)
  597. if not current_user.email:
  598. raise HTTPException(
  599. status_code=status.HTTP_400_BAD_REQUEST,
  600. detail="You must have an email address configured to enable email OTP 2FA",
  601. )
  602. smtp_settings = await get_smtp_settings(db)
  603. if not smtp_settings:
  604. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured")
  605. # Generate and store the setup token (reuse AuthEphemeralToken with type "email_otp_setup")
  606. now = datetime.now(timezone.utc)
  607. # Prune any existing pending setup tokens for this user
  608. await db.execute(
  609. delete(AuthEphemeralToken).where(
  610. AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
  611. AuthEphemeralToken.username == current_user.username,
  612. )
  613. )
  614. code = str(secrets.randbelow(1_000_000)).zfill(6)
  615. code_hash = pwd_context.hash(code)
  616. setup_token = secrets.token_urlsafe(32)
  617. db.add(
  618. AuthEphemeralToken(
  619. token=setup_token,
  620. token_type=TokenType.EMAIL_OTP_SETUP,
  621. username=current_user.username,
  622. # Reuse the nonce field to store the code hash
  623. nonce=code_hash,
  624. expires_at=now + timedelta(minutes=10),
  625. )
  626. )
  627. await db.commit()
  628. try:
  629. send_email(
  630. smtp_settings=smtp_settings,
  631. to_email=current_user.email,
  632. subject="Verify your Bambuddy email address for 2FA",
  633. body_text=(
  634. f"Your Bambuddy email 2FA setup code is: {code}\n\n"
  635. "Enter this code to confirm email-based two-factor authentication.\n"
  636. "The code expires in 10 minutes."
  637. ),
  638. body_html=(
  639. "<p>To enable <strong>email-based two-factor authentication</strong> on your Bambuddy account, "
  640. "enter the code below:</p>"
  641. f"<h2 style='letter-spacing:4px'>{code}</h2>"
  642. "<p>The code expires in <strong>10 minutes</strong>. "
  643. "If you did not request this, you can safely ignore this email.</p>"
  644. ),
  645. )
  646. await record_email_otp_send(db, current_user.username)
  647. except Exception as exc:
  648. logger.error("Failed to send email OTP setup code to user_id=%d: %s", current_user.id, exc)
  649. raise HTTPException(
  650. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send verification email"
  651. )
  652. return {"message": "Verification code sent to your email address", "setup_token": setup_token}
  653. @router.post("/2fa/email/enable/confirm")
  654. async def confirm_enable_email_otp(
  655. body: EmailOTPEnableConfirmRequest,
  656. current_user: User = Depends(get_current_active_user),
  657. db: AsyncSession = Depends(get_db),
  658. ) -> dict:
  659. """Step 2 of email OTP enable: verify the code and activate email 2FA.
  660. H-2 fix: Uses peek-then-consume so a wrong code does NOT burn the setup token.
  661. The token is only deleted after successful code verification, allowing retries
  662. up to the rate limit (5 attempts / 15 min).
  663. M4: Rate-limited to prevent brute-forcing the 6-digit setup code.
  664. """
  665. await check_rate_limit(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  666. now = datetime.now(timezone.utc)
  667. # --- Peek: validate token without consuming ---
  668. peek_result = await db.execute(
  669. select(AuthEphemeralToken).where(
  670. AuthEphemeralToken.token == body.setup_token,
  671. AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
  672. AuthEphemeralToken.username == current_user.username,
  673. AuthEphemeralToken.expires_at > now,
  674. )
  675. )
  676. eph = peek_result.scalar_one_or_none()
  677. if eph is None:
  678. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token")
  679. code_hash = eph.nonce # code hash stored in the nonce field
  680. # --- Verify code before consuming the token ---
  681. if not pwd_context.verify(body.code, code_hash):
  682. await record_failed_attempt(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  683. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification code")
  684. # --- Atomically consume the token now that the code is correct ---
  685. # DELETE...RETURNING prevents a concurrent request from using the same token.
  686. del_result = await db.execute(
  687. delete(AuthEphemeralToken)
  688. .where(
  689. AuthEphemeralToken.token == body.setup_token,
  690. AuthEphemeralToken.token_type == TokenType.EMAIL_OTP_SETUP,
  691. AuthEphemeralToken.username == current_user.username,
  692. )
  693. .returning(AuthEphemeralToken.id)
  694. )
  695. if del_result.one_or_none() is None:
  696. # Concurrent request consumed it between peek and delete — treat as invalid.
  697. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired setup token")
  698. await clear_failed_attempts(db, current_user.username, event_type=EventType.TWO_FA_ATTEMPT)
  699. await _set_email_2fa_enabled(db, current_user.id, True)
  700. await db.commit()
  701. return {"message": "Email OTP 2FA enabled"}
  702. @router.post("/2fa/email/disable")
  703. async def disable_email_otp(
  704. body: EmailOTPDisableRequest,
  705. current_user: User = Depends(get_current_active_user),
  706. db: AsyncSession = Depends(get_db),
  707. ) -> dict:
  708. """Disable email-based OTP 2FA for the current user.
  709. C6: Re-authentication required — the caller must supply their account password
  710. to prevent a hijacked session from silently removing a second factor.
  711. LDAP/OIDC-only users (no local password) are exempt from this check.
  712. H-2: Rate-limited to prevent brute-forcing the password via this endpoint.
  713. """
  714. await check_rate_limit(db, current_user.username)
  715. if current_user.password_hash:
  716. if not verify_password(body.password, current_user.password_hash):
  717. await record_failed_attempt(db, current_user.username)
  718. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password")
  719. await _set_email_2fa_enabled(db, current_user.id, False)
  720. await db.commit()
  721. return {"message": "Email OTP 2FA disabled"}
  722. @router.post("/2fa/email/send")
  723. async def send_email_otp(
  724. request: Request,
  725. body: EmailOTPSendRequest,
  726. db: AsyncSession = Depends(get_db),
  727. ) -> dict:
  728. """Send a 6-digit OTP code to the user's email address.
  729. Requires a valid pre_auth_token obtained during the login flow.
  730. """
  731. # Peek (validate without consuming) first so a rate-limit rejection does not
  732. # permanently burn the caller's pre-auth token.
  733. challenge_id = request.cookies.get("2fa_challenge")
  734. username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  735. if not username:
  736. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  737. # Enforce rate limit BEFORE consuming the token to prevent OTP email flooding.
  738. await check_email_otp_send_rate(db, username)
  739. user = await get_user_by_username(db, username)
  740. if not user or not user.is_active:
  741. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
  742. if not user.email:
  743. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no email address configured")
  744. smtp_settings = await get_smtp_settings(db)
  745. if not smtp_settings:
  746. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email service is not configured")
  747. # Invalidate all existing unused OTP codes for this user (staged, not yet committed)
  748. await db.execute(
  749. UserOTPCode.__table__.update() # type: ignore[attr-defined]
  750. .where(UserOTPCode.user_id == user.id)
  751. .where(UserOTPCode.used.is_(False))
  752. .values(used=True)
  753. )
  754. # Generate a 6-digit code and stage the record (not committed yet)
  755. code = str(secrets.randbelow(1_000_000)).zfill(6)
  756. code_hash = pwd_context.hash(code)
  757. expires_at = datetime.now(timezone.utc) + timedelta(minutes=UserOTPCode.OTP_TTL_MINUTES)
  758. otp_record = UserOTPCode(
  759. user_id=user.id,
  760. code_hash=code_hash,
  761. attempts=0,
  762. used=False,
  763. expires_at=expires_at,
  764. )
  765. db.add(otp_record)
  766. # M2: Send the email BEFORE consuming the pre-auth token.
  767. # If the send fails we raise an exception here; the session is uncommitted so
  768. # the OTP record is discarded and the original token remains valid for retry.
  769. try:
  770. send_email(
  771. smtp_settings=smtp_settings,
  772. to_email=user.email,
  773. subject="Your Bambuddy verification code",
  774. body_text=f"Your Bambuddy login code is: {code}\n\nThis code expires in {UserOTPCode.OTP_TTL_MINUTES} minutes and can only be used once.",
  775. body_html=(
  776. f"<p>Your <strong>Bambuddy</strong> login verification code is:</p>"
  777. f"<h2 style='letter-spacing:4px'>{code}</h2>"
  778. f"<p>This code expires in <strong>{UserOTPCode.OTP_TTL_MINUTES} minutes</strong> and can only be used once.</p>"
  779. f"<p>If you did not request this code, you can safely ignore this email.</p>"
  780. ),
  781. )
  782. await record_email_otp_send(db, username)
  783. except Exception as exc:
  784. logger.error("Failed to send OTP email to user_id=%d: %s", user.id, exc)
  785. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to send OTP email")
  786. # Email sent — now atomically consume the old token (this also commits the
  787. # staged OTP record) and issue a fresh token for the verify step.
  788. consumed = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  789. if not consumed:
  790. # Raced with another request or token just expired — treat as invalid.
  791. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  792. # Re-issue a fresh pre-auth token bound to the same cookie so the binding
  793. # carries forward through the email → verify step.
  794. fresh_token = await create_pre_auth_token(db, username, challenge_id=challenge_id)
  795. # Return the fresh pre-auth token so the frontend can proceed to verify
  796. return {"message": "Code sent to your email address", "pre_auth_token": fresh_token}
  797. @router.post("/2fa/verify", response_model=TwoFAVerifyResponse)
  798. async def verify_2fa(
  799. request: Request,
  800. body: TwoFAVerifyRequest,
  801. db: AsyncSession = Depends(get_db),
  802. ) -> TwoFAVerifyResponse:
  803. """Verify a 2FA code and exchange the pre_auth_token for a full JWT.
  804. Accepted methods: ``totp``, ``email``, ``backup``.
  805. The pre_auth_token is NOT consumed on failed verification attempts so the
  806. user can retry without restarting the login flow. It is only consumed once
  807. verification succeeds, preventing token replay after success.
  808. """
  809. # Peek without consuming — bad codes must not burn the session token.
  810. # Pass the HttpOnly challenge cookie so the binding check is enforced.
  811. challenge_id = request.cookies.get("2fa_challenge")
  812. username = await peek_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  813. if not username:
  814. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  815. await check_rate_limit(db, username)
  816. user = await get_user_by_username(db, username)
  817. if not user or not user.is_active:
  818. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
  819. method = body.method
  820. if method == "totp":
  821. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
  822. totp_record = result.scalar_one_or_none()
  823. if not totp_record or not totp_record.is_enabled:
  824. await record_failed_attempt(db, username)
  825. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user")
  826. totp_obj = pyotp.TOTP(totp_record.secret)
  827. if not totp_obj.verify(body.code, valid_window=1):
  828. await record_failed_attempt(db, username)
  829. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid TOTP code")
  830. _assert_totp_not_replayed(totp_obj, totp_record, body.code)
  831. await db.flush() # L-3: persist last_totp_counter immediately to block replay
  832. elif method == "email":
  833. now = datetime.now(timezone.utc)
  834. result = await db.execute(
  835. select(UserOTPCode)
  836. .where(UserOTPCode.user_id == user.id)
  837. .where(UserOTPCode.used.is_(False))
  838. .where(UserOTPCode.expires_at > now)
  839. .order_by(UserOTPCode.created_at.desc())
  840. )
  841. otp_record = result.scalar_one_or_none()
  842. if not otp_record:
  843. await record_failed_attempt(db, username)
  844. raise HTTPException(
  845. status_code=status.HTTP_401_UNAUTHORIZED, detail="No valid OTP code found. Request a new one."
  846. )
  847. if otp_record.attempts >= UserOTPCode.MAX_ATTEMPTS:
  848. otp_record.consume()
  849. await db.commit()
  850. await record_failed_attempt(db, username)
  851. raise HTTPException(
  852. status_code=status.HTTP_401_UNAUTHORIZED, detail="OTP code has been invalidated after too many attempts"
  853. )
  854. if not pwd_context.verify(body.code, otp_record.code_hash):
  855. otp_record.attempts += 1
  856. await db.commit()
  857. await record_failed_attempt(db, username)
  858. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid OTP code")
  859. otp_record.consume()
  860. await db.commit()
  861. else: # method == "backup"
  862. result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
  863. totp_record = result.scalar_one_or_none()
  864. if not totp_record or not totp_record.is_enabled:
  865. await record_failed_attempt(db, username)
  866. raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="TOTP is not enabled for this user")
  867. # Always iterate all codes — no early break (L-R9-A: constant iteration
  868. # count prevents timing oracle based on used-code position in the list).
  869. matched_index: int | None = None
  870. for idx, hashed in enumerate(totp_record.backup_code_hashes):
  871. if pwd_context.verify(body.code, hashed) and matched_index is None:
  872. matched_index = idx
  873. if matched_index is None:
  874. await record_failed_attempt(db, username)
  875. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid backup code")
  876. # M1: Consume the pre-auth token FIRST (atomic single-use enforcement).
  877. # Only if that succeeds do we remove the backup code — this prevents a race
  878. # where two concurrent requests both pass code verification but only one
  879. # should be granted a session.
  880. consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  881. if not consumed_username:
  882. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  883. # Remove the used backup code now that the token is atomically consumed.
  884. updated_codes = [c for i, c in enumerate(totp_record.backup_code_hashes) if i != matched_index]
  885. totp_record.backup_code_hashes = updated_codes
  886. await db.commit()
  887. await clear_failed_attempts(db, username)
  888. access_token = create_access_token(
  889. data={"sub": user.username},
  890. expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
  891. )
  892. result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
  893. user = result.scalar_one()
  894. return TwoFAVerifyResponse(access_token=access_token, token_type="bearer", user=_user_to_response(user))
  895. # Verification succeeded (TOTP or email) — consume the pre-auth token.
  896. # C-1: Check the return value; if None the token was already consumed by a
  897. # concurrent request (race condition) — reject to prevent double-use.
  898. consumed_username = await consume_pre_auth_token(db, body.pre_auth_token, challenge_id=challenge_id)
  899. if not consumed_username:
  900. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired pre-auth token")
  901. await clear_failed_attempts(db, username)
  902. access_token = create_access_token(
  903. data={"sub": user.username},
  904. expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
  905. )
  906. # Reload with groups for permission calculation
  907. result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
  908. user = result.scalar_one()
  909. return TwoFAVerifyResponse(
  910. access_token=access_token,
  911. token_type="bearer",
  912. user=_user_to_response(user),
  913. )
  914. @router.delete("/2fa/admin/{user_id}")
  915. async def admin_disable_2fa(
  916. user_id: int,
  917. body: AdminDisable2FARequest = Body(default_factory=AdminDisable2FARequest),
  918. current_user: User | None = RequirePermissionIfAuthEnabled(Permission.USERS_UPDATE),
  919. db: AsyncSession = Depends(get_db),
  920. ) -> dict:
  921. """Admin endpoint: disable all 2FA for a given user.
  922. Nit 3: Requires the admin's own password as a re-auth step (matching how
  923. disable_email_otp protects a user's own 2FA removal). OIDC/LDAP-only admins
  924. (no local password_hash) are exempt.
  925. """
  926. # Nit 3: Re-auth — admin must supply their own password.
  927. if current_user and current_user.password_hash:
  928. if not body.admin_password or not verify_password(body.admin_password, current_user.password_hash):
  929. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Admin password required")
  930. # Delete TOTP record
  931. await db.execute(delete(UserTOTP).where(UserTOTP.user_id == user_id))
  932. # Disable email 2FA setting
  933. await _set_email_2fa_enabled(db, user_id, False)
  934. # Invalidate all OTP codes
  935. await db.execute(
  936. UserOTPCode.__table__.update() # type: ignore[attr-defined]
  937. .where(UserOTPCode.user_id == user_id)
  938. .values(used=True)
  939. )
  940. # I2: Invalidate existing JWTs for the target user by bumping password_changed_at.
  941. # Without this, a stolen token remains valid after 2FA removal.
  942. target_user = (await db.execute(select(User).where(User.id == user_id))).scalar_one_or_none()
  943. if target_user:
  944. target_user.password_changed_at = datetime.now(timezone.utc)
  945. await db.commit()
  946. actor = current_user.username if current_user else "anonymous"
  947. logger.info("Admin %s disabled all 2FA for user_id=%d", actor, user_id)
  948. return {"message": "2FA disabled for user"}
  949. # ===========================================================================
  950. # OIDC Endpoints
  951. # ===========================================================================
  952. @router.get("/oidc/providers", response_model=list[OIDCProviderResponse])
  953. async def list_oidc_providers(
  954. db: AsyncSession = Depends(get_db),
  955. ) -> list[OIDCProviderResponse]:
  956. """List all enabled OIDC providers (public)."""
  957. result = await db.execute(select(OIDCProvider).where(OIDCProvider.is_enabled.is_(True)))
  958. providers = result.scalars().all()
  959. return [OIDCProviderResponse.model_validate(p) for p in providers]
  960. @router.get("/oidc/providers/all", response_model=list[OIDCProviderResponse])
  961. async def list_all_oidc_providers(
  962. _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_READ),
  963. db: AsyncSession = Depends(get_db),
  964. ) -> list[OIDCProviderResponse]:
  965. """List ALL OIDC providers including disabled ones (admin only)."""
  966. result2 = await db.execute(select(OIDCProvider))
  967. providers = result2.scalars().all()
  968. return [OIDCProviderResponse.model_validate(p) for p in providers]
  969. @router.post("/oidc/providers", response_model=OIDCProviderResponse, status_code=status.HTTP_201_CREATED)
  970. async def create_oidc_provider(
  971. body: OIDCProviderCreate,
  972. _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
  973. db: AsyncSession = Depends(get_db),
  974. ) -> OIDCProviderResponse:
  975. """Create a new OIDC provider (admin only)."""
  976. provider = OIDCProvider(
  977. name=body.name,
  978. issuer_url=body.issuer_url.rstrip("/"),
  979. client_id=body.client_id,
  980. client_secret=body.client_secret,
  981. scopes=body.scopes,
  982. is_enabled=body.is_enabled,
  983. auto_create_users=body.auto_create_users,
  984. auto_link_existing_accounts=body.auto_link_existing_accounts,
  985. email_claim=body.email_claim,
  986. require_email_verified=body.require_email_verified,
  987. icon_url=body.icon_url,
  988. )
  989. # SEC-1 + SEC-6: runtime guard mirrors the OIDCProviderCreate model_validator in schemas/auth.py.
  990. # Catches any future path that bypasses Pydantic validation (direct ORM, scripts).
  991. _enforce_auto_link_safety(provider)
  992. db.add(provider)
  993. await db.commit()
  994. await db.refresh(provider)
  995. return OIDCProviderResponse.model_validate(provider)
  996. @router.put("/oidc/providers/{provider_id}", response_model=OIDCProviderResponse)
  997. async def update_oidc_provider(
  998. provider_id: int,
  999. body: OIDCProviderUpdate,
  1000. _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
  1001. db: AsyncSession = Depends(get_db),
  1002. ) -> OIDCProviderResponse:
  1003. """Update an existing OIDC provider (admin only)."""
  1004. result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
  1005. provider = result2.scalar_one_or_none()
  1006. if not provider:
  1007. raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
  1008. for field, value in body.model_dump(exclude_none=True).items():
  1009. if field == "issuer_url" and value:
  1010. value = value.rstrip("/")
  1011. setattr(provider, field, value)
  1012. # SEC-1 + SEC-6: Combined-State-Guard after setattr loop.
  1013. # Checks the final in-memory state (DB values + newly set values combined) to catch
  1014. # partial updates that each pass schema validation individually but are unsafe together.
  1015. _enforce_auto_link_safety(provider)
  1016. await db.commit()
  1017. await db.refresh(provider)
  1018. return OIDCProviderResponse.model_validate(provider)
  1019. @router.delete("/oidc/providers/{provider_id}")
  1020. async def delete_oidc_provider(
  1021. provider_id: int,
  1022. _: User | None = RequirePermissionIfAuthEnabled(Permission.SETTINGS_UPDATE),
  1023. db: AsyncSession = Depends(get_db),
  1024. ) -> dict:
  1025. """Delete an OIDC provider and all its user links (admin only)."""
  1026. result2 = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
  1027. provider = result2.scalar_one_or_none()
  1028. if not provider:
  1029. raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found")
  1030. await db.delete(provider)
  1031. await db.commit()
  1032. return {"message": "Provider deleted"}
  1033. @router.get("/oidc/authorize/{provider_id}", response_model=OIDCAuthorizeResponse)
  1034. async def oidc_authorize(
  1035. provider_id: int,
  1036. db: AsyncSession = Depends(get_db),
  1037. ) -> OIDCAuthorizeResponse:
  1038. """Return the OIDC authorization URL for the given provider."""
  1039. result = await db.execute(
  1040. select(OIDCProvider).where(OIDCProvider.id == provider_id).where(OIDCProvider.is_enabled.is_(True))
  1041. )
  1042. provider = result.scalar_one_or_none()
  1043. if not provider:
  1044. raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Provider not found or not enabled")
  1045. # Fetch discovery document
  1046. discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration"
  1047. try:
  1048. async with httpx.AsyncClient(timeout=10) as client:
  1049. resp = await client.get(discovery_url)
  1050. resp.raise_for_status()
  1051. discovery = resp.json()
  1052. except Exception as exc:
  1053. logger.error("Failed to fetch OIDC discovery for provider %d: %s", provider_id, exc)
  1054. raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="Failed to fetch OIDC discovery document")
  1055. authorization_endpoint = discovery.get("authorization_endpoint")
  1056. if not authorization_endpoint:
  1057. raise HTTPException(
  1058. status_code=status.HTTP_502_BAD_GATEWAY, detail="OIDC discovery document missing authorization_endpoint"
  1059. )
  1060. # B2: SSRF guard — reject non-HTTP(S) schemes in the authorization endpoint
  1061. if not authorization_endpoint.startswith(("https://", "http://")):
  1062. logger.warning("OIDC discovery authorization_endpoint has invalid scheme: %s", authorization_endpoint)
  1063. raise HTTPException(
  1064. status_code=status.HTTP_502_BAD_GATEWAY,
  1065. detail="OIDC discovery document contains invalid authorization_endpoint",
  1066. )
  1067. external_url = await _get_base_external_url(db)
  1068. redirect_uri = f"{external_url}/api/v1/auth/oidc/callback"
  1069. now = datetime.now(timezone.utc)
  1070. # Prune expired OIDC states from the DB
  1071. await db.execute(
  1072. delete(AuthEphemeralToken).where(
  1073. AuthEphemeralToken.token_type == TokenType.OIDC_STATE,
  1074. AuthEphemeralToken.expires_at < now,
  1075. )
  1076. )
  1077. state = secrets.token_urlsafe(32)
  1078. nonce = secrets.token_urlsafe(32)
  1079. # PKCE (S256) – required by PocketID and recommended for all OIDC flows
  1080. code_verifier = secrets.token_urlsafe(48) # 64-char URL-safe string
  1081. code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
  1082. db.add(
  1083. AuthEphemeralToken(
  1084. token=state,
  1085. token_type=TokenType.OIDC_STATE,
  1086. provider_id=provider_id,
  1087. nonce=nonce,
  1088. code_verifier=code_verifier,
  1089. expires_at=now + OIDC_STATE_TTL,
  1090. )
  1091. )
  1092. await db.commit()
  1093. params = urllib.parse.urlencode(
  1094. {
  1095. "response_type": "code",
  1096. "client_id": provider.client_id,
  1097. "redirect_uri": redirect_uri,
  1098. "scope": provider.scopes,
  1099. "state": state,
  1100. "nonce": nonce,
  1101. "code_challenge": code_challenge,
  1102. "code_challenge_method": "S256",
  1103. }
  1104. )
  1105. auth_url = f"{authorization_endpoint}?{params}"
  1106. return OIDCAuthorizeResponse(auth_url=auth_url)
  1107. @router.get("/oidc/callback")
  1108. async def oidc_callback(
  1109. code: str | None = Query(default=None, max_length=2048),
  1110. state: str | None = Query(default=None, max_length=2048),
  1111. error: str | None = Query(default=None, max_length=256),
  1112. db: AsyncSession = Depends(get_db),
  1113. ) -> RedirectResponse:
  1114. """Handle the OIDC authorization code callback from the identity provider."""
  1115. external_url = await _get_base_external_url(db)
  1116. frontend_error_url = f"{external_url}/?oidc_error="
  1117. try:
  1118. if error:
  1119. logger.warning("OIDC callback received error: %s", error)
  1120. return RedirectResponse(url=f"{frontend_error_url}oidc_provider_error", status_code=302)
  1121. if not code or not state:
  1122. return RedirectResponse(url=f"{frontend_error_url}missing_parameters", status_code=302)
  1123. # Atomically validate and consume OIDC state from DB (I6: single-use enforcement).
  1124. # DELETE...RETURNING ensures concurrent callbacks with the same state token
  1125. # cannot both succeed — only the first DELETE finds the row.
  1126. now = datetime.now(timezone.utc)
  1127. state_del = await db.execute(
  1128. delete(AuthEphemeralToken)
  1129. .where(
  1130. AuthEphemeralToken.token == state,
  1131. AuthEphemeralToken.token_type == TokenType.OIDC_STATE,
  1132. AuthEphemeralToken.expires_at > now, # reject expired tokens atomically
  1133. )
  1134. .returning(
  1135. AuthEphemeralToken.provider_id,
  1136. AuthEphemeralToken.nonce,
  1137. AuthEphemeralToken.code_verifier,
  1138. )
  1139. )
  1140. state_row = state_del.one_or_none()
  1141. if state_row is None:
  1142. await db.rollback()
  1143. return RedirectResponse(url=f"{frontend_error_url}invalid_state", status_code=302)
  1144. provider_id, nonce, code_verifier = state_row
  1145. await db.commit()
  1146. # Load provider
  1147. result = await db.execute(select(OIDCProvider).where(OIDCProvider.id == provider_id))
  1148. provider = result.scalar_one_or_none()
  1149. if not provider:
  1150. return RedirectResponse(url=f"{frontend_error_url}provider_not_found", status_code=302)
  1151. redirect_uri = f"{external_url}/api/v1/auth/oidc/callback"
  1152. # ── Step 1: Fetch discovery document ────────────────────────────────
  1153. discovery_url = f"{provider.issuer_url.rstrip('/')}/.well-known/openid-configuration"
  1154. try:
  1155. async with httpx.AsyncClient(timeout=10) as client:
  1156. disc_resp = await client.get(discovery_url)
  1157. disc_resp.raise_for_status()
  1158. discovery = disc_resp.json()
  1159. except Exception as exc:
  1160. logger.error("OIDC discovery fetch failed for provider %d: %s", provider_id, exc)
  1161. return RedirectResponse(url=f"{frontend_error_url}discovery_failed", status_code=302)
  1162. token_endpoint = discovery.get("token_endpoint")
  1163. jwks_uri = discovery.get("jwks_uri")
  1164. if not token_endpoint or not jwks_uri:
  1165. return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302)
  1166. # L-R7-C: Reject non-HTTP(S) URLs in the discovery document to prevent
  1167. # SSRF via crafted responses (e.g. file://, gopher://, internal schemes).
  1168. if not token_endpoint.startswith(("https://", "http://")) or not jwks_uri.startswith(("https://", "http://")):
  1169. logger.warning(
  1170. "OIDC discovery document contains non-HTTP URL(s): token=%s jwks=%s", token_endpoint, jwks_uri
  1171. )
  1172. return RedirectResponse(url=f"{frontend_error_url}invalid_discovery_document", status_code=302)
  1173. # ── Step 2: Exchange authorization code for tokens ───────────────────
  1174. token_form: dict[str, str] = {
  1175. "grant_type": "authorization_code",
  1176. "code": code,
  1177. "redirect_uri": redirect_uri,
  1178. "client_id": provider.client_id,
  1179. }
  1180. if provider.client_secret:
  1181. token_form["client_secret"] = provider.client_secret
  1182. if code_verifier:
  1183. token_form["code_verifier"] = code_verifier
  1184. try:
  1185. async with httpx.AsyncClient(timeout=15) as client:
  1186. token_resp = await client.post(
  1187. token_endpoint,
  1188. data=token_form,
  1189. headers={"Accept": "application/json"},
  1190. )
  1191. except Exception as exc:
  1192. logger.error("OIDC token exchange request failed for provider %d: %s", provider_id, exc)
  1193. return RedirectResponse(url=f"{frontend_error_url}token_exchange_network_error", status_code=302)
  1194. if not token_resp.is_success:
  1195. try:
  1196. err_body = token_resp.json()
  1197. oidc_err = err_body.get("error", "")
  1198. oidc_desc = err_body.get("error_description", "")
  1199. except Exception:
  1200. oidc_err = ""
  1201. oidc_desc = token_resp.text[:200]
  1202. logger.error(
  1203. "OIDC token exchange HTTP %d for provider %d. redirect_uri=%r error=%r desc=%r",
  1204. token_resp.status_code,
  1205. provider_id,
  1206. redirect_uri,
  1207. oidc_err,
  1208. oidc_desc,
  1209. )
  1210. # Encode the OIDC error code into the redirect so the user sees it in the toast.
  1211. # URL-encode the value to prevent query-parameter injection from provider responses.
  1212. raw_err = oidc_err[:40] if oidc_err else str(token_resp.status_code)
  1213. safe_err = urllib.parse.quote(raw_err, safe="")
  1214. return RedirectResponse(
  1215. url=f"{frontend_error_url}token_exchange_{safe_err}",
  1216. status_code=302,
  1217. )
  1218. try:
  1219. token_data = token_resp.json()
  1220. except Exception as exc:
  1221. logger.error("OIDC token exchange non-JSON response for provider %d: %s", provider_id, exc)
  1222. return RedirectResponse(url=f"{frontend_error_url}token_exchange_bad_response", status_code=302)
  1223. id_token = token_data.get("id_token")
  1224. if not id_token:
  1225. # Only log the keys present — values may contain secrets (access_token, etc.)
  1226. logger.error(
  1227. "OIDC token response missing id_token for provider %d; keys present: %s",
  1228. provider_id,
  1229. list(token_data.keys()),
  1230. )
  1231. return RedirectResponse(url=f"{frontend_error_url}no_id_token", status_code=302)
  1232. # ── Step 3: Fetch JWKS and validate ID token ─────────────────────────
  1233. # Use the issuer from the discovery document as the canonical value (OIDC Core
  1234. # §3.1.3.7 requires iss == discovery issuer exactly). We strip trailing slashes
  1235. # from both sides because some providers (e.g. Authentik, older PocketID versions)
  1236. # are inconsistent between the discovery issuer and the JWT iss claim.
  1237. discovery_issuer: str = discovery.get("issuer", provider.issuer_url).rstrip("/")
  1238. try:
  1239. async with httpx.AsyncClient(timeout=10) as jwks_http:
  1240. jwks_resp = await jwks_http.get(jwks_uri)
  1241. jwks_resp.raise_for_status()
  1242. jwks_data = jwks_resp.json()
  1243. jwks_client = PyJWKClient(jwks_uri)
  1244. jwks_client.fetch_data = lambda: jwks_data # type: ignore[method-assign]
  1245. signing_key = jwks_client.get_signing_key_from_jwt(id_token)
  1246. # M-3: Decode without built-in issuer check, then compare normalised
  1247. # (both sides rstrip("/")) to handle providers like Authentik that include
  1248. # a trailing slash in iss but not in the discovery issuer, or vice-versa.
  1249. claims = jwt.decode(
  1250. id_token,
  1251. signing_key.key,
  1252. algorithms=["RS256", "ES256", "RS384", "ES384", "RS512"],
  1253. audience=provider.client_id,
  1254. options={"verify_iss": False},
  1255. )
  1256. token_iss = claims.get("iss", "").rstrip("/")
  1257. if token_iss != discovery_issuer:
  1258. raise jwt.exceptions.InvalidIssuerError("Invalid issuer")
  1259. except Exception as exc:
  1260. logger.error("OIDC JWT validation failed for provider %d: %s", provider_id, exc, exc_info=True)
  1261. return RedirectResponse(url=f"{frontend_error_url}token_validation_failed", status_code=302)
  1262. # Verify nonce — fail closed: we always send a nonce, so the provider must echo it.
  1263. # Skipping the check when nonce is absent would allow CSRF on non-nonce providers.
  1264. token_nonce = claims.get("nonce")
  1265. if token_nonce is None or token_nonce != nonce:
  1266. logger.warning("OIDC nonce mismatch for provider %d (present=%r)", provider_id, token_nonce is not None)
  1267. return RedirectResponse(url=f"{frontend_error_url}nonce_mismatch", status_code=302)
  1268. provider_sub: str = claims.get("sub", "")
  1269. if not provider_sub:
  1270. return RedirectResponse(url=f"{frontend_error_url}missing_sub_claim", status_code=302)
  1271. # SEC-3: resolve email via Fall A/B/C logic (see _resolve_provider_email).
  1272. provider_email = _resolve_provider_email(provider, claims, provider_sub)
  1273. # ── Step 4: Resolve / create user ────────────────────────────────────
  1274. try:
  1275. # 1. Look up existing OIDC link
  1276. link_result = await db.execute(
  1277. select(UserOIDCLink)
  1278. .where(UserOIDCLink.provider_id == provider_id)
  1279. .where(UserOIDCLink.provider_user_id == provider_sub)
  1280. )
  1281. link = link_result.scalar_one_or_none()
  1282. user: User | None = None
  1283. if link:
  1284. # Existing link → load the linked user
  1285. user_result = await db.execute(
  1286. select(User).where(User.id == link.user_id).options(selectinload(User.groups))
  1287. )
  1288. user = user_result.scalar_one_or_none()
  1289. else:
  1290. # 2. No OIDC link yet — check for an existing user with the same email.
  1291. # Use case-insensitive matching (func.lower) so that "User@Example.com"
  1292. # and "user@example.com" are treated as the same identity, preventing
  1293. # an attacker-controlled IdP from bypassing the auto-link guard by
  1294. # registering the target email with different casing.
  1295. email_user: User | None = None
  1296. if provider_email:
  1297. email_user = await get_user_by_email(db, provider_email)
  1298. if email_user and provider.auto_link_existing_accounts:
  1299. # M-4: Only auto-link when the provider has auto_link_existing_accounts
  1300. # enabled. Operators can disable this to require explicit account linking,
  1301. # preventing an attacker-controlled IdP from hijacking local accounts.
  1302. #
  1303. # M-NEW-6: Refuse auto-link if the target user already has any OIDC
  1304. # link (to any provider). Without this guard an attacker who controls
  1305. # a second OIDC provider with auto_link enabled could add themselves as
  1306. # a second IdP for a user that already authenticates via a legitimate
  1307. # provider, effectively taking over the account.
  1308. existing_links_result = await db.execute(
  1309. select(UserOIDCLink).where(UserOIDCLink.user_id == email_user.id)
  1310. )
  1311. has_existing_oidc_link = existing_links_result.scalar_one_or_none() is not None
  1312. if has_existing_oidc_link:
  1313. logger.warning(
  1314. "Auto-link rejected for user '%s': already linked to another OIDC provider",
  1315. email_user.username,
  1316. )
  1317. return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302)
  1318. db.add(
  1319. UserOIDCLink(
  1320. user_id=email_user.id,
  1321. provider_id=provider_id,
  1322. provider_user_id=provider_sub,
  1323. provider_email=provider_email,
  1324. )
  1325. )
  1326. await db.commit()
  1327. user = email_user
  1328. logger.info(
  1329. "Auto-linked existing user '%s' to OIDC provider %d via email match",
  1330. email_user.username,
  1331. provider_id,
  1332. )
  1333. elif provider.auto_create_users:
  1334. # 3. No existing user — create one
  1335. if provider_email:
  1336. raw = provider_email.split("@")[0]
  1337. else:
  1338. raw = provider_sub[:30]
  1339. candidate = re.sub(r"[^a-zA-Z0-9._-]", "", raw)[:30] or "oidcuser"
  1340. username = candidate
  1341. counter = 1
  1342. while True:
  1343. existing = await get_user_by_username(db, username)
  1344. if not existing:
  1345. break
  1346. username = f"{candidate}{counter}"
  1347. counter += 1
  1348. # I9: Assign new OIDC users to the default "Viewers" group so they
  1349. # have read-only access rather than starting with no permissions.
  1350. # Fetch the group BEFORE creating the user so we can set the
  1351. # relationship before flush — accessing new_user.groups after a
  1352. # flush triggers a lazy-load which fails in async context.
  1353. viewers_result = await db.execute(select(Group).where(Group.name == "Viewers"))
  1354. viewers_group = viewers_result.scalar_one_or_none()
  1355. new_user = User(
  1356. username=username,
  1357. email=provider_email,
  1358. # M-1: auth_source="oidc" prevents local password-reset flow
  1359. # for users who should only authenticate via OIDC.
  1360. auth_source="oidc",
  1361. password_hash=None, # OIDC users never use password auth
  1362. role="user",
  1363. is_active=True,
  1364. groups=[viewers_group] if viewers_group else [],
  1365. )
  1366. db.add(new_user)
  1367. await db.flush()
  1368. db.add(
  1369. UserOIDCLink(
  1370. user_id=new_user.id,
  1371. provider_id=provider_id,
  1372. provider_user_id=provider_sub,
  1373. provider_email=provider_email,
  1374. )
  1375. )
  1376. await db.commit()
  1377. user_result = await db.execute(
  1378. select(User).where(User.id == new_user.id).options(selectinload(User.groups))
  1379. )
  1380. user = user_result.scalar_one()
  1381. logger.info("Auto-created user '%s' via OIDC provider %d", username, provider_id)
  1382. else:
  1383. return RedirectResponse(url=f"{frontend_error_url}no_linked_account", status_code=302)
  1384. if not user or not user.is_active:
  1385. return RedirectResponse(url=f"{frontend_error_url}account_inactive", status_code=302)
  1386. # Issue an OIDC exchange token (short-lived, single-use) stored in DB.
  1387. # I7: Opportunistically prune expired exchange tokens to keep the table small.
  1388. now2 = datetime.now(timezone.utc)
  1389. await db.execute(
  1390. delete(AuthEphemeralToken).where(
  1391. AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE,
  1392. AuthEphemeralToken.expires_at < now2,
  1393. )
  1394. )
  1395. exchange_token = secrets.token_urlsafe(32)
  1396. db.add(
  1397. AuthEphemeralToken(
  1398. token=exchange_token,
  1399. token_type=TokenType.OIDC_EXCHANGE,
  1400. username=user.username,
  1401. expires_at=now2 + OIDC_EXCHANGE_TTL,
  1402. )
  1403. )
  1404. await db.commit()
  1405. # H-4: Use a URL fragment (#) instead of a query parameter so the exchange
  1406. # token is never sent to the server in the Referer header or server logs.
  1407. return RedirectResponse(url=f"{external_url}/login#oidc_token={exchange_token}", status_code=302)
  1408. except Exception as exc:
  1409. logger.error("OIDC user resolution failed for provider %d: %s", provider_id, exc, exc_info=True)
  1410. try:
  1411. await db.rollback()
  1412. except Exception as rb_exc:
  1413. logger.error("DB rollback failed after OIDC user-resolution error: %s", rb_exc, exc_info=True)
  1414. return RedirectResponse(url=f"{frontend_error_url}user_resolution_failed", status_code=302)
  1415. except Exception as exc:
  1416. # L-1: Log the exception class name internally but never expose it in the
  1417. # redirect URL — leaking exception names aids attacker reconnaissance.
  1418. logger.error("Unexpected error in OIDC callback (%s): %s", type(exc).__name__, exc, exc_info=True)
  1419. try:
  1420. return RedirectResponse(url=f"{frontend_error_url}internal_error", status_code=302)
  1421. except Exception as redirect_exc:
  1422. logger.error("Failed to construct error redirect in OIDC callback: %s", redirect_exc, exc_info=True)
  1423. raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="OIDC callback failed")
  1424. @router.post("/oidc/exchange", response_model=LoginResponse)
  1425. async def oidc_exchange(
  1426. body: OIDCExchangeRequest,
  1427. raw_request: Request,
  1428. response: Response,
  1429. db: AsyncSession = Depends(get_db),
  1430. ) -> LoginResponse:
  1431. """Exchange an OIDC exchange token (from the callback redirect) for a full JWT.
  1432. C4: If the resolved user has 2FA enabled the exchange returns a pre_auth_token
  1433. (requires_2fa=True) instead of a full JWT. The frontend must then complete the
  1434. 2FA step exactly as it would after a password-based login.
  1435. """
  1436. now = datetime.now(timezone.utc)
  1437. # Atomically consume the exchange token (DELETE...RETURNING prevents replay).
  1438. consume_result = await db.execute(
  1439. delete(AuthEphemeralToken)
  1440. .where(
  1441. AuthEphemeralToken.token == body.oidc_token,
  1442. AuthEphemeralToken.token_type == TokenType.OIDC_EXCHANGE,
  1443. AuthEphemeralToken.expires_at > now, # reject expired tokens atomically
  1444. )
  1445. .returning(AuthEphemeralToken.username)
  1446. )
  1447. row = consume_result.one_or_none()
  1448. if row is None:
  1449. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired OIDC exchange token")
  1450. (username,) = row
  1451. await db.commit()
  1452. user = await get_user_by_username(db, username)
  1453. if not user or not user.is_active:
  1454. raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
  1455. # Reload with groups
  1456. result = await db.execute(select(User).where(User.id == user.id).options(selectinload(User.groups)))
  1457. user = result.scalar_one()
  1458. # C4: Check whether the user has any 2FA method enabled.
  1459. totp_result = await db.execute(select(UserTOTP).where(UserTOTP.user_id == user.id))
  1460. totp_record = totp_result.scalar_one_or_none()
  1461. totp_enabled = totp_record is not None and totp_record.is_enabled
  1462. email_2fa_enabled = await _get_email_2fa_enabled(db, user.id)
  1463. if totp_enabled or email_2fa_enabled:
  1464. # User has 2FA — issue a pre_auth_token bound to this browser session via
  1465. # an HttpOnly cookie (H-A: mirrors the cookie-binding done in auth.py:login).
  1466. two_fa_methods: list[str] = []
  1467. if totp_enabled:
  1468. two_fa_methods.append("totp")
  1469. if email_2fa_enabled:
  1470. two_fa_methods.append("email")
  1471. if totp_enabled:
  1472. two_fa_methods.append("backup")
  1473. challenge_id = secrets.token_urlsafe(32)
  1474. pre_auth_token = await create_pre_auth_token(db, user.username, challenge_id=challenge_id)
  1475. response.set_cookie(
  1476. key="2fa_challenge",
  1477. value=challenge_id,
  1478. httponly=True,
  1479. secure=raw_request.url.scheme == "https",
  1480. samesite="lax",
  1481. max_age=300,
  1482. path="/api/v1/auth/2fa",
  1483. )
  1484. return LoginResponse(
  1485. requires_2fa=True,
  1486. pre_auth_token=pre_auth_token,
  1487. two_fa_methods=two_fa_methods,
  1488. user=_user_to_response(user),
  1489. )
  1490. access_token = create_access_token(
  1491. data={"sub": user.username},
  1492. expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
  1493. )
  1494. return LoginResponse(
  1495. access_token=access_token,
  1496. token_type="bearer",
  1497. user=_user_to_response(user),
  1498. requires_2fa=False,
  1499. )
  1500. @router.get("/oidc/links", response_model=list[OIDCLinkResponse])
  1501. async def list_oidc_links(
  1502. current_user: User = Depends(get_current_active_user),
  1503. db: AsyncSession = Depends(get_db),
  1504. ) -> list[OIDCLinkResponse]:
  1505. """List all OIDC provider links for the current user."""
  1506. result = await db.execute(
  1507. select(UserOIDCLink).where(UserOIDCLink.user_id == current_user.id).options(selectinload(UserOIDCLink.provider))
  1508. )
  1509. links = result.scalars().all()
  1510. return [
  1511. OIDCLinkResponse(
  1512. id=link.id,
  1513. provider_id=link.provider_id,
  1514. provider_name=link.provider.name,
  1515. provider_email=link.provider_email,
  1516. created_at=link.created_at.isoformat(),
  1517. )
  1518. for link in links
  1519. ]
  1520. @router.delete("/oidc/links/{provider_id}")
  1521. async def remove_oidc_link(
  1522. provider_id: int,
  1523. current_user: User = Depends(get_current_active_user),
  1524. db: AsyncSession = Depends(get_db),
  1525. ) -> dict:
  1526. """Remove the OIDC link between the current user and a provider."""
  1527. result = await db.execute(
  1528. select(UserOIDCLink)
  1529. .where(UserOIDCLink.user_id == current_user.id)
  1530. .where(UserOIDCLink.provider_id == provider_id)
  1531. )
  1532. link = result.scalar_one_or_none()
  1533. if not link:
  1534. raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="OIDC link not found")
  1535. await db.delete(link)
  1536. await db.commit()
  1537. return {"message": "OIDC link removed"}
  1538. # ---------------------------------------------------------------------------
  1539. # Internal helpers
  1540. # ---------------------------------------------------------------------------
  1541. async def _get_base_external_url(db: AsyncSession) -> str:
  1542. """Return the base external URL (no trailing slash, no /login suffix)."""
  1543. external_url = await get_setting(db, "external_url")
  1544. if external_url:
  1545. return external_url.rstrip("/")
  1546. return os.environ.get("APP_URL", "http://localhost:5173").rstrip("/")