from contextlib import asynccontextmanager from datetime import datetime, UTC from typing import Callable, Generic, TypeVar from fastapi import FastAPI, Depends, HTTPException, Response, status, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, EmailStr from sqlalchemy import select, func, desc from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from auth import ( ACCESS_TOKEN_EXPIRE_MINUTES, COOKIE_NAME, UserCreate, UserLogin, UserResponse, get_password_hash, get_user_by_email, authenticate_user, create_access_token, get_current_user, require_permission, build_user_response, ) from database import engine, get_db, Base from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR, Invite, InviteStatus from validation import validate_profile_fields from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format R = TypeVar("R", bound=BaseModel) async def paginate_with_user_email( db: AsyncSession, model: type[SumRecord] | type[CounterRecord], page: int, per_page: int, row_mapper: Callable[..., R], ) -> tuple[list[R], int, int]: """ Generic pagination helper for audit records that need user email. Returns: (records, total, total_pages) """ # Get total count count_result = await db.execute(select(func.count(model.id))) total = count_result.scalar() or 0 total_pages = (total + per_page - 1) // per_page if total > 0 else 1 # Get paginated records with user email offset = (page - 1) * per_page query = ( select(model, User.email) .join(User, model.user_id == User.id) .order_by(desc(model.created_at)) .offset(offset) .limit(per_page) ) result = await db.execute(query) rows = result.all() records: list[R] = [row_mapper(record, email) for record, email in rows] return records, total, total_pages @asynccontextmanager async def lifespan(app: FastAPI): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True, ) def set_auth_cookie(response: Response, token: str) -> None: response.set_cookie( key=COOKIE_NAME, value=token, httponly=True, secure=False, # Set to True in production with HTTPS samesite="lax", max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) async def get_default_role(db: AsyncSession) -> Role | None: """Get the default 'regular' role for new users.""" result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR)) return result.scalar_one_or_none() # Invite check endpoint (public) class InviteCheckResponse(BaseModel): """Response for invite check endpoint.""" valid: bool status: str | None = None error: str | None = None @app.get("/api/invites/{identifier}/check", response_model=InviteCheckResponse) async def check_invite( identifier: str, db: AsyncSession = Depends(get_db), ): """Check if an invite is valid and can be used for signup.""" normalized = normalize_identifier(identifier) # Validate format before querying database if not is_valid_identifier_format(normalized): return InviteCheckResponse(valid=False, error="Invalid invite code format") result = await db.execute( select(Invite).where(Invite.identifier == normalized) ) invite = result.scalar_one_or_none() # Return same error for not found, spent, and revoked to avoid information leakage if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED): return InviteCheckResponse(valid=False, error="Invite not found") return InviteCheckResponse(valid=True, status=invite.status.value) # Auth endpoints class RegisterWithInvite(BaseModel): """Request model for registration with invite.""" email: EmailStr password: str invite_identifier: str @app.post("/api/auth/register", response_model=UserResponse) async def register( user_data: RegisterWithInvite, response: Response, db: AsyncSession = Depends(get_db), ): """Register a new user using an invite code.""" # Validate invite normalized_identifier = normalize_identifier(user_data.invite_identifier) result = await db.execute( select(Invite).where(Invite.identifier == normalized_identifier) ) invite = result.scalar_one_or_none() # Return same error for not found, spent, and revoked to avoid information leakage if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid invite code", ) # Check email not already taken existing_user = await get_user_by_email(db, user_data.email) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered", ) # Create user with godfather user = User( email=user_data.email, hashed_password=get_password_hash(user_data.password), godfather_id=invite.godfather_id, ) # Assign default role default_role = await get_default_role(db) if default_role: user.roles.append(default_role) db.add(user) await db.flush() # Get user ID # Mark invite as spent invite.status = InviteStatus.SPENT invite.used_by_id = user.id invite.spent_at = datetime.now(UTC) await db.commit() await db.refresh(user) access_token = create_access_token(data={"sub": str(user.id)}) set_auth_cookie(response, access_token) return await build_user_response(user, db) @app.post("/api/auth/login", response_model=UserResponse) async def login( user_data: UserLogin, response: Response, db: AsyncSession = Depends(get_db), ): user = await authenticate_user(db, user_data.email, user_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", ) access_token = create_access_token(data={"sub": str(user.id)}) set_auth_cookie(response, access_token) return await build_user_response(user, db) @app.post("/api/auth/logout") async def logout(response: Response): response.delete_cookie(key=COOKIE_NAME) return {"ok": True} @app.get("/api/auth/me", response_model=UserResponse) async def get_me( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): return await build_user_response(current_user, db) # Counter endpoints async def get_or_create_counter(db: AsyncSession) -> Counter: result = await db.execute(select(Counter).where(Counter.id == 1)) counter = result.scalar_one_or_none() if not counter: counter = Counter(id=1, value=0) db.add(counter) await db.commit() await db.refresh(counter) return counter @app.get("/api/counter") async def get_counter( db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)), ): counter = await get_or_create_counter(db) return {"value": counter.value} @app.post("/api/counter/increment") async def increment_counter( db: AsyncSession = Depends(get_db), current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)), ): counter = await get_or_create_counter(db) value_before = counter.value counter.value += 1 record = CounterRecord( user_id=current_user.id, value_before=value_before, value_after=counter.value, ) db.add(record) await db.commit() return {"value": counter.value} # Sum endpoints class SumRequest(BaseModel): a: float b: float class SumResponse(BaseModel): a: float b: float result: float @app.post("/api/sum", response_model=SumResponse) async def calculate_sum( data: SumRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(require_permission(Permission.USE_SUM)), ): result = data.a + data.b record = SumRecord( user_id=current_user.id, a=data.a, b=data.b, result=result, ) db.add(record) await db.commit() return SumResponse(a=data.a, b=data.b, result=result) # Audit endpoints class CounterRecordResponse(BaseModel): id: int user_email: str value_before: int value_after: int created_at: datetime class SumRecordResponse(BaseModel): id: int user_email: str a: float b: float result: float created_at: datetime RecordT = TypeVar("RecordT", bound=BaseModel) class PaginatedResponse(BaseModel, Generic[RecordT]): """Generic paginated response wrapper.""" records: list[RecordT] total: int page: int per_page: int total_pages: int PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse] PaginatedSumRecords = PaginatedResponse[SumRecordResponse] def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse: return CounterRecordResponse( id=record.id, user_email=email, value_before=record.value_before, value_after=record.value_after, created_at=record.created_at, ) @app.get("/api/audit/counter", response_model=PaginatedCounterRecords) async def get_counter_records( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)), ): records, total, total_pages = await paginate_with_user_email( db, CounterRecord, page, per_page, _map_counter_record ) return PaginatedCounterRecords( records=records, total=total, page=page, per_page=per_page, total_pages=total_pages, ) def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse: return SumRecordResponse( id=record.id, user_email=email, a=record.a, b=record.b, result=record.result, created_at=record.created_at, ) @app.get("/api/audit/sum", response_model=PaginatedSumRecords) async def get_sum_records( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)), ): records, total, total_pages = await paginate_with_user_email( db, SumRecord, page, per_page, _map_sum_record ) return PaginatedSumRecords( records=records, total=total, page=page, per_page=per_page, total_pages=total_pages, ) # Profile endpoints class ProfileResponse(BaseModel): """Response model for profile data.""" contact_email: str | None telegram: str | None signal: str | None nostr_npub: str | None godfather_email: str | None = None class ProfileUpdate(BaseModel): """Request model for updating profile.""" contact_email: str | None = None telegram: str | None = None signal: str | None = None nostr_npub: str | None = None async def require_regular_user( current_user: User = Depends(get_current_user), ) -> User: """Dependency that requires the user to have the 'regular' role.""" if ROLE_REGULAR not in current_user.role_names: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Profile access is only available to regular users", ) return current_user async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str | None: """Get the email of a godfather user by ID.""" if not godfather_id: return None result = await db.execute( select(User.email).where(User.id == godfather_id) ) return result.scalar_one_or_none() @app.get("/api/profile", response_model=ProfileResponse) async def get_profile( current_user: User = Depends(require_regular_user), db: AsyncSession = Depends(get_db), ): """Get the current user's profile (contact details and godfather).""" godfather_email = await get_godfather_email(db, current_user.godfather_id) return ProfileResponse( contact_email=current_user.contact_email, telegram=current_user.telegram, signal=current_user.signal, nostr_npub=current_user.nostr_npub, godfather_email=godfather_email, ) @app.put("/api/profile", response_model=ProfileResponse) async def update_profile( data: ProfileUpdate, db: AsyncSession = Depends(get_db), current_user: User = Depends(require_regular_user), ): """Update the current user's profile (contact details).""" # Validate all fields errors = validate_profile_fields( contact_email=data.contact_email, telegram=data.telegram, signal=data.signal, nostr_npub=data.nostr_npub, ) if errors: raise HTTPException( status_code=422, detail={"field_errors": errors}, ) # Update fields current_user.contact_email = data.contact_email current_user.telegram = data.telegram current_user.signal = data.signal current_user.nostr_npub = data.nostr_npub await db.commit() await db.refresh(current_user) godfather_email = await get_godfather_email(db, current_user.godfather_id) return ProfileResponse( contact_email=current_user.contact_email, telegram=current_user.telegram, signal=current_user.signal, nostr_npub=current_user.nostr_npub, godfather_email=godfather_email, ) # Invite endpoints class InviteCreate(BaseModel): """Request model for creating an invite.""" godfather_id: int class InviteResponse(BaseModel): """Response model for invite data.""" id: int identifier: str godfather_id: int godfather_email: str status: str used_by_id: int | None used_by_email: str | None created_at: datetime spent_at: datetime | None revoked_at: datetime | None def build_invite_response(invite: Invite) -> InviteResponse: """Build an InviteResponse from an Invite with loaded relationships.""" return InviteResponse( id=invite.id, identifier=invite.identifier, godfather_id=invite.godfather_id, godfather_email=invite.godfather.email, status=invite.status.value, used_by_id=invite.used_by_id, used_by_email=invite.used_by.email if invite.used_by else None, created_at=invite.created_at, spent_at=invite.spent_at, revoked_at=invite.revoked_at, ) MAX_INVITE_COLLISION_RETRIES = 3 @app.post("/api/admin/invites", response_model=InviteResponse) async def create_invite( data: InviteCreate, db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), ): """Create a new invite for a specified godfather user.""" # Validate godfather exists result = await db.execute( select(User.id).where(User.id == data.godfather_id) ) godfather_id = result.scalar_one_or_none() if not godfather_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Godfather user not found", ) # Try to create invite with retry on collision invite: Invite | None = None for attempt in range(MAX_INVITE_COLLISION_RETRIES): identifier = generate_invite_identifier() invite = Invite( identifier=identifier, godfather_id=godfather_id, status=InviteStatus.READY, ) db.add(invite) try: await db.commit() await db.refresh(invite, ["godfather"]) break except IntegrityError: await db.rollback() if attempt == MAX_INVITE_COLLISION_RETRIES - 1: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to generate unique invite code. Please try again.", ) if invite is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create invite", ) return build_invite_response(invite) class UserInviteResponse(BaseModel): """Response model for a user's invite (simpler than admin view).""" id: int identifier: str status: str used_by_email: str | None created_at: datetime spent_at: datetime | None @app.get("/api/invites", response_model=list[UserInviteResponse]) async def get_my_invites( db: AsyncSession = Depends(get_db), current_user: User = Depends(require_permission(Permission.VIEW_OWN_INVITES)), ): """Get all invites owned by the current user.""" result = await db.execute( select(Invite) .where(Invite.godfather_id == current_user.id) .order_by(desc(Invite.created_at)) ) invites = result.scalars().all() # Use preloaded used_by relationship (selectin loading) return [ UserInviteResponse( id=invite.id, identifier=invite.identifier, status=invite.status.value, used_by_email=invite.used_by.email if invite.used_by else None, created_at=invite.created_at, spent_at=invite.spent_at, ) for invite in invites ] # Admin Invite Management PaginatedInviteRecords = PaginatedResponse[InviteResponse] class AdminUserResponse(BaseModel): """Minimal user info for admin dropdowns.""" id: int email: str @app.get("/api/admin/users", response_model=list[AdminUserResponse]) async def list_users_for_admin( db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), ): """List all users for admin dropdowns (invite creation, etc.).""" result = await db.execute(select(User.id, User.email).order_by(User.email)) users = result.all() return [AdminUserResponse(id=u.id, email=u.email) for u in users] @app.get("/api/admin/invites", response_model=PaginatedInviteRecords) async def list_all_invites( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"), godfather_id: int | None = Query(None, description="Filter by godfather user ID"), db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), ): """List all invites with optional filtering and pagination.""" # Build query query = select(Invite) count_query = select(func.count(Invite.id)) # Apply filters if status_filter: try: status_enum = InviteStatus(status_filter) query = query.where(Invite.status == status_enum) count_query = count_query.where(Invite.status == status_enum) except ValueError: raise HTTPException( status_code=400, detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked", ) if godfather_id: query = query.where(Invite.godfather_id == godfather_id) count_query = count_query.where(Invite.godfather_id == godfather_id) # Get total count count_result = await db.execute(count_query) total = count_result.scalar() or 0 total_pages = (total + per_page - 1) // per_page if total > 0 else 1 # Get paginated invites (relationships loaded via selectin) offset = (page - 1) * per_page query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page) result = await db.execute(query) invites = result.scalars().all() # Build responses using preloaded relationships records = [build_invite_response(invite) for invite in invites] return PaginatedInviteRecords( records=records, total=total, page=page, per_page=per_page, total_pages=total_pages, ) @app.post("/api/admin/invites/{invite_id}/revoke", response_model=InviteResponse) async def revoke_invite( invite_id: int, db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), ): """Revoke an invite. Only READY invites can be revoked.""" result = await db.execute(select(Invite).where(Invite.id == invite_id)) invite = result.scalar_one_or_none() if not invite: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Invite not found", ) if invite.status != InviteStatus.READY: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.", ) invite.status = InviteStatus.REVOKED invite.revoked_at = datetime.now(UTC) await db.commit() await db.refresh(invite) return build_invite_response(invite)