first round of review

This commit is contained in:
counterweight 2025-12-20 11:43:32 +01:00
parent 870804e7b9
commit 23049da55a
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
15 changed files with 325 additions and 182 deletions

View file

@ -6,6 +6,7 @@ from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr
from sqlalchemy import select, func, desc
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from auth import (
@ -25,7 +26,7 @@ from auth import (
from database import engine, get_db, Base
from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR, Invite, InviteStatus
from validation import validate_profile_fields
from invite_utils import generate_invite_identifier, normalize_identifier
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format
R = TypeVar("R", bound=BaseModel)
@ -115,6 +116,10 @@ async def check_invite(
"""Check if an invite is valid and can be used for signup."""
normalized = normalize_identifier(identifier)
# Validate format before querying database
if not is_valid_identifier_format(normalized):
return InviteCheckResponse(valid=False, error="Invalid invite code format")
result = await db.execute(
select(Invite).where(Invite.identifier == normalized)
)
@ -441,18 +446,23 @@ async def require_regular_user(
return current_user
async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str | None:
"""Get the email of a godfather user by ID."""
if not godfather_id:
return None
result = await db.execute(
select(User.email).where(User.id == godfather_id)
)
return result.scalar_one_or_none()
@app.get("/api/profile", response_model=ProfileResponse)
async def get_profile(
current_user: User = Depends(require_regular_user),
db: AsyncSession = Depends(get_db),
):
"""Get the current user's profile (contact details and godfather)."""
godfather_email = None
if current_user.godfather_id:
result = await db.execute(
select(User.email).where(User.id == current_user.godfather_id)
)
godfather_email = result.scalar_one_or_none()
godfather_email = await get_godfather_email(db, current_user.godfather_id)
return ProfileResponse(
contact_email=current_user.contact_email,
@ -493,13 +503,7 @@ async def update_profile(
await db.commit()
await db.refresh(current_user)
# Get godfather email if set
godfather_email = None
if current_user.godfather_id:
gf_result = await db.execute(
select(User.email).where(User.id == current_user.godfather_id)
)
godfather_email = gf_result.scalar_one_or_none()
godfather_email = await get_godfather_email(db, current_user.godfather_id)
return ProfileResponse(
contact_email=current_user.contact_email,
@ -530,6 +534,9 @@ class InviteResponse(BaseModel):
revoked_at: datetime | None
MAX_INVITE_COLLISION_RETRIES = 3
@app.post("/api/admin/invites", response_model=InviteResponse)
async def create_invite(
data: InviteCreate,
@ -537,33 +544,46 @@ async def create_invite(
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
):
"""Create a new invite for a specified godfather user."""
# Validate godfather exists
result = await db.execute(select(User).where(User.id == data.godfather_id))
godfather = result.scalar_one_or_none()
if not godfather:
# Validate godfather exists and get their info
result = await db.execute(
select(User.id, User.email).where(User.id == data.godfather_id)
)
godfather_row = result.one_or_none()
if not godfather_row:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Godfather user not found",
)
godfather_id, godfather_email = godfather_row
# Generate unique identifier
identifier = generate_invite_identifier()
# Create invite
invite = Invite(
identifier=identifier,
godfather_id=godfather.id,
status=InviteStatus.READY,
)
db.add(invite)
await db.commit()
await db.refresh(invite)
# Try to create invite with retry on collision
invite: Invite | None = None
for attempt in range(MAX_INVITE_COLLISION_RETRIES):
identifier = generate_invite_identifier()
invite = Invite(
identifier=identifier,
godfather_id=godfather_id,
status=InviteStatus.READY,
)
db.add(invite)
try:
await db.commit()
await db.refresh(invite)
break
except IntegrityError:
await db.rollback()
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate unique invite code. Please try again.",
)
assert invite is not None # We either succeeded or raised an exception above
return InviteResponse(
id=invite.id,
identifier=invite.identifier,
godfather_id=invite.godfather_id,
godfather_email=godfather.email,
godfather_email=godfather_email,
status=invite.status.value,
used_by_id=invite.used_by_id,
used_by_email=None,
@ -596,26 +616,18 @@ async def get_my_invites(
)
invites = result.scalars().all()
responses = []
for invite in invites:
used_by_email = None
if invite.used_by_id:
# Fetch the user who used this invite
user_result = await db.execute(
select(User.email).where(User.id == invite.used_by_id)
)
used_by_email = user_result.scalar_one_or_none()
responses.append(UserInviteResponse(
# Use preloaded used_by relationship (selectin loading)
return [
UserInviteResponse(
id=invite.id,
identifier=invite.identifier,
status=invite.status.value,
used_by_email=used_by_email,
used_by_email=invite.used_by.email if invite.used_by else None,
created_at=invite.created_at,
spent_at=invite.spent_at,
))
return responses
)
for invite in invites
]
# Admin Invite Management
@ -674,37 +686,23 @@ async def list_all_invites(
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated invites
# Get paginated invites (relationships loaded via selectin)
offset = (page - 1) * per_page
query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page)
result = await db.execute(query)
invites = result.scalars().all()
# Build responses with user emails
# Build responses using preloaded relationships
records = []
for invite in invites:
# Get godfather email
gf_result = await db.execute(
select(User.email).where(User.id == invite.godfather_id)
)
godfather_email = gf_result.scalar_one()
# Get used_by email if applicable
used_by_email = None
if invite.used_by_id:
ub_result = await db.execute(
select(User.email).where(User.id == invite.used_by_id)
)
used_by_email = ub_result.scalar_one_or_none()
records.append(InviteResponse(
id=invite.id,
identifier=invite.identifier,
godfather_id=invite.godfather_id,
godfather_email=godfather_email,
godfather_email=invite.godfather.email,
status=invite.status.value,
used_by_id=invite.used_by_id,
used_by_email=used_by_email,
used_by_email=invite.used_by.email if invite.used_by else None,
created_at=invite.created_at,
spent_at=invite.spent_at,
revoked_at=invite.revoked_at,
@ -746,20 +744,15 @@ async def revoke_invite(
await db.commit()
await db.refresh(invite)
# Get godfather email
gf_result = await db.execute(
select(User.email).where(User.id == invite.godfather_id)
)
godfather_email = gf_result.scalar_one()
# Use preloaded relationships (selectin loading)
return InviteResponse(
id=invite.id,
identifier=invite.identifier,
godfather_id=invite.godfather_id,
godfather_email=godfather_email,
godfather_email=invite.godfather.email,
status=invite.status.value,
used_by_id=invite.used_by_id,
used_by_email=None,
used_by_email=invite.used_by.email if invite.used_by else None,
created_at=invite.created_at,
spent_at=invite.spent_at,
revoked_at=invite.revoked_at,

View file

@ -3,7 +3,7 @@ import uuid
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models import User, Invite, InviteStatus, ROLE_ADMIN
from models import User, Invite, InviteStatus
from invite_utils import generate_invite_identifier
@ -12,36 +12,32 @@ def unique_email(prefix: str = "test") -> str:
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str:
async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> str:
"""
Create an invite that can be used for registration.
Returns the invite identifier.
Create an invite for an existing godfather user.
Args:
db: Database session
godfather_id: ID of the existing user who will be the godfather
Returns:
The invite identifier.
Raises:
ValueError: If the godfather user doesn't exist.
"""
# Find godfather
result = await db.execute(select(User).where(User.email == godfather_email))
# Verify godfather exists
result = await db.execute(select(User).where(User.id == godfather_id))
godfather = result.scalar_one_or_none()
if not godfather:
# Create a godfather user (admin can create invites)
from auth import get_password_hash
from models import Role
result = await db.execute(select(Role).where(Role.name == ROLE_ADMIN))
admin_role = result.scalar_one_or_none()
godfather = User(
email=godfather_email,
hashed_password=get_password_hash("password123"),
roles=[admin_role] if admin_role else [],
)
db.add(godfather)
await db.flush()
raise ValueError(f"Godfather user with ID {godfather_id} not found")
# Create invite
identifier = generate_invite_identifier()
invite = Invite(
identifier=identifier,
godfather_id=godfather.id,
godfather_id=godfather_id,
status=InviteStatus.READY,
)
db.add(invite)
@ -49,3 +45,29 @@ async def create_invite_for_registration(db: AsyncSession, godfather_email: str)
return identifier
# Backwards-compatible alias that gets a user by email
async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str:
"""
Create an invite for an existing godfather user (looked up by email).
The godfather must already exist in the database.
Args:
db: Database session
godfather_email: Email of the existing user who will be the godfather
Returns:
The invite identifier.
Raises:
ValueError: If the godfather user doesn't exist.
"""
result = await db.execute(select(User).where(User.email == godfather_email))
godfather = result.scalar_one_or_none()
if not godfather:
raise ValueError(f"Godfather user with email '{godfather_email}' not found. "
"Create the user first using create_user_with_roles().")
return await create_invite_for_godfather(db, godfather.id)

View file

@ -6,7 +6,9 @@ users will create invites first via the helper function.
import pytest
from auth import COOKIE_NAME
from tests.helpers import unique_email, create_invite_for_registration
from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
# Registration tests (with invite)
@ -15,9 +17,10 @@ async def test_register_success(client_factory):
"""Can register with valid invite code."""
email = unique_email("register")
# Create invite
# Create godfather user and invite
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("godfather"))
godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post(
"/api/auth/register",
@ -44,10 +47,11 @@ async def test_register_duplicate_email(client_factory):
"""Cannot register with already-used email."""
email = unique_email("duplicate")
# Create two invites
# Create godfather and two invites
async with client_factory.get_db_session() as db:
invite1 = await create_invite_for_registration(db, unique_email("gf1"))
invite2 = await create_invite_for_registration(db, unique_email("gf2"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
# First registration
await client_factory.post(
@ -76,7 +80,8 @@ async def test_register_duplicate_email(client_factory):
async def test_register_invalid_email(client_factory):
"""Cannot register with invalid email format."""
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post(
"/api/auth/register",
@ -133,7 +138,8 @@ async def test_login_success(client_factory):
email = unique_email("login")
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
"/api/auth/register",
@ -161,7 +167,8 @@ async def test_login_wrong_password(client_factory):
email = unique_email("wrongpass")
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
"/api/auth/register",
@ -214,7 +221,8 @@ async def test_get_me_success(client_factory):
email = unique_email("me")
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post(
"/api/auth/register",
@ -269,7 +277,8 @@ async def test_cookie_from_register_works_for_me(client_factory):
email = unique_email("tokentest")
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post(
"/api/auth/register",
@ -294,7 +303,8 @@ async def test_cookie_from_login_works_for_me(client_factory):
email = unique_email("logintoken")
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
"/api/auth/register",
@ -325,8 +335,9 @@ async def test_multiple_users_isolated(client_factory):
email2 = unique_email("user2")
async with client_factory.get_db_session() as db:
invite1 = await create_invite_for_registration(db, unique_email("gf1"))
invite2 = await create_invite_for_registration(db, unique_email("gf2"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
resp1 = await client_factory.post(
"/api/auth/register",
@ -366,7 +377,8 @@ async def test_password_is_hashed(client_factory):
email = unique_email("hashtest")
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
"/api/auth/register",
@ -389,7 +401,8 @@ async def test_case_sensitive_password(client_factory):
email = unique_email("casetest")
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
"/api/auth/register",
@ -413,7 +426,8 @@ async def test_logout_success(client_factory):
email = unique_email("logout")
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post(
"/api/auth/register",

View file

@ -5,7 +5,9 @@ Note: Registration now requires an invite code.
import pytest
from auth import COOKIE_NAME
from tests.helpers import unique_email, create_invite_for_registration
from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
# Protected endpoint tests - without auth
@ -39,7 +41,8 @@ async def test_increment_counter_invalid_cookie(client_factory):
@pytest.mark.asyncio
async def test_get_counter_authenticated(client_factory):
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
@ -61,7 +64,8 @@ async def test_get_counter_authenticated(client_factory):
@pytest.mark.asyncio
async def test_increment_counter(client_factory):
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
@ -87,7 +91,8 @@ async def test_increment_counter(client_factory):
@pytest.mark.asyncio
async def test_increment_counter_multiple(client_factory):
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
@ -115,7 +120,8 @@ async def test_increment_counter_multiple(client_factory):
@pytest.mark.asyncio
async def test_get_counter_after_increment(client_factory):
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
@ -141,10 +147,11 @@ async def test_get_counter_after_increment(client_factory):
# Counter is shared between users
@pytest.mark.asyncio
async def test_counter_shared_between_users(client_factory):
# Create invites for two users
# Create godfather and invites for two users
async with client_factory.get_db_session() as db:
invite1 = await create_invite_for_registration(db, unique_email("gf1"))
invite2 = await create_invite_for_registration(db, unique_email("gf2"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
# Create first user
reg1 = await client_factory.post(

View file

@ -397,6 +397,47 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
assert invite.status == InviteStatus.READY
@pytest.mark.asyncio
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user):
"""Create invite retries with new identifier on collision."""
from unittest.mock import patch
async with client_factory.create(cookies=admin_user["cookies"]) as client:
async with client_factory.get_db_session() as db:
result = await db.execute(
select(User).where(User.email == regular_user["email"])
)
godfather = result.scalar_one()
# Create first invite normally
response1 = await client.post(
"/api/admin/invites",
json={"godfather_id": godfather.id},
)
assert response1.status_code == 200
identifier1 = response1.json()["identifier"]
# Mock generator to first return the same identifier (collision), then a new one
call_count = 0
def mock_generator():
nonlocal call_count
call_count += 1
if call_count == 1:
return identifier1 # Will collide
return f"unique-word-{call_count:02d}" # Won't collide
with patch("main.generate_invite_identifier", side_effect=mock_generator):
response2 = await client.post(
"/api/admin/invites",
json={"godfather_id": godfather.id},
)
assert response2.status_code == 200
# Should have retried and gotten a new identifier
assert response2.json()["identifier"] != identifier1
assert call_count >= 2 # At least one retry
# ============================================================================
# Invite Check API Tests (Phase 3)
# ============================================================================
@ -441,6 +482,32 @@ async def test_check_invite_not_found(client_factory):
assert "not found" in data["error"].lower()
@pytest.mark.asyncio
async def test_check_invite_invalid_format(client_factory):
"""Check endpoint returns error for invalid format without querying DB."""
async with client_factory.create() as client:
# Missing number part
response = await client.get("/api/invites/word-word/check")
assert response.status_code == 200
data = response.json()
assert data["valid"] is False
assert "format" in data["error"].lower()
# Single digit number
response = await client.get("/api/invites/word-word-1/check")
assert response.status_code == 200
data = response.json()
assert data["valid"] is False
assert "format" in data["error"].lower()
# Too many parts
response = await client.get("/api/invites/word-word-word-00/check")
assert response.status_code == 200
data = response.json()
assert data["valid"] is False
assert "format" in data["error"].lower()
@pytest.mark.asyncio
async def test_check_invite_case_insensitive(client_factory, admin_user, regular_user):
"""Check endpoint handles case-insensitive identifiers."""

View file

@ -305,10 +305,13 @@ class TestSecurityBypassAttempts:
Test that new registrations cannot claim admin role.
New users should only get 'regular' role by default.
"""
from tests.helpers import unique_email, create_invite_for_registration
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
from models import ROLE_REGULAR
async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf"))
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post(
"/api/auth/register",