first implementation
This commit is contained in:
parent
1eb4641ed9
commit
a56a4c076a
14 changed files with 898 additions and 729 deletions
|
|
@ -7,3 +7,4 @@ alwaysApply: false
|
||||||
- Use docstrings following the style of the existing code.
|
- Use docstrings following the style of the existing code.
|
||||||
- Avoid bloaty comments. Instead, favour using descriptive functions and variable names to make what's happening obvious.
|
- Avoid bloaty comments. Instead, favour using descriptive functions and variable names to make what's happening obvious.
|
||||||
- Use tests frequently to detect errors early. Make sure that tests are kept up to date. Remove dead code tests when removing code.
|
- Use tests frequently to detect errors early. Make sure that tests are kept up to date. Remove dead code tests when removing code.
|
||||||
|
- Avoid using comments to divide sections of code. If you feel that is necessary, you probably should break into multiple files instead.
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,12 @@ from datetime import datetime, timedelta, timezone
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from fastapi import Depends, HTTPException, Request, status
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import User, Permission
|
from models import User, Permission
|
||||||
|
from schemas import UserResponse
|
||||||
|
|
||||||
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
|
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
|
|
@ -17,28 +17,6 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
|
||||||
COOKIE_NAME = "auth_token"
|
COOKIE_NAME = "auth_token"
|
||||||
|
|
||||||
|
|
||||||
class UserCredentials(BaseModel):
|
|
||||||
email: EmailStr
|
|
||||||
password: str
|
|
||||||
|
|
||||||
|
|
||||||
UserCreate = UserCredentials
|
|
||||||
UserLogin = UserCredentials
|
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
|
||||||
id: int
|
|
||||||
email: str
|
|
||||||
roles: list[str]
|
|
||||||
permissions: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
token_type: str
|
|
||||||
user: UserResponse
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
return bcrypt.checkpw(
|
return bcrypt.checkpw(
|
||||||
plain_password.encode("utf-8"),
|
plain_password.encode("utf-8"),
|
||||||
|
|
|
||||||
712
backend/main.py
712
backend/main.py
|
|
@ -1,72 +1,21 @@
|
||||||
|
"""FastAPI application entry point."""
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime, UTC
|
|
||||||
from typing import Callable, Generic, TypeVar
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel, EmailStr
|
|
||||||
from sqlalchemy import select, func, desc
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from auth import (
|
from database import engine, Base
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
from routes import sum as sum_routes
|
||||||
COOKIE_NAME,
|
from routes import counter as counter_routes
|
||||||
UserCreate,
|
from routes import audit as audit_routes
|
||||||
UserLogin,
|
from routes import profile as profile_routes
|
||||||
UserResponse,
|
from routes import invites as invites_routes
|
||||||
get_password_hash,
|
from routes import auth as auth_routes
|
||||||
get_user_by_email,
|
|
||||||
authenticate_user,
|
|
||||||
create_access_token,
|
|
||||||
get_current_user,
|
|
||||||
require_permission,
|
|
||||||
build_user_response,
|
|
||||||
)
|
|
||||||
from database import engine, get_db, Base
|
|
||||||
from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR, Invite, InviteStatus
|
|
||||||
from validation import validate_profile_fields
|
|
||||||
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format
|
|
||||||
|
|
||||||
|
|
||||||
R = TypeVar("R", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
async def paginate_with_user_email(
|
|
||||||
db: AsyncSession,
|
|
||||||
model: type[SumRecord] | type[CounterRecord],
|
|
||||||
page: int,
|
|
||||||
per_page: int,
|
|
||||||
row_mapper: Callable[..., R],
|
|
||||||
) -> tuple[list[R], int, int]:
|
|
||||||
"""
|
|
||||||
Generic pagination helper for audit records that need user email.
|
|
||||||
|
|
||||||
Returns: (records, total, total_pages)
|
|
||||||
"""
|
|
||||||
# Get total count
|
|
||||||
count_result = await db.execute(select(func.count(model.id)))
|
|
||||||
total = count_result.scalar() or 0
|
|
||||||
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
|
||||||
|
|
||||||
# Get paginated records with user email
|
|
||||||
offset = (page - 1) * per_page
|
|
||||||
query = (
|
|
||||||
select(model, User.email)
|
|
||||||
.join(User, model.user_id == User.id)
|
|
||||||
.order_by(desc(model.created_at))
|
|
||||||
.offset(offset)
|
|
||||||
.limit(per_page)
|
|
||||||
)
|
|
||||||
result = await db.execute(query)
|
|
||||||
rows = result.all()
|
|
||||||
|
|
||||||
records: list[R] = [row_mapper(record, email) for record, email in rows]
|
|
||||||
return records, total, total_pages
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Create database tables on startup."""
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
yield
|
yield
|
||||||
|
|
@ -82,637 +31,10 @@ app.add_middleware(
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Include routers
|
||||||
def set_auth_cookie(response: Response, token: str) -> None:
|
app.include_router(auth_routes.router)
|
||||||
response.set_cookie(
|
app.include_router(sum_routes.router)
|
||||||
key=COOKIE_NAME,
|
app.include_router(counter_routes.router)
|
||||||
value=token,
|
app.include_router(audit_routes.router)
|
||||||
httponly=True,
|
app.include_router(profile_routes.router)
|
||||||
secure=False, # Set to True in production with HTTPS
|
app.include_router(invites_routes.router)
|
||||||
samesite="lax",
|
|
||||||
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_default_role(db: AsyncSession) -> Role | None:
|
|
||||||
"""Get the default 'regular' role for new users."""
|
|
||||||
result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR))
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
|
|
||||||
# Invite check endpoint (public)
|
|
||||||
class InviteCheckResponse(BaseModel):
|
|
||||||
"""Response for invite check endpoint."""
|
|
||||||
valid: bool
|
|
||||||
status: str | None = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/invites/{identifier}/check", response_model=InviteCheckResponse)
|
|
||||||
async def check_invite(
|
|
||||||
identifier: str,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Check if an invite is valid and can be used for signup."""
|
|
||||||
normalized = normalize_identifier(identifier)
|
|
||||||
|
|
||||||
# Validate format before querying database
|
|
||||||
if not is_valid_identifier_format(normalized):
|
|
||||||
return InviteCheckResponse(valid=False, error="Invalid invite code format")
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(Invite).where(Invite.identifier == normalized)
|
|
||||||
)
|
|
||||||
invite = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
# Return same error for not found, spent, and revoked to avoid information leakage
|
|
||||||
if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED):
|
|
||||||
return InviteCheckResponse(valid=False, error="Invite not found")
|
|
||||||
|
|
||||||
return InviteCheckResponse(valid=True, status=invite.status.value)
|
|
||||||
|
|
||||||
|
|
||||||
# Auth endpoints
|
|
||||||
class RegisterWithInvite(BaseModel):
|
|
||||||
"""Request model for registration with invite."""
|
|
||||||
email: EmailStr
|
|
||||||
password: str
|
|
||||||
invite_identifier: str
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/auth/register", response_model=UserResponse)
|
|
||||||
async def register(
|
|
||||||
user_data: RegisterWithInvite,
|
|
||||||
response: Response,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Register a new user using an invite code."""
|
|
||||||
# Validate invite
|
|
||||||
normalized_identifier = normalize_identifier(user_data.invite_identifier)
|
|
||||||
result = await db.execute(
|
|
||||||
select(Invite).where(Invite.identifier == normalized_identifier)
|
|
||||||
)
|
|
||||||
invite = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
# Return same error for not found, spent, and revoked to avoid information leakage
|
|
||||||
if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Invalid invite code",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check email not already taken
|
|
||||||
existing_user = await get_user_by_email(db, user_data.email)
|
|
||||||
if existing_user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Email already registered",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create user with godfather
|
|
||||||
user = User(
|
|
||||||
email=user_data.email,
|
|
||||||
hashed_password=get_password_hash(user_data.password),
|
|
||||||
godfather_id=invite.godfather_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign default role
|
|
||||||
default_role = await get_default_role(db)
|
|
||||||
if default_role:
|
|
||||||
user.roles.append(default_role)
|
|
||||||
|
|
||||||
db.add(user)
|
|
||||||
await db.flush() # Get user ID
|
|
||||||
|
|
||||||
# Mark invite as spent
|
|
||||||
invite.status = InviteStatus.SPENT
|
|
||||||
invite.used_by_id = user.id
|
|
||||||
invite.spent_at = datetime.now(UTC)
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user)
|
|
||||||
|
|
||||||
access_token = create_access_token(data={"sub": str(user.id)})
|
|
||||||
set_auth_cookie(response, access_token)
|
|
||||||
return await build_user_response(user, db)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/auth/login", response_model=UserResponse)
|
|
||||||
async def login(
|
|
||||||
user_data: UserLogin,
|
|
||||||
response: Response,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
|
||||||
user = await authenticate_user(db, user_data.email, user_data.password)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Incorrect email or password",
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token = create_access_token(data={"sub": str(user.id)})
|
|
||||||
set_auth_cookie(response, access_token)
|
|
||||||
return await build_user_response(user, db)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/auth/logout")
|
|
||||||
async def logout(response: Response):
|
|
||||||
response.delete_cookie(key=COOKIE_NAME)
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/auth/me", response_model=UserResponse)
|
|
||||||
async def get_me(
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
|
||||||
return await build_user_response(current_user, db)
|
|
||||||
|
|
||||||
|
|
||||||
# Counter endpoints
|
|
||||||
async def get_or_create_counter(db: AsyncSession) -> Counter:
|
|
||||||
result = await db.execute(select(Counter).where(Counter.id == 1))
|
|
||||||
counter = result.scalar_one_or_none()
|
|
||||||
if not counter:
|
|
||||||
counter = Counter(id=1, value=0)
|
|
||||||
db.add(counter)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(counter)
|
|
||||||
return counter
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/counter")
|
|
||||||
async def get_counter(
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
_current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
|
|
||||||
):
|
|
||||||
counter = await get_or_create_counter(db)
|
|
||||||
return {"value": counter.value}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/counter/increment")
|
|
||||||
async def increment_counter(
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
|
|
||||||
):
|
|
||||||
counter = await get_or_create_counter(db)
|
|
||||||
value_before = counter.value
|
|
||||||
counter.value += 1
|
|
||||||
|
|
||||||
record = CounterRecord(
|
|
||||||
user_id=current_user.id,
|
|
||||||
value_before=value_before,
|
|
||||||
value_after=counter.value,
|
|
||||||
)
|
|
||||||
db.add(record)
|
|
||||||
await db.commit()
|
|
||||||
return {"value": counter.value}
|
|
||||||
|
|
||||||
|
|
||||||
# Sum endpoints
|
|
||||||
class SumRequest(BaseModel):
|
|
||||||
a: float
|
|
||||||
b: float
|
|
||||||
|
|
||||||
|
|
||||||
class SumResponse(BaseModel):
|
|
||||||
a: float
|
|
||||||
b: float
|
|
||||||
result: float
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/sum", response_model=SumResponse)
|
|
||||||
async def calculate_sum(
|
|
||||||
data: SumRequest,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(require_permission(Permission.USE_SUM)),
|
|
||||||
):
|
|
||||||
result = data.a + data.b
|
|
||||||
record = SumRecord(
|
|
||||||
user_id=current_user.id,
|
|
||||||
a=data.a,
|
|
||||||
b=data.b,
|
|
||||||
result=result,
|
|
||||||
)
|
|
||||||
db.add(record)
|
|
||||||
await db.commit()
|
|
||||||
return SumResponse(a=data.a, b=data.b, result=result)
|
|
||||||
|
|
||||||
|
|
||||||
# Audit endpoints
|
|
||||||
class CounterRecordResponse(BaseModel):
|
|
||||||
id: int
|
|
||||||
user_email: str
|
|
||||||
value_before: int
|
|
||||||
value_after: int
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class SumRecordResponse(BaseModel):
|
|
||||||
id: int
|
|
||||||
user_email: str
|
|
||||||
a: float
|
|
||||||
b: float
|
|
||||||
result: float
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
|
|
||||||
RecordT = TypeVar("RecordT", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResponse(BaseModel, Generic[RecordT]):
|
|
||||||
"""Generic paginated response wrapper."""
|
|
||||||
records: list[RecordT]
|
|
||||||
total: int
|
|
||||||
page: int
|
|
||||||
per_page: int
|
|
||||||
total_pages: int
|
|
||||||
|
|
||||||
|
|
||||||
PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
|
|
||||||
PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
|
|
||||||
|
|
||||||
|
|
||||||
def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse:
|
|
||||||
return CounterRecordResponse(
|
|
||||||
id=record.id,
|
|
||||||
user_email=email,
|
|
||||||
value_before=record.value_before,
|
|
||||||
value_after=record.value_after,
|
|
||||||
created_at=record.created_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
|
|
||||||
async def get_counter_records(
|
|
||||||
page: int = Query(1, ge=1),
|
|
||||||
per_page: int = Query(10, ge=1, le=100),
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
|
||||||
):
|
|
||||||
records, total, total_pages = await paginate_with_user_email(
|
|
||||||
db, CounterRecord, page, per_page, _map_counter_record
|
|
||||||
)
|
|
||||||
return PaginatedCounterRecords(
|
|
||||||
records=records,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
per_page=per_page,
|
|
||||||
total_pages=total_pages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse:
|
|
||||||
return SumRecordResponse(
|
|
||||||
id=record.id,
|
|
||||||
user_email=email,
|
|
||||||
a=record.a,
|
|
||||||
b=record.b,
|
|
||||||
result=record.result,
|
|
||||||
created_at=record.created_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/audit/sum", response_model=PaginatedSumRecords)
|
|
||||||
async def get_sum_records(
|
|
||||||
page: int = Query(1, ge=1),
|
|
||||||
per_page: int = Query(10, ge=1, le=100),
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
|
||||||
):
|
|
||||||
records, total, total_pages = await paginate_with_user_email(
|
|
||||||
db, SumRecord, page, per_page, _map_sum_record
|
|
||||||
)
|
|
||||||
return PaginatedSumRecords(
|
|
||||||
records=records,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
per_page=per_page,
|
|
||||||
total_pages=total_pages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Profile endpoints
|
|
||||||
class ProfileResponse(BaseModel):
|
|
||||||
"""Response model for profile data."""
|
|
||||||
contact_email: str | None
|
|
||||||
telegram: str | None
|
|
||||||
signal: str | None
|
|
||||||
nostr_npub: str | None
|
|
||||||
godfather_email: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ProfileUpdate(BaseModel):
|
|
||||||
"""Request model for updating profile."""
|
|
||||||
contact_email: str | None = None
|
|
||||||
telegram: str | None = None
|
|
||||||
signal: str | None = None
|
|
||||||
nostr_npub: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def require_regular_user(
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
) -> User:
|
|
||||||
"""Dependency that requires the user to have the 'regular' role."""
|
|
||||||
if ROLE_REGULAR not in current_user.role_names:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Profile access is only available to regular users",
|
|
||||||
)
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str | None:
|
|
||||||
"""Get the email of a godfather user by ID."""
|
|
||||||
if not godfather_id:
|
|
||||||
return None
|
|
||||||
result = await db.execute(
|
|
||||||
select(User.email).where(User.id == godfather_id)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/profile", response_model=ProfileResponse)
|
|
||||||
async def get_profile(
|
|
||||||
current_user: User = Depends(require_regular_user),
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
):
|
|
||||||
"""Get the current user's profile (contact details and godfather)."""
|
|
||||||
godfather_email = await get_godfather_email(db, current_user.godfather_id)
|
|
||||||
|
|
||||||
return ProfileResponse(
|
|
||||||
contact_email=current_user.contact_email,
|
|
||||||
telegram=current_user.telegram,
|
|
||||||
signal=current_user.signal,
|
|
||||||
nostr_npub=current_user.nostr_npub,
|
|
||||||
godfather_email=godfather_email,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.put("/api/profile", response_model=ProfileResponse)
|
|
||||||
async def update_profile(
|
|
||||||
data: ProfileUpdate,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(require_regular_user),
|
|
||||||
):
|
|
||||||
"""Update the current user's profile (contact details)."""
|
|
||||||
# Validate all fields
|
|
||||||
errors = validate_profile_fields(
|
|
||||||
contact_email=data.contact_email,
|
|
||||||
telegram=data.telegram,
|
|
||||||
signal=data.signal,
|
|
||||||
nostr_npub=data.nostr_npub,
|
|
||||||
)
|
|
||||||
|
|
||||||
if errors:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=422,
|
|
||||||
detail={"field_errors": errors},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update fields
|
|
||||||
current_user.contact_email = data.contact_email
|
|
||||||
current_user.telegram = data.telegram
|
|
||||||
current_user.signal = data.signal
|
|
||||||
current_user.nostr_npub = data.nostr_npub
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(current_user)
|
|
||||||
|
|
||||||
godfather_email = await get_godfather_email(db, current_user.godfather_id)
|
|
||||||
|
|
||||||
return ProfileResponse(
|
|
||||||
contact_email=current_user.contact_email,
|
|
||||||
telegram=current_user.telegram,
|
|
||||||
signal=current_user.signal,
|
|
||||||
nostr_npub=current_user.nostr_npub,
|
|
||||||
godfather_email=godfather_email,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Invite endpoints
|
|
||||||
class InviteCreate(BaseModel):
|
|
||||||
"""Request model for creating an invite."""
|
|
||||||
godfather_id: int
|
|
||||||
|
|
||||||
|
|
||||||
class InviteResponse(BaseModel):
|
|
||||||
"""Response model for invite data."""
|
|
||||||
id: int
|
|
||||||
identifier: str
|
|
||||||
godfather_id: int
|
|
||||||
godfather_email: str
|
|
||||||
status: str
|
|
||||||
used_by_id: int | None
|
|
||||||
used_by_email: str | None
|
|
||||||
created_at: datetime
|
|
||||||
spent_at: datetime | None
|
|
||||||
revoked_at: datetime | None
|
|
||||||
|
|
||||||
|
|
||||||
def build_invite_response(invite: Invite) -> InviteResponse:
|
|
||||||
"""Build an InviteResponse from an Invite with loaded relationships."""
|
|
||||||
return InviteResponse(
|
|
||||||
id=invite.id,
|
|
||||||
identifier=invite.identifier,
|
|
||||||
godfather_id=invite.godfather_id,
|
|
||||||
godfather_email=invite.godfather.email,
|
|
||||||
status=invite.status.value,
|
|
||||||
used_by_id=invite.used_by_id,
|
|
||||||
used_by_email=invite.used_by.email if invite.used_by else None,
|
|
||||||
created_at=invite.created_at,
|
|
||||||
spent_at=invite.spent_at,
|
|
||||||
revoked_at=invite.revoked_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
MAX_INVITE_COLLISION_RETRIES = 3
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/admin/invites", response_model=InviteResponse)
|
|
||||||
async def create_invite(
|
|
||||||
data: InviteCreate,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
|
||||||
):
|
|
||||||
"""Create a new invite for a specified godfather user."""
|
|
||||||
# Validate godfather exists
|
|
||||||
result = await db.execute(
|
|
||||||
select(User.id).where(User.id == data.godfather_id)
|
|
||||||
)
|
|
||||||
godfather_id = result.scalar_one_or_none()
|
|
||||||
if not godfather_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Godfather user not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to create invite with retry on collision
|
|
||||||
invite: Invite | None = None
|
|
||||||
for attempt in range(MAX_INVITE_COLLISION_RETRIES):
|
|
||||||
identifier = generate_invite_identifier()
|
|
||||||
invite = Invite(
|
|
||||||
identifier=identifier,
|
|
||||||
godfather_id=godfather_id,
|
|
||||||
status=InviteStatus.READY,
|
|
||||||
)
|
|
||||||
db.add(invite)
|
|
||||||
try:
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(invite, ["godfather"])
|
|
||||||
break
|
|
||||||
except IntegrityError:
|
|
||||||
await db.rollback()
|
|
||||||
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to generate unique invite code. Please try again.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if invite is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to create invite",
|
|
||||||
)
|
|
||||||
return build_invite_response(invite)
|
|
||||||
|
|
||||||
|
|
||||||
class UserInviteResponse(BaseModel):
|
|
||||||
"""Response model for a user's invite (simpler than admin view)."""
|
|
||||||
id: int
|
|
||||||
identifier: str
|
|
||||||
status: str
|
|
||||||
used_by_email: str | None
|
|
||||||
created_at: datetime
|
|
||||||
spent_at: datetime | None
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/invites", response_model=list[UserInviteResponse])
|
|
||||||
async def get_my_invites(
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
current_user: User = Depends(require_permission(Permission.VIEW_OWN_INVITES)),
|
|
||||||
):
|
|
||||||
"""Get all invites owned by the current user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(Invite)
|
|
||||||
.where(Invite.godfather_id == current_user.id)
|
|
||||||
.order_by(desc(Invite.created_at))
|
|
||||||
)
|
|
||||||
invites = result.scalars().all()
|
|
||||||
|
|
||||||
# Use preloaded used_by relationship (selectin loading)
|
|
||||||
return [
|
|
||||||
UserInviteResponse(
|
|
||||||
id=invite.id,
|
|
||||||
identifier=invite.identifier,
|
|
||||||
status=invite.status.value,
|
|
||||||
used_by_email=invite.used_by.email if invite.used_by else None,
|
|
||||||
created_at=invite.created_at,
|
|
||||||
spent_at=invite.spent_at,
|
|
||||||
)
|
|
||||||
for invite in invites
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Admin Invite Management
|
|
||||||
PaginatedInviteRecords = PaginatedResponse[InviteResponse]
|
|
||||||
|
|
||||||
|
|
||||||
class AdminUserResponse(BaseModel):
|
|
||||||
"""Minimal user info for admin dropdowns."""
|
|
||||||
id: int
|
|
||||||
email: str
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/admin/users", response_model=list[AdminUserResponse])
|
|
||||||
async def list_users_for_admin(
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
|
||||||
):
|
|
||||||
"""List all users for admin dropdowns (invite creation, etc.)."""
|
|
||||||
result = await db.execute(select(User.id, User.email).order_by(User.email))
|
|
||||||
users = result.all()
|
|
||||||
return [AdminUserResponse(id=u.id, email=u.email) for u in users]
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/admin/invites", response_model=PaginatedInviteRecords)
|
|
||||||
async def list_all_invites(
|
|
||||||
page: int = Query(1, ge=1),
|
|
||||||
per_page: int = Query(10, ge=1, le=100),
|
|
||||||
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"),
|
|
||||||
godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
|
||||||
):
|
|
||||||
"""List all invites with optional filtering and pagination."""
|
|
||||||
# Build query
|
|
||||||
query = select(Invite)
|
|
||||||
count_query = select(func.count(Invite.id))
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
if status_filter:
|
|
||||||
try:
|
|
||||||
status_enum = InviteStatus(status_filter)
|
|
||||||
query = query.where(Invite.status == status_enum)
|
|
||||||
count_query = count_query.where(Invite.status == status_enum)
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked",
|
|
||||||
)
|
|
||||||
|
|
||||||
if godfather_id:
|
|
||||||
query = query.where(Invite.godfather_id == godfather_id)
|
|
||||||
count_query = count_query.where(Invite.godfather_id == godfather_id)
|
|
||||||
|
|
||||||
# Get total count
|
|
||||||
count_result = await db.execute(count_query)
|
|
||||||
total = count_result.scalar() or 0
|
|
||||||
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
|
||||||
|
|
||||||
# Get paginated invites (relationships loaded via selectin)
|
|
||||||
offset = (page - 1) * per_page
|
|
||||||
query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page)
|
|
||||||
result = await db.execute(query)
|
|
||||||
invites = result.scalars().all()
|
|
||||||
|
|
||||||
# Build responses using preloaded relationships
|
|
||||||
records = [build_invite_response(invite) for invite in invites]
|
|
||||||
|
|
||||||
return PaginatedInviteRecords(
|
|
||||||
records=records,
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
per_page=per_page,
|
|
||||||
total_pages=total_pages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/admin/invites/{invite_id}/revoke", response_model=InviteResponse)
|
|
||||||
async def revoke_invite(
|
|
||||||
invite_id: int,
|
|
||||||
db: AsyncSession = Depends(get_db),
|
|
||||||
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
|
||||||
):
|
|
||||||
"""Revoke an invite. Only READY invites can be revoked."""
|
|
||||||
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
|
||||||
invite = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not invite:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Invite not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
if invite.status != InviteStatus.READY:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.",
|
|
||||||
)
|
|
||||||
|
|
||||||
invite.status = InviteStatus.REVOKED
|
|
||||||
invite.revoked_at = datetime.now(UTC)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(invite)
|
|
||||||
|
|
||||||
return build_invite_response(invite)
|
|
||||||
|
|
|
||||||
0
backend/routes/__init__.py
Normal file
0
backend/routes/__init__.py
Normal file
117
backend/routes/audit.py
Normal file
117
backend/routes/audit.py
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""Audit routes for viewing action records."""
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select, func, desc
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import require_permission
|
||||||
|
from database import get_db
|
||||||
|
from models import User, SumRecord, CounterRecord, Permission
|
||||||
|
from schemas import (
|
||||||
|
CounterRecordResponse,
|
||||||
|
SumRecordResponse,
|
||||||
|
PaginatedCounterRecords,
|
||||||
|
PaginatedSumRecords,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/audit", tags=["audit"])
|
||||||
|
|
||||||
|
R = TypeVar("R", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
async def paginate_with_user_email(
|
||||||
|
db: AsyncSession,
|
||||||
|
model: type[SumRecord] | type[CounterRecord],
|
||||||
|
page: int,
|
||||||
|
per_page: int,
|
||||||
|
row_mapper: Callable[..., R],
|
||||||
|
) -> tuple[list[R], int, int]:
|
||||||
|
"""
|
||||||
|
Generic pagination helper for audit records that need user email.
|
||||||
|
|
||||||
|
Returns: (records, total, total_pages)
|
||||||
|
"""
|
||||||
|
# Get total count
|
||||||
|
count_result = await db.execute(select(func.count(model.id)))
|
||||||
|
total = count_result.scalar() or 0
|
||||||
|
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
||||||
|
|
||||||
|
# Get paginated records with user email
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
query = (
|
||||||
|
select(model, User.email)
|
||||||
|
.join(User, model.user_id == User.id)
|
||||||
|
.order_by(desc(model.created_at))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(per_page)
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
records: list[R] = [row_mapper(record, email) for record, email in rows]
|
||||||
|
return records, total, total_pages
|
||||||
|
|
||||||
|
|
||||||
|
def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse:
|
||||||
|
return CounterRecordResponse(
|
||||||
|
id=record.id,
|
||||||
|
user_email=email,
|
||||||
|
value_before=record.value_before,
|
||||||
|
value_after=record.value_after,
|
||||||
|
created_at=record.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse:
|
||||||
|
return SumRecordResponse(
|
||||||
|
id=record.id,
|
||||||
|
user_email=email,
|
||||||
|
a=record.a,
|
||||||
|
b=record.b,
|
||||||
|
result=record.result,
|
||||||
|
created_at=record.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/counter", response_model=PaginatedCounterRecords)
|
||||||
|
async def get_counter_records(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
per_page: int = Query(10, ge=1, le=100),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||||
|
) -> PaginatedCounterRecords:
|
||||||
|
"""Get paginated counter action records."""
|
||||||
|
records, total, total_pages = await paginate_with_user_email(
|
||||||
|
db, CounterRecord, page, per_page, _map_counter_record
|
||||||
|
)
|
||||||
|
return PaginatedCounterRecords(
|
||||||
|
records=records,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
per_page=per_page,
|
||||||
|
total_pages=total_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sum", response_model=PaginatedSumRecords)
|
||||||
|
async def get_sum_records(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
per_page: int = Query(10, ge=1, le=100),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||||
|
) -> PaginatedSumRecords:
|
||||||
|
"""Get paginated sum action records."""
|
||||||
|
records, total, total_pages = await paginate_with_user_email(
|
||||||
|
db, SumRecord, page, per_page, _map_sum_record
|
||||||
|
)
|
||||||
|
return PaginatedSumRecords(
|
||||||
|
records=records,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
per_page=per_page,
|
||||||
|
total_pages=total_pages,
|
||||||
|
)
|
||||||
|
|
||||||
135
backend/routes/auth.py
Normal file
135
backend/routes/auth.py
Normal file
|
|
@ -0,0 +1,135 @@
|
||||||
|
"""Authentication routes for register, login, logout, and current user."""
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import (
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
|
COOKIE_NAME,
|
||||||
|
get_password_hash,
|
||||||
|
get_user_by_email,
|
||||||
|
authenticate_user,
|
||||||
|
create_access_token,
|
||||||
|
get_current_user,
|
||||||
|
build_user_response,
|
||||||
|
)
|
||||||
|
from database import get_db
|
||||||
|
from invite_utils import normalize_identifier
|
||||||
|
from models import User, Role, ROLE_REGULAR, Invite, InviteStatus
|
||||||
|
from schemas import UserLogin, UserResponse, RegisterWithInvite
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
def set_auth_cookie(response: Response, token: str) -> None:
|
||||||
|
"""Set the authentication cookie on the response."""
|
||||||
|
response.set_cookie(
|
||||||
|
key=COOKIE_NAME,
|
||||||
|
value=token,
|
||||||
|
httponly=True,
|
||||||
|
secure=False, # Set to True in production with HTTPS
|
||||||
|
samesite="lax",
|
||||||
|
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_default_role(db: AsyncSession) -> Role | None:
|
||||||
|
"""Get the default 'regular' role for new users."""
|
||||||
|
result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR))
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse)
|
||||||
|
async def register(
|
||||||
|
user_data: RegisterWithInvite,
|
||||||
|
response: Response,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> UserResponse:
|
||||||
|
"""Register a new user using an invite code."""
|
||||||
|
# Validate invite
|
||||||
|
normalized_identifier = normalize_identifier(user_data.invite_identifier)
|
||||||
|
result = await db.execute(
|
||||||
|
select(Invite).where(Invite.identifier == normalized_identifier)
|
||||||
|
)
|
||||||
|
invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# Return same error for not found, spent, and revoked to avoid information leakage
|
||||||
|
if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid invite code",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check email not already taken
|
||||||
|
existing_user = await get_user_by_email(db, user_data.email)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Email already registered",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create user with godfather
|
||||||
|
user = User(
|
||||||
|
email=user_data.email,
|
||||||
|
hashed_password=get_password_hash(user_data.password),
|
||||||
|
godfather_id=invite.godfather_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign default role
|
||||||
|
default_role = await get_default_role(db)
|
||||||
|
if default_role:
|
||||||
|
user.roles.append(default_role)
|
||||||
|
|
||||||
|
db.add(user)
|
||||||
|
await db.flush() # Get user ID
|
||||||
|
|
||||||
|
# Mark invite as spent
|
||||||
|
invite.status = InviteStatus.SPENT
|
||||||
|
invite.used_by_id = user.id
|
||||||
|
invite.spent_at = datetime.now(UTC)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
access_token = create_access_token(data={"sub": str(user.id)})
|
||||||
|
set_auth_cookie(response, access_token)
|
||||||
|
return await build_user_response(user, db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=UserResponse)
|
||||||
|
async def login(
|
||||||
|
user_data: UserLogin,
|
||||||
|
response: Response,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> UserResponse:
|
||||||
|
"""Authenticate a user and return their info with an auth cookie."""
|
||||||
|
user = await authenticate_user(db, user_data.email, user_data.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect email or password",
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token = create_access_token(data={"sub": str(user.id)})
|
||||||
|
set_auth_cookie(response, access_token)
|
||||||
|
return await build_user_response(user, db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout")
|
||||||
|
async def logout(response: Response) -> dict[str, bool]:
|
||||||
|
"""Log out the current user by clearing their auth cookie."""
|
||||||
|
response.delete_cookie(key=COOKIE_NAME)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_me(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> UserResponse:
|
||||||
|
"""Get the current authenticated user's info."""
|
||||||
|
return await build_user_response(current_user, db)
|
||||||
|
|
||||||
54
backend/routes/counter.py
Normal file
54
backend/routes/counter.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""Counter routes."""
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import require_permission
|
||||||
|
from database import get_db
|
||||||
|
from models import Counter, User, CounterRecord, Permission
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/counter", tags=["counter"])
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_counter(db: AsyncSession) -> Counter:
|
||||||
|
"""Get the singleton counter, creating it if it doesn't exist."""
|
||||||
|
result = await db.execute(select(Counter).where(Counter.id == 1))
|
||||||
|
counter = result.scalar_one_or_none()
|
||||||
|
if not counter:
|
||||||
|
counter = Counter(id=1, value=0)
|
||||||
|
db.add(counter)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(counter)
|
||||||
|
return counter
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def get_counter(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Get the current counter value."""
|
||||||
|
counter = await get_or_create_counter(db)
|
||||||
|
return {"value": counter.value}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/increment")
|
||||||
|
async def increment_counter(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Increment the counter and record the action."""
|
||||||
|
counter = await get_or_create_counter(db)
|
||||||
|
value_before = counter.value
|
||||||
|
counter.value += 1
|
||||||
|
|
||||||
|
record = CounterRecord(
|
||||||
|
user_id=current_user.id,
|
||||||
|
value_before=value_before,
|
||||||
|
value_after=counter.value,
|
||||||
|
)
|
||||||
|
db.add(record)
|
||||||
|
await db.commit()
|
||||||
|
return {"value": counter.value}
|
||||||
|
|
||||||
247
backend/routes/invites.py
Normal file
247
backend/routes/invites.py
Normal file
|
|
@ -0,0 +1,247 @@
|
||||||
|
"""Invite routes for public check, user invites, and admin management."""
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from sqlalchemy import select, func, desc
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import require_permission
|
||||||
|
from database import get_db
|
||||||
|
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format
|
||||||
|
from models import User, Invite, InviteStatus, Permission
|
||||||
|
from schemas import (
|
||||||
|
InviteCheckResponse,
|
||||||
|
InviteCreate,
|
||||||
|
InviteResponse,
|
||||||
|
UserInviteResponse,
|
||||||
|
PaginatedInviteRecords,
|
||||||
|
AdminUserResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(tags=["invites"])
|
||||||
|
|
||||||
|
MAX_INVITE_COLLISION_RETRIES = 3
|
||||||
|
|
||||||
|
|
||||||
|
def build_invite_response(invite: Invite) -> InviteResponse:
|
||||||
|
"""Build an InviteResponse from an Invite with loaded relationships."""
|
||||||
|
return InviteResponse(
|
||||||
|
id=invite.id,
|
||||||
|
identifier=invite.identifier,
|
||||||
|
godfather_id=invite.godfather_id,
|
||||||
|
godfather_email=invite.godfather.email,
|
||||||
|
status=invite.status.value,
|
||||||
|
used_by_id=invite.used_by_id,
|
||||||
|
used_by_email=invite.used_by.email if invite.used_by else None,
|
||||||
|
created_at=invite.created_at,
|
||||||
|
spent_at=invite.spent_at,
|
||||||
|
revoked_at=invite.revoked_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Public Endpoints
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.get("/api/invites/{identifier}/check", response_model=InviteCheckResponse)
|
||||||
|
async def check_invite(
|
||||||
|
identifier: str,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> InviteCheckResponse:
|
||||||
|
"""Check if an invite is valid and can be used for signup."""
|
||||||
|
normalized = normalize_identifier(identifier)
|
||||||
|
|
||||||
|
# Validate format before querying database
|
||||||
|
if not is_valid_identifier_format(normalized):
|
||||||
|
return InviteCheckResponse(valid=False, error="Invalid invite code format")
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Invite).where(Invite.identifier == normalized)
|
||||||
|
)
|
||||||
|
invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
# Return same error for not found, spent, and revoked to avoid information leakage
|
||||||
|
if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED):
|
||||||
|
return InviteCheckResponse(valid=False, error="Invite not found")
|
||||||
|
|
||||||
|
return InviteCheckResponse(valid=True, status=invite.status.value)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# User Endpoints (requires VIEW_OWN_INVITES permission)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.get("/api/invites", response_model=list[UserInviteResponse])
|
||||||
|
async def get_my_invites(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_permission(Permission.VIEW_OWN_INVITES)),
|
||||||
|
) -> list[UserInviteResponse]:
|
||||||
|
"""Get all invites owned by the current user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Invite)
|
||||||
|
.where(Invite.godfather_id == current_user.id)
|
||||||
|
.order_by(desc(Invite.created_at))
|
||||||
|
)
|
||||||
|
invites = result.scalars().all()
|
||||||
|
|
||||||
|
# Use preloaded used_by relationship (selectin loading)
|
||||||
|
return [
|
||||||
|
UserInviteResponse(
|
||||||
|
id=invite.id,
|
||||||
|
identifier=invite.identifier,
|
||||||
|
status=invite.status.value,
|
||||||
|
used_by_email=invite.used_by.email if invite.used_by else None,
|
||||||
|
created_at=invite.created_at,
|
||||||
|
spent_at=invite.spent_at,
|
||||||
|
)
|
||||||
|
for invite in invites
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Admin Endpoints (requires MANAGE_INVITES permission)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@router.get("/api/admin/users", response_model=list[AdminUserResponse])
|
||||||
|
async def list_users_for_admin(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
||||||
|
) -> list[AdminUserResponse]:
|
||||||
|
"""List all users for admin dropdowns (invite creation, etc.)."""
|
||||||
|
result = await db.execute(select(User.id, User.email).order_by(User.email))
|
||||||
|
users = result.all()
|
||||||
|
return [AdminUserResponse(id=u.id, email=u.email) for u in users]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/admin/invites", response_model=InviteResponse)
|
||||||
|
async def create_invite(
|
||||||
|
data: InviteCreate,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
||||||
|
) -> InviteResponse:
|
||||||
|
"""Create a new invite for a specified godfather user."""
|
||||||
|
# Validate godfather exists
|
||||||
|
result = await db.execute(
|
||||||
|
select(User.id).where(User.id == data.godfather_id)
|
||||||
|
)
|
||||||
|
godfather_id = result.scalar_one_or_none()
|
||||||
|
if not godfather_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Godfather user not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to create invite with retry on collision
|
||||||
|
invite: Invite | None = None
|
||||||
|
for attempt in range(MAX_INVITE_COLLISION_RETRIES):
|
||||||
|
identifier = generate_invite_identifier()
|
||||||
|
invite = Invite(
|
||||||
|
identifier=identifier,
|
||||||
|
godfather_id=godfather_id,
|
||||||
|
status=InviteStatus.READY,
|
||||||
|
)
|
||||||
|
db.add(invite)
|
||||||
|
try:
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(invite, ["godfather"])
|
||||||
|
break
|
||||||
|
except IntegrityError:
|
||||||
|
await db.rollback()
|
||||||
|
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to generate unique invite code. Please try again.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if invite is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="Failed to create invite",
|
||||||
|
)
|
||||||
|
return build_invite_response(invite)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/admin/invites", response_model=PaginatedInviteRecords)
|
||||||
|
async def list_all_invites(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
per_page: int = Query(10, ge=1, le=100),
|
||||||
|
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"),
|
||||||
|
godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
||||||
|
) -> PaginatedInviteRecords:
|
||||||
|
"""List all invites with optional filtering and pagination."""
|
||||||
|
# Build query
|
||||||
|
query = select(Invite)
|
||||||
|
count_query = select(func.count(Invite.id))
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if status_filter:
|
||||||
|
try:
|
||||||
|
status_enum = InviteStatus(status_filter)
|
||||||
|
query = query.where(Invite.status == status_enum)
|
||||||
|
count_query = count_query.where(Invite.status == status_enum)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked",
|
||||||
|
)
|
||||||
|
|
||||||
|
if godfather_id:
|
||||||
|
query = query.where(Invite.godfather_id == godfather_id)
|
||||||
|
count_query = count_query.where(Invite.godfather_id == godfather_id)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
count_result = await db.execute(count_query)
|
||||||
|
total = count_result.scalar() or 0
|
||||||
|
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
||||||
|
|
||||||
|
# Get paginated invites (relationships loaded via selectin)
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page)
|
||||||
|
result = await db.execute(query)
|
||||||
|
invites = result.scalars().all()
|
||||||
|
|
||||||
|
# Build responses using preloaded relationships
|
||||||
|
records = [build_invite_response(invite) for invite in invites]
|
||||||
|
|
||||||
|
return PaginatedInviteRecords(
|
||||||
|
records=records,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
per_page=per_page,
|
||||||
|
total_pages=total_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/admin/invites/{invite_id}/revoke", response_model=InviteResponse)
|
||||||
|
async def revoke_invite(
|
||||||
|
invite_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
||||||
|
) -> InviteResponse:
|
||||||
|
"""Revoke an invite. Only READY invites can be revoked."""
|
||||||
|
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||||
|
invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not invite:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Invite not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if invite.status != InviteStatus.READY:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.",
|
||||||
|
)
|
||||||
|
|
||||||
|
invite.status = InviteStatus.REVOKED
|
||||||
|
invite.revoked_at = datetime.now(UTC)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(invite)
|
||||||
|
|
||||||
|
return build_invite_response(invite)
|
||||||
|
|
||||||
94
backend/routes/profile.py
Normal file
94
backend/routes/profile.py
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""Profile routes for user contact details."""
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import get_current_user
|
||||||
|
from database import get_db
|
||||||
|
from models import User, ROLE_REGULAR
|
||||||
|
from schemas import ProfileResponse, ProfileUpdate
|
||||||
|
from validation import validate_profile_fields
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/profile", tags=["profile"])
|
||||||
|
|
||||||
|
|
||||||
|
async def require_regular_user(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
) -> User:
|
||||||
|
"""Dependency that requires the user to have the 'regular' role."""
|
||||||
|
if ROLE_REGULAR not in current_user.role_names:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Profile access is only available to regular users",
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str | None:
|
||||||
|
"""Get the email of a godfather user by ID."""
|
||||||
|
if not godfather_id:
|
||||||
|
return None
|
||||||
|
result = await db.execute(
|
||||||
|
select(User.email).where(User.id == godfather_id)
|
||||||
|
)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ProfileResponse)
|
||||||
|
async def get_profile(
|
||||||
|
current_user: User = Depends(require_regular_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
) -> ProfileResponse:
|
||||||
|
"""Get the current user's profile (contact details and godfather)."""
|
||||||
|
godfather_email = await get_godfather_email(db, current_user.godfather_id)
|
||||||
|
|
||||||
|
return ProfileResponse(
|
||||||
|
contact_email=current_user.contact_email,
|
||||||
|
telegram=current_user.telegram,
|
||||||
|
signal=current_user.signal,
|
||||||
|
nostr_npub=current_user.nostr_npub,
|
||||||
|
godfather_email=godfather_email,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("", response_model=ProfileResponse)
|
||||||
|
async def update_profile(
|
||||||
|
data: ProfileUpdate,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_regular_user),
|
||||||
|
) -> ProfileResponse:
|
||||||
|
"""Update the current user's profile (contact details)."""
|
||||||
|
# Validate all fields
|
||||||
|
errors = validate_profile_fields(
|
||||||
|
contact_email=data.contact_email,
|
||||||
|
telegram=data.telegram,
|
||||||
|
signal=data.signal,
|
||||||
|
nostr_npub=data.nostr_npub,
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail={"field_errors": errors},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update fields
|
||||||
|
current_user.contact_email = data.contact_email
|
||||||
|
current_user.telegram = data.telegram
|
||||||
|
current_user.signal = data.signal
|
||||||
|
current_user.nostr_npub = data.nostr_npub
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(current_user)
|
||||||
|
|
||||||
|
godfather_email = await get_godfather_email(db, current_user.godfather_id)
|
||||||
|
|
||||||
|
return ProfileResponse(
|
||||||
|
contact_email=current_user.contact_email,
|
||||||
|
telegram=current_user.telegram,
|
||||||
|
signal=current_user.signal,
|
||||||
|
nostr_npub=current_user.nostr_npub,
|
||||||
|
godfather_email=godfather_email,
|
||||||
|
)
|
||||||
|
|
||||||
31
backend/routes/sum.py
Normal file
31
backend/routes/sum.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""Sum calculation routes."""
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from auth import require_permission
|
||||||
|
from database import get_db
|
||||||
|
from models import User, SumRecord, Permission
|
||||||
|
from schemas import SumRequest, SumResponse
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/sum", tags=["sum"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=SumResponse)
|
||||||
|
async def calculate_sum(
|
||||||
|
data: SumRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(require_permission(Permission.USE_SUM)),
|
||||||
|
) -> SumResponse:
|
||||||
|
"""Calculate the sum of two numbers and record it."""
|
||||||
|
result = data.a + data.b
|
||||||
|
record = SumRecord(
|
||||||
|
user_id=current_user.id,
|
||||||
|
a=data.a,
|
||||||
|
b=data.b,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
db.add(record)
|
||||||
|
await db.commit()
|
||||||
|
return SumResponse(a=data.a, b=data.b, result=result)
|
||||||
|
|
||||||
185
backend/schemas.py
Normal file
185
backend/schemas.py
Normal file
|
|
@ -0,0 +1,185 @@
|
||||||
|
"""Pydantic schemas for API request/response models."""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Auth Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class UserCredentials(BaseModel):
|
||||||
|
"""Base model for user email/password."""
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
UserCreate = UserCredentials
|
||||||
|
UserLogin = UserCredentials
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
"""Response model for authenticated user info."""
|
||||||
|
id: int
|
||||||
|
email: str
|
||||||
|
roles: list[str]
|
||||||
|
permissions: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
"""Response model for token-based auth (unused but kept for API completeness)."""
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
user: UserResponse
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterWithInvite(BaseModel):
|
||||||
|
"""Request model for registration with invite."""
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
invite_identifier: str
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Counter Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class CounterValue(BaseModel):
|
||||||
|
"""Response model for counter value."""
|
||||||
|
value: int
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Sum Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class SumRequest(BaseModel):
|
||||||
|
"""Request model for sum calculation."""
|
||||||
|
a: float
|
||||||
|
b: float
|
||||||
|
|
||||||
|
|
||||||
|
class SumResponse(BaseModel):
|
||||||
|
"""Response model for sum calculation."""
|
||||||
|
a: float
|
||||||
|
b: float
|
||||||
|
result: float
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Audit Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class CounterRecordResponse(BaseModel):
|
||||||
|
"""Response model for a counter audit record."""
|
||||||
|
id: int
|
||||||
|
user_email: str
|
||||||
|
value_before: int
|
||||||
|
value_after: int
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class SumRecordResponse(BaseModel):
|
||||||
|
"""Response model for a sum audit record."""
|
||||||
|
id: int
|
||||||
|
user_email: str
|
||||||
|
a: float
|
||||||
|
b: float
|
||||||
|
result: float
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Pagination (Generic)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
RecordT = TypeVar("RecordT", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class PaginatedResponse(BaseModel, Generic[RecordT]):
|
||||||
|
"""Generic paginated response wrapper."""
|
||||||
|
records: list[RecordT]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
per_page: int
|
||||||
|
total_pages: int
|
||||||
|
|
||||||
|
|
||||||
|
PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
|
||||||
|
PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Profile Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class ProfileResponse(BaseModel):
|
||||||
|
"""Response model for profile data."""
|
||||||
|
contact_email: str | None
|
||||||
|
telegram: str | None
|
||||||
|
signal: str | None
|
||||||
|
nostr_npub: str | None
|
||||||
|
godfather_email: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileUpdate(BaseModel):
|
||||||
|
"""Request model for updating profile."""
|
||||||
|
contact_email: str | None = None
|
||||||
|
telegram: str | None = None
|
||||||
|
signal: str | None = None
|
||||||
|
nostr_npub: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Invite Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class InviteCheckResponse(BaseModel):
|
||||||
|
"""Response for invite check endpoint."""
|
||||||
|
valid: bool
|
||||||
|
status: str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InviteCreate(BaseModel):
|
||||||
|
"""Request model for creating an invite."""
|
||||||
|
godfather_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class InviteResponse(BaseModel):
|
||||||
|
"""Response model for invite data (admin view)."""
|
||||||
|
id: int
|
||||||
|
identifier: str
|
||||||
|
godfather_id: int
|
||||||
|
godfather_email: str
|
||||||
|
status: str
|
||||||
|
used_by_id: int | None
|
||||||
|
used_by_email: str | None
|
||||||
|
created_at: datetime
|
||||||
|
spent_at: datetime | None
|
||||||
|
revoked_at: datetime | None
|
||||||
|
|
||||||
|
|
||||||
|
class UserInviteResponse(BaseModel):
|
||||||
|
"""Response model for a user's invite (simpler than admin view)."""
|
||||||
|
id: int
|
||||||
|
identifier: str
|
||||||
|
status: str
|
||||||
|
used_by_email: str | None
|
||||||
|
created_at: datetime
|
||||||
|
spent_at: datetime | None
|
||||||
|
|
||||||
|
|
||||||
|
PaginatedInviteRecords = PaginatedResponse[InviteResponse]
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Admin Schemas
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
class AdminUserResponse(BaseModel):
|
||||||
|
"""Minimal user info for admin dropdowns."""
|
||||||
|
id: int
|
||||||
|
email: str
|
||||||
|
|
||||||
|
|
@ -426,7 +426,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
||||||
return identifier1 # Will collide
|
return identifier1 # Will collide
|
||||||
return f"unique-word-{call_count:02d}" # Won't collide
|
return f"unique-word-{call_count:02d}" # Won't collide
|
||||||
|
|
||||||
with patch("main.generate_invite_identifier", side_effect=mock_generator):
|
with patch("routes.invites.generate_invite_identifier", side_effect=mock_generator):
|
||||||
response2 = await client.post(
|
response2 = await client.post(
|
||||||
"/api/admin/invites",
|
"/api/admin/invites",
|
||||||
json={"godfather_id": godfather.id},
|
json={"godfather_id": godfather.id},
|
||||||
|
|
|
||||||
|
|
@ -97,15 +97,20 @@ test.describe("Admin Invites Page", () => {
|
||||||
await godfatherSelect.selectOption({ label: REGULAR_USER_EMAIL });
|
await godfatherSelect.selectOption({ label: REGULAR_USER_EMAIL });
|
||||||
await page.click('button:has-text("Create Invite")');
|
await page.click('button:has-text("Create Invite")');
|
||||||
|
|
||||||
// Wait for the invite to appear
|
// Wait for the new invite to appear and capture its code
|
||||||
await expect(page.locator("table")).toContainText("ready");
|
// The new invite should be the first row with godfather = REGULAR_USER_EMAIL and status = ready
|
||||||
|
const newInviteRow = page.locator("tr").filter({ hasText: REGULAR_USER_EMAIL }).filter({ hasText: "ready" }).first();
|
||||||
|
await expect(newInviteRow).toBeVisible();
|
||||||
|
|
||||||
// Click revoke on the first ready invite
|
// Get the invite code from this row (first cell)
|
||||||
const revokeButton = page.locator('button:has-text("Revoke")').first();
|
const inviteCode = await newInviteRow.locator("td").first().textContent();
|
||||||
await revokeButton.click();
|
|
||||||
|
|
||||||
// Verify the status changed to revoked
|
// Click revoke on this specific row
|
||||||
await expect(page.locator("table")).toContainText("revoked");
|
await newInviteRow.locator('button:has-text("Revoke")').click();
|
||||||
|
|
||||||
|
// Verify this specific invite now shows "revoked"
|
||||||
|
const revokedRow = page.locator("tr").filter({ hasText: inviteCode! });
|
||||||
|
await expect(revokedRow).toContainText("revoked");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("status filter works", async ({ page }) => {
|
test("status filter works", async ({ page }) => {
|
||||||
|
|
|
||||||
|
|
@ -22,16 +22,16 @@ cd ..
|
||||||
|
|
||||||
# Start backend (SECRET_KEY should be set via .envrc or environment)
|
# Start backend (SECRET_KEY should be set via .envrc or environment)
|
||||||
cd backend
|
cd backend
|
||||||
uv run uvicorn main:app --port 8000 &
|
uv run uvicorn main:app --port 8000 --log-level warning &
|
||||||
PID=$!
|
PID=$!
|
||||||
cd ..
|
cd ..
|
||||||
|
|
||||||
# Wait for backend
|
# Wait for backend
|
||||||
sleep 2
|
sleep 2
|
||||||
|
|
||||||
# Run tests
|
# Run tests (suppress Node.js color warnings)
|
||||||
cd frontend
|
cd frontend
|
||||||
npm run test:e2e
|
NODE_NO_WARNINGS=1 npm run test:e2e
|
||||||
EXIT_CODE=$?
|
EXIT_CODE=$?
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue