2025-12-18 21:48:41 +01:00
|
|
|
from contextlib import asynccontextmanager
|
2025-12-18 23:33:32 +01:00
|
|
|
from datetime import datetime
|
2025-12-19 11:08:19 +01:00
|
|
|
from typing import Callable, Generic, TypeVar
|
2025-12-18 23:33:32 +01:00
|
|
|
|
2025-12-18 22:51:43 +01:00
|
|
|
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
|
2025-12-18 21:37:28 +01:00
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2025-12-18 23:33:32 +01:00
|
|
|
from pydantic import BaseModel
|
2025-12-18 22:51:43 +01:00
|
|
|
from sqlalchemy import select, func, desc
|
2025-12-18 21:48:41 +01:00
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2025-12-18 21:37:28 +01:00
|
|
|
|
2025-12-19 11:08:19 +01:00
|
|
|
from auth import (
|
|
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
|
|
|
|
COOKIE_NAME,
|
|
|
|
|
UserCreate,
|
|
|
|
|
UserLogin,
|
|
|
|
|
UserResponse,
|
|
|
|
|
get_password_hash,
|
|
|
|
|
get_user_by_email,
|
|
|
|
|
authenticate_user,
|
|
|
|
|
create_access_token,
|
|
|
|
|
get_current_user,
|
|
|
|
|
require_permission,
|
|
|
|
|
build_user_response,
|
|
|
|
|
)
|
2025-12-18 21:48:41 +01:00
|
|
|
from database import engine, get_db, Base
|
2025-12-19 00:12:43 +01:00
|
|
|
from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR
|
2025-12-19 10:12:55 +01:00
|
|
|
from validation import validate_profile_fields
|
2025-12-19 00:12:43 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
R = TypeVar("R", bound=BaseModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def paginate_with_user_email(
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
model: type[SumRecord] | type[CounterRecord],
|
|
|
|
|
page: int,
|
|
|
|
|
per_page: int,
|
|
|
|
|
row_mapper: Callable[..., R],
|
|
|
|
|
) -> tuple[list[R], int, int]:
|
|
|
|
|
"""
|
|
|
|
|
Generic pagination helper for audit records that need user email.
|
|
|
|
|
|
|
|
|
|
Returns: (records, total, total_pages)
|
|
|
|
|
"""
|
|
|
|
|
# Get total count
|
|
|
|
|
count_result = await db.execute(select(func.count(model.id)))
|
|
|
|
|
total = count_result.scalar() or 0
|
|
|
|
|
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
|
|
|
|
|
|
|
|
|
# Get paginated records with user email
|
|
|
|
|
offset = (page - 1) * per_page
|
|
|
|
|
query = (
|
|
|
|
|
select(model, User.email)
|
|
|
|
|
.join(User, model.user_id == User.id)
|
|
|
|
|
.order_by(desc(model.created_at))
|
|
|
|
|
.offset(offset)
|
|
|
|
|
.limit(per_page)
|
|
|
|
|
)
|
|
|
|
|
result = await db.execute(query)
|
|
|
|
|
rows = result.all()
|
|
|
|
|
|
|
|
|
|
records: list[R] = [row_mapper(record, email) for record, email in rows]
|
|
|
|
|
return records, total, total_pages
|
2025-12-18 21:48:41 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
|
async with engine.begin() as conn:
|
|
|
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
2025-12-18 21:37:28 +01:00
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
allow_origins=["http://localhost:3000"],
|
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
allow_headers=["*"],
|
2025-12-18 22:08:31 +01:00
|
|
|
allow_credentials=True,
|
2025-12-18 21:37:28 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-12-18 22:24:46 +01:00
|
|
|
def set_auth_cookie(response: Response, token: str) -> None:
|
|
|
|
|
response.set_cookie(
|
|
|
|
|
key=COOKIE_NAME,
|
|
|
|
|
value=token,
|
|
|
|
|
httponly=True,
|
|
|
|
|
secure=False, # Set to True in production with HTTPS
|
|
|
|
|
samesite="lax",
|
|
|
|
|
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-12-18 23:33:32 +01:00
|
|
|
async def get_default_role(db: AsyncSession) -> Role | None:
|
|
|
|
|
"""Get the default 'regular' role for new users."""
|
2025-12-19 00:12:43 +01:00
|
|
|
result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR))
|
2025-12-18 23:33:32 +01:00
|
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
|
2025-12-18 22:08:31 +01:00
|
|
|
# Auth endpoints
|
2025-12-18 22:24:46 +01:00
|
|
|
@app.post("/api/auth/register", response_model=UserResponse)
|
|
|
|
|
async def register(
|
|
|
|
|
user_data: UserCreate,
|
|
|
|
|
response: Response,
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
2025-12-18 22:08:31 +01:00
|
|
|
existing_user = await get_user_by_email(db, user_data.email)
|
|
|
|
|
if existing_user:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
detail="Email already registered",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
user = User(
|
|
|
|
|
email=user_data.email,
|
|
|
|
|
hashed_password=get_password_hash(user_data.password),
|
|
|
|
|
)
|
2025-12-18 23:33:32 +01:00
|
|
|
|
|
|
|
|
# Assign default role if it exists
|
|
|
|
|
default_role = await get_default_role(db)
|
|
|
|
|
if default_role:
|
|
|
|
|
user.roles.append(default_role)
|
|
|
|
|
|
2025-12-18 22:08:31 +01:00
|
|
|
db.add(user)
|
|
|
|
|
await db.commit()
|
|
|
|
|
await db.refresh(user)
|
|
|
|
|
|
|
|
|
|
access_token = create_access_token(data={"sub": str(user.id)})
|
2025-12-18 22:24:46 +01:00
|
|
|
set_auth_cookie(response, access_token)
|
2025-12-18 23:33:32 +01:00
|
|
|
return await build_user_response(user, db)
|
2025-12-18 22:08:31 +01:00
|
|
|
|
|
|
|
|
|
2025-12-18 22:24:46 +01:00
|
|
|
@app.post("/api/auth/login", response_model=UserResponse)
|
|
|
|
|
async def login(
|
|
|
|
|
user_data: UserLogin,
|
|
|
|
|
response: Response,
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
2025-12-18 22:08:31 +01:00
|
|
|
user = await authenticate_user(db, user_data.email, user_data.password)
|
|
|
|
|
if not user:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
detail="Incorrect email or password",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
access_token = create_access_token(data={"sub": str(user.id)})
|
2025-12-18 22:24:46 +01:00
|
|
|
set_auth_cookie(response, access_token)
|
2025-12-18 23:33:32 +01:00
|
|
|
return await build_user_response(user, db)
|
2025-12-18 22:24:46 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/auth/logout")
|
|
|
|
|
async def logout(response: Response):
|
|
|
|
|
response.delete_cookie(key=COOKIE_NAME)
|
|
|
|
|
return {"ok": True}
|
2025-12-18 22:08:31 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/auth/me", response_model=UserResponse)
|
2025-12-18 23:33:32 +01:00
|
|
|
async def get_me(
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
):
|
|
|
|
|
return await build_user_response(current_user, db)
|
2025-12-18 22:08:31 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# Counter endpoints
|
2025-12-18 21:48:41 +01:00
|
|
|
async def get_or_create_counter(db: AsyncSession) -> Counter:
|
|
|
|
|
result = await db.execute(select(Counter).where(Counter.id == 1))
|
|
|
|
|
counter = result.scalar_one_or_none()
|
|
|
|
|
if not counter:
|
|
|
|
|
counter = Counter(id=1, value=0)
|
|
|
|
|
db.add(counter)
|
|
|
|
|
await db.commit()
|
|
|
|
|
await db.refresh(counter)
|
|
|
|
|
return counter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/counter")
|
2025-12-18 22:08:31 +01:00
|
|
|
async def get_counter(
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
2025-12-18 23:33:32 +01:00
|
|
|
_current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
|
2025-12-18 22:08:31 +01:00
|
|
|
):
|
2025-12-18 21:48:41 +01:00
|
|
|
counter = await get_or_create_counter(db)
|
|
|
|
|
return {"value": counter.value}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/counter/increment")
|
2025-12-18 22:08:31 +01:00
|
|
|
async def increment_counter(
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
2025-12-18 23:33:32 +01:00
|
|
|
current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
|
2025-12-18 22:08:31 +01:00
|
|
|
):
|
2025-12-18 21:48:41 +01:00
|
|
|
counter = await get_or_create_counter(db)
|
2025-12-18 22:51:43 +01:00
|
|
|
value_before = counter.value
|
2025-12-18 21:48:41 +01:00
|
|
|
counter.value += 1
|
2025-12-18 22:51:43 +01:00
|
|
|
|
|
|
|
|
record = CounterRecord(
|
|
|
|
|
user_id=current_user.id,
|
|
|
|
|
value_before=value_before,
|
|
|
|
|
value_after=counter.value,
|
|
|
|
|
)
|
|
|
|
|
db.add(record)
|
2025-12-18 21:48:41 +01:00
|
|
|
await db.commit()
|
|
|
|
|
return {"value": counter.value}
|
2025-12-18 21:37:28 +01:00
|
|
|
|
2025-12-18 22:51:43 +01:00
|
|
|
|
|
|
|
|
# Sum endpoints
|
|
|
|
|
class SumRequest(BaseModel):
|
|
|
|
|
a: float
|
|
|
|
|
b: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SumResponse(BaseModel):
|
|
|
|
|
a: float
|
|
|
|
|
b: float
|
|
|
|
|
result: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/sum", response_model=SumResponse)
|
|
|
|
|
async def calculate_sum(
|
|
|
|
|
data: SumRequest,
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
2025-12-18 23:33:32 +01:00
|
|
|
current_user: User = Depends(require_permission(Permission.USE_SUM)),
|
2025-12-18 22:51:43 +01:00
|
|
|
):
|
|
|
|
|
result = data.a + data.b
|
|
|
|
|
record = SumRecord(
|
|
|
|
|
user_id=current_user.id,
|
|
|
|
|
a=data.a,
|
|
|
|
|
b=data.b,
|
|
|
|
|
result=result,
|
|
|
|
|
)
|
|
|
|
|
db.add(record)
|
|
|
|
|
await db.commit()
|
|
|
|
|
return SumResponse(a=data.a, b=data.b, result=result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Audit endpoints
|
|
|
|
|
class CounterRecordResponse(BaseModel):
|
|
|
|
|
id: int
|
|
|
|
|
user_email: str
|
|
|
|
|
value_before: int
|
|
|
|
|
value_after: int
|
|
|
|
|
created_at: datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SumRecordResponse(BaseModel):
|
|
|
|
|
id: int
|
|
|
|
|
user_email: str
|
|
|
|
|
a: float
|
|
|
|
|
b: float
|
|
|
|
|
result: float
|
|
|
|
|
created_at: datetime
|
|
|
|
|
|
|
|
|
|
|
2025-12-19 00:12:43 +01:00
|
|
|
RecordT = TypeVar("RecordT", bound=BaseModel)
|
2025-12-18 22:51:43 +01:00
|
|
|
|
|
|
|
|
|
2025-12-19 00:12:43 +01:00
|
|
|
class PaginatedResponse(BaseModel, Generic[RecordT]):
|
|
|
|
|
"""Generic paginated response wrapper."""
|
|
|
|
|
records: list[RecordT]
|
2025-12-18 22:51:43 +01:00
|
|
|
total: int
|
|
|
|
|
page: int
|
|
|
|
|
per_page: int
|
|
|
|
|
total_pages: int
|
|
|
|
|
|
|
|
|
|
|
2025-12-19 00:12:43 +01:00
|
|
|
PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
|
|
|
|
|
PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse:
|
|
|
|
|
return CounterRecordResponse(
|
|
|
|
|
id=record.id,
|
|
|
|
|
user_email=email,
|
|
|
|
|
value_before=record.value_before,
|
|
|
|
|
value_after=record.value_after,
|
|
|
|
|
created_at=record.created_at,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-12-18 22:51:43 +01:00
|
|
|
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
|
|
|
|
|
async def get_counter_records(
|
|
|
|
|
page: int = Query(1, ge=1),
|
|
|
|
|
per_page: int = Query(10, ge=1, le=100),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
2025-12-18 23:33:32 +01:00
|
|
|
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
2025-12-18 22:51:43 +01:00
|
|
|
):
|
2025-12-19 00:12:43 +01:00
|
|
|
records, total, total_pages = await paginate_with_user_email(
|
|
|
|
|
db, CounterRecord, page, per_page, _map_counter_record
|
2025-12-18 22:51:43 +01:00
|
|
|
)
|
|
|
|
|
return PaginatedCounterRecords(
|
|
|
|
|
records=records,
|
|
|
|
|
total=total,
|
|
|
|
|
page=page,
|
|
|
|
|
per_page=per_page,
|
|
|
|
|
total_pages=total_pages,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-12-19 00:12:43 +01:00
|
|
|
def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse:
|
|
|
|
|
return SumRecordResponse(
|
|
|
|
|
id=record.id,
|
|
|
|
|
user_email=email,
|
|
|
|
|
a=record.a,
|
|
|
|
|
b=record.b,
|
|
|
|
|
result=record.result,
|
|
|
|
|
created_at=record.created_at,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-12-18 22:51:43 +01:00
|
|
|
@app.get("/api/audit/sum", response_model=PaginatedSumRecords)
|
|
|
|
|
async def get_sum_records(
|
|
|
|
|
page: int = Query(1, ge=1),
|
|
|
|
|
per_page: int = Query(10, ge=1, le=100),
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
2025-12-18 23:33:32 +01:00
|
|
|
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
2025-12-18 22:51:43 +01:00
|
|
|
):
|
2025-12-19 00:12:43 +01:00
|
|
|
records, total, total_pages = await paginate_with_user_email(
|
|
|
|
|
db, SumRecord, page, per_page, _map_sum_record
|
2025-12-18 22:51:43 +01:00
|
|
|
)
|
|
|
|
|
return PaginatedSumRecords(
|
|
|
|
|
records=records,
|
|
|
|
|
total=total,
|
|
|
|
|
page=page,
|
|
|
|
|
per_page=per_page,
|
|
|
|
|
total_pages=total_pages,
|
|
|
|
|
)
|
2025-12-19 10:12:55 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# Profile endpoints
|
|
|
|
|
class ProfileResponse(BaseModel):
|
|
|
|
|
"""Response model for profile data."""
|
|
|
|
|
contact_email: str | None
|
|
|
|
|
telegram: str | None
|
|
|
|
|
signal: str | None
|
|
|
|
|
nostr_npub: str | None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ProfileUpdate(BaseModel):
|
|
|
|
|
"""Request model for updating profile."""
|
|
|
|
|
contact_email: str | None = None
|
|
|
|
|
telegram: str | None = None
|
|
|
|
|
signal: str | None = None
|
|
|
|
|
nostr_npub: str | None = None
|
|
|
|
|
|
|
|
|
|
|
2025-12-19 10:38:15 +01:00
|
|
|
async def require_regular_user(
|
|
|
|
|
current_user: User = Depends(get_current_user),
|
|
|
|
|
) -> User:
|
2025-12-19 10:12:55 +01:00
|
|
|
"""Dependency that requires the user to have the 'regular' role."""
|
2025-12-19 10:38:15 +01:00
|
|
|
if ROLE_REGULAR not in current_user.role_names:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
|
|
|
detail="Profile access is only available to regular users",
|
|
|
|
|
)
|
|
|
|
|
return current_user
|
2025-12-19 10:12:55 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/profile", response_model=ProfileResponse)
|
|
|
|
|
async def get_profile(
|
2025-12-19 10:38:15 +01:00
|
|
|
current_user: User = Depends(require_regular_user),
|
2025-12-19 10:12:55 +01:00
|
|
|
):
|
|
|
|
|
"""Get the current user's profile (contact details)."""
|
|
|
|
|
return ProfileResponse(
|
|
|
|
|
contact_email=current_user.contact_email,
|
|
|
|
|
telegram=current_user.telegram,
|
|
|
|
|
signal=current_user.signal,
|
|
|
|
|
nostr_npub=current_user.nostr_npub,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.put("/api/profile", response_model=ProfileResponse)
|
|
|
|
|
async def update_profile(
|
|
|
|
|
data: ProfileUpdate,
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
2025-12-19 10:38:15 +01:00
|
|
|
current_user: User = Depends(require_regular_user),
|
2025-12-19 10:12:55 +01:00
|
|
|
):
|
|
|
|
|
"""Update the current user's profile (contact details)."""
|
|
|
|
|
# Validate all fields
|
|
|
|
|
errors = validate_profile_fields(
|
|
|
|
|
contact_email=data.contact_email,
|
|
|
|
|
telegram=data.telegram,
|
|
|
|
|
signal=data.signal,
|
|
|
|
|
nostr_npub=data.nostr_npub,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if errors:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=422,
|
|
|
|
|
detail={"field_errors": errors},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Update fields
|
|
|
|
|
current_user.contact_email = data.contact_email
|
|
|
|
|
current_user.telegram = data.telegram
|
|
|
|
|
current_user.signal = data.signal
|
|
|
|
|
current_user.nostr_npub = data.nostr_npub
|
|
|
|
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
await db.refresh(current_user)
|
|
|
|
|
|
|
|
|
|
return ProfileResponse(
|
|
|
|
|
contact_email=current_user.contact_email,
|
|
|
|
|
telegram=current_user.telegram,
|
|
|
|
|
signal=current_user.signal,
|
|
|
|
|
nostr_npub=current_user.nostr_npub,
|
|
|
|
|
)
|