auth.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from datetime import timedelta
  2. from fastapi import APIRouter, Depends, HTTPException, status
  3. from sqlalchemy import select
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from backend.app.core.auth import (
  6. ACCESS_TOKEN_EXPIRE_MINUTES,
  7. authenticate_user,
  8. create_access_token,
  9. get_current_active_user,
  10. get_password_hash,
  11. get_user_by_username,
  12. )
  13. from backend.app.core.database import get_db
  14. from backend.app.models.settings import Settings
  15. from backend.app.models.user import User
  16. from backend.app.schemas.auth import LoginRequest, LoginResponse, SetupRequest, SetupResponse, UserResponse
  17. router = APIRouter(prefix="/auth", tags=["authentication"])
  18. async def is_auth_enabled(db: AsyncSession) -> bool:
  19. """Check if authentication is enabled."""
  20. result = await db.execute(select(Settings).where(Settings.key == "auth_enabled"))
  21. setting = result.scalar_one_or_none()
  22. if setting is None:
  23. return False
  24. return setting.value.lower() == "true"
  25. async def set_auth_enabled(db: AsyncSession, enabled: bool) -> None:
  26. """Set authentication enabled status."""
  27. from sqlalchemy import func
  28. from sqlalchemy.dialects.sqlite import insert as sqlite_insert
  29. stmt = sqlite_insert(Settings).values(key="auth_enabled", value="true" if enabled else "false")
  30. stmt = stmt.on_conflict_do_update(
  31. index_elements=["key"], set_={"value": "true" if enabled else "false", "updated_at": func.now()}
  32. )
  33. await db.execute(stmt)
  34. # Note: Don't commit here - let get_db handle it or commit explicitly in the route
  35. async def is_setup_completed(db: AsyncSession) -> bool:
  36. """Check if setup has been completed."""
  37. result = await db.execute(select(Settings).where(Settings.key == "setup_completed"))
  38. setting = result.scalar_one_or_none()
  39. return setting and setting.value.lower() == "true"
  40. async def set_setup_completed(db: AsyncSession, completed: bool) -> None:
  41. """Set setup completed status."""
  42. from sqlalchemy import func
  43. from sqlalchemy.dialects.sqlite import insert as sqlite_insert
  44. stmt = sqlite_insert(Settings).values(key="setup_completed", value="true" if completed else "false")
  45. stmt = stmt.on_conflict_do_update(
  46. index_elements=["key"], set_={"value": "true" if completed else "false", "updated_at": func.now()}
  47. )
  48. await db.execute(stmt)
  49. # Note: Don't commit here - let get_db handle it or commit explicitly in the route
  50. @router.post("/setup", response_model=SetupResponse)
  51. async def setup_auth(request: SetupRequest, db: AsyncSession = Depends(get_db)):
  52. """First-time setup: enable/disable authentication and create admin user."""
  53. import logging
  54. logger = logging.getLogger(__name__)
  55. try:
  56. # Check if auth is already configured (prevent re-setup)
  57. result = await db.execute(select(Settings).where(Settings.key == "auth_enabled"))
  58. _existing_setting = result.scalar_one_or_none()
  59. # Check if users exist
  60. user_count_result = await db.execute(select(User))
  61. _user_count = len(user_count_result.scalars().all())
  62. # if _existing_setting and _user_count > 0:
  63. # # Auth already configured and users exist - prevent re-setup
  64. # raise HTTPException(
  65. # status_code=status.HTTP_400_BAD_REQUEST,
  66. # detail="Authentication is already configured. Use user management to modify users.",
  67. # )
  68. # If auth_enabled is true but no users exist, allow re-setup (recovery scenario)
  69. admin_created = False
  70. if request.auth_enabled:
  71. # Check if admin users already exist
  72. admin_users_result = await db.execute(select(User).where(User.role == "admin"))
  73. existing_admin_users = list(admin_users_result.scalars().all())
  74. has_admin_users = len(existing_admin_users) > 0
  75. if has_admin_users:
  76. # Admin users already exist, just enable auth (don't create new admin)
  77. logger.info(
  78. f"Admin users already exist ({len(existing_admin_users)} found), enabling authentication without creating new admin"
  79. )
  80. admin_created = False
  81. else:
  82. # No admin users exist, require admin credentials to create first admin
  83. if not request.admin_username or not request.admin_password:
  84. raise HTTPException(
  85. status_code=status.HTTP_400_BAD_REQUEST,
  86. detail="Admin username and password are required when enabling authentication (no admin users exist)",
  87. )
  88. # Check if username already exists (shouldn't happen if no admin users exist, but check anyway)
  89. existing_user = await get_user_by_username(db, request.admin_username)
  90. if existing_user:
  91. raise HTTPException(
  92. status_code=status.HTTP_400_BAD_REQUEST,
  93. detail="User with this username already exists",
  94. )
  95. # Create admin user FIRST (before enabling auth)
  96. try:
  97. logger.info(f"Creating admin user: {request.admin_username}")
  98. admin_user = User(
  99. username=request.admin_username,
  100. password_hash=get_password_hash(request.admin_password),
  101. role="admin",
  102. is_active=True,
  103. )
  104. db.add(admin_user)
  105. logger.info(f"Admin user added to session: {request.admin_username}")
  106. admin_created = True
  107. except Exception as e:
  108. await db.rollback()
  109. logger.error(f"Failed to create admin user: {e}", exc_info=True)
  110. raise HTTPException(
  111. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  112. detail=f"Failed to create admin user: {str(e)}",
  113. )
  114. # Set auth enabled and mark setup as completed
  115. await set_auth_enabled(db, request.auth_enabled)
  116. await set_setup_completed(db, True)
  117. await db.commit()
  118. if admin_created:
  119. await db.refresh(admin_user)
  120. logger.info(f"Admin user created successfully: {admin_user.id}")
  121. logger.info(f"Setup completed: auth_enabled={request.auth_enabled}, admin_created={admin_created}")
  122. return SetupResponse(auth_enabled=request.auth_enabled, admin_created=admin_created)
  123. except HTTPException:
  124. raise
  125. except Exception as e:
  126. logger.error(f"Setup error: {e}", exc_info=True)
  127. await db.rollback()
  128. raise HTTPException(
  129. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  130. detail=f"Setup failed: {str(e)}",
  131. )
  132. @router.get("/status")
  133. async def get_auth_status(db: AsyncSession = Depends(get_db)):
  134. """Get authentication status (public endpoint)."""
  135. auth_enabled = await is_auth_enabled(db)
  136. setup_completed = await is_setup_completed(db)
  137. # Only require setup if it hasn't been completed yet
  138. requires_setup = not setup_completed
  139. return {"auth_enabled": auth_enabled, "requires_setup": requires_setup}
  140. @router.post("/disable", response_model=dict)
  141. async def disable_auth(
  142. current_user: User = Depends(get_current_active_user),
  143. db: AsyncSession = Depends(get_db),
  144. ):
  145. """Disable authentication (admin only)."""
  146. import logging
  147. logger = logging.getLogger(__name__)
  148. # Only admins can disable authentication
  149. if current_user.role != "admin":
  150. raise HTTPException(
  151. status_code=status.HTTP_403_FORBIDDEN,
  152. detail="Only admins can disable authentication",
  153. )
  154. try:
  155. await set_auth_enabled(db, False)
  156. await db.commit()
  157. logger.info(f"Authentication disabled by admin user: {current_user.username}")
  158. return {"message": "Authentication disabled successfully", "auth_enabled": False}
  159. except Exception as e:
  160. await db.rollback()
  161. logger.error(f"Failed to disable authentication: {e}", exc_info=True)
  162. raise HTTPException(
  163. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  164. detail=f"Failed to disable authentication: {str(e)}",
  165. )
  166. @router.post("/login", response_model=LoginResponse)
  167. async def login(request: LoginRequest, db: AsyncSession = Depends(get_db)):
  168. """Login and get access token."""
  169. # Check if auth is enabled
  170. auth_enabled = await is_auth_enabled(db)
  171. if not auth_enabled:
  172. raise HTTPException(
  173. status_code=status.HTTP_400_BAD_REQUEST,
  174. detail="Authentication is not enabled",
  175. )
  176. user = await authenticate_user(db, request.username, request.password)
  177. if not user:
  178. raise HTTPException(
  179. status_code=status.HTTP_401_UNAUTHORIZED,
  180. detail="Incorrect username or password",
  181. headers={"WWW-Authenticate": "Bearer"},
  182. )
  183. access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
  184. access_token = create_access_token(data={"sub": user.username}, expires_delta=access_token_expires)
  185. return LoginResponse(
  186. access_token=access_token,
  187. token_type="bearer",
  188. user=UserResponse(
  189. id=user.id,
  190. username=user.username,
  191. role=user.role,
  192. is_active=user.is_active,
  193. created_at=user.created_at.isoformat(),
  194. ),
  195. )
  196. @router.get("/me", response_model=UserResponse)
  197. async def get_current_user_info(current_user: User = Depends(get_current_active_user)):
  198. """Get current user information."""
  199. return UserResponse(
  200. id=current_user.id,
  201. username=current_user.username,
  202. role=current_user.role,
  203. is_active=current_user.is_active,
  204. created_at=current_user.created_at.isoformat(),
  205. )
  206. @router.post("/logout")
  207. async def logout():
  208. """Logout (client should discard token)."""
  209. return {"message": "Logged out successfully"}