|
|
@@ -1,350 +1,105 @@
|
|
|
+import hashlib
|
|
|
import secrets
|
|
|
-from datetime import datetime, timedelta
|
|
|
-from typing import TYPE_CHECKING, Annotated
|
|
|
+from datetime import datetime
|
|
|
|
|
|
-from fastapi import Depends, Header, HTTPException, status
|
|
|
-from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
-from jose import JWTError, jwt
|
|
|
-from passlib.context import CryptContext
|
|
|
+from fastapi import Depends, Header, HTTPException
|
|
|
from sqlalchemy import select
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
-from backend.app.core.database import async_session, get_db
|
|
|
-from backend.app.models.settings import Settings
|
|
|
-from backend.app.models.user import User
|
|
|
+from backend.app.core.database import get_db
|
|
|
+from backend.app.models.api_key import APIKey
|
|
|
|
|
|
-if TYPE_CHECKING:
|
|
|
- from backend.app.models.api_key import APIKey
|
|
|
|
|
|
-# Password hashing
|
|
|
-# Use pbkdf2_sha256 instead of bcrypt to avoid 72-byte limit and passlib initialization issues
|
|
|
-# pbkdf2_sha256 is a secure password hashing algorithm without bcrypt's limitations
|
|
|
-pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
|
|
-
|
|
|
-# JWT settings
|
|
|
-SECRET_KEY = "bambuddy-secret-key-change-in-production" # TODO: Move to settings/env
|
|
|
-ALGORITHM = "HS256"
|
|
|
-ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
|
|
|
-
|
|
|
-# HTTP Bearer token
|
|
|
-security = HTTPBearer(auto_error=False)
|
|
|
-
|
|
|
-
|
|
|
-def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
|
- """Verify a password against a hash.
|
|
|
+def generate_api_key() -> tuple[str, str, str]:
|
|
|
+ """Generate a new API key.
|
|
|
|
|
|
- Uses pbkdf2_sha256 which handles long passwords automatically.
|
|
|
+ Returns:
|
|
|
+ Tuple of (full_key, key_hash, key_prefix)
|
|
|
"""
|
|
|
- return pwd_context.verify(plain_password, hashed_password)
|
|
|
+ # Generate a random 32-byte key and encode as hex (64 chars)
|
|
|
+ full_key = f"bb_{secrets.token_hex(32)}"
|
|
|
+ key_hash = hashlib.sha256(full_key.encode()).hexdigest()
|
|
|
+ key_prefix = full_key[:11] # "bb_" + first 8 chars of token
|
|
|
+ return full_key, key_hash, key_prefix
|
|
|
|
|
|
|
|
|
-def get_password_hash(password: str) -> str:
|
|
|
- """Hash a password.
|
|
|
+def hash_api_key(key: str) -> str:
|
|
|
+ """Hash an API key for comparison."""
|
|
|
+ return hashlib.sha256(key.encode()).hexdigest()
|
|
|
|
|
|
- Uses pbkdf2_sha256 which is secure and has no password length limit.
|
|
|
- """
|
|
|
- return pwd_context.hash(password)
|
|
|
|
|
|
+async def get_api_key(
|
|
|
+ x_api_key: str = Header(..., alias="X-API-Key"),
|
|
|
+ db: AsyncSession = Depends(get_db),
|
|
|
+) -> APIKey:
|
|
|
+ """Verify API key and return the key record.
|
|
|
|
|
|
-def create_access_token(data: dict, expires_delta: timedelta | None = None) -> str:
|
|
|
- """Create a JWT access token."""
|
|
|
- to_encode = data.copy()
|
|
|
- if expires_delta:
|
|
|
- expire = datetime.utcnow() + expires_delta
|
|
|
- else:
|
|
|
- expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
- to_encode.update({"exp": expire})
|
|
|
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
- return encoded_jwt
|
|
|
+ Raises HTTPException if key is invalid, disabled, or expired.
|
|
|
+ """
|
|
|
+ key_hash = hash_api_key(x_api_key)
|
|
|
|
|
|
+ result = await db.execute(select(APIKey).where(APIKey.key_hash == key_hash))
|
|
|
+ api_key = result.scalar_one_or_none()
|
|
|
|
|
|
-async def get_user_by_username(db: AsyncSession, username: str) -> User | None:
|
|
|
- """Get a user by username."""
|
|
|
- result = await db.execute(select(User).where(User.username == username))
|
|
|
- return result.scalar_one_or_none()
|
|
|
+ if not api_key:
|
|
|
+ raise HTTPException(status_code=401, detail="Invalid API key")
|
|
|
|
|
|
+ if not api_key.enabled:
|
|
|
+ raise HTTPException(status_code=403, detail="API key is disabled")
|
|
|
|
|
|
-async def authenticate_user(db: AsyncSession, username: str, password: str) -> User | None:
|
|
|
- """Authenticate a user by username and password."""
|
|
|
- user = await get_user_by_username(db, username)
|
|
|
- if not user:
|
|
|
- return None
|
|
|
- if not verify_password(password, user.password_hash):
|
|
|
- return None
|
|
|
- if not user.is_active:
|
|
|
- return None
|
|
|
- return user
|
|
|
+ if api_key.expires_at and api_key.expires_at < datetime.utcnow():
|
|
|
+ raise HTTPException(status_code=403, detail="API key has expired")
|
|
|
|
|
|
+ # Update last_used timestamp
|
|
|
+ api_key.last_used = datetime.utcnow()
|
|
|
|
|
|
-async def is_auth_enabled(db: AsyncSession) -> bool:
|
|
|
- """Check if authentication is enabled."""
|
|
|
- try:
|
|
|
- result = await db.execute(select(Settings).where(Settings.key == "auth_enabled"))
|
|
|
- setting = result.scalar_one_or_none()
|
|
|
- return setting and setting.value.lower() == "true"
|
|
|
- except Exception:
|
|
|
- # If settings table doesn't exist or query fails, assume auth is disabled
|
|
|
- return False
|
|
|
+ return api_key
|
|
|
|
|
|
|
|
|
-async def get_current_user_optional(
|
|
|
- credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
|
|
|
-) -> User | None:
|
|
|
- """Get the current authenticated user from JWT token, or None if not authenticated."""
|
|
|
- if credentials is None:
|
|
|
+async def get_optional_api_key(
|
|
|
+ x_api_key: str | None = Header(None, alias="X-API-Key"),
|
|
|
+ db: AsyncSession = Depends(get_db),
|
|
|
+) -> APIKey | None:
|
|
|
+ """Get API key if provided, return None otherwise."""
|
|
|
+ if not x_api_key:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
- token = credentials.credentials
|
|
|
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
- username: str = payload.get("sub")
|
|
|
- if username is None:
|
|
|
- return None
|
|
|
- except JWTError:
|
|
|
+ return await get_api_key(x_api_key, db)
|
|
|
+ except HTTPException:
|
|
|
return None
|
|
|
|
|
|
- async with async_session() as db:
|
|
|
- user = await get_user_by_username(db, username)
|
|
|
- if user is None or not user.is_active:
|
|
|
- return None
|
|
|
- return user
|
|
|
-
|
|
|
-
|
|
|
-async def get_current_user(credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)]) -> User:
|
|
|
- """Get the current authenticated user from JWT token."""
|
|
|
- credentials_exception = HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="Could not validate credentials",
|
|
|
- headers={"WWW-Authenticate": "Bearer"},
|
|
|
- )
|
|
|
- try:
|
|
|
- token = credentials.credentials
|
|
|
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
- username: str = payload.get("sub")
|
|
|
- if username is None:
|
|
|
- raise credentials_exception
|
|
|
- except JWTError:
|
|
|
- raise credentials_exception
|
|
|
-
|
|
|
- async with async_session() as db:
|
|
|
- user = await get_user_by_username(db, username)
|
|
|
- if user is None:
|
|
|
- raise credentials_exception
|
|
|
- if not user.is_active:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_403_FORBIDDEN,
|
|
|
- detail="User account is disabled",
|
|
|
- )
|
|
|
- return user
|
|
|
-
|
|
|
-
|
|
|
-async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]) -> User:
|
|
|
- """Get the current active user (alias for clarity)."""
|
|
|
- return current_user
|
|
|
-
|
|
|
-
|
|
|
-async def require_auth_if_enabled(
|
|
|
- credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)] = None,
|
|
|
-) -> User | None:
|
|
|
- """Require authentication if auth is enabled, otherwise return None."""
|
|
|
- async with async_session() as db:
|
|
|
- auth_enabled = await is_auth_enabled(db)
|
|
|
- if not auth_enabled:
|
|
|
- return None
|
|
|
-
|
|
|
- if credentials is None:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="Authentication required",
|
|
|
- headers={"WWW-Authenticate": "Bearer"},
|
|
|
- )
|
|
|
-
|
|
|
- try:
|
|
|
- token = credentials.credentials
|
|
|
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
- username: str = payload.get("sub")
|
|
|
- if username is None:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="Could not validate credentials",
|
|
|
- headers={"WWW-Authenticate": "Bearer"},
|
|
|
- )
|
|
|
- except JWTError:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="Could not validate credentials",
|
|
|
- headers={"WWW-Authenticate": "Bearer"},
|
|
|
- )
|
|
|
-
|
|
|
- user = await get_user_by_username(db, username)
|
|
|
- if user is None or not user.is_active:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="Could not validate credentials",
|
|
|
- headers={"WWW-Authenticate": "Bearer"},
|
|
|
- )
|
|
|
- return user
|
|
|
-
|
|
|
-
|
|
|
-def require_role(required_role: str):
|
|
|
- """Dependency factory for role-based access control."""
|
|
|
-
|
|
|
- async def role_checker(current_user: Annotated[User, Depends(get_current_user)]) -> User:
|
|
|
- if current_user.role != required_role:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_403_FORBIDDEN,
|
|
|
- detail=f"Requires {required_role} role",
|
|
|
- )
|
|
|
- return current_user
|
|
|
-
|
|
|
- return role_checker
|
|
|
|
|
|
-
|
|
|
-def require_admin_if_auth_enabled():
|
|
|
- """Dependency factory that requires admin role if auth is enabled."""
|
|
|
-
|
|
|
- async def admin_checker(
|
|
|
- current_user: Annotated[User | None, Depends(require_auth_if_enabled)] = None,
|
|
|
- ) -> User | None:
|
|
|
- if current_user is None:
|
|
|
- return None # Auth not enabled, allow access
|
|
|
- if current_user.role != "admin":
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_403_FORBIDDEN,
|
|
|
- detail="Requires admin role",
|
|
|
- )
|
|
|
- return current_user
|
|
|
-
|
|
|
- return admin_checker
|
|
|
-
|
|
|
-
|
|
|
-def generate_api_key() -> tuple[str, str, str]:
|
|
|
- """Generate a new API key.
|
|
|
-
|
|
|
- Returns:
|
|
|
- tuple: (full_key, key_hash, key_prefix)
|
|
|
- - full_key: The complete API key (only shown once on creation)
|
|
|
- - key_hash: Hashed version for storage and verification
|
|
|
- - key_prefix: First 8 characters for display purposes
|
|
|
- """
|
|
|
- # Generate a secure random API key (32 bytes = 64 hex characters)
|
|
|
- full_key = f"bb_{secrets.token_urlsafe(32)}"
|
|
|
- key_hash = get_password_hash(full_key)
|
|
|
- key_prefix = full_key[:8] + "..." if len(full_key) > 8 else full_key
|
|
|
- return full_key, key_hash, key_prefix
|
|
|
-
|
|
|
-
|
|
|
-async def get_api_key(
|
|
|
- authorization: Annotated[str | None, Header(alias="Authorization")] = None,
|
|
|
- x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
|
|
|
- db: AsyncSession = Depends(get_db),
|
|
|
-) -> "APIKey":
|
|
|
- """Get and validate API key from request headers.
|
|
|
-
|
|
|
- Checks both 'Authorization: Bearer <key>' and 'X-API-Key: <key>' headers.
|
|
|
- """
|
|
|
- from fastapi import HTTPException, status
|
|
|
-
|
|
|
- from backend.app.models.api_key import APIKey
|
|
|
-
|
|
|
- api_key_value = None
|
|
|
- if x_api_key:
|
|
|
- api_key_value = x_api_key
|
|
|
- elif authorization and authorization.startswith("Bearer "):
|
|
|
- api_key_value = authorization.replace("Bearer ", "")
|
|
|
-
|
|
|
- if not api_key_value:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="API key required. Provide 'X-API-Key' header or 'Authorization: Bearer <key>'",
|
|
|
- )
|
|
|
-
|
|
|
- # Get all API keys and check them
|
|
|
- result = await db.execute(select(APIKey).where(APIKey.enabled.is_(True)))
|
|
|
- api_keys = result.scalars().all()
|
|
|
-
|
|
|
- for api_key in api_keys:
|
|
|
- # Check if key matches (verify against hash)
|
|
|
- if verify_password(api_key_value, api_key.key_hash):
|
|
|
- # Check expiration
|
|
|
- if api_key.expires_at and api_key.expires_at < datetime.now():
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="API key has expired",
|
|
|
- )
|
|
|
- # Update last_used timestamp
|
|
|
- api_key.last_used = datetime.now()
|
|
|
- await db.commit()
|
|
|
- return api_key
|
|
|
-
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
- detail="Invalid API key",
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-def check_permission(api_key: "APIKey", permission: str) -> None:
|
|
|
- """Check if API key has the required permission.
|
|
|
+def check_permission(api_key: APIKey, permission: str) -> None:
|
|
|
+ """Check if API key has a specific permission.
|
|
|
|
|
|
Args:
|
|
|
- api_key: The API key object
|
|
|
+ api_key: The API key record
|
|
|
permission: One of 'queue', 'control_printer', 'read_status'
|
|
|
|
|
|
- Raises:
|
|
|
- HTTPException: If permission is not granted
|
|
|
+ Raises HTTPException if permission is denied.
|
|
|
"""
|
|
|
- from fastapi import HTTPException, status
|
|
|
-
|
|
|
permission_map = {
|
|
|
- "queue": "can_queue",
|
|
|
- "control_printer": "can_control_printer",
|
|
|
- "read_status": "can_read_status",
|
|
|
+ "queue": api_key.can_queue,
|
|
|
+ "control_printer": api_key.can_control_printer,
|
|
|
+ "read_status": api_key.can_read_status,
|
|
|
}
|
|
|
|
|
|
if permission not in permission_map:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
|
- detail=f"Unknown permission: {permission}",
|
|
|
- )
|
|
|
+ raise HTTPException(status_code=500, detail=f"Unknown permission: {permission}")
|
|
|
|
|
|
- attr_name = permission_map[permission]
|
|
|
- if not getattr(api_key, attr_name, False):
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_403_FORBIDDEN,
|
|
|
- detail=f"API key does not have '{permission}' permission",
|
|
|
- )
|
|
|
+ if not permission_map[permission]:
|
|
|
+ raise HTTPException(status_code=403, detail=f"API key does not have '{permission}' permission")
|
|
|
|
|
|
|
|
|
-def check_printer_access(api_key: "APIKey", printer_id: int) -> None:
|
|
|
- """Check if API key has access to the specified printer.
|
|
|
+def check_printer_access(api_key: APIKey, printer_id: int) -> None:
|
|
|
+ """Check if API key has access to a specific printer.
|
|
|
|
|
|
Args:
|
|
|
- api_key: The API key object
|
|
|
- printer_id: The printer ID to check access for
|
|
|
+ api_key: The API key record
|
|
|
+ printer_id: The printer ID to check
|
|
|
|
|
|
- Raises:
|
|
|
- HTTPException: If access is denied
|
|
|
+ Raises HTTPException if access is denied.
|
|
|
"""
|
|
|
- from fastapi import HTTPException, status
|
|
|
-
|
|
|
- # If printer_ids is None or empty, access to all printers
|
|
|
- if api_key.printer_ids is None or len(api_key.printer_ids) == 0:
|
|
|
- return
|
|
|
-
|
|
|
- # Check if printer_id is in allowed list
|
|
|
- if printer_id not in api_key.printer_ids:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_403_FORBIDDEN,
|
|
|
- detail=f"API key does not have access to printer {printer_id}",
|
|
|
- )
|
|
|
-
|
|
|
-
|
|
|
-# Convenience dependencies - these are functions that return Depends objects
|
|
|
-def RequireAdmin():
|
|
|
- """Dependency that requires admin role."""
|
|
|
- return Depends(require_role("admin"))
|
|
|
-
|
|
|
-
|
|
|
-def RequireAdminIfAuthEnabled():
|
|
|
- """Dependency that requires admin role if auth is enabled."""
|
|
|
- return Depends(require_admin_if_auth_enabled())
|
|
|
+ if api_key.printer_ids is not None and printer_id not in api_key.printer_ids:
|
|
|
+ raise HTTPException(status_code=403, detail=f"API key does not have access to printer {printer_id}")
|