from contextlib import asynccontextmanager from fastapi import FastAPI, Depends, HTTPException, Response, status, Query from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import select, func, desc from sqlalchemy.ext.asyncio import AsyncSession from pydantic import BaseModel from database import engine, get_db, Base from models import Counter, User, SumRecord, CounterRecord from auth import ( ACCESS_TOKEN_EXPIRE_MINUTES, COOKIE_NAME, UserCreate, UserLogin, UserResponse, get_password_hash, get_user_by_email, authenticate_user, create_access_token, get_current_user, ) @asynccontextmanager async def lifespan(app: FastAPI): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000"], allow_methods=["*"], allow_headers=["*"], allow_credentials=True, ) def set_auth_cookie(response: Response, token: str) -> None: response.set_cookie( key=COOKIE_NAME, value=token, httponly=True, secure=False, # Set to True in production with HTTPS samesite="lax", max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, ) # Auth endpoints @app.post("/api/auth/register", response_model=UserResponse) async def register( user_data: UserCreate, response: Response, db: AsyncSession = Depends(get_db), ): existing_user = await get_user_by_email(db, user_data.email) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered", ) user = User( email=user_data.email, hashed_password=get_password_hash(user_data.password), ) db.add(user) await db.commit() await db.refresh(user) access_token = create_access_token(data={"sub": str(user.id)}) set_auth_cookie(response, access_token) return UserResponse(id=user.id, email=user.email) @app.post("/api/auth/login", response_model=UserResponse) async def login( user_data: UserLogin, response: Response, db: AsyncSession = Depends(get_db), ): user = await authenticate_user(db, user_data.email, user_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect email or password", ) access_token = create_access_token(data={"sub": str(user.id)}) set_auth_cookie(response, access_token) return UserResponse(id=user.id, email=user.email) @app.post("/api/auth/logout") async def logout(response: Response): response.delete_cookie(key=COOKIE_NAME) return {"ok": True} @app.get("/api/auth/me", response_model=UserResponse) async def get_me(current_user: User = Depends(get_current_user)): return UserResponse(id=current_user.id, email=current_user.email) # Counter endpoints async def get_or_create_counter(db: AsyncSession) -> Counter: result = await db.execute(select(Counter).where(Counter.id == 1)) counter = result.scalar_one_or_none() if not counter: counter = Counter(id=1, value=0) db.add(counter) await db.commit() await db.refresh(counter) return counter @app.get("/api/counter") async def get_counter( db: AsyncSession = Depends(get_db), _current_user: User = Depends(get_current_user), ): counter = await get_or_create_counter(db) return {"value": counter.value} @app.post("/api/counter/increment") async def increment_counter( db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): counter = await get_or_create_counter(db) value_before = counter.value counter.value += 1 record = CounterRecord( user_id=current_user.id, value_before=value_before, value_after=counter.value, ) db.add(record) await db.commit() return {"value": counter.value} # Sum endpoints class SumRequest(BaseModel): a: float b: float class SumResponse(BaseModel): a: float b: float result: float @app.post("/api/sum", response_model=SumResponse) async def calculate_sum( data: SumRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): result = data.a + data.b record = SumRecord( user_id=current_user.id, a=data.a, b=data.b, result=result, ) db.add(record) await db.commit() return SumResponse(a=data.a, b=data.b, result=result) # Audit endpoints from datetime import datetime from typing import List class CounterRecordResponse(BaseModel): id: int user_email: str value_before: int value_after: int created_at: datetime class SumRecordResponse(BaseModel): id: int user_email: str a: float b: float result: float created_at: datetime class PaginatedCounterRecords(BaseModel): records: List[CounterRecordResponse] total: int page: int per_page: int total_pages: int class PaginatedSumRecords(BaseModel): records: List[SumRecordResponse] total: int page: int per_page: int total_pages: int @app.get("/api/audit/counter", response_model=PaginatedCounterRecords) async def get_counter_records( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), db: AsyncSession = Depends(get_db), _current_user: User = Depends(get_current_user), ): # Get total count count_result = await db.execute(select(func.count(CounterRecord.id))) total = count_result.scalar() or 0 total_pages = (total + per_page - 1) // per_page if total > 0 else 1 # Get paginated records with user email offset = (page - 1) * per_page query = ( select(CounterRecord, User.email) .join(User, CounterRecord.user_id == User.id) .order_by(desc(CounterRecord.created_at)) .offset(offset) .limit(per_page) ) result = await db.execute(query) rows = result.all() records = [ CounterRecordResponse( id=record.id, user_email=email, value_before=record.value_before, value_after=record.value_after, created_at=record.created_at, ) for record, email in rows ] return PaginatedCounterRecords( records=records, total=total, page=page, per_page=per_page, total_pages=total_pages, ) @app.get("/api/audit/sum", response_model=PaginatedSumRecords) async def get_sum_records( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), db: AsyncSession = Depends(get_db), _current_user: User = Depends(get_current_user), ): # Get total count count_result = await db.execute(select(func.count(SumRecord.id))) total = count_result.scalar() or 0 total_pages = (total + per_page - 1) // per_page if total > 0 else 1 # Get paginated records with user email offset = (page - 1) * per_page query = ( select(SumRecord, User.email) .join(User, SumRecord.user_id == User.id) .order_by(desc(SumRecord.created_at)) .offset(offset) .limit(per_page) ) result = await db.execute(query) rows = result.all() records = [ SumRecordResponse( id=record.id, user_email=email, a=record.a, b=record.b, result=record.result, created_at=record.created_at, ) for record, email in rows ] return PaginatedSumRecords( records=records, total=total, page=page, per_page=per_page, total_pages=total_pages, )