arbret/backend/main.py

304 lines
7.7 KiB
Python

from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from database import engine, get_db, Base
from models import Counter, User, SumRecord, CounterRecord
from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME,
UserCreate,
UserLogin,
UserResponse,
get_password_hash,
get_user_by_email,
authenticate_user,
create_access_token,
get_current_user,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
def set_auth_cookie(response: Response, token: str) -> None:
response.set_cookie(
key=COOKIE_NAME,
value=token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
# Auth endpoints
@app.post("/api/auth/register", response_model=UserResponse)
async def register(
user_data: UserCreate,
response: Response,
db: AsyncSession = Depends(get_db),
):
existing_user = await get_user_by_email(db, user_data.email)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
user = User(
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
)
db.add(user)
await db.commit()
await db.refresh(user)
access_token = create_access_token(data={"sub": str(user.id)})
set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email)
@app.post("/api/auth/login", response_model=UserResponse)
async def login(
user_data: UserLogin,
response: Response,
db: AsyncSession = Depends(get_db),
):
user = await authenticate_user(db, user_data.email, user_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
)
access_token = create_access_token(data={"sub": str(user.id)})
set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email)
@app.post("/api/auth/logout")
async def logout(response: Response):
response.delete_cookie(key=COOKIE_NAME)
return {"ok": True}
@app.get("/api/auth/me", response_model=UserResponse)
async def get_me(current_user: User = Depends(get_current_user)):
return UserResponse(id=current_user.id, email=current_user.email)
# Counter endpoints
async def get_or_create_counter(db: AsyncSession) -> Counter:
result = await db.execute(select(Counter).where(Counter.id == 1))
counter = result.scalar_one_or_none()
if not counter:
counter = Counter(id=1, value=0)
db.add(counter)
await db.commit()
await db.refresh(counter)
return counter
@app.get("/api/counter")
async def get_counter(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
):
counter = await get_or_create_counter(db)
return {"value": counter.value}
@app.post("/api/counter/increment")
async def increment_counter(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
counter = await get_or_create_counter(db)
value_before = counter.value
counter.value += 1
record = CounterRecord(
user_id=current_user.id,
value_before=value_before,
value_after=counter.value,
)
db.add(record)
await db.commit()
return {"value": counter.value}
# Sum endpoints
class SumRequest(BaseModel):
a: float
b: float
class SumResponse(BaseModel):
a: float
b: float
result: float
@app.post("/api/sum", response_model=SumResponse)
async def calculate_sum(
data: SumRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = data.a + data.b
record = SumRecord(
user_id=current_user.id,
a=data.a,
b=data.b,
result=result,
)
db.add(record)
await db.commit()
return SumResponse(a=data.a, b=data.b, result=result)
# Audit endpoints
from datetime import datetime
from typing import List
class CounterRecordResponse(BaseModel):
id: int
user_email: str
value_before: int
value_after: int
created_at: datetime
class SumRecordResponse(BaseModel):
id: int
user_email: str
a: float
b: float
result: float
created_at: datetime
class PaginatedCounterRecords(BaseModel):
records: List[CounterRecordResponse]
total: int
page: int
per_page: int
total_pages: int
class PaginatedSumRecords(BaseModel):
records: List[SumRecordResponse]
total: int
page: int
per_page: int
total_pages: int
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
async def get_counter_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
):
# Get total count
count_result = await db.execute(select(func.count(CounterRecord.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(CounterRecord, User.email)
.join(User, CounterRecord.user_id == User.id)
.order_by(desc(CounterRecord.created_at))
.offset(offset)
.limit(per_page)
)
result = await db.execute(query)
rows = result.all()
records = [
CounterRecordResponse(
id=record.id,
user_email=email,
value_before=record.value_before,
value_after=record.value_after,
created_at=record.created_at,
)
for record, email in rows
]
return PaginatedCounterRecords(
records=records,
total=total,
page=page,
per_page=per_page,
total_pages=total_pages,
)
@app.get("/api/audit/sum", response_model=PaginatedSumRecords)
async def get_sum_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
):
# Get total count
count_result = await db.execute(select(func.count(SumRecord.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(SumRecord, User.email)
.join(User, SumRecord.user_id == User.id)
.order_by(desc(SumRecord.created_at))
.offset(offset)
.limit(per_page)
)
result = await db.execute(query)
rows = result.all()
records = [
SumRecordResponse(
id=record.id,
user_email=email,
a=record.a,
b=record.b,
result=record.result,
created_at=record.created_at,
)
for record, email in rows
]
return PaginatedSumRecords(
records=records,
total=total,
page=page,
per_page=per_page,
total_pages=total_pages,
)