auth.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import re
  2. from typing import Literal
  3. from pydantic import BaseModel, Field, field_validator, model_validator
  4. def _validate_password_complexity(v: str) -> str:
  5. """Enforce minimum password complexity (M-C).
  6. Requires at least one uppercase letter, one lowercase letter, one digit,
  7. and one special character in addition to the min_length=8 Field constraint.
  8. """
  9. if not re.search(r"[A-Z]", v):
  10. raise ValueError("Password must contain at least one uppercase letter")
  11. if not re.search(r"[a-z]", v):
  12. raise ValueError("Password must contain at least one lowercase letter")
  13. if not re.search(r"\d", v):
  14. raise ValueError("Password must contain at least one digit")
  15. if not re.search(r"[^A-Za-z0-9]", v):
  16. raise ValueError("Password must contain at least one special character")
  17. return v
  18. class GroupBrief(BaseModel):
  19. """Brief group info for embedding in user responses."""
  20. id: int
  21. name: str
  22. class Config:
  23. from_attributes = True
  24. class LoginRequest(BaseModel):
  25. username: str = Field(..., max_length=150)
  26. password: str = Field(..., max_length=256)
  27. class LoginResponse(BaseModel):
  28. access_token: str | None = None
  29. token_type: str = "bearer"
  30. user: "UserResponse | None" = None
  31. # Set when 2FA is required; the frontend must call /auth/2fa/verify
  32. requires_2fa: bool = False
  33. pre_auth_token: str | None = None
  34. two_fa_methods: list[str] = []
  35. class UserCreate(BaseModel):
  36. username: str = Field(..., max_length=150)
  37. password: str | None = Field(default=None, max_length=256) # M-NEW-4: cap before pbkdf2
  38. email: str | None = Field(default=None, max_length=254) # L-NEW-5: RFC 5321 max
  39. role: str = "user"
  40. group_ids: list[int] | None = None
  41. @field_validator("password")
  42. @classmethod
  43. def validate_password(cls, v: str | None) -> str | None:
  44. if v is not None:
  45. _validate_password_complexity(v)
  46. return v
  47. class UserUpdate(BaseModel):
  48. username: str | None = Field(default=None, max_length=150)
  49. password: str | None = Field(default=None, max_length=256) # M-NEW-4: cap before pbkdf2
  50. email: str | None = Field(default=None, max_length=254) # L-NEW-5: RFC 5321 max
  51. role: str | None = None
  52. is_active: bool | None = None
  53. group_ids: list[int] | None = None
  54. @field_validator("password")
  55. @classmethod
  56. def validate_password(cls, v: str | None) -> str | None:
  57. if v is not None:
  58. _validate_password_complexity(v)
  59. return v
  60. class UserResponse(BaseModel):
  61. id: int
  62. username: str
  63. email: str | None = None
  64. role: str # Deprecated, kept for backward compatibility
  65. is_active: bool
  66. is_admin: bool # Computed from role and group membership
  67. auth_source: str = "local" # "local" or "ldap"
  68. groups: list[GroupBrief] = []
  69. permissions: list[str] = [] # All permissions from groups
  70. created_at: str
  71. class Config:
  72. from_attributes = True
  73. class ChangePasswordRequest(BaseModel):
  74. current_password: str = Field(..., max_length=256) # M-NEW-3: cap before pbkdf2
  75. new_password: str = Field(..., min_length=8, max_length=256)
  76. @field_validator("new_password")
  77. @classmethod
  78. def validate_new_password(cls, v: str) -> str:
  79. return _validate_password_complexity(v)
  80. class SetupRequest(BaseModel):
  81. auth_enabled: bool
  82. admin_username: str | None = Field(default=None, max_length=150)
  83. admin_password: str | None = Field(default=None, max_length=256)
  84. # Password complexity is NOT validated at the schema layer. When re-enabling auth
  85. # with an existing admin user (or when LDAP is the auth backend), the frontend
  86. # still sends whatever is in the password field but the route ignores it.
  87. # Enforcing complexity here would reject those legitimate flows. The route body
  88. # applies the check only when a brand-new local admin is actually being created.
  89. class SetupResponse(BaseModel):
  90. auth_enabled: bool
  91. admin_created: bool | None = None
  92. class ForgotPasswordRequest(BaseModel):
  93. email: str = Field(..., max_length=254) # L-NEW-1: RFC 5321 max; caps memory/CPU before lookup
  94. class ForgotPasswordConfirmRequest(BaseModel):
  95. token: str = Field(..., max_length=128)
  96. new_password: str = Field(..., min_length=8, max_length=256)
  97. @field_validator("new_password")
  98. @classmethod
  99. def validate_new_password(cls, v: str) -> str:
  100. return _validate_password_complexity(v)
  101. class ForgotPasswordResponse(BaseModel):
  102. message: str
  103. class ResetPasswordRequest(BaseModel):
  104. user_id: int
  105. class ResetPasswordResponse(BaseModel):
  106. message: str
  107. class SMTPSettings(BaseModel):
  108. smtp_host: str
  109. smtp_port: int
  110. smtp_username: str | None = None # Optional when auth is disabled
  111. smtp_password: str | None = None # Optional for read operations or when auth is disabled
  112. smtp_security: str = "starttls" # 'starttls', 'ssl', 'none'
  113. smtp_auth_enabled: bool = True
  114. smtp_from_email: str
  115. smtp_from_name: str = "BamBuddy"
  116. # Deprecated field for backward compatibility
  117. smtp_use_tls: bool | None = None
  118. class TestSMTPRequest(BaseModel):
  119. test_recipient: str
  120. class TestSMTPResponse(BaseModel):
  121. success: bool
  122. message: str
  123. # ---------------------------------------------------------------------------
  124. # 2FA / MFA schemas
  125. # ---------------------------------------------------------------------------
  126. class TwoFAStatusResponse(BaseModel):
  127. totp_enabled: bool
  128. email_otp_enabled: bool
  129. backup_codes_remaining: int
  130. class TOTPSetupResponse(BaseModel):
  131. """Returned when a user initiates TOTP setup. The frontend should display
  132. the QR code image (base64 PNG) and ask the user to scan it, then call
  133. /auth/2fa/totp/enable with a valid code to confirm."""
  134. secret: str # base32 secret (shown as fallback text)
  135. qr_code_b64: str # base64-encoded PNG of the QR code
  136. issuer: str
  137. class TOTPSetupRequest(BaseModel):
  138. """Optional body for POST /auth/2fa/totp/setup.
  139. Only required when re-initialising setup while an active TOTP record exists.
  140. Provide the current TOTP code (from the existing authenticator app) to
  141. confirm intent — mirrors the verification requirement in disable_totp.
  142. """
  143. code: str | None = Field(default=None, max_length=8) # L-NEW-2: bound before pyotp
  144. class TOTPEnableRequest(BaseModel):
  145. code: str # 6-digit TOTP code from the authenticator app
  146. @field_validator("code")
  147. @classmethod
  148. def validate_code(cls, v: str) -> str:
  149. v = v.strip()
  150. if not v.isdigit() or len(v) != 6:
  151. raise ValueError("TOTP code must be exactly 6 digits")
  152. return v
  153. class TOTPEnableResponse(BaseModel):
  154. message: str
  155. backup_codes: list[str] # plain-text codes shown once; user must save them
  156. class TOTPDisableRequest(BaseModel):
  157. """Requires a valid TOTP code OR a backup code to disable TOTP."""
  158. code: str = Field(..., max_length=128)
  159. class BackupCodesResponse(BaseModel):
  160. backup_codes: list[str]
  161. message: str
  162. class EmailOTPEnableRequest(BaseModel):
  163. """No body required — email is taken from the authenticated user's profile."""
  164. pass
  165. class TwoFAVerifyRequest(BaseModel):
  166. pre_auth_token: str = Field(..., max_length=128)
  167. # TOTP/email codes are 6 digits; backup codes are 8 uppercase alphanumeric chars.
  168. # max_length=8 prevents excessively long inputs from reaching pbkdf2/pyotp.
  169. code: str = Field(..., min_length=6, max_length=8)
  170. method: Literal["totp", "email", "backup"] = "totp"
  171. @field_validator("code")
  172. @classmethod
  173. def validate_code_format(cls, v: str) -> str:
  174. v = v.strip()
  175. if not re.match(r"^[A-Za-z0-9]{6,8}$", v):
  176. raise ValueError("Code must be 6–8 alphanumeric characters")
  177. return v.upper() # normalise backup codes to uppercase
  178. class TwoFAVerifyResponse(BaseModel):
  179. access_token: str
  180. token_type: str = "bearer"
  181. user: "UserResponse"
  182. class EmailOTPSendRequest(BaseModel):
  183. pre_auth_token: str = Field(..., max_length=128)
  184. class EmailOTPEnableConfirmRequest(BaseModel):
  185. """Body for the second step of email OTP enable: verify the proof-of-possession code."""
  186. setup_token: str = Field(..., max_length=128)
  187. # L-NEW-3: email OTP setup codes are always exactly 6 digits; reject anything else.
  188. code: str = Field(..., min_length=6, max_length=6)
  189. @field_validator("code")
  190. @classmethod
  191. def validate_code_digits(cls, v: str) -> str:
  192. v = v.strip()
  193. if not v.isdigit() or len(v) != 6:
  194. raise ValueError("Email OTP setup code must be exactly 6 digits")
  195. return v
  196. class EmailOTPDisableRequest(BaseModel):
  197. """Requires the account password to disable email OTP."""
  198. password: str = Field(..., max_length=256)
  199. class AdminDisable2FARequest(BaseModel):
  200. """Admin must supply their own password as re-auth before disabling 2FA for another user.
  201. OIDC/LDAP-only admins (no local password_hash) are exempt from this check.
  202. """
  203. admin_password: str | None = Field(default=None, max_length=256)
  204. # ---------------------------------------------------------------------------
  205. # OIDC schemas
  206. # ---------------------------------------------------------------------------
  207. AUTO_LINK_REQUIREMENTS_ERROR = (
  208. "auto_link_existing_accounts requires require_email_verified=True when email_claim='email'"
  209. )
  210. def _validate_email_claim_name(v: str) -> str:
  211. # Accepts only alphanumeric/underscore/hyphen claim names starting with a letter —
  212. # prevents log injection and limits the attack surface of operator-supplied claim names.
  213. if not re.fullmatch(r"[a-zA-Z][a-zA-Z0-9_\-]{0,63}", v):
  214. raise ValueError("Invalid claim name")
  215. return v
  216. def _validate_icon_url(v: str | None) -> str | None:
  217. """Reject non-HTTPS icon URLs to prevent SSRF / mixed-content issues."""
  218. if v is None:
  219. return v
  220. if not v.startswith("https://"):
  221. raise ValueError("icon_url must start with https://")
  222. return v
  223. def _validate_issuer_url(v: str | None) -> str | None:
  224. """Nit4: Reject non-HTTPS issuer URLs and private/loopback/link-local hosts.
  225. HTTP is no longer accepted — OIDC providers must be reachable over TLS.
  226. Private-network and loopback addresses are rejected to prevent SSRF attacks
  227. where an admin-supplied URL could reach internal services.
  228. """
  229. import ipaddress
  230. from urllib.parse import urlparse
  231. if v is None:
  232. return v
  233. if not v.startswith("https://"):
  234. raise ValueError("issuer_url must start with https://")
  235. host = urlparse(v).hostname or ""
  236. try:
  237. addr = ipaddress.ip_address(host)
  238. if addr.is_private or addr.is_loopback or addr.is_link_local:
  239. raise ValueError("issuer_url must not point to a private, loopback, or link-local address")
  240. except ValueError as exc:
  241. if "issuer_url" in str(exc):
  242. raise
  243. # hostname is a domain name, not a bare IP — that's fine
  244. return v
  245. def _validate_scopes(v: str | None) -> str | None:
  246. """Nit5: Require that the 'openid' scope is present.
  247. The OpenID Connect spec mandates the 'openid' scope; without it the
  248. response is plain OAuth2, not OIDC, and claims like sub/email are not
  249. guaranteed.
  250. """
  251. if v is None:
  252. return v
  253. scope_list = v.split()
  254. if "openid" not in scope_list:
  255. raise ValueError("scopes must include 'openid'")
  256. return v
  257. class OIDCProviderCreate(BaseModel):
  258. name: str = Field(..., max_length=100) # L-NEW-4
  259. issuer_url: str
  260. client_id: str = Field(..., max_length=256) # L-NEW-4
  261. client_secret: str = Field(..., max_length=512) # L-NEW-4: Fernet input bounded
  262. scopes: str = Field(default="openid email profile", max_length=256) # L-NEW-4
  263. is_enabled: bool = True
  264. auto_create_users: bool = False
  265. auto_link_existing_accounts: bool = False # M-2: conservative default, opt-in only
  266. email_claim: str = Field(default="email", max_length=64)
  267. require_email_verified: bool = True
  268. icon_url: str | None = None
  269. default_group_id: int | None = None
  270. @field_validator("issuer_url")
  271. @classmethod
  272. def validate_issuer_url(cls, v: str) -> str:
  273. result = _validate_issuer_url(v)
  274. if result is None:
  275. raise ValueError("issuer_url is required")
  276. return result
  277. @field_validator("scopes")
  278. @classmethod
  279. def validate_scopes(cls, v: str) -> str:
  280. result = _validate_scopes(v)
  281. if result is None:
  282. raise ValueError("scopes is required")
  283. return result
  284. @field_validator("email_claim")
  285. @classmethod
  286. def validate_email_claim(cls, v: str) -> str:
  287. return _validate_email_claim_name(v)
  288. @field_validator("icon_url")
  289. @classmethod
  290. def validate_icon_url(cls, v: str | None) -> str | None:
  291. return _validate_icon_url(v)
  292. # SEC-1: auto_link with email_claim='email' requires require_email_verified=True.
  293. # Fall B (require_email_verified=False + email_claim='email') accepts absent email_verified → account-takeover risk.
  294. # Fall C (custom claim != 'email') is safe: no email_verified gate on that path regardless of require_email_verified.
  295. @model_validator(mode="after")
  296. def check_auto_link_requires_verified(self) -> "OIDCProviderCreate":
  297. if self.auto_link_existing_accounts and self.email_claim == "email" and not self.require_email_verified:
  298. raise ValueError(AUTO_LINK_REQUIREMENTS_ERROR)
  299. return self
  300. class OIDCProviderUpdate(BaseModel):
  301. name: str | None = Field(default=None, max_length=100)
  302. issuer_url: str | None = None
  303. @field_validator("issuer_url")
  304. @classmethod
  305. def validate_issuer_url(cls, v: str | None) -> str | None:
  306. return _validate_issuer_url(v)
  307. client_id: str | None = Field(default=None, max_length=256)
  308. client_secret: str | None = Field(default=None, max_length=512)
  309. scopes: str | None = Field(default=None, max_length=256)
  310. is_enabled: bool | None = None
  311. auto_create_users: bool | None = None
  312. auto_link_existing_accounts: bool | None = None
  313. email_claim: str | None = Field(default=None, max_length=64)
  314. require_email_verified: bool | None = None
  315. icon_url: str | None = None
  316. default_group_id: int | None = None
  317. @field_validator("scopes")
  318. @classmethod
  319. def validate_scopes(cls, v: str | None) -> str | None:
  320. return _validate_scopes(v)
  321. @field_validator("email_claim")
  322. @classmethod
  323. def validate_email_claim(cls, v: str | None) -> str | None:
  324. if v is None:
  325. return None
  326. return _validate_email_claim_name(v)
  327. @field_validator("icon_url")
  328. @classmethod
  329. def validate_icon_url(cls, v: str | None) -> str | None:
  330. return _validate_icon_url(v)
  331. # SEC-1 (schema-level): blocks only when auto_link=True + email_claim='email' + require_email_verified=False
  332. # arrive in the same request. email_claim=None means the request leaves it unchanged (still 'email' by default),
  333. # so that is also treated as 'email'. Partial updates spanning two requests are caught by the
  334. # Combined-State-Guard in the route handler after the setattr loop.
  335. @model_validator(mode="after")
  336. def check_auto_link_requires_verified(self) -> "OIDCProviderUpdate":
  337. if (
  338. self.auto_link_existing_accounts is True
  339. and self.require_email_verified is False
  340. and (self.email_claim is None or self.email_claim == "email")
  341. ):
  342. raise ValueError(AUTO_LINK_REQUIREMENTS_ERROR)
  343. return self
  344. class OIDCProviderResponse(BaseModel):
  345. id: int
  346. name: str
  347. issuer_url: str
  348. client_id: str
  349. scopes: str
  350. is_enabled: bool
  351. auto_create_users: bool
  352. auto_link_existing_accounts: bool = False
  353. email_claim: str = "email"
  354. require_email_verified: bool = True
  355. icon_url: str | None = None
  356. default_group_id: int | None = None
  357. class Config:
  358. from_attributes = True
  359. class OIDCAuthorizeResponse(BaseModel):
  360. auth_url: str
  361. class OIDCExchangeRequest(BaseModel):
  362. oidc_token: str = Field(..., max_length=128)
  363. class OIDCLinkResponse(BaseModel):
  364. id: int
  365. provider_id: int
  366. provider_name: str
  367. provider_email: str | None = None
  368. created_at: str