Add ruff linter/formatter for Python

- Add ruff as dev dependency
- Configure ruff in pyproject.toml with strict 88-char line limit
- Ignore B008 (FastAPI Depends pattern is standard)
- Allow longer lines in tests for readability
- Fix all lint issues in source files
- Add Makefile targets: lint-backend, format-backend, fix-backend
This commit is contained in:
counterweight 2025-12-21 21:54:26 +01:00
parent 69bc8413e0
commit 6c218130e9
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
31 changed files with 1234 additions and 876 deletions

View file

@ -1,4 +1,4 @@
.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants .PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants lint-backend format-backend fix-backend
-include .env -include .env
export export
@ -93,3 +93,12 @@ check-types-fresh: generate-types-standalone
check-constants: check-constants:
@cd backend && uv run python validate_constants.py @cd backend && uv run python validate_constants.py
lint-backend:
cd backend && uv run ruff check .
format-backend:
cd backend && uv run ruff format .
fix-backend:
cd backend && uv run ruff check --fix . && uv run ruff format .

View file

@ -1,5 +1,5 @@
import os import os
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
import bcrypt import bcrypt
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db from database import get_db
from models import User, Permission from models import Permission, User
from schemas import UserResponse from schemas import UserResponse
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
@ -32,9 +32,13 @@ def get_password_hash(password: str) -> str:
).decode("utf-8") ).decode("utf-8")
def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str: def create_access_token(
data: dict[str, str],
expires_delta: timedelta | None = None,
) -> str:
to_encode: dict[str, str | datetime] = dict(data) to_encode: dict[str, str | datetime] = dict(data)
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + delta
to_encode["exp"] = expire to_encode["exp"] = expire
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded return encoded
@ -60,11 +64,11 @@ async def get_current_user(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials", detail="Invalid authentication credentials",
) )
token = request.cookies.get(COOKIE_NAME) token = request.cookies.get(COOKIE_NAME)
if not token: if not token:
raise credentials_exception raise credentials_exception
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub") user_id_str = payload.get("sub")
@ -72,7 +76,7 @@ async def get_current_user(
raise credentials_exception raise credentials_exception
user_id = int(user_id_str) user_id = int(user_id_str)
except (JWTError, ValueError): except (JWTError, ValueError):
raise credentials_exception raise credentials_exception from None
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
@ -83,27 +87,32 @@ async def get_current_user(
def require_permission(*required_permissions: Permission): def require_permission(*required_permissions: Permission):
""" """
Dependency factory that checks if user has ALL of the required permissions. Dependency factory that checks if user has ALL required permissions.
Usage: Usage:
@app.get("/api/counter") @app.get("/api/counter")
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))): async def get_counter(
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
):
... ...
""" """
async def permission_checker( async def permission_checker(
request: Request, request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> User: ) -> User:
user = await get_current_user(request, db) user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db) user_permissions = await user.get_permissions(db)
missing = [p for p in required_permissions if p not in user_permissions] missing = [p for p in required_permissions if p not in user_permissions]
if missing: if missing:
missing_str = ", ".join(p.value for p in missing)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}", detail=f"Missing required permissions: {missing_str}",
) )
return user return user
return permission_checker return permission_checker

View file

@ -1,8 +1,11 @@
import os import os
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret") DATABASE_URL = os.getenv(
"DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret"
)
engine = create_async_engine(DATABASE_URL) engine = create_async_engine(DATABASE_URL)
async_session = async_sessionmaker(engine, expire_on_commit=False) async_session = async_sessionmaker(engine, expire_on_commit=False)
@ -15,4 +18,3 @@ class Base(DeclarativeBase):
async def get_db(): async def get_db():
async with async_session() as session: async with async_session() as session:
yield session yield session

View file

@ -1,4 +1,5 @@
"""Utilities for invite code generation and validation.""" """Utilities for invite code generation and validation."""
import random import random
from pathlib import Path from pathlib import Path
@ -13,11 +14,11 @@ assert len(BIP39_WORDS) == 2048, f"Expected 2048 BIP39 words, got {len(BIP39_WOR
def generate_invite_identifier() -> str: def generate_invite_identifier() -> str:
""" """
Generate a unique invite identifier. Generate a unique invite identifier.
Format: word1-word2-NN where: Format: word1-word2-NN where:
- word1, word2 are random BIP39 words - word1, word2 are random BIP39 words
- NN is a two-digit number (00-99) - NN is a two-digit number (00-99)
Returns lowercase identifier. Returns lowercase identifier.
""" """
word1 = random.choice(BIP39_WORDS) word1 = random.choice(BIP39_WORDS)
@ -29,7 +30,7 @@ def generate_invite_identifier() -> str:
def normalize_identifier(identifier: str) -> str: def normalize_identifier(identifier: str) -> str:
""" """
Normalize an invite identifier for comparison/lookup. Normalize an invite identifier for comparison/lookup.
- Converts to lowercase - Converts to lowercase
- Strips whitespace - Strips whitespace
""" """
@ -39,22 +40,18 @@ def normalize_identifier(identifier: str) -> str:
def is_valid_identifier_format(identifier: str) -> bool: def is_valid_identifier_format(identifier: str) -> bool:
""" """
Check if an identifier has valid format (word-word-NN). Check if an identifier has valid format (word-word-NN).
Does NOT check if words are valid BIP39 words. Does NOT check if words are valid BIP39 words.
""" """
parts = identifier.split("-") parts = identifier.split("-")
if len(parts) != 3: if len(parts) != 3:
return False return False
word1, word2, number = parts word1, word2, number = parts
# Check words are non-empty # Check words are non-empty
if not word1 or not word2: if not word1 or not word2:
return False return False
# Check number is two digits
if len(number) != 2 or not number.isdigit():
return False
return True
# Check number is two digits
return len(number) == 2 and number.isdigit()

View file

@ -1,19 +1,20 @@
"""FastAPI application entry point.""" """FastAPI application entry point."""
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from database import engine, Base from database import Base, engine
from routes import sum as sum_routes
from routes import counter as counter_routes
from routes import audit as audit_routes from routes import audit as audit_routes
from routes import profile as profile_routes
from routes import invites as invites_routes
from routes import auth as auth_routes from routes import auth as auth_routes
from routes import meta as meta_routes
from routes import availability as availability_routes from routes import availability as availability_routes
from routes import booking as booking_routes from routes import booking as booking_routes
from routes import counter as counter_routes
from routes import invites as invites_routes
from routes import meta as meta_routes
from routes import profile as profile_routes
from routes import sum as sum_routes
from validate_constants import validate_shared_constants from validate_constants import validate_shared_constants
@ -22,7 +23,7 @@ async def lifespan(app: FastAPI):
"""Create database tables on startup and validate constants.""" """Create database tables on startup and validate constants."""
# Validate shared constants match backend definitions # Validate shared constants match backend definitions
validate_shared_constants() validate_shared_constants()
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
yield yield

View file

@ -1,9 +1,24 @@
from datetime import datetime, date, time, timezone from datetime import UTC, date, datetime, time
from enum import Enum as PyEnum from enum import Enum as PyEnum
from typing import TypedDict from typing import TypedDict
from sqlalchemy import Integer, String, Float, DateTime, Date, Time, ForeignKey, Table, Column, Enum, UniqueConstraint, select
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy import (
Column,
Date,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
String,
Table,
Time,
UniqueConstraint,
select,
)
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship
from database import Base from database import Base
@ -14,25 +29,26 @@ class RoleConfig(TypedDict):
class Permission(str, PyEnum): class Permission(str, PyEnum):
"""All available permissions in the system.""" """All available permissions in the system."""
# Counter permissions # Counter permissions
VIEW_COUNTER = "view_counter" VIEW_COUNTER = "view_counter"
INCREMENT_COUNTER = "increment_counter" INCREMENT_COUNTER = "increment_counter"
# Sum permissions # Sum permissions
USE_SUM = "use_sum" USE_SUM = "use_sum"
# Audit permissions # Audit permissions
VIEW_AUDIT = "view_audit" VIEW_AUDIT = "view_audit"
# Invite permissions # Invite permissions
MANAGE_INVITES = "manage_invites" MANAGE_INVITES = "manage_invites"
VIEW_OWN_INVITES = "view_own_invites" VIEW_OWN_INVITES = "view_own_invites"
# Booking permissions (regular users) # Booking permissions (regular users)
BOOK_APPOINTMENT = "book_appointment" BOOK_APPOINTMENT = "book_appointment"
VIEW_OWN_APPOINTMENTS = "view_own_appointments" VIEW_OWN_APPOINTMENTS = "view_own_appointments"
CANCEL_OWN_APPOINTMENT = "cancel_own_appointment" CANCEL_OWN_APPOINTMENT = "cancel_own_appointment"
# Availability/Appointments permissions (admin) # Availability/Appointments permissions (admin)
MANAGE_AVAILABILITY = "manage_availability" MANAGE_AVAILABILITY = "manage_availability"
VIEW_ALL_APPOINTMENTS = "view_all_appointments" VIEW_ALL_APPOINTMENTS = "view_all_appointments"
@ -41,6 +57,7 @@ class Permission(str, PyEnum):
class InviteStatus(str, PyEnum): class InviteStatus(str, PyEnum):
"""Status of an invite.""" """Status of an invite."""
READY = "ready" READY = "ready"
SPENT = "spent" SPENT = "spent"
REVOKED = "revoked" REVOKED = "revoked"
@ -48,6 +65,7 @@ class InviteStatus(str, PyEnum):
class AppointmentStatus(str, PyEnum): class AppointmentStatus(str, PyEnum):
"""Status of an appointment.""" """Status of an appointment."""
BOOKED = "booked" BOOKED = "booked"
CANCELLED_BY_USER = "cancelled_by_user" CANCELLED_BY_USER = "cancelled_by_user"
CANCELLED_BY_ADMIN = "cancelled_by_admin" CANCELLED_BY_ADMIN = "cancelled_by_admin"
@ -60,7 +78,7 @@ ROLE_REGULAR = "regular"
# Role definitions with their permissions # Role definitions with their permissions
ROLE_DEFINITIONS: dict[str, RoleConfig] = { ROLE_DEFINITIONS: dict[str, RoleConfig] = {
ROLE_ADMIN: { ROLE_ADMIN: {
"description": "Administrator with audit, invite, and appointment management access", "description": "Administrator with audit/invite/appointment access",
"permissions": [ "permissions": [
Permission.VIEW_AUDIT, Permission.VIEW_AUDIT,
Permission.MANAGE_INVITES, Permission.MANAGE_INVITES,
@ -88,7 +106,12 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = {
role_permissions = Table( role_permissions = Table(
"role_permissions", "role_permissions",
Base.metadata, Base.metadata,
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True), Column(
"role_id",
Integer,
ForeignKey("roles.id", ondelete="CASCADE"),
primary_key=True,
),
Column("permission", Enum(Permission), primary_key=True), Column("permission", Enum(Permission), primary_key=True),
) )
@ -97,8 +120,18 @@ role_permissions = Table(
user_roles = Table( user_roles = Table(
"user_roles", "user_roles",
Base.metadata, Base.metadata,
Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True), Column(
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True), "user_id",
Integer,
ForeignKey("users.id", ondelete="CASCADE"),
primary_key=True,
),
Column(
"role_id",
Integer,
ForeignKey("roles.id", ondelete="CASCADE"),
primary_key=True,
),
) )
@ -108,7 +141,7 @@ class Role(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=True) description: Mapped[str] = mapped_column(String(255), nullable=True)
# Relationship to users # Relationship to users
users: Mapped[list["User"]] = relationship( users: Mapped[list["User"]] = relationship(
"User", "User",
@ -118,31 +151,42 @@ class Role(Base):
async def get_permissions(self, db: AsyncSession) -> set[Permission]: async def get_permissions(self, db: AsyncSession) -> set[Permission]:
"""Get all permissions for this role.""" """Get all permissions for this role."""
result = await db.execute( query = select(role_permissions.c.permission).where(
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id) role_permissions.c.role_id == self.id
) )
result = await db.execute(query)
return {row[0] for row in result.fetchall()} return {row[0] for row in result.fetchall()}
async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None: async def set_permissions(
self, db: AsyncSession, permissions: list[Permission]
) -> None:
"""Set all permissions for this role (replaces existing).""" """Set all permissions for this role (replaces existing)."""
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id)) delete_query = role_permissions.delete().where(
role_permissions.c.role_id == self.id
)
await db.execute(delete_query)
for perm in permissions: for perm in permissions:
await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm)) insert_query = role_permissions.insert().values(
role_id=self.id, permission=perm
)
await db.execute(insert_query)
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) email: Mapped[str] = mapped_column(
String(255), unique=True, nullable=False, index=True
)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
# Contact details (all optional) # Contact details (all optional)
contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True) contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
telegram: Mapped[str | None] = mapped_column(String(64), nullable=True) telegram: Mapped[str | None] = mapped_column(String(64), nullable=True)
signal: Mapped[str | None] = mapped_column(String(64), nullable=True) signal: Mapped[str | None] = mapped_column(String(64), nullable=True)
nostr_npub: Mapped[str | None] = mapped_column(String(63), nullable=True) nostr_npub: Mapped[str | None] = mapped_column(String(63), nullable=True)
# Godfather (who invited this user) - null for seeded/admin users # Godfather (who invited this user) - null for seeded/admin users
godfather_id: Mapped[int | None] = mapped_column( godfather_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("users.id"), nullable=True Integer, ForeignKey("users.id"), nullable=True
@ -152,7 +196,7 @@ class User(Base):
remote_side="User.id", remote_side="User.id",
foreign_keys=[godfather_id], foreign_keys=[godfather_id],
) )
# Relationship to roles # Relationship to roles
roles: Mapped[list[Role]] = relationship( roles: Mapped[list[Role]] = relationship(
"Role", "Role",
@ -192,12 +236,14 @@ class SumRecord(Base):
__tablename__ = "sum_records" __tablename__ = "sum_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
a: Mapped[float] = mapped_column(Float, nullable=False) a: Mapped[float] = mapped_column(Float, nullable=False)
b: Mapped[float] = mapped_column(Float, nullable=False) b: Mapped[float] = mapped_column(Float, nullable=False)
result: Mapped[float] = mapped_column(Float, nullable=False) result: Mapped[float] = mapped_column(Float, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
) )
@ -205,11 +251,13 @@ class CounterRecord(Base):
__tablename__ = "counter_records" __tablename__ = "counter_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
value_before: Mapped[int] = mapped_column(Integer, nullable=False) value_before: Mapped[int] = mapped_column(Integer, nullable=False)
value_after: Mapped[int] = mapped_column(Integer, nullable=False) value_after: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
) )
@ -217,11 +265,13 @@ class Invite(Base):
__tablename__ = "invites" __tablename__ = "invites"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
identifier: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) identifier: Mapped[str] = mapped_column(
String(64), unique=True, nullable=False, index=True
)
status: Mapped[InviteStatus] = mapped_column( status: Mapped[InviteStatus] = mapped_column(
Enum(InviteStatus), nullable=False, default=InviteStatus.READY Enum(InviteStatus), nullable=False, default=InviteStatus.READY
) )
# Godfather - the user who owns this invite # Godfather - the user who owns this invite
godfather_id: Mapped[int] = mapped_column( godfather_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True Integer, ForeignKey("users.id"), nullable=False, index=True
@ -231,7 +281,7 @@ class Invite(Base):
foreign_keys=[godfather_id], foreign_keys=[godfather_id],
lazy="joined", lazy="joined",
) )
# User who used this invite (null until spent) # User who used this invite (null until spent)
used_by_id: Mapped[int | None] = mapped_column( used_by_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("users.id"), nullable=True Integer, ForeignKey("users.id"), nullable=True
@ -241,17 +291,22 @@ class Invite(Base):
foreign_keys=[used_by_id], foreign_keys=[used_by_id],
lazy="joined", lazy="joined",
) )
# Timestamps # Timestamps
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
spent_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
revoked_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
) )
spent_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
class Availability(Base): class Availability(Base):
"""Admin availability slots for booking.""" """Admin availability slots for booking."""
__tablename__ = "availability" __tablename__ = "availability"
__table_args__ = ( __table_args__ = (
UniqueConstraint("date", "start_time", name="uq_availability_date_start"), UniqueConstraint("date", "start_time", name="uq_availability_date_start"),
@ -262,34 +317,37 @@ class Availability(Base):
start_time: Mapped[time] = mapped_column(Time, nullable=False) start_time: Mapped[time] = mapped_column(Time, nullable=False)
end_time: Mapped[time] = mapped_column(Time, nullable=False) end_time: Mapped[time] = mapped_column(Time, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc), default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(timezone.utc) onupdate=lambda: datetime.now(UTC),
) )
class Appointment(Base): class Appointment(Base):
"""User appointment bookings.""" """User appointment bookings."""
__tablename__ = "appointments" __tablename__ = "appointments"
__table_args__ = ( __table_args__ = (UniqueConstraint("slot_start", name="uq_appointment_slot_start"),)
UniqueConstraint("slot_start", name="uq_appointment_slot_start"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column( user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True Integer, ForeignKey("users.id"), nullable=False, index=True
) )
user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined") user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined")
slot_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) slot_start: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, index=True
)
slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
note: Mapped[str | None] = mapped_column(String(144), nullable=True) note: Mapped[str | None] = mapped_column(String(144), nullable=True)
status: Mapped[AppointmentStatus] = mapped_column( status: Mapped[AppointmentStatus] = mapped_column(
Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
cancelled_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
) )
cancelled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)

View file

@ -20,6 +20,7 @@ dev = [
"httpx>=0.28.1", "httpx>=0.28.1",
"aiosqlite>=0.20.0", "aiosqlite>=0.20.0",
"mypy>=1.13.0", "mypy>=1.13.0",
"ruff>=0.14.10",
] ]
[tool.mypy] [tool.mypy]
@ -30,3 +31,27 @@ check_untyped_defs = true
ignore_missing_imports = true ignore_missing_imports = true
exclude = ["tests/"] exclude = ["tests/"]
[tool.ruff]
line-length = 88
target-version = "py311"
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"UP", # pyupgrade
"SIM", # flake8-simplify
"RUF", # ruff-specific rules
]
ignore = [
"B008", # function-call-in-default-argument (standard FastAPI pattern with Depends)
]
[tool.ruff.format]
quote-style = "double"
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["E501"] # Allow longer lines in tests for readability

View file

@ -1,22 +1,23 @@
"""Audit routes for viewing action records.""" """Audit routes for viewing action records."""
from typing import Callable, TypeVar
from collections.abc import Callable
from typing import TypeVar
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select, func, desc from sqlalchemy import desc, func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import User, SumRecord, CounterRecord, Permission from models import CounterRecord, Permission, SumRecord, User
from schemas import ( from schemas import (
CounterRecordResponse, CounterRecordResponse,
SumRecordResponse,
PaginatedCounterRecords, PaginatedCounterRecords,
PaginatedSumRecords, PaginatedSumRecords,
SumRecordResponse,
) )
router = APIRouter(prefix="/api/audit", tags=["audit"]) router = APIRouter(prefix="/api/audit", tags=["audit"])
R = TypeVar("R", bound=BaseModel) R = TypeVar("R", bound=BaseModel)

View file

@ -1,5 +1,6 @@
"""Authentication routes for register, login, logout, and current user.""" """Authentication routes for register, login, logout, and current user."""
from datetime import datetime, timezone
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Response, status from fastapi import APIRouter, Depends, HTTPException, Response, status
from sqlalchemy import select from sqlalchemy import select
@ -9,18 +10,17 @@ from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES, ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME, COOKIE_NAME,
COOKIE_SECURE, COOKIE_SECURE,
get_password_hash,
get_user_by_email,
authenticate_user, authenticate_user,
build_user_response,
create_access_token, create_access_token,
get_current_user, get_current_user,
build_user_response, get_password_hash,
get_user_by_email,
) )
from database import get_db from database import get_db
from invite_utils import normalize_identifier from invite_utils import normalize_identifier
from models import User, Role, ROLE_REGULAR, Invite, InviteStatus from models import ROLE_REGULAR, Invite, InviteStatus, Role, User
from schemas import UserLogin, UserResponse, RegisterWithInvite from schemas import RegisterWithInvite, UserLogin, UserResponse
router = APIRouter(prefix="/api/auth", tags=["auth"]) router = APIRouter(prefix="/api/auth", tags=["auth"])
@ -52,9 +52,8 @@ async def register(
"""Register a new user using an invite code.""" """Register a new user using an invite code."""
# Validate invite # Validate invite
normalized_identifier = normalize_identifier(user_data.invite_identifier) normalized_identifier = normalize_identifier(user_data.invite_identifier)
result = await db.execute( query = select(Invite).where(Invite.identifier == normalized_identifier)
select(Invite).where(Invite.identifier == normalized_identifier) result = await db.execute(query)
)
invite = result.scalar_one_or_none() invite = result.scalar_one_or_none()
# Return same error for not found, spent, and revoked to avoid information leakage # Return same error for not found, spent, and revoked to avoid information leakage
@ -90,7 +89,7 @@ async def register(
# Mark invite as spent # Mark invite as spent
invite.status = InviteStatus.SPENT invite.status = InviteStatus.SPENT
invite.used_by_id = user.id invite.used_by_id = user.id
invite.spent_at = datetime.now(timezone.utc) invite.spent_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(user) await db.refresh(user)

View file

@ -1,28 +1,28 @@
"""Availability routes for admin to manage booking availability.""" """Availability routes for admin to manage booking availability."""
from datetime import date, timedelta from datetime import date, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, delete, and_ from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import User, Availability, Permission from models import Availability, Permission, User
from schemas import ( from schemas import (
TimeSlot,
AvailabilityDay, AvailabilityDay,
AvailabilityResponse, AvailabilityResponse,
SetAvailabilityRequest,
CopyAvailabilityRequest, CopyAvailabilityRequest,
SetAvailabilityRequest,
TimeSlot,
) )
from shared_constants import MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS
router = APIRouter(prefix="/api/admin/availability", tags=["availability"]) router = APIRouter(prefix="/api/admin/availability", tags=["availability"])
def _get_date_range_bounds() -> tuple[date, date]: def _get_date_range_bounds() -> tuple[date, date]:
"""Get the valid date range for availability (using MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS).""" """Get valid date range (MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
today = date.today() today = date.today()
min_date = today + timedelta(days=MIN_ADVANCE_DAYS) min_date = today + timedelta(days=MIN_ADVANCE_DAYS)
max_date = today + timedelta(days=MAX_ADVANCE_DAYS) max_date = today + timedelta(days=MAX_ADVANCE_DAYS)
@ -34,12 +34,14 @@ def _validate_date_in_range(d: date, min_date: date, max_date: date) -> None:
if d < min_date: if d < min_date:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot set availability for past dates. Earliest allowed: {min_date}", detail=f"Cannot set availability for past dates. "
f"Earliest allowed: {min_date}",
) )
if d > max_date: if d > max_date:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot set availability more than {MAX_ADVANCE_DAYS} days ahead. Latest allowed: {max_date}", detail=f"Cannot set more than {MAX_ADVANCE_DAYS} days ahead. "
f"Latest allowed: {max_date}",
) )
@ -56,7 +58,7 @@ async def get_availability(
status_code=400, status_code=400,
detail="'from' date must be before or equal to 'to' date", detail="'from' date must be before or equal to 'to' date",
) )
# Query availability in range # Query availability in range
result = await db.execute( result = await db.execute(
select(Availability) select(Availability)
@ -64,23 +66,24 @@ async def get_availability(
.order_by(Availability.date, Availability.start_time) .order_by(Availability.date, Availability.start_time)
) )
slots = result.scalars().all() slots = result.scalars().all()
# Group by date # Group by date
days_dict: dict[date, list[TimeSlot]] = {} days_dict: dict[date, list[TimeSlot]] = {}
for slot in slots: for slot in slots:
if slot.date not in days_dict: if slot.date not in days_dict:
days_dict[slot.date] = [] days_dict[slot.date] = []
days_dict[slot.date].append(TimeSlot( days_dict[slot.date].append(
start_time=slot.start_time, TimeSlot(
end_time=slot.end_time, start_time=slot.start_time,
)) end_time=slot.end_time,
)
)
# Convert to response format # Convert to response format
days = [ days = [
AvailabilityDay(date=d, slots=days_dict[d]) AvailabilityDay(date=d, slots=days_dict[d]) for d in sorted(days_dict.keys())
for d in sorted(days_dict.keys())
] ]
return AvailabilityResponse(days=days) return AvailabilityResponse(days=days)
@ -93,29 +96,31 @@ async def set_availability(
"""Set availability for a specific date. Replaces any existing availability.""" """Set availability for a specific date. Replaces any existing availability."""
min_date, max_date = _get_date_range_bounds() min_date, max_date = _get_date_range_bounds()
_validate_date_in_range(request.date, min_date, max_date) _validate_date_in_range(request.date, min_date, max_date)
# Validate slots don't overlap # Validate slots don't overlap
sorted_slots = sorted(request.slots, key=lambda s: s.start_time) sorted_slots = sorted(request.slots, key=lambda s: s.start_time)
for i in range(len(sorted_slots) - 1): for i in range(len(sorted_slots) - 1):
if sorted_slots[i].end_time > sorted_slots[i + 1].start_time: if sorted_slots[i].end_time > sorted_slots[i + 1].start_time:
end = sorted_slots[i].end_time
start = sorted_slots[i + 1].start_time
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Time slots overlap on {request.date}: slot ending at {sorted_slots[i].end_time} overlaps with slot starting at {sorted_slots[i + 1].start_time}. Please ensure all time slots are non-overlapping.", detail=f"Time slots overlap: slot ending at {end} "
f"overlaps with slot starting at {start}",
) )
# Validate each slot's end_time > start_time # Validate each slot's end_time > start_time
for slot in request.slots: for slot in request.slots:
if slot.end_time <= slot.start_time: if slot.end_time <= slot.start_time:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid time slot on {request.date}: end time {slot.end_time} must be after start time {slot.start_time}. Please correct the time range.", detail=f"Invalid time slot: end time {slot.end_time} "
f"must be after start time {slot.start_time}",
) )
# Delete existing availability for this date # Delete existing availability for this date
await db.execute( await db.execute(delete(Availability).where(Availability.date == request.date))
delete(Availability).where(Availability.date == request.date)
)
# Create new availability slots # Create new availability slots
for slot in request.slots: for slot in request.slots:
availability = Availability( availability = Availability(
@ -124,9 +129,9 @@ async def set_availability(
end_time=slot.end_time, end_time=slot.end_time,
) )
db.add(availability) db.add(availability)
await db.commit() await db.commit()
return AvailabilityDay(date=request.date, slots=request.slots) return AvailabilityDay(date=request.date, slots=request.slots)
@ -138,14 +143,14 @@ async def copy_availability(
) -> AvailabilityResponse: ) -> AvailabilityResponse:
"""Copy availability from one day to multiple target days.""" """Copy availability from one day to multiple target days."""
min_date, max_date = _get_date_range_bounds() min_date, max_date = _get_date_range_bounds()
# Validate source date is in range (for consistency, though DB query would fail anyway) # Validate source date is in range
_validate_date_in_range(request.source_date, min_date, max_date) _validate_date_in_range(request.source_date, min_date, max_date)
# Validate target dates # Validate target dates
for target_date in request.target_dates: for target_date in request.target_dates:
_validate_date_in_range(target_date, min_date, max_date) _validate_date_in_range(target_date, min_date, max_date)
# Get source availability # Get source availability
result = await db.execute( result = await db.execute(
select(Availability) select(Availability)
@ -153,13 +158,13 @@ async def copy_availability(
.order_by(Availability.start_time) .order_by(Availability.start_time)
) )
source_slots = result.scalars().all() source_slots = result.scalars().all()
if not source_slots: if not source_slots:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"No availability found for source date {request.source_date}", detail=f"No availability found for source date {request.source_date}",
) )
# Copy to each target date within a single atomic transaction # Copy to each target date within a single atomic transaction
# All deletes and inserts happen before commit, ensuring atomicity # All deletes and inserts happen before commit, ensuring atomicity
copied_days: list[AvailabilityDay] = [] copied_days: list[AvailabilityDay] = []
@ -167,12 +172,11 @@ async def copy_availability(
for target_date in request.target_dates: for target_date in request.target_dates:
if target_date == request.source_date: if target_date == request.source_date:
continue # Skip copying to self continue # Skip copying to self
# Delete existing availability for target date # Delete existing availability for target date
await db.execute( del_query = delete(Availability).where(Availability.date == target_date)
delete(Availability).where(Availability.date == target_date) await db.execute(del_query)
)
# Copy slots # Copy slots
target_slots: list[TimeSlot] = [] target_slots: list[TimeSlot] = []
for source_slot in source_slots: for source_slot in source_slots:
@ -182,19 +186,20 @@ async def copy_availability(
end_time=source_slot.end_time, end_time=source_slot.end_time,
) )
db.add(new_availability) db.add(new_availability)
target_slots.append(TimeSlot( target_slots.append(
start_time=source_slot.start_time, TimeSlot(
end_time=source_slot.end_time, start_time=source_slot.start_time,
)) end_time=source_slot.end_time,
)
)
copied_days.append(AvailabilityDay(date=target_date, slots=target_slots)) copied_days.append(AvailabilityDay(date=target_date, slots=target_slots))
# Commit all changes atomically # Commit all changes atomically
await db.commit() await db.commit()
except Exception: except Exception:
# Rollback on any error to maintain atomicity # Rollback on any error to maintain atomicity
await db.rollback() await db.rollback()
raise raise
return AvailabilityResponse(days=copied_days)
return AvailabilityResponse(days=copied_days)

View file

@ -1,24 +1,24 @@
"""Booking routes for users to book appointments.""" """Booking routes for users to book appointments."""
from datetime import date, datetime, time, timedelta, timezone
from datetime import UTC, date, datetime, time, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, and_, func from sqlalchemy import and_, func, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import User, Availability, Appointment, AppointmentStatus, Permission from models import Appointment, AppointmentStatus, Availability, Permission, User
from schemas import ( from schemas import (
BookableSlot,
AvailableSlotsResponse,
BookingRequest,
AppointmentResponse, AppointmentResponse,
AvailableSlotsResponse,
BookableSlot,
BookingRequest,
PaginatedAppointments, PaginatedAppointments,
) )
from shared_constants import SLOT_DURATION_MINUTES, MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS, SLOT_DURATION_MINUTES
router = APIRouter(prefix="/api/booking", tags=["booking"]) router = APIRouter(prefix="/api/booking", tags=["booking"])
@ -28,7 +28,7 @@ def _to_appointment_response(
user_email: str | None = None, user_email: str | None = None,
) -> AppointmentResponse: ) -> AppointmentResponse:
"""Convert an Appointment model to AppointmentResponse schema. """Convert an Appointment model to AppointmentResponse schema.
Args: Args:
appointment: The appointment model instance appointment: The appointment model instance
user_email: Optional user email. If not provided, uses appointment.user.email user_email: Optional user email. If not provided, uses appointment.user.email
@ -49,7 +49,7 @@ def _to_appointment_response(
def _get_valid_minute_boundaries() -> tuple[int, ...]: def _get_valid_minute_boundaries() -> tuple[int, ...]:
"""Get valid minute boundaries based on SLOT_DURATION_MINUTES. """Get valid minute boundaries based on SLOT_DURATION_MINUTES.
Assumes SLOT_DURATION_MINUTES divides 60 evenly (e.g., 15 minutes = 0, 15, 30, 45). Assumes SLOT_DURATION_MINUTES divides 60 evenly (e.g., 15 minutes = 0, 15, 30, 45).
""" """
boundaries: list[int] = [] boundaries: list[int] = []
@ -74,12 +74,14 @@ def _validate_booking_date(d: date) -> None:
if d < min_date: if d < min_date:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot book for today or past dates. Earliest bookable date: {min_date}", detail=f"Cannot book for today or past dates. "
f"Earliest bookable: {min_date}",
) )
if d > max_date: if d > max_date:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. Latest bookable: {max_date}", detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. "
f"Latest bookable: {max_date}",
) )
@ -89,18 +91,18 @@ def _expand_availability_to_slots(
) -> list[BookableSlot]: ) -> list[BookableSlot]:
"""Expand availability time ranges into 15-minute bookable slots.""" """Expand availability time ranges into 15-minute bookable slots."""
result: list[BookableSlot] = [] result: list[BookableSlot] = []
for avail in availability_slots: for avail in availability_slots:
# Create datetime objects for start and end # Create datetime objects for start and end
current = datetime.combine(target_date, avail.start_time, tzinfo=timezone.utc) current = datetime.combine(target_date, avail.start_time, tzinfo=UTC)
end = datetime.combine(target_date, avail.end_time, tzinfo=timezone.utc) end = datetime.combine(target_date, avail.end_time, tzinfo=UTC)
# Generate 15-minute slots # Generate 15-minute slots
while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end: while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end:
slot_end = current + timedelta(minutes=SLOT_DURATION_MINUTES) slot_end = current + timedelta(minutes=SLOT_DURATION_MINUTES)
result.append(BookableSlot(start_time=current, end_time=slot_end)) result.append(BookableSlot(start_time=current, end_time=slot_end))
current = slot_end current = slot_end
return result return result
@ -112,7 +114,7 @@ async def get_available_slots(
) -> AvailableSlotsResponse: ) -> AvailableSlotsResponse:
"""Get available booking slots for a specific date.""" """Get available booking slots for a specific date."""
_validate_booking_date(target_date) _validate_booking_date(target_date)
# Get availability for this date # Get availability for this date
result = await db.execute( result = await db.execute(
select(Availability) select(Availability)
@ -120,20 +122,19 @@ async def get_available_slots(
.order_by(Availability.start_time) .order_by(Availability.start_time)
) )
availability_slots = result.scalars().all() availability_slots = result.scalars().all()
if not availability_slots: if not availability_slots:
return AvailableSlotsResponse(date=target_date, slots=[]) return AvailableSlotsResponse(date=target_date, slots=[])
# Expand to 15-minute slots # Expand to 15-minute slots
all_slots = _expand_availability_to_slots(availability_slots, target_date) all_slots = _expand_availability_to_slots(availability_slots, target_date)
# Get existing booked appointments for this date # Get existing booked appointments for this date
day_start = datetime.combine(target_date, time.min, tzinfo=timezone.utc) day_start = datetime.combine(target_date, time.min, tzinfo=UTC)
day_end = datetime.combine(target_date, time.max, tzinfo=timezone.utc) day_end = datetime.combine(target_date, time.max, tzinfo=UTC)
result = await db.execute( result = await db.execute(
select(Appointment.slot_start) select(Appointment.slot_start).where(
.where(
and_( and_(
Appointment.slot_start >= day_start, Appointment.slot_start >= day_start,
Appointment.slot_start <= day_end, Appointment.slot_start <= day_end,
@ -142,13 +143,12 @@ async def get_available_slots(
) )
) )
booked_starts = {row[0] for row in result.fetchall()} booked_starts = {row[0] for row in result.fetchall()}
# Filter out already booked slots # Filter out already booked slots
available_slots = [ available_slots = [
slot for slot in all_slots slot for slot in all_slots if slot.start_time not in booked_starts
if slot.start_time not in booked_starts
] ]
return AvailableSlotsResponse(date=target_date, slots=available_slots) return AvailableSlotsResponse(date=target_date, slots=available_slots)
@ -161,27 +161,28 @@ async def create_booking(
"""Book an appointment slot.""" """Book an appointment slot."""
slot_date = request.slot_start.date() slot_date = request.slot_start.date()
_validate_booking_date(slot_date) _validate_booking_date(slot_date)
# Validate slot is on the correct minute boundary (derived from SLOT_DURATION_MINUTES) # Validate slot is on the correct minute boundary
valid_minutes = _get_valid_minute_boundaries() valid_minutes = _get_valid_minute_boundaries()
if request.slot_start.minute not in valid_minutes: if request.slot_start.minute not in valid_minutes:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Slot start time must be on {SLOT_DURATION_MINUTES}-minute boundary (valid minutes: {valid_minutes})", detail=f"Slot must be on {SLOT_DURATION_MINUTES}-minute boundary "
f"(valid minutes: {valid_minutes})",
) )
if request.slot_start.second != 0 or request.slot_start.microsecond != 0: if request.slot_start.second != 0 or request.slot_start.microsecond != 0:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="Slot start time must not have seconds or microseconds", detail="Slot start time must not have seconds or microseconds",
) )
# Verify slot falls within availability # Verify slot falls within availability
slot_start_time = request.slot_start.time() slot_start_time = request.slot_start.time()
slot_end_time = (request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)).time() slot_end_dt = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
slot_end_time = slot_end_dt.time()
result = await db.execute( result = await db.execute(
select(Availability) select(Availability).where(
.where(
and_( and_(
Availability.date == slot_date, Availability.date == slot_date,
Availability.start_time <= slot_start_time, Availability.start_time <= slot_start_time,
@ -190,13 +191,15 @@ async def create_booking(
) )
) )
matching_availability = result.scalar_one_or_none() matching_availability = result.scalar_one_or_none()
if not matching_availability: if not matching_availability:
slot_str = request.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Selected slot at {request.slot_start.strftime('%Y-%m-%d %H:%M')} UTC is not within any available time ranges for {slot_date}. Please select a different time slot.", detail=f"Selected slot at {slot_str} UTC is not within "
f"any available time ranges for {slot_date}",
) )
# Create the appointment # Create the appointment
slot_end = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES) slot_end = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
appointment = Appointment( appointment = Appointment(
@ -206,9 +209,9 @@ async def create_booking(
note=request.note, note=request.note,
status=AppointmentStatus.BOOKED, status=AppointmentStatus.BOOKED,
) )
db.add(appointment) db.add(appointment)
try: try:
await db.commit() await db.commit()
await db.refresh(appointment) await db.refresh(appointment)
@ -216,9 +219,9 @@ async def create_booking(
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail="This slot has already been booked. Please select another slot.", detail="This slot has already been booked. Select another slot.",
) ) from None
return _to_appointment_response(appointment, current_user.email) return _to_appointment_response(appointment, current_user.email)
@ -241,60 +244,63 @@ async def get_my_appointments(
.order_by(Appointment.slot_start.desc()) .order_by(Appointment.slot_start.desc())
) )
appointments = result.scalars().all() appointments = result.scalars().all()
return [ return [_to_appointment_response(apt, current_user.email) for apt in appointments]
_to_appointment_response(apt, current_user.email)
for apt in appointments
]
@appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse) @appointments_router.post(
"/{appointment_id}/cancel", response_model=AppointmentResponse
)
async def cancel_my_appointment( async def cancel_my_appointment(
appointment_id: int, appointment_id: int,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)), current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)),
) -> AppointmentResponse: ) -> AppointmentResponse:
"""Cancel one of the current user's appointments.""" """Cancel one of the current user's appointments."""
# Get the appointment with explicit eager loading of user relationship # Get the appointment with eager loading of user relationship
result = await db.execute( result = await db.execute(
select(Appointment) select(Appointment)
.options(joinedload(Appointment.user)) .options(joinedload(Appointment.user))
.where(Appointment.id == appointment_id) .where(Appointment.id == appointment_id)
) )
appointment = result.scalar_one_or_none() appointment = result.scalar_one_or_none()
if not appointment: if not appointment:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.", detail=f"Appointment {appointment_id} not found",
) )
# Verify ownership # Verify ownership
if appointment.user_id != current_user.id: if appointment.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Cannot cancel another user's appointment") raise HTTPException(
status_code=403,
detail="Cannot cancel another user's appointment",
)
# Check if already cancelled # Check if already cancelled
if appointment.status != AppointmentStatus.BOOKED: if appointment.status != AppointmentStatus.BOOKED:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot cancel appointment with status '{appointment.status.value}'" detail=f"Cannot cancel: status is '{appointment.status.value}'",
) )
# Check if appointment is in the past # Check if appointment is in the past
if appointment.slot_start <= datetime.now(timezone.utc): if appointment.slot_start <= datetime.now(UTC):
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC" apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started." detail=f"Cannot cancel appointment at {apt_time} UTC: "
"already started or in the past",
) )
# Cancel the appointment # Cancel the appointment
appointment.status = AppointmentStatus.CANCELLED_BY_USER appointment.status = AppointmentStatus.CANCELLED_BY_USER
appointment.cancelled_at = datetime.now(timezone.utc) appointment.cancelled_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(appointment) await db.refresh(appointment)
return _to_appointment_response(appointment, current_user.email) return _to_appointment_response(appointment, current_user.email)
@ -302,7 +308,9 @@ async def cancel_my_appointment(
# Admin Appointments Endpoints # Admin Appointments Endpoints
# ============================================================================= # =============================================================================
admin_appointments_router = APIRouter(prefix="/api/admin/appointments", tags=["admin-appointments"]) admin_appointments_router = APIRouter(
prefix="/api/admin/appointments", tags=["admin-appointments"]
)
@admin_appointments_router.get("", response_model=PaginatedAppointments) @admin_appointments_router.get("", response_model=PaginatedAppointments)
@ -317,7 +325,7 @@ async def get_all_appointments(
count_result = await db.execute(select(func.count(Appointment.id))) count_result = await db.execute(select(func.count(Appointment.id)))
total = count_result.scalar() or 0 total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1 total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated appointments with explicit eager loading of user relationship # Get paginated appointments with explicit eager loading of user relationship
offset = (page - 1) * per_page offset = (page - 1) * per_page
result = await db.execute( result = await db.execute(
@ -328,13 +336,13 @@ async def get_all_appointments(
.limit(per_page) .limit(per_page)
) )
appointments = result.scalars().all() appointments = result.scalars().all()
# Build responses using the eager-loaded user relationship # Build responses using the eager-loaded user relationship
records = [ records = [
_to_appointment_response(apt) # Uses eager-loaded relationship _to_appointment_response(apt) # Uses eager-loaded relationship
for apt in appointments for apt in appointments
] ]
return PaginatedAppointments( return PaginatedAppointments(
records=records, records=records,
total=total, total=total,
@ -344,47 +352,52 @@ async def get_all_appointments(
) )
@admin_appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse) @admin_appointments_router.post(
"/{appointment_id}/cancel", response_model=AppointmentResponse
)
async def admin_cancel_appointment( async def admin_cancel_appointment(
appointment_id: int, appointment_id: int,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.CANCEL_ANY_APPOINTMENT)), _current_user: User = Depends(
require_permission(Permission.CANCEL_ANY_APPOINTMENT)
),
) -> AppointmentResponse: ) -> AppointmentResponse:
"""Cancel any appointment (admin only).""" """Cancel any appointment (admin only)."""
# Get the appointment with explicit eager loading of user relationship # Get the appointment with eager loading of user relationship
result = await db.execute( result = await db.execute(
select(Appointment) select(Appointment)
.options(joinedload(Appointment.user)) .options(joinedload(Appointment.user))
.where(Appointment.id == appointment_id) .where(Appointment.id == appointment_id)
) )
appointment = result.scalar_one_or_none() appointment = result.scalar_one_or_none()
if not appointment: if not appointment:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.", detail=f"Appointment {appointment_id} not found",
) )
# Check if already cancelled # Check if already cancelled
if appointment.status != AppointmentStatus.BOOKED: if appointment.status != AppointmentStatus.BOOKED:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot cancel appointment with status '{appointment.status.value}'" detail=f"Cannot cancel: status is '{appointment.status.value}'",
) )
# Check if appointment is in the past # Check if appointment is in the past
if appointment.slot_start <= datetime.now(timezone.utc): if appointment.slot_start <= datetime.now(UTC):
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC" apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started." detail=f"Cannot cancel appointment at {apt_time} UTC: "
"already started or in the past",
) )
# Cancel the appointment # Cancel the appointment
appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN
appointment.cancelled_at = datetime.now(timezone.utc) appointment.cancelled_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(appointment) await db.refresh(appointment)
return _to_appointment_response(appointment) # Uses eager-loaded relationship return _to_appointment_response(appointment) # Uses eager-loaded relationship

View file

@ -1,12 +1,12 @@
"""Counter routes.""" """Counter routes."""
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import Counter, User, CounterRecord, Permission from models import Counter, CounterRecord, Permission, User
router = APIRouter(prefix="/api/counter", tags=["counter"]) router = APIRouter(prefix="/api/counter", tags=["counter"])

View file

@ -1,25 +1,29 @@
"""Invite routes for public check, user invites, and admin management.""" """Invite routes for public check, user invites, and admin management."""
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status, Query from datetime import UTC, datetime
from sqlalchemy import select, func, desc
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import desc, func, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format from invite_utils import (
from models import User, Invite, InviteStatus, Permission generate_invite_identifier,
is_valid_identifier_format,
normalize_identifier,
)
from models import Invite, InviteStatus, Permission, User
from schemas import ( from schemas import (
AdminUserResponse,
InviteCheckResponse, InviteCheckResponse,
InviteCreate, InviteCreate,
InviteResponse, InviteResponse,
UserInviteResponse,
PaginatedInviteRecords, PaginatedInviteRecords,
AdminUserResponse, UserInviteResponse,
) )
router = APIRouter(prefix="/api/invites", tags=["invites"]) router = APIRouter(prefix="/api/invites", tags=["invites"])
admin_router = APIRouter(prefix="/api/admin", tags=["admin"]) admin_router = APIRouter(prefix="/api/admin", tags=["admin"])
@ -54,9 +58,7 @@ async def check_invite(
if not is_valid_identifier_format(normalized): if not is_valid_identifier_format(normalized):
return InviteCheckResponse(valid=False, error="Invalid invite code format") return InviteCheckResponse(valid=False, error="Invalid invite code format")
result = await db.execute( result = await db.execute(select(Invite).where(Invite.identifier == normalized))
select(Invite).where(Invite.identifier == normalized)
)
invite = result.scalar_one_or_none() invite = result.scalar_one_or_none()
# Return same error for not found, spent, and revoked to avoid information leakage # Return same error for not found, spent, and revoked to avoid information leakage
@ -112,9 +114,7 @@ async def create_invite(
) -> InviteResponse: ) -> InviteResponse:
"""Create a new invite for a specified godfather user.""" """Create a new invite for a specified godfather user."""
# Validate godfather exists # Validate godfather exists
result = await db.execute( result = await db.execute(select(User.id).where(User.id == data.godfather_id))
select(User.id).where(User.id == data.godfather_id)
)
godfather_id = result.scalar_one_or_none() godfather_id = result.scalar_one_or_none()
if not godfather_id: if not godfather_id:
raise HTTPException( raise HTTPException(
@ -141,8 +141,8 @@ async def create_invite(
if attempt == MAX_INVITE_COLLISION_RETRIES - 1: if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate unique invite code. Please try again.", detail="Failed to generate unique invite code. Try again.",
) ) from None
if invite is None: if invite is None:
raise HTTPException( raise HTTPException(
@ -156,7 +156,9 @@ async def create_invite(
async def list_all_invites( async def list_all_invites(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100), per_page: int = Query(10, ge=1, le=100),
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"), status_filter: str | None = Query(
None, alias="status", description="Filter by status: ready, spent, revoked"
),
godfather_id: int | None = Query(None, description="Filter by godfather user ID"), godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
@ -175,8 +177,9 @@ async def list_all_invites(
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked", detail=f"Invalid status: {status_filter}. "
) "Must be ready, spent, or revoked",
) from None
if godfather_id: if godfather_id:
query = query.where(Invite.godfather_id == godfather_id) query = query.where(Invite.godfather_id == godfather_id)
@ -224,11 +227,12 @@ async def revoke_invite(
if invite.status != InviteStatus.READY: if invite.status != InviteStatus.READY:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.", detail=f"Cannot revoke invite with status '{invite.status.value}'. "
"Only READY invites can be revoked.",
) )
invite.status = InviteStatus.REVOKED invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(timezone.utc) invite.revoked_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(invite) await db.refresh(invite)

View file

@ -1,7 +1,8 @@
"""Meta endpoints for shared constants.""" """Meta endpoints for shared constants."""
from fastapi import APIRouter from fastapi import APIRouter
from models import Permission, InviteStatus, ROLE_ADMIN, ROLE_REGULAR from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, Permission
from schemas import ConstantsResponse from schemas import ConstantsResponse
router = APIRouter(prefix="/api/meta", tags=["meta"]) router = APIRouter(prefix="/api/meta", tags=["meta"])
@ -15,4 +16,3 @@ async def get_constants() -> ConstantsResponse:
roles=[ROLE_ADMIN, ROLE_REGULAR], roles=[ROLE_ADMIN, ROLE_REGULAR],
invite_statuses=[s.value for s in InviteStatus], invite_statuses=[s.value for s in InviteStatus],
) )

View file

@ -1,15 +1,15 @@
"""Profile routes for user contact details.""" """Profile routes for user contact details."""
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import get_current_user from auth import get_current_user
from database import get_db from database import get_db
from models import User, ROLE_REGULAR from models import ROLE_REGULAR, User
from schemas import ProfileResponse, ProfileUpdate from schemas import ProfileResponse, ProfileUpdate
from validation import validate_profile_fields from validation import validate_profile_fields
router = APIRouter(prefix="/api/profile", tags=["profile"]) router = APIRouter(prefix="/api/profile", tags=["profile"])
@ -29,9 +29,7 @@ async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str
"""Get the email of a godfather user by ID.""" """Get the email of a godfather user by ID."""
if not godfather_id: if not godfather_id:
return None return None
result = await db.execute( result = await db.execute(select(User.email).where(User.id == godfather_id))
select(User.email).where(User.id == godfather_id)
)
return result.scalar_one_or_none() return result.scalar_one_or_none()

View file

@ -1,13 +1,13 @@
"""Sum calculation routes.""" """Sum calculation routes."""
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import User, SumRecord, Permission from models import Permission, SumRecord, User
from schemas import SumRequest, SumResponse from schemas import SumRequest, SumResponse
router = APIRouter(prefix="/api/sum", tags=["sum"]) router = APIRouter(prefix="/api/sum", tags=["sum"])

View file

@ -1,5 +1,6 @@
"""Pydantic schemas for API request/response models.""" """Pydantic schemas for API request/response models."""
from datetime import datetime, date, time
from datetime import date, datetime, time
from typing import Generic, TypeVar from typing import Generic, TypeVar
from pydantic import BaseModel, EmailStr, field_validator from pydantic import BaseModel, EmailStr, field_validator
@ -9,6 +10,7 @@ from shared_constants import NOTE_MAX_LENGTH
class UserCredentials(BaseModel): class UserCredentials(BaseModel):
"""Base model for user email/password.""" """Base model for user email/password."""
email: EmailStr email: EmailStr
password: str password: str
@ -19,6 +21,7 @@ UserLogin = UserCredentials
class UserResponse(BaseModel): class UserResponse(BaseModel):
"""Response model for authenticated user info.""" """Response model for authenticated user info."""
id: int id: int
email: str email: str
roles: list[str] roles: list[str]
@ -27,6 +30,7 @@ class UserResponse(BaseModel):
class RegisterWithInvite(BaseModel): class RegisterWithInvite(BaseModel):
"""Request model for registration with invite.""" """Request model for registration with invite."""
email: EmailStr email: EmailStr
password: str password: str
invite_identifier: str invite_identifier: str
@ -34,12 +38,14 @@ class RegisterWithInvite(BaseModel):
class SumRequest(BaseModel): class SumRequest(BaseModel):
"""Request model for sum calculation.""" """Request model for sum calculation."""
a: float a: float
b: float b: float
class SumResponse(BaseModel): class SumResponse(BaseModel):
"""Response model for sum calculation.""" """Response model for sum calculation."""
a: float a: float
b: float b: float
result: float result: float
@ -47,6 +53,7 @@ class SumResponse(BaseModel):
class CounterRecordResponse(BaseModel): class CounterRecordResponse(BaseModel):
"""Response model for a counter audit record.""" """Response model for a counter audit record."""
id: int id: int
user_email: str user_email: str
value_before: int value_before: int
@ -56,6 +63,7 @@ class CounterRecordResponse(BaseModel):
class SumRecordResponse(BaseModel): class SumRecordResponse(BaseModel):
"""Response model for a sum audit record.""" """Response model for a sum audit record."""
id: int id: int
user_email: str user_email: str
a: float a: float
@ -69,6 +77,7 @@ RecordT = TypeVar("RecordT", bound=BaseModel)
class PaginatedResponse(BaseModel, Generic[RecordT]): class PaginatedResponse(BaseModel, Generic[RecordT]):
"""Generic paginated response wrapper.""" """Generic paginated response wrapper."""
records: list[RecordT] records: list[RecordT]
total: int total: int
page: int page: int
@ -82,6 +91,7 @@ PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
class ProfileResponse(BaseModel): class ProfileResponse(BaseModel):
"""Response model for profile data.""" """Response model for profile data."""
contact_email: str | None contact_email: str | None
telegram: str | None telegram: str | None
signal: str | None signal: str | None
@ -91,6 +101,7 @@ class ProfileResponse(BaseModel):
class ProfileUpdate(BaseModel): class ProfileUpdate(BaseModel):
"""Request model for updating profile.""" """Request model for updating profile."""
contact_email: str | None = None contact_email: str | None = None
telegram: str | None = None telegram: str | None = None
signal: str | None = None signal: str | None = None
@ -99,6 +110,7 @@ class ProfileUpdate(BaseModel):
class InviteCheckResponse(BaseModel): class InviteCheckResponse(BaseModel):
"""Response for invite check endpoint.""" """Response for invite check endpoint."""
valid: bool valid: bool
status: str | None = None status: str | None = None
error: str | None = None error: str | None = None
@ -106,11 +118,13 @@ class InviteCheckResponse(BaseModel):
class InviteCreate(BaseModel): class InviteCreate(BaseModel):
"""Request model for creating an invite.""" """Request model for creating an invite."""
godfather_id: int godfather_id: int
class InviteResponse(BaseModel): class InviteResponse(BaseModel):
"""Response model for invite data (admin view).""" """Response model for invite data (admin view)."""
id: int id: int
identifier: str identifier: str
godfather_id: int godfather_id: int
@ -125,6 +139,7 @@ class InviteResponse(BaseModel):
class UserInviteResponse(BaseModel): class UserInviteResponse(BaseModel):
"""Response model for a user's invite (simpler than admin view).""" """Response model for a user's invite (simpler than admin view)."""
id: int id: int
identifier: str identifier: str
status: str status: str
@ -138,6 +153,7 @@ PaginatedInviteRecords = PaginatedResponse[InviteResponse]
class AdminUserResponse(BaseModel): class AdminUserResponse(BaseModel):
"""Minimal user info for admin dropdowns.""" """Minimal user info for admin dropdowns."""
id: int id: int
email: str email: str
@ -146,11 +162,13 @@ class AdminUserResponse(BaseModel):
# Availability Schemas # Availability Schemas
# ============================================================================= # =============================================================================
class TimeSlot(BaseModel): class TimeSlot(BaseModel):
"""A single time slot (start and end time).""" """A single time slot (start and end time)."""
start_time: time start_time: time
end_time: time end_time: time
@field_validator("start_time", "end_time") @field_validator("start_time", "end_time")
@classmethod @classmethod
def validate_15min_boundary(cls, v: time) -> time: def validate_15min_boundary(cls, v: time) -> time:
@ -164,23 +182,27 @@ class TimeSlot(BaseModel):
class AvailabilityDay(BaseModel): class AvailabilityDay(BaseModel):
"""Availability for a single day.""" """Availability for a single day."""
date: date date: date
slots: list[TimeSlot] slots: list[TimeSlot]
class AvailabilityResponse(BaseModel): class AvailabilityResponse(BaseModel):
"""Response model for availability query.""" """Response model for availability query."""
days: list[AvailabilityDay] days: list[AvailabilityDay]
class SetAvailabilityRequest(BaseModel): class SetAvailabilityRequest(BaseModel):
"""Request to set availability for a specific date.""" """Request to set availability for a specific date."""
date: date date: date
slots: list[TimeSlot] slots: list[TimeSlot]
class CopyAvailabilityRequest(BaseModel): class CopyAvailabilityRequest(BaseModel):
"""Request to copy availability from one day to others.""" """Request to copy availability from one day to others."""
source_date: date source_date: date
target_dates: list[date] target_dates: list[date]
@ -189,20 +211,24 @@ class CopyAvailabilityRequest(BaseModel):
# Booking Schemas # Booking Schemas
# ============================================================================= # =============================================================================
class BookableSlot(BaseModel): class BookableSlot(BaseModel):
"""A bookable 15-minute slot.""" """A bookable 15-minute slot."""
start_time: datetime start_time: datetime
end_time: datetime end_time: datetime
class AvailableSlotsResponse(BaseModel): class AvailableSlotsResponse(BaseModel):
"""Response for available slots on a given date.""" """Response for available slots on a given date."""
date: date date: date
slots: list[BookableSlot] slots: list[BookableSlot]
class BookingRequest(BaseModel): class BookingRequest(BaseModel):
"""Request to book an appointment.""" """Request to book an appointment."""
slot_start: datetime slot_start: datetime
note: str | None = None note: str | None = None
@ -216,6 +242,7 @@ class BookingRequest(BaseModel):
class AppointmentResponse(BaseModel): class AppointmentResponse(BaseModel):
"""Response model for an appointment.""" """Response model for an appointment."""
id: int id: int
user_id: int user_id: int
user_email: str user_email: str
@ -234,8 +261,10 @@ PaginatedAppointments = PaginatedResponse[AppointmentResponse]
# Meta/Constants Schemas # Meta/Constants Schemas
# ============================================================================= # =============================================================================
class ConstantsResponse(BaseModel): class ConstantsResponse(BaseModel):
"""Response model for shared constants.""" """Response model for shared constants."""
permissions: list[str] permissions: list[str]
roles: list[str] roles: list[str]
invite_statuses: list[str] invite_statuses: list[str]

View file

@ -1,12 +1,21 @@
"""Seed the database with roles, permissions, and dev users.""" """Seed the database with roles, permissions, and dev users."""
import asyncio import asyncio
import os import os
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, async_session, Base
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
from auth import get_password_hash from auth import get_password_hash
from database import Base, async_session, engine
from models import (
ROLE_ADMIN,
ROLE_DEFINITIONS,
ROLE_REGULAR,
Permission,
Role,
User,
)
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"] DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"] DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
@ -14,11 +23,13 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"] DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role: async def upsert_role(
db: AsyncSession, name: str, description: str, permissions: list[Permission]
) -> Role:
"""Create or update a role with the given permissions.""" """Create or update a role with the given permissions."""
result = await db.execute(select(Role).where(Role.name == name)) result = await db.execute(select(Role).where(Role.name == name))
role = result.scalar_one_or_none() role = result.scalar_one_or_none()
if role: if role:
role.description = description role.description = description
print(f"Updated role: {name}") print(f"Updated role: {name}")
@ -27,19 +38,21 @@ async def upsert_role(db: AsyncSession, name: str, description: str, permissions
db.add(role) db.add(role)
await db.flush() # Get the role ID await db.flush() # Get the role ID
print(f"Created role: {name}") print(f"Created role: {name}")
# Set permissions for the role # Set permissions for the role
await role.set_permissions(db, permissions) await role.set_permissions(db, permissions)
print(f" Permissions: {', '.join(p.value for p in permissions)}") print(f" Permissions: {', '.join(p.value for p in permissions)}")
return role return role
async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User: async def upsert_user(
db: AsyncSession, email: str, password: str, role_names: list[str]
) -> User:
"""Create or update a user with the given credentials and roles.""" """Create or update a user with the given credentials and roles."""
result = await db.execute(select(User).where(User.email == email)) result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
# Get roles # Get roles
roles = [] roles = []
for role_name in role_names: for role_name in role_names:
@ -48,7 +61,7 @@ async def upsert_user(db: AsyncSession, email: str, password: str, role_names: l
if not role: if not role:
raise ValueError(f"Role '{role_name}' not found") raise ValueError(f"Role '{role_name}' not found")
roles.append(role) roles.append(role)
if user: if user:
user.hashed_password = get_password_hash(password) user.hashed_password = get_password_hash(password)
user.roles = roles # type: ignore[assignment] user.roles = roles # type: ignore[assignment]
@ -61,7 +74,7 @@ async def upsert_user(db: AsyncSession, email: str, password: str, role_names: l
) )
db.add(user) db.add(user)
print(f"Created user: {email} with roles: {role_names}") print(f"Created user: {email} with roles: {role_names}")
return user return user
@ -78,14 +91,14 @@ async def seed() -> None:
role_config["description"], role_config["description"],
role_config["permissions"], role_config["permissions"],
) )
print("\n=== Seeding Users ===") print("\n=== Seeding Users ===")
# Create regular dev user # Create regular dev user
await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, [ROLE_REGULAR]) await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, [ROLE_REGULAR])
# Create admin dev user # Create admin dev user
await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, [ROLE_ADMIN]) await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, [ROLE_ADMIN])
await db.commit() await db.commit()
print("\n=== Seeding Complete ===\n") print("\n=== Seeding Complete ===\n")

View file

@ -1,4 +1,5 @@
"""Load shared constants from shared/constants.json.""" """Load shared constants from shared/constants.json."""
import json import json
from pathlib import Path from pathlib import Path
@ -10,4 +11,3 @@ SLOT_DURATION_MINUTES: int = _constants["booking"]["slotDurationMinutes"]
MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"] MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"]
MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"] MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"]
NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"] NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"]

View file

@ -7,28 +7,28 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from auth import get_password_hash
from database import Base, get_db from database import Base, get_db
from main import app from main import app
from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN from models import ROLE_ADMIN, ROLE_DEFINITIONS, ROLE_REGULAR, Role, User
from auth import get_password_hash
from tests.helpers import unique_email from tests.helpers import unique_email
TEST_DATABASE_URL = os.getenv( TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL", "TEST_DATABASE_URL",
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test" "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test",
) )
class ClientFactory: class ClientFactory:
"""Factory for creating httpx clients with optional cookies.""" """Factory for creating httpx clients with optional cookies."""
def __init__(self, transport, base_url, session_factory): def __init__(self, transport, base_url, session_factory):
self._transport = transport self._transport = transport
self._base_url = base_url self._base_url = base_url
self._session_factory = session_factory self._session_factory = session_factory
@asynccontextmanager @asynccontextmanager
async def create(self, cookies: dict | None = None): async def create(self, cookies: dict | None = None):
"""Create a new client, optionally with cookies set.""" """Create a new client, optionally with cookies set."""
@ -38,15 +38,15 @@ class ClientFactory:
cookies=cookies or {}, cookies=cookies or {},
) as client: ) as client:
yield client yield client
async def request(self, method: str, url: str, **kwargs): async def request(self, method: str, url: str, **kwargs):
"""Make a one-off request without cookies.""" """Make a one-off request without cookies."""
async with self.create() as client: async with self.create() as client:
return await client.request(method, url, **kwargs) return await client.request(method, url, **kwargs)
async def get(self, url: str, **kwargs): async def get(self, url: str, **kwargs):
return await self.request("GET", url, **kwargs) return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs): async def post(self, url: str, **kwargs):
return await self.request("POST", url, **kwargs) return await self.request("POST", url, **kwargs)
@ -64,16 +64,16 @@ async def setup_roles(db: AsyncSession) -> dict[str, Role]:
# Check if role exists # Check if role exists
result = await db.execute(select(Role).where(Role.name == role_name)) result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none() role = result.scalar_one_or_none()
if not role: if not role:
role = Role(name=role_name, description=config["description"]) role = Role(name=role_name, description=config["description"])
db.add(role) db.add(role)
await db.flush() await db.flush()
# Set permissions # Set permissions
await role.set_permissions(db, config["permissions"]) await role.set_permissions(db, config["permissions"])
roles[role_name] = role roles[role_name] = role
await db.commit() await db.commit()
return roles return roles
@ -91,9 +91,11 @@ async def create_user_with_roles(
result = await db.execute(select(Role).where(Role.name == role_name)) result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none() role = result.scalar_one_or_none()
if not role: if not role:
raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?") raise ValueError(
f"Role '{role_name}' not found. Did you run setup_roles()?"
)
roles.append(role) roles.append(role)
user = User( user = User(
email=email, email=email,
hashed_password=get_password_hash(password), hashed_password=get_password_hash(password),
@ -110,27 +112,27 @@ async def client_factory():
"""Fixture that provides a factory for creating clients.""" """Fixture that provides a factory for creating clients."""
engine = create_async_engine(TEST_DATABASE_URL) engine = create_async_engine(TEST_DATABASE_URL)
session_factory = async_sessionmaker(engine, expire_on_commit=False) session_factory = async_sessionmaker(engine, expire_on_commit=False)
# Create tables # Create tables
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
# Setup roles # Setup roles
async with session_factory() as db: async with session_factory() as db:
await setup_roles(db) await setup_roles(db)
async def override_get_db(): async def override_get_db():
async with session_factory() as session: async with session_factory() as session:
yield session yield session
app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
factory = ClientFactory(transport, "http://test", session_factory) factory = ClientFactory(transport, "http://test", session_factory)
yield factory yield factory
app.dependency_overrides.clear() app.dependency_overrides.clear()
await engine.dispose() await engine.dispose()
@ -147,17 +149,17 @@ async def regular_user(client_factory):
"""Create a regular user and return their credentials and cookies.""" """Create a regular user and return their credentials and cookies."""
email = unique_email("regular") email = unique_email("regular")
password = "password123" password = "password123"
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
user = await create_user_with_roles(db, email, password, [ROLE_REGULAR]) user = await create_user_with_roles(db, email, password, [ROLE_REGULAR])
user_id = user.id user_id = user.id
# Login to get cookies # Login to get cookies
response = await client_factory.post( response = await client_factory.post(
"/api/auth/login", "/api/auth/login",
json={"email": email, "password": password}, json={"email": email, "password": password},
) )
return { return {
"email": email, "email": email,
"password": password, "password": password,
@ -172,17 +174,17 @@ async def alt_regular_user(client_factory):
"""Create a second regular user for tests needing multiple users.""" """Create a second regular user for tests needing multiple users."""
email = unique_email("alt_regular") email = unique_email("alt_regular")
password = "password123" password = "password123"
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
user = await create_user_with_roles(db, email, password, [ROLE_REGULAR]) user = await create_user_with_roles(db, email, password, [ROLE_REGULAR])
user_id = user.id user_id = user.id
# Login to get cookies # Login to get cookies
response = await client_factory.post( response = await client_factory.post(
"/api/auth/login", "/api/auth/login",
json={"email": email, "password": password}, json={"email": email, "password": password},
) )
return { return {
"email": email, "email": email,
"password": password, "password": password,
@ -197,16 +199,16 @@ async def admin_user(client_factory):
"""Create an admin user and return their credentials and cookies.""" """Create an admin user and return their credentials and cookies."""
email = unique_email("admin") email = unique_email("admin")
password = "password123" password = "password123"
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
await create_user_with_roles(db, email, password, [ROLE_ADMIN]) await create_user_with_roles(db, email, password, [ROLE_ADMIN])
# Login to get cookies # Login to get cookies
response = await client_factory.post( response = await client_factory.post(
"/api/auth/login", "/api/auth/login",
json={"email": email, "password": password}, json={"email": email, "password": password},
) )
return { return {
"email": email, "email": email,
"password": password, "password": password,
@ -220,16 +222,16 @@ async def user_no_roles(client_factory):
"""Create a user with NO roles and return their credentials and cookies.""" """Create a user with NO roles and return their credentials and cookies."""
email = unique_email("noroles") email = unique_email("noroles")
password = "password123" password = "password123"
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
await create_user_with_roles(db, email, password, []) await create_user_with_roles(db, email, password, [])
# Login to get cookies # Login to get cookies
response = await client_factory.post( response = await client_factory.post(
"/api/auth/login", "/api/auth/login",
json={"email": email, "password": password}, json={"email": email, "password": password},
) )
return { return {
"email": email, "email": email,
"password": password, "password": password,

View file

@ -3,8 +3,8 @@ import uuid
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from models import User, Invite, InviteStatus
from invite_utils import generate_invite_identifier from invite_utils import generate_invite_identifier
from models import Invite, InviteStatus, User
def unique_email(prefix: str = "test") -> str: def unique_email(prefix: str = "test") -> str:
@ -15,24 +15,24 @@ def unique_email(prefix: str = "test") -> str:
async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> str: async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> str:
""" """
Create an invite for an existing godfather user. Create an invite for an existing godfather user.
Args: Args:
db: Database session db: Database session
godfather_id: ID of the existing user who will be the godfather godfather_id: ID of the existing user who will be the godfather
Returns: Returns:
The invite identifier. The invite identifier.
Raises: Raises:
ValueError: If the godfather user doesn't exist. ValueError: If the godfather user doesn't exist.
""" """
# Verify godfather exists # Verify godfather exists
result = await db.execute(select(User).where(User.id == godfather_id)) result = await db.execute(select(User).where(User.id == godfather_id))
godfather = result.scalar_one_or_none() godfather = result.scalar_one_or_none()
if not godfather: if not godfather:
raise ValueError(f"Godfather user with ID {godfather_id} not found") raise ValueError(f"Godfather user with ID {godfather_id} not found")
# Create invite # Create invite
identifier = generate_invite_identifier() identifier = generate_invite_identifier()
invite = Invite( invite = Invite(
@ -42,7 +42,7 @@ async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> st
) )
db.add(invite) db.add(invite)
await db.commit() await db.commit()
return identifier return identifier
@ -50,24 +50,26 @@ async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> st
async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str: async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str:
""" """
Create an invite for an existing godfather user (looked up by email). Create an invite for an existing godfather user (looked up by email).
The godfather must already exist in the database. The godfather must already exist in the database.
Args: Args:
db: Database session db: Database session
godfather_email: Email of the existing user who will be the godfather godfather_email: Email of the existing user who will be the godfather
Returns: Returns:
The invite identifier. The invite identifier.
Raises: Raises:
ValueError: If the godfather user doesn't exist. ValueError: If the godfather user doesn't exist.
""" """
result = await db.execute(select(User).where(User.email == godfather_email)) result = await db.execute(select(User).where(User.email == godfather_email))
godfather = result.scalar_one_or_none() godfather = result.scalar_one_or_none()
if not godfather: if not godfather:
raise ValueError(f"Godfather user with email '{godfather_email}' not found. " raise ValueError(
"Create the user first using create_user_with_roles().") f"Godfather user with email '{godfather_email}' not found. "
"Create the user first using create_user_with_roles()."
)
return await create_invite_for_godfather(db, godfather.id) return await create_invite_for_godfather(db, godfather.id)

View file

@ -3,12 +3,13 @@
Note: Registration now requires an invite code. Tests that need to register Note: Registration now requires an invite code. Tests that need to register
users will create invites first via the helper function. users will create invites first via the helper function.
""" """
import pytest import pytest
from auth import COOKIE_NAME from auth import COOKIE_NAME
from models import ROLE_REGULAR from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
# Registration tests (with invite) # Registration tests (with invite)
@ -16,12 +17,14 @@ from tests.conftest import create_user_with_roles
async def test_register_success(client_factory): async def test_register_success(client_factory):
"""Can register with valid invite code.""" """Can register with valid invite code."""
email = unique_email("register") email = unique_email("register")
# Create godfather user and invite # Create godfather user and invite
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("godfather"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -46,13 +49,15 @@ async def test_register_success(client_factory):
async def test_register_duplicate_email(client_factory): async def test_register_duplicate_email(client_factory):
"""Cannot register with already-used email.""" """Cannot register with already-used email."""
email = unique_email("duplicate") email = unique_email("duplicate")
# Create godfather and two invites # Create godfather and two invites
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id)
# First registration # First registration
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -62,7 +67,7 @@ async def test_register_duplicate_email(client_factory):
"invite_identifier": invite1, "invite_identifier": invite1,
}, },
) )
# Second registration with same email # Second registration with same email
response = await client_factory.post( response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -80,9 +85,11 @@ async def test_register_duplicate_email(client_factory):
async def test_register_invalid_email(client_factory): async def test_register_invalid_email(client_factory):
"""Cannot register with invalid email format.""" """Cannot register with invalid email format."""
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -136,11 +143,13 @@ async def test_register_empty_body(client):
async def test_login_success(client_factory): async def test_login_success(client_factory):
"""Can login with valid credentials.""" """Can login with valid credentials."""
email = unique_email("login") email = unique_email("login")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -165,11 +174,13 @@ async def test_login_success(client_factory):
async def test_login_wrong_password(client_factory): async def test_login_wrong_password(client_factory):
"""Cannot login with wrong password.""" """Cannot login with wrong password."""
email = unique_email("wrongpass") email = unique_email("wrongpass")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -219,11 +230,13 @@ async def test_login_missing_fields(client):
async def test_get_me_success(client_factory): async def test_get_me_success(client_factory):
"""Can get current user info when authenticated.""" """Can get current user info when authenticated."""
email = unique_email("me") email = unique_email("me")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -233,10 +246,10 @@ async def test_get_me_success(client_factory):
}, },
) )
cookies = dict(reg_response.cookies) cookies = dict(reg_response.cookies)
async with client_factory.create(cookies=cookies) as authed: async with client_factory.create(cookies=cookies) as authed:
response = await authed.get("/api/auth/me") response = await authed.get("/api/auth/me")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["email"] == email assert data["email"] == email
@ -255,7 +268,9 @@ async def test_get_me_no_cookie(client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_me_invalid_cookie(client_factory): async def test_get_me_invalid_cookie(client_factory):
"""Cannot get current user with invalid cookie.""" """Cannot get current user with invalid cookie."""
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed: async with client_factory.create(
cookies={COOKIE_NAME: "invalidtoken123"}
) as authed:
response = await authed.get("/api/auth/me") response = await authed.get("/api/auth/me")
assert response.status_code == 401 assert response.status_code == 401
assert response.json()["detail"] == "Invalid authentication credentials" assert response.json()["detail"] == "Invalid authentication credentials"
@ -275,11 +290,13 @@ async def test_get_me_expired_token(client_factory):
async def test_cookie_from_register_works_for_me(client_factory): async def test_cookie_from_register_works_for_me(client_factory):
"""Auth cookie from registration works for subsequent requests.""" """Auth cookie from registration works for subsequent requests."""
email = unique_email("tokentest") email = unique_email("tokentest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -289,10 +306,10 @@ async def test_cookie_from_register_works_for_me(client_factory):
}, },
) )
cookies = dict(reg_response.cookies) cookies = dict(reg_response.cookies)
async with client_factory.create(cookies=cookies) as authed: async with client_factory.create(cookies=cookies) as authed:
me_response = await authed.get("/api/auth/me") me_response = await authed.get("/api/auth/me")
assert me_response.status_code == 200 assert me_response.status_code == 200
assert me_response.json()["email"] == email assert me_response.json()["email"] == email
@ -301,11 +318,13 @@ async def test_cookie_from_register_works_for_me(client_factory):
async def test_cookie_from_login_works_for_me(client_factory): async def test_cookie_from_login_works_for_me(client_factory):
"""Auth cookie from login works for subsequent requests.""" """Auth cookie from login works for subsequent requests."""
email = unique_email("logintoken") email = unique_email("logintoken")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -319,10 +338,10 @@ async def test_cookie_from_login_works_for_me(client_factory):
json={"email": email, "password": "password123"}, json={"email": email, "password": "password123"},
) )
cookies = dict(login_response.cookies) cookies = dict(login_response.cookies)
async with client_factory.create(cookies=cookies) as authed: async with client_factory.create(cookies=cookies) as authed:
me_response = await authed.get("/api/auth/me") me_response = await authed.get("/api/auth/me")
assert me_response.status_code == 200 assert me_response.status_code == 200
assert me_response.json()["email"] == email assert me_response.json()["email"] == email
@ -333,12 +352,14 @@ async def test_multiple_users_isolated(client_factory):
"""Multiple users have isolated sessions.""" """Multiple users have isolated sessions."""
email1 = unique_email("user1") email1 = unique_email("user1")
email2 = unique_email("user2") email2 = unique_email("user2")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id)
resp1 = await client_factory.post( resp1 = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -355,16 +376,16 @@ async def test_multiple_users_isolated(client_factory):
"invite_identifier": invite2, "invite_identifier": invite2,
}, },
) )
cookies1 = dict(resp1.cookies) cookies1 = dict(resp1.cookies)
cookies2 = dict(resp2.cookies) cookies2 = dict(resp2.cookies)
async with client_factory.create(cookies=cookies1) as user1: async with client_factory.create(cookies=cookies1) as user1:
me1 = await user1.get("/api/auth/me") me1 = await user1.get("/api/auth/me")
async with client_factory.create(cookies=cookies2) as user2: async with client_factory.create(cookies=cookies2) as user2:
me2 = await user2.get("/api/auth/me") me2 = await user2.get("/api/auth/me")
assert me1.json()["email"] == email1 assert me1.json()["email"] == email1
assert me2.json()["email"] == email2 assert me2.json()["email"] == email2
assert me1.json()["id"] != me2.json()["id"] assert me1.json()["id"] != me2.json()["id"]
@ -375,11 +396,13 @@ async def test_multiple_users_isolated(client_factory):
async def test_password_is_hashed(client_factory): async def test_password_is_hashed(client_factory):
"""Passwords are properly hashed (can login with correct password).""" """Passwords are properly hashed (can login with correct password)."""
email = unique_email("hashtest") email = unique_email("hashtest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -399,11 +422,13 @@ async def test_password_is_hashed(client_factory):
async def test_case_sensitive_password(client_factory): async def test_case_sensitive_password(client_factory):
"""Passwords are case-sensitive.""" """Passwords are case-sensitive."""
email = unique_email("casetest") email = unique_email("casetest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -424,11 +449,13 @@ async def test_case_sensitive_password(client_factory):
async def test_logout_success(client_factory): async def test_logout_success(client_factory):
"""Can logout successfully.""" """Can logout successfully."""
email = unique_email("logout") email = unique_email("logout")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -438,9 +465,9 @@ async def test_logout_success(client_factory):
}, },
) )
cookies = dict(reg_response.cookies) cookies = dict(reg_response.cookies)
async with client_factory.create(cookies=cookies) as authed: async with client_factory.create(cookies=cookies) as authed:
logout_response = await authed.post("/api/auth/logout") logout_response = await authed.post("/api/auth/logout")
assert logout_response.status_code == 200 assert logout_response.status_code == 200
assert logout_response.json() == {"ok": True} assert logout_response.json() == {"ok": True}

View file

@ -3,7 +3,9 @@ Availability API Tests
Tests for the admin availability management endpoints. Tests for the admin availability management endpoints.
""" """
from datetime import date, time, timedelta
from datetime import date, timedelta
import pytest import pytest
@ -19,6 +21,7 @@ def in_days(n: int) -> date:
# Permission Tests # Permission Tests
# ============================================================================= # =============================================================================
class TestAvailabilityPermissions: class TestAvailabilityPermissions:
"""Test that only admins can access availability endpoints.""" """Test that only admins can access availability endpoints."""
@ -44,7 +47,9 @@ class TestAvailabilityPermissions:
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_get_availability(self, client_factory, regular_user): async def test_regular_user_cannot_get_availability(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get( response = await client.get(
"/api/admin/availability", "/api/admin/availability",
@ -53,7 +58,9 @@ class TestAvailabilityPermissions:
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_set_availability(self, client_factory, regular_user): async def test_regular_user_cannot_set_availability(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.put( response = await client.put(
"/api/admin/availability", "/api/admin/availability",
@ -88,6 +95,7 @@ class TestAvailabilityPermissions:
# Set Availability Tests # Set Availability Tests
# ============================================================================= # =============================================================================
class TestSetAvailability: class TestSetAvailability:
"""Test setting availability for a date.""" """Test setting availability for a date."""
@ -101,7 +109,7 @@ class TestSetAvailability:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["date"] == str(tomorrow()) assert data["date"] == str(tomorrow())
@ -122,13 +130,15 @@ class TestSetAvailability:
], ],
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["slots"]) == 2 assert len(data["slots"]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_empty_slots_clears_availability(self, client_factory, admin_user): async def test_set_empty_slots_clears_availability(
self, client_factory, admin_user
):
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
# First set some availability # First set some availability
await client.put( await client.put(
@ -138,13 +148,13 @@ class TestSetAvailability:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# Then clear it # Then clear it
response = await client.put( response = await client.put(
"/api/admin/availability", "/api/admin/availability",
json={"date": str(tomorrow()), "slots": []}, json={"date": str(tomorrow()), "slots": []},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["slots"]) == 0 assert len(data["slots"]) == 0
@ -160,22 +170,22 @@ class TestSetAvailability:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# Replace with different slots # Replace with different slots
response = await client.put( await client.put(
"/api/admin/availability", "/api/admin/availability",
json={ json={
"date": str(tomorrow()), "date": str(tomorrow()),
"slots": [{"start_time": "14:00:00", "end_time": "16:00:00"}], "slots": [{"start_time": "14:00:00", "end_time": "16:00:00"}],
}, },
) )
# Verify the replacement # Verify the replacement
get_response = await client.get( get_response = await client.get(
"/api/admin/availability", "/api/admin/availability",
params={"from": str(tomorrow()), "to": str(tomorrow())}, params={"from": str(tomorrow()), "to": str(tomorrow())},
) )
data = get_response.json() data = get_response.json()
assert len(data["days"]) == 1 assert len(data["days"]) == 1
assert len(data["days"][0]["slots"]) == 1 assert len(data["days"][0]["slots"]) == 1
@ -186,6 +196,7 @@ class TestSetAvailability:
# Validation Tests # Validation Tests
# ============================================================================= # =============================================================================
class TestAvailabilityValidation: class TestAvailabilityValidation:
"""Test validation rules for availability.""" """Test validation rules for availability."""
@ -200,7 +211,7 @@ class TestAvailabilityValidation:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "past" in response.json()["detail"].lower() assert "past" in response.json()["detail"].lower()
@ -214,7 +225,7 @@ class TestAvailabilityValidation:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "past" in response.json()["detail"].lower() assert "past" in response.json()["detail"].lower()
@ -229,7 +240,7 @@ class TestAvailabilityValidation:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "30" in response.json()["detail"] assert "30" in response.json()["detail"]
@ -243,7 +254,7 @@ class TestAvailabilityValidation:
"slots": [{"start_time": "09:05:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:05:00", "end_time": "12:00:00"}],
}, },
) )
assert response.status_code == 422 # Pydantic validation error assert response.status_code == 422 # Pydantic validation error
assert "15-minute" in response.json()["detail"][0]["msg"] assert "15-minute" in response.json()["detail"][0]["msg"]
@ -257,7 +268,7 @@ class TestAvailabilityValidation:
"slots": [{"start_time": "12:00:00", "end_time": "09:00:00"}], "slots": [{"start_time": "12:00:00", "end_time": "09:00:00"}],
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "after" in response.json()["detail"].lower() assert "after" in response.json()["detail"].lower()
@ -274,7 +285,7 @@ class TestAvailabilityValidation:
], ],
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "overlap" in response.json()["detail"].lower() assert "overlap" in response.json()["detail"].lower()
@ -283,6 +294,7 @@ class TestAvailabilityValidation:
# Get Availability Tests # Get Availability Tests
# ============================================================================= # =============================================================================
class TestGetAvailability: class TestGetAvailability:
"""Test retrieving availability.""" """Test retrieving availability."""
@ -293,7 +305,7 @@ class TestGetAvailability:
"/api/admin/availability", "/api/admin/availability",
params={"from": str(tomorrow()), "to": str(in_days(7))}, params={"from": str(tomorrow()), "to": str(in_days(7))},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["days"] == [] assert data["days"] == []
@ -310,13 +322,13 @@ class TestGetAvailability:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# Get range that includes all # Get range that includes all
response = await client.get( response = await client.get(
"/api/admin/availability", "/api/admin/availability",
params={"from": str(in_days(1)), "to": str(in_days(3))}, params={"from": str(in_days(1)), "to": str(in_days(3))},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["days"]) == 3 assert len(data["days"]) == 3
@ -333,13 +345,13 @@ class TestGetAvailability:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# Get only a subset # Get only a subset
response = await client.get( response = await client.get(
"/api/admin/availability", "/api/admin/availability",
params={"from": str(in_days(2)), "to": str(in_days(4))}, params={"from": str(in_days(2)), "to": str(in_days(4))},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["days"]) == 3 assert len(data["days"]) == 3
@ -351,7 +363,7 @@ class TestGetAvailability:
"/api/admin/availability", "/api/admin/availability",
params={"from": str(in_days(7)), "to": str(in_days(1))}, params={"from": str(in_days(7)), "to": str(in_days(1))},
) )
assert response.status_code == 400 assert response.status_code == 400
assert "before" in response.json()["detail"].lower() assert "before" in response.json()["detail"].lower()
@ -360,6 +372,7 @@ class TestGetAvailability:
# Copy Availability Tests # Copy Availability Tests
# ============================================================================= # =============================================================================
class TestCopyAvailability: class TestCopyAvailability:
"""Test copying availability from one day to others.""" """Test copying availability from one day to others."""
@ -377,7 +390,7 @@ class TestCopyAvailability:
], ],
}, },
) )
# Copy to another day # Copy to another day
response = await client.post( response = await client.post(
"/api/admin/availability/copy", "/api/admin/availability/copy",
@ -386,7 +399,7 @@ class TestCopyAvailability:
"target_dates": [str(in_days(2))], "target_dates": [str(in_days(2))],
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["days"]) == 1 assert len(data["days"]) == 1
@ -404,7 +417,7 @@ class TestCopyAvailability:
"slots": [{"start_time": "10:00:00", "end_time": "11:00:00"}], "slots": [{"start_time": "10:00:00", "end_time": "11:00:00"}],
}, },
) )
# Copy to multiple days # Copy to multiple days
response = await client.post( response = await client.post(
"/api/admin/availability/copy", "/api/admin/availability/copy",
@ -413,7 +426,7 @@ class TestCopyAvailability:
"target_dates": [str(in_days(2)), str(in_days(3)), str(in_days(4))], "target_dates": [str(in_days(2)), str(in_days(3)), str(in_days(4))],
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["days"]) == 3 assert len(data["days"]) == 3
@ -429,7 +442,7 @@ class TestCopyAvailability:
"slots": [{"start_time": "08:00:00", "end_time": "09:00:00"}], "slots": [{"start_time": "08:00:00", "end_time": "09:00:00"}],
}, },
) )
# Set source availability # Set source availability
await client.put( await client.put(
"/api/admin/availability", "/api/admin/availability",
@ -438,7 +451,7 @@ class TestCopyAvailability:
"slots": [{"start_time": "14:00:00", "end_time": "15:00:00"}], "slots": [{"start_time": "14:00:00", "end_time": "15:00:00"}],
}, },
) )
# Copy (should replace) # Copy (should replace)
await client.post( await client.post(
"/api/admin/availability/copy", "/api/admin/availability/copy",
@ -447,13 +460,13 @@ class TestCopyAvailability:
"target_dates": [str(in_days(2))], "target_dates": [str(in_days(2))],
}, },
) )
# Verify target was replaced # Verify target was replaced
response = await client.get( response = await client.get(
"/api/admin/availability", "/api/admin/availability",
params={"from": str(in_days(2)), "to": str(in_days(2))}, params={"from": str(in_days(2)), "to": str(in_days(2))},
) )
data = response.json() data = response.json()
assert len(data["days"]) == 1 assert len(data["days"]) == 1
assert len(data["days"][0]["slots"]) == 1 assert len(data["days"][0]["slots"]) == 1
@ -469,7 +482,7 @@ class TestCopyAvailability:
"target_dates": [str(in_days(2))], "target_dates": [str(in_days(2))],
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "no availability" in response.json()["detail"].lower() assert "no availability" in response.json()["detail"].lower()
@ -484,7 +497,7 @@ class TestCopyAvailability:
"slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}],
}, },
) )
# Copy including self in targets # Copy including self in targets
response = await client.post( response = await client.post(
"/api/admin/availability/copy", "/api/admin/availability/copy",
@ -493,7 +506,7 @@ class TestCopyAvailability:
"target_dates": [str(in_days(1)), str(in_days(2))], "target_dates": [str(in_days(1)), str(in_days(2))],
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# Should only have copied to day 2, not day 1 (self) # Should only have copied to day 2, not day 1 (self)
@ -511,7 +524,7 @@ class TestCopyAvailability:
"slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}],
}, },
) )
# Try to copy to a date beyond 30 days # Try to copy to a date beyond 30 days
response = await client.post( response = await client.post(
"/api/admin/availability/copy", "/api/admin/availability/copy",
@ -520,7 +533,7 @@ class TestCopyAvailability:
"target_dates": [str(in_days(31))], "target_dates": [str(in_days(31))],
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "30" in response.json()["detail"] assert "30" in response.json()["detail"]
@ -535,6 +548,6 @@ class TestCopyAvailability:
"target_dates": [str(in_days(1))], "target_dates": [str(in_days(1))],
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "past" in response.json()["detail"].lower() assert "past" in response.json()["detail"].lower()

View file

@ -3,7 +3,9 @@ Booking API Tests
Tests for the user booking endpoints. Tests for the user booking endpoints.
""" """
from datetime import date, datetime, timedelta, timezone
from datetime import UTC, date, datetime, timedelta
import pytest import pytest
from models import Appointment, AppointmentStatus from models import Appointment, AppointmentStatus
@ -21,11 +23,14 @@ def in_days(n: int) -> date:
# Permission Tests # Permission Tests
# ============================================================================= # =============================================================================
class TestBookingPermissions: class TestBookingPermissions:
"""Test that only regular users can book appointments.""" """Test that only regular users can book appointments."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_can_get_slots(self, client_factory, regular_user, admin_user): async def test_regular_user_can_get_slots(
self, client_factory, regular_user, admin_user
):
"""Regular user can get available slots.""" """Regular user can get available slots."""
# First, admin sets up availability # First, admin sets up availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -36,15 +41,19 @@ class TestBookingPermissions:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# Regular user gets slots # Regular user gets slots
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_can_book(self, client_factory, regular_user, admin_user): async def test_regular_user_can_book(
self, client_factory, regular_user, admin_user
):
"""Regular user can book an appointment.""" """Regular user can book an appointment."""
# Admin sets up availability # Admin sets up availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -55,22 +64,24 @@ class TestBookingPermissions:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# Regular user books # Regular user books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post( response = await client.post(
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": "Test booking"}, json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": "Test booking"},
) )
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_get_slots(self, client_factory, admin_user): async def test_admin_cannot_get_slots(self, client_factory, admin_user):
"""Admin cannot access booking slots endpoint.""" """Admin cannot access booking slots endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@ -85,18 +96,20 @@ class TestBookingPermissions:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
response = await client.post( response = await client.post(
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unauthenticated_cannot_get_slots(self, client): async def test_unauthenticated_cannot_get_slots(self, client):
"""Unauthenticated user cannot get slots.""" """Unauthenticated user cannot get slots."""
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
@ -113,6 +126,7 @@ class TestBookingPermissions:
# Get Slots Tests # Get Slots Tests
# ============================================================================= # =============================================================================
class TestGetSlots: class TestGetSlots:
"""Test getting available booking slots.""" """Test getting available booking slots."""
@ -120,15 +134,19 @@ class TestGetSlots:
async def test_get_slots_no_availability(self, client_factory, regular_user): async def test_get_slots_no_availability(self, client_factory, regular_user):
"""Returns empty slots when no availability set.""" """Returns empty slots when no availability set."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["date"] == str(tomorrow()) assert data["date"] == str(tomorrow())
assert data["slots"] == [] assert data["slots"] == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_slots_expands_to_15min(self, client_factory, regular_user, admin_user): async def test_get_slots_expands_to_15min(
self, client_factory, regular_user, admin_user
):
"""Availability is expanded into 15-minute slots.""" """Availability is expanded into 15-minute slots."""
# Admin sets 1-hour availability # Admin sets 1-hour availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -139,15 +157,17 @@ class TestGetSlots:
"slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}],
}, },
) )
# User gets slots - should be 4 x 15-minute slots # User gets slots - should be 4 x 15-minute slots
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["slots"]) == 4 assert len(data["slots"]) == 4
# Verify times # Verify times
assert "09:00:00" in data["slots"][0]["start_time"] assert "09:00:00" in data["slots"][0]["start_time"]
assert "09:15:00" in data["slots"][0]["end_time"] assert "09:15:00" in data["slots"][0]["end_time"]
@ -156,7 +176,9 @@ class TestGetSlots:
assert "10:00:00" in data["slots"][3]["end_time"] assert "10:00:00" in data["slots"][3]["end_time"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_slots_excludes_booked(self, client_factory, regular_user, admin_user): async def test_get_slots_excludes_booked(
self, client_factory, regular_user, admin_user
):
"""Already booked slots are excluded from available slots.""" """Already booked slots are excluded from available slots."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -167,17 +189,19 @@ class TestGetSlots:
"slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}],
}, },
) )
# User books first slot # User books first slot
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post( await client.post(
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
# Get slots again - should have 3 left # Get slots again - should have 3 left
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["slots"]) == 3 assert len(data["slots"]) == 3
@ -189,6 +213,7 @@ class TestGetSlots:
# Booking Tests # Booking Tests
# ============================================================================= # =============================================================================
class TestCreateBooking: class TestCreateBooking:
"""Test creating bookings.""" """Test creating bookings."""
@ -204,7 +229,7 @@ class TestCreateBooking:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books # User books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post( response = await client.post(
@ -214,7 +239,7 @@ class TestCreateBooking:
"note": "Discussion about project", "note": "Discussion about project",
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["user_id"] == regular_user["user"]["id"] assert data["user_id"] == regular_user["user"]["id"]
@ -235,20 +260,22 @@ class TestCreateBooking:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books without note # User books without note
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post( response = await client.post(
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["note"] is None assert data["note"] is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_double_book_slot(self, client_factory, regular_user, admin_user, alt_regular_user): async def test_cannot_double_book_slot(
self, client_factory, regular_user, admin_user, alt_regular_user
):
"""Cannot book a slot that's already booked.""" """Cannot book a slot that's already booked."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -259,7 +286,7 @@ class TestCreateBooking:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# First user books # First user books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post( response = await client.post(
@ -267,19 +294,21 @@ class TestCreateBooking:
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
assert response.status_code == 200 assert response.status_code == 200
# Second user tries to book same slot # Second user tries to book same slot
async with client_factory.create(cookies=alt_regular_user["cookies"]) as client: async with client_factory.create(cookies=alt_regular_user["cookies"]) as client:
response = await client.post( response = await client.post(
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
assert response.status_code == 409 assert response.status_code == 409
assert "already been booked" in response.json()["detail"] assert "already been booked" in response.json()["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_book_outside_availability(self, client_factory, regular_user, admin_user): async def test_cannot_book_outside_availability(
self, client_factory, regular_user, admin_user
):
"""Cannot book a slot outside of availability.""" """Cannot book a slot outside of availability."""
# Admin sets availability for morning only # Admin sets availability for morning only
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -290,14 +319,14 @@ class TestCreateBooking:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User tries to book afternoon slot # User tries to book afternoon slot
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post( response = await client.post(
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T14:00:00Z"}, json={"slot_start": f"{tomorrow()}T14:00:00Z"},
) )
assert response.status_code == 400 assert response.status_code == 400
assert "not within any available time ranges" in response.json()["detail"] assert "not within any available time ranges" in response.json()["detail"]
@ -306,6 +335,7 @@ class TestCreateBooking:
# Date Validation Tests # Date Validation Tests
# ============================================================================= # =============================================================================
class TestBookingDateValidation: class TestBookingDateValidation:
"""Test date validation for bookings.""" """Test date validation for bookings."""
@ -317,9 +347,12 @@ class TestBookingDateValidation:
"/api/booking", "/api/booking",
json={"slot_start": f"{date.today()}T09:00:00Z"}, json={"slot_start": f"{date.today()}T09:00:00Z"},
) )
assert response.status_code == 400 assert response.status_code == 400
assert "past" in response.json()["detail"].lower() or "today" in response.json()["detail"].lower() assert (
"past" in response.json()["detail"].lower()
or "today" in response.json()["detail"].lower()
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_book_past_date(self, client_factory, regular_user): async def test_cannot_book_past_date(self, client_factory, regular_user):
@ -330,7 +363,7 @@ class TestBookingDateValidation:
"/api/booking", "/api/booking",
json={"slot_start": f"{yesterday}T09:00:00Z"}, json={"slot_start": f"{yesterday}T09:00:00Z"},
) )
assert response.status_code == 400 assert response.status_code == 400
@pytest.mark.asyncio @pytest.mark.asyncio
@ -342,7 +375,7 @@ class TestBookingDateValidation:
"/api/booking", "/api/booking",
json={"slot_start": f"{too_far}T09:00:00Z"}, json={"slot_start": f"{too_far}T09:00:00Z"},
) )
assert response.status_code == 400 assert response.status_code == 400
assert "30" in response.json()["detail"] assert "30" in response.json()["detail"]
@ -350,8 +383,10 @@ class TestBookingDateValidation:
async def test_cannot_get_slots_today(self, client_factory, regular_user): async def test_cannot_get_slots_today(self, client_factory, regular_user):
"""Cannot get slots for today.""" """Cannot get slots for today."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(date.today())}) response = await client.get(
"/api/booking/slots", params={"date": str(date.today())}
)
assert response.status_code == 400 assert response.status_code == 400
@pytest.mark.asyncio @pytest.mark.asyncio
@ -359,8 +394,10 @@ class TestBookingDateValidation:
"""Cannot get slots for past date.""" """Cannot get slots for past date."""
yesterday = date.today() - timedelta(days=1) yesterday = date.today() - timedelta(days=1)
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(yesterday)}) response = await client.get(
"/api/booking/slots", params={"date": str(yesterday)}
)
assert response.status_code == 400 assert response.status_code == 400
@ -368,11 +405,14 @@ class TestBookingDateValidation:
# Time Validation Tests # Time Validation Tests
# ============================================================================= # =============================================================================
class TestBookingTimeValidation: class TestBookingTimeValidation:
"""Test time validation for bookings.""" """Test time validation for bookings."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_slot_must_be_15min_boundary(self, client_factory, regular_user, admin_user): async def test_slot_must_be_15min_boundary(
self, client_factory, regular_user, admin_user
):
"""Slot start time must be on 15-minute boundary.""" """Slot start time must be on 15-minute boundary."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -383,14 +423,14 @@ class TestBookingTimeValidation:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User tries to book at 09:05 # User tries to book at 09:05
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post( response = await client.post(
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:05:00Z"}, json={"slot_start": f"{tomorrow()}T09:05:00Z"},
) )
assert response.status_code == 400 assert response.status_code == 400
assert "15-minute" in response.json()["detail"] assert "15-minute" in response.json()["detail"]
@ -399,6 +439,7 @@ class TestBookingTimeValidation:
# Note Validation Tests # Note Validation Tests
# ============================================================================= # =============================================================================
class TestBookingNoteValidation: class TestBookingNoteValidation:
"""Test note validation for bookings.""" """Test note validation for bookings."""
@ -414,7 +455,7 @@ class TestBookingNoteValidation:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User tries to book with long note # User tries to book with long note
long_note = "x" * 145 long_note = "x" * 145
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
@ -422,11 +463,13 @@ class TestBookingNoteValidation:
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": long_note}, json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": long_note},
) )
assert response.status_code == 422 assert response.status_code == 422
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_note_exactly_144_chars(self, client_factory, regular_user, admin_user): async def test_note_exactly_144_chars(
self, client_factory, regular_user, admin_user
):
"""Note of exactly 144 characters is allowed.""" """Note of exactly 144 characters is allowed."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -437,7 +480,7 @@ class TestBookingNoteValidation:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books with exactly 144 char note # User books with exactly 144 char note
note = "x" * 144 note = "x" * 144
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
@ -445,7 +488,7 @@ class TestBookingNoteValidation:
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": note}, json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": note},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["note"] == note assert response.json()["note"] == note
@ -454,6 +497,7 @@ class TestBookingNoteValidation:
# User Appointments Tests # User Appointments Tests
# ============================================================================= # =============================================================================
class TestUserAppointments: class TestUserAppointments:
"""Test user appointments endpoints.""" """Test user appointments endpoints."""
@ -462,12 +506,14 @@ class TestUserAppointments:
"""Returns empty list when user has no appointments.""" """Returns empty list when user has no appointments."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/appointments") response = await client.get("/api/appointments")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == [] assert response.json() == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_my_appointments_with_bookings(self, client_factory, regular_user, admin_user): async def test_get_my_appointments_with_bookings(
self, client_factory, regular_user, admin_user
):
"""Returns user's appointments.""" """Returns user's appointments."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -478,7 +524,7 @@ class TestUserAppointments:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books two slots # User books two slots
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post( await client.post(
@ -489,10 +535,10 @@ class TestUserAppointments:
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:15:00Z", "note": "Second"}, json={"slot_start": f"{tomorrow()}T09:15:00Z", "note": "Second"},
) )
# Get appointments # Get appointments
response = await client.get("/api/appointments") response = await client.get("/api/appointments")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data) == 2 assert len(data) == 2
@ -502,11 +548,13 @@ class TestUserAppointments:
assert "Second" in notes assert "Second" in notes
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_view_user_appointments(self, client_factory, admin_user): async def test_admin_cannot_view_user_appointments(
self, client_factory, admin_user
):
"""Admin cannot access user appointments endpoint.""" """Admin cannot access user appointments endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/appointments") response = await client.get("/api/appointments")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@ -520,7 +568,9 @@ class TestCancelAppointment:
"""Test cancelling appointments.""" """Test cancelling appointments."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cancel_own_appointment(self, client_factory, regular_user, admin_user): async def test_cancel_own_appointment(
self, client_factory, regular_user, admin_user
):
"""User can cancel their own appointment.""" """User can cancel their own appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -531,7 +581,7 @@ class TestCancelAppointment:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books # User books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
book_response = await client.post( book_response = await client.post(
@ -539,17 +589,19 @@ class TestCancelAppointment:
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
apt_id = book_response.json()["id"] apt_id = book_response.json()["id"]
# Cancel # Cancel
response = await client.post(f"/api/appointments/{apt_id}/cancel") response = await client.post(f"/api/appointments/{apt_id}/cancel")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["status"] == "cancelled_by_user" assert data["status"] == "cancelled_by_user"
assert data["cancelled_at"] is not None assert data["cancelled_at"] is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_cancel_others_appointment(self, client_factory, regular_user, alt_regular_user, admin_user): async def test_cannot_cancel_others_appointment(
self, client_factory, regular_user, alt_regular_user, admin_user
):
"""User cannot cancel another user's appointment.""" """User cannot cancel another user's appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -560,7 +612,7 @@ class TestCancelAppointment:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# First user books # First user books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
book_response = await client.post( book_response = await client.post(
@ -568,24 +620,28 @@ class TestCancelAppointment:
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
apt_id = book_response.json()["id"] apt_id = book_response.json()["id"]
# Second user tries to cancel # Second user tries to cancel
async with client_factory.create(cookies=alt_regular_user["cookies"]) as client: async with client_factory.create(cookies=alt_regular_user["cookies"]) as client:
response = await client.post(f"/api/appointments/{apt_id}/cancel") response = await client.post(f"/api/appointments/{apt_id}/cancel")
assert response.status_code == 403 assert response.status_code == 403
assert "another user" in response.json()["detail"].lower() assert "another user" in response.json()["detail"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_cancel_nonexistent_appointment(self, client_factory, regular_user): async def test_cannot_cancel_nonexistent_appointment(
self, client_factory, regular_user
):
"""Returns 404 for non-existent appointment.""" """Returns 404 for non-existent appointment."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post("/api/appointments/99999/cancel") response = await client.post("/api/appointments/99999/cancel")
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user): async def test_cannot_cancel_already_cancelled(
self, client_factory, regular_user, admin_user
):
"""Cannot cancel an already cancelled appointment.""" """Cannot cancel an already cancelled appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -596,7 +652,7 @@ class TestCancelAppointment:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books and cancels # User books and cancels
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
book_response = await client.post( book_response = await client.post(
@ -605,23 +661,27 @@ class TestCancelAppointment:
) )
apt_id = book_response.json()["id"] apt_id = book_response.json()["id"]
await client.post(f"/api/appointments/{apt_id}/cancel") await client.post(f"/api/appointments/{apt_id}/cancel")
# Try to cancel again # Try to cancel again
response = await client.post(f"/api/appointments/{apt_id}/cancel") response = await client.post(f"/api/appointments/{apt_id}/cancel")
assert response.status_code == 400 assert response.status_code == 400
assert "cancelled_by_user" in response.json()["detail"] assert "cancelled_by_user" in response.json()["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_use_user_cancel_endpoint(self, client_factory, admin_user): async def test_admin_cannot_use_user_cancel_endpoint(
self, client_factory, admin_user
):
"""Admin cannot use user cancel endpoint.""" """Admin cannot use user cancel endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/appointments/1/cancel") response = await client.post("/api/appointments/1/cancel")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cancelled_slot_becomes_available(self, client_factory, regular_user, admin_user): async def test_cancelled_slot_becomes_available(
self, client_factory, regular_user, admin_user
):
"""After cancelling, the slot becomes available again.""" """After cancelling, the slot becomes available again."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -632,7 +692,7 @@ class TestCancelAppointment:
"slots": [{"start_time": "09:00:00", "end_time": "09:30:00"}], "slots": [{"start_time": "09:00:00", "end_time": "09:30:00"}],
}, },
) )
# User books # User books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
book_response = await client.post( book_response = await client.post(
@ -640,17 +700,17 @@ class TestCancelAppointment:
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
apt_id = book_response.json()["id"] apt_id = book_response.json()["id"]
# Check slots - should have 1 slot left (09:15) # Check slots - should have 1 slot left (09:15)
slots_response = await client.get( slots_response = await client.get(
"/api/booking/slots", "/api/booking/slots",
params={"date": str(tomorrow())}, params={"date": str(tomorrow())},
) )
assert len(slots_response.json()["slots"]) == 1 assert len(slots_response.json()["slots"]) == 1
# Cancel # Cancel
await client.post(f"/api/appointments/{apt_id}/cancel") await client.post(f"/api/appointments/{apt_id}/cancel")
# Check slots - should have 2 slots now # Check slots - should have 2 slots now
slots_response = await client.get( slots_response = await client.get(
"/api/booking/slots", "/api/booking/slots",
@ -663,7 +723,7 @@ class TestCancelAppointment:
"""User cannot cancel a past appointment.""" """User cannot cancel a past appointment."""
# Create a past appointment directly in DB # Create a past appointment directly in DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
past_time = datetime.now(timezone.utc) - timedelta(hours=1) past_time = datetime.now(UTC) - timedelta(hours=1)
appointment = Appointment( appointment = Appointment(
user_id=regular_user["user"]["id"], user_id=regular_user["user"]["id"],
slot_start=past_time, slot_start=past_time,
@ -674,11 +734,11 @@ class TestCancelAppointment:
await db.commit() await db.commit()
await db.refresh(appointment) await db.refresh(appointment)
apt_id = appointment.id apt_id = appointment.id
# Try to cancel # Try to cancel
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post(f"/api/appointments/{apt_id}/cancel") response = await client.post(f"/api/appointments/{apt_id}/cancel")
assert response.status_code == 400 assert response.status_code == 400
assert "past" in response.json()["detail"].lower() assert "past" in response.json()["detail"].lower()
@ -687,11 +747,14 @@ class TestCancelAppointment:
# Admin Appointments Tests # Admin Appointments Tests
# ============================================================================= # =============================================================================
class TestAdminViewAppointments: class TestAdminViewAppointments:
"""Test admin viewing all appointments.""" """Test admin viewing all appointments."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_view_all_appointments(self, client_factory, regular_user, admin_user): async def test_admin_can_view_all_appointments(
self, client_factory, regular_user, admin_user
):
"""Admin can view all appointments.""" """Admin can view all appointments."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -702,18 +765,18 @@ class TestAdminViewAppointments:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books # User books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post( await client.post(
"/api/booking", "/api/booking",
json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": "Test"}, json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": "Test"},
) )
# Admin views all appointments # Admin views all appointments
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.get("/api/admin/appointments") response = await admin_client.get("/api/admin/appointments")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# Paginated response # Paginated response
@ -725,11 +788,13 @@ class TestAdminViewAppointments:
assert any(apt["note"] == "Test" for apt in data["records"]) assert any(apt["note"] == "Test" for apt in data["records"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_view_all_appointments(self, client_factory, regular_user): async def test_regular_user_cannot_view_all_appointments(
self, client_factory, regular_user
):
"""Regular user cannot access admin appointments endpoint.""" """Regular user cannot access admin appointments endpoint."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/admin/appointments") response = await client.get("/api/admin/appointments")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@ -743,7 +808,9 @@ class TestAdminCancelAppointment:
"""Test admin cancelling appointments.""" """Test admin cancelling appointments."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_cancel_any_appointment(self, client_factory, regular_user, admin_user): async def test_admin_can_cancel_any_appointment(
self, client_factory, regular_user, admin_user
):
"""Admin can cancel any user's appointment.""" """Admin can cancel any user's appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -754,7 +821,7 @@ class TestAdminCancelAppointment:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books # User books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
book_response = await client.post( book_response = await client.post(
@ -762,18 +829,22 @@ class TestAdminCancelAppointment:
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
apt_id = book_response.json()["id"] apt_id = book_response.json()["id"]
# Admin cancels # Admin cancels
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["status"] == "cancelled_by_admin" assert data["status"] == "cancelled_by_admin"
assert data["cancelled_at"] is not None assert data["cancelled_at"] is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_use_admin_cancel(self, client_factory, regular_user, admin_user): async def test_regular_user_cannot_use_admin_cancel(
self, client_factory, regular_user, admin_user
):
"""Regular user cannot use admin cancel endpoint.""" """Regular user cannot use admin cancel endpoint."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -784,7 +855,7 @@ class TestAdminCancelAppointment:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books # User books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
book_response = await client.post( book_response = await client.post(
@ -792,22 +863,26 @@ class TestAdminCancelAppointment:
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
apt_id = book_response.json()["id"] apt_id = book_response.json()["id"]
# User tries to use admin cancel endpoint # User tries to use admin cancel endpoint
response = await client.post(f"/api/admin/appointments/{apt_id}/cancel") response = await client.post(f"/api/admin/appointments/{apt_id}/cancel")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cancel_nonexistent_appointment(self, client_factory, admin_user): async def test_admin_cancel_nonexistent_appointment(
self, client_factory, admin_user
):
"""Returns 404 for non-existent appointment.""" """Returns 404 for non-existent appointment."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/admin/appointments/99999/cancel") response = await client.post("/api/admin/appointments/99999/cancel")
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user): async def test_admin_cannot_cancel_already_cancelled(
self, client_factory, regular_user, admin_user
):
"""Admin cannot cancel an already cancelled appointment.""" """Admin cannot cancel an already cancelled appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -818,7 +893,7 @@ class TestAdminCancelAppointment:
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
}, },
) )
# User books # User books
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
book_response = await client.post( book_response = await client.post(
@ -826,23 +901,27 @@ class TestAdminCancelAppointment:
json={"slot_start": f"{tomorrow()}T09:00:00Z"}, json={"slot_start": f"{tomorrow()}T09:00:00Z"},
) )
apt_id = book_response.json()["id"] apt_id = book_response.json()["id"]
# User cancels their own appointment # User cancels their own appointment
await client.post(f"/api/appointments/{apt_id}/cancel") await client.post(f"/api/appointments/{apt_id}/cancel")
# Admin tries to cancel again # Admin tries to cancel again
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 400 assert response.status_code == 400
assert "cancelled_by_user" in response.json()["detail"] assert "cancelled_by_user" in response.json()["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_cancel_past_appointment(self, client_factory, regular_user, admin_user): async def test_admin_cannot_cancel_past_appointment(
self, client_factory, regular_user, admin_user
):
"""Admin cannot cancel a past appointment.""" """Admin cannot cancel a past appointment."""
# Create a past appointment directly in DB # Create a past appointment directly in DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
past_time = datetime.now(timezone.utc) - timedelta(hours=1) past_time = datetime.now(UTC) - timedelta(hours=1)
appointment = Appointment( appointment = Appointment(
user_id=regular_user["user"]["id"], user_id=regular_user["user"]["id"],
slot_start=past_time, slot_start=past_time,
@ -853,11 +932,12 @@ class TestAdminCancelAppointment:
await db.commit() await db.commit()
await db.refresh(appointment) await db.refresh(appointment)
apt_id = appointment.id apt_id = appointment.id
# Admin tries to cancel # Admin tries to cancel
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 400 assert response.status_code == 400
assert "past" in response.json()["detail"].lower() assert "past" in response.json()["detail"].lower()

View file

@ -2,12 +2,13 @@
Note: Registration now requires an invite code. Note: Registration now requires an invite code.
""" """
import pytest import pytest
from auth import COOKIE_NAME from auth import COOKIE_NAME
from models import ROLE_REGULAR from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
# Protected endpoint tests - without auth # Protected endpoint tests - without auth
@ -41,9 +42,11 @@ async def test_increment_counter_invalid_cookie(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_authenticated(client_factory): async def test_get_counter_authenticated(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -53,10 +56,10 @@ async def test_get_counter_authenticated(client_factory):
}, },
) )
cookies = dict(reg.cookies) cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed: async with client_factory.create(cookies=cookies) as authed:
response = await authed.get("/api/counter") response = await authed.get("/api/counter")
assert response.status_code == 200 assert response.status_code == 200
assert "value" in response.json() assert "value" in response.json()
@ -64,9 +67,11 @@ async def test_get_counter_authenticated(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter(client_factory): async def test_increment_counter(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -76,12 +81,12 @@ async def test_increment_counter(client_factory):
}, },
) )
cookies = dict(reg.cookies) cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed: async with client_factory.create(cookies=cookies) as authed:
# Get current value # Get current value
before = await authed.get("/api/counter") before = await authed.get("/api/counter")
before_value = before.json()["value"] before_value = before.json()["value"]
# Increment # Increment
response = await authed.post("/api/counter/increment") response = await authed.post("/api/counter/increment")
assert response.status_code == 200 assert response.status_code == 200
@ -91,9 +96,11 @@ async def test_increment_counter(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter_multiple(client_factory): async def test_increment_counter_multiple(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -103,26 +110,28 @@ async def test_increment_counter_multiple(client_factory):
}, },
) )
cookies = dict(reg.cookies) cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed: async with client_factory.create(cookies=cookies) as authed:
# Get starting value # Get starting value
before = await authed.get("/api/counter") before = await authed.get("/api/counter")
start = before.json()["value"] start = before.json()["value"]
# Increment 3 times # Increment 3 times
await authed.post("/api/counter/increment") await authed.post("/api/counter/increment")
await authed.post("/api/counter/increment") await authed.post("/api/counter/increment")
response = await authed.post("/api/counter/increment") response = await authed.post("/api/counter/increment")
assert response.json()["value"] == start + 3 assert response.json()["value"] == start + 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_after_increment(client_factory): async def test_get_counter_after_increment(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -132,14 +141,14 @@ async def test_get_counter_after_increment(client_factory):
}, },
) )
cookies = dict(reg.cookies) cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed: async with client_factory.create(cookies=cookies) as authed:
before = await authed.get("/api/counter") before = await authed.get("/api/counter")
start = before.json()["value"] start = before.json()["value"]
await authed.post("/api/counter/increment") await authed.post("/api/counter/increment")
await authed.post("/api/counter/increment") await authed.post("/api/counter/increment")
response = await authed.get("/api/counter") response = await authed.get("/api/counter")
assert response.json()["value"] == start + 2 assert response.json()["value"] == start + 2
@ -149,10 +158,12 @@ async def test_get_counter_after_increment(client_factory):
async def test_counter_shared_between_users(client_factory): async def test_counter_shared_between_users(client_factory):
# Create godfather and invites for two users # Create godfather and invites for two users
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id)
# Create first user # Create first user
reg1 = await client_factory.post( reg1 = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -163,15 +174,15 @@ async def test_counter_shared_between_users(client_factory):
}, },
) )
cookies1 = dict(reg1.cookies) cookies1 = dict(reg1.cookies)
async with client_factory.create(cookies=cookies1) as user1: async with client_factory.create(cookies=cookies1) as user1:
# Get starting value # Get starting value
before = await user1.get("/api/counter") before = await user1.get("/api/counter")
start = before.json()["value"] start = before.json()["value"]
await user1.post("/api/counter/increment") await user1.post("/api/counter/increment")
await user1.post("/api/counter/increment") await user1.post("/api/counter/increment")
# Create second user - should see the increments # Create second user - should see the increments
reg2 = await client_factory.post( reg2 = await client_factory.post(
"/api/auth/register", "/api/auth/register",
@ -182,14 +193,14 @@ async def test_counter_shared_between_users(client_factory):
}, },
) )
cookies2 = dict(reg2.cookies) cookies2 = dict(reg2.cookies)
async with client_factory.create(cookies=cookies2) as user2: async with client_factory.create(cookies=cookies2) as user2:
response = await user2.get("/api/counter") response = await user2.get("/api/counter")
assert response.json()["value"] == start + 2 assert response.json()["value"] == start + 2
# Second user increments # Second user increments
await user2.post("/api/counter/increment") await user2.post("/api/counter/increment")
# First user sees the increment # First user sees the increment
async with client_factory.create(cookies=cookies1) as user1: async with client_factory.create(cookies=cookies1) as user1:
response = await user1.get("/api/counter") response = await user1.get("/api/counter")

View file

@ -1,22 +1,23 @@
"""Tests for invite functionality.""" """Tests for invite functionality."""
import pytest import pytest
from sqlalchemy import select from sqlalchemy import select
from invite_utils import ( from invite_utils import (
generate_invite_identifier,
normalize_identifier,
is_valid_identifier_format,
BIP39_WORDS, BIP39_WORDS,
generate_invite_identifier,
is_valid_identifier_format,
normalize_identifier,
) )
from models import Invite, InviteStatus, User, ROLE_REGULAR from models import ROLE_REGULAR, Invite, InviteStatus, User
from tests.helpers import unique_email
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
from tests.helpers import unique_email
# ============================================================================ # ============================================================================
# Invite Utils Tests # Invite Utils Tests
# ============================================================================ # ============================================================================
def test_bip39_words_loaded(): def test_bip39_words_loaded():
"""BIP39 word list should have exactly 2048 words.""" """BIP39 word list should have exactly 2048 words."""
assert len(BIP39_WORDS) == 2048 assert len(BIP39_WORDS) == 2048
@ -26,7 +27,7 @@ def test_generate_invite_identifier_format():
"""Generated identifier should have word-word-NN format.""" """Generated identifier should have word-word-NN format."""
identifier = generate_invite_identifier() identifier = generate_invite_identifier()
assert is_valid_identifier_format(identifier) assert is_valid_identifier_format(identifier)
parts = identifier.split("-") parts = identifier.split("-")
assert len(parts) == 3 assert len(parts) == 3
assert parts[0] in BIP39_WORDS assert parts[0] in BIP39_WORDS
@ -74,11 +75,11 @@ def test_is_valid_identifier_format_invalid():
assert is_valid_identifier_format("apple-banana") is False assert is_valid_identifier_format("apple-banana") is False
assert is_valid_identifier_format("apple-banana-42-extra") is False assert is_valid_identifier_format("apple-banana-42-extra") is False
assert is_valid_identifier_format("applebanan42") is False assert is_valid_identifier_format("applebanan42") is False
# Empty parts # Empty parts
assert is_valid_identifier_format("-banana-42") is False assert is_valid_identifier_format("-banana-42") is False
assert is_valid_identifier_format("apple--42") is False assert is_valid_identifier_format("apple--42") is False
# Invalid number format # Invalid number format
assert is_valid_identifier_format("apple-banana-4") is False # Single digit assert is_valid_identifier_format("apple-banana-4") is False # Single digit
assert is_valid_identifier_format("apple-banana-420") is False # Three digits assert is_valid_identifier_format("apple-banana-420") is False # Three digits
@ -89,6 +90,7 @@ def test_is_valid_identifier_format_invalid():
# Invite Model Tests # Invite Model Tests
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_invite(client_factory): async def test_create_invite(client_factory):
"""Can create an invite with godfather.""" """Can create an invite with godfather."""
@ -97,7 +99,7 @@ async def test_create_invite(client_factory):
godfather = await create_user_with_roles( godfather = await create_user_with_roles(
db, unique_email("godfather"), "password123", [ROLE_REGULAR] db, unique_email("godfather"), "password123", [ROLE_REGULAR]
) )
# Create invite # Create invite
invite = Invite( invite = Invite(
identifier="test-invite-01", identifier="test-invite-01",
@ -107,7 +109,7 @@ async def test_create_invite(client_factory):
db.add(invite) db.add(invite)
await db.commit() await db.commit()
await db.refresh(invite) await db.refresh(invite)
assert invite.id is not None assert invite.id is not None
assert invite.identifier == "test-invite-01" assert invite.identifier == "test-invite-01"
assert invite.godfather_id == godfather.id assert invite.godfather_id == godfather.id
@ -125,20 +127,20 @@ async def test_invite_godfather_relationship(client_factory):
godfather = await create_user_with_roles( godfather = await create_user_with_roles(
db, unique_email("godfather"), "password123", [ROLE_REGULAR] db, unique_email("godfather"), "password123", [ROLE_REGULAR]
) )
invite = Invite( invite = Invite(
identifier="rel-test-01", identifier="rel-test-01",
godfather_id=godfather.id, godfather_id=godfather.id,
) )
db.add(invite) db.add(invite)
await db.commit() await db.commit()
# Query invite fresh # Query invite fresh
result = await db.execute( result = await db.execute(
select(Invite).where(Invite.identifier == "rel-test-01") select(Invite).where(Invite.identifier == "rel-test-01")
) )
loaded_invite = result.scalar_one() loaded_invite = result.scalar_one()
assert loaded_invite.godfather is not None assert loaded_invite.godfather is not None
assert loaded_invite.godfather.email == godfather.email assert loaded_invite.godfather.email == godfather.email
@ -147,25 +149,25 @@ async def test_invite_godfather_relationship(client_factory):
async def test_invite_unique_identifier(client_factory): async def test_invite_unique_identifier(client_factory):
"""Invite identifier must be unique.""" """Invite identifier must be unique."""
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles( godfather = await create_user_with_roles(
db, unique_email("godfather"), "password123", [ROLE_REGULAR] db, unique_email("godfather"), "password123", [ROLE_REGULAR]
) )
invite1 = Invite( invite1 = Invite(
identifier="unique-test-01", identifier="unique-test-01",
godfather_id=godfather.id, godfather_id=godfather.id,
) )
db.add(invite1) db.add(invite1)
await db.commit() await db.commit()
invite2 = Invite( invite2 = Invite(
identifier="unique-test-01", # Same identifier identifier="unique-test-01", # Same identifier
godfather_id=godfather.id, godfather_id=godfather.id,
) )
db.add(invite2) db.add(invite2)
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
await db.commit() await db.commit()
@ -173,8 +175,8 @@ async def test_invite_unique_identifier(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invite_status_transitions(client_factory): async def test_invite_status_transitions(client_factory):
"""Invite status can be changed.""" """Invite status can be changed."""
from datetime import datetime, UTC from datetime import UTC, datetime
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles( godfather = await create_user_with_roles(
db, unique_email("godfather"), "password123", [ROLE_REGULAR] db, unique_email("godfather"), "password123", [ROLE_REGULAR]
@ -182,7 +184,7 @@ async def test_invite_status_transitions(client_factory):
user = await create_user_with_roles( user = await create_user_with_roles(
db, unique_email("invitee"), "password123", [ROLE_REGULAR] db, unique_email("invitee"), "password123", [ROLE_REGULAR]
) )
invite = Invite( invite = Invite(
identifier="status-test-01", identifier="status-test-01",
godfather_id=godfather.id, godfather_id=godfather.id,
@ -190,14 +192,14 @@ async def test_invite_status_transitions(client_factory):
) )
db.add(invite) db.add(invite)
await db.commit() await db.commit()
# Transition to SPENT # Transition to SPENT
invite.status = InviteStatus.SPENT invite.status = InviteStatus.SPENT
invite.used_by_id = user.id invite.used_by_id = user.id
invite.spent_at = datetime.now(UTC) invite.spent_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(invite) await db.refresh(invite)
assert invite.status == InviteStatus.SPENT assert invite.status == InviteStatus.SPENT
assert invite.used_by_id == user.id assert invite.used_by_id == user.id
assert invite.spent_at is not None assert invite.spent_at is not None
@ -206,13 +208,13 @@ async def test_invite_status_transitions(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invite_revoke(client_factory): async def test_invite_revoke(client_factory):
"""Invite can be revoked.""" """Invite can be revoked."""
from datetime import datetime, UTC from datetime import UTC, datetime
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles( godfather = await create_user_with_roles(
db, unique_email("godfather"), "password123", [ROLE_REGULAR] db, unique_email("godfather"), "password123", [ROLE_REGULAR]
) )
invite = Invite( invite = Invite(
identifier="revoke-test-01", identifier="revoke-test-01",
godfather_id=godfather.id, godfather_id=godfather.id,
@ -220,13 +222,13 @@ async def test_invite_revoke(client_factory):
) )
db.add(invite) db.add(invite)
await db.commit() await db.commit()
# Revoke # Revoke
invite.status = InviteStatus.REVOKED invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(UTC) invite.revoked_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(invite) await db.refresh(invite)
assert invite.status == InviteStatus.REVOKED assert invite.status == InviteStatus.REVOKED
assert invite.revoked_at is not None assert invite.revoked_at is not None
assert invite.used_by_id is None # Not used assert invite.used_by_id is None # Not used
@ -236,6 +238,7 @@ async def test_invite_revoke(client_factory):
# User Godfather Tests # User Godfather Tests
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_godfather_relationship(client_factory): async def test_user_godfather_relationship(client_factory):
"""User can have a godfather.""" """User can have a godfather."""
@ -243,7 +246,7 @@ async def test_user_godfather_relationship(client_factory):
godfather = await create_user_with_roles( godfather = await create_user_with_roles(
db, unique_email("godfather"), "password123", [ROLE_REGULAR] db, unique_email("godfather"), "password123", [ROLE_REGULAR]
) )
# Create user with godfather # Create user with godfather
user = User( user = User(
email=unique_email("godchild"), email=unique_email("godchild"),
@ -252,13 +255,11 @@ async def test_user_godfather_relationship(client_factory):
) )
db.add(user) db.add(user)
await db.commit() await db.commit()
# Query user fresh # Query user fresh
result = await db.execute( result = await db.execute(select(User).where(User.id == user.id))
select(User).where(User.id == user.id)
)
loaded_user = result.scalar_one() loaded_user = result.scalar_one()
assert loaded_user.godfather_id == godfather.id assert loaded_user.godfather_id == godfather.id
assert loaded_user.godfather is not None assert loaded_user.godfather is not None
assert loaded_user.godfather.email == godfather.email assert loaded_user.godfather.email == godfather.email
@ -271,7 +272,7 @@ async def test_user_without_godfather(client_factory):
user = await create_user_with_roles( user = await create_user_with_roles(
db, unique_email("noparent"), "password123", [ROLE_REGULAR] db, unique_email("noparent"), "password123", [ROLE_REGULAR]
) )
assert user.godfather_id is None assert user.godfather_id is None
assert user.godfather is None assert user.godfather is None
@ -280,6 +281,7 @@ async def test_user_without_godfather(client_factory):
# Admin Create Invite API Tests (Phase 2) # Admin Create Invite API Tests (Phase 2)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_create_invite(client_factory, admin_user, regular_user): async def test_admin_can_create_invite(client_factory, admin_user, regular_user):
"""Admin can create an invite for a regular user.""" """Admin can create an invite for a regular user."""
@ -290,12 +292,12 @@ async def test_admin_can_create_invite(client_factory, admin_user, regular_user)
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
response = await client.post( response = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["godfather_id"] == godfather.id assert data["godfather_id"] == godfather.id
@ -318,12 +320,12 @@ async def test_admin_can_create_invite_for_self(client_factory, admin_user):
select(User).where(User.email == admin_user["email"]) select(User).where(User.email == admin_user["email"])
) )
admin = result.scalar_one() admin = result.scalar_one()
response = await client.post( response = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": admin.id}, json={"godfather_id": admin.id},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["godfather_id"] == admin.id assert data["godfather_id"] == admin.id
@ -338,7 +340,7 @@ async def test_regular_user_cannot_create_invite(client_factory, regular_user):
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": 1}, json={"godfather_id": 1},
) )
assert response.status_code == 403 assert response.status_code == 403
@ -350,7 +352,7 @@ async def test_unauthenticated_cannot_create_invite(client_factory):
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": 1}, json={"godfather_id": 1},
) )
assert response.status_code == 401 assert response.status_code == 401
@ -362,7 +364,7 @@ async def test_create_invite_invalid_godfather(client_factory, admin_user):
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": 99999}, json={"godfather_id": 99999},
) )
assert response.status_code == 400 assert response.status_code == 400
assert "not found" in response.json()["detail"].lower() assert "not found" in response.json()["detail"].lower()
@ -376,39 +378,39 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
response = await client.post( response = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
data = response.json() data = response.json()
invite_id = data["id"] invite_id = data["id"]
# Query from DB # Query from DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(select(Invite).where(Invite.id == invite_id))
select(Invite).where(Invite.id == invite_id)
)
invite = result.scalar_one() invite = result.scalar_one()
assert invite.identifier == data["identifier"] assert invite.identifier == data["identifier"]
assert invite.godfather_id == godfather.id assert invite.godfather_id == godfather.id
assert invite.status == InviteStatus.READY assert invite.status == InviteStatus.READY
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user): async def test_create_invite_retries_on_collision(
client_factory, admin_user, regular_user
):
"""Create invite retries with new identifier on collision.""" """Create invite retries with new identifier on collision."""
from unittest.mock import patch from unittest.mock import patch
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
# Create first invite normally # Create first invite normally
response1 = await client.post( response1 = await client.post(
"/api/admin/invites", "/api/admin/invites",
@ -416,22 +418,25 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
) )
assert response1.status_code == 200 assert response1.status_code == 200
identifier1 = response1.json()["identifier"] identifier1 = response1.json()["identifier"]
# Mock generator to first return the same identifier (collision), then a new one # Mock generator to first return the same identifier (collision), then a new one
call_count = 0 call_count = 0
def mock_generator(): def mock_generator():
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
if call_count == 1: if call_count == 1:
return identifier1 # Will collide return identifier1 # Will collide
return f"unique-word-{call_count:02d}" # Won't collide return f"unique-word-{call_count:02d}" # Won't collide
with patch("routes.invites.generate_invite_identifier", side_effect=mock_generator): with patch(
"routes.invites.generate_invite_identifier", side_effect=mock_generator
):
response2 = await client.post( response2 = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
assert response2.status_code == 200 assert response2.status_code == 200
# Should have retried and gotten a new identifier # Should have retried and gotten a new identifier
assert response2.json()["identifier"] != identifier1 assert response2.json()["identifier"] != identifier1
@ -442,6 +447,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
# Invite Check API Tests (Phase 3) # Invite Check API Tests (Phase 3)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_invite_valid(client_factory, admin_user, regular_user): async def test_check_invite_valid(client_factory, admin_user, regular_user):
"""Check endpoint returns valid=True for READY invite.""" """Check endpoint returns valid=True for READY invite."""
@ -452,17 +458,17 @@ async def test_check_invite_valid(client_factory, admin_user, regular_user):
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Check invite (no auth needed) # Check invite (no auth needed)
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.get(f"/api/invites/{identifier}/check") response = await client.get(f"/api/invites/{identifier}/check")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["valid"] is True assert data["valid"] is True
@ -475,7 +481,7 @@ async def test_check_invite_not_found(client_factory):
"""Check endpoint returns valid=False for unknown invite.""" """Check endpoint returns valid=False for unknown invite."""
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.get("/api/invites/fake-invite-99/check") response = await client.get("/api/invites/fake-invite-99/check")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["valid"] is False assert data["valid"] is False
@ -492,14 +498,14 @@ async def test_check_invite_invalid_format(client_factory):
data = response.json() data = response.json()
assert data["valid"] is False assert data["valid"] is False
assert "format" in data["error"].lower() assert "format" in data["error"].lower()
# Single digit number # Single digit number
response = await client.get("/api/invites/word-word-1/check") response = await client.get("/api/invites/word-word-1/check")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["valid"] is False assert data["valid"] is False
assert "format" in data["error"].lower() assert "format" in data["error"].lower()
# Too many parts # Too many parts
response = await client.get("/api/invites/word-word-word-00/check") response = await client.get("/api/invites/word-word-word-00/check")
assert response.status_code == 200 assert response.status_code == 200
@ -509,7 +515,9 @@ async def test_check_invite_invalid_format(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_invite_spent_returns_not_found(client_factory, admin_user, regular_user): async def test_check_invite_spent_returns_not_found(
client_factory, admin_user, regular_user
):
"""Check endpoint returns same error for spent invite as for non-existent (no info leakage).""" """Check endpoint returns same error for spent invite as for non-existent (no info leakage)."""
# Create invite # Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -518,13 +526,13 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Use the invite # Use the invite
async with client_factory.create() as client: async with client_factory.create() as client:
await client.post( await client.post(
@ -535,11 +543,11 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
# Check spent invite - should return same error as non-existent # Check spent invite - should return same error as non-existent
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.get(f"/api/invites/{identifier}/check") response = await client.get(f"/api/invites/{identifier}/check")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["valid"] is False assert data["valid"] is False
@ -547,10 +555,12 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_invite_revoked_returns_not_found(client_factory, admin_user, regular_user): async def test_check_invite_revoked_returns_not_found(
client_factory, admin_user, regular_user
):
"""Check endpoint returns same error for revoked invite as for non-existent (no info leakage).""" """Check endpoint returns same error for revoked invite as for non-existent (no info leakage)."""
from datetime import datetime, UTC from datetime import UTC, datetime
# Create invite # Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
@ -558,14 +568,14 @@ async def test_check_invite_revoked_returns_not_found(client_factory, admin_user
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
invite_id = create_resp.json()["id"] invite_id = create_resp.json()["id"]
# Revoke the invite # Revoke the invite
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute(select(Invite).where(Invite.id == invite_id)) result = await db.execute(select(Invite).where(Invite.id == invite_id))
@ -573,11 +583,11 @@ async def test_check_invite_revoked_returns_not_found(client_factory, admin_user
invite.status = InviteStatus.REVOKED invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(UTC) invite.revoked_at = datetime.now(UTC)
await db.commit() await db.commit()
# Check revoked invite - should return same error as non-existent # Check revoked invite - should return same error as non-existent
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.get(f"/api/invites/{identifier}/check") response = await client.get(f"/api/invites/{identifier}/check")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["valid"] is False assert data["valid"] is False
@ -594,17 +604,17 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Check with uppercase # Check with uppercase
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.get(f"/api/invites/{identifier.upper()}/check") response = await client.get(f"/api/invites/{identifier.upper()}/check")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["valid"] is True assert response.json()["valid"] is True
@ -613,6 +623,7 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular
# Register with Invite Tests (Phase 3) # Register with Invite Tests (Phase 3)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_with_valid_invite(client_factory, admin_user, regular_user): async def test_register_with_valid_invite(client_factory, admin_user, regular_user):
"""Can register with valid invite code.""" """Can register with valid invite code."""
@ -624,13 +635,13 @@ async def test_register_with_valid_invite(client_factory, admin_user, regular_us
) )
godfather = result.scalar_one() godfather = result.scalar_one()
godfather_id = godfather.id godfather_id = godfather.id
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather_id}, json={"godfather_id": godfather_id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Register with invite # Register with invite
new_email = unique_email("newuser") new_email = unique_email("newuser")
async with client_factory.create() as client: async with client_factory.create() as client:
@ -642,7 +653,7 @@ async def test_register_with_valid_invite(client_factory, admin_user, regular_us
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["email"] == new_email assert data["email"] == new_email
@ -659,7 +670,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
@ -667,7 +678,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
invite_data = create_resp.json() invite_data = create_resp.json()
identifier = invite_data["identifier"] identifier = invite_data["identifier"]
invite_id = invite_data["id"] invite_id = invite_data["id"]
# Register # Register
async with client_factory.create() as client: async with client_factory.create() as client:
await client.post( await client.post(
@ -678,14 +689,12 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
# Check invite status # Check invite status
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(select(Invite).where(Invite.id == invite_id))
select(Invite).where(Invite.id == invite_id)
)
invite = result.scalar_one() invite = result.scalar_one()
assert invite.status == InviteStatus.SPENT assert invite.status == InviteStatus.SPENT
assert invite.used_by_id is not None assert invite.used_by_id is not None
assert invite.spent_at is not None assert invite.spent_at is not None
@ -702,13 +711,13 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user)
) )
godfather = result.scalar_one() godfather = result.scalar_one()
godfather_id = godfather.id godfather_id = godfather.id
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather_id}, json={"godfather_id": godfather_id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Register # Register
new_email = unique_email("godchildtest") new_email = unique_email("godchildtest")
async with client_factory.create() as client: async with client_factory.create() as client:
@ -720,14 +729,12 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user)
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
# Check user's godfather # Check user's godfather
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(select(User).where(User.email == new_email))
select(User).where(User.email == new_email)
)
new_user = result.scalar_one() new_user = result.scalar_one()
assert new_user.godfather_id == godfather_id assert new_user.godfather_id == godfather_id
@ -743,7 +750,7 @@ async def test_register_with_invalid_invite(client_factory):
"invite_identifier": "fake-invite-99", "invite_identifier": "fake-invite-99",
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "invalid" in response.json()["detail"].lower() assert "invalid" in response.json()["detail"].lower()
@ -758,13 +765,13 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# First registration # First registration
async with client_factory.create() as client: async with client_factory.create() as client:
await client.post( await client.post(
@ -775,7 +782,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
# Second registration with same invite # Second registration with same invite
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.post( response = await client.post(
@ -786,7 +793,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "invalid invite code" in response.json()["detail"].lower() assert "invalid invite code" in response.json()["detail"].lower()
@ -794,8 +801,8 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_with_revoked_invite(client_factory, admin_user, regular_user): async def test_register_with_revoked_invite(client_factory, admin_user, regular_user):
"""Cannot register with revoked invite.""" """Cannot register with revoked invite."""
from datetime import datetime, UTC from datetime import UTC, datetime
# Create invite # Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
@ -803,7 +810,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
@ -811,17 +818,15 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
invite_data = create_resp.json() invite_data = create_resp.json()
identifier = invite_data["identifier"] identifier = invite_data["identifier"]
invite_id = invite_data["id"] invite_id = invite_data["id"]
# Revoke invite directly in DB # Revoke invite directly in DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(select(Invite).where(Invite.id == invite_id))
select(Invite).where(Invite.id == invite_id)
)
invite = result.scalar_one() invite = result.scalar_one()
invite.status = InviteStatus.REVOKED invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(UTC) invite.revoked_at = datetime.now(UTC)
await db.commit() await db.commit()
# Try to register # Try to register
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.post( response = await client.post(
@ -832,7 +837,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "invalid invite code" in response.json()["detail"].lower() assert "invalid invite code" in response.json()["detail"].lower()
@ -847,13 +852,13 @@ async def test_register_duplicate_email(client_factory, admin_user, regular_user
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Try to register with existing email # Try to register with existing email
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.post( response = await client.post(
@ -864,7 +869,7 @@ async def test_register_duplicate_email(client_factory, admin_user, regular_user
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
assert response.status_code == 400 assert response.status_code == 400
assert "already registered" in response.json()["detail"].lower() assert "already registered" in response.json()["detail"].lower()
@ -879,13 +884,13 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Register # Register
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.post( response = await client.post(
@ -896,7 +901,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
assert "auth_token" in response.cookies assert "auth_token" in response.cookies
@ -904,6 +909,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
# User Invites API Tests (Phase 4) # User Invites API Tests (Phase 4)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user): async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user):
"""Regular user can list their own invites.""" """Regular user can list their own invites."""
@ -914,14 +920,14 @@ async def test_regular_user_can_list_invites(client_factory, admin_user, regular
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
# List invites as regular user # List invites as regular user
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/invites") response = await client.get("/api/invites")
assert response.status_code == 200 assert response.status_code == 200
invites = response.json() invites = response.json()
assert len(invites) == 2 assert len(invites) == 2
@ -935,13 +941,15 @@ async def test_user_with_no_invites_gets_empty_list(client_factory, regular_user
"""User with no invites gets empty list.""" """User with no invites gets empty list."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/invites") response = await client.get("/api/invites")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == [] assert response.json() == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regular_user): async def test_spent_invite_shows_used_by_email(
client_factory, admin_user, regular_user
):
"""Spent invite shows who used it.""" """Spent invite shows who used it."""
# Create invite for regular user # Create invite for regular user
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -950,13 +958,13 @@ async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regu
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Use the invite # Use the invite
invitee_email = unique_email("invitee") invitee_email = unique_email("invitee")
async with client_factory.create() as client: async with client_factory.create() as client:
@ -968,11 +976,11 @@ async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regu
"invite_identifier": identifier, "invite_identifier": identifier,
}, },
) )
# Check that regular user sees the invitee email # Check that regular user sees the invitee email
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/invites") response = await client.get("/api/invites")
assert response.status_code == 200 assert response.status_code == 200
invites = response.json() invites = response.json()
assert len(invites) == 1 assert len(invites) == 1
@ -985,7 +993,7 @@ async def test_admin_cannot_list_own_invites(client_factory, admin_user):
"""Admin without VIEW_OWN_INVITES permission gets 403.""" """Admin without VIEW_OWN_INVITES permission gets 403."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/invites") response = await client.get("/api/invites")
assert response.status_code == 403 assert response.status_code == 403
@ -994,7 +1002,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory):
"""Unauthenticated user gets 401.""" """Unauthenticated user gets 401."""
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.get("/api/invites") response = await client.get("/api/invites")
assert response.status_code == 401 assert response.status_code == 401
@ -1002,6 +1010,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory):
# Admin Invite Management Tests (Phase 5) # Admin Invite Management Tests (Phase 5)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user): async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user):
"""Admin can list all invites.""" """Admin can list all invites."""
@ -1012,13 +1021,13 @@ async def test_admin_can_list_all_invites(client_factory, admin_user, regular_us
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
# List all # List all
response = await client.get("/api/admin/invites") response = await client.get("/api/admin/invites")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["total"] >= 2 assert data["total"] >= 2
@ -1034,14 +1043,14 @@ async def test_admin_list_pagination(client_factory, admin_user, regular_user):
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
# Create 5 invites # Create 5 invites
for _ in range(5): for _ in range(5):
await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
# Get page 1 with 2 per page # Get page 1 with 2 per page
response = await client.get("/api/admin/invites?page=1&per_page=2") response = await client.get("/api/admin/invites?page=1&per_page=2")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert len(data["records"]) == 2 assert len(data["records"]) == 2
@ -1058,13 +1067,13 @@ async def test_admin_filter_by_status(client_factory, admin_user, regular_user):
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
# Create an invite # Create an invite
await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
# Filter by ready # Filter by ready
response = await client.get("/api/admin/invites?status=ready") response = await client.get("/api/admin/invites?status=ready")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
for record in data["records"]: for record in data["records"]:
@ -1080,17 +1089,17 @@ async def test_admin_can_revoke_invite(client_factory, admin_user, regular_user)
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
# Create invite # Create invite
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
invite_id = create_resp.json()["id"] invite_id = create_resp.json()["id"]
# Revoke it # Revoke it
response = await client.post(f"/api/admin/invites/{invite_id}/revoke") response = await client.post(f"/api/admin/invites/{invite_id}/revoke")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["status"] == "revoked" assert data["status"] == "revoked"
@ -1106,14 +1115,14 @@ async def test_cannot_revoke_spent_invite(client_factory, admin_user, regular_us
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
# Create invite # Create invite
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
invite_data = create_resp.json() invite_data = create_resp.json()
# Use the invite # Use the invite
async with client_factory.create() as client: async with client_factory.create() as client:
await client.post( await client.post(
@ -1124,11 +1133,11 @@ async def test_cannot_revoke_spent_invite(client_factory, admin_user, regular_us
"invite_identifier": invite_data["identifier"], "invite_identifier": invite_data["identifier"],
}, },
) )
# Try to revoke # Try to revoke
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post(f"/api/admin/invites/{invite_data['id']}/revoke") response = await client.post(f"/api/admin/invites/{invite_data['id']}/revoke")
assert response.status_code == 400 assert response.status_code == 400
assert "only ready" in response.json()["detail"].lower() assert "only ready" in response.json()["detail"].lower()
@ -1138,7 +1147,7 @@ async def test_revoke_nonexistent_invite(client_factory, admin_user):
"""Revoking non-existent invite returns 404.""" """Revoking non-existent invite returns 404."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/admin/invites/99999/revoke") response = await client.post("/api/admin/invites/99999/revoke")
assert response.status_code == 404 assert response.status_code == 404
@ -1149,8 +1158,7 @@ async def test_regular_user_cannot_access_admin_invites(client_factory, regular_
# List # List
response = await client.get("/api/admin/invites") response = await client.get("/api/admin/invites")
assert response.status_code == 403 assert response.status_code == 403
# Revoke # Revoke
response = await client.post("/api/admin/invites/1/revoke") response = await client.post("/api/admin/invites/1/revoke")
assert response.status_code == 403 assert response.status_code == 403

View file

@ -7,15 +7,16 @@ These tests verify that:
3. Unauthenticated users are denied access (401) 3. Unauthenticated users are denied access (401)
4. The permission system cannot be bypassed 4. The permission system cannot be bypassed
""" """
import pytest import pytest
from models import Permission from models import Permission
# ============================================================================= # =============================================================================
# Role Assignment Tests # Role Assignment Tests
# ============================================================================= # =============================================================================
class TestRoleAssignment: class TestRoleAssignment:
"""Test that roles are properly assigned and returned.""" """Test that roles are properly assigned and returned."""
@ -23,7 +24,7 @@ class TestRoleAssignment:
async def test_regular_user_has_correct_roles(self, client_factory, regular_user): async def test_regular_user_has_correct_roles(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "regular" in data["roles"] assert "regular" in data["roles"]
@ -33,25 +34,27 @@ class TestRoleAssignment:
async def test_admin_user_has_correct_roles(self, client_factory, admin_user): async def test_admin_user_has_correct_roles(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "admin" in data["roles"] assert "admin" in data["roles"]
assert "regular" not in data["roles"] assert "regular" not in data["roles"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_has_correct_permissions(self, client_factory, regular_user): async def test_regular_user_has_correct_permissions(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
data = response.json() data = response.json()
permissions = data["permissions"] permissions = data["permissions"]
# Should have counter and sum permissions # Should have counter and sum permissions
assert Permission.VIEW_COUNTER.value in permissions assert Permission.VIEW_COUNTER.value in permissions
assert Permission.INCREMENT_COUNTER.value in permissions assert Permission.INCREMENT_COUNTER.value in permissions
assert Permission.USE_SUM.value in permissions assert Permission.USE_SUM.value in permissions
# Should NOT have audit permission # Should NOT have audit permission
assert Permission.VIEW_AUDIT.value not in permissions assert Permission.VIEW_AUDIT.value not in permissions
@ -59,23 +62,25 @@ class TestRoleAssignment:
async def test_admin_user_has_correct_permissions(self, client_factory, admin_user): async def test_admin_user_has_correct_permissions(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
data = response.json() data = response.json()
permissions = data["permissions"] permissions = data["permissions"]
# Should have audit permission # Should have audit permission
assert Permission.VIEW_AUDIT.value in permissions assert Permission.VIEW_AUDIT.value in permissions
# Should NOT have counter/sum permissions # Should NOT have counter/sum permissions
assert Permission.VIEW_COUNTER.value not in permissions assert Permission.VIEW_COUNTER.value not in permissions
assert Permission.INCREMENT_COUNTER.value not in permissions assert Permission.INCREMENT_COUNTER.value not in permissions
assert Permission.USE_SUM.value not in permissions assert Permission.USE_SUM.value not in permissions
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_with_no_roles_has_no_permissions(self, client_factory, user_no_roles): async def test_user_with_no_roles_has_no_permissions(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client: async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
data = response.json() data = response.json()
assert data["roles"] == [] assert data["roles"] == []
assert data["permissions"] == [] assert data["permissions"] == []
@ -85,6 +90,7 @@ class TestRoleAssignment:
# Counter Endpoint Access Tests # Counter Endpoint Access Tests
# ============================================================================= # =============================================================================
class TestCounterAccess: class TestCounterAccess:
"""Test access control for counter endpoints.""" """Test access control for counter endpoints."""
@ -92,15 +98,17 @@ class TestCounterAccess:
async def test_regular_user_can_view_counter(self, client_factory, regular_user): async def test_regular_user_can_view_counter(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/counter") response = await client.get("/api/counter")
assert response.status_code == 200 assert response.status_code == 200
assert "value" in response.json() assert "value" in response.json()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_can_increment_counter(self, client_factory, regular_user): async def test_regular_user_can_increment_counter(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post("/api/counter/increment") response = await client.post("/api/counter/increment")
assert response.status_code == 200 assert response.status_code == 200
assert "value" in response.json() assert "value" in response.json()
@ -109,7 +117,7 @@ class TestCounterAccess:
"""Admin users should be forbidden from counter endpoints.""" """Admin users should be forbidden from counter endpoints."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/counter") response = await client.get("/api/counter")
assert response.status_code == 403 assert response.status_code == 403
assert "permission" in response.json()["detail"].lower() assert "permission" in response.json()["detail"].lower()
@ -118,15 +126,17 @@ class TestCounterAccess:
"""Admin users should be forbidden from incrementing counter.""" """Admin users should be forbidden from incrementing counter."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/counter/increment") response = await client.post("/api/counter/increment")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_without_roles_cannot_view_counter(self, client_factory, user_no_roles): async def test_user_without_roles_cannot_view_counter(
self, client_factory, user_no_roles
):
"""Users with no roles should be forbidden.""" """Users with no roles should be forbidden."""
async with client_factory.create(cookies=user_no_roles["cookies"]) as client: async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/counter") response = await client.get("/api/counter")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@ -146,6 +156,7 @@ class TestCounterAccess:
# Sum Endpoint Access Tests # Sum Endpoint Access Tests
# ============================================================================= # =============================================================================
class TestSumAccess: class TestSumAccess:
"""Test access control for sum endpoint.""" """Test access control for sum endpoint."""
@ -156,7 +167,7 @@ class TestSumAccess:
"/api/sum", "/api/sum",
json={"a": 5, "b": 3}, json={"a": 5, "b": 3},
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["result"] == 8 assert data["result"] == 8
@ -169,17 +180,19 @@ class TestSumAccess:
"/api/sum", "/api/sum",
json={"a": 5, "b": 3}, json={"a": 5, "b": 3},
) )
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_without_roles_cannot_use_sum(self, client_factory, user_no_roles): async def test_user_without_roles_cannot_use_sum(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client: async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.post( response = await client.post(
"/api/sum", "/api/sum",
json={"a": 5, "b": 3}, json={"a": 5, "b": 3},
) )
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@ -195,6 +208,7 @@ class TestSumAccess:
# Audit Endpoint Access Tests # Audit Endpoint Access Tests
# ============================================================================= # =============================================================================
class TestAuditAccess: class TestAuditAccess:
"""Test access control for audit endpoints.""" """Test access control for audit endpoints."""
@ -202,7 +216,7 @@ class TestAuditAccess:
async def test_admin_can_view_counter_audit(self, client_factory, admin_user): async def test_admin_can_view_counter_audit(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/counter") response = await client.get("/api/audit/counter")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "records" in data assert "records" in data
@ -212,34 +226,40 @@ class TestAuditAccess:
async def test_admin_can_view_sum_audit(self, client_factory, admin_user): async def test_admin_can_view_sum_audit(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/sum") response = await client.get("/api/audit/sum")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "records" in data assert "records" in data
assert "total" in data assert "total" in data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_view_counter_audit(self, client_factory, regular_user): async def test_regular_user_cannot_view_counter_audit(
self, client_factory, regular_user
):
"""Regular users should be forbidden from audit endpoints.""" """Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/counter") response = await client.get("/api/audit/counter")
assert response.status_code == 403 assert response.status_code == 403
assert "permission" in response.json()["detail"].lower() assert "permission" in response.json()["detail"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_view_sum_audit(self, client_factory, regular_user): async def test_regular_user_cannot_view_sum_audit(
self, client_factory, regular_user
):
"""Regular users should be forbidden from audit endpoints.""" """Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/sum") response = await client.get("/api/audit/sum")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_without_roles_cannot_view_audit(self, client_factory, user_no_roles): async def test_user_without_roles_cannot_view_audit(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client: async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/audit/counter") response = await client.get("/api/audit/counter")
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@ -257,6 +277,7 @@ class TestAuditAccess:
# Offensive Security Tests - Bypass Attempts # Offensive Security Tests - Bypass Attempts
# ============================================================================= # =============================================================================
class TestSecurityBypassAttempts: class TestSecurityBypassAttempts:
""" """
Offensive tests that attempt to bypass security controls. Offensive tests that attempt to bypass security controls.
@ -264,7 +285,9 @@ class TestSecurityBypassAttempts:
""" """
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_access_audit_with_forged_role_claim(self, client_factory, regular_user): async def test_cannot_access_audit_with_forged_role_claim(
self, client_factory, regular_user
):
""" """
Attempt to access audit by somehow claiming admin role. Attempt to access audit by somehow claiming admin role.
The server should verify roles from DB, not trust client claims. The server should verify roles from DB, not trust client claims.
@ -272,7 +295,7 @@ class TestSecurityBypassAttempts:
# Regular user tries to access audit endpoint # Regular user tries to access audit endpoint
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/counter") response = await client.get("/api/audit/counter")
# Should be denied regardless of any manipulation attempts # Should be denied regardless of any manipulation attempts
assert response.status_code == 403 assert response.status_code == 403
@ -280,23 +303,27 @@ class TestSecurityBypassAttempts:
async def test_cannot_access_counter_with_expired_session(self, client_factory): async def test_cannot_access_counter_with_expired_session(self, client_factory):
"""Test that invalid/expired tokens are rejected.""" """Test that invalid/expired tokens are rejected."""
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5OTk5IiwiZXhwIjoxfQ.invalid" fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5OTk5IiwiZXhwIjoxfQ.invalid"
async with client_factory.create(cookies={"auth_token": fake_token}) as client: async with client_factory.create(cookies={"auth_token": fake_token}) as client:
response = await client.get("/api/counter") response = await client.get("/api/counter")
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_access_with_tampered_token(self, client_factory, regular_user): async def test_cannot_access_with_tampered_token(
self, client_factory, regular_user
):
"""Test that tokens signed with wrong key are rejected.""" """Test that tokens signed with wrong key are rejected."""
# Take a valid token and modify it # Take a valid token and modify it
original_token = regular_user["cookies"].get("auth_token", "") original_token = regular_user["cookies"].get("auth_token", "")
if original_token: if original_token:
tampered_token = original_token[:-5] + "XXXXX" tampered_token = original_token[:-5] + "XXXXX"
async with client_factory.create(cookies={"auth_token": tampered_token}) as client: async with client_factory.create(
cookies={"auth_token": tampered_token}
) as client:
response = await client.get("/api/counter") response = await client.get("/api/counter")
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
@ -305,14 +332,16 @@ class TestSecurityBypassAttempts:
Test that new registrations cannot claim admin role. Test that new registrations cannot claim admin role.
New users should only get 'regular' role by default. New users should only get 'regular' role by default.
""" """
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
from models import ROLE_REGULAR from models import ROLE_REGULAR
from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={ json={
@ -321,18 +350,18 @@ class TestSecurityBypassAttempts:
"invite_identifier": invite_code, "invite_identifier": invite_code,
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# Should only have regular role, not admin # Should only have regular role, not admin
assert "admin" not in data["roles"] assert "admin" not in data["roles"]
assert Permission.VIEW_AUDIT.value not in data["permissions"] assert Permission.VIEW_AUDIT.value not in data["permissions"]
# Try to access audit with this new user # Try to access audit with this new user
async with client_factory.create(cookies=dict(response.cookies)) as client: async with client_factory.create(cookies=dict(response.cookies)) as client:
audit_response = await client.get("/api/audit/counter") audit_response = await client.get("/api/audit/counter")
assert audit_response.status_code == 403 assert audit_response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
@ -341,33 +370,35 @@ class TestSecurityBypassAttempts:
If a user is deleted, their token should no longer work. If a user is deleted, their token should no longer work.
This tests that tokens are validated against current DB state. This tests that tokens are validated against current DB state.
""" """
from tests.helpers import unique_email
from sqlalchemy import delete from sqlalchemy import delete
from models import User from models import User
from tests.helpers import unique_email
email = unique_email("deleted") email = unique_email("deleted")
# Create and login user # Create and login user
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
user = await create_user_with_roles(db, email, "password123", ["regular"]) user = await create_user_with_roles(db, email, "password123", ["regular"])
user_id = user.id user_id = user.id
login_response = await client_factory.post( login_response = await client_factory.post(
"/api/auth/login", "/api/auth/login",
json={"email": email, "password": "password123"}, json={"email": email, "password": "password123"},
) )
cookies = dict(login_response.cookies) cookies = dict(login_response.cookies)
# Delete the user from DB # Delete the user from DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
await db.execute(delete(User).where(User.id == user_id)) await db.execute(delete(User).where(User.id == user_id))
await db.commit() await db.commit()
# Try to use the old token # Try to use the old token
async with client_factory.create(cookies=cookies) as client: async with client_factory.create(cookies=cookies) as client:
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
@ -376,42 +407,41 @@ class TestSecurityBypassAttempts:
If a user's role is changed, the change should be reflected If a user's role is changed, the change should be reflected
in subsequent requests (no stale permission cache). in subsequent requests (no stale permission cache).
""" """
from tests.helpers import unique_email
from sqlalchemy import select from sqlalchemy import select
from models import User, Role
from models import Role, User
from tests.helpers import unique_email
email = unique_email("rolechange") email = unique_email("rolechange")
# Create regular user # Create regular user
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
await create_user_with_roles(db, email, "password123", ["regular"]) await create_user_with_roles(db, email, "password123", ["regular"])
login_response = await client_factory.post( login_response = await client_factory.post(
"/api/auth/login", "/api/auth/login",
json={"email": email, "password": "password123"}, json={"email": email, "password": "password123"},
) )
cookies = dict(login_response.cookies) cookies = dict(login_response.cookies)
# Verify can access counter but not audit # Verify can access counter but not audit
async with client_factory.create(cookies=cookies) as client: async with client_factory.create(cookies=cookies) as client:
assert (await client.get("/api/counter")).status_code == 200 assert (await client.get("/api/counter")).status_code == 200
assert (await client.get("/api/audit/counter")).status_code == 403 assert (await client.get("/api/audit/counter")).status_code == 403
# Change user's role from regular to admin # Change user's role from regular to admin
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute(select(User).where(User.email == email)) result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one() user = result.scalar_one()
result = await db.execute(select(Role).where(Role.name == "admin")) result = await db.execute(select(Role).where(Role.name == "admin"))
admin_role = result.scalar_one() admin_role = result.scalar_one()
result = await db.execute(select(Role).where(Role.name == "regular")) user.roles = [admin_role] # Replace roles with admin only
regular_role = result.scalar_one()
user.roles = [admin_role] # Remove regular, add admin
await db.commit() await db.commit()
# Now should have audit access but not counter access # Now should have audit access but not counter access
async with client_factory.create(cookies=cookies) as client: async with client_factory.create(cookies=cookies) as client:
assert (await client.get("/api/audit/counter")).status_code == 200 assert (await client.get("/api/audit/counter")).status_code == 200
@ -422,6 +452,7 @@ class TestSecurityBypassAttempts:
# Audit Record Tests # Audit Record Tests
# ============================================================================= # =============================================================================
class TestAuditRecords: class TestAuditRecords:
"""Test that actions are properly recorded in audit logs.""" """Test that actions are properly recorded in audit logs."""
@ -433,15 +464,15 @@ class TestAuditRecords:
# Regular user increments counter # Regular user increments counter
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post("/api/counter/increment") await client.post("/api/counter/increment")
# Admin checks audit # Admin checks audit
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/counter") response = await client.get("/api/audit/counter")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["total"] >= 1 assert data["total"] >= 1
# Find record for our user # Find record for our user
records = data["records"] records = data["records"]
user_records = [r for r in records if r["user_email"] == regular_user["email"]] user_records = [r for r in records if r["user_email"] == regular_user["email"]]
@ -455,18 +486,18 @@ class TestAuditRecords:
# Regular user uses sum # Regular user uses sum
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post("/api/sum", json={"a": 10, "b": 20}) await client.post("/api/sum", json={"a": 10, "b": 20})
# Admin checks audit # Admin checks audit
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/sum") response = await client.get("/api/audit/sum")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["total"] >= 1 assert data["total"] >= 1
# Find record with our values # Find record with our values
records = data["records"] records = data["records"]
matching = [r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30] matching = [
r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30
]
assert len(matching) >= 1 assert len(matching) >= 1

View file

@ -1,9 +1,9 @@
"""Tests for user profile and contact details.""" """Tests for user profile and contact details."""
import pytest
from sqlalchemy import select from sqlalchemy import select
from models import User, ROLE_REGULAR
from auth import get_password_hash from auth import get_password_hash
from models import User
from tests.helpers import unique_email from tests.helpers import unique_email
# Valid npub for testing (32 zero bytes encoded as bech32) # Valid npub for testing (32 zero bytes encoded as bech32)
@ -16,7 +16,7 @@ class TestUserContactFields:
async def test_contact_fields_default_to_none(self, client_factory): async def test_contact_fields_default_to_none(self, client_factory):
"""New users should have all contact fields as None.""" """New users should have all contact fields as None."""
email = unique_email("test") email = unique_email("test")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
user = User( user = User(
email=email, email=email,
@ -25,7 +25,7 @@ class TestUserContactFields:
db.add(user) db.add(user)
await db.commit() await db.commit()
await db.refresh(user) await db.refresh(user)
assert user.contact_email is None assert user.contact_email is None
assert user.telegram is None assert user.telegram is None
assert user.signal is None assert user.signal is None
@ -34,7 +34,7 @@ class TestUserContactFields:
async def test_contact_fields_can_be_set(self, client_factory): async def test_contact_fields_can_be_set(self, client_factory):
"""Contact fields can be set when creating a user.""" """Contact fields can be set when creating a user."""
email = unique_email("test") email = unique_email("test")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
user = User( user = User(
email=email, email=email,
@ -47,7 +47,7 @@ class TestUserContactFields:
db.add(user) db.add(user)
await db.commit() await db.commit()
await db.refresh(user) await db.refresh(user)
assert user.contact_email == "contact@example.com" assert user.contact_email == "contact@example.com"
assert user.telegram == "@alice" assert user.telegram == "@alice"
assert user.signal == "alice.42" assert user.signal == "alice.42"
@ -56,7 +56,7 @@ class TestUserContactFields:
async def test_contact_fields_persist_after_reload(self, client_factory): async def test_contact_fields_persist_after_reload(self, client_factory):
"""Contact fields should persist in the database.""" """Contact fields should persist in the database."""
email = unique_email("test") email = unique_email("test")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
user = User( user = User(
email=email, email=email,
@ -69,12 +69,12 @@ class TestUserContactFields:
db.add(user) db.add(user)
await db.commit() await db.commit()
user_id = user.id user_id = user.id
# Reload from database in a new session # Reload from database in a new session
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
loaded_user = result.scalar_one() loaded_user = result.scalar_one()
assert loaded_user.contact_email == "contact@example.com" assert loaded_user.contact_email == "contact@example.com"
assert loaded_user.telegram == "@bob" assert loaded_user.telegram == "@bob"
assert loaded_user.signal == "bob.99" assert loaded_user.signal == "bob.99"
@ -83,7 +83,7 @@ class TestUserContactFields:
async def test_contact_fields_can_be_updated(self, client_factory): async def test_contact_fields_can_be_updated(self, client_factory):
"""Contact fields can be updated after user creation.""" """Contact fields can be updated after user creation."""
email = unique_email("test") email = unique_email("test")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
user = User( user = User(
email=email, email=email,
@ -92,21 +92,21 @@ class TestUserContactFields:
db.add(user) db.add(user)
await db.commit() await db.commit()
user_id = user.id user_id = user.id
# Update fields # Update fields
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one() user = result.scalar_one()
user.contact_email = "new@example.com" user.contact_email = "new@example.com"
user.telegram = "@updated" user.telegram = "@updated"
await db.commit() await db.commit()
# Verify update persisted # Verify update persisted
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one() user = result.scalar_one()
assert user.contact_email == "new@example.com" assert user.contact_email == "new@example.com"
assert user.telegram == "@updated" assert user.telegram == "@updated"
assert user.signal is None # Still None assert user.signal is None # Still None
@ -115,7 +115,7 @@ class TestUserContactFields:
async def test_contact_fields_can_be_cleared(self, client_factory): async def test_contact_fields_can_be_cleared(self, client_factory):
"""Contact fields can be set back to None.""" """Contact fields can be set back to None."""
email = unique_email("test") email = unique_email("test")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
user = User( user = User(
email=email, email=email,
@ -126,21 +126,21 @@ class TestUserContactFields:
db.add(user) db.add(user)
await db.commit() await db.commit()
user_id = user.id user_id = user.id
# Clear fields # Clear fields
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one() user = result.scalar_one()
user.contact_email = None user.contact_email = None
user.telegram = None user.telegram = None
await db.commit() await db.commit()
# Verify cleared # Verify cleared
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one() user = result.scalar_one()
assert user.contact_email is None assert user.contact_email is None
assert user.telegram is None assert user.telegram is None
@ -152,7 +152,7 @@ class TestGetProfileEndpoint:
"""Regular user can fetch their profile.""" """Regular user can fetch their profile."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/profile") response = await client.get("/api/profile")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "contact_email" in data assert "contact_email" in data
@ -169,7 +169,7 @@ class TestGetProfileEndpoint:
"""Admin user gets 403 when trying to access profile.""" """Admin user gets 403 when trying to access profile."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/profile") response = await client.get("/api/profile")
assert response.status_code == 403 assert response.status_code == 403
assert "regular users" in response.json()["detail"].lower() assert "regular users" in response.json()["detail"].lower()
@ -177,7 +177,7 @@ class TestGetProfileEndpoint:
"""Unauthenticated user gets 401.""" """Unauthenticated user gets 401."""
async with client_factory.create() as client: async with client_factory.create() as client:
response = await client.get("/api/profile") response = await client.get("/api/profile")
assert response.status_code == 401 assert response.status_code == 401
async def test_profile_returns_existing_data(self, client_factory, regular_user): async def test_profile_returns_existing_data(self, client_factory, regular_user):
@ -191,11 +191,11 @@ class TestGetProfileEndpoint:
user.contact_email = "contact@test.com" user.contact_email = "contact@test.com"
user.telegram = "@testuser" user.telegram = "@testuser"
await db.commit() await db.commit()
# Fetch via API # Fetch via API
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/profile") response = await client.get("/api/profile")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["contact_email"] == "contact@test.com" assert data["contact_email"] == "contact@test.com"
@ -219,7 +219,7 @@ class TestUpdateProfileEndpoint:
"nostr_npub": VALID_NPUB, "nostr_npub": VALID_NPUB,
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["contact_email"] == "new@example.com" assert data["contact_email"] == "new@example.com"
@ -234,10 +234,10 @@ class TestUpdateProfileEndpoint:
"/api/profile", "/api/profile",
json={"telegram": "@persisted"}, json={"telegram": "@persisted"},
) )
# Fetch again to verify # Fetch again to verify
response = await client.get("/api/profile") response = await client.get("/api/profile")
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["telegram"] == "@persisted" assert response.json()["telegram"] == "@persisted"
@ -248,7 +248,7 @@ class TestUpdateProfileEndpoint:
"/api/profile", "/api/profile",
json={"telegram": "@admin"}, json={"telegram": "@admin"},
) )
assert response.status_code == 403 assert response.status_code == 403
async def test_unauthenticated_user_gets_401(self, client_factory): async def test_unauthenticated_user_gets_401(self, client_factory):
@ -258,7 +258,7 @@ class TestUpdateProfileEndpoint:
"/api/profile", "/api/profile",
json={"telegram": "@test"}, json={"telegram": "@test"},
) )
assert response.status_code == 401 assert response.status_code == 401
async def test_can_clear_fields(self, client_factory, regular_user): async def test_can_clear_fields(self, client_factory, regular_user):
@ -272,7 +272,7 @@ class TestUpdateProfileEndpoint:
"telegram": "@test", "telegram": "@test",
}, },
) )
# Then clear them # Then clear them
response = await client.put( response = await client.put(
"/api/profile", "/api/profile",
@ -283,7 +283,7 @@ class TestUpdateProfileEndpoint:
"nostr_npub": None, "nostr_npub": None,
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["contact_email"] is None assert data["contact_email"] is None
@ -296,7 +296,7 @@ class TestUpdateProfileEndpoint:
"/api/profile", "/api/profile",
json={"contact_email": "not-an-email"}, json={"contact_email": "not-an-email"},
) )
assert response.status_code == 422 assert response.status_code == 422
data = response.json() data = response.json()
assert "field_errors" in data["detail"] assert "field_errors" in data["detail"]
@ -309,7 +309,7 @@ class TestUpdateProfileEndpoint:
"/api/profile", "/api/profile",
json={"telegram": "missing_at_sign"}, json={"telegram": "missing_at_sign"},
) )
assert response.status_code == 422 assert response.status_code == 422
data = response.json() data = response.json()
assert "field_errors" in data["detail"] assert "field_errors" in data["detail"]
@ -322,13 +322,15 @@ class TestUpdateProfileEndpoint:
"/api/profile", "/api/profile",
json={"nostr_npub": "npub1invalid"}, json={"nostr_npub": "npub1invalid"},
) )
assert response.status_code == 422 assert response.status_code == 422
data = response.json() data = response.json()
assert "field_errors" in data["detail"] assert "field_errors" in data["detail"]
assert "nostr_npub" in data["detail"]["field_errors"] assert "nostr_npub" in data["detail"]["field_errors"]
async def test_multiple_invalid_fields_returns_all_errors(self, client_factory, regular_user): async def test_multiple_invalid_fields_returns_all_errors(
self, client_factory, regular_user
):
"""Multiple invalid fields return all errors.""" """Multiple invalid fields return all errors."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.put( response = await client.put(
@ -338,13 +340,15 @@ class TestUpdateProfileEndpoint:
"telegram": "no-at", "telegram": "no-at",
}, },
) )
assert response.status_code == 422 assert response.status_code == 422
data = response.json() data = response.json()
assert "contact_email" in data["detail"]["field_errors"] assert "contact_email" in data["detail"]["field_errors"]
assert "telegram" in data["detail"]["field_errors"] assert "telegram" in data["detail"]["field_errors"]
async def test_partial_update_preserves_other_fields(self, client_factory, regular_user): async def test_partial_update_preserves_other_fields(
self, client_factory, regular_user
):
"""Updating one field doesn't affect others (they get set to the request values).""" """Updating one field doesn't affect others (they get set to the request values)."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
# Set initial values # Set initial values
@ -355,7 +359,7 @@ class TestUpdateProfileEndpoint:
"telegram": "@initial", "telegram": "@initial",
}, },
) )
# Update only telegram, but note: PUT replaces all fields # Update only telegram, but note: PUT replaces all fields
# So we need to include all fields we want to keep # So we need to include all fields we want to keep
response = await client.put( response = await client.put(
@ -365,7 +369,7 @@ class TestUpdateProfileEndpoint:
"telegram": "@updated", "telegram": "@updated",
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["contact_email"] == "initial@example.com" assert data["contact_email"] == "initial@example.com"
@ -386,10 +390,10 @@ class TestProfilePrivacy:
"telegram": "@secret", "telegram": "@secret",
}, },
) )
# Check /api/auth/me doesn't expose it # Check /api/auth/me doesn't expose it
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# These fields should NOT be in the response # These fields should NOT be in the response
@ -402,12 +406,15 @@ class TestProfilePrivacy:
class TestProfileGodfather: class TestProfileGodfather:
"""Tests for godfather information in profile.""" """Tests for godfather information in profile."""
async def test_profile_shows_godfather_email(self, client_factory, admin_user, regular_user): async def test_profile_shows_godfather_email(
self, client_factory, admin_user, regular_user
):
"""Profile shows godfather email for users who signed up with invite.""" """Profile shows godfather email for users who signed up with invite."""
from tests.helpers import unique_email
from sqlalchemy import select from sqlalchemy import select
from models import User from models import User
from tests.helpers import unique_email
# Create invite # Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
@ -415,13 +422,13 @@ class TestProfileGodfather:
select(User).where(User.email == regular_user["email"]) select(User).where(User.email == regular_user["email"])
) )
godfather = result.scalar_one() godfather = result.scalar_one()
create_resp = await client.post( create_resp = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
) )
identifier = create_resp.json()["identifier"] identifier = create_resp.json()["identifier"]
# Register new user with invite # Register new user with invite
new_email = unique_email("godchild") new_email = unique_email("godchild")
async with client_factory.create() as client: async with client_factory.create() as client:
@ -434,20 +441,22 @@ class TestProfileGodfather:
}, },
) )
new_user_cookies = dict(reg_resp.cookies) new_user_cookies = dict(reg_resp.cookies)
# Check profile shows godfather # Check profile shows godfather
async with client_factory.create(cookies=new_user_cookies) as client: async with client_factory.create(cookies=new_user_cookies) as client:
response = await client.get("/api/profile") response = await client.get("/api/profile")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["godfather_email"] == regular_user["email"] assert data["godfather_email"] == regular_user["email"]
async def test_profile_godfather_null_for_seeded_users(self, client_factory, regular_user): async def test_profile_godfather_null_for_seeded_users(
self, client_factory, regular_user
):
"""Profile shows null godfather for users without one (e.g., seeded users).""" """Profile shows null godfather for users without one (e.g., seeded users)."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/profile") response = await client.get("/api/profile")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["godfather_email"] is None assert data["godfather_email"] is None

View file

@ -1,12 +1,11 @@
"""Tests for profile field validation.""" """Tests for profile field validation."""
import pytest
from validation import ( from validation import (
validate_contact_email, validate_contact_email,
validate_telegram,
validate_signal,
validate_nostr_npub, validate_nostr_npub,
validate_profile_fields, validate_profile_fields,
validate_signal,
validate_telegram,
) )
@ -140,13 +139,17 @@ class TestValidateNostrNpub:
assert validate_nostr_npub(self.VALID_NPUB) is None assert validate_nostr_npub(self.VALID_NPUB) is None
def test_wrong_prefix(self): def test_wrong_prefix(self):
result = validate_nostr_npub("nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz") result = validate_nostr_npub(
"nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz"
)
assert result is not None assert result is not None
assert "npub" in result.lower() assert "npub" in result.lower()
def test_invalid_checksum(self): def test_invalid_checksum(self):
# Change last character to break checksum # Change last character to break checksum
result = validate_nostr_npub("npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd") result = validate_nostr_npub(
"npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd"
)
assert result is not None assert result is not None
assert "checksum" in result.lower() assert "checksum" in result.lower()
@ -155,7 +158,9 @@ class TestValidateNostrNpub:
assert result is not None assert result is not None
def test_not_starting_with_npub1(self): def test_not_starting_with_npub1(self):
result = validate_nostr_npub("npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc") result = validate_nostr_npub(
"npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc"
)
assert result is not None assert result is not None
assert "npub1" in result assert "npub1" in result
@ -206,4 +211,3 @@ class TestValidateProfileFields:
nostr_npub="", nostr_npub="",
) )
assert errors == {} assert errors == {}

View file

@ -1,8 +1,9 @@
"""Validate shared constants match backend definitions.""" """Validate shared constants match backend definitions."""
import json import json
from pathlib import Path from pathlib import Path
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, AppointmentStatus from models import ROLE_ADMIN, ROLE_REGULAR, AppointmentStatus, InviteStatus
def validate_shared_constants() -> None: def validate_shared_constants() -> None:
@ -11,13 +12,13 @@ def validate_shared_constants() -> None:
Raises ValueError if there's a mismatch. Raises ValueError if there's a mismatch.
""" """
constants_path = Path(__file__).parent.parent / "shared" / "constants.json" constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
if not constants_path.exists(): if not constants_path.exists():
raise ValueError(f"Shared constants file not found: {constants_path}") raise ValueError(f"Shared constants file not found: {constants_path}")
with open(constants_path) as f: with open(constants_path) as f:
constants = json.load(f) constants = json.load(f)
# Validate roles # Validate roles
expected_roles = {"ADMIN": ROLE_ADMIN, "REGULAR": ROLE_REGULAR} expected_roles = {"ADMIN": ROLE_ADMIN, "REGULAR": ROLE_REGULAR}
if constants.get("roles") != expected_roles: if constants.get("roles") != expected_roles:
@ -25,39 +26,46 @@ def validate_shared_constants() -> None:
f"Role mismatch in shared/constants.json. " f"Role mismatch in shared/constants.json. "
f"Expected: {expected_roles}, Got: {constants.get('roles')}" f"Expected: {expected_roles}, Got: {constants.get('roles')}"
) )
# Validate invite statuses # Validate invite statuses
expected_invite_statuses = {s.name: s.value for s in InviteStatus} expected_invite_statuses = {s.name: s.value for s in InviteStatus}
if constants.get("inviteStatuses") != expected_invite_statuses: if constants.get("inviteStatuses") != expected_invite_statuses:
got = constants.get("inviteStatuses")
raise ValueError( raise ValueError(
f"Invite status mismatch in shared/constants.json. " f"Invite status mismatch. Expected: {expected_invite_statuses}, Got: {got}"
f"Expected: {expected_invite_statuses}, Got: {constants.get('inviteStatuses')}"
) )
# Validate appointment statuses # Validate appointment statuses
expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus} expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus}
if constants.get("appointmentStatuses") != expected_appointment_statuses: if constants.get("appointmentStatuses") != expected_appointment_statuses:
got = constants.get("appointmentStatuses")
raise ValueError( raise ValueError(
f"Appointment status mismatch in shared/constants.json. " f"Appointment status mismatch. "
f"Expected: {expected_appointment_statuses}, Got: {constants.get('appointmentStatuses')}" f"Expected: {expected_appointment_statuses}, Got: {got}"
) )
# Validate booking constants exist with required fields # Validate booking constants exist with required fields
booking = constants.get("booking", {}) booking = constants.get("booking", {})
required_booking_fields = ["slotDurationMinutes", "maxAdvanceDays", "minAdvanceDays", "noteMaxLength"] required_booking_fields = [
"slotDurationMinutes",
"maxAdvanceDays",
"minAdvanceDays",
"noteMaxLength",
]
for field in required_booking_fields: for field in required_booking_fields:
if field not in booking: if field not in booking:
raise ValueError(f"Missing booking constant '{field}' in shared/constants.json") raise ValueError(f"Missing booking constant '{field}' in constants.json")
# Validate validation rules exist (structure check only) # Validate validation rules exist (structure check only)
validation = constants.get("validation", {}) validation = constants.get("validation", {})
required_fields = ["telegram", "signal", "nostrNpub"] required_fields = ["telegram", "signal", "nostrNpub"]
for field in required_fields: for field in required_fields:
if field not in validation: if field not in validation:
raise ValueError(f"Missing validation rules for '{field}' in shared/constants.json") raise ValueError(
f"Missing validation rules for '{field}' in constants.json"
)
if __name__ == "__main__": if __name__ == "__main__":
validate_shared_constants() validate_shared_constants()
print("✓ Shared constants are valid") print("✓ Shared constants are valid")

View file

@ -1,9 +1,10 @@
"""Validation utilities for user profile fields.""" """Validation utilities for user profile fields."""
import json import json
from pathlib import Path from pathlib import Path
from email_validator import validate_email, EmailNotValidError
from bech32 import bech32_decode from bech32 import bech32_decode
from email_validator import EmailNotValidError, validate_email
# Load validation rules from shared constants # Load validation rules from shared constants
_constants_path = Path(__file__).parent.parent / "shared" / "constants.json" _constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
@ -18,13 +19,13 @@ NPUB_RULES = _constants["validation"]["nostrNpub"]
def validate_contact_email(value: str | None) -> str | None: def validate_contact_email(value: str | None) -> str | None:
""" """
Validate contact email format. Validate contact email format.
Returns None if valid, error message if invalid. Returns None if valid, error message if invalid.
Empty/None values are valid (field is optional). Empty/None values are valid (field is optional).
""" """
if not value: if not value:
return None return None
try: try:
validate_email(value, check_deliverability=False) validate_email(value, check_deliverability=False)
return None return None
@ -35,84 +36,84 @@ def validate_contact_email(value: str | None) -> str | None:
def validate_telegram(value: str | None) -> str | None: def validate_telegram(value: str | None) -> str | None:
""" """
Validate Telegram handle. Validate Telegram handle.
Must start with @ if provided, with characters after @ within max length. Must start with @ if provided, with characters after @ within max length.
Returns None if valid, error message if invalid. Returns None if valid, error message if invalid.
Empty/None values are valid (field is optional). Empty/None values are valid (field is optional).
""" """
if not value: if not value:
return None return None
prefix = TELEGRAM_RULES["mustStartWith"] prefix = TELEGRAM_RULES["mustStartWith"]
max_len = TELEGRAM_RULES["maxLengthAfterAt"] max_len = TELEGRAM_RULES["maxLengthAfterAt"]
if not value.startswith(prefix): if not value.startswith(prefix):
return f"Telegram handle must start with {prefix}" return f"Telegram handle must start with {prefix}"
handle = value[1:] handle = value[1:]
if not handle: if not handle:
return f"Telegram handle must have at least one character after {prefix}" return f"Telegram handle must have at least one character after {prefix}"
if len(handle) > max_len: if len(handle) > max_len:
return f"Telegram handle must be at most {max_len} characters (after {prefix})" return f"Telegram handle must be at most {max_len} characters (after {prefix})"
return None return None
def validate_signal(value: str | None) -> str | None: def validate_signal(value: str | None) -> str | None:
""" """
Validate Signal username. Validate Signal username.
Any non-empty string within max length is valid. Any non-empty string within max length is valid.
Returns None if valid, error message if invalid. Returns None if valid, error message if invalid.
Empty/None values are valid (field is optional). Empty/None values are valid (field is optional).
""" """
if not value: if not value:
return None return None
max_len = SIGNAL_RULES["maxLength"] max_len = SIGNAL_RULES["maxLength"]
# Signal usernames are fairly permissive, just check it's not empty # Signal usernames are fairly permissive, just check it's not empty
if len(value.strip()) == 0: if len(value.strip()) == 0:
return "Signal username cannot be empty" return "Signal username cannot be empty"
if len(value) > max_len: if len(value) > max_len:
return f"Signal username must be at most {max_len} characters" return f"Signal username must be at most {max_len} characters"
return None return None
def validate_nostr_npub(value: str | None) -> str | None: def validate_nostr_npub(value: str | None) -> str | None:
""" """
Validate Nostr npub (public key in bech32 format). Validate Nostr npub (public key in bech32 format).
Must be valid bech32 with 'npub' prefix. Must be valid bech32 with 'npub' prefix.
Returns None if valid, error message if invalid. Returns None if valid, error message if invalid.
Empty/None values are valid (field is optional). Empty/None values are valid (field is optional).
""" """
if not value: if not value:
return None return None
prefix = NPUB_RULES["prefix"] prefix = NPUB_RULES["prefix"]
expected_words = NPUB_RULES["bech32Words"] expected_words = NPUB_RULES["bech32Words"]
if not value.startswith(prefix): if not value.startswith(prefix):
return f"Nostr npub must start with '{prefix}'" return f"Nostr npub must start with '{prefix}'"
# Decode bech32 to validate checksum # Decode bech32 to validate checksum
hrp, data = bech32_decode(value) hrp, data = bech32_decode(value)
if hrp is None or data is None: if hrp is None or data is None:
return "Invalid Nostr npub: bech32 checksum failed" return "Invalid Nostr npub: bech32 checksum failed"
if hrp != "npub": if hrp != "npub":
return "Nostr npub must have 'npub' prefix" return "Nostr npub must have 'npub' prefix"
# npub should decode to 32 bytes (256 bits) for a public key # npub should decode to 32 bytes (256 bits) for a public key
# In bech32, each character encodes 5 bits, so 32 bytes = 52 characters of data # In bech32, each character encodes 5 bits, so 32 bytes = 52 characters of data
if len(data) != expected_words: if len(data) != expected_words:
return "Invalid Nostr npub: incorrect length" return "Invalid Nostr npub: incorrect length"
return None return None
@ -124,23 +125,22 @@ def validate_profile_fields(
) -> dict[str, str]: ) -> dict[str, str]:
""" """
Validate all profile fields at once. Validate all profile fields at once.
Returns a dict of field_name -> error_message for any invalid fields. Returns a dict of field_name -> error_message for any invalid fields.
Empty dict means all fields are valid. Empty dict means all fields are valid.
""" """
errors: dict[str, str] = {} errors: dict[str, str] = {}
if err := validate_contact_email(contact_email): if err := validate_contact_email(contact_email):
errors["contact_email"] = err errors["contact_email"] = err
if err := validate_telegram(telegram): if err := validate_telegram(telegram):
errors["telegram"] = err errors["telegram"] = err
if err := validate_signal(signal): if err := validate_signal(signal):
errors["signal"] = err errors["signal"] = err
if err := validate_nostr_npub(nostr_npub): if err := validate_nostr_npub(nostr_npub):
errors["nostr_npub"] = err errors["nostr_npub"] = err
return errors
return errors