diff --git a/backend/main.py b/backend/main.py index bce9b3e..02b3278 100644 --- a/backend/main.py +++ b/backend/main.py @@ -125,23 +125,10 @@ async def check_invite( ) invite = result.scalar_one_or_none() - if not invite: + # Return same error for not found, spent, and revoked to avoid information leakage + if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED): return InviteCheckResponse(valid=False, error="Invite not found") - if invite.status == InviteStatus.SPENT: - return InviteCheckResponse( - valid=False, - status=invite.status.value, - error="This invite has already been used" - ) - - if invite.status == InviteStatus.REVOKED: - return InviteCheckResponse( - valid=False, - status=invite.status.value, - error="This invite has been revoked" - ) - return InviteCheckResponse(valid=True, status=invite.status.value) @@ -167,24 +154,13 @@ async def register( ) invite = result.scalar_one_or_none() - if not invite: + # Return same error for not found, spent, and revoked to avoid information leakage + if not invite or invite.status in (InviteStatus.SPENT, InviteStatus.REVOKED): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid invite code", ) - if invite.status == InviteStatus.SPENT: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="This invite has already been used", - ) - - if invite.status == InviteStatus.REVOKED: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="This invite has been revoked", - ) - # Check email not already taken existing_user = await get_user_by_email(db, user_data.email) if existing_user: diff --git a/backend/tests/test_invites.py b/backend/tests/test_invites.py index ad05850..5e9ce13 100644 --- a/backend/tests/test_invites.py +++ b/backend/tests/test_invites.py @@ -508,6 +508,82 @@ async def test_check_invite_invalid_format(client_factory): assert "format" in data["error"].lower() +@pytest.mark.asyncio +async def test_check_invite_spent_returns_not_found(client_factory, admin_user, regular_user): + """Check endpoint returns same error for spent invite as for non-existent (no info leakage).""" + # Create invite + async with client_factory.create(cookies=admin_user["cookies"]) as client: + async with client_factory.get_db_session() as db: + result = await db.execute( + select(User).where(User.email == regular_user["email"]) + ) + godfather = result.scalar_one() + + create_resp = await client.post( + "/api/admin/invites", + json={"godfather_id": godfather.id}, + ) + identifier = create_resp.json()["identifier"] + + # Use the invite + async with client_factory.create() as client: + await client.post( + "/api/auth/register", + json={ + "email": unique_email("spentcheck"), + "password": "password123", + "invite_identifier": identifier, + }, + ) + + # Check spent invite - should return same error as non-existent + async with client_factory.create() as client: + response = await client.get(f"/api/invites/{identifier}/check") + + assert response.status_code == 200 + data = response.json() + assert data["valid"] is False + assert "not found" in data["error"].lower() + + +@pytest.mark.asyncio +async def test_check_invite_revoked_returns_not_found(client_factory, admin_user, regular_user): + """Check endpoint returns same error for revoked invite as for non-existent (no info leakage).""" + from datetime import datetime, UTC + + # Create invite + async with client_factory.create(cookies=admin_user["cookies"]) as client: + async with client_factory.get_db_session() as db: + result = await db.execute( + select(User).where(User.email == regular_user["email"]) + ) + godfather = result.scalar_one() + + create_resp = await client.post( + "/api/admin/invites", + json={"godfather_id": godfather.id}, + ) + identifier = create_resp.json()["identifier"] + invite_id = create_resp.json()["id"] + + # Revoke the invite + async with client_factory.get_db_session() as db: + result = await db.execute(select(Invite).where(Invite.id == invite_id)) + invite = result.scalar_one() + invite.status = InviteStatus.REVOKED + invite.revoked_at = datetime.now(UTC) + await db.commit() + + # Check revoked invite - should return same error as non-existent + async with client_factory.create() as client: + response = await client.get(f"/api/invites/{identifier}/check") + + assert response.status_code == 200 + data = response.json() + assert data["valid"] is False + assert "not found" in data["error"].lower() + + @pytest.mark.asyncio async def test_check_invite_case_insensitive(client_factory, admin_user, regular_user): """Check endpoint handles case-insensitive identifiers.""" @@ -712,7 +788,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us ) assert response.status_code == 400 - assert "already been used" in response.json()["detail"] + assert "invalid invite code" in response.json()["detail"].lower() @pytest.mark.asyncio @@ -758,7 +834,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_ ) assert response.status_code == 400 - assert "revoked" in response.json()["detail"].lower() + assert "invalid invite code" in response.json()["detail"].lower() @pytest.mark.asyncio