auth.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import secrets
  5. from datetime import datetime, timedelta, timezone
  6. from pathlib import Path
  7. from typing import Annotated
  8. import jwt
  9. from fastapi import Depends, Header, HTTPException, status
  10. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  11. from jwt.exceptions import PyJWTError as JWTError
  12. from passlib.context import CryptContext
  13. from sqlalchemy import func, select
  14. from sqlalchemy.ext.asyncio import AsyncSession
  15. from sqlalchemy.orm import selectinload
  16. from backend.app.core.database import async_session, get_db
  17. from backend.app.core.permissions import Permission
  18. from backend.app.models.api_key import APIKey
  19. from backend.app.models.settings import Settings
  20. from backend.app.models.user import User
  21. logger = logging.getLogger(__name__)
  22. # Password hashing
  23. # Use pbkdf2_sha256 instead of bcrypt to avoid 72-byte limit and passlib initialization issues
  24. # pbkdf2_sha256 is a secure password hashing algorithm without bcrypt's limitations
  25. pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
  26. def _get_jwt_secret() -> str:
  27. """Get the JWT secret key from environment, file, or generate a new one.
  28. Priority:
  29. 1. JWT_SECRET_KEY environment variable
  30. 2. .jwt_secret file in data directory
  31. 3. Generate new random secret and save to file
  32. Returns:
  33. The JWT secret key
  34. """
  35. # 1. Check environment variable first
  36. env_secret = os.environ.get("JWT_SECRET_KEY")
  37. if env_secret:
  38. logger.info("Using JWT secret from JWT_SECRET_KEY environment variable")
  39. return env_secret
  40. # 2. Check for secret file in data directory
  41. # Use DATA_DIR env var (same as rest of app), fallback to data/ subdirectory
  42. data_dir_env = os.environ.get("DATA_DIR")
  43. if data_dir_env:
  44. data_dir = Path(data_dir_env)
  45. else:
  46. # Fallback to data/ subdirectory under project root (not project root itself!)
  47. data_dir = Path(__file__).parent.parent.parent.parent / "data"
  48. secret_file = data_dir / ".jwt_secret"
  49. if secret_file.exists():
  50. try:
  51. secret = secret_file.read_text().strip()
  52. if secret and len(secret) >= 32:
  53. logger.info("Using JWT secret from %s", secret_file)
  54. return secret
  55. except OSError as e:
  56. logger.warning("Failed to read JWT secret file: %s", e)
  57. # 3. Generate new random secret
  58. new_secret = secrets.token_urlsafe(64)
  59. # Try to save it
  60. try:
  61. data_dir.mkdir(parents=True, exist_ok=True)
  62. # Note: CodeQL flags this as "clear-text storage of sensitive information" but this is
  63. # intentional and secure - JWT secrets must be readable by the app, we set 0600 permissions,
  64. # and this is standard practice for self-hosted applications (same as .env files).
  65. secret_file.write_text(new_secret) # nosec B105
  66. # Restrict permissions (owner read/write only)
  67. secret_file.chmod(0o600)
  68. logger.info("Generated new JWT secret and saved to %s", secret_file)
  69. except OSError as e:
  70. logger.warning(
  71. "Could not save JWT secret to file (%s). "
  72. "Secret will be regenerated on restart, invalidating existing tokens. "
  73. "Set JWT_SECRET_KEY environment variable for persistence.",
  74. e,
  75. )
  76. return new_secret
  77. # JWT settings
  78. SECRET_KEY = _get_jwt_secret()
  79. ALGORITHM = "HS256"
  80. ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
  81. # HTTP Bearer token
  82. security = HTTPBearer(auto_error=False)
  83. # --- Slicer download tokens ---
  84. # Short-lived tokens for slicer protocol handlers that can't send auth headers.
  85. # Maps token → (resource_key, expiry). resource_key = "archive:{id}" or "library:{id}".
  86. _slicer_tokens: dict[str, tuple[str, datetime]] = {}
  87. SLICER_TOKEN_EXPIRE_MINUTES = 5
  88. def create_slicer_download_token(resource_type: str, resource_id: int) -> str:
  89. """Create a short-lived download token for slicer protocol handlers."""
  90. # Cleanup expired tokens
  91. now = datetime.now(timezone.utc)
  92. expired = [k for k, (_, exp) in _slicer_tokens.items() if exp < now]
  93. for k in expired:
  94. del _slicer_tokens[k]
  95. token = secrets.token_urlsafe(24)
  96. resource_key = f"{resource_type}:{resource_id}"
  97. _slicer_tokens[token] = (resource_key, now + timedelta(minutes=SLICER_TOKEN_EXPIRE_MINUTES))
  98. return token
  99. def verify_slicer_download_token(token: str, resource_type: str, resource_id: int) -> bool:
  100. """Verify a slicer download token is valid for the given resource."""
  101. entry = _slicer_tokens.get(token)
  102. if not entry:
  103. return False
  104. resource_key, expiry = entry
  105. if datetime.now(timezone.utc) > expiry:
  106. del _slicer_tokens[token]
  107. return False
  108. expected_key = f"{resource_type}:{resource_id}"
  109. if resource_key != expected_key:
  110. return False
  111. # Token is single-use
  112. del _slicer_tokens[token]
  113. return True
  114. def verify_password(plain_password: str, hashed_password: str) -> bool:
  115. """Verify a password against a hash.
  116. Uses pbkdf2_sha256 which handles long passwords automatically.
  117. """
  118. return pwd_context.verify(plain_password, hashed_password)
  119. def get_password_hash(password: str) -> str:
  120. """Hash a password.
  121. Uses pbkdf2_sha256 which is secure and has no password length limit.
  122. """
  123. return pwd_context.hash(password)
  124. def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
  125. """Create a JWT access token."""
  126. to_encode = data.copy()
  127. if expires_delta:
  128. expire = datetime.now(timezone.utc) + expires_delta
  129. else:
  130. expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  131. to_encode.update({"exp": expire})
  132. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  133. return encoded_jwt
  134. async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
  135. """Get a user by username (case-insensitive) with groups loaded for permission checks."""
  136. result = await db.execute(
  137. select(User).where(func.lower(User.username) == func.lower(username)).options(selectinload(User.groups))
  138. )
  139. return result.scalar_one_or_none()
  140. async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
  141. """Get a user by email (case-insensitive) with groups loaded for permission checks."""
  142. result = await db.execute(
  143. select(User).where(func.lower(User.email) == func.lower(email)).options(selectinload(User.groups))
  144. )
  145. return result.scalar_one_or_none()
  146. async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None:
  147. """Authenticate a user by username and password.
  148. Username lookup is case-insensitive. Password is case-sensitive.
  149. """
  150. user = await get_user_by_username(db, username)
  151. if not user:
  152. return None
  153. if not verify_password(password, user.password_hash):
  154. return None
  155. if not user.is_active:
  156. return None
  157. return user
  158. async def authenticate_user_by_email(db: AsyncSession, email: str, password: str) -> User | None:
  159. """Authenticate a user by email and password.
  160. Email lookup is case-insensitive. Password is case-sensitive.
  161. """
  162. user = await get_user_by_email(db, email)
  163. if not user:
  164. return None
  165. if not verify_password(password, user.password_hash):
  166. return None
  167. if not user.is_active:
  168. return None
  169. return user
  170. async def is_auth_enabled(db: AsyncSession) -> bool:
  171. """Check if authentication is enabled."""
  172. try:
  173. result = await db.execute(select(Settings).where(Settings.key == "auth_enabled"))
  174. setting = result.scalar_one_or_none()
  175. if setting is None:
  176. return False
  177. return setting.value.lower() == "true"
  178. except Exception:
  179. # If settings table doesn't exist or query fails, assume auth is disabled
  180. return False
  181. async def _validate_api_key(db: AsyncSession, api_key_value: str) -> APIKey | None:
  182. """Validate an API key and return the APIKey object if valid, None otherwise.
  183. This is an internal helper used by auth functions to check API keys.
  184. """
  185. try:
  186. result = await db.execute(select(APIKey).where(APIKey.enabled.is_(True)))
  187. api_keys = result.scalars().all()
  188. for api_key in api_keys:
  189. if verify_password(api_key_value, api_key.key_hash):
  190. # Check expiration
  191. if api_key.expires_at:
  192. expires = api_key.expires_at
  193. if expires.tzinfo is None:
  194. expires = expires.replace(tzinfo=timezone.utc)
  195. if expires < datetime.now(timezone.utc):
  196. return None # Expired
  197. # Update last_used timestamp
  198. api_key.last_used = datetime.now(timezone.utc)
  199. await db.commit()
  200. return api_key
  201. except Exception as e:
  202. logger.warning("API key validation error: %s", e)
  203. return None
  204. async def get_current_user_optional(
  205. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  206. ) -> User | None:
  207. """Get the current authenticated user from JWT token, or None if not authenticated."""
  208. if credentials is None:
  209. return None
  210. try:
  211. token = credentials.credentials
  212. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  213. username: str = payload.get("sub")
  214. if username is None:
  215. return None
  216. except JWTError:
  217. return None
  218. async with async_session() as db:
  219. user = await get_user_by_username(db, username)
  220. if user is None or not user.is_active:
  221. return None
  222. return user
  223. async def get_current_user(
  224. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  225. ) -> User:
  226. """Get the current authenticated user from JWT token."""
  227. credentials_exception = HTTPException(
  228. status_code=status.HTTP_401_UNAUTHORIZED,
  229. detail="Could not validate credentials",
  230. headers={"WWW-Authenticate": "Bearer"},
  231. )
  232. if credentials is None:
  233. raise credentials_exception
  234. try:
  235. token = credentials.credentials
  236. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  237. username: str = payload.get("sub")
  238. if username is None:
  239. raise credentials_exception
  240. except JWTError:
  241. raise credentials_exception
  242. async with async_session() as db:
  243. user = await get_user_by_username(db, username)
  244. if user is None:
  245. raise credentials_exception
  246. if not user.is_active:
  247. raise HTTPException(
  248. status_code=status.HTTP_403_FORBIDDEN,
  249. detail="User account is disabled",
  250. )
  251. return user
  252. async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]) -> User:
  253. """Get the current active user (alias for clarity)."""
  254. return current_user
  255. async def require_auth_if_enabled(
  256. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  257. x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
  258. ) -> User | None:
  259. """Require authentication if auth is enabled, otherwise return None.
  260. Accepts both JWT tokens (via Authorization: Bearer header) and API keys
  261. (via X-API-Key header or Authorization: Bearer bb_xxx).
  262. """
  263. async with async_session() as db:
  264. auth_enabled = await is_auth_enabled(db)
  265. if not auth_enabled:
  266. return None
  267. # Check for API key first (X-API-Key header)
  268. if x_api_key:
  269. api_key = await _validate_api_key(db, x_api_key)
  270. if api_key:
  271. return None # API key valid, allow access
  272. # Check for Bearer token (could be JWT or API key)
  273. if credentials is not None:
  274. token = credentials.credentials
  275. # Check if it's an API key (starts with bb_)
  276. if token.startswith("bb_"):
  277. api_key = await _validate_api_key(db, token)
  278. if api_key:
  279. return None # API key valid, allow access
  280. raise HTTPException(
  281. status_code=status.HTTP_401_UNAUTHORIZED,
  282. detail="Invalid API key",
  283. headers={"WWW-Authenticate": "Bearer"},
  284. )
  285. # Otherwise treat as JWT
  286. try:
  287. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  288. username: str = payload.get("sub")
  289. if username is None:
  290. raise HTTPException(
  291. status_code=status.HTTP_401_UNAUTHORIZED,
  292. detail="Could not validate credentials",
  293. headers={"WWW-Authenticate": "Bearer"},
  294. )
  295. except JWTError:
  296. raise HTTPException(
  297. status_code=status.HTTP_401_UNAUTHORIZED,
  298. detail="Could not validate credentials",
  299. headers={"WWW-Authenticate": "Bearer"},
  300. )
  301. user = await get_user_by_username(db, username)
  302. if user is None or not user.is_active:
  303. raise HTTPException(
  304. status_code=status.HTTP_401_UNAUTHORIZED,
  305. detail="Could not validate credentials",
  306. headers={"WWW-Authenticate": "Bearer"},
  307. )
  308. return user
  309. # No credentials provided
  310. raise HTTPException(
  311. status_code=status.HTTP_401_UNAUTHORIZED,
  312. detail="Authentication required",
  313. headers={"WWW-Authenticate": "Bearer"},
  314. )
  315. def require_role(required_role: str):
  316. """Dependency factory for role-based access control."""
  317. async def role_checker(current_user: Annotated[User, Depends(get_current_user)]) -> User:
  318. if current_user.role != required_role:
  319. raise HTTPException(
  320. status_code=status.HTTP_403_FORBIDDEN,
  321. detail=f"Requires {required_role} role",
  322. )
  323. return current_user
  324. return role_checker
  325. def require_admin_if_auth_enabled():
  326. """Dependency factory that requires admin role if auth is enabled."""
  327. async def admin_checker(
  328. current_user: Annotated[User | None, Depends(require_auth_if_enabled)] = None,
  329. ) -> User | None:
  330. if current_user is None:
  331. return None # Auth not enabled, allow access
  332. if current_user.role != "admin":
  333. raise HTTPException(
  334. status_code=status.HTTP_403_FORBIDDEN,
  335. detail="Requires admin role",
  336. )
  337. return current_user
  338. return admin_checker
  339. def generate_api_key() -> tuple[str, str, str]:
  340. """Generate a new API key.
  341. Returns:
  342. tuple: (full_key, key_hash, key_prefix)
  343. - full_key: The complete API key (only shown once on creation)
  344. - key_hash: Hashed version for storage and verification
  345. - key_prefix: First 8 characters for display purposes
  346. """
  347. # Generate a secure random API key (32 bytes = 64 hex characters)
  348. full_key = f"bb_{secrets.token_urlsafe(32)}"
  349. key_hash = get_password_hash(full_key)
  350. key_prefix = full_key[:8] + "..." if len(full_key) > 8 else full_key
  351. return full_key, key_hash, key_prefix
  352. async def get_api_key(
  353. authorization: Annotated[str | None, Header(alias="Authorization")] = None,
  354. x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
  355. db: AsyncSession = Depends(get_db),
  356. ) -> APIKey:
  357. """Get and validate API key from request headers.
  358. Checks both 'Authorization: Bearer <key>' and 'X-API-Key: <key>' headers.
  359. """
  360. api_key_value = None
  361. if x_api_key:
  362. api_key_value = x_api_key
  363. elif authorization and authorization.startswith("Bearer "):
  364. api_key_value = authorization.replace("Bearer ", "")
  365. if not api_key_value:
  366. raise HTTPException(
  367. status_code=status.HTTP_401_UNAUTHORIZED,
  368. detail="API key required. Provide 'X-API-Key' header or 'Authorization: Bearer <key>'",
  369. )
  370. # Get all API keys and check them
  371. result = await db.execute(select(APIKey).where(APIKey.enabled.is_(True)))
  372. api_keys = result.scalars().all()
  373. for api_key in api_keys:
  374. # Check if key matches (verify against hash)
  375. if verify_password(api_key_value, api_key.key_hash):
  376. # Check expiration
  377. if api_key.expires_at:
  378. expires = api_key.expires_at
  379. if expires.tzinfo is None:
  380. expires = expires.replace(tzinfo=timezone.utc)
  381. if expires < datetime.now(timezone.utc):
  382. raise HTTPException(
  383. status_code=status.HTTP_401_UNAUTHORIZED,
  384. detail="API key has expired",
  385. )
  386. # Update last_used timestamp
  387. api_key.last_used = datetime.now(timezone.utc)
  388. await db.commit()
  389. return api_key
  390. raise HTTPException(
  391. status_code=status.HTTP_401_UNAUTHORIZED,
  392. detail="Invalid API key",
  393. )
  394. def check_permission(api_key: APIKey, permission: str) -> None:
  395. """Check if API key has the required permission.
  396. Args:
  397. api_key: The API key object
  398. permission: One of 'queue', 'control_printer', 'read_status'
  399. Raises:
  400. HTTPException: If permission is not granted
  401. """
  402. permission_map = {
  403. "queue": "can_queue",
  404. "control_printer": "can_control_printer",
  405. "read_status": "can_read_status",
  406. }
  407. if permission not in permission_map:
  408. raise HTTPException(
  409. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  410. detail=f"Unknown permission: {permission}",
  411. )
  412. attr_name = permission_map[permission]
  413. if not getattr(api_key, attr_name, False):
  414. raise HTTPException(
  415. status_code=status.HTTP_403_FORBIDDEN,
  416. detail=f"API key does not have '{permission}' permission",
  417. )
  418. def check_printer_access(api_key: APIKey, printer_id: int) -> None:
  419. """Check if API key has access to the specified printer.
  420. Args:
  421. api_key: The API key object
  422. printer_id: The printer ID to check access for
  423. Raises:
  424. HTTPException: If access is denied
  425. """
  426. # If printer_ids is None or empty, access to all printers
  427. if api_key.printer_ids is None or len(api_key.printer_ids) == 0:
  428. return
  429. # Check if printer_id is in allowed list
  430. if printer_id not in api_key.printer_ids:
  431. raise HTTPException(
  432. status_code=status.HTTP_403_FORBIDDEN,
  433. detail=f"API key does not have access to printer {printer_id}",
  434. )
  435. # Convenience dependencies - these are functions that return Depends objects
  436. def RequireAdmin():
  437. """Dependency that requires admin role."""
  438. return Depends(require_role("admin"))
  439. def RequireAdminIfAuthEnabled():
  440. """Dependency that requires admin role if auth is enabled."""
  441. return Depends(require_admin_if_auth_enabled())
  442. def require_permission(*permissions: str | Permission):
  443. """Dependency factory that requires user to have ALL specified permissions.
  444. Accepts both JWT tokens (via Authorization: Bearer header) and API keys
  445. (via X-API-Key header or Authorization: Bearer bb_xxx).
  446. Args:
  447. *permissions: Permission strings or Permission enum values to require
  448. Returns:
  449. A dependency function that validates permissions
  450. """
  451. # Convert Permission enums to strings
  452. perm_strings = [p.value if isinstance(p, Permission) else p for p in permissions]
  453. async def permission_checker(
  454. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  455. x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
  456. ) -> User | None:
  457. async with async_session() as db:
  458. # Check for API key first (X-API-Key header)
  459. if x_api_key:
  460. api_key = await _validate_api_key(db, x_api_key)
  461. if api_key:
  462. return None # API key valid, allow access
  463. credentials_exception = HTTPException(
  464. status_code=status.HTTP_401_UNAUTHORIZED,
  465. detail="Could not validate credentials",
  466. headers={"WWW-Authenticate": "Bearer"},
  467. )
  468. if credentials is None:
  469. raise credentials_exception
  470. token = credentials.credentials
  471. # Check if it's an API key (starts with bb_)
  472. if token.startswith("bb_"):
  473. api_key = await _validate_api_key(db, token)
  474. if api_key:
  475. return None # API key valid, allow access
  476. raise HTTPException(
  477. status_code=status.HTTP_401_UNAUTHORIZED,
  478. detail="Invalid API key",
  479. headers={"WWW-Authenticate": "Bearer"},
  480. )
  481. # Otherwise treat as JWT
  482. try:
  483. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  484. username: str = payload.get("sub")
  485. if username is None:
  486. raise credentials_exception
  487. except JWTError:
  488. raise credentials_exception
  489. user = await get_user_by_username(db, username)
  490. if user is None or not user.is_active:
  491. raise credentials_exception
  492. if not user.has_all_permissions(*perm_strings):
  493. raise HTTPException(
  494. status_code=status.HTTP_403_FORBIDDEN,
  495. detail=f"Missing required permissions: {', '.join(perm_strings)}",
  496. )
  497. return user
  498. return permission_checker
  499. def require_permission_if_auth_enabled(*permissions: str | Permission):
  500. """Dependency factory that checks permissions only if auth is enabled.
  501. This provides backward compatibility - when auth is disabled, all access is allowed.
  502. Accepts both JWT tokens (via Authorization: Bearer header) and API keys
  503. (via X-API-Key header or Authorization: Bearer bb_xxx).
  504. Args:
  505. *permissions: Permission strings or Permission enum values to require
  506. Returns:
  507. A dependency function that validates permissions if auth is enabled
  508. """
  509. # Convert Permission enums to strings
  510. perm_strings = [p.value if isinstance(p, Permission) else p for p in permissions]
  511. async def permission_checker(
  512. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  513. x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
  514. ) -> User | None:
  515. async with async_session() as db:
  516. auth_enabled = await is_auth_enabled(db)
  517. if not auth_enabled:
  518. return None # Auth disabled, allow access
  519. # Check for API key first (X-API-Key header)
  520. if x_api_key:
  521. api_key = await _validate_api_key(db, x_api_key)
  522. if api_key:
  523. return None # API key valid, allow access
  524. # Check for Bearer token (could be JWT or API key)
  525. if credentials is not None:
  526. token = credentials.credentials
  527. # Check if it's an API key (starts with bb_)
  528. if token.startswith("bb_"):
  529. api_key = await _validate_api_key(db, token)
  530. if api_key:
  531. return None # API key valid, allow access
  532. raise HTTPException(
  533. status_code=status.HTTP_401_UNAUTHORIZED,
  534. detail="Invalid API key",
  535. headers={"WWW-Authenticate": "Bearer"},
  536. )
  537. # Otherwise treat as JWT
  538. try:
  539. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  540. username: str = payload.get("sub")
  541. if username is None:
  542. raise HTTPException(
  543. status_code=status.HTTP_401_UNAUTHORIZED,
  544. detail="Could not validate credentials",
  545. headers={"WWW-Authenticate": "Bearer"},
  546. )
  547. except JWTError:
  548. raise HTTPException(
  549. status_code=status.HTTP_401_UNAUTHORIZED,
  550. detail="Could not validate credentials",
  551. headers={"WWW-Authenticate": "Bearer"},
  552. )
  553. user = await get_user_by_username(db, username)
  554. if user is None or not user.is_active:
  555. raise HTTPException(
  556. status_code=status.HTTP_401_UNAUTHORIZED,
  557. detail="Could not validate credentials",
  558. headers={"WWW-Authenticate": "Bearer"},
  559. )
  560. if not user.has_all_permissions(*perm_strings):
  561. raise HTTPException(
  562. status_code=status.HTTP_403_FORBIDDEN,
  563. detail=f"Missing required permissions: {', '.join(perm_strings)}",
  564. )
  565. return user
  566. # No credentials provided
  567. raise HTTPException(
  568. status_code=status.HTTP_401_UNAUTHORIZED,
  569. detail="Authentication required",
  570. headers={"WWW-Authenticate": "Bearer"},
  571. )
  572. return permission_checker
  573. def RequirePermission(*permissions: str | Permission):
  574. """Convenience dependency that requires ALL specified permissions."""
  575. return Depends(require_permission(*permissions))
  576. def RequirePermissionIfAuthEnabled(*permissions: str | Permission):
  577. """Convenience dependency that requires permissions if auth is enabled."""
  578. return Depends(require_permission_if_auth_enabled(*permissions))
  579. def require_ownership_permission(
  580. all_permission: str | Permission,
  581. own_permission: str | Permission,
  582. ):
  583. """Dependency factory for ownership-based permission checks.
  584. - User with `all_permission` can modify any item
  585. - User with `own_permission` can only modify items where created_by_id == user.id
  586. - Ownerless items (created_by_id = null) require `all_permission`
  587. - API keys (via X-API-Key header or Bearer bb_xxx) get full access (can_modify_all=True)
  588. Returns:
  589. A dependency function that returns (user, can_modify_all).
  590. - can_modify_all=True: user can modify any item
  591. - can_modify_all=False: user can only modify their own items
  592. """
  593. all_perm = all_permission.value if isinstance(all_permission, Permission) else all_permission
  594. own_perm = own_permission.value if isinstance(own_permission, Permission) else own_permission
  595. async def checker(
  596. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  597. x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
  598. ) -> tuple[User | None, bool]:
  599. """Returns (user, can_modify_all).
  600. - can_modify_all=True: user can modify any item
  601. - can_modify_all=False: user can only modify their own items
  602. """
  603. async with async_session() as db:
  604. auth_enabled = await is_auth_enabled(db)
  605. if not auth_enabled:
  606. return None, True # Auth disabled, allow all
  607. # Check for API key first (X-API-Key header)
  608. if x_api_key:
  609. api_key = await _validate_api_key(db, x_api_key)
  610. if api_key:
  611. return None, True # API key valid, allow all
  612. # Check for Bearer token (could be JWT or API key)
  613. if credentials is not None:
  614. token = credentials.credentials
  615. # Check if it's an API key (starts with bb_)
  616. if token.startswith("bb_"):
  617. api_key = await _validate_api_key(db, token)
  618. if api_key:
  619. return None, True # API key valid, allow all
  620. raise HTTPException(
  621. status_code=status.HTTP_401_UNAUTHORIZED,
  622. detail="Invalid API key",
  623. headers={"WWW-Authenticate": "Bearer"},
  624. )
  625. # Otherwise treat as JWT
  626. try:
  627. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  628. username: str = payload.get("sub")
  629. if username is None:
  630. raise HTTPException(
  631. status_code=status.HTTP_401_UNAUTHORIZED,
  632. detail="Could not validate credentials",
  633. headers={"WWW-Authenticate": "Bearer"},
  634. )
  635. except JWTError:
  636. raise HTTPException(
  637. status_code=status.HTTP_401_UNAUTHORIZED,
  638. detail="Could not validate credentials",
  639. headers={"WWW-Authenticate": "Bearer"},
  640. )
  641. user = await get_user_by_username(db, username)
  642. if user is None or not user.is_active:
  643. raise HTTPException(
  644. status_code=status.HTTP_401_UNAUTHORIZED,
  645. detail="Could not validate credentials",
  646. headers={"WWW-Authenticate": "Bearer"},
  647. )
  648. if user.has_permission(all_perm):
  649. return user, True
  650. if user.has_permission(own_perm):
  651. return user, False
  652. raise HTTPException(
  653. status_code=status.HTTP_403_FORBIDDEN,
  654. detail=f"Missing permission: {own_perm} or {all_perm}",
  655. )
  656. # No credentials provided
  657. raise HTTPException(
  658. status_code=status.HTTP_401_UNAUTHORIZED,
  659. detail="Authentication required",
  660. headers={"WWW-Authenticate": "Bearer"},
  661. )
  662. return checker