tests passing

This commit is contained in:
counterweight 2025-12-18 23:33:32 +01:00
parent 322bdd3e6e
commit b173b47925
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
18 changed files with 1414 additions and 93 deletions

View file

@ -1,12 +1,15 @@
from contextlib import asynccontextmanager
from datetime import datetime
from typing import List
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from database import engine, get_db, Base
from models import Counter, User, SumRecord, CounterRecord
from models import Counter, User, SumRecord, CounterRecord, Permission, Role
from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME,
@ -18,6 +21,8 @@ from auth import (
authenticate_user,
create_access_token,
get_current_user,
require_permission,
build_user_response,
)
@ -50,6 +55,12 @@ def set_auth_cookie(response: Response, token: str) -> None:
)
async def get_default_role(db: AsyncSession) -> Role | None:
"""Get the default 'regular' role for new users."""
result = await db.execute(select(Role).where(Role.name == "regular"))
return result.scalar_one_or_none()
# Auth endpoints
@app.post("/api/auth/register", response_model=UserResponse)
async def register(
@ -68,13 +79,19 @@ async def register(
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
)
# Assign default role if it exists
default_role = await get_default_role(db)
if default_role:
user.roles.append(default_role)
db.add(user)
await db.commit()
await db.refresh(user)
access_token = create_access_token(data={"sub": str(user.id)})
set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email)
return await build_user_response(user, db)
@app.post("/api/auth/login", response_model=UserResponse)
@ -92,7 +109,7 @@ async def login(
access_token = create_access_token(data={"sub": str(user.id)})
set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email)
return await build_user_response(user, db)
@app.post("/api/auth/logout")
@ -102,8 +119,11 @@ async def logout(response: Response):
@app.get("/api/auth/me", response_model=UserResponse)
async def get_me(current_user: User = Depends(get_current_user)):
return UserResponse(id=current_user.id, email=current_user.email)
async def get_me(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await build_user_response(current_user, db)
# Counter endpoints
@ -121,7 +141,7 @@ async def get_or_create_counter(db: AsyncSession) -> Counter:
@app.get("/api/counter")
async def get_counter(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
_current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
):
counter = await get_or_create_counter(db)
return {"value": counter.value}
@ -130,7 +150,7 @@ async def get_counter(
@app.post("/api/counter/increment")
async def increment_counter(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
):
counter = await get_or_create_counter(db)
value_before = counter.value
@ -162,7 +182,7 @@ class SumResponse(BaseModel):
async def calculate_sum(
data: SumRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
current_user: User = Depends(require_permission(Permission.USE_SUM)),
):
result = data.a + data.b
record = SumRecord(
@ -177,10 +197,6 @@ async def calculate_sum(
# Audit endpoints
from datetime import datetime
from typing import List
class CounterRecordResponse(BaseModel):
id: int
user_email: str
@ -219,7 +235,7 @@ async def get_counter_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
):
# Get total count
count_result = await db.execute(select(func.count(CounterRecord.id)))
@ -263,7 +279,7 @@ async def get_sum_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
):
# Get total count
count_result = await db.execute(select(func.count(SumRecord.id)))
@ -301,4 +317,3 @@ async def get_sum_records(
per_page=per_page,
total_pages=total_pages,
)