arbret/backend/main.py

766 lines
22 KiB
Python

from contextlib import asynccontextmanager
from datetime import datetime, UTC
from typing import Callable, Generic, TypeVar
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr
from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession
from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME,
UserCreate,
UserLogin,
UserResponse,
get_password_hash,
get_user_by_email,
authenticate_user,
create_access_token,
get_current_user,
require_permission,
build_user_response,
)
from database import engine, get_db, Base
from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR, Invite, InviteStatus
from validation import validate_profile_fields
from invite_utils import generate_invite_identifier, normalize_identifier
R = TypeVar("R", bound=BaseModel)
async def paginate_with_user_email(
db: AsyncSession,
model: type[SumRecord] | type[CounterRecord],
page: int,
per_page: int,
row_mapper: Callable[..., R],
) -> tuple[list[R], int, int]:
"""
Generic pagination helper for audit records that need user email.
Returns: (records, total, total_pages)
"""
# Get total count
count_result = await db.execute(select(func.count(model.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(model, User.email)
.join(User, model.user_id == User.id)
.order_by(desc(model.created_at))
.offset(offset)
.limit(per_page)
)
result = await db.execute(query)
rows = result.all()
records: list[R] = [row_mapper(record, email) for record, email in rows]
return records, total, total_pages
@asynccontextmanager
async def lifespan(app: FastAPI):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
def set_auth_cookie(response: Response, token: str) -> None:
response.set_cookie(
key=COOKIE_NAME,
value=token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
async def get_default_role(db: AsyncSession) -> Role | None:
"""Get the default 'regular' role for new users."""
result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR))
return result.scalar_one_or_none()
# Invite check endpoint (public)
class InviteCheckResponse(BaseModel):
"""Response for invite check endpoint."""
valid: bool
status: str | None = None
error: str | None = None
@app.get("/api/invites/{identifier}/check", response_model=InviteCheckResponse)
async def check_invite(
identifier: str,
db: AsyncSession = Depends(get_db),
):
"""Check if an invite is valid and can be used for signup."""
normalized = normalize_identifier(identifier)
result = await db.execute(
select(Invite).where(Invite.identifier == normalized)
)
invite = result.scalar_one_or_none()
if not invite:
return InviteCheckResponse(valid=False, error="Invite not found")
if invite.status == InviteStatus.SPENT:
return InviteCheckResponse(
valid=False,
status=invite.status.value,
error="This invite has already been used"
)
if invite.status == InviteStatus.REVOKED:
return InviteCheckResponse(
valid=False,
status=invite.status.value,
error="This invite has been revoked"
)
return InviteCheckResponse(valid=True, status=invite.status.value)
# Auth endpoints
class RegisterWithInvite(BaseModel):
"""Request model for registration with invite."""
email: EmailStr
password: str
invite_identifier: str
@app.post("/api/auth/register", response_model=UserResponse)
async def register(
user_data: RegisterWithInvite,
response: Response,
db: AsyncSession = Depends(get_db),
):
"""Register a new user using an invite code."""
# Validate invite
normalized_identifier = normalize_identifier(user_data.invite_identifier)
result = await db.execute(
select(Invite).where(Invite.identifier == normalized_identifier)
)
invite = result.scalar_one_or_none()
if not invite:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid invite code",
)
if invite.status == InviteStatus.SPENT:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="This invite has already been used",
)
if invite.status == InviteStatus.REVOKED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="This invite has been revoked",
)
# Check email not already taken
existing_user = await get_user_by_email(db, user_data.email)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
# Create user with godfather
user = User(
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
godfather_id=invite.godfather_id,
)
# Assign default role
default_role = await get_default_role(db)
if default_role:
user.roles.append(default_role)
db.add(user)
await db.flush() # Get user ID
# Mark invite as spent
invite.status = InviteStatus.SPENT
invite.used_by_id = user.id
invite.spent_at = datetime.now(UTC)
await db.commit()
await db.refresh(user)
access_token = create_access_token(data={"sub": str(user.id)})
set_auth_cookie(response, access_token)
return await build_user_response(user, db)
@app.post("/api/auth/login", response_model=UserResponse)
async def login(
user_data: UserLogin,
response: Response,
db: AsyncSession = Depends(get_db),
):
user = await authenticate_user(db, user_data.email, user_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
)
access_token = create_access_token(data={"sub": str(user.id)})
set_auth_cookie(response, access_token)
return await build_user_response(user, db)
@app.post("/api/auth/logout")
async def logout(response: Response):
response.delete_cookie(key=COOKIE_NAME)
return {"ok": True}
@app.get("/api/auth/me", response_model=UserResponse)
async def get_me(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await build_user_response(current_user, db)
# Counter endpoints
async def get_or_create_counter(db: AsyncSession) -> Counter:
result = await db.execute(select(Counter).where(Counter.id == 1))
counter = result.scalar_one_or_none()
if not counter:
counter = Counter(id=1, value=0)
db.add(counter)
await db.commit()
await db.refresh(counter)
return counter
@app.get("/api/counter")
async def get_counter(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
):
counter = await get_or_create_counter(db)
return {"value": counter.value}
@app.post("/api/counter/increment")
async def increment_counter(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
):
counter = await get_or_create_counter(db)
value_before = counter.value
counter.value += 1
record = CounterRecord(
user_id=current_user.id,
value_before=value_before,
value_after=counter.value,
)
db.add(record)
await db.commit()
return {"value": counter.value}
# Sum endpoints
class SumRequest(BaseModel):
a: float
b: float
class SumResponse(BaseModel):
a: float
b: float
result: float
@app.post("/api/sum", response_model=SumResponse)
async def calculate_sum(
data: SumRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_permission(Permission.USE_SUM)),
):
result = data.a + data.b
record = SumRecord(
user_id=current_user.id,
a=data.a,
b=data.b,
result=result,
)
db.add(record)
await db.commit()
return SumResponse(a=data.a, b=data.b, result=result)
# Audit endpoints
class CounterRecordResponse(BaseModel):
id: int
user_email: str
value_before: int
value_after: int
created_at: datetime
class SumRecordResponse(BaseModel):
id: int
user_email: str
a: float
b: float
result: float
created_at: datetime
RecordT = TypeVar("RecordT", bound=BaseModel)
class PaginatedResponse(BaseModel, Generic[RecordT]):
"""Generic paginated response wrapper."""
records: list[RecordT]
total: int
page: int
per_page: int
total_pages: int
PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse:
return CounterRecordResponse(
id=record.id,
user_email=email,
value_before=record.value_before,
value_after=record.value_after,
created_at=record.created_at,
)
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
async def get_counter_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
):
records, total, total_pages = await paginate_with_user_email(
db, CounterRecord, page, per_page, _map_counter_record
)
return PaginatedCounterRecords(
records=records,
total=total,
page=page,
per_page=per_page,
total_pages=total_pages,
)
def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse:
return SumRecordResponse(
id=record.id,
user_email=email,
a=record.a,
b=record.b,
result=record.result,
created_at=record.created_at,
)
@app.get("/api/audit/sum", response_model=PaginatedSumRecords)
async def get_sum_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
):
records, total, total_pages = await paginate_with_user_email(
db, SumRecord, page, per_page, _map_sum_record
)
return PaginatedSumRecords(
records=records,
total=total,
page=page,
per_page=per_page,
total_pages=total_pages,
)
# Profile endpoints
class ProfileResponse(BaseModel):
"""Response model for profile data."""
contact_email: str | None
telegram: str | None
signal: str | None
nostr_npub: str | None
godfather_email: str | None = None
class ProfileUpdate(BaseModel):
"""Request model for updating profile."""
contact_email: str | None = None
telegram: str | None = None
signal: str | None = None
nostr_npub: str | None = None
async def require_regular_user(
current_user: User = Depends(get_current_user),
) -> User:
"""Dependency that requires the user to have the 'regular' role."""
if ROLE_REGULAR not in current_user.role_names:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Profile access is only available to regular users",
)
return current_user
@app.get("/api/profile", response_model=ProfileResponse)
async def get_profile(
current_user: User = Depends(require_regular_user),
db: AsyncSession = Depends(get_db),
):
"""Get the current user's profile (contact details and godfather)."""
godfather_email = None
if current_user.godfather_id:
result = await db.execute(
select(User.email).where(User.id == current_user.godfather_id)
)
godfather_email = result.scalar_one_or_none()
return ProfileResponse(
contact_email=current_user.contact_email,
telegram=current_user.telegram,
signal=current_user.signal,
nostr_npub=current_user.nostr_npub,
godfather_email=godfather_email,
)
@app.put("/api/profile", response_model=ProfileResponse)
async def update_profile(
data: ProfileUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_regular_user),
):
"""Update the current user's profile (contact details)."""
# Validate all fields
errors = validate_profile_fields(
contact_email=data.contact_email,
telegram=data.telegram,
signal=data.signal,
nostr_npub=data.nostr_npub,
)
if errors:
raise HTTPException(
status_code=422,
detail={"field_errors": errors},
)
# Update fields
current_user.contact_email = data.contact_email
current_user.telegram = data.telegram
current_user.signal = data.signal
current_user.nostr_npub = data.nostr_npub
await db.commit()
await db.refresh(current_user)
# Get godfather email if set
godfather_email = None
if current_user.godfather_id:
gf_result = await db.execute(
select(User.email).where(User.id == current_user.godfather_id)
)
godfather_email = gf_result.scalar_one_or_none()
return ProfileResponse(
contact_email=current_user.contact_email,
telegram=current_user.telegram,
signal=current_user.signal,
nostr_npub=current_user.nostr_npub,
godfather_email=godfather_email,
)
# Invite endpoints
class InviteCreate(BaseModel):
"""Request model for creating an invite."""
godfather_id: int
class InviteResponse(BaseModel):
"""Response model for invite data."""
id: int
identifier: str
godfather_id: int
godfather_email: str
status: str
used_by_id: int | None
used_by_email: str | None
created_at: datetime
spent_at: datetime | None
revoked_at: datetime | None
@app.post("/api/admin/invites", response_model=InviteResponse)
async def create_invite(
data: InviteCreate,
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
):
"""Create a new invite for a specified godfather user."""
# Validate godfather exists
result = await db.execute(select(User).where(User.id == data.godfather_id))
godfather = result.scalar_one_or_none()
if not godfather:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Godfather user not found",
)
# Generate unique identifier
identifier = generate_invite_identifier()
# Create invite
invite = Invite(
identifier=identifier,
godfather_id=godfather.id,
status=InviteStatus.READY,
)
db.add(invite)
await db.commit()
await db.refresh(invite)
return InviteResponse(
id=invite.id,
identifier=invite.identifier,
godfather_id=invite.godfather_id,
godfather_email=godfather.email,
status=invite.status.value,
used_by_id=invite.used_by_id,
used_by_email=None,
created_at=invite.created_at,
spent_at=invite.spent_at,
revoked_at=invite.revoked_at,
)
class UserInviteResponse(BaseModel):
"""Response model for a user's invite (simpler than admin view)."""
id: int
identifier: str
status: str
used_by_email: str | None
created_at: datetime
spent_at: datetime | None
@app.get("/api/invites", response_model=list[UserInviteResponse])
async def get_my_invites(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_permission(Permission.VIEW_OWN_INVITES)),
):
"""Get all invites owned by the current user."""
result = await db.execute(
select(Invite)
.where(Invite.godfather_id == current_user.id)
.order_by(desc(Invite.created_at))
)
invites = result.scalars().all()
responses = []
for invite in invites:
used_by_email = None
if invite.used_by_id:
# Fetch the user who used this invite
user_result = await db.execute(
select(User.email).where(User.id == invite.used_by_id)
)
used_by_email = user_result.scalar_one_or_none()
responses.append(UserInviteResponse(
id=invite.id,
identifier=invite.identifier,
status=invite.status.value,
used_by_email=used_by_email,
created_at=invite.created_at,
spent_at=invite.spent_at,
))
return responses
# Admin Invite Management
PaginatedInviteRecords = PaginatedResponse[InviteResponse]
class AdminUserResponse(BaseModel):
"""Minimal user info for admin dropdowns."""
id: int
email: str
@app.get("/api/admin/users", response_model=list[AdminUserResponse])
async def list_users_for_admin(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
):
"""List all users for admin dropdowns (invite creation, etc.)."""
result = await db.execute(select(User.id, User.email).order_by(User.email))
users = result.all()
return [AdminUserResponse(id=u.id, email=u.email) for u in users]
@app.get("/api/admin/invites", response_model=PaginatedInviteRecords)
async def list_all_invites(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"),
godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
):
"""List all invites with optional filtering and pagination."""
# Build query
query = select(Invite)
count_query = select(func.count(Invite.id))
# Apply filters
if status_filter:
try:
status_enum = InviteStatus(status_filter)
query = query.where(Invite.status == status_enum)
count_query = count_query.where(Invite.status == status_enum)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked",
)
if godfather_id:
query = query.where(Invite.godfather_id == godfather_id)
count_query = count_query.where(Invite.godfather_id == godfather_id)
# Get total count
count_result = await db.execute(count_query)
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated invites
offset = (page - 1) * per_page
query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page)
result = await db.execute(query)
invites = result.scalars().all()
# Build responses with user emails
records = []
for invite in invites:
# Get godfather email
gf_result = await db.execute(
select(User.email).where(User.id == invite.godfather_id)
)
godfather_email = gf_result.scalar_one()
# Get used_by email if applicable
used_by_email = None
if invite.used_by_id:
ub_result = await db.execute(
select(User.email).where(User.id == invite.used_by_id)
)
used_by_email = ub_result.scalar_one_or_none()
records.append(InviteResponse(
id=invite.id,
identifier=invite.identifier,
godfather_id=invite.godfather_id,
godfather_email=godfather_email,
status=invite.status.value,
used_by_id=invite.used_by_id,
used_by_email=used_by_email,
created_at=invite.created_at,
spent_at=invite.spent_at,
revoked_at=invite.revoked_at,
))
return PaginatedInviteRecords(
records=records,
total=total,
page=page,
per_page=per_page,
total_pages=total_pages,
)
@app.post("/api/admin/invites/{invite_id}/revoke", response_model=InviteResponse)
async def revoke_invite(
invite_id: int,
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
):
"""Revoke an invite. Only READY invites can be revoked."""
result = await db.execute(select(Invite).where(Invite.id == invite_id))
invite = result.scalar_one_or_none()
if not invite:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Invite not found",
)
if invite.status != InviteStatus.READY:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.",
)
invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(UTC)
await db.commit()
await db.refresh(invite)
# Get godfather email
gf_result = await db.execute(
select(User.email).where(User.id == invite.godfather_id)
)
godfather_email = gf_result.scalar_one()
return InviteResponse(
id=invite.id,
identifier=invite.identifier,
godfather_id=invite.godfather_id,
godfather_email=godfather_email,
status=invite.status.value,
used_by_id=invite.used_by_id,
used_by_email=None,
created_at=invite.created_at,
spent_at=invite.spent_at,
revoked_at=invite.revoked_at,
)