auth.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. import re
  2. from typing import Literal
  3. from pydantic import BaseModel, Field, field_validator
  4. def _validate_password_complexity(v: str) -> str:
  5. """Enforce minimum password complexity (M-C).
  6. Requires at least one uppercase letter, one lowercase letter, one digit,
  7. and one special character in addition to the min_length=8 Field constraint.
  8. """
  9. if not re.search(r"[A-Z]", v):
  10. raise ValueError("Password must contain at least one uppercase letter")
  11. if not re.search(r"[a-z]", v):
  12. raise ValueError("Password must contain at least one lowercase letter")
  13. if not re.search(r"\d", v):
  14. raise ValueError("Password must contain at least one digit")
  15. if not re.search(r"[^A-Za-z0-9]", v):
  16. raise ValueError("Password must contain at least one special character")
  17. return v
  18. class GroupBrief(BaseModel):
  19. """Brief group info for embedding in user responses."""
  20. id: int
  21. name: str
  22. class Config:
  23. from_attributes = True
  24. class LoginRequest(BaseModel):
  25. username: str = Field(..., max_length=150)
  26. password: str = Field(..., max_length=256)
  27. class LoginResponse(BaseModel):
  28. access_token: str | None = None
  29. token_type: str = "bearer"
  30. user: "UserResponse | None" = None
  31. # Set when 2FA is required; the frontend must call /auth/2fa/verify
  32. requires_2fa: bool = False
  33. pre_auth_token: str | None = None
  34. two_fa_methods: list[str] = []
  35. class UserCreate(BaseModel):
  36. username: str = Field(..., max_length=150)
  37. password: str | None = Field(default=None, max_length=256) # M-NEW-4: cap before pbkdf2
  38. email: str | None = Field(default=None, max_length=254) # L-NEW-5: RFC 5321 max
  39. role: str = "user"
  40. group_ids: list[int] | None = None
  41. @field_validator("password")
  42. @classmethod
  43. def validate_password(cls, v: str | None) -> str | None:
  44. if v is not None:
  45. _validate_password_complexity(v)
  46. return v
  47. class UserUpdate(BaseModel):
  48. username: str | None = Field(default=None, max_length=150)
  49. password: str | None = Field(default=None, max_length=256) # M-NEW-4: cap before pbkdf2
  50. email: str | None = Field(default=None, max_length=254) # L-NEW-5: RFC 5321 max
  51. role: str | None = None
  52. is_active: bool | None = None
  53. group_ids: list[int] | None = None
  54. @field_validator("password")
  55. @classmethod
  56. def validate_password(cls, v: str | None) -> str | None:
  57. if v is not None:
  58. _validate_password_complexity(v)
  59. return v
  60. class UserResponse(BaseModel):
  61. id: int
  62. username: str
  63. email: str | None = None
  64. role: str # Deprecated, kept for backward compatibility
  65. is_active: bool
  66. is_admin: bool # Computed from role and group membership
  67. auth_source: str = "local" # "local" or "ldap"
  68. groups: list[GroupBrief] = []
  69. permissions: list[str] = [] # All permissions from groups
  70. created_at: str
  71. class Config:
  72. from_attributes = True
  73. class ChangePasswordRequest(BaseModel):
  74. current_password: str = Field(..., max_length=256) # M-NEW-3: cap before pbkdf2
  75. new_password: str = Field(..., min_length=8, max_length=256)
  76. @field_validator("new_password")
  77. @classmethod
  78. def validate_new_password(cls, v: str) -> str:
  79. return _validate_password_complexity(v)
  80. class SetupRequest(BaseModel):
  81. auth_enabled: bool
  82. admin_username: str | None = Field(default=None, max_length=150)
  83. admin_password: str | None = Field(default=None, max_length=256)
  84. @field_validator("admin_password")
  85. @classmethod
  86. def validate_admin_password(cls, v: str | None) -> str | None:
  87. if v is not None:
  88. _validate_password_complexity(v)
  89. return v
  90. class SetupResponse(BaseModel):
  91. auth_enabled: bool
  92. admin_created: bool | None = None
  93. class ForgotPasswordRequest(BaseModel):
  94. email: str = Field(..., max_length=254) # L-NEW-1: RFC 5321 max; caps memory/CPU before lookup
  95. class ForgotPasswordConfirmRequest(BaseModel):
  96. token: str = Field(..., max_length=128)
  97. new_password: str = Field(..., min_length=8, max_length=256)
  98. @field_validator("new_password")
  99. @classmethod
  100. def validate_new_password(cls, v: str) -> str:
  101. return _validate_password_complexity(v)
  102. class ForgotPasswordResponse(BaseModel):
  103. message: str
  104. class ResetPasswordRequest(BaseModel):
  105. user_id: int
  106. class ResetPasswordResponse(BaseModel):
  107. message: str
  108. class SMTPSettings(BaseModel):
  109. smtp_host: str
  110. smtp_port: int
  111. smtp_username: str | None = None # Optional when auth is disabled
  112. smtp_password: str | None = None # Optional for read operations or when auth is disabled
  113. smtp_security: str = "starttls" # 'starttls', 'ssl', 'none'
  114. smtp_auth_enabled: bool = True
  115. smtp_from_email: str
  116. smtp_from_name: str = "BamBuddy"
  117. # Deprecated field for backward compatibility
  118. smtp_use_tls: bool | None = None
  119. class TestSMTPRequest(BaseModel):
  120. test_recipient: str
  121. class TestSMTPResponse(BaseModel):
  122. success: bool
  123. message: str
  124. # ---------------------------------------------------------------------------
  125. # 2FA / MFA schemas
  126. # ---------------------------------------------------------------------------
  127. class TwoFAStatusResponse(BaseModel):
  128. totp_enabled: bool
  129. email_otp_enabled: bool
  130. backup_codes_remaining: int
  131. class TOTPSetupResponse(BaseModel):
  132. """Returned when a user initiates TOTP setup. The frontend should display
  133. the QR code image (base64 PNG) and ask the user to scan it, then call
  134. /auth/2fa/totp/enable with a valid code to confirm."""
  135. secret: str # base32 secret (shown as fallback text)
  136. qr_code_b64: str # base64-encoded PNG of the QR code
  137. issuer: str
  138. class TOTPSetupRequest(BaseModel):
  139. """Optional body for POST /auth/2fa/totp/setup.
  140. Only required when re-initialising setup while an active TOTP record exists.
  141. Provide the current TOTP code (from the existing authenticator app) to
  142. confirm intent — mirrors the verification requirement in disable_totp.
  143. """
  144. code: str | None = Field(default=None, max_length=8) # L-NEW-2: bound before pyotp
  145. class TOTPEnableRequest(BaseModel):
  146. code: str # 6-digit TOTP code from the authenticator app
  147. @field_validator("code")
  148. @classmethod
  149. def validate_code(cls, v: str) -> str:
  150. v = v.strip()
  151. if not v.isdigit() or len(v) != 6:
  152. raise ValueError("TOTP code must be exactly 6 digits")
  153. return v
  154. class TOTPEnableResponse(BaseModel):
  155. message: str
  156. backup_codes: list[str] # plain-text codes shown once; user must save them
  157. class TOTPDisableRequest(BaseModel):
  158. """Requires a valid TOTP code OR a backup code to disable TOTP."""
  159. code: str = Field(..., max_length=128)
  160. class BackupCodesResponse(BaseModel):
  161. backup_codes: list[str]
  162. message: str
  163. class EmailOTPEnableRequest(BaseModel):
  164. """No body required — email is taken from the authenticated user's profile."""
  165. pass
  166. class TwoFAVerifyRequest(BaseModel):
  167. pre_auth_token: str = Field(..., max_length=128)
  168. # TOTP/email codes are 6 digits; backup codes are 8 uppercase alphanumeric chars.
  169. # max_length=8 prevents excessively long inputs from reaching pbkdf2/pyotp.
  170. code: str = Field(..., min_length=6, max_length=8)
  171. method: Literal["totp", "email", "backup"] = "totp"
  172. @field_validator("code")
  173. @classmethod
  174. def validate_code_format(cls, v: str) -> str:
  175. v = v.strip()
  176. if not re.match(r"^[A-Za-z0-9]{6,8}$", v):
  177. raise ValueError("Code must be 6–8 alphanumeric characters")
  178. return v.upper() # normalise backup codes to uppercase
  179. class TwoFAVerifyResponse(BaseModel):
  180. access_token: str
  181. token_type: str = "bearer"
  182. user: "UserResponse"
  183. class EmailOTPSendRequest(BaseModel):
  184. pre_auth_token: str = Field(..., max_length=128)
  185. class EmailOTPEnableConfirmRequest(BaseModel):
  186. """Body for the second step of email OTP enable: verify the proof-of-possession code."""
  187. setup_token: str = Field(..., max_length=128)
  188. # L-NEW-3: email OTP setup codes are always exactly 6 digits; reject anything else.
  189. code: str = Field(..., min_length=6, max_length=6)
  190. @field_validator("code")
  191. @classmethod
  192. def validate_code_digits(cls, v: str) -> str:
  193. v = v.strip()
  194. if not v.isdigit() or len(v) != 6:
  195. raise ValueError("Email OTP setup code must be exactly 6 digits")
  196. return v
  197. class EmailOTPDisableRequest(BaseModel):
  198. """Requires the account password to disable email OTP."""
  199. password: str = Field(..., max_length=256)
  200. class AdminDisable2FARequest(BaseModel):
  201. """Admin must supply their own password as re-auth before disabling 2FA for another user.
  202. OIDC/LDAP-only admins (no local password_hash) are exempt from this check.
  203. """
  204. admin_password: str | None = Field(default=None, max_length=256)
  205. # ---------------------------------------------------------------------------
  206. # OIDC schemas
  207. # ---------------------------------------------------------------------------
  208. def _validate_icon_url(v: str | None) -> str | None:
  209. """Reject non-HTTPS icon URLs to prevent SSRF / mixed-content issues."""
  210. if v is None:
  211. return v
  212. if not v.startswith("https://"):
  213. raise ValueError("icon_url must start with https://")
  214. return v
  215. def _validate_issuer_url(v: str | None) -> str | None:
  216. """Nit4: Reject non-HTTPS issuer URLs and private/loopback/link-local hosts.
  217. HTTP is no longer accepted — OIDC providers must be reachable over TLS.
  218. Private-network and loopback addresses are rejected to prevent SSRF attacks
  219. where an admin-supplied URL could reach internal services.
  220. """
  221. import ipaddress
  222. from urllib.parse import urlparse
  223. if v is None:
  224. return v
  225. if not v.startswith("https://"):
  226. raise ValueError("issuer_url must start with https://")
  227. host = urlparse(v).hostname or ""
  228. try:
  229. addr = ipaddress.ip_address(host)
  230. if addr.is_private or addr.is_loopback or addr.is_link_local:
  231. raise ValueError("issuer_url must not point to a private, loopback, or link-local address")
  232. except ValueError as exc:
  233. if "issuer_url" in str(exc):
  234. raise
  235. # hostname is a domain name, not a bare IP — that's fine
  236. return v
  237. def _validate_scopes(v: str | None) -> str | None:
  238. """Nit5: Require that the 'openid' scope is present.
  239. The OpenID Connect spec mandates the 'openid' scope; without it the
  240. response is plain OAuth2, not OIDC, and claims like sub/email are not
  241. guaranteed.
  242. """
  243. if v is None:
  244. return v
  245. scope_list = v.split()
  246. if "openid" not in scope_list:
  247. raise ValueError("scopes must include 'openid'")
  248. return v
  249. class OIDCProviderCreate(BaseModel):
  250. name: str = Field(..., max_length=100) # L-NEW-4
  251. issuer_url: str
  252. client_id: str = Field(..., max_length=256) # L-NEW-4
  253. client_secret: str = Field(..., max_length=512) # L-NEW-4: Fernet input bounded
  254. scopes: str = Field(default="openid email profile", max_length=256) # L-NEW-4
  255. is_enabled: bool = True
  256. auto_create_users: bool = False
  257. auto_link_existing_accounts: bool = False # M-2: conservative default, opt-in only
  258. icon_url: str | None = None
  259. @field_validator("issuer_url")
  260. @classmethod
  261. def validate_issuer_url(cls, v: str) -> str:
  262. result = _validate_issuer_url(v)
  263. assert result is not None
  264. return result
  265. @field_validator("scopes")
  266. @classmethod
  267. def validate_scopes(cls, v: str) -> str:
  268. result = _validate_scopes(v)
  269. assert result is not None
  270. return result
  271. @field_validator("icon_url")
  272. @classmethod
  273. def validate_icon_url(cls, v: str | None) -> str | None:
  274. return _validate_icon_url(v)
  275. class OIDCProviderUpdate(BaseModel):
  276. name: str | None = Field(default=None, max_length=100)
  277. issuer_url: str | None = None
  278. @field_validator("issuer_url")
  279. @classmethod
  280. def validate_issuer_url(cls, v: str | None) -> str | None:
  281. return _validate_issuer_url(v)
  282. client_id: str | None = Field(default=None, max_length=256)
  283. client_secret: str | None = Field(default=None, max_length=512)
  284. scopes: str | None = Field(default=None, max_length=256)
  285. is_enabled: bool | None = None
  286. auto_create_users: bool | None = None
  287. auto_link_existing_accounts: bool | None = None
  288. icon_url: str | None = None
  289. @field_validator("scopes")
  290. @classmethod
  291. def validate_scopes(cls, v: str | None) -> str | None:
  292. return _validate_scopes(v)
  293. @field_validator("icon_url")
  294. @classmethod
  295. def validate_icon_url(cls, v: str | None) -> str | None:
  296. return _validate_icon_url(v)
  297. class OIDCProviderResponse(BaseModel):
  298. id: int
  299. name: str
  300. issuer_url: str
  301. client_id: str
  302. scopes: str
  303. is_enabled: bool
  304. auto_create_users: bool
  305. auto_link_existing_accounts: bool = False
  306. icon_url: str | None = None
  307. class Config:
  308. from_attributes = True
  309. class OIDCAuthorizeResponse(BaseModel):
  310. auth_url: str
  311. class OIDCExchangeRequest(BaseModel):
  312. oidc_token: str = Field(..., max_length=128)
  313. class OIDCLinkResponse(BaseModel):
  314. id: int
  315. provider_id: int
  316. provider_name: str
  317. provider_email: str | None = None
  318. created_at: str