push.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. """API routes for Web Push notifications."""
  2. import base64
  3. import json
  4. import logging
  5. from datetime import datetime
  6. from cryptography.hazmat.primitives import serialization
  7. from cryptography.hazmat.primitives.asymmetric import ec
  8. from fastapi import APIRouter, Depends, HTTPException
  9. from sqlalchemy import select
  10. from sqlalchemy.ext.asyncio import AsyncSession
  11. from backend.app.core.database import get_db
  12. from backend.app.models.notification import PushSubscription
  13. from backend.app.models.settings import Settings
  14. from backend.app.schemas.notification import (
  15. PushSubscriptionCreate,
  16. PushSubscriptionResponse,
  17. PushSubscriptionUpdate,
  18. VapidPublicKeyResponse,
  19. )
  20. logger = logging.getLogger(__name__)
  21. router = APIRouter(prefix="/push", tags=["push"])
  22. # Settings keys for VAPID
  23. VAPID_PRIVATE_KEY = "vapid_private_key"
  24. VAPID_PUBLIC_KEY = "vapid_public_key"
  25. VAPID_CLAIMS_EMAIL = "vapid_claims_email"
  26. def _generate_vapid_keys() -> tuple[str, str]:
  27. """Generate VAPID key pair using cryptography library."""
  28. # Generate private key
  29. private_key = ec.generate_private_key(ec.SECP256R1())
  30. # Get private key in PEM format
  31. private_pem = private_key.private_bytes(
  32. encoding=serialization.Encoding.PEM,
  33. format=serialization.PrivateFormat.PKCS8,
  34. encryption_algorithm=serialization.NoEncryption()
  35. ).decode("utf-8")
  36. # Get public key in uncompressed point format (X9.62)
  37. public_key = private_key.public_key()
  38. public_bytes = public_key.public_bytes(
  39. encoding=serialization.Encoding.X962,
  40. format=serialization.PublicFormat.UncompressedPoint
  41. )
  42. # Convert to URL-safe base64 (no padding)
  43. public_b64 = base64.urlsafe_b64encode(public_bytes).rstrip(b'=').decode('ascii')
  44. return private_pem, public_b64
  45. async def get_or_create_vapid_keys(db: AsyncSession) -> tuple[str, str]:
  46. """Get existing VAPID keys or generate new ones."""
  47. # Try to get existing keys
  48. result = await db.execute(
  49. select(Settings).where(Settings.key.in_([VAPID_PRIVATE_KEY, VAPID_PUBLIC_KEY]))
  50. )
  51. settings = {s.key: s.value for s in result.scalars().all()}
  52. if VAPID_PRIVATE_KEY in settings and VAPID_PUBLIC_KEY in settings:
  53. return settings[VAPID_PRIVATE_KEY], settings[VAPID_PUBLIC_KEY]
  54. # Generate new keys
  55. logger.info("Generating new VAPID keys for Web Push")
  56. private_key, public_key = _generate_vapid_keys()
  57. # Store keys in database
  58. for key, value in [(VAPID_PRIVATE_KEY, private_key), (VAPID_PUBLIC_KEY, public_key)]:
  59. existing = await db.execute(select(Settings).where(Settings.key == key))
  60. setting = existing.scalar_one_or_none()
  61. if setting:
  62. setting.value = value
  63. else:
  64. db.add(Settings(key=key, value=value))
  65. await db.commit()
  66. logger.info("VAPID keys generated and stored")
  67. return private_key, public_key
  68. async def get_vapid_claims_email(db: AsyncSession) -> str:
  69. """Get the email for VAPID claims (defaults to a placeholder)."""
  70. result = await db.execute(select(Settings).where(Settings.key == VAPID_CLAIMS_EMAIL))
  71. setting = result.scalar_one_or_none()
  72. return setting.value if setting else "mailto:bambuddy@localhost"
  73. @router.get("/vapid-public-key", response_model=VapidPublicKeyResponse)
  74. async def get_vapid_public_key(db: AsyncSession = Depends(get_db)):
  75. """Get the VAPID public key for push subscription."""
  76. _, public_key = await get_or_create_vapid_keys(db)
  77. return VapidPublicKeyResponse(public_key=public_key)
  78. @router.get("/subscriptions", response_model=list[PushSubscriptionResponse])
  79. async def list_subscriptions(db: AsyncSession = Depends(get_db)):
  80. """List all push subscriptions."""
  81. result = await db.execute(
  82. select(PushSubscription).order_by(PushSubscription.created_at.desc())
  83. )
  84. return result.scalars().all()
  85. @router.post("/subscribe", response_model=PushSubscriptionResponse)
  86. async def subscribe(
  87. subscription: PushSubscriptionCreate,
  88. db: AsyncSession = Depends(get_db),
  89. ):
  90. """Subscribe a browser to push notifications."""
  91. # Check if subscription already exists (by endpoint)
  92. result = await db.execute(
  93. select(PushSubscription).where(PushSubscription.endpoint == subscription.endpoint)
  94. )
  95. existing = result.scalar_one_or_none()
  96. if existing:
  97. # Update existing subscription
  98. existing.p256dh_key = subscription.p256dh_key
  99. existing.auth_key = subscription.auth_key
  100. existing.user_agent = subscription.user_agent
  101. if subscription.name:
  102. existing.name = subscription.name
  103. existing.enabled = True
  104. existing.updated_at = datetime.utcnow()
  105. await db.commit()
  106. await db.refresh(existing)
  107. logger.info(f"Updated push subscription: {existing.name or existing.id}")
  108. return existing
  109. # Create new subscription
  110. # Generate a default name from user agent if not provided
  111. name = subscription.name
  112. if not name and subscription.user_agent:
  113. # Extract browser name from user agent
  114. ua = subscription.user_agent.lower()
  115. if "chrome" in ua and "edg" not in ua:
  116. name = "Chrome"
  117. elif "firefox" in ua:
  118. name = "Firefox"
  119. elif "safari" in ua and "chrome" not in ua:
  120. name = "Safari"
  121. elif "edg" in ua:
  122. name = "Edge"
  123. else:
  124. name = "Browser"
  125. # Add device hint
  126. if "mobile" in ua or "android" in ua or "iphone" in ua:
  127. name += " (Mobile)"
  128. else:
  129. name += " (Desktop)"
  130. new_subscription = PushSubscription(
  131. endpoint=subscription.endpoint,
  132. p256dh_key=subscription.p256dh_key,
  133. auth_key=subscription.auth_key,
  134. name=name,
  135. user_agent=subscription.user_agent,
  136. enabled=True,
  137. )
  138. db.add(new_subscription)
  139. await db.commit()
  140. await db.refresh(new_subscription)
  141. logger.info(f"New push subscription created: {new_subscription.name or new_subscription.id}")
  142. return new_subscription
  143. @router.patch("/subscriptions/{subscription_id}", response_model=PushSubscriptionResponse)
  144. async def update_subscription(
  145. subscription_id: int,
  146. update: PushSubscriptionUpdate,
  147. db: AsyncSession = Depends(get_db),
  148. ):
  149. """Update a push subscription."""
  150. result = await db.execute(
  151. select(PushSubscription).where(PushSubscription.id == subscription_id)
  152. )
  153. subscription = result.scalar_one_or_none()
  154. if not subscription:
  155. raise HTTPException(status_code=404, detail="Subscription not found")
  156. if update.name is not None:
  157. subscription.name = update.name
  158. if update.enabled is not None:
  159. subscription.enabled = update.enabled
  160. subscription.updated_at = datetime.utcnow()
  161. await db.commit()
  162. await db.refresh(subscription)
  163. return subscription
  164. @router.delete("/subscriptions/{subscription_id}")
  165. async def delete_subscription(
  166. subscription_id: int,
  167. db: AsyncSession = Depends(get_db),
  168. ):
  169. """Delete a push subscription."""
  170. result = await db.execute(
  171. select(PushSubscription).where(PushSubscription.id == subscription_id)
  172. )
  173. subscription = result.scalar_one_or_none()
  174. if not subscription:
  175. raise HTTPException(status_code=404, detail="Subscription not found")
  176. await db.delete(subscription)
  177. await db.commit()
  178. return {"message": "Subscription deleted"}
  179. @router.post("/unsubscribe")
  180. async def unsubscribe_by_endpoint(
  181. endpoint: str,
  182. db: AsyncSession = Depends(get_db),
  183. ):
  184. """Unsubscribe by endpoint URL (called when browser unsubscribes)."""
  185. result = await db.execute(
  186. select(PushSubscription).where(PushSubscription.endpoint == endpoint)
  187. )
  188. subscription = result.scalar_one_or_none()
  189. if subscription:
  190. await db.delete(subscription)
  191. await db.commit()
  192. logger.info(f"Push subscription removed by endpoint: {subscription.name or subscription.id}")
  193. return {"message": "Unsubscribed"}
  194. @router.post("/test")
  195. async def test_push_notification(db: AsyncSession = Depends(get_db)):
  196. """Send a test push notification to all subscribed browsers."""
  197. from pywebpush import webpush, WebPushException
  198. private_key, public_key = await get_or_create_vapid_keys(db)
  199. claims_email = await get_vapid_claims_email(db)
  200. # Get all enabled subscriptions
  201. result = await db.execute(
  202. select(PushSubscription).where(PushSubscription.enabled == True)
  203. )
  204. subscriptions = result.scalars().all()
  205. if not subscriptions:
  206. raise HTTPException(status_code=400, detail="No push subscriptions found")
  207. success_count = 0
  208. error_count = 0
  209. errors = []
  210. for sub in subscriptions:
  211. subscription_info = {
  212. "endpoint": sub.endpoint,
  213. "keys": {
  214. "p256dh": sub.p256dh_key,
  215. "auth": sub.auth_key,
  216. },
  217. }
  218. payload = json.dumps({
  219. "title": "BamBuddy Test",
  220. "body": "Push notifications are working!",
  221. "url": "/",
  222. })
  223. try:
  224. webpush(
  225. subscription_info=subscription_info,
  226. data=payload,
  227. vapid_private_key=private_key,
  228. vapid_claims={"sub": claims_email},
  229. )
  230. sub.last_success = datetime.utcnow()
  231. success_count += 1
  232. logger.info(f"Test push sent to: {sub.name or sub.id}")
  233. except WebPushException as e:
  234. error_count += 1
  235. sub.last_error = str(e)
  236. sub.last_error_at = datetime.utcnow()
  237. errors.append(f"{sub.name or sub.id}: {str(e)}")
  238. logger.error(f"Push error for {sub.name or sub.id}: {e}")
  239. # If subscription is gone (410), remove it
  240. if e.response and e.response.status_code == 410:
  241. await db.delete(sub)
  242. logger.info(f"Removed expired subscription: {sub.name or sub.id}")
  243. await db.commit()
  244. return {
  245. "success": success_count > 0,
  246. "message": f"Sent to {success_count} device(s), {error_count} error(s)",
  247. "errors": errors if errors else None,
  248. }