auth.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. import secrets
  2. from datetime import datetime, timedelta
  3. from typing import Annotated, TYPE_CHECKING
  4. from fastapi import Depends, Header, HTTPException, status
  5. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  6. from jose import JWTError, jwt
  7. from passlib.context import CryptContext
  8. from sqlalchemy import select
  9. from sqlalchemy.ext.asyncio import AsyncSession
  10. from backend.app.core.config import settings
  11. from backend.app.core.database import async_session, get_db
  12. from backend.app.models.settings import Settings
  13. from backend.app.models.user import User
  14. if TYPE_CHECKING:
  15. from backend.app.models.api_key import APIKey
  16. # Password hashing
  17. # Use pbkdf2_sha256 instead of bcrypt to avoid 72-byte limit and passlib initialization issues
  18. # pbkdf2_sha256 is a secure password hashing algorithm without bcrypt's limitations
  19. pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
  20. # JWT settings
  21. SECRET_KEY = "bambuddy-secret-key-change-in-production" # TODO: Move to settings/env
  22. ALGORITHM = "HS256"
  23. ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
  24. # HTTP Bearer token
  25. security = HTTPBearer(auto_error=False)
  26. def verify_password(plain_password: str, hashed_password: str) -> bool:
  27. """Verify a password against a hash.
  28. Uses pbkdf2_sha256 which handles long passwords automatically.
  29. """
  30. return pwd_context.verify(plain_password, hashed_password)
  31. def get_password_hash(password: str) -> str:
  32. """Hash a password.
  33. Uses pbkdf2_sha256 which is secure and has no password length limit.
  34. """
  35. return pwd_context.hash(password)
  36. def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
  37. """Create a JWT access token."""
  38. to_encode = data.copy()
  39. if expires_delta:
  40. expire = datetime.utcnow() + expires_delta
  41. else:
  42. expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  43. to_encode.update({"exp": expire})
  44. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  45. return encoded_jwt
  46. async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
  47. """Get a user by username."""
  48. result = await db.execute(select(User).where(User.username == username))
  49. return result.scalar_one_or_none()
  50. async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None:
  51. """Authenticate a user by username and password."""
  52. user = await get_user_by_username(db, username)
  53. if not user:
  54. return None
  55. if not verify_password(password, user.password_hash):
  56. return None
  57. if not user.is_active:
  58. return None
  59. return user
  60. async def is_auth_enabled(db: AsyncSession) -> bool:
  61. """Check if authentication is enabled."""
  62. result = await db.execute(select(Settings).where(Settings.key == "auth_enabled"))
  63. setting = result.scalar_one_or_none()
  64. return setting and setting.value.lower() == "true"
  65. async def get_current_user_optional(
  66. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  67. ) -> User | None:
  68. """Get the current authenticated user from JWT token, or None if not authenticated."""
  69. if credentials is None:
  70. return None
  71. try:
  72. token = credentials.credentials
  73. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  74. username: str = payload.get("sub")
  75. if username is None:
  76. return None
  77. except JWTError:
  78. return None
  79. async with async_session() as db:
  80. user = await get_user_by_username(db, username)
  81. if user is None or not user.is_active:
  82. return None
  83. return user
  84. async def get_current_user(
  85. credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]
  86. ) -> User:
  87. """Get the current authenticated user from JWT token."""
  88. credentials_exception = HTTPException(
  89. status_code=status.HTTP_401_UNAUTHORIZED,
  90. detail="Could not validate credentials",
  91. headers={"WWW-Authenticate": "Bearer"},
  92. )
  93. try:
  94. token = credentials.credentials
  95. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  96. username: str = payload.get("sub")
  97. if username is None:
  98. raise credentials_exception
  99. except JWTError:
  100. raise credentials_exception
  101. async with async_session() as db:
  102. user = await get_user_by_username(db, username)
  103. if user is None:
  104. raise credentials_exception
  105. if not user.is_active:
  106. raise HTTPException(
  107. status_code=status.HTTP_403_FORBIDDEN,
  108. detail="User account is disabled",
  109. )
  110. return user
  111. async def get_current_active_user(
  112. current_user: Annotated[User, Depends(get_current_user)]
  113. ) -> User:
  114. """Get the current active user (alias for clarity)."""
  115. return current_user
  116. async def require_auth_if_enabled(
  117. credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
  118. ) -> User | None:
  119. """Require authentication if auth is enabled, otherwise return None."""
  120. async with async_session() as db:
  121. auth_enabled = await is_auth_enabled(db)
  122. if not auth_enabled:
  123. return None
  124. if credentials is None:
  125. raise HTTPException(
  126. status_code=status.HTTP_401_UNAUTHORIZED,
  127. detail="Authentication required",
  128. headers={"WWW-Authenticate": "Bearer"},
  129. )
  130. try:
  131. token = credentials.credentials
  132. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  133. username: str = payload.get("sub")
  134. if username is None:
  135. raise HTTPException(
  136. status_code=status.HTTP_401_UNAUTHORIZED,
  137. detail="Could not validate credentials",
  138. headers={"WWW-Authenticate": "Bearer"},
  139. )
  140. except JWTError:
  141. raise HTTPException(
  142. status_code=status.HTTP_401_UNAUTHORIZED,
  143. detail="Could not validate credentials",
  144. headers={"WWW-Authenticate": "Bearer"},
  145. )
  146. user = await get_user_by_username(db, username)
  147. if user is None or not user.is_active:
  148. raise HTTPException(
  149. status_code=status.HTTP_401_UNAUTHORIZED,
  150. detail="Could not validate credentials",
  151. headers={"WWW-Authenticate": "Bearer"},
  152. )
  153. return user
  154. def require_role(required_role: str):
  155. """Dependency factory for role-based access control."""
  156. async def role_checker(
  157. current_user: Annotated[User, Depends(get_current_user)]
  158. ) -> User:
  159. if current_user.role != required_role:
  160. raise HTTPException(
  161. status_code=status.HTTP_403_FORBIDDEN,
  162. detail=f"Requires {required_role} role",
  163. )
  164. return current_user
  165. return role_checker
  166. def require_admin_if_auth_enabled():
  167. """Dependency factory that requires admin role if auth is enabled."""
  168. async def admin_checker(
  169. current_user: Annotated[User | None, Depends(require_auth_if_enabled)] = None,
  170. ) -> User | None:
  171. if current_user is None:
  172. return None # Auth not enabled, allow access
  173. if current_user.role != "admin":
  174. raise HTTPException(
  175. status_code=status.HTTP_403_FORBIDDEN,
  176. detail="Requires admin role",
  177. )
  178. return current_user
  179. return admin_checker
  180. def generate_api_key() -> tuple[str, str, str]:
  181. """Generate a new API key.
  182. Returns:
  183. tuple: (full_key, key_hash, key_prefix)
  184. - full_key: The complete API key (only shown once on creation)
  185. - key_hash: Hashed version for storage and verification
  186. - key_prefix: First 8 characters for display purposes
  187. """
  188. # Generate a secure random API key (32 bytes = 64 hex characters)
  189. full_key = f"bb_{secrets.token_urlsafe(32)}"
  190. key_hash = get_password_hash(full_key)
  191. key_prefix = full_key[:8] + "..." if len(full_key) > 8 else full_key
  192. return full_key, key_hash, key_prefix
  193. async def get_api_key(
  194. authorization: Annotated[str | None, Header(alias="Authorization")] = None,
  195. x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
  196. db: AsyncSession = Depends(get_db),
  197. ) -> "APIKey":
  198. """Get and validate API key from request headers.
  199. Checks both 'Authorization: Bearer <key>' and 'X-API-Key: <key>' headers.
  200. """
  201. from fastapi import HTTPException, status
  202. from backend.app.models.api_key import APIKey
  203. api_key_value = None
  204. if x_api_key:
  205. api_key_value = x_api_key
  206. elif authorization and authorization.startswith("Bearer "):
  207. api_key_value = authorization.replace("Bearer ", "")
  208. if not api_key_value:
  209. raise HTTPException(
  210. status_code=status.HTTP_401_UNAUTHORIZED,
  211. detail="API key required. Provide 'X-API-Key' header or 'Authorization: Bearer <key>'",
  212. )
  213. # Get all API keys and check them
  214. result = await db.execute(select(APIKey).where(APIKey.enabled.is_(True)))
  215. api_keys = result.scalars().all()
  216. for api_key in api_keys:
  217. # Check if key matches (verify against hash)
  218. if verify_password(api_key_value, api_key.key_hash):
  219. # Check expiration
  220. if api_key.expires_at and api_key.expires_at < datetime.now():
  221. raise HTTPException(
  222. status_code=status.HTTP_401_UNAUTHORIZED,
  223. detail="API key has expired",
  224. )
  225. # Update last_used timestamp
  226. api_key.last_used = datetime.now()
  227. await db.commit()
  228. return api_key
  229. raise HTTPException(
  230. status_code=status.HTTP_401_UNAUTHORIZED,
  231. detail="Invalid API key",
  232. )
  233. def check_permission(api_key: "APIKey", permission: str) -> None:
  234. """Check if API key has the required permission.
  235. Args:
  236. api_key: The API key object
  237. permission: One of 'queue', 'control_printer', 'read_status'
  238. Raises:
  239. HTTPException: If permission is not granted
  240. """
  241. from fastapi import HTTPException, status
  242. permission_map = {
  243. "queue": "can_queue",
  244. "control_printer": "can_control_printer",
  245. "read_status": "can_read_status",
  246. }
  247. if permission not in permission_map:
  248. raise HTTPException(
  249. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  250. detail=f"Unknown permission: {permission}",
  251. )
  252. attr_name = permission_map[permission]
  253. if not getattr(api_key, attr_name, False):
  254. raise HTTPException(
  255. status_code=status.HTTP_403_FORBIDDEN,
  256. detail=f"API key does not have '{permission}' permission",
  257. )
  258. def check_printer_access(api_key: "APIKey", printer_id: int) -> None:
  259. """Check if API key has access to the specified printer.
  260. Args:
  261. api_key: The API key object
  262. printer_id: The printer ID to check access for
  263. Raises:
  264. HTTPException: If access is denied
  265. """
  266. from fastapi import HTTPException, status
  267. # If printer_ids is None or empty, access to all printers
  268. if api_key.printer_ids is None or len(api_key.printer_ids) == 0:
  269. return
  270. # Check if printer_id is in allowed list
  271. if printer_id not in api_key.printer_ids:
  272. raise HTTPException(
  273. status_code=status.HTTP_403_FORBIDDEN,
  274. detail=f"API key does not have access to printer {printer_id}",
  275. )
  276. # Convenience dependencies - these are functions that return Depends objects
  277. def RequireAdmin():
  278. """Dependency that requires admin role."""
  279. return Depends(require_role("admin"))
  280. def RequireAdminIfAuthEnabled():
  281. """Dependency that requires admin role if auth is enabled."""
  282. return Depends(require_admin_if_auth_enabled())