finish branch

This commit is contained in:
counterweight 2025-12-19 00:12:43 +01:00
parent 66bc4c5a45
commit 40ca82bb45
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
11 changed files with 139 additions and 128 deletions

View file

@ -1,6 +1,6 @@
from contextlib import asynccontextmanager
from datetime import datetime
from typing import List
from typing import Any, Callable, Generic, TypeVar
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
from fastapi.middleware.cors import CORSMiddleware
@ -9,7 +9,43 @@ from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, get_db, Base
from models import Counter, User, SumRecord, CounterRecord, Permission, Role
from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR
R = TypeVar("R", bound=BaseModel)
async def paginate_with_user_email(
db: AsyncSession,
model: type[SumRecord] | type[CounterRecord],
page: int,
per_page: int,
row_mapper: Callable[..., R],
) -> tuple[list[R], int, int]:
"""
Generic pagination helper for audit records that need user email.
Returns: (records, total, total_pages)
"""
# Get total count
count_result = await db.execute(select(func.count(model.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(model, User.email)
.join(User, model.user_id == User.id)
.order_by(desc(model.created_at))
.offset(offset)
.limit(per_page)
)
result = await db.execute(query)
rows = result.all()
records: list[R] = [row_mapper(record, email) for record, email in rows]
return records, total, total_pages
from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME,
@ -57,7 +93,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
async def get_default_role(db: AsyncSession) -> Role | None:
"""Get the default 'regular' role for new users."""
result = await db.execute(select(Role).where(Role.name == "regular"))
result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR))
return result.scalar_one_or_none()
@ -214,20 +250,30 @@ class SumRecordResponse(BaseModel):
created_at: datetime
class PaginatedCounterRecords(BaseModel):
records: List[CounterRecordResponse]
RecordT = TypeVar("RecordT", bound=BaseModel)
class PaginatedResponse(BaseModel, Generic[RecordT]):
"""Generic paginated response wrapper."""
records: list[RecordT]
total: int
page: int
per_page: int
total_pages: int
class PaginatedSumRecords(BaseModel):
records: List[SumRecordResponse]
total: int
page: int
per_page: int
total_pages: int
PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse:
return CounterRecordResponse(
id=record.id,
user_email=email,
value_before=record.value_before,
value_after=record.value_after,
created_at=record.created_at,
)
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
@ -237,34 +283,9 @@ async def get_counter_records(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
):
# Get total count
count_result = await db.execute(select(func.count(CounterRecord.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(CounterRecord, User.email)
.join(User, CounterRecord.user_id == User.id)
.order_by(desc(CounterRecord.created_at))
.offset(offset)
.limit(per_page)
records, total, total_pages = await paginate_with_user_email(
db, CounterRecord, page, per_page, _map_counter_record
)
result = await db.execute(query)
rows = result.all()
records = [
CounterRecordResponse(
id=record.id,
user_email=email,
value_before=record.value_before,
value_after=record.value_after,
created_at=record.created_at,
)
for record, email in rows
]
return PaginatedCounterRecords(
records=records,
total=total,
@ -274,6 +295,17 @@ async def get_counter_records(
)
def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse:
return SumRecordResponse(
id=record.id,
user_email=email,
a=record.a,
b=record.b,
result=record.result,
created_at=record.created_at,
)
@app.get("/api/audit/sum", response_model=PaginatedSumRecords)
async def get_sum_records(
page: int = Query(1, ge=1),
@ -281,35 +313,9 @@ async def get_sum_records(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
):
# Get total count
count_result = await db.execute(select(func.count(SumRecord.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(SumRecord, User.email)
.join(User, SumRecord.user_id == User.id)
.order_by(desc(SumRecord.created_at))
.offset(offset)
.limit(per_page)
records, total, total_pages = await paginate_with_user_email(
db, SumRecord, page, per_page, _map_sum_record
)
result = await db.execute(query)
rows = result.all()
records = [
SumRecordResponse(
id=record.id,
user_email=email,
a=record.a,
b=record.b,
result=record.result,
created_at=record.created_at,
)
for record, email in rows
]
return PaginatedSumRecords(
records=records,
total=total,