arbret/backend/main.py

305 lines
7.7 KiB
Python
Raw Normal View History

2025-12-18 21:48:41 +01:00
from contextlib import asynccontextmanager
2025-12-18 22:51:43 +01:00
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
2025-12-18 21:37:28 +01:00
from fastapi.middleware.cors import CORSMiddleware
2025-12-18 22:51:43 +01:00
from sqlalchemy import select, func, desc
2025-12-18 21:48:41 +01:00
from sqlalchemy.ext.asyncio import AsyncSession
2025-12-18 21:37:28 +01:00
2025-12-18 22:51:43 +01:00
from pydantic import BaseModel
2025-12-18 21:48:41 +01:00
from database import engine, get_db, Base
2025-12-18 22:51:43 +01:00
from models import Counter, User, SumRecord, CounterRecord
2025-12-18 22:08:31 +01:00
from auth import (
2025-12-18 22:24:46 +01:00
ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME,
2025-12-18 22:08:31 +01:00
UserCreate,
UserLogin,
UserResponse,
get_password_hash,
get_user_by_email,
authenticate_user,
create_access_token,
get_current_user,
)
2025-12-18 21:48:41 +01:00
@asynccontextmanager
async def lifespan(app: FastAPI):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
app = FastAPI(lifespan=lifespan)
2025-12-18 21:37:28 +01:00
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_methods=["*"],
allow_headers=["*"],
2025-12-18 22:08:31 +01:00
allow_credentials=True,
2025-12-18 21:37:28 +01:00
)
2025-12-18 22:24:46 +01:00
def set_auth_cookie(response: Response, token: str) -> None:
response.set_cookie(
key=COOKIE_NAME,
value=token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
2025-12-18 22:08:31 +01:00
# Auth endpoints
2025-12-18 22:24:46 +01:00
@app.post("/api/auth/register", response_model=UserResponse)
async def register(
user_data: UserCreate,
response: Response,
db: AsyncSession = Depends(get_db),
):
2025-12-18 22:08:31 +01:00
existing_user = await get_user_by_email(db, user_data.email)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
user = User(
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
)
db.add(user)
await db.commit()
await db.refresh(user)
access_token = create_access_token(data={"sub": str(user.id)})
2025-12-18 22:24:46 +01:00
set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email)
2025-12-18 22:08:31 +01:00
2025-12-18 22:24:46 +01:00
@app.post("/api/auth/login", response_model=UserResponse)
async def login(
user_data: UserLogin,
response: Response,
db: AsyncSession = Depends(get_db),
):
2025-12-18 22:08:31 +01:00
user = await authenticate_user(db, user_data.email, user_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
)
access_token = create_access_token(data={"sub": str(user.id)})
2025-12-18 22:24:46 +01:00
set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email)
@app.post("/api/auth/logout")
async def logout(response: Response):
response.delete_cookie(key=COOKIE_NAME)
return {"ok": True}
2025-12-18 22:08:31 +01:00
@app.get("/api/auth/me", response_model=UserResponse)
async def get_me(current_user: User = Depends(get_current_user)):
return UserResponse(id=current_user.id, email=current_user.email)
# Counter endpoints
2025-12-18 21:48:41 +01:00
async def get_or_create_counter(db: AsyncSession) -> Counter:
result = await db.execute(select(Counter).where(Counter.id == 1))
counter = result.scalar_one_or_none()
if not counter:
counter = Counter(id=1, value=0)
db.add(counter)
await db.commit()
await db.refresh(counter)
return counter
@app.get("/api/counter")
2025-12-18 22:08:31 +01:00
async def get_counter(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
):
2025-12-18 21:48:41 +01:00
counter = await get_or_create_counter(db)
return {"value": counter.value}
@app.post("/api/counter/increment")
2025-12-18 22:08:31 +01:00
async def increment_counter(
db: AsyncSession = Depends(get_db),
2025-12-18 22:51:43 +01:00
current_user: User = Depends(get_current_user),
2025-12-18 22:08:31 +01:00
):
2025-12-18 21:48:41 +01:00
counter = await get_or_create_counter(db)
2025-12-18 22:51:43 +01:00
value_before = counter.value
2025-12-18 21:48:41 +01:00
counter.value += 1
2025-12-18 22:51:43 +01:00
record = CounterRecord(
user_id=current_user.id,
value_before=value_before,
value_after=counter.value,
)
db.add(record)
2025-12-18 21:48:41 +01:00
await db.commit()
return {"value": counter.value}
2025-12-18 21:37:28 +01:00
2025-12-18 22:51:43 +01:00
# Sum endpoints
class SumRequest(BaseModel):
a: float
b: float
class SumResponse(BaseModel):
a: float
b: float
result: float
@app.post("/api/sum", response_model=SumResponse)
async def calculate_sum(
data: SumRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = data.a + data.b
record = SumRecord(
user_id=current_user.id,
a=data.a,
b=data.b,
result=result,
)
db.add(record)
await db.commit()
return SumResponse(a=data.a, b=data.b, result=result)
# Audit endpoints
from datetime import datetime
from typing import List
class CounterRecordResponse(BaseModel):
id: int
user_email: str
value_before: int
value_after: int
created_at: datetime
class SumRecordResponse(BaseModel):
id: int
user_email: str
a: float
b: float
result: float
created_at: datetime
class PaginatedCounterRecords(BaseModel):
records: List[CounterRecordResponse]
total: int
page: int
per_page: int
total_pages: int
class PaginatedSumRecords(BaseModel):
records: List[SumRecordResponse]
total: int
page: int
per_page: int
total_pages: int
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
async def get_counter_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
):
# Get total count
count_result = await db.execute(select(func.count(CounterRecord.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(CounterRecord, User.email)
.join(User, CounterRecord.user_id == User.id)
.order_by(desc(CounterRecord.created_at))
.offset(offset)
.limit(per_page)
)
result = await db.execute(query)
rows = result.all()
records = [
CounterRecordResponse(
id=record.id,
user_email=email,
value_before=record.value_before,
value_after=record.value_after,
created_at=record.created_at,
)
for record, email in rows
]
return PaginatedCounterRecords(
records=records,
total=total,
page=page,
per_page=per_page,
total_pages=total_pages,
)
@app.get("/api/audit/sum", response_model=PaginatedSumRecords)
async def get_sum_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
):
# Get total count
count_result = await db.execute(select(func.count(SumRecord.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(SumRecord, User.email)
.join(User, SumRecord.user_id == User.id)
.order_by(desc(SumRecord.created_at))
.offset(offset)
.limit(per_page)
)
result = await db.execute(query)
rows = result.all()
records = [
SumRecordResponse(
id=record.id,
user_email=email,
a=record.a,
b=record.b,
result=record.result,
created_at=record.created_at,
)
for record, email in rows
]
return PaginatedSumRecords(
records=records,
total=total,
page=page,
per_page=per_page,
total_pages=total_pages,
)