tests passing
This commit is contained in:
parent
322bdd3e6e
commit
b173b47925
18 changed files with 1414 additions and 93 deletions
|
|
@ -1,12 +1,15 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from pydantic import BaseModel
|
||||
from database import engine, get_db, Base
|
||||
from models import Counter, User, SumRecord, CounterRecord
|
||||
from models import Counter, User, SumRecord, CounterRecord, Permission, Role
|
||||
from auth import (
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
COOKIE_NAME,
|
||||
|
|
@ -18,6 +21,8 @@ from auth import (
|
|||
authenticate_user,
|
||||
create_access_token,
|
||||
get_current_user,
|
||||
require_permission,
|
||||
build_user_response,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -50,6 +55,12 @@ def set_auth_cookie(response: Response, token: str) -> None:
|
|||
)
|
||||
|
||||
|
||||
async def get_default_role(db: AsyncSession) -> Role | None:
|
||||
"""Get the default 'regular' role for new users."""
|
||||
result = await db.execute(select(Role).where(Role.name == "regular"))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
# Auth endpoints
|
||||
@app.post("/api/auth/register", response_model=UserResponse)
|
||||
async def register(
|
||||
|
|
@ -68,13 +79,19 @@ async def register(
|
|||
email=user_data.email,
|
||||
hashed_password=get_password_hash(user_data.password),
|
||||
)
|
||||
|
||||
# Assign default role if it exists
|
||||
default_role = await get_default_role(db)
|
||||
if default_role:
|
||||
user.roles.append(default_role)
|
||||
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
set_auth_cookie(response, access_token)
|
||||
return UserResponse(id=user.id, email=user.email)
|
||||
return await build_user_response(user, db)
|
||||
|
||||
|
||||
@app.post("/api/auth/login", response_model=UserResponse)
|
||||
|
|
@ -92,7 +109,7 @@ async def login(
|
|||
|
||||
access_token = create_access_token(data={"sub": str(user.id)})
|
||||
set_auth_cookie(response, access_token)
|
||||
return UserResponse(id=user.id, email=user.email)
|
||||
return await build_user_response(user, db)
|
||||
|
||||
|
||||
@app.post("/api/auth/logout")
|
||||
|
|
@ -102,8 +119,11 @@ async def logout(response: Response):
|
|||
|
||||
|
||||
@app.get("/api/auth/me", response_model=UserResponse)
|
||||
async def get_me(current_user: User = Depends(get_current_user)):
|
||||
return UserResponse(id=current_user.id, email=current_user.email)
|
||||
async def get_me(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
return await build_user_response(current_user, db)
|
||||
|
||||
|
||||
# Counter endpoints
|
||||
|
|
@ -121,7 +141,7 @@ async def get_or_create_counter(db: AsyncSession) -> Counter:
|
|||
@app.get("/api/counter")
|
||||
async def get_counter(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(get_current_user),
|
||||
_current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
|
||||
):
|
||||
counter = await get_or_create_counter(db)
|
||||
return {"value": counter.value}
|
||||
|
|
@ -130,7 +150,7 @@ async def get_counter(
|
|||
@app.post("/api/counter/increment")
|
||||
async def increment_counter(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
|
||||
):
|
||||
counter = await get_or_create_counter(db)
|
||||
value_before = counter.value
|
||||
|
|
@ -162,7 +182,7 @@ class SumResponse(BaseModel):
|
|||
async def calculate_sum(
|
||||
data: SumRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Depends(require_permission(Permission.USE_SUM)),
|
||||
):
|
||||
result = data.a + data.b
|
||||
record = SumRecord(
|
||||
|
|
@ -177,10 +197,6 @@ async def calculate_sum(
|
|||
|
||||
|
||||
# Audit endpoints
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
|
||||
class CounterRecordResponse(BaseModel):
|
||||
id: int
|
||||
user_email: str
|
||||
|
|
@ -219,7 +235,7 @@ async def get_counter_records(
|
|||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(10, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(get_current_user),
|
||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||
):
|
||||
# Get total count
|
||||
count_result = await db.execute(select(func.count(CounterRecord.id)))
|
||||
|
|
@ -263,7 +279,7 @@ async def get_sum_records(
|
|||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(10, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(get_current_user),
|
||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||
):
|
||||
# Get total count
|
||||
count_result = await db.execute(select(func.count(SumRecord.id)))
|
||||
|
|
@ -301,4 +317,3 @@ async def get_sum_records(
|
|||
per_page=per_page,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue