"""Invite routes for public check, user invites, and admin management.""" from datetime import datetime, UTC from fastapi import APIRouter, Depends, HTTPException, status, Query from sqlalchemy import select, func, desc from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from auth import require_permission from database import get_db from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format from models import User, Invite, InviteStatus, Permission from schemas import ( InviteCheckResponse, InviteCreate, InviteResponse, UserInviteResponse, PaginatedInviteRecords, AdminUserResponse, ) router = APIRouter(prefix="/api/invites", tags=["invites"]) admin_router = APIRouter(prefix="/api/admin", tags=["admin"]) MAX_INVITE_COLLISION_RETRIES = 3 def build_invite_response(invite: Invite) -> InviteResponse: """Build an InviteResponse from an Invite with loaded relationships.""" return InviteResponse( id=invite.id, identifier=invite.identifier, godfather_id=invite.godfather_id, godfather_email=invite.godfather.email, status=invite.status.value, used_by_id=invite.used_by_id, used_by_email=invite.used_by.email if invite.used_by else None, created_at=invite.created_at, spent_at=invite.spent_at, revoked_at=invite.revoked_at, ) @router.get("/{identifier}/check", response_model=InviteCheckResponse) async def check_invite( identifier: str, db: AsyncSession = Depends(get_db), ) -> InviteCheckResponse: """Check if an invite is valid and can be used for signup.""" normalized = normalize_identifier(identifier) # Validate format before querying database if not is_valid_identifier_format(normalized): return InviteCheckResponse(valid=False, error="Invalid invite code format") result = await db.execute( select(Invite).where(Invite.identifier == normalized) ) invite = result.scalar_one_or_none() # Return same error for not found, spent, and revoked to avoid information leakage if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED): return InviteCheckResponse(valid=False, error="Invite not found") return InviteCheckResponse(valid=True, status=invite.status.value) @router.get("", response_model=list[UserInviteResponse]) async def get_my_invites( db: AsyncSession = Depends(get_db), current_user: User = Depends(require_permission(Permission.VIEW_OWN_INVITES)), ) -> list[UserInviteResponse]: """Get all invites owned by the current user.""" result = await db.execute( select(Invite) .where(Invite.godfather_id == current_user.id) .order_by(desc(Invite.created_at)) ) invites = result.scalars().all() # Use preloaded used_by relationship (selectin loading) return [ UserInviteResponse( id=invite.id, identifier=invite.identifier, status=invite.status.value, used_by_email=invite.used_by.email if invite.used_by else None, created_at=invite.created_at, spent_at=invite.spent_at, ) for invite in invites ] @admin_router.get("/users", response_model=list[AdminUserResponse]) async def list_users_for_admin( db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), ) -> list[AdminUserResponse]: """List all users for admin dropdowns (invite creation, etc.).""" result = await db.execute(select(User.id, User.email).order_by(User.email)) users = result.all() return [AdminUserResponse(id=u.id, email=u.email) for u in users] @admin_router.post("/invites", response_model=InviteResponse) async def create_invite( data: InviteCreate, db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), ) -> InviteResponse: """Create a new invite for a specified godfather user.""" # Validate godfather exists result = await db.execute( select(User.id).where(User.id == data.godfather_id) ) godfather_id = result.scalar_one_or_none() if not godfather_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Godfather user not found", ) # Try to create invite with retry on collision invite: Invite | None = None for attempt in range(MAX_INVITE_COLLISION_RETRIES): identifier = generate_invite_identifier() invite = Invite( identifier=identifier, godfather_id=godfather_id, status=InviteStatus.READY, ) db.add(invite) try: await db.commit() await db.refresh(invite, ["godfather"]) break except IntegrityError: await db.rollback() if attempt == MAX_INVITE_COLLISION_RETRIES - 1: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to generate unique invite code. Please try again.", ) if invite is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create invite", ) return build_invite_response(invite) @admin_router.get("/invites", response_model=PaginatedInviteRecords) async def list_all_invites( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"), godfather_id: int | None = Query(None, description="Filter by godfather user ID"), db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), ) -> PaginatedInviteRecords: """List all invites with optional filtering and pagination.""" # Build query query = select(Invite) count_query = select(func.count(Invite.id)) # Apply filters if status_filter: try: status_enum = InviteStatus(status_filter) query = query.where(Invite.status == status_enum) count_query = count_query.where(Invite.status == status_enum) except ValueError: raise HTTPException( status_code=400, detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked", ) if godfather_id: query = query.where(Invite.godfather_id == godfather_id) count_query = count_query.where(Invite.godfather_id == godfather_id) # Get total count count_result = await db.execute(count_query) total = count_result.scalar() or 0 total_pages = (total + per_page - 1) // per_page if total > 0 else 1 # Get paginated invites (relationships loaded via selectin) offset = (page - 1) * per_page query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page) result = await db.execute(query) invites = result.scalars().all() # Build responses using preloaded relationships records = [build_invite_response(invite) for invite in invites] return PaginatedInviteRecords( records=records, total=total, page=page, per_page=per_page, total_pages=total_pages, ) @admin_router.post("/invites/{invite_id}/revoke", response_model=InviteResponse) async def revoke_invite( invite_id: int, db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), ) -> InviteResponse: """Revoke an invite. Only READY invites can be revoked.""" result = await db.execute(select(Invite).where(Invite.id == invite_id)) invite = result.scalar_one_or_none() if not invite: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Invite not found", ) if invite.status != InviteStatus.READY: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.", ) invite.status = InviteStatus.REVOKED invite.revoked_at = datetime.now(UTC) await db.commit() await db.refresh(invite) return build_invite_response(invite)