first round of review
This commit is contained in:
parent
870804e7b9
commit
23049da55a
15 changed files with 325 additions and 182 deletions
|
|
@ -3,7 +3,7 @@ import uuid
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models import User, Invite, InviteStatus, ROLE_ADMIN
|
||||
from models import User, Invite, InviteStatus
|
||||
from invite_utils import generate_invite_identifier
|
||||
|
||||
|
||||
|
|
@ -12,36 +12,32 @@ def unique_email(prefix: str = "test") -> str:
|
|||
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
|
||||
|
||||
|
||||
async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str:
|
||||
async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> str:
|
||||
"""
|
||||
Create an invite that can be used for registration.
|
||||
Returns the invite identifier.
|
||||
Create an invite for an existing godfather user.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
godfather_id: ID of the existing user who will be the godfather
|
||||
|
||||
Returns:
|
||||
The invite identifier.
|
||||
|
||||
Raises:
|
||||
ValueError: If the godfather user doesn't exist.
|
||||
"""
|
||||
# Find godfather
|
||||
result = await db.execute(select(User).where(User.email == godfather_email))
|
||||
# Verify godfather exists
|
||||
result = await db.execute(select(User).where(User.id == godfather_id))
|
||||
godfather = result.scalar_one_or_none()
|
||||
|
||||
if not godfather:
|
||||
# Create a godfather user (admin can create invites)
|
||||
from auth import get_password_hash
|
||||
from models import Role
|
||||
|
||||
result = await db.execute(select(Role).where(Role.name == ROLE_ADMIN))
|
||||
admin_role = result.scalar_one_or_none()
|
||||
|
||||
godfather = User(
|
||||
email=godfather_email,
|
||||
hashed_password=get_password_hash("password123"),
|
||||
roles=[admin_role] if admin_role else [],
|
||||
)
|
||||
db.add(godfather)
|
||||
await db.flush()
|
||||
raise ValueError(f"Godfather user with ID {godfather_id} not found")
|
||||
|
||||
# Create invite
|
||||
identifier = generate_invite_identifier()
|
||||
invite = Invite(
|
||||
identifier=identifier,
|
||||
godfather_id=godfather.id,
|
||||
godfather_id=godfather_id,
|
||||
status=InviteStatus.READY,
|
||||
)
|
||||
db.add(invite)
|
||||
|
|
@ -49,3 +45,29 @@ async def create_invite_for_registration(db: AsyncSession, godfather_email: str)
|
|||
|
||||
return identifier
|
||||
|
||||
|
||||
# Backwards-compatible alias that gets a user by email
|
||||
async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str:
|
||||
"""
|
||||
Create an invite for an existing godfather user (looked up by email).
|
||||
|
||||
The godfather must already exist in the database.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
godfather_email: Email of the existing user who will be the godfather
|
||||
|
||||
Returns:
|
||||
The invite identifier.
|
||||
|
||||
Raises:
|
||||
ValueError: If the godfather user doesn't exist.
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.email == godfather_email))
|
||||
godfather = result.scalar_one_or_none()
|
||||
|
||||
if not godfather:
|
||||
raise ValueError(f"Godfather user with email '{godfather_email}' not found. "
|
||||
"Create the user first using create_user_with_roles().")
|
||||
|
||||
return await create_invite_for_godfather(db, godfather.id)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ users will create invites first via the helper function.
|
|||
import pytest
|
||||
|
||||
from auth import COOKIE_NAME
|
||||
from tests.helpers import unique_email, create_invite_for_registration
|
||||
from models import ROLE_REGULAR
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
|
||||
|
||||
# Registration tests (with invite)
|
||||
|
|
@ -15,9 +17,10 @@ async def test_register_success(client_factory):
|
|||
"""Can register with valid invite code."""
|
||||
email = unique_email("register")
|
||||
|
||||
# Create invite
|
||||
# Create godfather user and invite
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("godfather"))
|
||||
godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -44,10 +47,11 @@ async def test_register_duplicate_email(client_factory):
|
|||
"""Cannot register with already-used email."""
|
||||
email = unique_email("duplicate")
|
||||
|
||||
# Create two invites
|
||||
# Create godfather and two invites
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite1 = await create_invite_for_registration(db, unique_email("gf1"))
|
||||
invite2 = await create_invite_for_registration(db, unique_email("gf2"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
# First registration
|
||||
await client_factory.post(
|
||||
|
|
@ -76,7 +80,8 @@ async def test_register_duplicate_email(client_factory):
|
|||
async def test_register_invalid_email(client_factory):
|
||||
"""Cannot register with invalid email format."""
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -133,7 +138,8 @@ async def test_login_success(client_factory):
|
|||
email = unique_email("login")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -161,7 +167,8 @@ async def test_login_wrong_password(client_factory):
|
|||
email = unique_email("wrongpass")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -214,7 +221,8 @@ async def test_get_me_success(client_factory):
|
|||
email = unique_email("me")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -269,7 +277,8 @@ async def test_cookie_from_register_works_for_me(client_factory):
|
|||
email = unique_email("tokentest")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -294,7 +303,8 @@ async def test_cookie_from_login_works_for_me(client_factory):
|
|||
email = unique_email("logintoken")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -325,8 +335,9 @@ async def test_multiple_users_isolated(client_factory):
|
|||
email2 = unique_email("user2")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite1 = await create_invite_for_registration(db, unique_email("gf1"))
|
||||
invite2 = await create_invite_for_registration(db, unique_email("gf2"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
resp1 = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -366,7 +377,8 @@ async def test_password_is_hashed(client_factory):
|
|||
email = unique_email("hashtest")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -389,7 +401,8 @@ async def test_case_sensitive_password(client_factory):
|
|||
email = unique_email("casetest")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -413,7 +426,8 @@ async def test_logout_success(client_factory):
|
|||
email = unique_email("logout")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
|
|||
|
|
@ -5,7 +5,9 @@ Note: Registration now requires an invite code.
|
|||
import pytest
|
||||
|
||||
from auth import COOKIE_NAME
|
||||
from tests.helpers import unique_email, create_invite_for_registration
|
||||
from models import ROLE_REGULAR
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
|
||||
|
||||
# Protected endpoint tests - without auth
|
||||
|
|
@ -39,7 +41,8 @@ async def test_increment_counter_invalid_cookie(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_counter_authenticated(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -61,7 +64,8 @@ async def test_get_counter_authenticated(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_increment_counter(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -87,7 +91,8 @@ async def test_increment_counter(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_increment_counter_multiple(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -115,7 +120,8 @@ async def test_increment_counter_multiple(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_counter_after_increment(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -141,10 +147,11 @@ async def test_get_counter_after_increment(client_factory):
|
|||
# Counter is shared between users
|
||||
@pytest.mark.asyncio
|
||||
async def test_counter_shared_between_users(client_factory):
|
||||
# Create invites for two users
|
||||
# Create godfather and invites for two users
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite1 = await create_invite_for_registration(db, unique_email("gf1"))
|
||||
invite2 = await create_invite_for_registration(db, unique_email("gf2"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
# Create first user
|
||||
reg1 = await client_factory.post(
|
||||
|
|
|
|||
|
|
@ -397,6 +397,47 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
|
|||
assert invite.status == InviteStatus.READY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user):
|
||||
"""Create invite retries with new identifier on collision."""
|
||||
from unittest.mock import patch
|
||||
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
# Create first invite normally
|
||||
response1 = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
assert response1.status_code == 200
|
||||
identifier1 = response1.json()["identifier"]
|
||||
|
||||
# Mock generator to first return the same identifier (collision), then a new one
|
||||
call_count = 0
|
||||
def mock_generator():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return identifier1 # Will collide
|
||||
return f"unique-word-{call_count:02d}" # Won't collide
|
||||
|
||||
with patch("main.generate_invite_identifier", side_effect=mock_generator):
|
||||
response2 = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
|
||||
assert response2.status_code == 200
|
||||
# Should have retried and gotten a new identifier
|
||||
assert response2.json()["identifier"] != identifier1
|
||||
assert call_count >= 2 # At least one retry
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Invite Check API Tests (Phase 3)
|
||||
# ============================================================================
|
||||
|
|
@ -441,6 +482,32 @@ async def test_check_invite_not_found(client_factory):
|
|||
assert "not found" in data["error"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_invalid_format(client_factory):
|
||||
"""Check endpoint returns error for invalid format without querying DB."""
|
||||
async with client_factory.create() as client:
|
||||
# Missing number part
|
||||
response = await client.get("/api/invites/word-word/check")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["valid"] is False
|
||||
assert "format" in data["error"].lower()
|
||||
|
||||
# Single digit number
|
||||
response = await client.get("/api/invites/word-word-1/check")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["valid"] is False
|
||||
assert "format" in data["error"].lower()
|
||||
|
||||
# Too many parts
|
||||
response = await client.get("/api/invites/word-word-word-00/check")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["valid"] is False
|
||||
assert "format" in data["error"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_case_insensitive(client_factory, admin_user, regular_user):
|
||||
"""Check endpoint handles case-insensitive identifiers."""
|
||||
|
|
|
|||
|
|
@ -305,10 +305,13 @@ class TestSecurityBypassAttempts:
|
|||
Test that new registrations cannot claim admin role.
|
||||
New users should only get 'regular' role by default.
|
||||
"""
|
||||
from tests.helpers import unique_email, create_invite_for_registration
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
from models import ROLE_REGULAR
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
invite_code = await create_invite_for_registration(db, unique_email("gf"))
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue