from contextlib import asynccontextmanager from datetime import datetime from typing import Any, Callable, Generic, TypeVar from fastapi import FastAPI, Depends, HTTPException, Response, status, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from sqlalchemy import select, func, desc from sqlalchemy.ext.asyncio import AsyncSession from database import engine, get_db, Base from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR from validation import validate_profile_fields R = TypeVar("R", bound=BaseModel) async def paginate_with_user_email( db: AsyncSession, model: type[SumRecord] | type[CounterRecord], page: int, per_page: int, row_mapper: Callable[..., R], ) -> tuple[list[R], int, int]: """ Generic pagination helper for audit records that need user email. Returns: (records, total, total_pages) """ # Get total count count_result = await db.execute(select(func.count(model.id))) total = count_result.scalar() or 0 total_pages = (total + per_page - 1) // per_page if total > 0 else 1 # Get paginated records with user email offset = (page - 1) * per_page query = ( select(model, User.email) .join(User, model.user_id == User.id) .order_by(desc(model.created_at)) .offset(offset) .limit(per_page) ) result = await db.execute(query) rows = result.all() records: list[R] = [row_mapper(record, email) for record, email in rows] return records, total, total_pages from auth import ( ACCESS_TOKEN_EXPIRE_MINUTES, COOKIE_NAME, UserCreate, UserLogin, UserResponse, get_password_hash, get_user_by_email, authenticate_user, create_access_token, get_current_user, require_permission, build_user_response, ) @asynccontextmanager async def lifespan(app: FastAPI): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True, ) def set_auth_cookie(response: Response, token: str) -> None: response.set_cookie( key=COOKIE_NAME, value=token, httponly=True, secure=False, # Set to True in production with HTTPS samesite="lax", max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) async def get_default_role(db: AsyncSession) -> Role | None: """Get the default 'regular' role for new users.""" result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR)) return result.scalar_one_or_none() # Auth endpoints @app.post("/api/auth/register", response_model=UserResponse) async def register( user_data: UserCreate, response: Response, db: AsyncSession = Depends(get_db), ): existing_user = await get_user_by_email(db, user_data.email) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered", ) user = User( email=user_data.email, hashed_password=get_password_hash(user_data.password), ) # Assign default role if it exists default_role = await get_default_role(db) if default_role: user.roles.append(default_role) db.add(user) await db.commit() await db.refresh(user) access_token = create_access_token(data={"sub": str(user.id)}) set_auth_cookie(response, access_token) return await build_user_response(user, db) @app.post("/api/auth/login", response_model=UserResponse) async def login( user_data: UserLogin, response: Response, db: AsyncSession = Depends(get_db), ): user = await authenticate_user(db, user_data.email, user_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", ) access_token = create_access_token(data={"sub": str(user.id)}) set_auth_cookie(response, access_token) return await build_user_response(user, db) @app.post("/api/auth/logout") async def logout(response: Response): response.delete_cookie(key=COOKIE_NAME) return {"ok": True} @app.get("/api/auth/me", response_model=UserResponse) async def get_me( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): return await build_user_response(current_user, db) # Counter endpoints async def get_or_create_counter(db: AsyncSession) -> Counter: result = await db.execute(select(Counter).where(Counter.id == 1)) counter = result.scalar_one_or_none() if not counter: counter = Counter(id=1, value=0) db.add(counter) await db.commit() await db.refresh(counter) return counter @app.get("/api/counter") async def get_counter( db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)), ): counter = await get_or_create_counter(db) return {"value": counter.value} @app.post("/api/counter/increment") async def increment_counter( db: AsyncSession = Depends(get_db), current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)), ): counter = await get_or_create_counter(db) value_before = counter.value counter.value += 1 record = CounterRecord( user_id=current_user.id, value_before=value_before, value_after=counter.value, ) db.add(record) await db.commit() return {"value": counter.value} # Sum endpoints class SumRequest(BaseModel): a: float b: float class SumResponse(BaseModel): a: float b: float result: float @app.post("/api/sum", response_model=SumResponse) async def calculate_sum( data: SumRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(require_permission(Permission.USE_SUM)), ): result = data.a + data.b record = SumRecord( user_id=current_user.id, a=data.a, b=data.b, result=result, ) db.add(record) await db.commit() return SumResponse(a=data.a, b=data.b, result=result) # Audit endpoints class CounterRecordResponse(BaseModel): id: int user_email: str value_before: int value_after: int created_at: datetime class SumRecordResponse(BaseModel): id: int user_email: str a: float b: float result: float created_at: datetime RecordT = TypeVar("RecordT", bound=BaseModel) class PaginatedResponse(BaseModel, Generic[RecordT]): """Generic paginated response wrapper.""" records: list[RecordT] total: int page: int per_page: int total_pages: int PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse] PaginatedSumRecords = PaginatedResponse[SumRecordResponse] def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse: return CounterRecordResponse( id=record.id, user_email=email, value_before=record.value_before, value_after=record.value_after, created_at=record.created_at, ) @app.get("/api/audit/counter", response_model=PaginatedCounterRecords) async def get_counter_records( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)), ): records, total, total_pages = await paginate_with_user_email( db, CounterRecord, page, per_page, _map_counter_record ) return PaginatedCounterRecords( records=records, total=total, page=page, per_page=per_page, total_pages=total_pages, ) def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse: return SumRecordResponse( id=record.id, user_email=email, a=record.a, b=record.b, result=record.result, created_at=record.created_at, ) @app.get("/api/audit/sum", response_model=PaginatedSumRecords) async def get_sum_records( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)), ): records, total, total_pages = await paginate_with_user_email( db, SumRecord, page, per_page, _map_sum_record ) return PaginatedSumRecords( records=records, total=total, page=page, per_page=per_page, total_pages=total_pages, ) # Profile endpoints class ProfileResponse(BaseModel): """Response model for profile data.""" contact_email: str | None telegram: str | None signal: str | None nostr_npub: str | None class ProfileUpdate(BaseModel): """Request model for updating profile.""" contact_email: str | None = None telegram: str | None = None signal: str | None = None nostr_npub: str | None = None async def require_regular_user( current_user: User = Depends(get_current_user), ) -> User: """Dependency that requires the user to have the 'regular' role.""" if ROLE_REGULAR not in current_user.role_names: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Profile access is only available to regular users", ) return current_user @app.get("/api/profile", response_model=ProfileResponse) async def get_profile( current_user: User = Depends(require_regular_user), ): """Get the current user's profile (contact details).""" return ProfileResponse( contact_email=current_user.contact_email, telegram=current_user.telegram, signal=current_user.signal, nostr_npub=current_user.nostr_npub, ) @app.put("/api/profile", response_model=ProfileResponse) async def update_profile( data: ProfileUpdate, db: AsyncSession = Depends(get_db), current_user: User = Depends(require_regular_user), ): """Update the current user's profile (contact details).""" # Validate all fields errors = validate_profile_fields( contact_email=data.contact_email, telegram=data.telegram, signal=data.signal, nostr_npub=data.nostr_npub, ) if errors: raise HTTPException( status_code=422, detail={"field_errors": errors}, ) # Update fields current_user.contact_email = data.contact_email current_user.telegram = data.telegram current_user.signal = data.signal current_user.nostr_npub = data.nostr_npub await db.commit() await db.refresh(current_user) return ProfileResponse( contact_email=current_user.contact_email, telegram=current_user.telegram, signal=current_user.signal, nostr_npub=current_user.nostr_npub, )