first round of review

This commit is contained in:
counterweight 2025-12-20 11:43:32 +01:00
parent 870804e7b9
commit 23049da55a
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
15 changed files with 325 additions and 182 deletions

View file

@ -23,6 +23,9 @@ db:
db-stop: db-stop:
docker compose down docker compose down
db-clean:
docker compose down -v
db-ready: db-ready:
@docker compose up -d db @docker compose up -d db
@echo "Waiting for PostgreSQL to be ready..." @echo "Waiting for PostgreSQL to be ready..."
@ -42,13 +45,13 @@ dev:
cd frontend && npm run dev & \ cd frontend && npm run dev & \
wait wait
test-backend: db-ready test-backend: db-clean db-ready
cd backend && uv run pytest -v cd backend && uv run pytest -v
test-frontend: test-frontend:
cd frontend && npm run test cd frontend && npm run test
test-e2e: test-e2e: db-clean db-ready
./scripts/e2e.sh ./scripts/e2e.sh
test: test-backend test-frontend test-e2e test: test-backend test-frontend test-e2e

View file

@ -6,6 +6,7 @@ from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
from sqlalchemy import select, func, desc from sqlalchemy import select, func, desc
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import ( from auth import (
@ -25,7 +26,7 @@ from auth import (
from database import engine, get_db, Base from database import engine, get_db, Base
from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR, Invite, InviteStatus from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR, Invite, InviteStatus
from validation import validate_profile_fields from validation import validate_profile_fields
from invite_utils import generate_invite_identifier, normalize_identifier from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format
R = TypeVar("R", bound=BaseModel) R = TypeVar("R", bound=BaseModel)
@ -115,6 +116,10 @@ async def check_invite(
"""Check if an invite is valid and can be used for signup.""" """Check if an invite is valid and can be used for signup."""
normalized = normalize_identifier(identifier) normalized = normalize_identifier(identifier)
# Validate format before querying database
if not is_valid_identifier_format(normalized):
return InviteCheckResponse(valid=False, error="Invalid invite code format")
result = await db.execute( result = await db.execute(
select(Invite).where(Invite.identifier == normalized) select(Invite).where(Invite.identifier == normalized)
) )
@ -441,18 +446,23 @@ async def require_regular_user(
return current_user return current_user
async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str | None:
"""Get the email of a godfather user by ID."""
if not godfather_id:
return None
result = await db.execute(
select(User.email).where(User.id == godfather_id)
)
return result.scalar_one_or_none()
@app.get("/api/profile", response_model=ProfileResponse) @app.get("/api/profile", response_model=ProfileResponse)
async def get_profile( async def get_profile(
current_user: User = Depends(require_regular_user), current_user: User = Depends(require_regular_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
"""Get the current user's profile (contact details and godfather).""" """Get the current user's profile (contact details and godfather)."""
godfather_email = None godfather_email = await get_godfather_email(db, current_user.godfather_id)
if current_user.godfather_id:
result = await db.execute(
select(User.email).where(User.id == current_user.godfather_id)
)
godfather_email = result.scalar_one_or_none()
return ProfileResponse( return ProfileResponse(
contact_email=current_user.contact_email, contact_email=current_user.contact_email,
@ -493,13 +503,7 @@ async def update_profile(
await db.commit() await db.commit()
await db.refresh(current_user) await db.refresh(current_user)
# Get godfather email if set godfather_email = await get_godfather_email(db, current_user.godfather_id)
godfather_email = None
if current_user.godfather_id:
gf_result = await db.execute(
select(User.email).where(User.id == current_user.godfather_id)
)
godfather_email = gf_result.scalar_one_or_none()
return ProfileResponse( return ProfileResponse(
contact_email=current_user.contact_email, contact_email=current_user.contact_email,
@ -530,6 +534,9 @@ class InviteResponse(BaseModel):
revoked_at: datetime | None revoked_at: datetime | None
MAX_INVITE_COLLISION_RETRIES = 3
@app.post("/api/admin/invites", response_model=InviteResponse) @app.post("/api/admin/invites", response_model=InviteResponse)
async def create_invite( async def create_invite(
data: InviteCreate, data: InviteCreate,
@ -537,33 +544,46 @@ async def create_invite(
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
): ):
"""Create a new invite for a specified godfather user.""" """Create a new invite for a specified godfather user."""
# Validate godfather exists # Validate godfather exists and get their info
result = await db.execute(select(User).where(User.id == data.godfather_id)) result = await db.execute(
godfather = result.scalar_one_or_none() select(User.id, User.email).where(User.id == data.godfather_id)
if not godfather: )
godfather_row = result.one_or_none()
if not godfather_row:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Godfather user not found", detail="Godfather user not found",
) )
godfather_id, godfather_email = godfather_row
# Generate unique identifier # Try to create invite with retry on collision
invite: Invite | None = None
for attempt in range(MAX_INVITE_COLLISION_RETRIES):
identifier = generate_invite_identifier() identifier = generate_invite_identifier()
# Create invite
invite = Invite( invite = Invite(
identifier=identifier, identifier=identifier,
godfather_id=godfather.id, godfather_id=godfather_id,
status=InviteStatus.READY, status=InviteStatus.READY,
) )
db.add(invite) db.add(invite)
try:
await db.commit() await db.commit()
await db.refresh(invite) await db.refresh(invite)
break
except IntegrityError:
await db.rollback()
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate unique invite code. Please try again.",
)
assert invite is not None # We either succeeded or raised an exception above
return InviteResponse( return InviteResponse(
id=invite.id, id=invite.id,
identifier=invite.identifier, identifier=invite.identifier,
godfather_id=invite.godfather_id, godfather_id=invite.godfather_id,
godfather_email=godfather.email, godfather_email=godfather_email,
status=invite.status.value, status=invite.status.value,
used_by_id=invite.used_by_id, used_by_id=invite.used_by_id,
used_by_email=None, used_by_email=None,
@ -596,26 +616,18 @@ async def get_my_invites(
) )
invites = result.scalars().all() invites = result.scalars().all()
responses = [] # Use preloaded used_by relationship (selectin loading)
for invite in invites: return [
used_by_email = None UserInviteResponse(
if invite.used_by_id:
# Fetch the user who used this invite
user_result = await db.execute(
select(User.email).where(User.id == invite.used_by_id)
)
used_by_email = user_result.scalar_one_or_none()
responses.append(UserInviteResponse(
id=invite.id, id=invite.id,
identifier=invite.identifier, identifier=invite.identifier,
status=invite.status.value, status=invite.status.value,
used_by_email=used_by_email, used_by_email=invite.used_by.email if invite.used_by else None,
created_at=invite.created_at, created_at=invite.created_at,
spent_at=invite.spent_at, spent_at=invite.spent_at,
)) )
for invite in invites
return responses ]
# Admin Invite Management # Admin Invite Management
@ -674,37 +686,23 @@ async def list_all_invites(
total = count_result.scalar() or 0 total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1 total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated invites # Get paginated invites (relationships loaded via selectin)
offset = (page - 1) * per_page offset = (page - 1) * per_page
query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page) query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page)
result = await db.execute(query) result = await db.execute(query)
invites = result.scalars().all() invites = result.scalars().all()
# Build responses with user emails # Build responses using preloaded relationships
records = [] records = []
for invite in invites: for invite in invites:
# Get godfather email
gf_result = await db.execute(
select(User.email).where(User.id == invite.godfather_id)
)
godfather_email = gf_result.scalar_one()
# Get used_by email if applicable
used_by_email = None
if invite.used_by_id:
ub_result = await db.execute(
select(User.email).where(User.id == invite.used_by_id)
)
used_by_email = ub_result.scalar_one_or_none()
records.append(InviteResponse( records.append(InviteResponse(
id=invite.id, id=invite.id,
identifier=invite.identifier, identifier=invite.identifier,
godfather_id=invite.godfather_id, godfather_id=invite.godfather_id,
godfather_email=godfather_email, godfather_email=invite.godfather.email,
status=invite.status.value, status=invite.status.value,
used_by_id=invite.used_by_id, used_by_id=invite.used_by_id,
used_by_email=used_by_email, used_by_email=invite.used_by.email if invite.used_by else None,
created_at=invite.created_at, created_at=invite.created_at,
spent_at=invite.spent_at, spent_at=invite.spent_at,
revoked_at=invite.revoked_at, revoked_at=invite.revoked_at,
@ -746,20 +744,15 @@ async def revoke_invite(
await db.commit() await db.commit()
await db.refresh(invite) await db.refresh(invite)
# Get godfather email # Use preloaded relationships (selectin loading)
gf_result = await db.execute(
select(User.email).where(User.id == invite.godfather_id)
)
godfather_email = gf_result.scalar_one()
return InviteResponse( return InviteResponse(
id=invite.id, id=invite.id,
identifier=invite.identifier, identifier=invite.identifier,
godfather_id=invite.godfather_id, godfather_id=invite.godfather_id,
godfather_email=godfather_email, godfather_email=invite.godfather.email,
status=invite.status.value, status=invite.status.value,
used_by_id=invite.used_by_id, used_by_id=invite.used_by_id,
used_by_email=None, used_by_email=invite.used_by.email if invite.used_by else None,
created_at=invite.created_at, created_at=invite.created_at,
spent_at=invite.spent_at, spent_at=invite.spent_at,
revoked_at=invite.revoked_at, revoked_at=invite.revoked_at,

View file

@ -3,7 +3,7 @@ import uuid
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from models import User, Invite, InviteStatus, ROLE_ADMIN from models import User, Invite, InviteStatus
from invite_utils import generate_invite_identifier from invite_utils import generate_invite_identifier
@ -12,36 +12,32 @@ def unique_email(prefix: str = "test") -> str:
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com" return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str: async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> str:
""" """
Create an invite that can be used for registration. Create an invite for an existing godfather user.
Returns the invite identifier.
Args:
db: Database session
godfather_id: ID of the existing user who will be the godfather
Returns:
The invite identifier.
Raises:
ValueError: If the godfather user doesn't exist.
""" """
# Find godfather # Verify godfather exists
result = await db.execute(select(User).where(User.email == godfather_email)) result = await db.execute(select(User).where(User.id == godfather_id))
godfather = result.scalar_one_or_none() godfather = result.scalar_one_or_none()
if not godfather: if not godfather:
# Create a godfather user (admin can create invites) raise ValueError(f"Godfather user with ID {godfather_id} not found")
from auth import get_password_hash
from models import Role
result = await db.execute(select(Role).where(Role.name == ROLE_ADMIN))
admin_role = result.scalar_one_or_none()
godfather = User(
email=godfather_email,
hashed_password=get_password_hash("password123"),
roles=[admin_role] if admin_role else [],
)
db.add(godfather)
await db.flush()
# Create invite # Create invite
identifier = generate_invite_identifier() identifier = generate_invite_identifier()
invite = Invite( invite = Invite(
identifier=identifier, identifier=identifier,
godfather_id=godfather.id, godfather_id=godfather_id,
status=InviteStatus.READY, status=InviteStatus.READY,
) )
db.add(invite) db.add(invite)
@ -49,3 +45,29 @@ async def create_invite_for_registration(db: AsyncSession, godfather_email: str)
return identifier return identifier
# Backwards-compatible alias that gets a user by email
async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str:
"""
Create an invite for an existing godfather user (looked up by email).
The godfather must already exist in the database.
Args:
db: Database session
godfather_email: Email of the existing user who will be the godfather
Returns:
The invite identifier.
Raises:
ValueError: If the godfather user doesn't exist.
"""
result = await db.execute(select(User).where(User.email == godfather_email))
godfather = result.scalar_one_or_none()
if not godfather:
raise ValueError(f"Godfather user with email '{godfather_email}' not found. "
"Create the user first using create_user_with_roles().")
return await create_invite_for_godfather(db, godfather.id)

View file

@ -6,7 +6,9 @@ users will create invites first via the helper function.
import pytest import pytest
from auth import COOKIE_NAME from auth import COOKIE_NAME
from tests.helpers import unique_email, create_invite_for_registration from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
# Registration tests (with invite) # Registration tests (with invite)
@ -15,9 +17,10 @@ async def test_register_success(client_factory):
"""Can register with valid invite code.""" """Can register with valid invite code."""
email = unique_email("register") email = unique_email("register")
# Create invite # Create godfather user and invite
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("godfather")) godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -44,10 +47,11 @@ async def test_register_duplicate_email(client_factory):
"""Cannot register with already-used email.""" """Cannot register with already-used email."""
email = unique_email("duplicate") email = unique_email("duplicate")
# Create two invites # Create godfather and two invites
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite1 = await create_invite_for_registration(db, unique_email("gf1")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite2 = await create_invite_for_registration(db, unique_email("gf2")) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
# First registration # First registration
await client_factory.post( await client_factory.post(
@ -76,7 +80,8 @@ async def test_register_duplicate_email(client_factory):
async def test_register_invalid_email(client_factory): async def test_register_invalid_email(client_factory):
"""Cannot register with invalid email format.""" """Cannot register with invalid email format."""
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -133,7 +138,8 @@ async def test_login_success(client_factory):
email = unique_email("login") email = unique_email("login")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -161,7 +167,8 @@ async def test_login_wrong_password(client_factory):
email = unique_email("wrongpass") email = unique_email("wrongpass")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -214,7 +221,8 @@ async def test_get_me_success(client_factory):
email = unique_email("me") email = unique_email("me")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -269,7 +277,8 @@ async def test_cookie_from_register_works_for_me(client_factory):
email = unique_email("tokentest") email = unique_email("tokentest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -294,7 +303,8 @@ async def test_cookie_from_login_works_for_me(client_factory):
email = unique_email("logintoken") email = unique_email("logintoken")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -325,8 +335,9 @@ async def test_multiple_users_isolated(client_factory):
email2 = unique_email("user2") email2 = unique_email("user2")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite1 = await create_invite_for_registration(db, unique_email("gf1")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite2 = await create_invite_for_registration(db, unique_email("gf2")) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
resp1 = await client_factory.post( resp1 = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -366,7 +377,8 @@ async def test_password_is_hashed(client_factory):
email = unique_email("hashtest") email = unique_email("hashtest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -389,7 +401,8 @@ async def test_case_sensitive_password(client_factory):
email = unique_email("casetest") email = unique_email("casetest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -413,7 +426,8 @@ async def test_logout_success(client_factory):
email = unique_email("logout") email = unique_email("logout")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(
"/api/auth/register", "/api/auth/register",

View file

@ -5,7 +5,9 @@ Note: Registration now requires an invite code.
import pytest import pytest
from auth import COOKIE_NAME from auth import COOKIE_NAME
from tests.helpers import unique_email, create_invite_for_registration from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
# Protected endpoint tests - without auth # Protected endpoint tests - without auth
@ -39,7 +41,8 @@ async def test_increment_counter_invalid_cookie(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_authenticated(client_factory): async def test_get_counter_authenticated(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -61,7 +64,8 @@ async def test_get_counter_authenticated(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter(client_factory): async def test_increment_counter(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -87,7 +91,8 @@ async def test_increment_counter(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter_multiple(client_factory): async def test_increment_counter_multiple(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -115,7 +120,8 @@ async def test_increment_counter_multiple(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_after_increment(client_factory): async def test_get_counter_after_increment(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -141,10 +147,11 @@ async def test_get_counter_after_increment(client_factory):
# Counter is shared between users # Counter is shared between users
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_counter_shared_between_users(client_factory): async def test_counter_shared_between_users(client_factory):
# Create invites for two users # Create godfather and invites for two users
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite1 = await create_invite_for_registration(db, unique_email("gf1")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite2 = await create_invite_for_registration(db, unique_email("gf2")) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
# Create first user # Create first user
reg1 = await client_factory.post( reg1 = await client_factory.post(

View file

@ -397,6 +397,47 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
assert invite.status == InviteStatus.READY assert invite.status == InviteStatus.READY
@pytest.mark.asyncio
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user):
"""Create invite retries with new identifier on collision."""
from unittest.mock import patch
async with client_factory.create(cookies=admin_user["cookies"]) as client:
async with client_factory.get_db_session() as db:
result = await db.execute(
select(User).where(User.email == regular_user["email"])
)
godfather = result.scalar_one()
# Create first invite normally
response1 = await client.post(
"/api/admin/invites",
json={"godfather_id": godfather.id},
)
assert response1.status_code == 200
identifier1 = response1.json()["identifier"]
# Mock generator to first return the same identifier (collision), then a new one
call_count = 0
def mock_generator():
nonlocal call_count
call_count += 1
if call_count == 1:
return identifier1 # Will collide
return f"unique-word-{call_count:02d}" # Won't collide
with patch("main.generate_invite_identifier", side_effect=mock_generator):
response2 = await client.post(
"/api/admin/invites",
json={"godfather_id": godfather.id},
)
assert response2.status_code == 200
# Should have retried and gotten a new identifier
assert response2.json()["identifier"] != identifier1
assert call_count >= 2 # At least one retry
# ============================================================================ # ============================================================================
# Invite Check API Tests (Phase 3) # Invite Check API Tests (Phase 3)
# ============================================================================ # ============================================================================
@ -441,6 +482,32 @@ async def test_check_invite_not_found(client_factory):
assert "not found" in data["error"].lower() assert "not found" in data["error"].lower()
@pytest.mark.asyncio
async def test_check_invite_invalid_format(client_factory):
"""Check endpoint returns error for invalid format without querying DB."""
async with client_factory.create() as client:
# Missing number part
response = await client.get("/api/invites/word-word/check")
assert response.status_code == 200
data = response.json()
assert data["valid"] is False
assert "format" in data["error"].lower()
# Single digit number
response = await client.get("/api/invites/word-word-1/check")
assert response.status_code == 200
data = response.json()
assert data["valid"] is False
assert "format" in data["error"].lower()
# Too many parts
response = await client.get("/api/invites/word-word-word-00/check")
assert response.status_code == 200
data = response.json()
assert data["valid"] is False
assert "format" in data["error"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_invite_case_insensitive(client_factory, admin_user, regular_user): async def test_check_invite_case_insensitive(client_factory, admin_user, regular_user):
"""Check endpoint handles case-insensitive identifiers.""" """Check endpoint handles case-insensitive identifiers."""

View file

@ -305,10 +305,13 @@ class TestSecurityBypassAttempts:
Test that new registrations cannot claim admin role. Test that new registrations cannot claim admin role.
New users should only get 'regular' role by default. New users should only get 'regular' role by default.
""" """
from tests.helpers import unique_email, create_invite_for_registration from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
from models import ROLE_REGULAR
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
invite_code = await create_invite_for_registration(db, unique_email("gf")) godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
"/api/auth/register", "/api/auth/register",

View file

@ -43,7 +43,7 @@ export default function AdminInvitesPage() {
const [createError, setCreateError] = useState<string | null>(null); const [createError, setCreateError] = useState<string | null>(null);
const [users, setUsers] = useState<UserOption[]>([]); const [users, setUsers] = useState<UserOption[]>([]);
const { user, isLoading, isAuthorized } = useRequireAuth({ const { user, isLoading, isAuthorized } = useRequireAuth({
requiredPermission: Permission.VIEW_AUDIT, // Admins have this requiredPermission: Permission.MANAGE_INVITES,
fallbackRedirect: "/", fallbackRedirect: "/",
}); });

View file

@ -43,37 +43,11 @@ export function Header({ currentPage }: HeaderProps) {
if (!user) return null; if (!user) return null;
// For admin pages, show admin navigation // Build nav items based on user role
if (isAdminUser && (currentPage === "audit" || currentPage === "admin-invites")) { // Admin users see admin nav items, regular users see regular nav items
return ( const navItems = isAdminUser ? ADMIN_NAV_ITEMS : REGULAR_NAV_ITEMS;
<div style={sharedStyles.header}> const visibleItems = navItems.filter(
<div style={sharedStyles.nav}> (item) => (!item.regularOnly || isRegularUser) && (!item.adminOnly || isAdminUser)
{ADMIN_NAV_ITEMS.map((item, index) => (
<span key={item.id}>
{index > 0 && <span style={sharedStyles.navDivider}></span>}
{item.id === currentPage ? (
<span style={sharedStyles.navCurrent}>{item.label}</span>
) : (
<a href={item.href} style={sharedStyles.navLink}>
{item.label}
</a>
)}
</span>
))}
</div>
<div style={sharedStyles.userInfo}>
<span style={sharedStyles.userEmail}>{user.email}</span>
<button onClick={handleLogout} style={sharedStyles.logoutBtn}>
Sign out
</button>
</div>
</div>
);
}
// For regular pages, build nav with links
const visibleItems = REGULAR_NAV_ITEMS.filter(
(item) => !item.regularOnly || isRegularUser
); );
return ( return (

View file

@ -258,9 +258,9 @@ describe("Home - Navigation", () => {
render(<Home />); render(<Home />);
// Wait for render, then check profile link is not present // Wait for render - admin sees admin nav (Audit, Invites) not regular nav
await waitFor(() => { await waitFor(() => {
expect(screen.getByText("Counter")).toBeDefined(); expect(screen.getByText("Audit")).toBeDefined();
}); });
expect(screen.queryByText("My Profile")).toBeNull(); expect(screen.queryByText("My Profile")).toBeNull();
}); });

View file

@ -7,10 +7,13 @@ import { useAuth } from "../../auth-context";
export default function SignupWithCodePage() { export default function SignupWithCodePage() {
const params = useParams(); const params = useParams();
const router = useRouter(); const router = useRouter();
const { user } = useAuth(); const { user, isLoading } = useAuth();
const code = params.code as string; const code = params.code as string;
useEffect(() => { useEffect(() => {
// Wait for auth check to complete before redirecting
if (isLoading) return;
if (user) { if (user) {
// Already logged in, redirect to home // Already logged in, redirect to home
router.replace("/"); router.replace("/");
@ -18,7 +21,7 @@ export default function SignupWithCodePage() {
// Redirect to signup with code as query param // Redirect to signup with code as query param
router.replace(`/signup?code=${encodeURIComponent(code)}`); router.replace(`/signup?code=${encodeURIComponent(code)}`);
} }
}, [user, code, router]); }, [user, isLoading, code, router]);
return ( return (
<main style={{ <main style={{

View file

@ -5,10 +5,11 @@ import SignupPage from "./page";
const mockPush = vi.fn(); const mockPush = vi.fn();
vi.mock("next/navigation", () => ({ vi.mock("next/navigation", () => ({
useRouter: () => ({ push: mockPush }), useRouter: () => ({ push: mockPush }),
useSearchParams: () => ({ get: () => null }),
})); }));
vi.mock("../auth-context", () => ({ vi.mock("../auth-context", () => ({
useAuth: () => ({ register: vi.fn() }), useAuth: () => ({ user: null, register: vi.fn() }),
})); }));
beforeEach(() => vi.clearAllMocks()); beforeEach(() => vi.clearAllMocks());
@ -16,19 +17,18 @@ afterEach(() => cleanup());
test("renders signup form with title", () => { test("renders signup form with title", () => {
render(<SignupPage />); render(<SignupPage />);
expect(screen.getByRole("heading", { name: "Create account" })).toBeDefined(); // Step 1 shows "Join with Invite" title (invite code entry)
expect(screen.getByRole("heading", { name: "Join with Invite" })).toBeDefined();
}); });
test("renders email and password inputs", () => { test("renders invite code input", () => {
render(<SignupPage />); render(<SignupPage />);
expect(screen.getByLabelText("Email")).toBeDefined(); expect(screen.getByLabelText("Invite Code")).toBeDefined();
expect(screen.getByLabelText("Password")).toBeDefined();
expect(screen.getByLabelText("Confirm Password")).toBeDefined();
}); });
test("renders create account button", () => { test("renders continue button", () => {
render(<SignupPage />); render(<SignupPage />);
expect(screen.getByRole("button", { name: "Create account" })).toBeDefined(); expect(screen.getByRole("button", { name: "Continue" })).toBeDefined();
}); });
test("renders link to login", () => { test("renders link to login", () => {

View file

@ -1,6 +1,6 @@
"use client"; "use client";
import { useState, useEffect, Suspense } from "react"; import { useState, useEffect, useCallback, Suspense } from "react";
import { useRouter, useSearchParams } from "next/navigation"; import { useRouter, useSearchParams } from "next/navigation";
import { useAuth } from "../auth-context"; import { useAuth } from "../auth-context";
import { api } from "../api"; import { api } from "../api";
@ -20,6 +20,7 @@ function SignupContent() {
const [inviteValid, setInviteValid] = useState<boolean | null>(null); const [inviteValid, setInviteValid] = useState<boolean | null>(null);
const [inviteError, setInviteError] = useState(""); const [inviteError, setInviteError] = useState("");
const [isCheckingInvite, setIsCheckingInvite] = useState(false); const [isCheckingInvite, setIsCheckingInvite] = useState(false);
const [isCheckingInitialCode, setIsCheckingInitialCode] = useState(!!initialCode);
const [email, setEmail] = useState(""); const [email, setEmail] = useState("");
const [password, setPassword] = useState(""); const [password, setPassword] = useState("");
@ -37,14 +38,7 @@ function SignupContent() {
} }
}, [user, router]); }, [user, router]);
// Check invite code on mount if provided in URL const checkInvite = useCallback(async (code: string) => {
useEffect(() => {
if (initialCode) {
checkInvite(initialCode);
}
}, [initialCode]);
const checkInvite = async (code: string) => {
if (!code.trim()) { if (!code.trim()) {
setInviteValid(null); setInviteValid(null);
setInviteError(""); setInviteError("");
@ -72,7 +66,14 @@ function SignupContent() {
} finally { } finally {
setIsCheckingInvite(false); setIsCheckingInvite(false);
} }
}; }, []);
// Check invite code on mount if provided in URL
useEffect(() => {
if (initialCode) {
checkInvite(initialCode).finally(() => setIsCheckingInitialCode(false));
}
}, [initialCode, checkInvite]);
const handleInviteSubmit = (e: React.FormEvent) => { const handleInviteSubmit = (e: React.FormEvent) => {
e.preventDefault(); e.preventDefault();
@ -110,6 +111,21 @@ function SignupContent() {
return null; return null;
} }
// Show loading state while checking initial code from URL
if (isCheckingInitialCode) {
return (
<main style={styles.main}>
<div style={styles.container}>
<div style={styles.card}>
<div style={{ textAlign: "center", color: "rgba(255,255,255,0.6)" }}>
Checking invite code...
</div>
</div>
</div>
</main>
);
}
// Step 1: Enter invite code // Step 1: Enter invite code
if (!inviteValid) { if (!inviteValid) {
return ( return (

View file

@ -79,6 +79,53 @@ test.describe("Authentication Flow", () => {
}); });
}); });
test.describe("Logged-in User Visiting Invite URL", () => {
test("redirects to home when logged-in user visits direct invite URL", async ({ page, request }) => {
const email = uniqueEmail();
const inviteCode = await createInvite(request);
// First sign up to create a user
await page.goto("/signup");
await page.fill('input#inviteCode', inviteCode);
await page.click('button[type="submit"]');
await expect(page.locator("h1")).toHaveText("Create account");
await page.fill('input#email', email);
await page.fill('input#password', "password123");
await page.fill('input#confirmPassword', "password123");
await page.click('button[type="submit"]');
await expect(page).toHaveURL("/");
// Create another invite
const anotherInvite = await createInvite(request);
// Visit invite URL while logged in - should redirect to home
await page.goto(`/signup/${anotherInvite}`);
await expect(page).toHaveURL("/");
});
test("redirects to home when logged-in user visits signup page", async ({ page, request }) => {
const email = uniqueEmail();
const inviteCode = await createInvite(request);
// Sign up and stay logged in
await page.goto("/signup");
await page.fill('input#inviteCode', inviteCode);
await page.click('button[type="submit"]');
await expect(page.locator("h1")).toHaveText("Create account");
await page.fill('input#email', email);
await page.fill('input#password', "password123");
await page.fill('input#confirmPassword', "password123");
await page.click('button[type="submit"]');
await expect(page).toHaveURL("/");
// Try to visit signup page while logged in - should redirect to home
await page.goto("/signup");
await expect(page).toHaveURL("/");
});
});
test.describe("Signup with Invite", () => { test.describe("Signup with Invite", () => {
test.beforeEach(async ({ page }) => { test.beforeEach(async ({ page }) => {
await clearAuth(page); await clearAuth(page);

View file

@ -14,12 +14,6 @@ fi
pkill -f "uvicorn main:app" 2>/dev/null || true pkill -f "uvicorn main:app" 2>/dev/null || true
sleep 1 sleep 1
# Start db
docker compose up -d db
# Wait for db to be ready
sleep 2
# Seed the database with roles and test users # Seed the database with roles and test users
cd backend cd backend
echo "Seeding database..." echo "Seeding database..."