| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- """API routes for Web Push notifications."""
- import base64
- import json
- import logging
- from datetime import datetime
- from cryptography.hazmat.primitives import serialization
- from cryptography.hazmat.primitives.asymmetric import ec
- from fastapi import APIRouter, Depends, HTTPException
- from sqlalchemy import select
- from sqlalchemy.ext.asyncio import AsyncSession
- from backend.app.core.database import get_db
- from backend.app.models.notification import PushSubscription
- from backend.app.models.settings import Settings
- from backend.app.schemas.notification import (
- PushSubscriptionCreate,
- PushSubscriptionResponse,
- PushSubscriptionUpdate,
- VapidPublicKeyResponse,
- )
- logger = logging.getLogger(__name__)
- router = APIRouter(prefix="/push", tags=["push"])
- # Settings keys for VAPID
- VAPID_PRIVATE_KEY = "vapid_private_key"
- VAPID_PUBLIC_KEY = "vapid_public_key"
- VAPID_CLAIMS_EMAIL = "vapid_claims_email"
- def _generate_vapid_keys() -> tuple[str, str]:
- """Generate VAPID key pair using cryptography library."""
- # Generate private key
- private_key = ec.generate_private_key(ec.SECP256R1())
- # Get private key in PEM format
- private_pem = private_key.private_bytes(
- encoding=serialization.Encoding.PEM,
- format=serialization.PrivateFormat.PKCS8,
- encryption_algorithm=serialization.NoEncryption()
- ).decode("utf-8")
- # Get public key in uncompressed point format (X9.62)
- public_key = private_key.public_key()
- public_bytes = public_key.public_bytes(
- encoding=serialization.Encoding.X962,
- format=serialization.PublicFormat.UncompressedPoint
- )
- # Convert to URL-safe base64 (no padding)
- public_b64 = base64.urlsafe_b64encode(public_bytes).rstrip(b'=').decode('ascii')
- return private_pem, public_b64
- async def get_or_create_vapid_keys(db: AsyncSession) -> tuple[str, str]:
- """Get existing VAPID keys or generate new ones."""
- # Try to get existing keys
- result = await db.execute(
- select(Settings).where(Settings.key.in_([VAPID_PRIVATE_KEY, VAPID_PUBLIC_KEY]))
- )
- settings = {s.key: s.value for s in result.scalars().all()}
- if VAPID_PRIVATE_KEY in settings and VAPID_PUBLIC_KEY in settings:
- return settings[VAPID_PRIVATE_KEY], settings[VAPID_PUBLIC_KEY]
- # Generate new keys
- logger.info("Generating new VAPID keys for Web Push")
- private_key, public_key = _generate_vapid_keys()
- # Store keys in database
- for key, value in [(VAPID_PRIVATE_KEY, private_key), (VAPID_PUBLIC_KEY, public_key)]:
- existing = await db.execute(select(Settings).where(Settings.key == key))
- setting = existing.scalar_one_or_none()
- if setting:
- setting.value = value
- else:
- db.add(Settings(key=key, value=value))
- await db.commit()
- logger.info("VAPID keys generated and stored")
- return private_key, public_key
- async def get_vapid_claims_email(db: AsyncSession) -> str:
- """Get the email for VAPID claims (defaults to a placeholder)."""
- result = await db.execute(select(Settings).where(Settings.key == VAPID_CLAIMS_EMAIL))
- setting = result.scalar_one_or_none()
- return setting.value if setting else "mailto:bambuddy@localhost"
- @router.get("/vapid-public-key", response_model=VapidPublicKeyResponse)
- async def get_vapid_public_key(db: AsyncSession = Depends(get_db)):
- """Get the VAPID public key for push subscription."""
- _, public_key = await get_or_create_vapid_keys(db)
- return VapidPublicKeyResponse(public_key=public_key)
- @router.get("/subscriptions", response_model=list[PushSubscriptionResponse])
- async def list_subscriptions(db: AsyncSession = Depends(get_db)):
- """List all push subscriptions."""
- result = await db.execute(
- select(PushSubscription).order_by(PushSubscription.created_at.desc())
- )
- return result.scalars().all()
- @router.post("/subscribe", response_model=PushSubscriptionResponse)
- async def subscribe(
- subscription: PushSubscriptionCreate,
- db: AsyncSession = Depends(get_db),
- ):
- """Subscribe a browser to push notifications."""
- # Check if subscription already exists (by endpoint)
- result = await db.execute(
- select(PushSubscription).where(PushSubscription.endpoint == subscription.endpoint)
- )
- existing = result.scalar_one_or_none()
- if existing:
- # Update existing subscription
- existing.p256dh_key = subscription.p256dh_key
- existing.auth_key = subscription.auth_key
- existing.user_agent = subscription.user_agent
- if subscription.name:
- existing.name = subscription.name
- existing.enabled = True
- existing.updated_at = datetime.utcnow()
- await db.commit()
- await db.refresh(existing)
- logger.info(f"Updated push subscription: {existing.name or existing.id}")
- return existing
- # Create new subscription
- # Generate a default name from user agent if not provided
- name = subscription.name
- if not name and subscription.user_agent:
- # Extract browser name from user agent
- ua = subscription.user_agent.lower()
- if "chrome" in ua and "edg" not in ua:
- name = "Chrome"
- elif "firefox" in ua:
- name = "Firefox"
- elif "safari" in ua and "chrome" not in ua:
- name = "Safari"
- elif "edg" in ua:
- name = "Edge"
- else:
- name = "Browser"
- # Add device hint
- if "mobile" in ua or "android" in ua or "iphone" in ua:
- name += " (Mobile)"
- else:
- name += " (Desktop)"
- new_subscription = PushSubscription(
- endpoint=subscription.endpoint,
- p256dh_key=subscription.p256dh_key,
- auth_key=subscription.auth_key,
- name=name,
- user_agent=subscription.user_agent,
- enabled=True,
- )
- db.add(new_subscription)
- await db.commit()
- await db.refresh(new_subscription)
- logger.info(f"New push subscription created: {new_subscription.name or new_subscription.id}")
- return new_subscription
- @router.patch("/subscriptions/{subscription_id}", response_model=PushSubscriptionResponse)
- async def update_subscription(
- subscription_id: int,
- update: PushSubscriptionUpdate,
- db: AsyncSession = Depends(get_db),
- ):
- """Update a push subscription."""
- result = await db.execute(
- select(PushSubscription).where(PushSubscription.id == subscription_id)
- )
- subscription = result.scalar_one_or_none()
- if not subscription:
- raise HTTPException(status_code=404, detail="Subscription not found")
- if update.name is not None:
- subscription.name = update.name
- if update.enabled is not None:
- subscription.enabled = update.enabled
- subscription.updated_at = datetime.utcnow()
- await db.commit()
- await db.refresh(subscription)
- return subscription
- @router.delete("/subscriptions/{subscription_id}")
- async def delete_subscription(
- subscription_id: int,
- db: AsyncSession = Depends(get_db),
- ):
- """Delete a push subscription."""
- result = await db.execute(
- select(PushSubscription).where(PushSubscription.id == subscription_id)
- )
- subscription = result.scalar_one_or_none()
- if not subscription:
- raise HTTPException(status_code=404, detail="Subscription not found")
- await db.delete(subscription)
- await db.commit()
- return {"message": "Subscription deleted"}
- @router.post("/unsubscribe")
- async def unsubscribe_by_endpoint(
- endpoint: str,
- db: AsyncSession = Depends(get_db),
- ):
- """Unsubscribe by endpoint URL (called when browser unsubscribes)."""
- result = await db.execute(
- select(PushSubscription).where(PushSubscription.endpoint == endpoint)
- )
- subscription = result.scalar_one_or_none()
- if subscription:
- await db.delete(subscription)
- await db.commit()
- logger.info(f"Push subscription removed by endpoint: {subscription.name or subscription.id}")
- return {"message": "Unsubscribed"}
- @router.post("/test")
- async def test_push_notification(db: AsyncSession = Depends(get_db)):
- """Send a test push notification to all subscribed browsers."""
- from pywebpush import webpush, WebPushException
- private_key, public_key = await get_or_create_vapid_keys(db)
- claims_email = await get_vapid_claims_email(db)
- # Get all enabled subscriptions
- result = await db.execute(
- select(PushSubscription).where(PushSubscription.enabled == True)
- )
- subscriptions = result.scalars().all()
- if not subscriptions:
- raise HTTPException(status_code=400, detail="No push subscriptions found")
- success_count = 0
- error_count = 0
- errors = []
- for sub in subscriptions:
- subscription_info = {
- "endpoint": sub.endpoint,
- "keys": {
- "p256dh": sub.p256dh_key,
- "auth": sub.auth_key,
- },
- }
- payload = json.dumps({
- "title": "BamBuddy Test",
- "body": "Push notifications are working!",
- "url": "/",
- })
- try:
- webpush(
- subscription_info=subscription_info,
- data=payload,
- vapid_private_key=private_key,
- vapid_claims={"sub": claims_email},
- )
- sub.last_success = datetime.utcnow()
- success_count += 1
- logger.info(f"Test push sent to: {sub.name or sub.id}")
- except WebPushException as e:
- error_count += 1
- sub.last_error = str(e)
- sub.last_error_at = datetime.utcnow()
- errors.append(f"{sub.name or sub.id}: {str(e)}")
- logger.error(f"Push error for {sub.name or sub.id}: {e}")
- # If subscription is gone (410), remove it
- if e.response and e.response.status_code == 410:
- await db.delete(sub)
- logger.info(f"Removed expired subscription: {sub.name or sub.id}")
- await db.commit()
- return {
- "success": success_count > 0,
- "message": f"Sent to {success_count} device(s), {error_count} error(s)",
- "errors": errors if errors else None,
- }
|