auth.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. from __future__ import annotations
  2. import secrets
  3. from datetime import datetime, timedelta
  4. from typing import Annotated
  5. import jwt
  6. from fastapi import Depends, Header, HTTPException, status
  7. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  8. from jwt.exceptions import PyJWTError as JWTError
  9. from passlib.context import CryptContext
  10. from sqlalchemy import select
  11. from sqlalchemy.ext.asyncio import AsyncSession
  12. from sqlalchemy.orm import selectinload
  13. from backend.app.core.database import async_session, get_db
  14. from backend.app.core.permissions import Permission
  15. from backend.app.models.api_key import APIKey
  16. from backend.app.models.settings import Settings
  17. from backend.app.models.user import User
  18. # Password hashing
  19. # Use pbkdf2_sha256 instead of bcrypt to avoid 72-byte limit and passlib initialization issues
  20. # pbkdf2_sha256 is a secure password hashing algorithm without bcrypt's limitations
  21. pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
  22. # JWT settings
  23. SECRET_KEY = "bambuddy-secret-key-change-in-production" # TODO: Move to settings/env
  24. ALGORITHM = "HS256"
  25. ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
  26. # HTTP Bearer token
  27. security = HTTPBearer(auto_error=False)
  28. def verify_password(plain_password: str, hashed_password: str) -> bool:
  29. """Verify a password against a hash.
  30. Uses pbkdf2_sha256 which handles long passwords automatically.
  31. """
  32. return pwd_context.verify(plain_password, hashed_password)
  33. def get_password_hash(password: str) -> str:
  34. """Hash a password.
  35. Uses pbkdf2_sha256 which is secure and has no password length limit.
  36. """
  37. return pwd_context.hash(password)
  38. def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
  39. """Create a JWT access token."""
  40. to_encode = data.copy()
  41. if expires_delta:
  42. expire = datetime.utcnow() + expires_delta
  43. else:
  44. expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  45. to_encode.update({"exp": expire})
  46. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  47. return encoded_jwt
  48. async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
  49. """Get a user by username with groups loaded for permission checks."""
  50. result = await db.execute(select(User).where(User.username == username).options(selectinload(User.groups)))
  51. return result.scalar_one_or_none()
  52. async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None:
  53. """Authenticate a user by username and password."""
  54. user = await get_user_by_username(db, username)
  55. if not user:
  56. return None
  57. if not verify_password(password, user.password_hash):
  58. return None
  59. if not user.is_active:
  60. return None
  61. return user
  62. async def is_auth_enabled(db: AsyncSession) -> bool:
  63. """Check if authentication is enabled."""
  64. try:
  65. result = await db.execute(select(Settings).where(Settings.key == "auth_enabled"))
  66. setting = result.scalar_one_or_none()
  67. if setting is None:
  68. return False
  69. return setting.value.lower() == "true"
  70. except Exception:
  71. # If settings table doesn't exist or query fails, assume auth is disabled
  72. return False
  73. async def get_current_user_optional(
  74. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  75. ) -> User | None:
  76. """Get the current authenticated user from JWT token, or None if not authenticated."""
  77. if credentials is None:
  78. return None
  79. try:
  80. token = credentials.credentials
  81. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  82. username: str = payload.get("sub")
  83. if username is None:
  84. return None
  85. except JWTError:
  86. return None
  87. async with async_session() as db:
  88. user = await get_user_by_username(db, username)
  89. if user is None or not user.is_active:
  90. return None
  91. return user
  92. async def get_current_user(
  93. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  94. ) -> User:
  95. """Get the current authenticated user from JWT token."""
  96. credentials_exception = HTTPException(
  97. status_code=status.HTTP_401_UNAUTHORIZED,
  98. detail="Could not validate credentials",
  99. headers={"WWW-Authenticate": "Bearer"},
  100. )
  101. if credentials is None:
  102. raise credentials_exception
  103. try:
  104. token = credentials.credentials
  105. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  106. username: str = payload.get("sub")
  107. if username is None:
  108. raise credentials_exception
  109. except JWTError:
  110. raise credentials_exception
  111. async with async_session() as db:
  112. user = await get_user_by_username(db, username)
  113. if user is None:
  114. raise credentials_exception
  115. if not user.is_active:
  116. raise HTTPException(
  117. status_code=status.HTTP_403_FORBIDDEN,
  118. detail="User account is disabled",
  119. )
  120. return user
  121. async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]) -> User:
  122. """Get the current active user (alias for clarity)."""
  123. return current_user
  124. async def require_auth_if_enabled(
  125. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  126. ) -> User | None:
  127. """Require authentication if auth is enabled, otherwise return None."""
  128. async with async_session() as db:
  129. auth_enabled = await is_auth_enabled(db)
  130. if not auth_enabled:
  131. return None
  132. if credentials is None:
  133. raise HTTPException(
  134. status_code=status.HTTP_401_UNAUTHORIZED,
  135. detail="Authentication required",
  136. headers={"WWW-Authenticate": "Bearer"},
  137. )
  138. try:
  139. token = credentials.credentials
  140. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  141. username: str = payload.get("sub")
  142. if username is None:
  143. raise HTTPException(
  144. status_code=status.HTTP_401_UNAUTHORIZED,
  145. detail="Could not validate credentials",
  146. headers={"WWW-Authenticate": "Bearer"},
  147. )
  148. except JWTError:
  149. raise HTTPException(
  150. status_code=status.HTTP_401_UNAUTHORIZED,
  151. detail="Could not validate credentials",
  152. headers={"WWW-Authenticate": "Bearer"},
  153. )
  154. user = await get_user_by_username(db, username)
  155. if user is None or not user.is_active:
  156. raise HTTPException(
  157. status_code=status.HTTP_401_UNAUTHORIZED,
  158. detail="Could not validate credentials",
  159. headers={"WWW-Authenticate": "Bearer"},
  160. )
  161. return user
  162. def require_role(required_role: str):
  163. """Dependency factory for role-based access control."""
  164. async def role_checker(current_user: Annotated[User, Depends(get_current_user)]) -> User:
  165. if current_user.role != required_role:
  166. raise HTTPException(
  167. status_code=status.HTTP_403_FORBIDDEN,
  168. detail=f"Requires {required_role} role",
  169. )
  170. return current_user
  171. return role_checker
  172. def require_admin_if_auth_enabled():
  173. """Dependency factory that requires admin role if auth is enabled."""
  174. async def admin_checker(
  175. current_user: Annotated[User | None, Depends(require_auth_if_enabled)] = None,
  176. ) -> User | None:
  177. if current_user is None:
  178. return None # Auth not enabled, allow access
  179. if current_user.role != "admin":
  180. raise HTTPException(
  181. status_code=status.HTTP_403_FORBIDDEN,
  182. detail="Requires admin role",
  183. )
  184. return current_user
  185. return admin_checker
  186. def generate_api_key() -> tuple[str, str, str]:
  187. """Generate a new API key.
  188. Returns:
  189. tuple: (full_key, key_hash, key_prefix)
  190. - full_key: The complete API key (only shown once on creation)
  191. - key_hash: Hashed version for storage and verification
  192. - key_prefix: First 8 characters for display purposes
  193. """
  194. # Generate a secure random API key (32 bytes = 64 hex characters)
  195. full_key = f"bb_{secrets.token_urlsafe(32)}"
  196. key_hash = get_password_hash(full_key)
  197. key_prefix = full_key[:8] + "..." if len(full_key) > 8 else full_key
  198. return full_key, key_hash, key_prefix
  199. async def get_api_key(
  200. authorization: Annotated[str | None, Header(alias="Authorization")] = None,
  201. x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
  202. db: AsyncSession = Depends(get_db),
  203. ) -> APIKey:
  204. """Get and validate API key from request headers.
  205. Checks both 'Authorization: Bearer <key>' and 'X-API-Key: <key>' headers.
  206. """
  207. api_key_value = None
  208. if x_api_key:
  209. api_key_value = x_api_key
  210. elif authorization and authorization.startswith("Bearer "):
  211. api_key_value = authorization.replace("Bearer ", "")
  212. if not api_key_value:
  213. raise HTTPException(
  214. status_code=status.HTTP_401_UNAUTHORIZED,
  215. detail="API key required. Provide 'X-API-Key' header or 'Authorization: Bearer <key>'",
  216. )
  217. # Get all API keys and check them
  218. result = await db.execute(select(APIKey).where(APIKey.enabled.is_(True)))
  219. api_keys = result.scalars().all()
  220. for api_key in api_keys:
  221. # Check if key matches (verify against hash)
  222. if verify_password(api_key_value, api_key.key_hash):
  223. # Check expiration
  224. if api_key.expires_at and api_key.expires_at < datetime.now():
  225. raise HTTPException(
  226. status_code=status.HTTP_401_UNAUTHORIZED,
  227. detail="API key has expired",
  228. )
  229. # Update last_used timestamp
  230. api_key.last_used = datetime.now()
  231. await db.commit()
  232. return api_key
  233. raise HTTPException(
  234. status_code=status.HTTP_401_UNAUTHORIZED,
  235. detail="Invalid API key",
  236. )
  237. def check_permission(api_key: APIKey, permission: str) -> None:
  238. """Check if API key has the required permission.
  239. Args:
  240. api_key: The API key object
  241. permission: One of 'queue', 'control_printer', 'read_status'
  242. Raises:
  243. HTTPException: If permission is not granted
  244. """
  245. permission_map = {
  246. "queue": "can_queue",
  247. "control_printer": "can_control_printer",
  248. "read_status": "can_read_status",
  249. }
  250. if permission not in permission_map:
  251. raise HTTPException(
  252. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  253. detail=f"Unknown permission: {permission}",
  254. )
  255. attr_name = permission_map[permission]
  256. if not getattr(api_key, attr_name, False):
  257. raise HTTPException(
  258. status_code=status.HTTP_403_FORBIDDEN,
  259. detail=f"API key does not have '{permission}' permission",
  260. )
  261. def check_printer_access(api_key: APIKey, printer_id: int) -> None:
  262. """Check if API key has access to the specified printer.
  263. Args:
  264. api_key: The API key object
  265. printer_id: The printer ID to check access for
  266. Raises:
  267. HTTPException: If access is denied
  268. """
  269. # If printer_ids is None or empty, access to all printers
  270. if api_key.printer_ids is None or len(api_key.printer_ids) == 0:
  271. return
  272. # Check if printer_id is in allowed list
  273. if printer_id not in api_key.printer_ids:
  274. raise HTTPException(
  275. status_code=status.HTTP_403_FORBIDDEN,
  276. detail=f"API key does not have access to printer {printer_id}",
  277. )
  278. # Convenience dependencies - these are functions that return Depends objects
  279. def RequireAdmin():
  280. """Dependency that requires admin role."""
  281. return Depends(require_role("admin"))
  282. def RequireAdminIfAuthEnabled():
  283. """Dependency that requires admin role if auth is enabled."""
  284. return Depends(require_admin_if_auth_enabled())
  285. def require_permission(*permissions: str | Permission):
  286. """Dependency factory that requires user to have ALL specified permissions.
  287. Args:
  288. *permissions: Permission strings or Permission enum values to require
  289. Returns:
  290. A dependency function that validates permissions
  291. """
  292. # Convert Permission enums to strings
  293. perm_strings = [p.value if isinstance(p, Permission) else p for p in permissions]
  294. async def permission_checker(
  295. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  296. ) -> User:
  297. credentials_exception = HTTPException(
  298. status_code=status.HTTP_401_UNAUTHORIZED,
  299. detail="Could not validate credentials",
  300. headers={"WWW-Authenticate": "Bearer"},
  301. )
  302. if credentials is None:
  303. raise credentials_exception
  304. try:
  305. token = credentials.credentials
  306. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  307. username: str = payload.get("sub")
  308. if username is None:
  309. raise credentials_exception
  310. except JWTError:
  311. raise credentials_exception
  312. async with async_session() as db:
  313. user = await get_user_by_username(db, username)
  314. if user is None or not user.is_active:
  315. raise credentials_exception
  316. if not user.has_all_permissions(*perm_strings):
  317. raise HTTPException(
  318. status_code=status.HTTP_403_FORBIDDEN,
  319. detail=f"Missing required permissions: {', '.join(perm_strings)}",
  320. )
  321. return user
  322. return permission_checker
  323. def require_permission_if_auth_enabled(*permissions: str | Permission):
  324. """Dependency factory that checks permissions only if auth is enabled.
  325. This provides backward compatibility - when auth is disabled, all access is allowed.
  326. Args:
  327. *permissions: Permission strings or Permission enum values to require
  328. Returns:
  329. A dependency function that validates permissions if auth is enabled
  330. """
  331. # Convert Permission enums to strings
  332. perm_strings = [p.value if isinstance(p, Permission) else p for p in permissions]
  333. async def permission_checker(
  334. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  335. ) -> User | None:
  336. async with async_session() as db:
  337. auth_enabled = await is_auth_enabled(db)
  338. if not auth_enabled:
  339. return None # Auth disabled, allow access
  340. if credentials is None:
  341. raise HTTPException(
  342. status_code=status.HTTP_401_UNAUTHORIZED,
  343. detail="Authentication required",
  344. headers={"WWW-Authenticate": "Bearer"},
  345. )
  346. try:
  347. token = credentials.credentials
  348. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  349. username: str = payload.get("sub")
  350. if username is None:
  351. raise HTTPException(
  352. status_code=status.HTTP_401_UNAUTHORIZED,
  353. detail="Could not validate credentials",
  354. headers={"WWW-Authenticate": "Bearer"},
  355. )
  356. except JWTError:
  357. raise HTTPException(
  358. status_code=status.HTTP_401_UNAUTHORIZED,
  359. detail="Could not validate credentials",
  360. headers={"WWW-Authenticate": "Bearer"},
  361. )
  362. user = await get_user_by_username(db, username)
  363. if user is None or not user.is_active:
  364. raise HTTPException(
  365. status_code=status.HTTP_401_UNAUTHORIZED,
  366. detail="Could not validate credentials",
  367. headers={"WWW-Authenticate": "Bearer"},
  368. )
  369. if not user.has_all_permissions(*perm_strings):
  370. raise HTTPException(
  371. status_code=status.HTTP_403_FORBIDDEN,
  372. detail=f"Missing required permissions: {', '.join(perm_strings)}",
  373. )
  374. return user
  375. return permission_checker
  376. def RequirePermission(*permissions: str | Permission):
  377. """Convenience dependency that requires ALL specified permissions."""
  378. return Depends(require_permission(*permissions))
  379. def RequirePermissionIfAuthEnabled(*permissions: str | Permission):
  380. """Convenience dependency that requires permissions if auth is enabled."""
  381. return Depends(require_permission_if_auth_enabled(*permissions))
  382. def require_ownership_permission(
  383. all_permission: str | Permission,
  384. own_permission: str | Permission,
  385. ):
  386. """Dependency factory for ownership-based permission checks.
  387. - User with `all_permission` can modify any item
  388. - User with `own_permission` can only modify items where created_by_id == user.id
  389. - Ownerless items (created_by_id = null) require `all_permission`
  390. Returns:
  391. A dependency function that returns (user, can_modify_all).
  392. - can_modify_all=True: user can modify any item
  393. - can_modify_all=False: user can only modify their own items
  394. """
  395. all_perm = all_permission.value if isinstance(all_permission, Permission) else all_permission
  396. own_perm = own_permission.value if isinstance(own_permission, Permission) else own_permission
  397. async def checker(
  398. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  399. ) -> tuple[User | None, bool]:
  400. """Returns (user, can_modify_all).
  401. - can_modify_all=True: user can modify any item
  402. - can_modify_all=False: user can only modify their own items
  403. """
  404. async with async_session() as db:
  405. auth_enabled = await is_auth_enabled(db)
  406. if not auth_enabled:
  407. return None, True # Auth disabled, allow all
  408. if credentials is None:
  409. raise HTTPException(
  410. status_code=status.HTTP_401_UNAUTHORIZED,
  411. detail="Authentication required",
  412. headers={"WWW-Authenticate": "Bearer"},
  413. )
  414. try:
  415. token = credentials.credentials
  416. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  417. username: str = payload.get("sub")
  418. if username is None:
  419. raise HTTPException(
  420. status_code=status.HTTP_401_UNAUTHORIZED,
  421. detail="Could not validate credentials",
  422. headers={"WWW-Authenticate": "Bearer"},
  423. )
  424. except JWTError:
  425. raise HTTPException(
  426. status_code=status.HTTP_401_UNAUTHORIZED,
  427. detail="Could not validate credentials",
  428. headers={"WWW-Authenticate": "Bearer"},
  429. )
  430. user = await get_user_by_username(db, username)
  431. if user is None or not user.is_active:
  432. raise HTTPException(
  433. status_code=status.HTTP_401_UNAUTHORIZED,
  434. detail="Could not validate credentials",
  435. headers={"WWW-Authenticate": "Bearer"},
  436. )
  437. if user.has_permission(all_perm):
  438. return user, True
  439. if user.has_permission(own_perm):
  440. return user, False
  441. raise HTTPException(
  442. status_code=status.HTTP_403_FORBIDDEN,
  443. detail=f"Missing permission: {own_perm} or {all_perm}",
  444. )
  445. return checker