diff --git a/Makefile b/Makefile index 178f042..b6fb1cb 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants +.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants lint-backend format-backend fix-backend -include .env export @@ -93,3 +93,12 @@ check-types-fresh: generate-types-standalone check-constants: @cd backend && uv run python validate_constants.py + +lint-backend: + cd backend && uv run ruff check . + +format-backend: + cd backend && uv run ruff format . + +fix-backend: + cd backend && uv run ruff check --fix . && uv run ruff format . diff --git a/backend/auth.py b/backend/auth.py index 22fbaf8..d338b2c 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -1,5 +1,5 @@ import os -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import bcrypt from fastapi import Depends, HTTPException, Request, status @@ -8,7 +8,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_db -from models import User, Permission +from models import Permission, User from schemas import UserResponse SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example @@ -32,9 +32,13 @@ def get_password_hash(password: str) -> str: ).decode("utf-8") -def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str: +def create_access_token( + data: dict[str, str], + expires_delta: timedelta | None = None, +) -> str: to_encode: dict[str, str | datetime] = dict(data) - expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) + delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(UTC) + delta to_encode["exp"] = expire encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded @@ -60,11 +64,11 @@ async def get_current_user( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", ) - + token = request.cookies.get(COOKIE_NAME) if not token: raise credentials_exception - + try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) user_id_str = payload.get("sub") @@ -72,7 +76,7 @@ async def get_current_user( raise credentials_exception user_id = int(user_id_str) except (JWTError, ValueError): - raise credentials_exception + raise credentials_exception from None result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() @@ -83,27 +87,32 @@ async def get_current_user( def require_permission(*required_permissions: Permission): """ - Dependency factory that checks if user has ALL of the required permissions. - + Dependency factory that checks if user has ALL required permissions. + Usage: @app.get("/api/counter") - async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))): + async def get_counter( + user: User = Depends(require_permission(Permission.VIEW_COUNTER)) + ): ... """ + async def permission_checker( request: Request, db: AsyncSession = Depends(get_db), ) -> User: user = await get_current_user(request, db) user_permissions = await user.get_permissions(db) - + missing = [p for p in required_permissions if p not in user_permissions] if missing: + missing_str = ", ".join(p.value for p in missing) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Missing required permissions: {', '.join(p.value for p in missing)}", + detail=f"Missing required permissions: {missing_str}", ) return user + return permission_checker diff --git a/backend/database.py b/backend/database.py index 4e28df8..160863c 100644 --- a/backend/database.py +++ b/backend/database.py @@ -1,8 +1,11 @@ import os -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker + +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.orm import DeclarativeBase -DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret") +DATABASE_URL = os.getenv( + "DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret" +) engine = create_async_engine(DATABASE_URL) async_session = async_sessionmaker(engine, expire_on_commit=False) @@ -15,4 +18,3 @@ class Base(DeclarativeBase): async def get_db(): async with async_session() as session: yield session - diff --git a/backend/invite_utils.py b/backend/invite_utils.py index af19d7a..dceb826 100644 --- a/backend/invite_utils.py +++ b/backend/invite_utils.py @@ -1,4 +1,5 @@ """Utilities for invite code generation and validation.""" + import random from pathlib import Path @@ -13,11 +14,11 @@ assert len(BIP39_WORDS) == 2048, f"Expected 2048 BIP39 words, got {len(BIP39_WOR def generate_invite_identifier() -> str: """ Generate a unique invite identifier. - + Format: word1-word2-NN where: - word1, word2 are random BIP39 words - NN is a two-digit number (00-99) - + Returns lowercase identifier. """ word1 = random.choice(BIP39_WORDS) @@ -29,7 +30,7 @@ def generate_invite_identifier() -> str: def normalize_identifier(identifier: str) -> str: """ Normalize an invite identifier for comparison/lookup. - + - Converts to lowercase - Strips whitespace """ @@ -39,22 +40,18 @@ def normalize_identifier(identifier: str) -> str: def is_valid_identifier_format(identifier: str) -> bool: """ Check if an identifier has valid format (word-word-NN). - + Does NOT check if words are valid BIP39 words. """ parts = identifier.split("-") if len(parts) != 3: return False - + word1, word2, number = parts - + # Check words are non-empty if not word1 or not word2: return False - - # Check number is two digits - if len(number) != 2 or not number.isdigit(): - return False - - return True + # Check number is two digits + return len(number) == 2 and number.isdigit() diff --git a/backend/main.py b/backend/main.py index 245956c..989f32e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,19 +1,20 @@ """FastAPI application entry point.""" + from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from database import engine, Base -from routes import sum as sum_routes -from routes import counter as counter_routes +from database import Base, engine from routes import audit as audit_routes -from routes import profile as profile_routes -from routes import invites as invites_routes from routes import auth as auth_routes -from routes import meta as meta_routes from routes import availability as availability_routes from routes import booking as booking_routes +from routes import counter as counter_routes +from routes import invites as invites_routes +from routes import meta as meta_routes +from routes import profile as profile_routes +from routes import sum as sum_routes from validate_constants import validate_shared_constants @@ -22,7 +23,7 @@ async def lifespan(app: FastAPI): """Create database tables on startup and validate constants.""" # Validate shared constants match backend definitions validate_shared_constants() - + async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield diff --git a/backend/models.py b/backend/models.py index 87bfe8e..a00f22f 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,9 +1,24 @@ -from datetime import datetime, date, time, timezone +from datetime import UTC, date, datetime, time from enum import Enum as PyEnum from typing import TypedDict -from sqlalchemy import Integer, String, Float, DateTime, Date, Time, ForeignKey, Table, Column, Enum, UniqueConstraint, select -from sqlalchemy.orm import Mapped, mapped_column, relationship + +from sqlalchemy import ( + Column, + Date, + DateTime, + Enum, + Float, + ForeignKey, + Integer, + String, + Table, + Time, + UniqueConstraint, + select, +) from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Mapped, mapped_column, relationship + from database import Base @@ -14,25 +29,26 @@ class RoleConfig(TypedDict): class Permission(str, PyEnum): """All available permissions in the system.""" + # Counter permissions VIEW_COUNTER = "view_counter" INCREMENT_COUNTER = "increment_counter" - + # Sum permissions USE_SUM = "use_sum" - + # Audit permissions VIEW_AUDIT = "view_audit" - + # Invite permissions MANAGE_INVITES = "manage_invites" VIEW_OWN_INVITES = "view_own_invites" - + # Booking permissions (regular users) BOOK_APPOINTMENT = "book_appointment" VIEW_OWN_APPOINTMENTS = "view_own_appointments" CANCEL_OWN_APPOINTMENT = "cancel_own_appointment" - + # Availability/Appointments permissions (admin) MANAGE_AVAILABILITY = "manage_availability" VIEW_ALL_APPOINTMENTS = "view_all_appointments" @@ -41,6 +57,7 @@ class Permission(str, PyEnum): class InviteStatus(str, PyEnum): """Status of an invite.""" + READY = "ready" SPENT = "spent" REVOKED = "revoked" @@ -48,6 +65,7 @@ class InviteStatus(str, PyEnum): class AppointmentStatus(str, PyEnum): """Status of an appointment.""" + BOOKED = "booked" CANCELLED_BY_USER = "cancelled_by_user" CANCELLED_BY_ADMIN = "cancelled_by_admin" @@ -60,7 +78,7 @@ ROLE_REGULAR = "regular" # Role definitions with their permissions ROLE_DEFINITIONS: dict[str, RoleConfig] = { ROLE_ADMIN: { - "description": "Administrator with audit, invite, and appointment management access", + "description": "Administrator with audit/invite/appointment access", "permissions": [ Permission.VIEW_AUDIT, Permission.MANAGE_INVITES, @@ -88,7 +106,12 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = { role_permissions = Table( "role_permissions", Base.metadata, - Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True), + Column( + "role_id", + Integer, + ForeignKey("roles.id", ondelete="CASCADE"), + primary_key=True, + ), Column("permission", Enum(Permission), primary_key=True), ) @@ -97,8 +120,18 @@ role_permissions = Table( user_roles = Table( "user_roles", Base.metadata, - Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True), - Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True), + Column( + "user_id", + Integer, + ForeignKey("users.id", ondelete="CASCADE"), + primary_key=True, + ), + Column( + "role_id", + Integer, + ForeignKey("roles.id", ondelete="CASCADE"), + primary_key=True, + ), ) @@ -108,7 +141,7 @@ class Role(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) description: Mapped[str] = mapped_column(String(255), nullable=True) - + # Relationship to users users: Mapped[list["User"]] = relationship( "User", @@ -118,31 +151,42 @@ class Role(Base): async def get_permissions(self, db: AsyncSession) -> set[Permission]: """Get all permissions for this role.""" - result = await db.execute( - select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id) + query = select(role_permissions.c.permission).where( + role_permissions.c.role_id == self.id ) + result = await db.execute(query) return {row[0] for row in result.fetchall()} - async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None: + async def set_permissions( + self, db: AsyncSession, permissions: list[Permission] + ) -> None: """Set all permissions for this role (replaces existing).""" - await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id)) + delete_query = role_permissions.delete().where( + role_permissions.c.role_id == self.id + ) + await db.execute(delete_query) for perm in permissions: - await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm)) + insert_query = role_permissions.insert().values( + role_id=self.id, permission=perm + ) + await db.execute(insert_query) class User(Base): __tablename__ = "users" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) + email: Mapped[str] = mapped_column( + String(255), unique=True, nullable=False, index=True + ) hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) - + # Contact details (all optional) contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True) telegram: Mapped[str | None] = mapped_column(String(64), nullable=True) signal: Mapped[str | None] = mapped_column(String(64), nullable=True) nostr_npub: Mapped[str | None] = mapped_column(String(63), nullable=True) - + # Godfather (who invited this user) - null for seeded/admin users godfather_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("users.id"), nullable=True @@ -152,7 +196,7 @@ class User(Base): remote_side="User.id", foreign_keys=[godfather_id], ) - + # Relationship to roles roles: Mapped[list[Role]] = relationship( "Role", @@ -192,12 +236,14 @@ class SumRecord(Base): __tablename__ = "sum_records" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.id"), nullable=False, index=True + ) a: Mapped[float] = mapped_column(Float, nullable=False) b: Mapped[float] = mapped_column(Float, nullable=False) result: Mapped[float] = mapped_column(Float, nullable=False) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + DateTime(timezone=True), default=lambda: datetime.now(UTC) ) @@ -205,11 +251,13 @@ class CounterRecord(Base): __tablename__ = "counter_records" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) + user_id: Mapped[int] = mapped_column( + Integer, ForeignKey("users.id"), nullable=False, index=True + ) value_before: Mapped[int] = mapped_column(Integer, nullable=False) value_after: Mapped[int] = mapped_column(Integer, nullable=False) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + DateTime(timezone=True), default=lambda: datetime.now(UTC) ) @@ -217,11 +265,13 @@ class Invite(Base): __tablename__ = "invites" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - identifier: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) + identifier: Mapped[str] = mapped_column( + String(64), unique=True, nullable=False, index=True + ) status: Mapped[InviteStatus] = mapped_column( Enum(InviteStatus), nullable=False, default=InviteStatus.READY ) - + # Godfather - the user who owns this invite godfather_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.id"), nullable=False, index=True @@ -231,7 +281,7 @@ class Invite(Base): foreign_keys=[godfather_id], lazy="joined", ) - + # User who used this invite (null until spent) used_by_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("users.id"), nullable=True @@ -241,17 +291,22 @@ class Invite(Base): foreign_keys=[used_by_id], lazy="joined", ) - + # Timestamps created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + DateTime(timezone=True), default=lambda: datetime.now(UTC) + ) + spent_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + revoked_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True ) - spent_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) - revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) class Availability(Base): """Admin availability slots for booking.""" + __tablename__ = "availability" __table_args__ = ( UniqueConstraint("date", "start_time", name="uq_availability_date_start"), @@ -262,34 +317,37 @@ class Availability(Base): start_time: Mapped[time] = mapped_column(Time, nullable=False) end_time: Mapped[time] = mapped_column(Time, nullable=False) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + DateTime(timezone=True), default=lambda: datetime.now(UTC) ) updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - default=lambda: datetime.now(timezone.utc), - onupdate=lambda: datetime.now(timezone.utc) + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), ) class Appointment(Base): """User appointment bookings.""" + __tablename__ = "appointments" - __table_args__ = ( - UniqueConstraint("slot_start", name="uq_appointment_slot_start"), - ) + __table_args__ = (UniqueConstraint("slot_start", name="uq_appointment_slot_start"),) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) user_id: Mapped[int] = mapped_column( Integer, ForeignKey("users.id"), nullable=False, index=True ) user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined") - slot_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) + slot_start: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, index=True + ) slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) note: Mapped[str | None] = mapped_column(String(144), nullable=True) status: Mapped[AppointmentStatus] = mapped_column( Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED ) created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) + DateTime(timezone=True), default=lambda: datetime.now(UTC) + ) + cancelled_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True ) - cancelled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 21359ac..2427cdc 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -20,6 +20,7 @@ dev = [ "httpx>=0.28.1", "aiosqlite>=0.20.0", "mypy>=1.13.0", + "ruff>=0.14.10", ] [tool.mypy] @@ -30,3 +31,27 @@ check_untyped_defs = true ignore_missing_imports = true exclude = ["tests/"] +[tool.ruff] +line-length = 88 +target-version = "py311" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort + "B", # flake8-bugbear + "UP", # pyupgrade + "SIM", # flake8-simplify + "RUF", # ruff-specific rules +] +ignore = [ + "B008", # function-call-in-default-argument (standard FastAPI pattern with Depends) +] + +[tool.ruff.format] +quote-style = "double" + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["E501"] # Allow longer lines in tests for readability + diff --git a/backend/routes/audit.py b/backend/routes/audit.py index 8144497..baee8e2 100644 --- a/backend/routes/audit.py +++ b/backend/routes/audit.py @@ -1,22 +1,23 @@ """Audit routes for viewing action records.""" -from typing import Callable, TypeVar + +from collections.abc import Callable +from typing import TypeVar from fastapi import APIRouter, Depends, Query from pydantic import BaseModel -from sqlalchemy import select, func, desc +from sqlalchemy import desc, func, select from sqlalchemy.ext.asyncio import AsyncSession from auth import require_permission from database import get_db -from models import User, SumRecord, CounterRecord, Permission +from models import CounterRecord, Permission, SumRecord, User from schemas import ( CounterRecordResponse, - SumRecordResponse, PaginatedCounterRecords, PaginatedSumRecords, + SumRecordResponse, ) - router = APIRouter(prefix="/api/audit", tags=["audit"]) R = TypeVar("R", bound=BaseModel) diff --git a/backend/routes/auth.py b/backend/routes/auth.py index be8c138..604b50d 100644 --- a/backend/routes/auth.py +++ b/backend/routes/auth.py @@ -1,5 +1,6 @@ """Authentication routes for register, login, logout, and current user.""" -from datetime import datetime, timezone + +from datetime import UTC, datetime from fastapi import APIRouter, Depends, HTTPException, Response, status from sqlalchemy import select @@ -9,18 +10,17 @@ from auth import ( ACCESS_TOKEN_EXPIRE_MINUTES, COOKIE_NAME, COOKIE_SECURE, - get_password_hash, - get_user_by_email, authenticate_user, + build_user_response, create_access_token, get_current_user, - build_user_response, + get_password_hash, + get_user_by_email, ) from database import get_db from invite_utils import normalize_identifier -from models import User, Role, ROLE_REGULAR, Invite, InviteStatus -from schemas import UserLogin, UserResponse, RegisterWithInvite - +from models import ROLE_REGULAR, Invite, InviteStatus, Role, User +from schemas import RegisterWithInvite, UserLogin, UserResponse router = APIRouter(prefix="/api/auth", tags=["auth"]) @@ -52,9 +52,8 @@ async def register( """Register a new user using an invite code.""" # Validate invite normalized_identifier = normalize_identifier(user_data.invite_identifier) - result = await db.execute( - select(Invite).where(Invite.identifier == normalized_identifier) - ) + query = select(Invite).where(Invite.identifier == normalized_identifier) + result = await db.execute(query) invite = result.scalar_one_or_none() # Return same error for not found, spent, and revoked to avoid information leakage @@ -90,7 +89,7 @@ async def register( # Mark invite as spent invite.status = InviteStatus.SPENT invite.used_by_id = user.id - invite.spent_at = datetime.now(timezone.utc) + invite.spent_at = datetime.now(UTC) await db.commit() await db.refresh(user) diff --git a/backend/routes/availability.py b/backend/routes/availability.py index 570c097..a55ae46 100644 --- a/backend/routes/availability.py +++ b/backend/routes/availability.py @@ -1,28 +1,28 @@ """Availability routes for admin to manage booking availability.""" + from datetime import date, timedelta from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy import select, delete, and_ +from sqlalchemy import and_, delete, select from sqlalchemy.ext.asyncio import AsyncSession from auth import require_permission from database import get_db -from models import User, Availability, Permission +from models import Availability, Permission, User from schemas import ( - TimeSlot, AvailabilityDay, AvailabilityResponse, - SetAvailabilityRequest, CopyAvailabilityRequest, + SetAvailabilityRequest, + TimeSlot, ) -from shared_constants import MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS - +from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS router = APIRouter(prefix="/api/admin/availability", tags=["availability"]) def _get_date_range_bounds() -> tuple[date, date]: - """Get the valid date range for availability (using MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS).""" + """Get valid date range (MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS).""" today = date.today() min_date = today + timedelta(days=MIN_ADVANCE_DAYS) max_date = today + timedelta(days=MAX_ADVANCE_DAYS) @@ -34,12 +34,14 @@ def _validate_date_in_range(d: date, min_date: date, max_date: date) -> None: if d < min_date: raise HTTPException( status_code=400, - detail=f"Cannot set availability for past dates. Earliest allowed: {min_date}", + detail=f"Cannot set availability for past dates. " + f"Earliest allowed: {min_date}", ) if d > max_date: raise HTTPException( status_code=400, - detail=f"Cannot set availability more than {MAX_ADVANCE_DAYS} days ahead. Latest allowed: {max_date}", + detail=f"Cannot set more than {MAX_ADVANCE_DAYS} days ahead. " + f"Latest allowed: {max_date}", ) @@ -56,7 +58,7 @@ async def get_availability( status_code=400, detail="'from' date must be before or equal to 'to' date", ) - + # Query availability in range result = await db.execute( select(Availability) @@ -64,23 +66,24 @@ async def get_availability( .order_by(Availability.date, Availability.start_time) ) slots = result.scalars().all() - + # Group by date days_dict: dict[date, list[TimeSlot]] = {} for slot in slots: if slot.date not in days_dict: days_dict[slot.date] = [] - days_dict[slot.date].append(TimeSlot( - start_time=slot.start_time, - end_time=slot.end_time, - )) - + days_dict[slot.date].append( + TimeSlot( + start_time=slot.start_time, + end_time=slot.end_time, + ) + ) + # Convert to response format days = [ - AvailabilityDay(date=d, slots=days_dict[d]) - for d in sorted(days_dict.keys()) + AvailabilityDay(date=d, slots=days_dict[d]) for d in sorted(days_dict.keys()) ] - + return AvailabilityResponse(days=days) @@ -93,29 +96,31 @@ async def set_availability( """Set availability for a specific date. Replaces any existing availability.""" min_date, max_date = _get_date_range_bounds() _validate_date_in_range(request.date, min_date, max_date) - + # Validate slots don't overlap sorted_slots = sorted(request.slots, key=lambda s: s.start_time) for i in range(len(sorted_slots) - 1): if sorted_slots[i].end_time > sorted_slots[i + 1].start_time: + end = sorted_slots[i].end_time + start = sorted_slots[i + 1].start_time raise HTTPException( status_code=400, - detail=f"Time slots overlap on {request.date}: slot ending at {sorted_slots[i].end_time} overlaps with slot starting at {sorted_slots[i + 1].start_time}. Please ensure all time slots are non-overlapping.", + detail=f"Time slots overlap: slot ending at {end} " + f"overlaps with slot starting at {start}", ) - + # Validate each slot's end_time > start_time for slot in request.slots: if slot.end_time <= slot.start_time: raise HTTPException( status_code=400, - detail=f"Invalid time slot on {request.date}: end time {slot.end_time} must be after start time {slot.start_time}. Please correct the time range.", + detail=f"Invalid time slot: end time {slot.end_time} " + f"must be after start time {slot.start_time}", ) - + # Delete existing availability for this date - await db.execute( - delete(Availability).where(Availability.date == request.date) - ) - + await db.execute(delete(Availability).where(Availability.date == request.date)) + # Create new availability slots for slot in request.slots: availability = Availability( @@ -124,9 +129,9 @@ async def set_availability( end_time=slot.end_time, ) db.add(availability) - + await db.commit() - + return AvailabilityDay(date=request.date, slots=request.slots) @@ -138,14 +143,14 @@ async def copy_availability( ) -> AvailabilityResponse: """Copy availability from one day to multiple target days.""" min_date, max_date = _get_date_range_bounds() - - # Validate source date is in range (for consistency, though DB query would fail anyway) + + # Validate source date is in range _validate_date_in_range(request.source_date, min_date, max_date) - + # Validate target dates for target_date in request.target_dates: _validate_date_in_range(target_date, min_date, max_date) - + # Get source availability result = await db.execute( select(Availability) @@ -153,13 +158,13 @@ async def copy_availability( .order_by(Availability.start_time) ) source_slots = result.scalars().all() - + if not source_slots: raise HTTPException( status_code=400, detail=f"No availability found for source date {request.source_date}", ) - + # Copy to each target date within a single atomic transaction # All deletes and inserts happen before commit, ensuring atomicity copied_days: list[AvailabilityDay] = [] @@ -167,12 +172,11 @@ async def copy_availability( for target_date in request.target_dates: if target_date == request.source_date: continue # Skip copying to self - + # Delete existing availability for target date - await db.execute( - delete(Availability).where(Availability.date == target_date) - ) - + del_query = delete(Availability).where(Availability.date == target_date) + await db.execute(del_query) + # Copy slots target_slots: list[TimeSlot] = [] for source_slot in source_slots: @@ -182,19 +186,20 @@ async def copy_availability( end_time=source_slot.end_time, ) db.add(new_availability) - target_slots.append(TimeSlot( - start_time=source_slot.start_time, - end_time=source_slot.end_time, - )) - + target_slots.append( + TimeSlot( + start_time=source_slot.start_time, + end_time=source_slot.end_time, + ) + ) + copied_days.append(AvailabilityDay(date=target_date, slots=target_slots)) - + # Commit all changes atomically await db.commit() except Exception: # Rollback on any error to maintain atomicity await db.rollback() raise - - return AvailabilityResponse(days=copied_days) + return AvailabilityResponse(days=copied_days) diff --git a/backend/routes/booking.py b/backend/routes/booking.py index 4ca135f..6dddf96 100644 --- a/backend/routes/booking.py +++ b/backend/routes/booking.py @@ -1,24 +1,24 @@ """Booking routes for users to book appointments.""" -from datetime import date, datetime, time, timedelta, timezone + +from datetime import UTC, date, datetime, time, timedelta from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy import select, and_, func +from sqlalchemy import and_, func, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from auth import require_permission from database import get_db -from models import User, Availability, Appointment, AppointmentStatus, Permission +from models import Appointment, AppointmentStatus, Availability, Permission, User from schemas import ( - BookableSlot, - AvailableSlotsResponse, - BookingRequest, AppointmentResponse, + AvailableSlotsResponse, + BookableSlot, + BookingRequest, PaginatedAppointments, ) -from shared_constants import SLOT_DURATION_MINUTES, MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS - +from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS, SLOT_DURATION_MINUTES router = APIRouter(prefix="/api/booking", tags=["booking"]) @@ -28,7 +28,7 @@ def _to_appointment_response( user_email: str | None = None, ) -> AppointmentResponse: """Convert an Appointment model to AppointmentResponse schema. - + Args: appointment: The appointment model instance user_email: Optional user email. If not provided, uses appointment.user.email @@ -49,7 +49,7 @@ def _to_appointment_response( def _get_valid_minute_boundaries() -> tuple[int, ...]: """Get valid minute boundaries based on SLOT_DURATION_MINUTES. - + Assumes SLOT_DURATION_MINUTES divides 60 evenly (e.g., 15 minutes = 0, 15, 30, 45). """ boundaries: list[int] = [] @@ -74,12 +74,14 @@ def _validate_booking_date(d: date) -> None: if d < min_date: raise HTTPException( status_code=400, - detail=f"Cannot book for today or past dates. Earliest bookable date: {min_date}", + detail=f"Cannot book for today or past dates. " + f"Earliest bookable: {min_date}", ) if d > max_date: raise HTTPException( status_code=400, - detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. Latest bookable: {max_date}", + detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. " + f"Latest bookable: {max_date}", ) @@ -89,18 +91,18 @@ def _expand_availability_to_slots( ) -> list[BookableSlot]: """Expand availability time ranges into 15-minute bookable slots.""" result: list[BookableSlot] = [] - + for avail in availability_slots: # Create datetime objects for start and end - current = datetime.combine(target_date, avail.start_time, tzinfo=timezone.utc) - end = datetime.combine(target_date, avail.end_time, tzinfo=timezone.utc) - + current = datetime.combine(target_date, avail.start_time, tzinfo=UTC) + end = datetime.combine(target_date, avail.end_time, tzinfo=UTC) + # Generate 15-minute slots while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end: slot_end = current + timedelta(minutes=SLOT_DURATION_MINUTES) result.append(BookableSlot(start_time=current, end_time=slot_end)) current = slot_end - + return result @@ -112,7 +114,7 @@ async def get_available_slots( ) -> AvailableSlotsResponse: """Get available booking slots for a specific date.""" _validate_booking_date(target_date) - + # Get availability for this date result = await db.execute( select(Availability) @@ -120,20 +122,19 @@ async def get_available_slots( .order_by(Availability.start_time) ) availability_slots = result.scalars().all() - + if not availability_slots: return AvailableSlotsResponse(date=target_date, slots=[]) - + # Expand to 15-minute slots all_slots = _expand_availability_to_slots(availability_slots, target_date) - + # Get existing booked appointments for this date - day_start = datetime.combine(target_date, time.min, tzinfo=timezone.utc) - day_end = datetime.combine(target_date, time.max, tzinfo=timezone.utc) - + day_start = datetime.combine(target_date, time.min, tzinfo=UTC) + day_end = datetime.combine(target_date, time.max, tzinfo=UTC) + result = await db.execute( - select(Appointment.slot_start) - .where( + select(Appointment.slot_start).where( and_( Appointment.slot_start >= day_start, Appointment.slot_start <= day_end, @@ -142,13 +143,12 @@ async def get_available_slots( ) ) booked_starts = {row[0] for row in result.fetchall()} - + # Filter out already booked slots available_slots = [ - slot for slot in all_slots - if slot.start_time not in booked_starts + slot for slot in all_slots if slot.start_time not in booked_starts ] - + return AvailableSlotsResponse(date=target_date, slots=available_slots) @@ -161,27 +161,28 @@ async def create_booking( """Book an appointment slot.""" slot_date = request.slot_start.date() _validate_booking_date(slot_date) - - # Validate slot is on the correct minute boundary (derived from SLOT_DURATION_MINUTES) + + # Validate slot is on the correct minute boundary valid_minutes = _get_valid_minute_boundaries() if request.slot_start.minute not in valid_minutes: raise HTTPException( status_code=400, - detail=f"Slot start time must be on {SLOT_DURATION_MINUTES}-minute boundary (valid minutes: {valid_minutes})", + detail=f"Slot must be on {SLOT_DURATION_MINUTES}-minute boundary " + f"(valid minutes: {valid_minutes})", ) if request.slot_start.second != 0 or request.slot_start.microsecond != 0: raise HTTPException( status_code=400, detail="Slot start time must not have seconds or microseconds", ) - + # Verify slot falls within availability slot_start_time = request.slot_start.time() - slot_end_time = (request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)).time() - + slot_end_dt = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES) + slot_end_time = slot_end_dt.time() + result = await db.execute( - select(Availability) - .where( + select(Availability).where( and_( Availability.date == slot_date, Availability.start_time <= slot_start_time, @@ -190,13 +191,15 @@ async def create_booking( ) ) matching_availability = result.scalar_one_or_none() - + if not matching_availability: + slot_str = request.slot_start.strftime("%Y-%m-%d %H:%M") raise HTTPException( status_code=400, - detail=f"Selected slot at {request.slot_start.strftime('%Y-%m-%d %H:%M')} UTC is not within any available time ranges for {slot_date}. Please select a different time slot.", + detail=f"Selected slot at {slot_str} UTC is not within " + f"any available time ranges for {slot_date}", ) - + # Create the appointment slot_end = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES) appointment = Appointment( @@ -206,9 +209,9 @@ async def create_booking( note=request.note, status=AppointmentStatus.BOOKED, ) - + db.add(appointment) - + try: await db.commit() await db.refresh(appointment) @@ -216,9 +219,9 @@ async def create_booking( await db.rollback() raise HTTPException( status_code=409, - detail="This slot has already been booked. Please select another slot.", - ) - + detail="This slot has already been booked. Select another slot.", + ) from None + return _to_appointment_response(appointment, current_user.email) @@ -241,60 +244,63 @@ async def get_my_appointments( .order_by(Appointment.slot_start.desc()) ) appointments = result.scalars().all() - - return [ - _to_appointment_response(apt, current_user.email) - for apt in appointments - ] + + return [_to_appointment_response(apt, current_user.email) for apt in appointments] -@appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse) +@appointments_router.post( + "/{appointment_id}/cancel", response_model=AppointmentResponse +) async def cancel_my_appointment( appointment_id: int, db: AsyncSession = Depends(get_db), current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)), ) -> AppointmentResponse: """Cancel one of the current user's appointments.""" - # Get the appointment with explicit eager loading of user relationship + # Get the appointment with eager loading of user relationship result = await db.execute( select(Appointment) .options(joinedload(Appointment.user)) .where(Appointment.id == appointment_id) ) appointment = result.scalar_one_or_none() - + if not appointment: raise HTTPException( status_code=404, - detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.", + detail=f"Appointment {appointment_id} not found", ) - + # Verify ownership if appointment.user_id != current_user.id: - raise HTTPException(status_code=403, detail="Cannot cancel another user's appointment") - + raise HTTPException( + status_code=403, + detail="Cannot cancel another user's appointment", + ) + # Check if already cancelled if appointment.status != AppointmentStatus.BOOKED: raise HTTPException( status_code=400, - detail=f"Cannot cancel appointment with status '{appointment.status.value}'" + detail=f"Cannot cancel: status is '{appointment.status.value}'", ) - + # Check if appointment is in the past - if appointment.slot_start <= datetime.now(timezone.utc): - appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC" + if appointment.slot_start <= datetime.now(UTC): + apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M") raise HTTPException( status_code=400, - detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started." + detail=f"Cannot cancel appointment at {apt_time} UTC: " + "already started or in the past", ) - + # Cancel the appointment appointment.status = AppointmentStatus.CANCELLED_BY_USER - appointment.cancelled_at = datetime.now(timezone.utc) - + appointment.cancelled_at = datetime.now(UTC) + await db.commit() await db.refresh(appointment) - + return _to_appointment_response(appointment, current_user.email) @@ -302,7 +308,9 @@ async def cancel_my_appointment( # Admin Appointments Endpoints # ============================================================================= -admin_appointments_router = APIRouter(prefix="/api/admin/appointments", tags=["admin-appointments"]) +admin_appointments_router = APIRouter( + prefix="/api/admin/appointments", tags=["admin-appointments"] +) @admin_appointments_router.get("", response_model=PaginatedAppointments) @@ -317,7 +325,7 @@ async def get_all_appointments( count_result = await db.execute(select(func.count(Appointment.id))) total = count_result.scalar() or 0 total_pages = (total + per_page - 1) // per_page if total > 0 else 1 - + # Get paginated appointments with explicit eager loading of user relationship offset = (page - 1) * per_page result = await db.execute( @@ -328,13 +336,13 @@ async def get_all_appointments( .limit(per_page) ) appointments = result.scalars().all() - + # Build responses using the eager-loaded user relationship records = [ _to_appointment_response(apt) # Uses eager-loaded relationship for apt in appointments ] - + return PaginatedAppointments( records=records, total=total, @@ -344,47 +352,52 @@ async def get_all_appointments( ) -@admin_appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse) +@admin_appointments_router.post( + "/{appointment_id}/cancel", response_model=AppointmentResponse +) async def admin_cancel_appointment( appointment_id: int, db: AsyncSession = Depends(get_db), - _current_user: User = Depends(require_permission(Permission.CANCEL_ANY_APPOINTMENT)), + _current_user: User = Depends( + require_permission(Permission.CANCEL_ANY_APPOINTMENT) + ), ) -> AppointmentResponse: """Cancel any appointment (admin only).""" - # Get the appointment with explicit eager loading of user relationship + # Get the appointment with eager loading of user relationship result = await db.execute( select(Appointment) .options(joinedload(Appointment.user)) .where(Appointment.id == appointment_id) ) appointment = result.scalar_one_or_none() - + if not appointment: raise HTTPException( status_code=404, - detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.", + detail=f"Appointment {appointment_id} not found", ) - + # Check if already cancelled if appointment.status != AppointmentStatus.BOOKED: raise HTTPException( status_code=400, - detail=f"Cannot cancel appointment with status '{appointment.status.value}'" + detail=f"Cannot cancel: status is '{appointment.status.value}'", ) - + # Check if appointment is in the past - if appointment.slot_start <= datetime.now(timezone.utc): - appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC" + if appointment.slot_start <= datetime.now(UTC): + apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M") raise HTTPException( status_code=400, - detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started." + detail=f"Cannot cancel appointment at {apt_time} UTC: " + "already started or in the past", ) - + # Cancel the appointment appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN - appointment.cancelled_at = datetime.now(timezone.utc) - + appointment.cancelled_at = datetime.now(UTC) + await db.commit() await db.refresh(appointment) - + return _to_appointment_response(appointment) # Uses eager-loaded relationship diff --git a/backend/routes/counter.py b/backend/routes/counter.py index 034766e..1f31acc 100644 --- a/backend/routes/counter.py +++ b/backend/routes/counter.py @@ -1,12 +1,12 @@ """Counter routes.""" + from fastapi import APIRouter, Depends from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from auth import require_permission from database import get_db -from models import Counter, User, CounterRecord, Permission - +from models import Counter, CounterRecord, Permission, User router = APIRouter(prefix="/api/counter", tags=["counter"]) diff --git a/backend/routes/invites.py b/backend/routes/invites.py index 4fde96e..741c280 100644 --- a/backend/routes/invites.py +++ b/backend/routes/invites.py @@ -1,25 +1,29 @@ """Invite routes for public check, user invites, and admin management.""" -from datetime import datetime, timezone -from fastapi import APIRouter, Depends, HTTPException, status, Query -from sqlalchemy import select, func, desc +from datetime import UTC, datetime + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import desc, func, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from auth import require_permission from database import get_db -from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format -from models import User, Invite, InviteStatus, Permission +from invite_utils import ( + generate_invite_identifier, + is_valid_identifier_format, + normalize_identifier, +) +from models import Invite, InviteStatus, Permission, User from schemas import ( + AdminUserResponse, InviteCheckResponse, InviteCreate, InviteResponse, - UserInviteResponse, PaginatedInviteRecords, - AdminUserResponse, + UserInviteResponse, ) - router = APIRouter(prefix="/api/invites", tags=["invites"]) admin_router = APIRouter(prefix="/api/admin", tags=["admin"]) @@ -54,9 +58,7 @@ async def check_invite( if not is_valid_identifier_format(normalized): return InviteCheckResponse(valid=False, error="Invalid invite code format") - result = await db.execute( - select(Invite).where(Invite.identifier == normalized) - ) + result = await db.execute(select(Invite).where(Invite.identifier == normalized)) invite = result.scalar_one_or_none() # Return same error for not found, spent, and revoked to avoid information leakage @@ -112,9 +114,7 @@ async def create_invite( ) -> InviteResponse: """Create a new invite for a specified godfather user.""" # Validate godfather exists - result = await db.execute( - select(User.id).where(User.id == data.godfather_id) - ) + result = await db.execute(select(User.id).where(User.id == data.godfather_id)) godfather_id = result.scalar_one_or_none() if not godfather_id: raise HTTPException( @@ -141,8 +141,8 @@ async def create_invite( if attempt == MAX_INVITE_COLLISION_RETRIES - 1: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to generate unique invite code. Please try again.", - ) + detail="Failed to generate unique invite code. Try again.", + ) from None if invite is None: raise HTTPException( @@ -156,7 +156,9 @@ async def create_invite( async def list_all_invites( page: int = Query(1, ge=1), per_page: int = Query(10, ge=1, le=100), - status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"), + status_filter: str | None = Query( + None, alias="status", description="Filter by status: ready, spent, revoked" + ), godfather_id: int | None = Query(None, description="Filter by godfather user ID"), db: AsyncSession = Depends(get_db), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), @@ -175,8 +177,9 @@ async def list_all_invites( except ValueError: raise HTTPException( status_code=400, - detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked", - ) + detail=f"Invalid status: {status_filter}. " + "Must be ready, spent, or revoked", + ) from None if godfather_id: query = query.where(Invite.godfather_id == godfather_id) @@ -224,11 +227,12 @@ async def revoke_invite( if invite.status != InviteStatus.READY: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.", + detail=f"Cannot revoke invite with status '{invite.status.value}'. " + "Only READY invites can be revoked.", ) invite.status = InviteStatus.REVOKED - invite.revoked_at = datetime.now(timezone.utc) + invite.revoked_at = datetime.now(UTC) await db.commit() await db.refresh(invite) diff --git a/backend/routes/meta.py b/backend/routes/meta.py index c984ab3..7eb7643 100644 --- a/backend/routes/meta.py +++ b/backend/routes/meta.py @@ -1,7 +1,8 @@ """Meta endpoints for shared constants.""" + from fastapi import APIRouter -from models import Permission, InviteStatus, ROLE_ADMIN, ROLE_REGULAR +from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, Permission from schemas import ConstantsResponse router = APIRouter(prefix="/api/meta", tags=["meta"]) @@ -15,4 +16,3 @@ async def get_constants() -> ConstantsResponse: roles=[ROLE_ADMIN, ROLE_REGULAR], invite_statuses=[s.value for s in InviteStatus], ) - diff --git a/backend/routes/profile.py b/backend/routes/profile.py index aba4cc4..99f570b 100644 --- a/backend/routes/profile.py +++ b/backend/routes/profile.py @@ -1,15 +1,15 @@ """Profile routes for user contact details.""" + from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from auth import get_current_user from database import get_db -from models import User, ROLE_REGULAR +from models import ROLE_REGULAR, User from schemas import ProfileResponse, ProfileUpdate from validation import validate_profile_fields - router = APIRouter(prefix="/api/profile", tags=["profile"]) @@ -29,9 +29,7 @@ async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str """Get the email of a godfather user by ID.""" if not godfather_id: return None - result = await db.execute( - select(User.email).where(User.id == godfather_id) - ) + result = await db.execute(select(User.email).where(User.id == godfather_id)) return result.scalar_one_or_none() diff --git a/backend/routes/sum.py b/backend/routes/sum.py index ab0bbfa..6db890e 100644 --- a/backend/routes/sum.py +++ b/backend/routes/sum.py @@ -1,13 +1,13 @@ """Sum calculation routes.""" + from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from auth import require_permission from database import get_db -from models import User, SumRecord, Permission +from models import Permission, SumRecord, User from schemas import SumRequest, SumResponse - router = APIRouter(prefix="/api/sum", tags=["sum"]) diff --git a/backend/schemas.py b/backend/schemas.py index 1993ebe..f98ab5e 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -1,5 +1,6 @@ """Pydantic schemas for API request/response models.""" -from datetime import datetime, date, time + +from datetime import date, datetime, time from typing import Generic, TypeVar from pydantic import BaseModel, EmailStr, field_validator @@ -9,6 +10,7 @@ from shared_constants import NOTE_MAX_LENGTH class UserCredentials(BaseModel): """Base model for user email/password.""" + email: EmailStr password: str @@ -19,6 +21,7 @@ UserLogin = UserCredentials class UserResponse(BaseModel): """Response model for authenticated user info.""" + id: int email: str roles: list[str] @@ -27,6 +30,7 @@ class UserResponse(BaseModel): class RegisterWithInvite(BaseModel): """Request model for registration with invite.""" + email: EmailStr password: str invite_identifier: str @@ -34,12 +38,14 @@ class RegisterWithInvite(BaseModel): class SumRequest(BaseModel): """Request model for sum calculation.""" + a: float b: float class SumResponse(BaseModel): """Response model for sum calculation.""" + a: float b: float result: float @@ -47,6 +53,7 @@ class SumResponse(BaseModel): class CounterRecordResponse(BaseModel): """Response model for a counter audit record.""" + id: int user_email: str value_before: int @@ -56,6 +63,7 @@ class CounterRecordResponse(BaseModel): class SumRecordResponse(BaseModel): """Response model for a sum audit record.""" + id: int user_email: str a: float @@ -69,6 +77,7 @@ RecordT = TypeVar("RecordT", bound=BaseModel) class PaginatedResponse(BaseModel, Generic[RecordT]): """Generic paginated response wrapper.""" + records: list[RecordT] total: int page: int @@ -82,6 +91,7 @@ PaginatedSumRecords = PaginatedResponse[SumRecordResponse] class ProfileResponse(BaseModel): """Response model for profile data.""" + contact_email: str | None telegram: str | None signal: str | None @@ -91,6 +101,7 @@ class ProfileResponse(BaseModel): class ProfileUpdate(BaseModel): """Request model for updating profile.""" + contact_email: str | None = None telegram: str | None = None signal: str | None = None @@ -99,6 +110,7 @@ class ProfileUpdate(BaseModel): class InviteCheckResponse(BaseModel): """Response for invite check endpoint.""" + valid: bool status: str | None = None error: str | None = None @@ -106,11 +118,13 @@ class InviteCheckResponse(BaseModel): class InviteCreate(BaseModel): """Request model for creating an invite.""" + godfather_id: int class InviteResponse(BaseModel): """Response model for invite data (admin view).""" + id: int identifier: str godfather_id: int @@ -125,6 +139,7 @@ class InviteResponse(BaseModel): class UserInviteResponse(BaseModel): """Response model for a user's invite (simpler than admin view).""" + id: int identifier: str status: str @@ -138,6 +153,7 @@ PaginatedInviteRecords = PaginatedResponse[InviteResponse] class AdminUserResponse(BaseModel): """Minimal user info for admin dropdowns.""" + id: int email: str @@ -146,11 +162,13 @@ class AdminUserResponse(BaseModel): # Availability Schemas # ============================================================================= + class TimeSlot(BaseModel): """A single time slot (start and end time).""" + start_time: time end_time: time - + @field_validator("start_time", "end_time") @classmethod def validate_15min_boundary(cls, v: time) -> time: @@ -164,23 +182,27 @@ class TimeSlot(BaseModel): class AvailabilityDay(BaseModel): """Availability for a single day.""" + date: date slots: list[TimeSlot] class AvailabilityResponse(BaseModel): """Response model for availability query.""" + days: list[AvailabilityDay] class SetAvailabilityRequest(BaseModel): """Request to set availability for a specific date.""" + date: date slots: list[TimeSlot] class CopyAvailabilityRequest(BaseModel): """Request to copy availability from one day to others.""" + source_date: date target_dates: list[date] @@ -189,20 +211,24 @@ class CopyAvailabilityRequest(BaseModel): # Booking Schemas # ============================================================================= + class BookableSlot(BaseModel): """A bookable 15-minute slot.""" + start_time: datetime end_time: datetime class AvailableSlotsResponse(BaseModel): """Response for available slots on a given date.""" + date: date slots: list[BookableSlot] class BookingRequest(BaseModel): """Request to book an appointment.""" + slot_start: datetime note: str | None = None @@ -216,6 +242,7 @@ class BookingRequest(BaseModel): class AppointmentResponse(BaseModel): """Response model for an appointment.""" + id: int user_id: int user_email: str @@ -234,8 +261,10 @@ PaginatedAppointments = PaginatedResponse[AppointmentResponse] # Meta/Constants Schemas # ============================================================================= + class ConstantsResponse(BaseModel): """Response model for shared constants.""" + permissions: list[str] roles: list[str] invite_statuses: list[str] diff --git a/backend/seed.py b/backend/seed.py index 185c2a6..dbb79a7 100644 --- a/backend/seed.py +++ b/backend/seed.py @@ -1,12 +1,21 @@ """Seed the database with roles, permissions, and dev users.""" + import asyncio import os + from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database import engine, async_session, Base -from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN from auth import get_password_hash +from database import Base, async_session, engine +from models import ( + ROLE_ADMIN, + ROLE_DEFINITIONS, + ROLE_REGULAR, + Permission, + Role, + User, +) DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"] DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"] @@ -14,11 +23,13 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"] DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"] -async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role: +async def upsert_role( + db: AsyncSession, name: str, description: str, permissions: list[Permission] +) -> Role: """Create or update a role with the given permissions.""" result = await db.execute(select(Role).where(Role.name == name)) role = result.scalar_one_or_none() - + if role: role.description = description print(f"Updated role: {name}") @@ -27,19 +38,21 @@ async def upsert_role(db: AsyncSession, name: str, description: str, permissions db.add(role) await db.flush() # Get the role ID print(f"Created role: {name}") - + # Set permissions for the role await role.set_permissions(db, permissions) print(f" Permissions: {', '.join(p.value for p in permissions)}") - + return role -async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User: +async def upsert_user( + db: AsyncSession, email: str, password: str, role_names: list[str] +) -> User: """Create or update a user with the given credentials and roles.""" result = await db.execute(select(User).where(User.email == email)) user = result.scalar_one_or_none() - + # Get roles roles = [] for role_name in role_names: @@ -48,7 +61,7 @@ async def upsert_user(db: AsyncSession, email: str, password: str, role_names: l if not role: raise ValueError(f"Role '{role_name}' not found") roles.append(role) - + if user: user.hashed_password = get_password_hash(password) user.roles = roles # type: ignore[assignment] @@ -61,7 +74,7 @@ async def upsert_user(db: AsyncSession, email: str, password: str, role_names: l ) db.add(user) print(f"Created user: {email} with roles: {role_names}") - + return user @@ -78,14 +91,14 @@ async def seed() -> None: role_config["description"], role_config["permissions"], ) - + print("\n=== Seeding Users ===") # Create regular dev user await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, [ROLE_REGULAR]) - + # Create admin dev user await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, [ROLE_ADMIN]) - + await db.commit() print("\n=== Seeding Complete ===\n") diff --git a/backend/shared_constants.py b/backend/shared_constants.py index 3a48651..b3151fc 100644 --- a/backend/shared_constants.py +++ b/backend/shared_constants.py @@ -1,4 +1,5 @@ """Load shared constants from shared/constants.json.""" + import json from pathlib import Path @@ -10,4 +11,3 @@ SLOT_DURATION_MINUTES: int = _constants["booking"]["slotDurationMinutes"] MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"] MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"] NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"] - diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 140a97c..bfe1516 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -7,28 +7,28 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only") import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import select -from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from auth import get_password_hash from database import Base, get_db from main import app -from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN -from auth import get_password_hash +from models import ROLE_ADMIN, ROLE_DEFINITIONS, ROLE_REGULAR, Role, User from tests.helpers import unique_email TEST_DATABASE_URL = os.getenv( "TEST_DATABASE_URL", - "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test" + "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test", ) class ClientFactory: """Factory for creating httpx clients with optional cookies.""" - + def __init__(self, transport, base_url, session_factory): self._transport = transport self._base_url = base_url self._session_factory = session_factory - + @asynccontextmanager async def create(self, cookies: dict | None = None): """Create a new client, optionally with cookies set.""" @@ -38,15 +38,15 @@ class ClientFactory: cookies=cookies or {}, ) as client: yield client - + async def request(self, method: str, url: str, **kwargs): """Make a one-off request without cookies.""" async with self.create() as client: return await client.request(method, url, **kwargs) - + async def get(self, url: str, **kwargs): return await self.request("GET", url, **kwargs) - + async def post(self, url: str, **kwargs): return await self.request("POST", url, **kwargs) @@ -64,16 +64,16 @@ async def setup_roles(db: AsyncSession) -> dict[str, Role]: # Check if role exists result = await db.execute(select(Role).where(Role.name == role_name)) role = result.scalar_one_or_none() - + if not role: role = Role(name=role_name, description=config["description"]) db.add(role) await db.flush() - + # Set permissions await role.set_permissions(db, config["permissions"]) roles[role_name] = role - + await db.commit() return roles @@ -91,9 +91,11 @@ async def create_user_with_roles( result = await db.execute(select(Role).where(Role.name == role_name)) role = result.scalar_one_or_none() if not role: - raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?") + raise ValueError( + f"Role '{role_name}' not found. Did you run setup_roles()?" + ) roles.append(role) - + user = User( email=email, hashed_password=get_password_hash(password), @@ -110,27 +112,27 @@ async def client_factory(): """Fixture that provides a factory for creating clients.""" engine = create_async_engine(TEST_DATABASE_URL) session_factory = async_sessionmaker(engine, expire_on_commit=False) - + # Create tables async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) - + # Setup roles async with session_factory() as db: await setup_roles(db) - + async def override_get_db(): async with session_factory() as session: yield session - + app.dependency_overrides[get_db] = override_get_db - + transport = ASGITransport(app=app) factory = ClientFactory(transport, "http://test", session_factory) - + yield factory - + app.dependency_overrides.clear() await engine.dispose() @@ -147,17 +149,17 @@ async def regular_user(client_factory): """Create a regular user and return their credentials and cookies.""" email = unique_email("regular") password = "password123" - + async with client_factory.get_db_session() as db: user = await create_user_with_roles(db, email, password, [ROLE_REGULAR]) user_id = user.id - + # Login to get cookies response = await client_factory.post( "/api/auth/login", json={"email": email, "password": password}, ) - + return { "email": email, "password": password, @@ -172,17 +174,17 @@ async def alt_regular_user(client_factory): """Create a second regular user for tests needing multiple users.""" email = unique_email("alt_regular") password = "password123" - + async with client_factory.get_db_session() as db: user = await create_user_with_roles(db, email, password, [ROLE_REGULAR]) user_id = user.id - + # Login to get cookies response = await client_factory.post( "/api/auth/login", json={"email": email, "password": password}, ) - + return { "email": email, "password": password, @@ -197,16 +199,16 @@ async def admin_user(client_factory): """Create an admin user and return their credentials and cookies.""" email = unique_email("admin") password = "password123" - + async with client_factory.get_db_session() as db: await create_user_with_roles(db, email, password, [ROLE_ADMIN]) - + # Login to get cookies response = await client_factory.post( "/api/auth/login", json={"email": email, "password": password}, ) - + return { "email": email, "password": password, @@ -220,16 +222,16 @@ async def user_no_roles(client_factory): """Create a user with NO roles and return their credentials and cookies.""" email = unique_email("noroles") password = "password123" - + async with client_factory.get_db_session() as db: await create_user_with_roles(db, email, password, []) - + # Login to get cookies response = await client_factory.post( "/api/auth/login", json={"email": email, "password": password}, ) - + return { "email": email, "password": password, diff --git a/backend/tests/helpers.py b/backend/tests/helpers.py index e1fa202..be1f410 100644 --- a/backend/tests/helpers.py +++ b/backend/tests/helpers.py @@ -3,8 +3,8 @@ import uuid from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from models import User, Invite, InviteStatus from invite_utils import generate_invite_identifier +from models import Invite, InviteStatus, User def unique_email(prefix: str = "test") -> str: @@ -15,24 +15,24 @@ def unique_email(prefix: str = "test") -> str: async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> str: """ Create an invite for an existing godfather user. - + Args: db: Database session godfather_id: ID of the existing user who will be the godfather - + Returns: The invite identifier. - + Raises: ValueError: If the godfather user doesn't exist. """ # Verify godfather exists result = await db.execute(select(User).where(User.id == godfather_id)) godfather = result.scalar_one_or_none() - + if not godfather: raise ValueError(f"Godfather user with ID {godfather_id} not found") - + # Create invite identifier = generate_invite_identifier() invite = Invite( @@ -42,7 +42,7 @@ async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> st ) db.add(invite) await db.commit() - + return identifier @@ -50,24 +50,26 @@ async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> st async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str: """ Create an invite for an existing godfather user (looked up by email). - + The godfather must already exist in the database. - + Args: db: Database session godfather_email: Email of the existing user who will be the godfather - + Returns: The invite identifier. - + Raises: ValueError: If the godfather user doesn't exist. """ result = await db.execute(select(User).where(User.email == godfather_email)) godfather = result.scalar_one_or_none() - + if not godfather: - raise ValueError(f"Godfather user with email '{godfather_email}' not found. " - "Create the user first using create_user_with_roles().") - + raise ValueError( + f"Godfather user with email '{godfather_email}' not found. " + "Create the user first using create_user_with_roles()." + ) + return await create_invite_for_godfather(db, godfather.id) diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index 026ca66..018e80d 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -3,12 +3,13 @@ Note: Registration now requires an invite code. Tests that need to register users will create invites first via the helper function. """ + import pytest from auth import COOKIE_NAME from models import ROLE_REGULAR -from tests.helpers import unique_email, create_invite_for_godfather from tests.conftest import create_user_with_roles +from tests.helpers import create_invite_for_godfather, unique_email # Registration tests (with invite) @@ -16,12 +17,14 @@ from tests.conftest import create_user_with_roles async def test_register_success(client_factory): """Can register with valid invite code.""" email = unique_email("register") - + # Create godfather user and invite async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("godfather"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + response = await client_factory.post( "/api/auth/register", json={ @@ -46,13 +49,15 @@ async def test_register_success(client_factory): async def test_register_duplicate_email(client_factory): """Cannot register with already-used email.""" email = unique_email("duplicate") - + # Create godfather and two invites async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite1 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id) - + # First registration await client_factory.post( "/api/auth/register", @@ -62,7 +67,7 @@ async def test_register_duplicate_email(client_factory): "invite_identifier": invite1, }, ) - + # Second registration with same email response = await client_factory.post( "/api/auth/register", @@ -80,9 +85,11 @@ async def test_register_duplicate_email(client_factory): async def test_register_invalid_email(client_factory): """Cannot register with invalid email format.""" async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + response = await client_factory.post( "/api/auth/register", json={ @@ -136,11 +143,13 @@ async def test_register_empty_body(client): async def test_login_success(client_factory): """Can login with valid credentials.""" email = unique_email("login") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + await client_factory.post( "/api/auth/register", json={ @@ -165,11 +174,13 @@ async def test_login_success(client_factory): async def test_login_wrong_password(client_factory): """Cannot login with wrong password.""" email = unique_email("wrongpass") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + await client_factory.post( "/api/auth/register", json={ @@ -219,11 +230,13 @@ async def test_login_missing_fields(client): async def test_get_me_success(client_factory): """Can get current user info when authenticated.""" email = unique_email("me") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + reg_response = await client_factory.post( "/api/auth/register", json={ @@ -233,10 +246,10 @@ async def test_get_me_success(client_factory): }, ) cookies = dict(reg_response.cookies) - + async with client_factory.create(cookies=cookies) as authed: response = await authed.get("/api/auth/me") - + assert response.status_code == 200 data = response.json() assert data["email"] == email @@ -255,7 +268,9 @@ async def test_get_me_no_cookie(client): @pytest.mark.asyncio async def test_get_me_invalid_cookie(client_factory): """Cannot get current user with invalid cookie.""" - async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed: + async with client_factory.create( + cookies={COOKIE_NAME: "invalidtoken123"} + ) as authed: response = await authed.get("/api/auth/me") assert response.status_code == 401 assert response.json()["detail"] == "Invalid authentication credentials" @@ -275,11 +290,13 @@ async def test_get_me_expired_token(client_factory): async def test_cookie_from_register_works_for_me(client_factory): """Auth cookie from registration works for subsequent requests.""" email = unique_email("tokentest") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + reg_response = await client_factory.post( "/api/auth/register", json={ @@ -289,10 +306,10 @@ async def test_cookie_from_register_works_for_me(client_factory): }, ) cookies = dict(reg_response.cookies) - + async with client_factory.create(cookies=cookies) as authed: me_response = await authed.get("/api/auth/me") - + assert me_response.status_code == 200 assert me_response.json()["email"] == email @@ -301,11 +318,13 @@ async def test_cookie_from_register_works_for_me(client_factory): async def test_cookie_from_login_works_for_me(client_factory): """Auth cookie from login works for subsequent requests.""" email = unique_email("logintoken") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + await client_factory.post( "/api/auth/register", json={ @@ -319,10 +338,10 @@ async def test_cookie_from_login_works_for_me(client_factory): json={"email": email, "password": "password123"}, ) cookies = dict(login_response.cookies) - + async with client_factory.create(cookies=cookies) as authed: me_response = await authed.get("/api/auth/me") - + assert me_response.status_code == 200 assert me_response.json()["email"] == email @@ -333,12 +352,14 @@ async def test_multiple_users_isolated(client_factory): """Multiple users have isolated sessions.""" email1 = unique_email("user1") email2 = unique_email("user2") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite1 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id) - + resp1 = await client_factory.post( "/api/auth/register", json={ @@ -355,16 +376,16 @@ async def test_multiple_users_isolated(client_factory): "invite_identifier": invite2, }, ) - + cookies1 = dict(resp1.cookies) cookies2 = dict(resp2.cookies) - + async with client_factory.create(cookies=cookies1) as user1: me1 = await user1.get("/api/auth/me") - + async with client_factory.create(cookies=cookies2) as user2: me2 = await user2.get("/api/auth/me") - + assert me1.json()["email"] == email1 assert me2.json()["email"] == email2 assert me1.json()["id"] != me2.json()["id"] @@ -375,11 +396,13 @@ async def test_multiple_users_isolated(client_factory): async def test_password_is_hashed(client_factory): """Passwords are properly hashed (can login with correct password).""" email = unique_email("hashtest") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + await client_factory.post( "/api/auth/register", json={ @@ -399,11 +422,13 @@ async def test_password_is_hashed(client_factory): async def test_case_sensitive_password(client_factory): """Passwords are case-sensitive.""" email = unique_email("casetest") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + await client_factory.post( "/api/auth/register", json={ @@ -424,11 +449,13 @@ async def test_case_sensitive_password(client_factory): async def test_logout_success(client_factory): """Can logout successfully.""" email = unique_email("logout") - + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + reg_response = await client_factory.post( "/api/auth/register", json={ @@ -438,9 +465,9 @@ async def test_logout_success(client_factory): }, ) cookies = dict(reg_response.cookies) - + async with client_factory.create(cookies=cookies) as authed: logout_response = await authed.post("/api/auth/logout") - + assert logout_response.status_code == 200 assert logout_response.json() == {"ok": True} diff --git a/backend/tests/test_availability.py b/backend/tests/test_availability.py index 4e08ab8..c982681 100644 --- a/backend/tests/test_availability.py +++ b/backend/tests/test_availability.py @@ -3,7 +3,9 @@ Availability API Tests Tests for the admin availability management endpoints. """ -from datetime import date, time, timedelta + +from datetime import date, timedelta + import pytest @@ -19,6 +21,7 @@ def in_days(n: int) -> date: # Permission Tests # ============================================================================= + class TestAvailabilityPermissions: """Test that only admins can access availability endpoints.""" @@ -44,7 +47,9 @@ class TestAvailabilityPermissions: assert response.status_code == 200 @pytest.mark.asyncio - async def test_regular_user_cannot_get_availability(self, client_factory, regular_user): + async def test_regular_user_cannot_get_availability( + self, client_factory, regular_user + ): async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get( "/api/admin/availability", @@ -53,7 +58,9 @@ class TestAvailabilityPermissions: assert response.status_code == 403 @pytest.mark.asyncio - async def test_regular_user_cannot_set_availability(self, client_factory, regular_user): + async def test_regular_user_cannot_set_availability( + self, client_factory, regular_user + ): async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.put( "/api/admin/availability", @@ -88,6 +95,7 @@ class TestAvailabilityPermissions: # Set Availability Tests # ============================================================================= + class TestSetAvailability: """Test setting availability for a date.""" @@ -101,7 +109,7 @@ class TestSetAvailability: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + assert response.status_code == 200 data = response.json() assert data["date"] == str(tomorrow()) @@ -122,13 +130,15 @@ class TestSetAvailability: ], }, ) - + assert response.status_code == 200 data = response.json() assert len(data["slots"]) == 2 @pytest.mark.asyncio - async def test_set_empty_slots_clears_availability(self, client_factory, admin_user): + async def test_set_empty_slots_clears_availability( + self, client_factory, admin_user + ): async with client_factory.create(cookies=admin_user["cookies"]) as client: # First set some availability await client.put( @@ -138,13 +148,13 @@ class TestSetAvailability: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # Then clear it response = await client.put( "/api/admin/availability", json={"date": str(tomorrow()), "slots": []}, ) - + assert response.status_code == 200 data = response.json() assert len(data["slots"]) == 0 @@ -160,22 +170,22 @@ class TestSetAvailability: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # Replace with different slots - response = await client.put( + await client.put( "/api/admin/availability", json={ "date": str(tomorrow()), "slots": [{"start_time": "14:00:00", "end_time": "16:00:00"}], }, ) - + # Verify the replacement get_response = await client.get( "/api/admin/availability", params={"from": str(tomorrow()), "to": str(tomorrow())}, ) - + data = get_response.json() assert len(data["days"]) == 1 assert len(data["days"][0]["slots"]) == 1 @@ -186,6 +196,7 @@ class TestSetAvailability: # Validation Tests # ============================================================================= + class TestAvailabilityValidation: """Test validation rules for availability.""" @@ -200,7 +211,7 @@ class TestAvailabilityValidation: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + assert response.status_code == 400 assert "past" in response.json()["detail"].lower() @@ -214,7 +225,7 @@ class TestAvailabilityValidation: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + assert response.status_code == 400 assert "past" in response.json()["detail"].lower() @@ -229,7 +240,7 @@ class TestAvailabilityValidation: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + assert response.status_code == 400 assert "30" in response.json()["detail"] @@ -243,7 +254,7 @@ class TestAvailabilityValidation: "slots": [{"start_time": "09:05:00", "end_time": "12:00:00"}], }, ) - + assert response.status_code == 422 # Pydantic validation error assert "15-minute" in response.json()["detail"][0]["msg"] @@ -257,7 +268,7 @@ class TestAvailabilityValidation: "slots": [{"start_time": "12:00:00", "end_time": "09:00:00"}], }, ) - + assert response.status_code == 400 assert "after" in response.json()["detail"].lower() @@ -274,7 +285,7 @@ class TestAvailabilityValidation: ], }, ) - + assert response.status_code == 400 assert "overlap" in response.json()["detail"].lower() @@ -283,6 +294,7 @@ class TestAvailabilityValidation: # Get Availability Tests # ============================================================================= + class TestGetAvailability: """Test retrieving availability.""" @@ -293,7 +305,7 @@ class TestGetAvailability: "/api/admin/availability", params={"from": str(tomorrow()), "to": str(in_days(7))}, ) - + assert response.status_code == 200 data = response.json() assert data["days"] == [] @@ -310,13 +322,13 @@ class TestGetAvailability: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # Get range that includes all response = await client.get( "/api/admin/availability", params={"from": str(in_days(1)), "to": str(in_days(3))}, ) - + assert response.status_code == 200 data = response.json() assert len(data["days"]) == 3 @@ -333,13 +345,13 @@ class TestGetAvailability: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # Get only a subset response = await client.get( "/api/admin/availability", params={"from": str(in_days(2)), "to": str(in_days(4))}, ) - + assert response.status_code == 200 data = response.json() assert len(data["days"]) == 3 @@ -351,7 +363,7 @@ class TestGetAvailability: "/api/admin/availability", params={"from": str(in_days(7)), "to": str(in_days(1))}, ) - + assert response.status_code == 400 assert "before" in response.json()["detail"].lower() @@ -360,6 +372,7 @@ class TestGetAvailability: # Copy Availability Tests # ============================================================================= + class TestCopyAvailability: """Test copying availability from one day to others.""" @@ -377,7 +390,7 @@ class TestCopyAvailability: ], }, ) - + # Copy to another day response = await client.post( "/api/admin/availability/copy", @@ -386,7 +399,7 @@ class TestCopyAvailability: "target_dates": [str(in_days(2))], }, ) - + assert response.status_code == 200 data = response.json() assert len(data["days"]) == 1 @@ -404,7 +417,7 @@ class TestCopyAvailability: "slots": [{"start_time": "10:00:00", "end_time": "11:00:00"}], }, ) - + # Copy to multiple days response = await client.post( "/api/admin/availability/copy", @@ -413,7 +426,7 @@ class TestCopyAvailability: "target_dates": [str(in_days(2)), str(in_days(3)), str(in_days(4))], }, ) - + assert response.status_code == 200 data = response.json() assert len(data["days"]) == 3 @@ -429,7 +442,7 @@ class TestCopyAvailability: "slots": [{"start_time": "08:00:00", "end_time": "09:00:00"}], }, ) - + # Set source availability await client.put( "/api/admin/availability", @@ -438,7 +451,7 @@ class TestCopyAvailability: "slots": [{"start_time": "14:00:00", "end_time": "15:00:00"}], }, ) - + # Copy (should replace) await client.post( "/api/admin/availability/copy", @@ -447,13 +460,13 @@ class TestCopyAvailability: "target_dates": [str(in_days(2))], }, ) - + # Verify target was replaced response = await client.get( "/api/admin/availability", params={"from": str(in_days(2)), "to": str(in_days(2))}, ) - + data = response.json() assert len(data["days"]) == 1 assert len(data["days"][0]["slots"]) == 1 @@ -469,7 +482,7 @@ class TestCopyAvailability: "target_dates": [str(in_days(2))], }, ) - + assert response.status_code == 400 assert "no availability" in response.json()["detail"].lower() @@ -484,7 +497,7 @@ class TestCopyAvailability: "slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}], }, ) - + # Copy including self in targets response = await client.post( "/api/admin/availability/copy", @@ -493,7 +506,7 @@ class TestCopyAvailability: "target_dates": [str(in_days(1)), str(in_days(2))], }, ) - + assert response.status_code == 200 data = response.json() # Should only have copied to day 2, not day 1 (self) @@ -511,7 +524,7 @@ class TestCopyAvailability: "slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}], }, ) - + # Try to copy to a date beyond 30 days response = await client.post( "/api/admin/availability/copy", @@ -520,7 +533,7 @@ class TestCopyAvailability: "target_dates": [str(in_days(31))], }, ) - + assert response.status_code == 400 assert "30" in response.json()["detail"] @@ -535,6 +548,6 @@ class TestCopyAvailability: "target_dates": [str(in_days(1))], }, ) - + assert response.status_code == 400 assert "past" in response.json()["detail"].lower() diff --git a/backend/tests/test_booking.py b/backend/tests/test_booking.py index e77b62a..ef58872 100644 --- a/backend/tests/test_booking.py +++ b/backend/tests/test_booking.py @@ -3,7 +3,9 @@ Booking API Tests Tests for the user booking endpoints. """ -from datetime import date, datetime, timedelta, timezone + +from datetime import UTC, date, datetime, timedelta + import pytest from models import Appointment, AppointmentStatus @@ -21,11 +23,14 @@ def in_days(n: int) -> date: # Permission Tests # ============================================================================= + class TestBookingPermissions: """Test that only regular users can book appointments.""" @pytest.mark.asyncio - async def test_regular_user_can_get_slots(self, client_factory, regular_user, admin_user): + async def test_regular_user_can_get_slots( + self, client_factory, regular_user, admin_user + ): """Regular user can get available slots.""" # First, admin sets up availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -36,15 +41,19 @@ class TestBookingPermissions: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # Regular user gets slots async with client_factory.create(cookies=regular_user["cookies"]) as client: - response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) - + response = await client.get( + "/api/booking/slots", params={"date": str(tomorrow())} + ) + assert response.status_code == 200 @pytest.mark.asyncio - async def test_regular_user_can_book(self, client_factory, regular_user, admin_user): + async def test_regular_user_can_book( + self, client_factory, regular_user, admin_user + ): """Regular user can book an appointment.""" # Admin sets up availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -55,22 +64,24 @@ class TestBookingPermissions: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # Regular user books async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post( "/api/booking", json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": "Test booking"}, ) - + assert response.status_code == 200 @pytest.mark.asyncio async def test_admin_cannot_get_slots(self, client_factory, admin_user): """Admin cannot access booking slots endpoint.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: - response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) - + response = await client.get( + "/api/booking/slots", params={"date": str(tomorrow())} + ) + assert response.status_code == 403 @pytest.mark.asyncio @@ -85,18 +96,20 @@ class TestBookingPermissions: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + response = await client.post( "/api/booking", json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) - + assert response.status_code == 403 @pytest.mark.asyncio async def test_unauthenticated_cannot_get_slots(self, client): """Unauthenticated user cannot get slots.""" - response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) + response = await client.get( + "/api/booking/slots", params={"date": str(tomorrow())} + ) assert response.status_code == 401 @pytest.mark.asyncio @@ -113,6 +126,7 @@ class TestBookingPermissions: # Get Slots Tests # ============================================================================= + class TestGetSlots: """Test getting available booking slots.""" @@ -120,15 +134,19 @@ class TestGetSlots: async def test_get_slots_no_availability(self, client_factory, regular_user): """Returns empty slots when no availability set.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: - response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) - + response = await client.get( + "/api/booking/slots", params={"date": str(tomorrow())} + ) + assert response.status_code == 200 data = response.json() assert data["date"] == str(tomorrow()) assert data["slots"] == [] @pytest.mark.asyncio - async def test_get_slots_expands_to_15min(self, client_factory, regular_user, admin_user): + async def test_get_slots_expands_to_15min( + self, client_factory, regular_user, admin_user + ): """Availability is expanded into 15-minute slots.""" # Admin sets 1-hour availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -139,15 +157,17 @@ class TestGetSlots: "slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}], }, ) - + # User gets slots - should be 4 x 15-minute slots async with client_factory.create(cookies=regular_user["cookies"]) as client: - response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) - + response = await client.get( + "/api/booking/slots", params={"date": str(tomorrow())} + ) + assert response.status_code == 200 data = response.json() assert len(data["slots"]) == 4 - + # Verify times assert "09:00:00" in data["slots"][0]["start_time"] assert "09:15:00" in data["slots"][0]["end_time"] @@ -156,7 +176,9 @@ class TestGetSlots: assert "10:00:00" in data["slots"][3]["end_time"] @pytest.mark.asyncio - async def test_get_slots_excludes_booked(self, client_factory, regular_user, admin_user): + async def test_get_slots_excludes_booked( + self, client_factory, regular_user, admin_user + ): """Already booked slots are excluded from available slots.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -167,17 +189,19 @@ class TestGetSlots: "slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}], }, ) - + # User books first slot async with client_factory.create(cookies=regular_user["cookies"]) as client: await client.post( "/api/booking", json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) - + # Get slots again - should have 3 left - response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) - + response = await client.get( + "/api/booking/slots", params={"date": str(tomorrow())} + ) + assert response.status_code == 200 data = response.json() assert len(data["slots"]) == 3 @@ -189,6 +213,7 @@ class TestGetSlots: # Booking Tests # ============================================================================= + class TestCreateBooking: """Test creating bookings.""" @@ -204,7 +229,7 @@ class TestCreateBooking: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post( @@ -214,7 +239,7 @@ class TestCreateBooking: "note": "Discussion about project", }, ) - + assert response.status_code == 200 data = response.json() assert data["user_id"] == regular_user["user"]["id"] @@ -235,20 +260,22 @@ class TestCreateBooking: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books without note async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post( "/api/booking", json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) - + assert response.status_code == 200 data = response.json() assert data["note"] is None @pytest.mark.asyncio - async def test_cannot_double_book_slot(self, client_factory, regular_user, admin_user, alt_regular_user): + async def test_cannot_double_book_slot( + self, client_factory, regular_user, admin_user, alt_regular_user + ): """Cannot book a slot that's already booked.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -259,7 +286,7 @@ class TestCreateBooking: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # First user books async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post( @@ -267,19 +294,21 @@ class TestCreateBooking: json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) assert response.status_code == 200 - + # Second user tries to book same slot async with client_factory.create(cookies=alt_regular_user["cookies"]) as client: response = await client.post( "/api/booking", json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) - + assert response.status_code == 409 assert "already been booked" in response.json()["detail"] @pytest.mark.asyncio - async def test_cannot_book_outside_availability(self, client_factory, regular_user, admin_user): + async def test_cannot_book_outside_availability( + self, client_factory, regular_user, admin_user + ): """Cannot book a slot outside of availability.""" # Admin sets availability for morning only async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -290,14 +319,14 @@ class TestCreateBooking: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User tries to book afternoon slot async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post( "/api/booking", json={"slot_start": f"{tomorrow()}T14:00:00Z"}, ) - + assert response.status_code == 400 assert "not within any available time ranges" in response.json()["detail"] @@ -306,6 +335,7 @@ class TestCreateBooking: # Date Validation Tests # ============================================================================= + class TestBookingDateValidation: """Test date validation for bookings.""" @@ -317,9 +347,12 @@ class TestBookingDateValidation: "/api/booking", json={"slot_start": f"{date.today()}T09:00:00Z"}, ) - + assert response.status_code == 400 - assert "past" in response.json()["detail"].lower() or "today" in response.json()["detail"].lower() + assert ( + "past" in response.json()["detail"].lower() + or "today" in response.json()["detail"].lower() + ) @pytest.mark.asyncio async def test_cannot_book_past_date(self, client_factory, regular_user): @@ -330,7 +363,7 @@ class TestBookingDateValidation: "/api/booking", json={"slot_start": f"{yesterday}T09:00:00Z"}, ) - + assert response.status_code == 400 @pytest.mark.asyncio @@ -342,7 +375,7 @@ class TestBookingDateValidation: "/api/booking", json={"slot_start": f"{too_far}T09:00:00Z"}, ) - + assert response.status_code == 400 assert "30" in response.json()["detail"] @@ -350,8 +383,10 @@ class TestBookingDateValidation: async def test_cannot_get_slots_today(self, client_factory, regular_user): """Cannot get slots for today.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: - response = await client.get("/api/booking/slots", params={"date": str(date.today())}) - + response = await client.get( + "/api/booking/slots", params={"date": str(date.today())} + ) + assert response.status_code == 400 @pytest.mark.asyncio @@ -359,8 +394,10 @@ class TestBookingDateValidation: """Cannot get slots for past date.""" yesterday = date.today() - timedelta(days=1) async with client_factory.create(cookies=regular_user["cookies"]) as client: - response = await client.get("/api/booking/slots", params={"date": str(yesterday)}) - + response = await client.get( + "/api/booking/slots", params={"date": str(yesterday)} + ) + assert response.status_code == 400 @@ -368,11 +405,14 @@ class TestBookingDateValidation: # Time Validation Tests # ============================================================================= + class TestBookingTimeValidation: """Test time validation for bookings.""" @pytest.mark.asyncio - async def test_slot_must_be_15min_boundary(self, client_factory, regular_user, admin_user): + async def test_slot_must_be_15min_boundary( + self, client_factory, regular_user, admin_user + ): """Slot start time must be on 15-minute boundary.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -383,14 +423,14 @@ class TestBookingTimeValidation: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User tries to book at 09:05 async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post( "/api/booking", json={"slot_start": f"{tomorrow()}T09:05:00Z"}, ) - + assert response.status_code == 400 assert "15-minute" in response.json()["detail"] @@ -399,6 +439,7 @@ class TestBookingTimeValidation: # Note Validation Tests # ============================================================================= + class TestBookingNoteValidation: """Test note validation for bookings.""" @@ -414,7 +455,7 @@ class TestBookingNoteValidation: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User tries to book with long note long_note = "x" * 145 async with client_factory.create(cookies=regular_user["cookies"]) as client: @@ -422,11 +463,13 @@ class TestBookingNoteValidation: "/api/booking", json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": long_note}, ) - + assert response.status_code == 422 @pytest.mark.asyncio - async def test_note_exactly_144_chars(self, client_factory, regular_user, admin_user): + async def test_note_exactly_144_chars( + self, client_factory, regular_user, admin_user + ): """Note of exactly 144 characters is allowed.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -437,7 +480,7 @@ class TestBookingNoteValidation: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books with exactly 144 char note note = "x" * 144 async with client_factory.create(cookies=regular_user["cookies"]) as client: @@ -445,7 +488,7 @@ class TestBookingNoteValidation: "/api/booking", json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": note}, ) - + assert response.status_code == 200 assert response.json()["note"] == note @@ -454,6 +497,7 @@ class TestBookingNoteValidation: # User Appointments Tests # ============================================================================= + class TestUserAppointments: """Test user appointments endpoints.""" @@ -462,12 +506,14 @@ class TestUserAppointments: """Returns empty list when user has no appointments.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/appointments") - + assert response.status_code == 200 assert response.json() == [] @pytest.mark.asyncio - async def test_get_my_appointments_with_bookings(self, client_factory, regular_user, admin_user): + async def test_get_my_appointments_with_bookings( + self, client_factory, regular_user, admin_user + ): """Returns user's appointments.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -478,7 +524,7 @@ class TestUserAppointments: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books two slots async with client_factory.create(cookies=regular_user["cookies"]) as client: await client.post( @@ -489,10 +535,10 @@ class TestUserAppointments: "/api/booking", json={"slot_start": f"{tomorrow()}T09:15:00Z", "note": "Second"}, ) - + # Get appointments response = await client.get("/api/appointments") - + assert response.status_code == 200 data = response.json() assert len(data) == 2 @@ -502,11 +548,13 @@ class TestUserAppointments: assert "Second" in notes @pytest.mark.asyncio - async def test_admin_cannot_view_user_appointments(self, client_factory, admin_user): + async def test_admin_cannot_view_user_appointments( + self, client_factory, admin_user + ): """Admin cannot access user appointments endpoint.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/appointments") - + assert response.status_code == 403 @pytest.mark.asyncio @@ -520,7 +568,9 @@ class TestCancelAppointment: """Test cancelling appointments.""" @pytest.mark.asyncio - async def test_cancel_own_appointment(self, client_factory, regular_user, admin_user): + async def test_cancel_own_appointment( + self, client_factory, regular_user, admin_user + ): """User can cancel their own appointment.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -531,7 +581,7 @@ class TestCancelAppointment: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books async with client_factory.create(cookies=regular_user["cookies"]) as client: book_response = await client.post( @@ -539,17 +589,19 @@ class TestCancelAppointment: json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) apt_id = book_response.json()["id"] - + # Cancel response = await client.post(f"/api/appointments/{apt_id}/cancel") - + assert response.status_code == 200 data = response.json() assert data["status"] == "cancelled_by_user" assert data["cancelled_at"] is not None @pytest.mark.asyncio - async def test_cannot_cancel_others_appointment(self, client_factory, regular_user, alt_regular_user, admin_user): + async def test_cannot_cancel_others_appointment( + self, client_factory, regular_user, alt_regular_user, admin_user + ): """User cannot cancel another user's appointment.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -560,7 +612,7 @@ class TestCancelAppointment: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # First user books async with client_factory.create(cookies=regular_user["cookies"]) as client: book_response = await client.post( @@ -568,24 +620,28 @@ class TestCancelAppointment: json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) apt_id = book_response.json()["id"] - + # Second user tries to cancel async with client_factory.create(cookies=alt_regular_user["cookies"]) as client: response = await client.post(f"/api/appointments/{apt_id}/cancel") - + assert response.status_code == 403 assert "another user" in response.json()["detail"].lower() @pytest.mark.asyncio - async def test_cannot_cancel_nonexistent_appointment(self, client_factory, regular_user): + async def test_cannot_cancel_nonexistent_appointment( + self, client_factory, regular_user + ): """Returns 404 for non-existent appointment.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post("/api/appointments/99999/cancel") - + assert response.status_code == 404 @pytest.mark.asyncio - async def test_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user): + async def test_cannot_cancel_already_cancelled( + self, client_factory, regular_user, admin_user + ): """Cannot cancel an already cancelled appointment.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -596,7 +652,7 @@ class TestCancelAppointment: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books and cancels async with client_factory.create(cookies=regular_user["cookies"]) as client: book_response = await client.post( @@ -605,23 +661,27 @@ class TestCancelAppointment: ) apt_id = book_response.json()["id"] await client.post(f"/api/appointments/{apt_id}/cancel") - + # Try to cancel again response = await client.post(f"/api/appointments/{apt_id}/cancel") - + assert response.status_code == 400 assert "cancelled_by_user" in response.json()["detail"] @pytest.mark.asyncio - async def test_admin_cannot_use_user_cancel_endpoint(self, client_factory, admin_user): + async def test_admin_cannot_use_user_cancel_endpoint( + self, client_factory, admin_user + ): """Admin cannot use user cancel endpoint.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.post("/api/appointments/1/cancel") - + assert response.status_code == 403 @pytest.mark.asyncio - async def test_cancelled_slot_becomes_available(self, client_factory, regular_user, admin_user): + async def test_cancelled_slot_becomes_available( + self, client_factory, regular_user, admin_user + ): """After cancelling, the slot becomes available again.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -632,7 +692,7 @@ class TestCancelAppointment: "slots": [{"start_time": "09:00:00", "end_time": "09:30:00"}], }, ) - + # User books async with client_factory.create(cookies=regular_user["cookies"]) as client: book_response = await client.post( @@ -640,17 +700,17 @@ class TestCancelAppointment: json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) apt_id = book_response.json()["id"] - + # Check slots - should have 1 slot left (09:15) slots_response = await client.get( "/api/booking/slots", params={"date": str(tomorrow())}, ) assert len(slots_response.json()["slots"]) == 1 - + # Cancel await client.post(f"/api/appointments/{apt_id}/cancel") - + # Check slots - should have 2 slots now slots_response = await client.get( "/api/booking/slots", @@ -663,7 +723,7 @@ class TestCancelAppointment: """User cannot cancel a past appointment.""" # Create a past appointment directly in DB async with client_factory.get_db_session() as db: - past_time = datetime.now(timezone.utc) - timedelta(hours=1) + past_time = datetime.now(UTC) - timedelta(hours=1) appointment = Appointment( user_id=regular_user["user"]["id"], slot_start=past_time, @@ -674,11 +734,11 @@ class TestCancelAppointment: await db.commit() await db.refresh(appointment) apt_id = appointment.id - + # Try to cancel async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post(f"/api/appointments/{apt_id}/cancel") - + assert response.status_code == 400 assert "past" in response.json()["detail"].lower() @@ -687,11 +747,14 @@ class TestCancelAppointment: # Admin Appointments Tests # ============================================================================= + class TestAdminViewAppointments: """Test admin viewing all appointments.""" @pytest.mark.asyncio - async def test_admin_can_view_all_appointments(self, client_factory, regular_user, admin_user): + async def test_admin_can_view_all_appointments( + self, client_factory, regular_user, admin_user + ): """Admin can view all appointments.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -702,18 +765,18 @@ class TestAdminViewAppointments: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books async with client_factory.create(cookies=regular_user["cookies"]) as client: await client.post( "/api/booking", json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": "Test"}, ) - + # Admin views all appointments async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: response = await admin_client.get("/api/admin/appointments") - + assert response.status_code == 200 data = response.json() # Paginated response @@ -725,11 +788,13 @@ class TestAdminViewAppointments: assert any(apt["note"] == "Test" for apt in data["records"]) @pytest.mark.asyncio - async def test_regular_user_cannot_view_all_appointments(self, client_factory, regular_user): + async def test_regular_user_cannot_view_all_appointments( + self, client_factory, regular_user + ): """Regular user cannot access admin appointments endpoint.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/admin/appointments") - + assert response.status_code == 403 @pytest.mark.asyncio @@ -743,7 +808,9 @@ class TestAdminCancelAppointment: """Test admin cancelling appointments.""" @pytest.mark.asyncio - async def test_admin_can_cancel_any_appointment(self, client_factory, regular_user, admin_user): + async def test_admin_can_cancel_any_appointment( + self, client_factory, regular_user, admin_user + ): """Admin can cancel any user's appointment.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -754,7 +821,7 @@ class TestAdminCancelAppointment: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books async with client_factory.create(cookies=regular_user["cookies"]) as client: book_response = await client.post( @@ -762,18 +829,22 @@ class TestAdminCancelAppointment: json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) apt_id = book_response.json()["id"] - + # Admin cancels async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: - response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") - + response = await admin_client.post( + f"/api/admin/appointments/{apt_id}/cancel" + ) + assert response.status_code == 200 data = response.json() assert data["status"] == "cancelled_by_admin" assert data["cancelled_at"] is not None @pytest.mark.asyncio - async def test_regular_user_cannot_use_admin_cancel(self, client_factory, regular_user, admin_user): + async def test_regular_user_cannot_use_admin_cancel( + self, client_factory, regular_user, admin_user + ): """Regular user cannot use admin cancel endpoint.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -784,7 +855,7 @@ class TestAdminCancelAppointment: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books async with client_factory.create(cookies=regular_user["cookies"]) as client: book_response = await client.post( @@ -792,22 +863,26 @@ class TestAdminCancelAppointment: json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) apt_id = book_response.json()["id"] - + # User tries to use admin cancel endpoint response = await client.post(f"/api/admin/appointments/{apt_id}/cancel") - + assert response.status_code == 403 @pytest.mark.asyncio - async def test_admin_cancel_nonexistent_appointment(self, client_factory, admin_user): + async def test_admin_cancel_nonexistent_appointment( + self, client_factory, admin_user + ): """Returns 404 for non-existent appointment.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.post("/api/admin/appointments/99999/cancel") - + assert response.status_code == 404 @pytest.mark.asyncio - async def test_admin_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user): + async def test_admin_cannot_cancel_already_cancelled( + self, client_factory, regular_user, admin_user + ): """Admin cannot cancel an already cancelled appointment.""" # Admin sets availability async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: @@ -818,7 +893,7 @@ class TestAdminCancelAppointment: "slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}], }, ) - + # User books async with client_factory.create(cookies=regular_user["cookies"]) as client: book_response = await client.post( @@ -826,23 +901,27 @@ class TestAdminCancelAppointment: json={"slot_start": f"{tomorrow()}T09:00:00Z"}, ) apt_id = book_response.json()["id"] - + # User cancels their own appointment await client.post(f"/api/appointments/{apt_id}/cancel") - + # Admin tries to cancel again async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: - response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") - + response = await admin_client.post( + f"/api/admin/appointments/{apt_id}/cancel" + ) + assert response.status_code == 400 assert "cancelled_by_user" in response.json()["detail"] @pytest.mark.asyncio - async def test_admin_cannot_cancel_past_appointment(self, client_factory, regular_user, admin_user): + async def test_admin_cannot_cancel_past_appointment( + self, client_factory, regular_user, admin_user + ): """Admin cannot cancel a past appointment.""" # Create a past appointment directly in DB async with client_factory.get_db_session() as db: - past_time = datetime.now(timezone.utc) - timedelta(hours=1) + past_time = datetime.now(UTC) - timedelta(hours=1) appointment = Appointment( user_id=regular_user["user"]["id"], slot_start=past_time, @@ -853,11 +932,12 @@ class TestAdminCancelAppointment: await db.commit() await db.refresh(appointment) apt_id = appointment.id - + # Admin tries to cancel async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: - response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") - + response = await admin_client.post( + f"/api/admin/appointments/{apt_id}/cancel" + ) + assert response.status_code == 400 assert "past" in response.json()["detail"].lower() - diff --git a/backend/tests/test_counter.py b/backend/tests/test_counter.py index 2bbc55e..7dceb88 100644 --- a/backend/tests/test_counter.py +++ b/backend/tests/test_counter.py @@ -2,12 +2,13 @@ Note: Registration now requires an invite code. """ + import pytest from auth import COOKIE_NAME from models import ROLE_REGULAR -from tests.helpers import unique_email, create_invite_for_godfather from tests.conftest import create_user_with_roles +from tests.helpers import create_invite_for_godfather, unique_email # Protected endpoint tests - without auth @@ -41,9 +42,11 @@ async def test_increment_counter_invalid_cookie(client_factory): @pytest.mark.asyncio async def test_get_counter_authenticated(client_factory): async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + reg = await client_factory.post( "/api/auth/register", json={ @@ -53,10 +56,10 @@ async def test_get_counter_authenticated(client_factory): }, ) cookies = dict(reg.cookies) - + async with client_factory.create(cookies=cookies) as authed: response = await authed.get("/api/counter") - + assert response.status_code == 200 assert "value" in response.json() @@ -64,9 +67,11 @@ async def test_get_counter_authenticated(client_factory): @pytest.mark.asyncio async def test_increment_counter(client_factory): async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + reg = await client_factory.post( "/api/auth/register", json={ @@ -76,12 +81,12 @@ async def test_increment_counter(client_factory): }, ) cookies = dict(reg.cookies) - + async with client_factory.create(cookies=cookies) as authed: # Get current value before = await authed.get("/api/counter") before_value = before.json()["value"] - + # Increment response = await authed.post("/api/counter/increment") assert response.status_code == 200 @@ -91,9 +96,11 @@ async def test_increment_counter(client_factory): @pytest.mark.asyncio async def test_increment_counter_multiple(client_factory): async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + reg = await client_factory.post( "/api/auth/register", json={ @@ -103,26 +110,28 @@ async def test_increment_counter_multiple(client_factory): }, ) cookies = dict(reg.cookies) - + async with client_factory.create(cookies=cookies) as authed: # Get starting value before = await authed.get("/api/counter") start = before.json()["value"] - + # Increment 3 times await authed.post("/api/counter/increment") await authed.post("/api/counter/increment") response = await authed.post("/api/counter/increment") - + assert response.json()["value"] == start + 3 @pytest.mark.asyncio async def test_get_counter_after_increment(client_factory): async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + reg = await client_factory.post( "/api/auth/register", json={ @@ -132,14 +141,14 @@ async def test_get_counter_after_increment(client_factory): }, ) cookies = dict(reg.cookies) - + async with client_factory.create(cookies=cookies) as authed: before = await authed.get("/api/counter") start = before.json()["value"] - + await authed.post("/api/counter/increment") await authed.post("/api/counter/increment") - + response = await authed.get("/api/counter") assert response.json()["value"] == start + 2 @@ -149,10 +158,12 @@ async def test_get_counter_after_increment(client_factory): async def test_counter_shared_between_users(client_factory): # Create godfather and invites for two users async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite1 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id) - + # Create first user reg1 = await client_factory.post( "/api/auth/register", @@ -163,15 +174,15 @@ async def test_counter_shared_between_users(client_factory): }, ) cookies1 = dict(reg1.cookies) - + async with client_factory.create(cookies=cookies1) as user1: # Get starting value before = await user1.get("/api/counter") start = before.json()["value"] - + await user1.post("/api/counter/increment") await user1.post("/api/counter/increment") - + # Create second user - should see the increments reg2 = await client_factory.post( "/api/auth/register", @@ -182,14 +193,14 @@ async def test_counter_shared_between_users(client_factory): }, ) cookies2 = dict(reg2.cookies) - + async with client_factory.create(cookies=cookies2) as user2: response = await user2.get("/api/counter") assert response.json()["value"] == start + 2 - + # Second user increments await user2.post("/api/counter/increment") - + # First user sees the increment async with client_factory.create(cookies=cookies1) as user1: response = await user1.get("/api/counter") diff --git a/backend/tests/test_invites.py b/backend/tests/test_invites.py index 9312825..20042ea 100644 --- a/backend/tests/test_invites.py +++ b/backend/tests/test_invites.py @@ -1,22 +1,23 @@ """Tests for invite functionality.""" + import pytest from sqlalchemy import select from invite_utils import ( - generate_invite_identifier, - normalize_identifier, - is_valid_identifier_format, BIP39_WORDS, + generate_invite_identifier, + is_valid_identifier_format, + normalize_identifier, ) -from models import Invite, InviteStatus, User, ROLE_REGULAR -from tests.helpers import unique_email +from models import ROLE_REGULAR, Invite, InviteStatus, User from tests.conftest import create_user_with_roles - +from tests.helpers import unique_email # ============================================================================ # Invite Utils Tests # ============================================================================ + def test_bip39_words_loaded(): """BIP39 word list should have exactly 2048 words.""" assert len(BIP39_WORDS) == 2048 @@ -26,7 +27,7 @@ def test_generate_invite_identifier_format(): """Generated identifier should have word-word-NN format.""" identifier = generate_invite_identifier() assert is_valid_identifier_format(identifier) - + parts = identifier.split("-") assert len(parts) == 3 assert parts[0] in BIP39_WORDS @@ -74,11 +75,11 @@ def test_is_valid_identifier_format_invalid(): assert is_valid_identifier_format("apple-banana") is False assert is_valid_identifier_format("apple-banana-42-extra") is False assert is_valid_identifier_format("applebanan42") is False - + # Empty parts assert is_valid_identifier_format("-banana-42") is False assert is_valid_identifier_format("apple--42") is False - + # Invalid number format assert is_valid_identifier_format("apple-banana-4") is False # Single digit assert is_valid_identifier_format("apple-banana-420") is False # Three digits @@ -89,6 +90,7 @@ def test_is_valid_identifier_format_invalid(): # Invite Model Tests # ============================================================================ + @pytest.mark.asyncio async def test_create_invite(client_factory): """Can create an invite with godfather.""" @@ -97,7 +99,7 @@ async def test_create_invite(client_factory): godfather = await create_user_with_roles( db, unique_email("godfather"), "password123", [ROLE_REGULAR] ) - + # Create invite invite = Invite( identifier="test-invite-01", @@ -107,7 +109,7 @@ async def test_create_invite(client_factory): db.add(invite) await db.commit() await db.refresh(invite) - + assert invite.id is not None assert invite.identifier == "test-invite-01" assert invite.godfather_id == godfather.id @@ -125,20 +127,20 @@ async def test_invite_godfather_relationship(client_factory): godfather = await create_user_with_roles( db, unique_email("godfather"), "password123", [ROLE_REGULAR] ) - + invite = Invite( identifier="rel-test-01", godfather_id=godfather.id, ) db.add(invite) await db.commit() - + # Query invite fresh result = await db.execute( select(Invite).where(Invite.identifier == "rel-test-01") ) loaded_invite = result.scalar_one() - + assert loaded_invite.godfather is not None assert loaded_invite.godfather.email == godfather.email @@ -147,25 +149,25 @@ async def test_invite_godfather_relationship(client_factory): async def test_invite_unique_identifier(client_factory): """Invite identifier must be unique.""" from sqlalchemy.exc import IntegrityError - + async with client_factory.get_db_session() as db: godfather = await create_user_with_roles( db, unique_email("godfather"), "password123", [ROLE_REGULAR] ) - + invite1 = Invite( identifier="unique-test-01", godfather_id=godfather.id, ) db.add(invite1) await db.commit() - + invite2 = Invite( identifier="unique-test-01", # Same identifier godfather_id=godfather.id, ) db.add(invite2) - + with pytest.raises(IntegrityError): await db.commit() @@ -173,8 +175,8 @@ async def test_invite_unique_identifier(client_factory): @pytest.mark.asyncio async def test_invite_status_transitions(client_factory): """Invite status can be changed.""" - from datetime import datetime, UTC - + from datetime import UTC, datetime + async with client_factory.get_db_session() as db: godfather = await create_user_with_roles( db, unique_email("godfather"), "password123", [ROLE_REGULAR] @@ -182,7 +184,7 @@ async def test_invite_status_transitions(client_factory): user = await create_user_with_roles( db, unique_email("invitee"), "password123", [ROLE_REGULAR] ) - + invite = Invite( identifier="status-test-01", godfather_id=godfather.id, @@ -190,14 +192,14 @@ async def test_invite_status_transitions(client_factory): ) db.add(invite) await db.commit() - + # Transition to SPENT invite.status = InviteStatus.SPENT invite.used_by_id = user.id invite.spent_at = datetime.now(UTC) await db.commit() await db.refresh(invite) - + assert invite.status == InviteStatus.SPENT assert invite.used_by_id == user.id assert invite.spent_at is not None @@ -206,13 +208,13 @@ async def test_invite_status_transitions(client_factory): @pytest.mark.asyncio async def test_invite_revoke(client_factory): """Invite can be revoked.""" - from datetime import datetime, UTC - + from datetime import UTC, datetime + async with client_factory.get_db_session() as db: godfather = await create_user_with_roles( db, unique_email("godfather"), "password123", [ROLE_REGULAR] ) - + invite = Invite( identifier="revoke-test-01", godfather_id=godfather.id, @@ -220,13 +222,13 @@ async def test_invite_revoke(client_factory): ) db.add(invite) await db.commit() - + # Revoke invite.status = InviteStatus.REVOKED invite.revoked_at = datetime.now(UTC) await db.commit() await db.refresh(invite) - + assert invite.status == InviteStatus.REVOKED assert invite.revoked_at is not None assert invite.used_by_id is None # Not used @@ -236,6 +238,7 @@ async def test_invite_revoke(client_factory): # User Godfather Tests # ============================================================================ + @pytest.mark.asyncio async def test_user_godfather_relationship(client_factory): """User can have a godfather.""" @@ -243,7 +246,7 @@ async def test_user_godfather_relationship(client_factory): godfather = await create_user_with_roles( db, unique_email("godfather"), "password123", [ROLE_REGULAR] ) - + # Create user with godfather user = User( email=unique_email("godchild"), @@ -252,13 +255,11 @@ async def test_user_godfather_relationship(client_factory): ) db.add(user) await db.commit() - + # Query user fresh - result = await db.execute( - select(User).where(User.id == user.id) - ) + result = await db.execute(select(User).where(User.id == user.id)) loaded_user = result.scalar_one() - + assert loaded_user.godfather_id == godfather.id assert loaded_user.godfather is not None assert loaded_user.godfather.email == godfather.email @@ -271,7 +272,7 @@ async def test_user_without_godfather(client_factory): user = await create_user_with_roles( db, unique_email("noparent"), "password123", [ROLE_REGULAR] ) - + assert user.godfather_id is None assert user.godfather is None @@ -280,6 +281,7 @@ async def test_user_without_godfather(client_factory): # Admin Create Invite API Tests (Phase 2) # ============================================================================ + @pytest.mark.asyncio async def test_admin_can_create_invite(client_factory, admin_user, regular_user): """Admin can create an invite for a regular user.""" @@ -290,12 +292,12 @@ async def test_admin_can_create_invite(client_factory, admin_user, regular_user) select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + response = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) - + assert response.status_code == 200 data = response.json() assert data["godfather_id"] == godfather.id @@ -318,12 +320,12 @@ async def test_admin_can_create_invite_for_self(client_factory, admin_user): select(User).where(User.email == admin_user["email"]) ) admin = result.scalar_one() - + response = await client.post( "/api/admin/invites", json={"godfather_id": admin.id}, ) - + assert response.status_code == 200 data = response.json() assert data["godfather_id"] == admin.id @@ -338,7 +340,7 @@ async def test_regular_user_cannot_create_invite(client_factory, regular_user): "/api/admin/invites", json={"godfather_id": 1}, ) - + assert response.status_code == 403 @@ -350,7 +352,7 @@ async def test_unauthenticated_cannot_create_invite(client_factory): "/api/admin/invites", json={"godfather_id": 1}, ) - + assert response.status_code == 401 @@ -362,7 +364,7 @@ async def test_create_invite_invalid_godfather(client_factory, admin_user): "/api/admin/invites", json={"godfather_id": 99999}, ) - + assert response.status_code == 400 assert "not found" in response.json()["detail"].lower() @@ -376,39 +378,39 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + response = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) - + data = response.json() invite_id = data["id"] - + # Query from DB async with client_factory.get_db_session() as db: - result = await db.execute( - select(Invite).where(Invite.id == invite_id) - ) + result = await db.execute(select(Invite).where(Invite.id == invite_id)) invite = result.scalar_one() - + assert invite.identifier == data["identifier"] assert invite.godfather_id == godfather.id assert invite.status == InviteStatus.READY @pytest.mark.asyncio -async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user): +async def test_create_invite_retries_on_collision( + client_factory, admin_user, regular_user +): """Create invite retries with new identifier on collision.""" from unittest.mock import patch - + async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.get_db_session() as db: result = await db.execute( select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + # Create first invite normally response1 = await client.post( "/api/admin/invites", @@ -416,22 +418,25 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re ) assert response1.status_code == 200 identifier1 = response1.json()["identifier"] - + # Mock generator to first return the same identifier (collision), then a new one call_count = 0 + def mock_generator(): nonlocal call_count call_count += 1 if call_count == 1: return identifier1 # Will collide return f"unique-word-{call_count:02d}" # Won't collide - - with patch("routes.invites.generate_invite_identifier", side_effect=mock_generator): + + with patch( + "routes.invites.generate_invite_identifier", side_effect=mock_generator + ): response2 = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) - + assert response2.status_code == 200 # Should have retried and gotten a new identifier assert response2.json()["identifier"] != identifier1 @@ -442,6 +447,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re # Invite Check API Tests (Phase 3) # ============================================================================ + @pytest.mark.asyncio async def test_check_invite_valid(client_factory, admin_user, regular_user): """Check endpoint returns valid=True for READY invite.""" @@ -452,17 +458,17 @@ async def test_check_invite_valid(client_factory, admin_user, regular_user): select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] - + # Check invite (no auth needed) async with client_factory.create() as client: response = await client.get(f"/api/invites/{identifier}/check") - + assert response.status_code == 200 data = response.json() assert data["valid"] is True @@ -475,7 +481,7 @@ async def test_check_invite_not_found(client_factory): """Check endpoint returns valid=False for unknown invite.""" async with client_factory.create() as client: response = await client.get("/api/invites/fake-invite-99/check") - + assert response.status_code == 200 data = response.json() assert data["valid"] is False @@ -492,14 +498,14 @@ async def test_check_invite_invalid_format(client_factory): data = response.json() assert data["valid"] is False assert "format" in data["error"].lower() - + # Single digit number response = await client.get("/api/invites/word-word-1/check") assert response.status_code == 200 data = response.json() assert data["valid"] is False assert "format" in data["error"].lower() - + # Too many parts response = await client.get("/api/invites/word-word-word-00/check") assert response.status_code == 200 @@ -509,7 +515,9 @@ async def test_check_invite_invalid_format(client_factory): @pytest.mark.asyncio -async def test_check_invite_spent_returns_not_found(client_factory, admin_user, regular_user): +async def test_check_invite_spent_returns_not_found( + client_factory, admin_user, regular_user +): """Check endpoint returns same error for spent invite as for non-existent (no info leakage).""" # Create invite async with client_factory.create(cookies=admin_user["cookies"]) as client: @@ -518,13 +526,13 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user, select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] - + # Use the invite async with client_factory.create() as client: await client.post( @@ -535,11 +543,11 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user, "invite_identifier": identifier, }, ) - + # Check spent invite - should return same error as non-existent async with client_factory.create() as client: response = await client.get(f"/api/invites/{identifier}/check") - + assert response.status_code == 200 data = response.json() assert data["valid"] is False @@ -547,10 +555,12 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user, @pytest.mark.asyncio -async def test_check_invite_revoked_returns_not_found(client_factory, admin_user, regular_user): +async def test_check_invite_revoked_returns_not_found( + client_factory, admin_user, regular_user +): """Check endpoint returns same error for revoked invite as for non-existent (no info leakage).""" - from datetime import datetime, UTC - + from datetime import UTC, datetime + # Create invite async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.get_db_session() as db: @@ -558,14 +568,14 @@ async def test_check_invite_revoked_returns_not_found(client_factory, admin_user select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] invite_id = create_resp.json()["id"] - + # Revoke the invite async with client_factory.get_db_session() as db: result = await db.execute(select(Invite).where(Invite.id == invite_id)) @@ -573,11 +583,11 @@ async def test_check_invite_revoked_returns_not_found(client_factory, admin_user invite.status = InviteStatus.REVOKED invite.revoked_at = datetime.now(UTC) await db.commit() - + # Check revoked invite - should return same error as non-existent async with client_factory.create() as client: response = await client.get(f"/api/invites/{identifier}/check") - + assert response.status_code == 200 data = response.json() assert data["valid"] is False @@ -594,17 +604,17 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] - + # Check with uppercase async with client_factory.create() as client: response = await client.get(f"/api/invites/{identifier.upper()}/check") - + assert response.status_code == 200 assert response.json()["valid"] is True @@ -613,6 +623,7 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular # Register with Invite Tests (Phase 3) # ============================================================================ + @pytest.mark.asyncio async def test_register_with_valid_invite(client_factory, admin_user, regular_user): """Can register with valid invite code.""" @@ -624,13 +635,13 @@ async def test_register_with_valid_invite(client_factory, admin_user, regular_us ) godfather = result.scalar_one() godfather_id = godfather.id - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather_id}, ) identifier = create_resp.json()["identifier"] - + # Register with invite new_email = unique_email("newuser") async with client_factory.create() as client: @@ -642,7 +653,7 @@ async def test_register_with_valid_invite(client_factory, admin_user, regular_us "invite_identifier": identifier, }, ) - + assert response.status_code == 200 data = response.json() assert data["email"] == new_email @@ -659,7 +670,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, @@ -667,7 +678,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u invite_data = create_resp.json() identifier = invite_data["identifier"] invite_id = invite_data["id"] - + # Register async with client_factory.create() as client: await client.post( @@ -678,14 +689,12 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u "invite_identifier": identifier, }, ) - + # Check invite status async with client_factory.get_db_session() as db: - result = await db.execute( - select(Invite).where(Invite.id == invite_id) - ) + result = await db.execute(select(Invite).where(Invite.id == invite_id)) invite = result.scalar_one() - + assert invite.status == InviteStatus.SPENT assert invite.used_by_id is not None assert invite.spent_at is not None @@ -702,13 +711,13 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user) ) godfather = result.scalar_one() godfather_id = godfather.id - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather_id}, ) identifier = create_resp.json()["identifier"] - + # Register new_email = unique_email("godchildtest") async with client_factory.create() as client: @@ -720,14 +729,12 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user) "invite_identifier": identifier, }, ) - + # Check user's godfather async with client_factory.get_db_session() as db: - result = await db.execute( - select(User).where(User.email == new_email) - ) + result = await db.execute(select(User).where(User.email == new_email)) new_user = result.scalar_one() - + assert new_user.godfather_id == godfather_id @@ -743,7 +750,7 @@ async def test_register_with_invalid_invite(client_factory): "invite_identifier": "fake-invite-99", }, ) - + assert response.status_code == 400 assert "invalid" in response.json()["detail"].lower() @@ -758,13 +765,13 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] - + # First registration async with client_factory.create() as client: await client.post( @@ -775,7 +782,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us "invite_identifier": identifier, }, ) - + # Second registration with same invite async with client_factory.create() as client: response = await client.post( @@ -786,7 +793,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us "invite_identifier": identifier, }, ) - + assert response.status_code == 400 assert "invalid invite code" in response.json()["detail"].lower() @@ -794,8 +801,8 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us @pytest.mark.asyncio async def test_register_with_revoked_invite(client_factory, admin_user, regular_user): """Cannot register with revoked invite.""" - from datetime import datetime, UTC - + from datetime import UTC, datetime + # Create invite async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.get_db_session() as db: @@ -803,7 +810,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_ select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, @@ -811,17 +818,15 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_ invite_data = create_resp.json() identifier = invite_data["identifier"] invite_id = invite_data["id"] - + # Revoke invite directly in DB async with client_factory.get_db_session() as db: - result = await db.execute( - select(Invite).where(Invite.id == invite_id) - ) + result = await db.execute(select(Invite).where(Invite.id == invite_id)) invite = result.scalar_one() invite.status = InviteStatus.REVOKED invite.revoked_at = datetime.now(UTC) await db.commit() - + # Try to register async with client_factory.create() as client: response = await client.post( @@ -832,7 +837,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_ "invite_identifier": identifier, }, ) - + assert response.status_code == 400 assert "invalid invite code" in response.json()["detail"].lower() @@ -847,13 +852,13 @@ async def test_register_duplicate_email(client_factory, admin_user, regular_user select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] - + # Try to register with existing email async with client_factory.create() as client: response = await client.post( @@ -864,7 +869,7 @@ async def test_register_duplicate_email(client_factory, admin_user, regular_user "invite_identifier": identifier, }, ) - + assert response.status_code == 400 assert "already registered" in response.json()["detail"].lower() @@ -879,13 +884,13 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] - + # Register async with client_factory.create() as client: response = await client.post( @@ -896,7 +901,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use "invite_identifier": identifier, }, ) - + assert "auth_token" in response.cookies @@ -904,6 +909,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use # User Invites API Tests (Phase 4) # ============================================================================ + @pytest.mark.asyncio async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user): """Regular user can list their own invites.""" @@ -914,14 +920,14 @@ async def test_regular_user_can_list_invites(client_factory, admin_user, regular select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) - + # List invites as regular user async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/invites") - + assert response.status_code == 200 invites = response.json() assert len(invites) == 2 @@ -935,13 +941,15 @@ async def test_user_with_no_invites_gets_empty_list(client_factory, regular_user """User with no invites gets empty list.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/invites") - + assert response.status_code == 200 assert response.json() == [] @pytest.mark.asyncio -async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regular_user): +async def test_spent_invite_shows_used_by_email( + client_factory, admin_user, regular_user +): """Spent invite shows who used it.""" # Create invite for regular user async with client_factory.create(cookies=admin_user["cookies"]) as client: @@ -950,13 +958,13 @@ async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regu select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] - + # Use the invite invitee_email = unique_email("invitee") async with client_factory.create() as client: @@ -968,11 +976,11 @@ async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regu "invite_identifier": identifier, }, ) - + # Check that regular user sees the invitee email async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/invites") - + assert response.status_code == 200 invites = response.json() assert len(invites) == 1 @@ -985,7 +993,7 @@ async def test_admin_cannot_list_own_invites(client_factory, admin_user): """Admin without VIEW_OWN_INVITES permission gets 403.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/invites") - + assert response.status_code == 403 @@ -994,7 +1002,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory): """Unauthenticated user gets 401.""" async with client_factory.create() as client: response = await client.get("/api/invites") - + assert response.status_code == 401 @@ -1002,6 +1010,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory): # Admin Invite Management Tests (Phase 5) # ============================================================================ + @pytest.mark.asyncio async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user): """Admin can list all invites.""" @@ -1012,13 +1021,13 @@ async def test_admin_can_list_all_invites(client_factory, admin_user, regular_us select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) - + # List all response = await client.get("/api/admin/invites") - + assert response.status_code == 200 data = response.json() assert data["total"] >= 2 @@ -1034,14 +1043,14 @@ async def test_admin_list_pagination(client_factory, admin_user, regular_user): select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + # Create 5 invites for _ in range(5): await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) - + # Get page 1 with 2 per page response = await client.get("/api/admin/invites?page=1&per_page=2") - + assert response.status_code == 200 data = response.json() assert len(data["records"]) == 2 @@ -1058,13 +1067,13 @@ async def test_admin_filter_by_status(client_factory, admin_user, regular_user): select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + # Create an invite await client.post("/api/admin/invites", json={"godfather_id": godfather.id}) - + # Filter by ready response = await client.get("/api/admin/invites?status=ready") - + assert response.status_code == 200 data = response.json() for record in data["records"]: @@ -1080,17 +1089,17 @@ async def test_admin_can_revoke_invite(client_factory, admin_user, regular_user) select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + # Create invite create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) invite_id = create_resp.json()["id"] - + # Revoke it response = await client.post(f"/api/admin/invites/{invite_id}/revoke") - + assert response.status_code == 200 data = response.json() assert data["status"] == "revoked" @@ -1106,14 +1115,14 @@ async def test_cannot_revoke_spent_invite(client_factory, admin_user, regular_us select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + # Create invite create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) invite_data = create_resp.json() - + # Use the invite async with client_factory.create() as client: await client.post( @@ -1124,11 +1133,11 @@ async def test_cannot_revoke_spent_invite(client_factory, admin_user, regular_us "invite_identifier": invite_data["identifier"], }, ) - + # Try to revoke async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.post(f"/api/admin/invites/{invite_data['id']}/revoke") - + assert response.status_code == 400 assert "only ready" in response.json()["detail"].lower() @@ -1138,7 +1147,7 @@ async def test_revoke_nonexistent_invite(client_factory, admin_user): """Revoking non-existent invite returns 404.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.post("/api/admin/invites/99999/revoke") - + assert response.status_code == 404 @@ -1149,8 +1158,7 @@ async def test_regular_user_cannot_access_admin_invites(client_factory, regular_ # List response = await client.get("/api/admin/invites") assert response.status_code == 403 - + # Revoke response = await client.post("/api/admin/invites/1/revoke") assert response.status_code == 403 - diff --git a/backend/tests/test_permissions.py b/backend/tests/test_permissions.py index efcfa67..af1fc1c 100644 --- a/backend/tests/test_permissions.py +++ b/backend/tests/test_permissions.py @@ -7,15 +7,16 @@ These tests verify that: 3. Unauthenticated users are denied access (401) 4. The permission system cannot be bypassed """ + import pytest from models import Permission - # ============================================================================= # Role Assignment Tests # ============================================================================= + class TestRoleAssignment: """Test that roles are properly assigned and returned.""" @@ -23,7 +24,7 @@ class TestRoleAssignment: async def test_regular_user_has_correct_roles(self, client_factory, regular_user): async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/auth/me") - + assert response.status_code == 200 data = response.json() assert "regular" in data["roles"] @@ -33,25 +34,27 @@ class TestRoleAssignment: async def test_admin_user_has_correct_roles(self, client_factory, admin_user): async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/auth/me") - + assert response.status_code == 200 data = response.json() assert "admin" in data["roles"] assert "regular" not in data["roles"] @pytest.mark.asyncio - async def test_regular_user_has_correct_permissions(self, client_factory, regular_user): + async def test_regular_user_has_correct_permissions( + self, client_factory, regular_user + ): async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/auth/me") - + data = response.json() permissions = data["permissions"] - + # Should have counter and sum permissions assert Permission.VIEW_COUNTER.value in permissions assert Permission.INCREMENT_COUNTER.value in permissions assert Permission.USE_SUM.value in permissions - + # Should NOT have audit permission assert Permission.VIEW_AUDIT.value not in permissions @@ -59,23 +62,25 @@ class TestRoleAssignment: async def test_admin_user_has_correct_permissions(self, client_factory, admin_user): async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/auth/me") - + data = response.json() permissions = data["permissions"] - + # Should have audit permission assert Permission.VIEW_AUDIT.value in permissions - + # Should NOT have counter/sum permissions assert Permission.VIEW_COUNTER.value not in permissions assert Permission.INCREMENT_COUNTER.value not in permissions assert Permission.USE_SUM.value not in permissions @pytest.mark.asyncio - async def test_user_with_no_roles_has_no_permissions(self, client_factory, user_no_roles): + async def test_user_with_no_roles_has_no_permissions( + self, client_factory, user_no_roles + ): async with client_factory.create(cookies=user_no_roles["cookies"]) as client: response = await client.get("/api/auth/me") - + data = response.json() assert data["roles"] == [] assert data["permissions"] == [] @@ -85,6 +90,7 @@ class TestRoleAssignment: # Counter Endpoint Access Tests # ============================================================================= + class TestCounterAccess: """Test access control for counter endpoints.""" @@ -92,15 +98,17 @@ class TestCounterAccess: async def test_regular_user_can_view_counter(self, client_factory, regular_user): async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/counter") - + assert response.status_code == 200 assert "value" in response.json() @pytest.mark.asyncio - async def test_regular_user_can_increment_counter(self, client_factory, regular_user): + async def test_regular_user_can_increment_counter( + self, client_factory, regular_user + ): async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.post("/api/counter/increment") - + assert response.status_code == 200 assert "value" in response.json() @@ -109,7 +117,7 @@ class TestCounterAccess: """Admin users should be forbidden from counter endpoints.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/counter") - + assert response.status_code == 403 assert "permission" in response.json()["detail"].lower() @@ -118,15 +126,17 @@ class TestCounterAccess: """Admin users should be forbidden from incrementing counter.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.post("/api/counter/increment") - + assert response.status_code == 403 @pytest.mark.asyncio - async def test_user_without_roles_cannot_view_counter(self, client_factory, user_no_roles): + async def test_user_without_roles_cannot_view_counter( + self, client_factory, user_no_roles + ): """Users with no roles should be forbidden.""" async with client_factory.create(cookies=user_no_roles["cookies"]) as client: response = await client.get("/api/counter") - + assert response.status_code == 403 @pytest.mark.asyncio @@ -146,6 +156,7 @@ class TestCounterAccess: # Sum Endpoint Access Tests # ============================================================================= + class TestSumAccess: """Test access control for sum endpoint.""" @@ -156,7 +167,7 @@ class TestSumAccess: "/api/sum", json={"a": 5, "b": 3}, ) - + assert response.status_code == 200 data = response.json() assert data["result"] == 8 @@ -169,17 +180,19 @@ class TestSumAccess: "/api/sum", json={"a": 5, "b": 3}, ) - + assert response.status_code == 403 @pytest.mark.asyncio - async def test_user_without_roles_cannot_use_sum(self, client_factory, user_no_roles): + async def test_user_without_roles_cannot_use_sum( + self, client_factory, user_no_roles + ): async with client_factory.create(cookies=user_no_roles["cookies"]) as client: response = await client.post( "/api/sum", json={"a": 5, "b": 3}, ) - + assert response.status_code == 403 @pytest.mark.asyncio @@ -195,6 +208,7 @@ class TestSumAccess: # Audit Endpoint Access Tests # ============================================================================= + class TestAuditAccess: """Test access control for audit endpoints.""" @@ -202,7 +216,7 @@ class TestAuditAccess: async def test_admin_can_view_counter_audit(self, client_factory, admin_user): async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/audit/counter") - + assert response.status_code == 200 data = response.json() assert "records" in data @@ -212,34 +226,40 @@ class TestAuditAccess: async def test_admin_can_view_sum_audit(self, client_factory, admin_user): async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/audit/sum") - + assert response.status_code == 200 data = response.json() assert "records" in data assert "total" in data @pytest.mark.asyncio - async def test_regular_user_cannot_view_counter_audit(self, client_factory, regular_user): + async def test_regular_user_cannot_view_counter_audit( + self, client_factory, regular_user + ): """Regular users should be forbidden from audit endpoints.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/audit/counter") - + assert response.status_code == 403 assert "permission" in response.json()["detail"].lower() @pytest.mark.asyncio - async def test_regular_user_cannot_view_sum_audit(self, client_factory, regular_user): + async def test_regular_user_cannot_view_sum_audit( + self, client_factory, regular_user + ): """Regular users should be forbidden from audit endpoints.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/audit/sum") - + assert response.status_code == 403 @pytest.mark.asyncio - async def test_user_without_roles_cannot_view_audit(self, client_factory, user_no_roles): + async def test_user_without_roles_cannot_view_audit( + self, client_factory, user_no_roles + ): async with client_factory.create(cookies=user_no_roles["cookies"]) as client: response = await client.get("/api/audit/counter") - + assert response.status_code == 403 @pytest.mark.asyncio @@ -257,6 +277,7 @@ class TestAuditAccess: # Offensive Security Tests - Bypass Attempts # ============================================================================= + class TestSecurityBypassAttempts: """ Offensive tests that attempt to bypass security controls. @@ -264,7 +285,9 @@ class TestSecurityBypassAttempts: """ @pytest.mark.asyncio - async def test_cannot_access_audit_with_forged_role_claim(self, client_factory, regular_user): + async def test_cannot_access_audit_with_forged_role_claim( + self, client_factory, regular_user + ): """ Attempt to access audit by somehow claiming admin role. The server should verify roles from DB, not trust client claims. @@ -272,7 +295,7 @@ class TestSecurityBypassAttempts: # Regular user tries to access audit endpoint async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/audit/counter") - + # Should be denied regardless of any manipulation attempts assert response.status_code == 403 @@ -280,23 +303,27 @@ class TestSecurityBypassAttempts: async def test_cannot_access_counter_with_expired_session(self, client_factory): """Test that invalid/expired tokens are rejected.""" fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5OTk5IiwiZXhwIjoxfQ.invalid" - + async with client_factory.create(cookies={"auth_token": fake_token}) as client: response = await client.get("/api/counter") - + assert response.status_code == 401 @pytest.mark.asyncio - async def test_cannot_access_with_tampered_token(self, client_factory, regular_user): + async def test_cannot_access_with_tampered_token( + self, client_factory, regular_user + ): """Test that tokens signed with wrong key are rejected.""" # Take a valid token and modify it original_token = regular_user["cookies"].get("auth_token", "") if original_token: tampered_token = original_token[:-5] + "XXXXX" - - async with client_factory.create(cookies={"auth_token": tampered_token}) as client: + + async with client_factory.create( + cookies={"auth_token": tampered_token} + ) as client: response = await client.get("/api/counter") - + assert response.status_code == 401 @pytest.mark.asyncio @@ -305,14 +332,16 @@ class TestSecurityBypassAttempts: Test that new registrations cannot claim admin role. New users should only get 'regular' role by default. """ - from tests.helpers import unique_email, create_invite_for_godfather - from tests.conftest import create_user_with_roles from models import ROLE_REGULAR - + from tests.conftest import create_user_with_roles + from tests.helpers import create_invite_for_godfather, unique_email + async with client_factory.get_db_session() as db: - godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) + godfather = await create_user_with_roles( + db, unique_email("gf"), "pass123", [ROLE_REGULAR] + ) invite_code = await create_invite_for_godfather(db, godfather.id) - + response = await client_factory.post( "/api/auth/register", json={ @@ -321,18 +350,18 @@ class TestSecurityBypassAttempts: "invite_identifier": invite_code, }, ) - + assert response.status_code == 200 data = response.json() - + # Should only have regular role, not admin assert "admin" not in data["roles"] assert Permission.VIEW_AUDIT.value not in data["permissions"] - + # Try to access audit with this new user async with client_factory.create(cookies=dict(response.cookies)) as client: audit_response = await client.get("/api/audit/counter") - + assert audit_response.status_code == 403 @pytest.mark.asyncio @@ -341,33 +370,35 @@ class TestSecurityBypassAttempts: If a user is deleted, their token should no longer work. This tests that tokens are validated against current DB state. """ - from tests.helpers import unique_email from sqlalchemy import delete + from models import User - + from tests.helpers import unique_email + email = unique_email("deleted") - + # Create and login user async with client_factory.get_db_session() as db: from tests.conftest import create_user_with_roles + user = await create_user_with_roles(db, email, "password123", ["regular"]) user_id = user.id - + login_response = await client_factory.post( "/api/auth/login", json={"email": email, "password": "password123"}, ) cookies = dict(login_response.cookies) - + # Delete the user from DB async with client_factory.get_db_session() as db: await db.execute(delete(User).where(User.id == user_id)) await db.commit() - + # Try to use the old token async with client_factory.create(cookies=cookies) as client: response = await client.get("/api/auth/me") - + assert response.status_code == 401 @pytest.mark.asyncio @@ -376,42 +407,41 @@ class TestSecurityBypassAttempts: If a user's role is changed, the change should be reflected in subsequent requests (no stale permission cache). """ - from tests.helpers import unique_email from sqlalchemy import select - from models import User, Role - + + from models import Role, User + from tests.helpers import unique_email + email = unique_email("rolechange") - + # Create regular user async with client_factory.get_db_session() as db: from tests.conftest import create_user_with_roles + await create_user_with_roles(db, email, "password123", ["regular"]) - + login_response = await client_factory.post( "/api/auth/login", json={"email": email, "password": "password123"}, ) cookies = dict(login_response.cookies) - + # Verify can access counter but not audit async with client_factory.create(cookies=cookies) as client: assert (await client.get("/api/counter")).status_code == 200 assert (await client.get("/api/audit/counter")).status_code == 403 - + # Change user's role from regular to admin async with client_factory.get_db_session() as db: result = await db.execute(select(User).where(User.email == email)) user = result.scalar_one() - + result = await db.execute(select(Role).where(Role.name == "admin")) admin_role = result.scalar_one() - - result = await db.execute(select(Role).where(Role.name == "regular")) - regular_role = result.scalar_one() - - user.roles = [admin_role] # Remove regular, add admin + + user.roles = [admin_role] # Replace roles with admin only await db.commit() - + # Now should have audit access but not counter access async with client_factory.create(cookies=cookies) as client: assert (await client.get("/api/audit/counter")).status_code == 200 @@ -422,6 +452,7 @@ class TestSecurityBypassAttempts: # Audit Record Tests # ============================================================================= + class TestAuditRecords: """Test that actions are properly recorded in audit logs.""" @@ -433,15 +464,15 @@ class TestAuditRecords: # Regular user increments counter async with client_factory.create(cookies=regular_user["cookies"]) as client: await client.post("/api/counter/increment") - + # Admin checks audit async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/audit/counter") - + assert response.status_code == 200 data = response.json() assert data["total"] >= 1 - + # Find record for our user records = data["records"] user_records = [r for r in records if r["user_email"] == regular_user["email"]] @@ -455,18 +486,18 @@ class TestAuditRecords: # Regular user uses sum async with client_factory.create(cookies=regular_user["cookies"]) as client: await client.post("/api/sum", json={"a": 10, "b": 20}) - + # Admin checks audit async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/audit/sum") - + assert response.status_code == 200 data = response.json() assert data["total"] >= 1 - + # Find record with our values records = data["records"] - matching = [r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30] + matching = [ + r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30 + ] assert len(matching) >= 1 - - diff --git a/backend/tests/test_profile.py b/backend/tests/test_profile.py index 3b69a17..5accec7 100644 --- a/backend/tests/test_profile.py +++ b/backend/tests/test_profile.py @@ -1,9 +1,9 @@ """Tests for user profile and contact details.""" -import pytest + from sqlalchemy import select -from models import User, ROLE_REGULAR from auth import get_password_hash +from models import User from tests.helpers import unique_email # Valid npub for testing (32 zero bytes encoded as bech32) @@ -16,7 +16,7 @@ class TestUserContactFields: async def test_contact_fields_default_to_none(self, client_factory): """New users should have all contact fields as None.""" email = unique_email("test") - + async with client_factory.get_db_session() as db: user = User( email=email, @@ -25,7 +25,7 @@ class TestUserContactFields: db.add(user) await db.commit() await db.refresh(user) - + assert user.contact_email is None assert user.telegram is None assert user.signal is None @@ -34,7 +34,7 @@ class TestUserContactFields: async def test_contact_fields_can_be_set(self, client_factory): """Contact fields can be set when creating a user.""" email = unique_email("test") - + async with client_factory.get_db_session() as db: user = User( email=email, @@ -47,7 +47,7 @@ class TestUserContactFields: db.add(user) await db.commit() await db.refresh(user) - + assert user.contact_email == "contact@example.com" assert user.telegram == "@alice" assert user.signal == "alice.42" @@ -56,7 +56,7 @@ class TestUserContactFields: async def test_contact_fields_persist_after_reload(self, client_factory): """Contact fields should persist in the database.""" email = unique_email("test") - + async with client_factory.get_db_session() as db: user = User( email=email, @@ -69,12 +69,12 @@ class TestUserContactFields: db.add(user) await db.commit() user_id = user.id - + # Reload from database in a new session async with client_factory.get_db_session() as db: result = await db.execute(select(User).where(User.id == user_id)) loaded_user = result.scalar_one() - + assert loaded_user.contact_email == "contact@example.com" assert loaded_user.telegram == "@bob" assert loaded_user.signal == "bob.99" @@ -83,7 +83,7 @@ class TestUserContactFields: async def test_contact_fields_can_be_updated(self, client_factory): """Contact fields can be updated after user creation.""" email = unique_email("test") - + async with client_factory.get_db_session() as db: user = User( email=email, @@ -92,21 +92,21 @@ class TestUserContactFields: db.add(user) await db.commit() user_id = user.id - + # Update fields async with client_factory.get_db_session() as db: result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one() - + user.contact_email = "new@example.com" user.telegram = "@updated" await db.commit() - + # Verify update persisted async with client_factory.get_db_session() as db: result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one() - + assert user.contact_email == "new@example.com" assert user.telegram == "@updated" assert user.signal is None # Still None @@ -115,7 +115,7 @@ class TestUserContactFields: async def test_contact_fields_can_be_cleared(self, client_factory): """Contact fields can be set back to None.""" email = unique_email("test") - + async with client_factory.get_db_session() as db: user = User( email=email, @@ -126,21 +126,21 @@ class TestUserContactFields: db.add(user) await db.commit() user_id = user.id - + # Clear fields async with client_factory.get_db_session() as db: result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one() - + user.contact_email = None user.telegram = None await db.commit() - + # Verify cleared async with client_factory.get_db_session() as db: result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one() - + assert user.contact_email is None assert user.telegram is None @@ -152,7 +152,7 @@ class TestGetProfileEndpoint: """Regular user can fetch their profile.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/profile") - + assert response.status_code == 200 data = response.json() assert "contact_email" in data @@ -169,7 +169,7 @@ class TestGetProfileEndpoint: """Admin user gets 403 when trying to access profile.""" async with client_factory.create(cookies=admin_user["cookies"]) as client: response = await client.get("/api/profile") - + assert response.status_code == 403 assert "regular users" in response.json()["detail"].lower() @@ -177,7 +177,7 @@ class TestGetProfileEndpoint: """Unauthenticated user gets 401.""" async with client_factory.create() as client: response = await client.get("/api/profile") - + assert response.status_code == 401 async def test_profile_returns_existing_data(self, client_factory, regular_user): @@ -191,11 +191,11 @@ class TestGetProfileEndpoint: user.contact_email = "contact@test.com" user.telegram = "@testuser" await db.commit() - + # Fetch via API async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/profile") - + assert response.status_code == 200 data = response.json() assert data["contact_email"] == "contact@test.com" @@ -219,7 +219,7 @@ class TestUpdateProfileEndpoint: "nostr_npub": VALID_NPUB, }, ) - + assert response.status_code == 200 data = response.json() assert data["contact_email"] == "new@example.com" @@ -234,10 +234,10 @@ class TestUpdateProfileEndpoint: "/api/profile", json={"telegram": "@persisted"}, ) - + # Fetch again to verify response = await client.get("/api/profile") - + assert response.status_code == 200 assert response.json()["telegram"] == "@persisted" @@ -248,7 +248,7 @@ class TestUpdateProfileEndpoint: "/api/profile", json={"telegram": "@admin"}, ) - + assert response.status_code == 403 async def test_unauthenticated_user_gets_401(self, client_factory): @@ -258,7 +258,7 @@ class TestUpdateProfileEndpoint: "/api/profile", json={"telegram": "@test"}, ) - + assert response.status_code == 401 async def test_can_clear_fields(self, client_factory, regular_user): @@ -272,7 +272,7 @@ class TestUpdateProfileEndpoint: "telegram": "@test", }, ) - + # Then clear them response = await client.put( "/api/profile", @@ -283,7 +283,7 @@ class TestUpdateProfileEndpoint: "nostr_npub": None, }, ) - + assert response.status_code == 200 data = response.json() assert data["contact_email"] is None @@ -296,7 +296,7 @@ class TestUpdateProfileEndpoint: "/api/profile", json={"contact_email": "not-an-email"}, ) - + assert response.status_code == 422 data = response.json() assert "field_errors" in data["detail"] @@ -309,7 +309,7 @@ class TestUpdateProfileEndpoint: "/api/profile", json={"telegram": "missing_at_sign"}, ) - + assert response.status_code == 422 data = response.json() assert "field_errors" in data["detail"] @@ -322,13 +322,15 @@ class TestUpdateProfileEndpoint: "/api/profile", json={"nostr_npub": "npub1invalid"}, ) - + assert response.status_code == 422 data = response.json() assert "field_errors" in data["detail"] assert "nostr_npub" in data["detail"]["field_errors"] - async def test_multiple_invalid_fields_returns_all_errors(self, client_factory, regular_user): + async def test_multiple_invalid_fields_returns_all_errors( + self, client_factory, regular_user + ): """Multiple invalid fields return all errors.""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.put( @@ -338,13 +340,15 @@ class TestUpdateProfileEndpoint: "telegram": "no-at", }, ) - + assert response.status_code == 422 data = response.json() assert "contact_email" in data["detail"]["field_errors"] assert "telegram" in data["detail"]["field_errors"] - async def test_partial_update_preserves_other_fields(self, client_factory, regular_user): + async def test_partial_update_preserves_other_fields( + self, client_factory, regular_user + ): """Updating one field doesn't affect others (they get set to the request values).""" async with client_factory.create(cookies=regular_user["cookies"]) as client: # Set initial values @@ -355,7 +359,7 @@ class TestUpdateProfileEndpoint: "telegram": "@initial", }, ) - + # Update only telegram, but note: PUT replaces all fields # So we need to include all fields we want to keep response = await client.put( @@ -365,7 +369,7 @@ class TestUpdateProfileEndpoint: "telegram": "@updated", }, ) - + assert response.status_code == 200 data = response.json() assert data["contact_email"] == "initial@example.com" @@ -386,10 +390,10 @@ class TestProfilePrivacy: "telegram": "@secret", }, ) - + # Check /api/auth/me doesn't expose it response = await client.get("/api/auth/me") - + assert response.status_code == 200 data = response.json() # These fields should NOT be in the response @@ -402,12 +406,15 @@ class TestProfilePrivacy: class TestProfileGodfather: """Tests for godfather information in profile.""" - async def test_profile_shows_godfather_email(self, client_factory, admin_user, regular_user): + async def test_profile_shows_godfather_email( + self, client_factory, admin_user, regular_user + ): """Profile shows godfather email for users who signed up with invite.""" - from tests.helpers import unique_email from sqlalchemy import select + from models import User - + from tests.helpers import unique_email + # Create invite async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.get_db_session() as db: @@ -415,13 +422,13 @@ class TestProfileGodfather: select(User).where(User.email == regular_user["email"]) ) godfather = result.scalar_one() - + create_resp = await client.post( "/api/admin/invites", json={"godfather_id": godfather.id}, ) identifier = create_resp.json()["identifier"] - + # Register new user with invite new_email = unique_email("godchild") async with client_factory.create() as client: @@ -434,20 +441,22 @@ class TestProfileGodfather: }, ) new_user_cookies = dict(reg_resp.cookies) - + # Check profile shows godfather async with client_factory.create(cookies=new_user_cookies) as client: response = await client.get("/api/profile") - + assert response.status_code == 200 data = response.json() assert data["godfather_email"] == regular_user["email"] - async def test_profile_godfather_null_for_seeded_users(self, client_factory, regular_user): + async def test_profile_godfather_null_for_seeded_users( + self, client_factory, regular_user + ): """Profile shows null godfather for users without one (e.g., seeded users).""" async with client_factory.create(cookies=regular_user["cookies"]) as client: response = await client.get("/api/profile") - + assert response.status_code == 200 data = response.json() assert data["godfather_email"] is None diff --git a/backend/tests/test_validation.py b/backend/tests/test_validation.py index 0f64ab1..742c3f3 100644 --- a/backend/tests/test_validation.py +++ b/backend/tests/test_validation.py @@ -1,12 +1,11 @@ """Tests for profile field validation.""" -import pytest from validation import ( validate_contact_email, - validate_telegram, - validate_signal, validate_nostr_npub, validate_profile_fields, + validate_signal, + validate_telegram, ) @@ -140,13 +139,17 @@ class TestValidateNostrNpub: assert validate_nostr_npub(self.VALID_NPUB) is None def test_wrong_prefix(self): - result = validate_nostr_npub("nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz") + result = validate_nostr_npub( + "nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz" + ) assert result is not None assert "npub" in result.lower() def test_invalid_checksum(self): # Change last character to break checksum - result = validate_nostr_npub("npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd") + result = validate_nostr_npub( + "npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd" + ) assert result is not None assert "checksum" in result.lower() @@ -155,7 +158,9 @@ class TestValidateNostrNpub: assert result is not None def test_not_starting_with_npub1(self): - result = validate_nostr_npub("npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc") + result = validate_nostr_npub( + "npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc" + ) assert result is not None assert "npub1" in result @@ -206,4 +211,3 @@ class TestValidateProfileFields: nostr_npub="", ) assert errors == {} - diff --git a/backend/validate_constants.py b/backend/validate_constants.py index e485399..1ce7f0a 100644 --- a/backend/validate_constants.py +++ b/backend/validate_constants.py @@ -1,8 +1,9 @@ """Validate shared constants match backend definitions.""" + import json from pathlib import Path -from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, AppointmentStatus +from models import ROLE_ADMIN, ROLE_REGULAR, AppointmentStatus, InviteStatus def validate_shared_constants() -> None: @@ -11,13 +12,13 @@ def validate_shared_constants() -> None: Raises ValueError if there's a mismatch. """ constants_path = Path(__file__).parent.parent / "shared" / "constants.json" - + if not constants_path.exists(): raise ValueError(f"Shared constants file not found: {constants_path}") - + with open(constants_path) as f: constants = json.load(f) - + # Validate roles expected_roles = {"ADMIN": ROLE_ADMIN, "REGULAR": ROLE_REGULAR} if constants.get("roles") != expected_roles: @@ -25,39 +26,46 @@ def validate_shared_constants() -> None: f"Role mismatch in shared/constants.json. " f"Expected: {expected_roles}, Got: {constants.get('roles')}" ) - + # Validate invite statuses expected_invite_statuses = {s.name: s.value for s in InviteStatus} if constants.get("inviteStatuses") != expected_invite_statuses: + got = constants.get("inviteStatuses") raise ValueError( - f"Invite status mismatch in shared/constants.json. " - f"Expected: {expected_invite_statuses}, Got: {constants.get('inviteStatuses')}" + f"Invite status mismatch. Expected: {expected_invite_statuses}, Got: {got}" ) - + # Validate appointment statuses expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus} if constants.get("appointmentStatuses") != expected_appointment_statuses: + got = constants.get("appointmentStatuses") raise ValueError( - f"Appointment status mismatch in shared/constants.json. " - f"Expected: {expected_appointment_statuses}, Got: {constants.get('appointmentStatuses')}" + f"Appointment status mismatch. " + f"Expected: {expected_appointment_statuses}, Got: {got}" ) - + # Validate booking constants exist with required fields booking = constants.get("booking", {}) - required_booking_fields = ["slotDurationMinutes", "maxAdvanceDays", "minAdvanceDays", "noteMaxLength"] + required_booking_fields = [ + "slotDurationMinutes", + "maxAdvanceDays", + "minAdvanceDays", + "noteMaxLength", + ] for field in required_booking_fields: if field not in booking: - raise ValueError(f"Missing booking constant '{field}' in shared/constants.json") - + raise ValueError(f"Missing booking constant '{field}' in constants.json") + # Validate validation rules exist (structure check only) validation = constants.get("validation", {}) required_fields = ["telegram", "signal", "nostrNpub"] for field in required_fields: if field not in validation: - raise ValueError(f"Missing validation rules for '{field}' in shared/constants.json") + raise ValueError( + f"Missing validation rules for '{field}' in constants.json" + ) if __name__ == "__main__": validate_shared_constants() print("✓ Shared constants are valid") - diff --git a/backend/validation.py b/backend/validation.py index 51a4496..d933709 100644 --- a/backend/validation.py +++ b/backend/validation.py @@ -1,9 +1,10 @@ """Validation utilities for user profile fields.""" + import json from pathlib import Path -from email_validator import validate_email, EmailNotValidError from bech32 import bech32_decode +from email_validator import EmailNotValidError, validate_email # Load validation rules from shared constants _constants_path = Path(__file__).parent.parent / "shared" / "constants.json" @@ -18,13 +19,13 @@ NPUB_RULES = _constants["validation"]["nostrNpub"] def validate_contact_email(value: str | None) -> str | None: """ Validate contact email format. - + Returns None if valid, error message if invalid. Empty/None values are valid (field is optional). """ if not value: return None - + try: validate_email(value, check_deliverability=False) return None @@ -35,84 +36,84 @@ def validate_contact_email(value: str | None) -> str | None: def validate_telegram(value: str | None) -> str | None: """ Validate Telegram handle. - + Must start with @ if provided, with characters after @ within max length. Returns None if valid, error message if invalid. Empty/None values are valid (field is optional). """ if not value: return None - + prefix = TELEGRAM_RULES["mustStartWith"] max_len = TELEGRAM_RULES["maxLengthAfterAt"] - + if not value.startswith(prefix): return f"Telegram handle must start with {prefix}" - + handle = value[1:] if not handle: return f"Telegram handle must have at least one character after {prefix}" - + if len(handle) > max_len: return f"Telegram handle must be at most {max_len} characters (after {prefix})" - + return None def validate_signal(value: str | None) -> str | None: """ Validate Signal username. - + Any non-empty string within max length is valid. Returns None if valid, error message if invalid. Empty/None values are valid (field is optional). """ if not value: return None - + max_len = SIGNAL_RULES["maxLength"] - + # Signal usernames are fairly permissive, just check it's not empty if len(value.strip()) == 0: return "Signal username cannot be empty" - + if len(value) > max_len: return f"Signal username must be at most {max_len} characters" - + return None def validate_nostr_npub(value: str | None) -> str | None: """ Validate Nostr npub (public key in bech32 format). - + Must be valid bech32 with 'npub' prefix. Returns None if valid, error message if invalid. Empty/None values are valid (field is optional). """ if not value: return None - + prefix = NPUB_RULES["prefix"] expected_words = NPUB_RULES["bech32Words"] - + if not value.startswith(prefix): return f"Nostr npub must start with '{prefix}'" - + # Decode bech32 to validate checksum hrp, data = bech32_decode(value) - + if hrp is None or data is None: return "Invalid Nostr npub: bech32 checksum failed" - + if hrp != "npub": return "Nostr npub must have 'npub' prefix" - + # npub should decode to 32 bytes (256 bits) for a public key # In bech32, each character encodes 5 bits, so 32 bytes = 52 characters of data if len(data) != expected_words: return "Invalid Nostr npub: incorrect length" - + return None @@ -124,23 +125,22 @@ def validate_profile_fields( ) -> dict[str, str]: """ Validate all profile fields at once. - + Returns a dict of field_name -> error_message for any invalid fields. Empty dict means all fields are valid. """ errors: dict[str, str] = {} - + if err := validate_contact_email(contact_email): errors["contact_email"] = err - + if err := validate_telegram(telegram): errors["telegram"] = err - + if err := validate_signal(signal): errors["signal"] = err - + if err := validate_nostr_npub(nostr_npub): errors["nostr_npub"] = err - - return errors + return errors