Add ruff linter/formatter for Python

- Add ruff as dev dependency
- Configure ruff in pyproject.toml with strict 88-char line limit
- Ignore B008 (FastAPI Depends pattern is standard)
- Allow longer lines in tests for readability
- Fix all lint issues in source files
- Add Makefile targets: lint-backend, format-backend, fix-backend
This commit is contained in:
counterweight 2025-12-21 21:54:26 +01:00
parent 69bc8413e0
commit 6c218130e9
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
31 changed files with 1234 additions and 876 deletions

View file

@ -1,4 +1,4 @@
.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants .PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants lint-backend format-backend fix-backend
-include .env -include .env
export export
@ -93,3 +93,12 @@ check-types-fresh: generate-types-standalone
check-constants: check-constants:
@cd backend && uv run python validate_constants.py @cd backend && uv run python validate_constants.py
lint-backend:
cd backend && uv run ruff check .
format-backend:
cd backend && uv run ruff format .
fix-backend:
cd backend && uv run ruff check --fix . && uv run ruff format .

View file

@ -1,5 +1,5 @@
import os import os
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
import bcrypt import bcrypt
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db from database import get_db
from models import User, Permission from models import Permission, User
from schemas import UserResponse from schemas import UserResponse
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
@ -32,9 +32,13 @@ def get_password_hash(password: str) -> str:
).decode("utf-8") ).decode("utf-8")
def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str: def create_access_token(
data: dict[str, str],
expires_delta: timedelta | None = None,
) -> str:
to_encode: dict[str, str | datetime] = dict(data) to_encode: dict[str, str | datetime] = dict(data)
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + delta
to_encode["exp"] = expire to_encode["exp"] = expire
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded return encoded
@ -72,7 +76,7 @@ async def get_current_user(
raise credentials_exception raise credentials_exception
user_id = int(user_id_str) user_id = int(user_id_str)
except (JWTError, ValueError): except (JWTError, ValueError):
raise credentials_exception raise credentials_exception from None
result = await db.execute(select(User).where(User.id == user_id)) result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
@ -83,13 +87,16 @@ async def get_current_user(
def require_permission(*required_permissions: Permission): def require_permission(*required_permissions: Permission):
""" """
Dependency factory that checks if user has ALL of the required permissions. Dependency factory that checks if user has ALL required permissions.
Usage: Usage:
@app.get("/api/counter") @app.get("/api/counter")
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))): async def get_counter(
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
):
... ...
""" """
async def permission_checker( async def permission_checker(
request: Request, request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
@ -99,11 +106,13 @@ def require_permission(*required_permissions: Permission):
missing = [p for p in required_permissions if p not in user_permissions] missing = [p for p in required_permissions if p not in user_permissions]
if missing: if missing:
missing_str = ", ".join(p.value for p in missing)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}", detail=f"Missing required permissions: {missing_str}",
) )
return user return user
return permission_checker return permission_checker

View file

@ -1,8 +1,11 @@
import os import os
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret") DATABASE_URL = os.getenv(
"DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret"
)
engine = create_async_engine(DATABASE_URL) engine = create_async_engine(DATABASE_URL)
async_session = async_sessionmaker(engine, expire_on_commit=False) async_session = async_sessionmaker(engine, expire_on_commit=False)
@ -15,4 +18,3 @@ class Base(DeclarativeBase):
async def get_db(): async def get_db():
async with async_session() as session: async with async_session() as session:
yield session yield session

View file

@ -1,4 +1,5 @@
"""Utilities for invite code generation and validation.""" """Utilities for invite code generation and validation."""
import random import random
from pathlib import Path from pathlib import Path
@ -53,8 +54,4 @@ def is_valid_identifier_format(identifier: str) -> bool:
return False return False
# Check number is two digits # Check number is two digits
if len(number) != 2 or not number.isdigit(): return len(number) == 2 and number.isdigit()
return False
return True

View file

@ -1,19 +1,20 @@
"""FastAPI application entry point.""" """FastAPI application entry point."""
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from database import engine, Base from database import Base, engine
from routes import sum as sum_routes
from routes import counter as counter_routes
from routes import audit as audit_routes from routes import audit as audit_routes
from routes import profile as profile_routes
from routes import invites as invites_routes
from routes import auth as auth_routes from routes import auth as auth_routes
from routes import meta as meta_routes
from routes import availability as availability_routes from routes import availability as availability_routes
from routes import booking as booking_routes from routes import booking as booking_routes
from routes import counter as counter_routes
from routes import invites as invites_routes
from routes import meta as meta_routes
from routes import profile as profile_routes
from routes import sum as sum_routes
from validate_constants import validate_shared_constants from validate_constants import validate_shared_constants

View file

@ -1,9 +1,24 @@
from datetime import datetime, date, time, timezone from datetime import UTC, date, datetime, time
from enum import Enum as PyEnum from enum import Enum as PyEnum
from typing import TypedDict from typing import TypedDict
from sqlalchemy import Integer, String, Float, DateTime, Date, Time, ForeignKey, Table, Column, Enum, UniqueConstraint, select
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy import (
Column,
Date,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
String,
Table,
Time,
UniqueConstraint,
select,
)
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship
from database import Base from database import Base
@ -14,6 +29,7 @@ class RoleConfig(TypedDict):
class Permission(str, PyEnum): class Permission(str, PyEnum):
"""All available permissions in the system.""" """All available permissions in the system."""
# Counter permissions # Counter permissions
VIEW_COUNTER = "view_counter" VIEW_COUNTER = "view_counter"
INCREMENT_COUNTER = "increment_counter" INCREMENT_COUNTER = "increment_counter"
@ -41,6 +57,7 @@ class Permission(str, PyEnum):
class InviteStatus(str, PyEnum): class InviteStatus(str, PyEnum):
"""Status of an invite.""" """Status of an invite."""
READY = "ready" READY = "ready"
SPENT = "spent" SPENT = "spent"
REVOKED = "revoked" REVOKED = "revoked"
@ -48,6 +65,7 @@ class InviteStatus(str, PyEnum):
class AppointmentStatus(str, PyEnum): class AppointmentStatus(str, PyEnum):
"""Status of an appointment.""" """Status of an appointment."""
BOOKED = "booked" BOOKED = "booked"
CANCELLED_BY_USER = "cancelled_by_user" CANCELLED_BY_USER = "cancelled_by_user"
CANCELLED_BY_ADMIN = "cancelled_by_admin" CANCELLED_BY_ADMIN = "cancelled_by_admin"
@ -60,7 +78,7 @@ ROLE_REGULAR = "regular"
# Role definitions with their permissions # Role definitions with their permissions
ROLE_DEFINITIONS: dict[str, RoleConfig] = { ROLE_DEFINITIONS: dict[str, RoleConfig] = {
ROLE_ADMIN: { ROLE_ADMIN: {
"description": "Administrator with audit, invite, and appointment management access", "description": "Administrator with audit/invite/appointment access",
"permissions": [ "permissions": [
Permission.VIEW_AUDIT, Permission.VIEW_AUDIT,
Permission.MANAGE_INVITES, Permission.MANAGE_INVITES,
@ -88,7 +106,12 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = {
role_permissions = Table( role_permissions = Table(
"role_permissions", "role_permissions",
Base.metadata, Base.metadata,
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True), Column(
"role_id",
Integer,
ForeignKey("roles.id", ondelete="CASCADE"),
primary_key=True,
),
Column("permission", Enum(Permission), primary_key=True), Column("permission", Enum(Permission), primary_key=True),
) )
@ -97,8 +120,18 @@ role_permissions = Table(
user_roles = Table( user_roles = Table(
"user_roles", "user_roles",
Base.metadata, Base.metadata,
Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True), Column(
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True), "user_id",
Integer,
ForeignKey("users.id", ondelete="CASCADE"),
primary_key=True,
),
Column(
"role_id",
Integer,
ForeignKey("roles.id", ondelete="CASCADE"),
primary_key=True,
),
) )
@ -118,23 +151,34 @@ class Role(Base):
async def get_permissions(self, db: AsyncSession) -> set[Permission]: async def get_permissions(self, db: AsyncSession) -> set[Permission]:
"""Get all permissions for this role.""" """Get all permissions for this role."""
result = await db.execute( query = select(role_permissions.c.permission).where(
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id) role_permissions.c.role_id == self.id
) )
result = await db.execute(query)
return {row[0] for row in result.fetchall()} return {row[0] for row in result.fetchall()}
async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None: async def set_permissions(
self, db: AsyncSession, permissions: list[Permission]
) -> None:
"""Set all permissions for this role (replaces existing).""" """Set all permissions for this role (replaces existing)."""
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id)) delete_query = role_permissions.delete().where(
role_permissions.c.role_id == self.id
)
await db.execute(delete_query)
for perm in permissions: for perm in permissions:
await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm)) insert_query = role_permissions.insert().values(
role_id=self.id, permission=perm
)
await db.execute(insert_query)
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) email: Mapped[str] = mapped_column(
String(255), unique=True, nullable=False, index=True
)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
# Contact details (all optional) # Contact details (all optional)
@ -192,12 +236,14 @@ class SumRecord(Base):
__tablename__ = "sum_records" __tablename__ = "sum_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
a: Mapped[float] = mapped_column(Float, nullable=False) a: Mapped[float] = mapped_column(Float, nullable=False)
b: Mapped[float] = mapped_column(Float, nullable=False) b: Mapped[float] = mapped_column(Float, nullable=False)
result: Mapped[float] = mapped_column(Float, nullable=False) result: Mapped[float] = mapped_column(Float, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
) )
@ -205,11 +251,13 @@ class CounterRecord(Base):
__tablename__ = "counter_records" __tablename__ = "counter_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
value_before: Mapped[int] = mapped_column(Integer, nullable=False) value_before: Mapped[int] = mapped_column(Integer, nullable=False)
value_after: Mapped[int] = mapped_column(Integer, nullable=False) value_after: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
) )
@ -217,7 +265,9 @@ class Invite(Base):
__tablename__ = "invites" __tablename__ = "invites"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
identifier: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) identifier: Mapped[str] = mapped_column(
String(64), unique=True, nullable=False, index=True
)
status: Mapped[InviteStatus] = mapped_column( status: Mapped[InviteStatus] = mapped_column(
Enum(InviteStatus), nullable=False, default=InviteStatus.READY Enum(InviteStatus), nullable=False, default=InviteStatus.READY
) )
@ -244,14 +294,19 @@ class Invite(Base):
# Timestamps # Timestamps
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
spent_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
revoked_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
) )
spent_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
class Availability(Base): class Availability(Base):
"""Admin availability slots for booking.""" """Admin availability slots for booking."""
__tablename__ = "availability" __tablename__ = "availability"
__table_args__ = ( __table_args__ = (
UniqueConstraint("date", "start_time", name="uq_availability_date_start"), UniqueConstraint("date", "start_time", name="uq_availability_date_start"),
@ -262,34 +317,37 @@ class Availability(Base):
start_time: Mapped[time] = mapped_column(Time, nullable=False) start_time: Mapped[time] = mapped_column(Time, nullable=False)
end_time: Mapped[time] = mapped_column(Time, nullable=False) end_time: Mapped[time] = mapped_column(Time, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc), default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(timezone.utc) onupdate=lambda: datetime.now(UTC),
) )
class Appointment(Base): class Appointment(Base):
"""User appointment bookings.""" """User appointment bookings."""
__tablename__ = "appointments" __tablename__ = "appointments"
__table_args__ = ( __table_args__ = (UniqueConstraint("slot_start", name="uq_appointment_slot_start"),)
UniqueConstraint("slot_start", name="uq_appointment_slot_start"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column( user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True Integer, ForeignKey("users.id"), nullable=False, index=True
) )
user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined") user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined")
slot_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True) slot_start: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, index=True
)
slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
note: Mapped[str | None] = mapped_column(String(144), nullable=True) note: Mapped[str | None] = mapped_column(String(144), nullable=True)
status: Mapped[AppointmentStatus] = mapped_column( status: Mapped[AppointmentStatus] = mapped_column(
Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc) DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
cancelled_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
) )
cancelled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)

View file

@ -20,6 +20,7 @@ dev = [
"httpx>=0.28.1", "httpx>=0.28.1",
"aiosqlite>=0.20.0", "aiosqlite>=0.20.0",
"mypy>=1.13.0", "mypy>=1.13.0",
"ruff>=0.14.10",
] ]
[tool.mypy] [tool.mypy]
@ -30,3 +31,27 @@ check_untyped_defs = true
ignore_missing_imports = true ignore_missing_imports = true
exclude = ["tests/"] exclude = ["tests/"]
[tool.ruff]
line-length = 88
target-version = "py311"
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"UP", # pyupgrade
"SIM", # flake8-simplify
"RUF", # ruff-specific rules
]
ignore = [
"B008", # function-call-in-default-argument (standard FastAPI pattern with Depends)
]
[tool.ruff.format]
quote-style = "double"
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["E501"] # Allow longer lines in tests for readability

View file

@ -1,22 +1,23 @@
"""Audit routes for viewing action records.""" """Audit routes for viewing action records."""
from typing import Callable, TypeVar
from collections.abc import Callable
from typing import TypeVar
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import select, func, desc from sqlalchemy import desc, func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import User, SumRecord, CounterRecord, Permission from models import CounterRecord, Permission, SumRecord, User
from schemas import ( from schemas import (
CounterRecordResponse, CounterRecordResponse,
SumRecordResponse,
PaginatedCounterRecords, PaginatedCounterRecords,
PaginatedSumRecords, PaginatedSumRecords,
SumRecordResponse,
) )
router = APIRouter(prefix="/api/audit", tags=["audit"]) router = APIRouter(prefix="/api/audit", tags=["audit"])
R = TypeVar("R", bound=BaseModel) R = TypeVar("R", bound=BaseModel)

View file

@ -1,5 +1,6 @@
"""Authentication routes for register, login, logout, and current user.""" """Authentication routes for register, login, logout, and current user."""
from datetime import datetime, timezone
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Response, status from fastapi import APIRouter, Depends, HTTPException, Response, status
from sqlalchemy import select from sqlalchemy import select
@ -9,18 +10,17 @@ from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES, ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME, COOKIE_NAME,
COOKIE_SECURE, COOKIE_SECURE,
get_password_hash,
get_user_by_email,
authenticate_user, authenticate_user,
build_user_response,
create_access_token, create_access_token,
get_current_user, get_current_user,
build_user_response, get_password_hash,
get_user_by_email,
) )
from database import get_db from database import get_db
from invite_utils import normalize_identifier from invite_utils import normalize_identifier
from models import User, Role, ROLE_REGULAR, Invite, InviteStatus from models import ROLE_REGULAR, Invite, InviteStatus, Role, User
from schemas import UserLogin, UserResponse, RegisterWithInvite from schemas import RegisterWithInvite, UserLogin, UserResponse
router = APIRouter(prefix="/api/auth", tags=["auth"]) router = APIRouter(prefix="/api/auth", tags=["auth"])
@ -52,9 +52,8 @@ async def register(
"""Register a new user using an invite code.""" """Register a new user using an invite code."""
# Validate invite # Validate invite
normalized_identifier = normalize_identifier(user_data.invite_identifier) normalized_identifier = normalize_identifier(user_data.invite_identifier)
result = await db.execute( query = select(Invite).where(Invite.identifier == normalized_identifier)
select(Invite).where(Invite.identifier == normalized_identifier) result = await db.execute(query)
)
invite = result.scalar_one_or_none() invite = result.scalar_one_or_none()
# Return same error for not found, spent, and revoked to avoid information leakage # Return same error for not found, spent, and revoked to avoid information leakage
@ -90,7 +89,7 @@ async def register(
# Mark invite as spent # Mark invite as spent
invite.status = InviteStatus.SPENT invite.status = InviteStatus.SPENT
invite.used_by_id = user.id invite.used_by_id = user.id
invite.spent_at = datetime.now(timezone.utc) invite.spent_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(user) await db.refresh(user)

View file

@ -1,28 +1,28 @@
"""Availability routes for admin to manage booking availability.""" """Availability routes for admin to manage booking availability."""
from datetime import date, timedelta from datetime import date, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, delete, and_ from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import User, Availability, Permission from models import Availability, Permission, User
from schemas import ( from schemas import (
TimeSlot,
AvailabilityDay, AvailabilityDay,
AvailabilityResponse, AvailabilityResponse,
SetAvailabilityRequest,
CopyAvailabilityRequest, CopyAvailabilityRequest,
SetAvailabilityRequest,
TimeSlot,
) )
from shared_constants import MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS
router = APIRouter(prefix="/api/admin/availability", tags=["availability"]) router = APIRouter(prefix="/api/admin/availability", tags=["availability"])
def _get_date_range_bounds() -> tuple[date, date]: def _get_date_range_bounds() -> tuple[date, date]:
"""Get the valid date range for availability (using MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS).""" """Get valid date range (MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
today = date.today() today = date.today()
min_date = today + timedelta(days=MIN_ADVANCE_DAYS) min_date = today + timedelta(days=MIN_ADVANCE_DAYS)
max_date = today + timedelta(days=MAX_ADVANCE_DAYS) max_date = today + timedelta(days=MAX_ADVANCE_DAYS)
@ -34,12 +34,14 @@ def _validate_date_in_range(d: date, min_date: date, max_date: date) -> None:
if d < min_date: if d < min_date:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot set availability for past dates. Earliest allowed: {min_date}", detail=f"Cannot set availability for past dates. "
f"Earliest allowed: {min_date}",
) )
if d > max_date: if d > max_date:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot set availability more than {MAX_ADVANCE_DAYS} days ahead. Latest allowed: {max_date}", detail=f"Cannot set more than {MAX_ADVANCE_DAYS} days ahead. "
f"Latest allowed: {max_date}",
) )
@ -70,15 +72,16 @@ async def get_availability(
for slot in slots: for slot in slots:
if slot.date not in days_dict: if slot.date not in days_dict:
days_dict[slot.date] = [] days_dict[slot.date] = []
days_dict[slot.date].append(TimeSlot( days_dict[slot.date].append(
start_time=slot.start_time, TimeSlot(
end_time=slot.end_time, start_time=slot.start_time,
)) end_time=slot.end_time,
)
)
# Convert to response format # Convert to response format
days = [ days = [
AvailabilityDay(date=d, slots=days_dict[d]) AvailabilityDay(date=d, slots=days_dict[d]) for d in sorted(days_dict.keys())
for d in sorted(days_dict.keys())
] ]
return AvailabilityResponse(days=days) return AvailabilityResponse(days=days)
@ -98,9 +101,12 @@ async def set_availability(
sorted_slots = sorted(request.slots, key=lambda s: s.start_time) sorted_slots = sorted(request.slots, key=lambda s: s.start_time)
for i in range(len(sorted_slots) - 1): for i in range(len(sorted_slots) - 1):
if sorted_slots[i].end_time > sorted_slots[i + 1].start_time: if sorted_slots[i].end_time > sorted_slots[i + 1].start_time:
end = sorted_slots[i].end_time
start = sorted_slots[i + 1].start_time
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Time slots overlap on {request.date}: slot ending at {sorted_slots[i].end_time} overlaps with slot starting at {sorted_slots[i + 1].start_time}. Please ensure all time slots are non-overlapping.", detail=f"Time slots overlap: slot ending at {end} "
f"overlaps with slot starting at {start}",
) )
# Validate each slot's end_time > start_time # Validate each slot's end_time > start_time
@ -108,13 +114,12 @@ async def set_availability(
if slot.end_time <= slot.start_time: if slot.end_time <= slot.start_time:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid time slot on {request.date}: end time {slot.end_time} must be after start time {slot.start_time}. Please correct the time range.", detail=f"Invalid time slot: end time {slot.end_time} "
f"must be after start time {slot.start_time}",
) )
# Delete existing availability for this date # Delete existing availability for this date
await db.execute( await db.execute(delete(Availability).where(Availability.date == request.date))
delete(Availability).where(Availability.date == request.date)
)
# Create new availability slots # Create new availability slots
for slot in request.slots: for slot in request.slots:
@ -139,7 +144,7 @@ async def copy_availability(
"""Copy availability from one day to multiple target days.""" """Copy availability from one day to multiple target days."""
min_date, max_date = _get_date_range_bounds() min_date, max_date = _get_date_range_bounds()
# Validate source date is in range (for consistency, though DB query would fail anyway) # Validate source date is in range
_validate_date_in_range(request.source_date, min_date, max_date) _validate_date_in_range(request.source_date, min_date, max_date)
# Validate target dates # Validate target dates
@ -169,9 +174,8 @@ async def copy_availability(
continue # Skip copying to self continue # Skip copying to self
# Delete existing availability for target date # Delete existing availability for target date
await db.execute( del_query = delete(Availability).where(Availability.date == target_date)
delete(Availability).where(Availability.date == target_date) await db.execute(del_query)
)
# Copy slots # Copy slots
target_slots: list[TimeSlot] = [] target_slots: list[TimeSlot] = []
@ -182,10 +186,12 @@ async def copy_availability(
end_time=source_slot.end_time, end_time=source_slot.end_time,
) )
db.add(new_availability) db.add(new_availability)
target_slots.append(TimeSlot( target_slots.append(
start_time=source_slot.start_time, TimeSlot(
end_time=source_slot.end_time, start_time=source_slot.start_time,
)) end_time=source_slot.end_time,
)
)
copied_days.append(AvailabilityDay(date=target_date, slots=target_slots)) copied_days.append(AvailabilityDay(date=target_date, slots=target_slots))
@ -197,4 +203,3 @@ async def copy_availability(
raise raise
return AvailabilityResponse(days=copied_days) return AvailabilityResponse(days=copied_days)

View file

@ -1,24 +1,24 @@
"""Booking routes for users to book appointments.""" """Booking routes for users to book appointments."""
from datetime import date, datetime, time, timedelta, timezone
from datetime import UTC, date, datetime, time, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, and_, func from sqlalchemy import and_, func, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import User, Availability, Appointment, AppointmentStatus, Permission from models import Appointment, AppointmentStatus, Availability, Permission, User
from schemas import ( from schemas import (
BookableSlot,
AvailableSlotsResponse,
BookingRequest,
AppointmentResponse, AppointmentResponse,
AvailableSlotsResponse,
BookableSlot,
BookingRequest,
PaginatedAppointments, PaginatedAppointments,
) )
from shared_constants import SLOT_DURATION_MINUTES, MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS, SLOT_DURATION_MINUTES
router = APIRouter(prefix="/api/booking", tags=["booking"]) router = APIRouter(prefix="/api/booking", tags=["booking"])
@ -74,12 +74,14 @@ def _validate_booking_date(d: date) -> None:
if d < min_date: if d < min_date:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot book for today or past dates. Earliest bookable date: {min_date}", detail=f"Cannot book for today or past dates. "
f"Earliest bookable: {min_date}",
) )
if d > max_date: if d > max_date:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. Latest bookable: {max_date}", detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. "
f"Latest bookable: {max_date}",
) )
@ -92,8 +94,8 @@ def _expand_availability_to_slots(
for avail in availability_slots: for avail in availability_slots:
# Create datetime objects for start and end # Create datetime objects for start and end
current = datetime.combine(target_date, avail.start_time, tzinfo=timezone.utc) current = datetime.combine(target_date, avail.start_time, tzinfo=UTC)
end = datetime.combine(target_date, avail.end_time, tzinfo=timezone.utc) end = datetime.combine(target_date, avail.end_time, tzinfo=UTC)
# Generate 15-minute slots # Generate 15-minute slots
while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end: while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end:
@ -128,12 +130,11 @@ async def get_available_slots(
all_slots = _expand_availability_to_slots(availability_slots, target_date) all_slots = _expand_availability_to_slots(availability_slots, target_date)
# Get existing booked appointments for this date # Get existing booked appointments for this date
day_start = datetime.combine(target_date, time.min, tzinfo=timezone.utc) day_start = datetime.combine(target_date, time.min, tzinfo=UTC)
day_end = datetime.combine(target_date, time.max, tzinfo=timezone.utc) day_end = datetime.combine(target_date, time.max, tzinfo=UTC)
result = await db.execute( result = await db.execute(
select(Appointment.slot_start) select(Appointment.slot_start).where(
.where(
and_( and_(
Appointment.slot_start >= day_start, Appointment.slot_start >= day_start,
Appointment.slot_start <= day_end, Appointment.slot_start <= day_end,
@ -145,8 +146,7 @@ async def get_available_slots(
# Filter out already booked slots # Filter out already booked slots
available_slots = [ available_slots = [
slot for slot in all_slots slot for slot in all_slots if slot.start_time not in booked_starts
if slot.start_time not in booked_starts
] ]
return AvailableSlotsResponse(date=target_date, slots=available_slots) return AvailableSlotsResponse(date=target_date, slots=available_slots)
@ -162,12 +162,13 @@ async def create_booking(
slot_date = request.slot_start.date() slot_date = request.slot_start.date()
_validate_booking_date(slot_date) _validate_booking_date(slot_date)
# Validate slot is on the correct minute boundary (derived from SLOT_DURATION_MINUTES) # Validate slot is on the correct minute boundary
valid_minutes = _get_valid_minute_boundaries() valid_minutes = _get_valid_minute_boundaries()
if request.slot_start.minute not in valid_minutes: if request.slot_start.minute not in valid_minutes:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Slot start time must be on {SLOT_DURATION_MINUTES}-minute boundary (valid minutes: {valid_minutes})", detail=f"Slot must be on {SLOT_DURATION_MINUTES}-minute boundary "
f"(valid minutes: {valid_minutes})",
) )
if request.slot_start.second != 0 or request.slot_start.microsecond != 0: if request.slot_start.second != 0 or request.slot_start.microsecond != 0:
raise HTTPException( raise HTTPException(
@ -177,11 +178,11 @@ async def create_booking(
# Verify slot falls within availability # Verify slot falls within availability
slot_start_time = request.slot_start.time() slot_start_time = request.slot_start.time()
slot_end_time = (request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)).time() slot_end_dt = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
slot_end_time = slot_end_dt.time()
result = await db.execute( result = await db.execute(
select(Availability) select(Availability).where(
.where(
and_( and_(
Availability.date == slot_date, Availability.date == slot_date,
Availability.start_time <= slot_start_time, Availability.start_time <= slot_start_time,
@ -192,9 +193,11 @@ async def create_booking(
matching_availability = result.scalar_one_or_none() matching_availability = result.scalar_one_or_none()
if not matching_availability: if not matching_availability:
slot_str = request.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Selected slot at {request.slot_start.strftime('%Y-%m-%d %H:%M')} UTC is not within any available time ranges for {slot_date}. Please select a different time slot.", detail=f"Selected slot at {slot_str} UTC is not within "
f"any available time ranges for {slot_date}",
) )
# Create the appointment # Create the appointment
@ -216,8 +219,8 @@ async def create_booking(
await db.rollback() await db.rollback()
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail="This slot has already been booked. Please select another slot.", detail="This slot has already been booked. Select another slot.",
) ) from None
return _to_appointment_response(appointment, current_user.email) return _to_appointment_response(appointment, current_user.email)
@ -242,20 +245,19 @@ async def get_my_appointments(
) )
appointments = result.scalars().all() appointments = result.scalars().all()
return [ return [_to_appointment_response(apt, current_user.email) for apt in appointments]
_to_appointment_response(apt, current_user.email)
for apt in appointments
]
@appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse) @appointments_router.post(
"/{appointment_id}/cancel", response_model=AppointmentResponse
)
async def cancel_my_appointment( async def cancel_my_appointment(
appointment_id: int, appointment_id: int,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)), current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)),
) -> AppointmentResponse: ) -> AppointmentResponse:
"""Cancel one of the current user's appointments.""" """Cancel one of the current user's appointments."""
# Get the appointment with explicit eager loading of user relationship # Get the appointment with eager loading of user relationship
result = await db.execute( result = await db.execute(
select(Appointment) select(Appointment)
.options(joinedload(Appointment.user)) .options(joinedload(Appointment.user))
@ -266,31 +268,35 @@ async def cancel_my_appointment(
if not appointment: if not appointment:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.", detail=f"Appointment {appointment_id} not found",
) )
# Verify ownership # Verify ownership
if appointment.user_id != current_user.id: if appointment.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Cannot cancel another user's appointment") raise HTTPException(
status_code=403,
detail="Cannot cancel another user's appointment",
)
# Check if already cancelled # Check if already cancelled
if appointment.status != AppointmentStatus.BOOKED: if appointment.status != AppointmentStatus.BOOKED:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot cancel appointment with status '{appointment.status.value}'" detail=f"Cannot cancel: status is '{appointment.status.value}'",
) )
# Check if appointment is in the past # Check if appointment is in the past
if appointment.slot_start <= datetime.now(timezone.utc): if appointment.slot_start <= datetime.now(UTC):
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC" apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started." detail=f"Cannot cancel appointment at {apt_time} UTC: "
"already started or in the past",
) )
# Cancel the appointment # Cancel the appointment
appointment.status = AppointmentStatus.CANCELLED_BY_USER appointment.status = AppointmentStatus.CANCELLED_BY_USER
appointment.cancelled_at = datetime.now(timezone.utc) appointment.cancelled_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(appointment) await db.refresh(appointment)
@ -302,7 +308,9 @@ async def cancel_my_appointment(
# Admin Appointments Endpoints # Admin Appointments Endpoints
# ============================================================================= # =============================================================================
admin_appointments_router = APIRouter(prefix="/api/admin/appointments", tags=["admin-appointments"]) admin_appointments_router = APIRouter(
prefix="/api/admin/appointments", tags=["admin-appointments"]
)
@admin_appointments_router.get("", response_model=PaginatedAppointments) @admin_appointments_router.get("", response_model=PaginatedAppointments)
@ -344,14 +352,18 @@ async def get_all_appointments(
) )
@admin_appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse) @admin_appointments_router.post(
"/{appointment_id}/cancel", response_model=AppointmentResponse
)
async def admin_cancel_appointment( async def admin_cancel_appointment(
appointment_id: int, appointment_id: int,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.CANCEL_ANY_APPOINTMENT)), _current_user: User = Depends(
require_permission(Permission.CANCEL_ANY_APPOINTMENT)
),
) -> AppointmentResponse: ) -> AppointmentResponse:
"""Cancel any appointment (admin only).""" """Cancel any appointment (admin only)."""
# Get the appointment with explicit eager loading of user relationship # Get the appointment with eager loading of user relationship
result = await db.execute( result = await db.execute(
select(Appointment) select(Appointment)
.options(joinedload(Appointment.user)) .options(joinedload(Appointment.user))
@ -362,27 +374,28 @@ async def admin_cancel_appointment(
if not appointment: if not appointment:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.", detail=f"Appointment {appointment_id} not found",
) )
# Check if already cancelled # Check if already cancelled
if appointment.status != AppointmentStatus.BOOKED: if appointment.status != AppointmentStatus.BOOKED:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot cancel appointment with status '{appointment.status.value}'" detail=f"Cannot cancel: status is '{appointment.status.value}'",
) )
# Check if appointment is in the past # Check if appointment is in the past
if appointment.slot_start <= datetime.now(timezone.utc): if appointment.slot_start <= datetime.now(UTC):
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC" apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started." detail=f"Cannot cancel appointment at {apt_time} UTC: "
"already started or in the past",
) )
# Cancel the appointment # Cancel the appointment
appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN
appointment.cancelled_at = datetime.now(timezone.utc) appointment.cancelled_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(appointment) await db.refresh(appointment)

View file

@ -1,12 +1,12 @@
"""Counter routes.""" """Counter routes."""
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import Counter, User, CounterRecord, Permission from models import Counter, CounterRecord, Permission, User
router = APIRouter(prefix="/api/counter", tags=["counter"]) router = APIRouter(prefix="/api/counter", tags=["counter"])

View file

@ -1,25 +1,29 @@
"""Invite routes for public check, user invites, and admin management.""" """Invite routes for public check, user invites, and admin management."""
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status, Query from datetime import UTC, datetime
from sqlalchemy import select, func, desc
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import desc, func, select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format from invite_utils import (
from models import User, Invite, InviteStatus, Permission generate_invite_identifier,
is_valid_identifier_format,
normalize_identifier,
)
from models import Invite, InviteStatus, Permission, User
from schemas import ( from schemas import (
AdminUserResponse,
InviteCheckResponse, InviteCheckResponse,
InviteCreate, InviteCreate,
InviteResponse, InviteResponse,
UserInviteResponse,
PaginatedInviteRecords, PaginatedInviteRecords,
AdminUserResponse, UserInviteResponse,
) )
router = APIRouter(prefix="/api/invites", tags=["invites"]) router = APIRouter(prefix="/api/invites", tags=["invites"])
admin_router = APIRouter(prefix="/api/admin", tags=["admin"]) admin_router = APIRouter(prefix="/api/admin", tags=["admin"])
@ -54,9 +58,7 @@ async def check_invite(
if not is_valid_identifier_format(normalized): if not is_valid_identifier_format(normalized):
return InviteCheckResponse(valid=False, error="Invalid invite code format") return InviteCheckResponse(valid=False, error="Invalid invite code format")
result = await db.execute( result = await db.execute(select(Invite).where(Invite.identifier == normalized))
select(Invite).where(Invite.identifier == normalized)
)
invite = result.scalar_one_or_none() invite = result.scalar_one_or_none()
# Return same error for not found, spent, and revoked to avoid information leakage # Return same error for not found, spent, and revoked to avoid information leakage
@ -112,9 +114,7 @@ async def create_invite(
) -> InviteResponse: ) -> InviteResponse:
"""Create a new invite for a specified godfather user.""" """Create a new invite for a specified godfather user."""
# Validate godfather exists # Validate godfather exists
result = await db.execute( result = await db.execute(select(User.id).where(User.id == data.godfather_id))
select(User.id).where(User.id == data.godfather_id)
)
godfather_id = result.scalar_one_or_none() godfather_id = result.scalar_one_or_none()
if not godfather_id: if not godfather_id:
raise HTTPException( raise HTTPException(
@ -141,8 +141,8 @@ async def create_invite(
if attempt == MAX_INVITE_COLLISION_RETRIES - 1: if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate unique invite code. Please try again.", detail="Failed to generate unique invite code. Try again.",
) ) from None
if invite is None: if invite is None:
raise HTTPException( raise HTTPException(
@ -156,7 +156,9 @@ async def create_invite(
async def list_all_invites( async def list_all_invites(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100), per_page: int = Query(10, ge=1, le=100),
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"), status_filter: str | None = Query(
None, alias="status", description="Filter by status: ready, spent, revoked"
),
godfather_id: int | None = Query(None, description="Filter by godfather user ID"), godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)), _current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
@ -175,8 +177,9 @@ async def list_all_invites(
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked", detail=f"Invalid status: {status_filter}. "
) "Must be ready, spent, or revoked",
) from None
if godfather_id: if godfather_id:
query = query.where(Invite.godfather_id == godfather_id) query = query.where(Invite.godfather_id == godfather_id)
@ -224,11 +227,12 @@ async def revoke_invite(
if invite.status != InviteStatus.READY: if invite.status != InviteStatus.READY:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.", detail=f"Cannot revoke invite with status '{invite.status.value}'. "
"Only READY invites can be revoked.",
) )
invite.status = InviteStatus.REVOKED invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(timezone.utc) invite.revoked_at = datetime.now(UTC)
await db.commit() await db.commit()
await db.refresh(invite) await db.refresh(invite)

View file

@ -1,7 +1,8 @@
"""Meta endpoints for shared constants.""" """Meta endpoints for shared constants."""
from fastapi import APIRouter from fastapi import APIRouter
from models import Permission, InviteStatus, ROLE_ADMIN, ROLE_REGULAR from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, Permission
from schemas import ConstantsResponse from schemas import ConstantsResponse
router = APIRouter(prefix="/api/meta", tags=["meta"]) router = APIRouter(prefix="/api/meta", tags=["meta"])
@ -15,4 +16,3 @@ async def get_constants() -> ConstantsResponse:
roles=[ROLE_ADMIN, ROLE_REGULAR], roles=[ROLE_ADMIN, ROLE_REGULAR],
invite_statuses=[s.value for s in InviteStatus], invite_statuses=[s.value for s in InviteStatus],
) )

View file

@ -1,15 +1,15 @@
"""Profile routes for user contact details.""" """Profile routes for user contact details."""
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import get_current_user from auth import get_current_user
from database import get_db from database import get_db
from models import User, ROLE_REGULAR from models import ROLE_REGULAR, User
from schemas import ProfileResponse, ProfileUpdate from schemas import ProfileResponse, ProfileUpdate
from validation import validate_profile_fields from validation import validate_profile_fields
router = APIRouter(prefix="/api/profile", tags=["profile"]) router = APIRouter(prefix="/api/profile", tags=["profile"])
@ -29,9 +29,7 @@ async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str
"""Get the email of a godfather user by ID.""" """Get the email of a godfather user by ID."""
if not godfather_id: if not godfather_id:
return None return None
result = await db.execute( result = await db.execute(select(User.email).where(User.id == godfather_id))
select(User.email).where(User.id == godfather_id)
)
return result.scalar_one_or_none() return result.scalar_one_or_none()

View file

@ -1,13 +1,13 @@
"""Sum calculation routes.""" """Sum calculation routes."""
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission from auth import require_permission
from database import get_db from database import get_db
from models import User, SumRecord, Permission from models import Permission, SumRecord, User
from schemas import SumRequest, SumResponse from schemas import SumRequest, SumResponse
router = APIRouter(prefix="/api/sum", tags=["sum"]) router = APIRouter(prefix="/api/sum", tags=["sum"])

View file

@ -1,5 +1,6 @@
"""Pydantic schemas for API request/response models.""" """Pydantic schemas for API request/response models."""
from datetime import datetime, date, time
from datetime import date, datetime, time
from typing import Generic, TypeVar from typing import Generic, TypeVar
from pydantic import BaseModel, EmailStr, field_validator from pydantic import BaseModel, EmailStr, field_validator
@ -9,6 +10,7 @@ from shared_constants import NOTE_MAX_LENGTH
class UserCredentials(BaseModel): class UserCredentials(BaseModel):
"""Base model for user email/password.""" """Base model for user email/password."""
email: EmailStr email: EmailStr
password: str password: str
@ -19,6 +21,7 @@ UserLogin = UserCredentials
class UserResponse(BaseModel): class UserResponse(BaseModel):
"""Response model for authenticated user info.""" """Response model for authenticated user info."""
id: int id: int
email: str email: str
roles: list[str] roles: list[str]
@ -27,6 +30,7 @@ class UserResponse(BaseModel):
class RegisterWithInvite(BaseModel): class RegisterWithInvite(BaseModel):
"""Request model for registration with invite.""" """Request model for registration with invite."""
email: EmailStr email: EmailStr
password: str password: str
invite_identifier: str invite_identifier: str
@ -34,12 +38,14 @@ class RegisterWithInvite(BaseModel):
class SumRequest(BaseModel): class SumRequest(BaseModel):
"""Request model for sum calculation.""" """Request model for sum calculation."""
a: float a: float
b: float b: float
class SumResponse(BaseModel): class SumResponse(BaseModel):
"""Response model for sum calculation.""" """Response model for sum calculation."""
a: float a: float
b: float b: float
result: float result: float
@ -47,6 +53,7 @@ class SumResponse(BaseModel):
class CounterRecordResponse(BaseModel): class CounterRecordResponse(BaseModel):
"""Response model for a counter audit record.""" """Response model for a counter audit record."""
id: int id: int
user_email: str user_email: str
value_before: int value_before: int
@ -56,6 +63,7 @@ class CounterRecordResponse(BaseModel):
class SumRecordResponse(BaseModel): class SumRecordResponse(BaseModel):
"""Response model for a sum audit record.""" """Response model for a sum audit record."""
id: int id: int
user_email: str user_email: str
a: float a: float
@ -69,6 +77,7 @@ RecordT = TypeVar("RecordT", bound=BaseModel)
class PaginatedResponse(BaseModel, Generic[RecordT]): class PaginatedResponse(BaseModel, Generic[RecordT]):
"""Generic paginated response wrapper.""" """Generic paginated response wrapper."""
records: list[RecordT] records: list[RecordT]
total: int total: int
page: int page: int
@ -82,6 +91,7 @@ PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
class ProfileResponse(BaseModel): class ProfileResponse(BaseModel):
"""Response model for profile data.""" """Response model for profile data."""
contact_email: str | None contact_email: str | None
telegram: str | None telegram: str | None
signal: str | None signal: str | None
@ -91,6 +101,7 @@ class ProfileResponse(BaseModel):
class ProfileUpdate(BaseModel): class ProfileUpdate(BaseModel):
"""Request model for updating profile.""" """Request model for updating profile."""
contact_email: str | None = None contact_email: str | None = None
telegram: str | None = None telegram: str | None = None
signal: str | None = None signal: str | None = None
@ -99,6 +110,7 @@ class ProfileUpdate(BaseModel):
class InviteCheckResponse(BaseModel): class InviteCheckResponse(BaseModel):
"""Response for invite check endpoint.""" """Response for invite check endpoint."""
valid: bool valid: bool
status: str | None = None status: str | None = None
error: str | None = None error: str | None = None
@ -106,11 +118,13 @@ class InviteCheckResponse(BaseModel):
class InviteCreate(BaseModel): class InviteCreate(BaseModel):
"""Request model for creating an invite.""" """Request model for creating an invite."""
godfather_id: int godfather_id: int
class InviteResponse(BaseModel): class InviteResponse(BaseModel):
"""Response model for invite data (admin view).""" """Response model for invite data (admin view)."""
id: int id: int
identifier: str identifier: str
godfather_id: int godfather_id: int
@ -125,6 +139,7 @@ class InviteResponse(BaseModel):
class UserInviteResponse(BaseModel): class UserInviteResponse(BaseModel):
"""Response model for a user's invite (simpler than admin view).""" """Response model for a user's invite (simpler than admin view)."""
id: int id: int
identifier: str identifier: str
status: str status: str
@ -138,6 +153,7 @@ PaginatedInviteRecords = PaginatedResponse[InviteResponse]
class AdminUserResponse(BaseModel): class AdminUserResponse(BaseModel):
"""Minimal user info for admin dropdowns.""" """Minimal user info for admin dropdowns."""
id: int id: int
email: str email: str
@ -146,8 +162,10 @@ class AdminUserResponse(BaseModel):
# Availability Schemas # Availability Schemas
# ============================================================================= # =============================================================================
class TimeSlot(BaseModel): class TimeSlot(BaseModel):
"""A single time slot (start and end time).""" """A single time slot (start and end time)."""
start_time: time start_time: time
end_time: time end_time: time
@ -164,23 +182,27 @@ class TimeSlot(BaseModel):
class AvailabilityDay(BaseModel): class AvailabilityDay(BaseModel):
"""Availability for a single day.""" """Availability for a single day."""
date: date date: date
slots: list[TimeSlot] slots: list[TimeSlot]
class AvailabilityResponse(BaseModel): class AvailabilityResponse(BaseModel):
"""Response model for availability query.""" """Response model for availability query."""
days: list[AvailabilityDay] days: list[AvailabilityDay]
class SetAvailabilityRequest(BaseModel): class SetAvailabilityRequest(BaseModel):
"""Request to set availability for a specific date.""" """Request to set availability for a specific date."""
date: date date: date
slots: list[TimeSlot] slots: list[TimeSlot]
class CopyAvailabilityRequest(BaseModel): class CopyAvailabilityRequest(BaseModel):
"""Request to copy availability from one day to others.""" """Request to copy availability from one day to others."""
source_date: date source_date: date
target_dates: list[date] target_dates: list[date]
@ -189,20 +211,24 @@ class CopyAvailabilityRequest(BaseModel):
# Booking Schemas # Booking Schemas
# ============================================================================= # =============================================================================
class BookableSlot(BaseModel): class BookableSlot(BaseModel):
"""A bookable 15-minute slot.""" """A bookable 15-minute slot."""
start_time: datetime start_time: datetime
end_time: datetime end_time: datetime
class AvailableSlotsResponse(BaseModel): class AvailableSlotsResponse(BaseModel):
"""Response for available slots on a given date.""" """Response for available slots on a given date."""
date: date date: date
slots: list[BookableSlot] slots: list[BookableSlot]
class BookingRequest(BaseModel): class BookingRequest(BaseModel):
"""Request to book an appointment.""" """Request to book an appointment."""
slot_start: datetime slot_start: datetime
note: str | None = None note: str | None = None
@ -216,6 +242,7 @@ class BookingRequest(BaseModel):
class AppointmentResponse(BaseModel): class AppointmentResponse(BaseModel):
"""Response model for an appointment.""" """Response model for an appointment."""
id: int id: int
user_id: int user_id: int
user_email: str user_email: str
@ -234,8 +261,10 @@ PaginatedAppointments = PaginatedResponse[AppointmentResponse]
# Meta/Constants Schemas # Meta/Constants Schemas
# ============================================================================= # =============================================================================
class ConstantsResponse(BaseModel): class ConstantsResponse(BaseModel):
"""Response model for shared constants.""" """Response model for shared constants."""
permissions: list[str] permissions: list[str]
roles: list[str] roles: list[str]
invite_statuses: list[str] invite_statuses: list[str]

View file

@ -1,12 +1,21 @@
"""Seed the database with roles, permissions, and dev users.""" """Seed the database with roles, permissions, and dev users."""
import asyncio import asyncio
import os import os
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, async_session, Base
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
from auth import get_password_hash from auth import get_password_hash
from database import Base, async_session, engine
from models import (
ROLE_ADMIN,
ROLE_DEFINITIONS,
ROLE_REGULAR,
Permission,
Role,
User,
)
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"] DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"] DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
@ -14,7 +23,9 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"] DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role: async def upsert_role(
db: AsyncSession, name: str, description: str, permissions: list[Permission]
) -> Role:
"""Create or update a role with the given permissions.""" """Create or update a role with the given permissions."""
result = await db.execute(select(Role).where(Role.name == name)) result = await db.execute(select(Role).where(Role.name == name))
role = result.scalar_one_or_none() role = result.scalar_one_or_none()
@ -35,7 +46,9 @@ async def upsert_role(db: AsyncSession, name: str, description: str, permissions
return role return role
async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User: async def upsert_user(
db: AsyncSession, email: str, password: str, role_names: list[str]
) -> User:
"""Create or update a user with the given credentials and roles.""" """Create or update a user with the given credentials and roles."""
result = await db.execute(select(User).where(User.email == email)) result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()

View file

@ -1,4 +1,5 @@
"""Load shared constants from shared/constants.json.""" """Load shared constants from shared/constants.json."""
import json import json
from pathlib import Path from pathlib import Path
@ -10,4 +11,3 @@ SLOT_DURATION_MINUTES: int = _constants["booking"]["slotDurationMinutes"]
MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"] MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"]
MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"] MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"]
NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"] NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"]

View file

@ -7,17 +7,17 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from auth import get_password_hash
from database import Base, get_db from database import Base, get_db
from main import app from main import app
from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN from models import ROLE_ADMIN, ROLE_DEFINITIONS, ROLE_REGULAR, Role, User
from auth import get_password_hash
from tests.helpers import unique_email from tests.helpers import unique_email
TEST_DATABASE_URL = os.getenv( TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL", "TEST_DATABASE_URL",
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test" "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test",
) )
@ -91,7 +91,9 @@ async def create_user_with_roles(
result = await db.execute(select(Role).where(Role.name == role_name)) result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none() role = result.scalar_one_or_none()
if not role: if not role:
raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?") raise ValueError(
f"Role '{role_name}' not found. Did you run setup_roles()?"
)
roles.append(role) roles.append(role)
user = User( user = User(

View file

@ -3,8 +3,8 @@ import uuid
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from models import User, Invite, InviteStatus
from invite_utils import generate_invite_identifier from invite_utils import generate_invite_identifier
from models import Invite, InviteStatus, User
def unique_email(prefix: str = "test") -> str: def unique_email(prefix: str = "test") -> str:
@ -67,7 +67,9 @@ async def create_invite_for_registration(db: AsyncSession, godfather_email: str)
godfather = result.scalar_one_or_none() godfather = result.scalar_one_or_none()
if not godfather: if not godfather:
raise ValueError(f"Godfather user with email '{godfather_email}' not found. " raise ValueError(
"Create the user first using create_user_with_roles().") f"Godfather user with email '{godfather_email}' not found. "
"Create the user first using create_user_with_roles()."
)
return await create_invite_for_godfather(db, godfather.id) return await create_invite_for_godfather(db, godfather.id)

View file

@ -3,12 +3,13 @@
Note: Registration now requires an invite code. Tests that need to register Note: Registration now requires an invite code. Tests that need to register
users will create invites first via the helper function. users will create invites first via the helper function.
""" """
import pytest import pytest
from auth import COOKIE_NAME from auth import COOKIE_NAME
from models import ROLE_REGULAR from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
# Registration tests (with invite) # Registration tests (with invite)
@ -19,7 +20,9 @@ async def test_register_success(client_factory):
# Create godfather user and invite # Create godfather user and invite
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("godfather"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
@ -49,7 +52,9 @@ async def test_register_duplicate_email(client_factory):
# Create godfather and two invites # Create godfather and two invites
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id)
@ -80,7 +85,9 @@ async def test_register_duplicate_email(client_factory):
async def test_register_invalid_email(client_factory): async def test_register_invalid_email(client_factory):
"""Cannot register with invalid email format.""" """Cannot register with invalid email format."""
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
@ -138,7 +145,9 @@ async def test_login_success(client_factory):
email = unique_email("login") email = unique_email("login")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
@ -167,7 +176,9 @@ async def test_login_wrong_password(client_factory):
email = unique_email("wrongpass") email = unique_email("wrongpass")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
@ -221,7 +232,9 @@ async def test_get_me_success(client_factory):
email = unique_email("me") email = unique_email("me")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(
@ -255,7 +268,9 @@ async def test_get_me_no_cookie(client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_me_invalid_cookie(client_factory): async def test_get_me_invalid_cookie(client_factory):
"""Cannot get current user with invalid cookie.""" """Cannot get current user with invalid cookie."""
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed: async with client_factory.create(
cookies={COOKIE_NAME: "invalidtoken123"}
) as authed:
response = await authed.get("/api/auth/me") response = await authed.get("/api/auth/me")
assert response.status_code == 401 assert response.status_code == 401
assert response.json()["detail"] == "Invalid authentication credentials" assert response.json()["detail"] == "Invalid authentication credentials"
@ -277,7 +292,9 @@ async def test_cookie_from_register_works_for_me(client_factory):
email = unique_email("tokentest") email = unique_email("tokentest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(
@ -303,7 +320,9 @@ async def test_cookie_from_login_works_for_me(client_factory):
email = unique_email("logintoken") email = unique_email("logintoken")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
@ -335,7 +354,9 @@ async def test_multiple_users_isolated(client_factory):
email2 = unique_email("user2") email2 = unique_email("user2")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id)
@ -377,7 +398,9 @@ async def test_password_is_hashed(client_factory):
email = unique_email("hashtest") email = unique_email("hashtest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
@ -401,7 +424,9 @@ async def test_case_sensitive_password(client_factory):
email = unique_email("casetest") email = unique_email("casetest")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post( await client_factory.post(
@ -426,7 +451,9 @@ async def test_logout_success(client_factory):
email = unique_email("logout") email = unique_email("logout")
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post( reg_response = await client_factory.post(

View file

@ -3,7 +3,9 @@ Availability API Tests
Tests for the admin availability management endpoints. Tests for the admin availability management endpoints.
""" """
from datetime import date, time, timedelta
from datetime import date, timedelta
import pytest import pytest
@ -19,6 +21,7 @@ def in_days(n: int) -> date:
# Permission Tests # Permission Tests
# ============================================================================= # =============================================================================
class TestAvailabilityPermissions: class TestAvailabilityPermissions:
"""Test that only admins can access availability endpoints.""" """Test that only admins can access availability endpoints."""
@ -44,7 +47,9 @@ class TestAvailabilityPermissions:
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_get_availability(self, client_factory, regular_user): async def test_regular_user_cannot_get_availability(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get( response = await client.get(
"/api/admin/availability", "/api/admin/availability",
@ -53,7 +58,9 @@ class TestAvailabilityPermissions:
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_set_availability(self, client_factory, regular_user): async def test_regular_user_cannot_set_availability(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.put( response = await client.put(
"/api/admin/availability", "/api/admin/availability",
@ -88,6 +95,7 @@ class TestAvailabilityPermissions:
# Set Availability Tests # Set Availability Tests
# ============================================================================= # =============================================================================
class TestSetAvailability: class TestSetAvailability:
"""Test setting availability for a date.""" """Test setting availability for a date."""
@ -128,7 +136,9 @@ class TestSetAvailability:
assert len(data["slots"]) == 2 assert len(data["slots"]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_empty_slots_clears_availability(self, client_factory, admin_user): async def test_set_empty_slots_clears_availability(
self, client_factory, admin_user
):
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
# First set some availability # First set some availability
await client.put( await client.put(
@ -162,7 +172,7 @@ class TestSetAvailability:
) )
# Replace with different slots # Replace with different slots
response = await client.put( await client.put(
"/api/admin/availability", "/api/admin/availability",
json={ json={
"date": str(tomorrow()), "date": str(tomorrow()),
@ -186,6 +196,7 @@ class TestSetAvailability:
# Validation Tests # Validation Tests
# ============================================================================= # =============================================================================
class TestAvailabilityValidation: class TestAvailabilityValidation:
"""Test validation rules for availability.""" """Test validation rules for availability."""
@ -283,6 +294,7 @@ class TestAvailabilityValidation:
# Get Availability Tests # Get Availability Tests
# ============================================================================= # =============================================================================
class TestGetAvailability: class TestGetAvailability:
"""Test retrieving availability.""" """Test retrieving availability."""
@ -360,6 +372,7 @@ class TestGetAvailability:
# Copy Availability Tests # Copy Availability Tests
# ============================================================================= # =============================================================================
class TestCopyAvailability: class TestCopyAvailability:
"""Test copying availability from one day to others.""" """Test copying availability from one day to others."""

View file

@ -3,7 +3,9 @@ Booking API Tests
Tests for the user booking endpoints. Tests for the user booking endpoints.
""" """
from datetime import date, datetime, timedelta, timezone
from datetime import UTC, date, datetime, timedelta
import pytest import pytest
from models import Appointment, AppointmentStatus from models import Appointment, AppointmentStatus
@ -21,11 +23,14 @@ def in_days(n: int) -> date:
# Permission Tests # Permission Tests
# ============================================================================= # =============================================================================
class TestBookingPermissions: class TestBookingPermissions:
"""Test that only regular users can book appointments.""" """Test that only regular users can book appointments."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_can_get_slots(self, client_factory, regular_user, admin_user): async def test_regular_user_can_get_slots(
self, client_factory, regular_user, admin_user
):
"""Regular user can get available slots.""" """Regular user can get available slots."""
# First, admin sets up availability # First, admin sets up availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -39,12 +44,16 @@ class TestBookingPermissions:
# Regular user gets slots # Regular user gets slots
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200 assert response.status_code == 200
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_can_book(self, client_factory, regular_user, admin_user): async def test_regular_user_can_book(
self, client_factory, regular_user, admin_user
):
"""Regular user can book an appointment.""" """Regular user can book an appointment."""
# Admin sets up availability # Admin sets up availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -69,7 +78,9 @@ class TestBookingPermissions:
async def test_admin_cannot_get_slots(self, client_factory, admin_user): async def test_admin_cannot_get_slots(self, client_factory, admin_user):
"""Admin cannot access booking slots endpoint.""" """Admin cannot access booking slots endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 403 assert response.status_code == 403
@ -96,7 +107,9 @@ class TestBookingPermissions:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unauthenticated_cannot_get_slots(self, client): async def test_unauthenticated_cannot_get_slots(self, client):
"""Unauthenticated user cannot get slots.""" """Unauthenticated user cannot get slots."""
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
@ -113,6 +126,7 @@ class TestBookingPermissions:
# Get Slots Tests # Get Slots Tests
# ============================================================================= # =============================================================================
class TestGetSlots: class TestGetSlots:
"""Test getting available booking slots.""" """Test getting available booking slots."""
@ -120,7 +134,9 @@ class TestGetSlots:
async def test_get_slots_no_availability(self, client_factory, regular_user): async def test_get_slots_no_availability(self, client_factory, regular_user):
"""Returns empty slots when no availability set.""" """Returns empty slots when no availability set."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@ -128,7 +144,9 @@ class TestGetSlots:
assert data["slots"] == [] assert data["slots"] == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_slots_expands_to_15min(self, client_factory, regular_user, admin_user): async def test_get_slots_expands_to_15min(
self, client_factory, regular_user, admin_user
):
"""Availability is expanded into 15-minute slots.""" """Availability is expanded into 15-minute slots."""
# Admin sets 1-hour availability # Admin sets 1-hour availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -142,7 +160,9 @@ class TestGetSlots:
# User gets slots - should be 4 x 15-minute slots # User gets slots - should be 4 x 15-minute slots
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@ -156,7 +176,9 @@ class TestGetSlots:
assert "10:00:00" in data["slots"][3]["end_time"] assert "10:00:00" in data["slots"][3]["end_time"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_slots_excludes_booked(self, client_factory, regular_user, admin_user): async def test_get_slots_excludes_booked(
self, client_factory, regular_user, admin_user
):
"""Already booked slots are excluded from available slots.""" """Already booked slots are excluded from available slots."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -176,7 +198,9 @@ class TestGetSlots:
) )
# Get slots again - should have 3 left # Get slots again - should have 3 left
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())}) response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@ -189,6 +213,7 @@ class TestGetSlots:
# Booking Tests # Booking Tests
# ============================================================================= # =============================================================================
class TestCreateBooking: class TestCreateBooking:
"""Test creating bookings.""" """Test creating bookings."""
@ -248,7 +273,9 @@ class TestCreateBooking:
assert data["note"] is None assert data["note"] is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_double_book_slot(self, client_factory, regular_user, admin_user, alt_regular_user): async def test_cannot_double_book_slot(
self, client_factory, regular_user, admin_user, alt_regular_user
):
"""Cannot book a slot that's already booked.""" """Cannot book a slot that's already booked."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -279,7 +306,9 @@ class TestCreateBooking:
assert "already been booked" in response.json()["detail"] assert "already been booked" in response.json()["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_book_outside_availability(self, client_factory, regular_user, admin_user): async def test_cannot_book_outside_availability(
self, client_factory, regular_user, admin_user
):
"""Cannot book a slot outside of availability.""" """Cannot book a slot outside of availability."""
# Admin sets availability for morning only # Admin sets availability for morning only
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -306,6 +335,7 @@ class TestCreateBooking:
# Date Validation Tests # Date Validation Tests
# ============================================================================= # =============================================================================
class TestBookingDateValidation: class TestBookingDateValidation:
"""Test date validation for bookings.""" """Test date validation for bookings."""
@ -319,7 +349,10 @@ class TestBookingDateValidation:
) )
assert response.status_code == 400 assert response.status_code == 400
assert "past" in response.json()["detail"].lower() or "today" in response.json()["detail"].lower() assert (
"past" in response.json()["detail"].lower()
or "today" in response.json()["detail"].lower()
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_book_past_date(self, client_factory, regular_user): async def test_cannot_book_past_date(self, client_factory, regular_user):
@ -350,7 +383,9 @@ class TestBookingDateValidation:
async def test_cannot_get_slots_today(self, client_factory, regular_user): async def test_cannot_get_slots_today(self, client_factory, regular_user):
"""Cannot get slots for today.""" """Cannot get slots for today."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(date.today())}) response = await client.get(
"/api/booking/slots", params={"date": str(date.today())}
)
assert response.status_code == 400 assert response.status_code == 400
@ -359,7 +394,9 @@ class TestBookingDateValidation:
"""Cannot get slots for past date.""" """Cannot get slots for past date."""
yesterday = date.today() - timedelta(days=1) yesterday = date.today() - timedelta(days=1)
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(yesterday)}) response = await client.get(
"/api/booking/slots", params={"date": str(yesterday)}
)
assert response.status_code == 400 assert response.status_code == 400
@ -368,11 +405,14 @@ class TestBookingDateValidation:
# Time Validation Tests # Time Validation Tests
# ============================================================================= # =============================================================================
class TestBookingTimeValidation: class TestBookingTimeValidation:
"""Test time validation for bookings.""" """Test time validation for bookings."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_slot_must_be_15min_boundary(self, client_factory, regular_user, admin_user): async def test_slot_must_be_15min_boundary(
self, client_factory, regular_user, admin_user
):
"""Slot start time must be on 15-minute boundary.""" """Slot start time must be on 15-minute boundary."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -399,6 +439,7 @@ class TestBookingTimeValidation:
# Note Validation Tests # Note Validation Tests
# ============================================================================= # =============================================================================
class TestBookingNoteValidation: class TestBookingNoteValidation:
"""Test note validation for bookings.""" """Test note validation for bookings."""
@ -426,7 +467,9 @@ class TestBookingNoteValidation:
assert response.status_code == 422 assert response.status_code == 422
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_note_exactly_144_chars(self, client_factory, regular_user, admin_user): async def test_note_exactly_144_chars(
self, client_factory, regular_user, admin_user
):
"""Note of exactly 144 characters is allowed.""" """Note of exactly 144 characters is allowed."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -454,6 +497,7 @@ class TestBookingNoteValidation:
# User Appointments Tests # User Appointments Tests
# ============================================================================= # =============================================================================
class TestUserAppointments: class TestUserAppointments:
"""Test user appointments endpoints.""" """Test user appointments endpoints."""
@ -467,7 +511,9 @@ class TestUserAppointments:
assert response.json() == [] assert response.json() == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_my_appointments_with_bookings(self, client_factory, regular_user, admin_user): async def test_get_my_appointments_with_bookings(
self, client_factory, regular_user, admin_user
):
"""Returns user's appointments.""" """Returns user's appointments."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -502,7 +548,9 @@ class TestUserAppointments:
assert "Second" in notes assert "Second" in notes
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_view_user_appointments(self, client_factory, admin_user): async def test_admin_cannot_view_user_appointments(
self, client_factory, admin_user
):
"""Admin cannot access user appointments endpoint.""" """Admin cannot access user appointments endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/appointments") response = await client.get("/api/appointments")
@ -520,7 +568,9 @@ class TestCancelAppointment:
"""Test cancelling appointments.""" """Test cancelling appointments."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cancel_own_appointment(self, client_factory, regular_user, admin_user): async def test_cancel_own_appointment(
self, client_factory, regular_user, admin_user
):
"""User can cancel their own appointment.""" """User can cancel their own appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -549,7 +599,9 @@ class TestCancelAppointment:
assert data["cancelled_at"] is not None assert data["cancelled_at"] is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_cancel_others_appointment(self, client_factory, regular_user, alt_regular_user, admin_user): async def test_cannot_cancel_others_appointment(
self, client_factory, regular_user, alt_regular_user, admin_user
):
"""User cannot cancel another user's appointment.""" """User cannot cancel another user's appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -577,7 +629,9 @@ class TestCancelAppointment:
assert "another user" in response.json()["detail"].lower() assert "another user" in response.json()["detail"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_cancel_nonexistent_appointment(self, client_factory, regular_user): async def test_cannot_cancel_nonexistent_appointment(
self, client_factory, regular_user
):
"""Returns 404 for non-existent appointment.""" """Returns 404 for non-existent appointment."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post("/api/appointments/99999/cancel") response = await client.post("/api/appointments/99999/cancel")
@ -585,7 +639,9 @@ class TestCancelAppointment:
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user): async def test_cannot_cancel_already_cancelled(
self, client_factory, regular_user, admin_user
):
"""Cannot cancel an already cancelled appointment.""" """Cannot cancel an already cancelled appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -613,7 +669,9 @@ class TestCancelAppointment:
assert "cancelled_by_user" in response.json()["detail"] assert "cancelled_by_user" in response.json()["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_use_user_cancel_endpoint(self, client_factory, admin_user): async def test_admin_cannot_use_user_cancel_endpoint(
self, client_factory, admin_user
):
"""Admin cannot use user cancel endpoint.""" """Admin cannot use user cancel endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/appointments/1/cancel") response = await client.post("/api/appointments/1/cancel")
@ -621,7 +679,9 @@ class TestCancelAppointment:
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cancelled_slot_becomes_available(self, client_factory, regular_user, admin_user): async def test_cancelled_slot_becomes_available(
self, client_factory, regular_user, admin_user
):
"""After cancelling, the slot becomes available again.""" """After cancelling, the slot becomes available again."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -663,7 +723,7 @@ class TestCancelAppointment:
"""User cannot cancel a past appointment.""" """User cannot cancel a past appointment."""
# Create a past appointment directly in DB # Create a past appointment directly in DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
past_time = datetime.now(timezone.utc) - timedelta(hours=1) past_time = datetime.now(UTC) - timedelta(hours=1)
appointment = Appointment( appointment = Appointment(
user_id=regular_user["user"]["id"], user_id=regular_user["user"]["id"],
slot_start=past_time, slot_start=past_time,
@ -687,11 +747,14 @@ class TestCancelAppointment:
# Admin Appointments Tests # Admin Appointments Tests
# ============================================================================= # =============================================================================
class TestAdminViewAppointments: class TestAdminViewAppointments:
"""Test admin viewing all appointments.""" """Test admin viewing all appointments."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_view_all_appointments(self, client_factory, regular_user, admin_user): async def test_admin_can_view_all_appointments(
self, client_factory, regular_user, admin_user
):
"""Admin can view all appointments.""" """Admin can view all appointments."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -725,7 +788,9 @@ class TestAdminViewAppointments:
assert any(apt["note"] == "Test" for apt in data["records"]) assert any(apt["note"] == "Test" for apt in data["records"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_view_all_appointments(self, client_factory, regular_user): async def test_regular_user_cannot_view_all_appointments(
self, client_factory, regular_user
):
"""Regular user cannot access admin appointments endpoint.""" """Regular user cannot access admin appointments endpoint."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/admin/appointments") response = await client.get("/api/admin/appointments")
@ -743,7 +808,9 @@ class TestAdminCancelAppointment:
"""Test admin cancelling appointments.""" """Test admin cancelling appointments."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_cancel_any_appointment(self, client_factory, regular_user, admin_user): async def test_admin_can_cancel_any_appointment(
self, client_factory, regular_user, admin_user
):
"""Admin can cancel any user's appointment.""" """Admin can cancel any user's appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -765,7 +832,9 @@ class TestAdminCancelAppointment:
# Admin cancels # Admin cancels
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@ -773,7 +842,9 @@ class TestAdminCancelAppointment:
assert data["cancelled_at"] is not None assert data["cancelled_at"] is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_use_admin_cancel(self, client_factory, regular_user, admin_user): async def test_regular_user_cannot_use_admin_cancel(
self, client_factory, regular_user, admin_user
):
"""Regular user cannot use admin cancel endpoint.""" """Regular user cannot use admin cancel endpoint."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -799,7 +870,9 @@ class TestAdminCancelAppointment:
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cancel_nonexistent_appointment(self, client_factory, admin_user): async def test_admin_cancel_nonexistent_appointment(
self, client_factory, admin_user
):
"""Returns 404 for non-existent appointment.""" """Returns 404 for non-existent appointment."""
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/admin/appointments/99999/cancel") response = await client.post("/api/admin/appointments/99999/cancel")
@ -807,7 +880,9 @@ class TestAdminCancelAppointment:
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user): async def test_admin_cannot_cancel_already_cancelled(
self, client_factory, regular_user, admin_user
):
"""Admin cannot cancel an already cancelled appointment.""" """Admin cannot cancel an already cancelled appointment."""
# Admin sets availability # Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -832,17 +907,21 @@ class TestAdminCancelAppointment:
# Admin tries to cancel again # Admin tries to cancel again
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 400 assert response.status_code == 400
assert "cancelled_by_user" in response.json()["detail"] assert "cancelled_by_user" in response.json()["detail"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_cannot_cancel_past_appointment(self, client_factory, regular_user, admin_user): async def test_admin_cannot_cancel_past_appointment(
self, client_factory, regular_user, admin_user
):
"""Admin cannot cancel a past appointment.""" """Admin cannot cancel a past appointment."""
# Create a past appointment directly in DB # Create a past appointment directly in DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
past_time = datetime.now(timezone.utc) - timedelta(hours=1) past_time = datetime.now(UTC) - timedelta(hours=1)
appointment = Appointment( appointment = Appointment(
user_id=regular_user["user"]["id"], user_id=regular_user["user"]["id"],
slot_start=past_time, slot_start=past_time,
@ -856,8 +935,9 @@ class TestAdminCancelAppointment:
# Admin tries to cancel # Admin tries to cancel
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client: async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel") response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 400 assert response.status_code == 400
assert "past" in response.json()["detail"].lower() assert "past" in response.json()["detail"].lower()

View file

@ -2,12 +2,13 @@
Note: Registration now requires an invite code. Note: Registration now requires an invite code.
""" """
import pytest import pytest
from auth import COOKIE_NAME from auth import COOKIE_NAME
from models import ROLE_REGULAR from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
# Protected endpoint tests - without auth # Protected endpoint tests - without auth
@ -41,7 +42,9 @@ async def test_increment_counter_invalid_cookie(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_authenticated(client_factory): async def test_get_counter_authenticated(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
@ -64,7 +67,9 @@ async def test_get_counter_authenticated(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter(client_factory): async def test_increment_counter(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
@ -91,7 +96,9 @@ async def test_increment_counter(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter_multiple(client_factory): async def test_increment_counter_multiple(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
@ -120,7 +127,9 @@ async def test_increment_counter_multiple(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_after_increment(client_factory): async def test_get_counter_after_increment(client_factory):
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post( reg = await client_factory.post(
@ -149,7 +158,9 @@ async def test_get_counter_after_increment(client_factory):
async def test_counter_shared_between_users(client_factory): async def test_counter_shared_between_users(client_factory):
# Create godfather and invites for two users # Create godfather and invites for two users
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id) invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id) invite2 = await create_invite_for_godfather(db, godfather.id)

View file

@ -1,22 +1,23 @@
"""Tests for invite functionality.""" """Tests for invite functionality."""
import pytest import pytest
from sqlalchemy import select from sqlalchemy import select
from invite_utils import ( from invite_utils import (
generate_invite_identifier,
normalize_identifier,
is_valid_identifier_format,
BIP39_WORDS, BIP39_WORDS,
generate_invite_identifier,
is_valid_identifier_format,
normalize_identifier,
) )
from models import Invite, InviteStatus, User, ROLE_REGULAR from models import ROLE_REGULAR, Invite, InviteStatus, User
from tests.helpers import unique_email
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
from tests.helpers import unique_email
# ============================================================================ # ============================================================================
# Invite Utils Tests # Invite Utils Tests
# ============================================================================ # ============================================================================
def test_bip39_words_loaded(): def test_bip39_words_loaded():
"""BIP39 word list should have exactly 2048 words.""" """BIP39 word list should have exactly 2048 words."""
assert len(BIP39_WORDS) == 2048 assert len(BIP39_WORDS) == 2048
@ -89,6 +90,7 @@ def test_is_valid_identifier_format_invalid():
# Invite Model Tests # Invite Model Tests
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_invite(client_factory): async def test_create_invite(client_factory):
"""Can create an invite with godfather.""" """Can create an invite with godfather."""
@ -173,7 +175,7 @@ async def test_invite_unique_identifier(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invite_status_transitions(client_factory): async def test_invite_status_transitions(client_factory):
"""Invite status can be changed.""" """Invite status can be changed."""
from datetime import datetime, UTC from datetime import UTC, datetime
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles( godfather = await create_user_with_roles(
@ -206,7 +208,7 @@ async def test_invite_status_transitions(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invite_revoke(client_factory): async def test_invite_revoke(client_factory):
"""Invite can be revoked.""" """Invite can be revoked."""
from datetime import datetime, UTC from datetime import UTC, datetime
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles( godfather = await create_user_with_roles(
@ -236,6 +238,7 @@ async def test_invite_revoke(client_factory):
# User Godfather Tests # User Godfather Tests
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_godfather_relationship(client_factory): async def test_user_godfather_relationship(client_factory):
"""User can have a godfather.""" """User can have a godfather."""
@ -254,9 +257,7 @@ async def test_user_godfather_relationship(client_factory):
await db.commit() await db.commit()
# Query user fresh # Query user fresh
result = await db.execute( result = await db.execute(select(User).where(User.id == user.id))
select(User).where(User.id == user.id)
)
loaded_user = result.scalar_one() loaded_user = result.scalar_one()
assert loaded_user.godfather_id == godfather.id assert loaded_user.godfather_id == godfather.id
@ -280,6 +281,7 @@ async def test_user_without_godfather(client_factory):
# Admin Create Invite API Tests (Phase 2) # Admin Create Invite API Tests (Phase 2)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_create_invite(client_factory, admin_user, regular_user): async def test_admin_can_create_invite(client_factory, admin_user, regular_user):
"""Admin can create an invite for a regular user.""" """Admin can create an invite for a regular user."""
@ -387,9 +389,7 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
# Query from DB # Query from DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(select(Invite).where(Invite.id == invite_id))
select(Invite).where(Invite.id == invite_id)
)
invite = result.scalar_one() invite = result.scalar_one()
assert invite.identifier == data["identifier"] assert invite.identifier == data["identifier"]
@ -398,7 +398,9 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user): async def test_create_invite_retries_on_collision(
client_factory, admin_user, regular_user
):
"""Create invite retries with new identifier on collision.""" """Create invite retries with new identifier on collision."""
from unittest.mock import patch from unittest.mock import patch
@ -419,6 +421,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
# Mock generator to first return the same identifier (collision), then a new one # Mock generator to first return the same identifier (collision), then a new one
call_count = 0 call_count = 0
def mock_generator(): def mock_generator():
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
@ -426,7 +429,9 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
return identifier1 # Will collide return identifier1 # Will collide
return f"unique-word-{call_count:02d}" # Won't collide return f"unique-word-{call_count:02d}" # Won't collide
with patch("routes.invites.generate_invite_identifier", side_effect=mock_generator): with patch(
"routes.invites.generate_invite_identifier", side_effect=mock_generator
):
response2 = await client.post( response2 = await client.post(
"/api/admin/invites", "/api/admin/invites",
json={"godfather_id": godfather.id}, json={"godfather_id": godfather.id},
@ -442,6 +447,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
# Invite Check API Tests (Phase 3) # Invite Check API Tests (Phase 3)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_invite_valid(client_factory, admin_user, regular_user): async def test_check_invite_valid(client_factory, admin_user, regular_user):
"""Check endpoint returns valid=True for READY invite.""" """Check endpoint returns valid=True for READY invite."""
@ -509,7 +515,9 @@ async def test_check_invite_invalid_format(client_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_invite_spent_returns_not_found(client_factory, admin_user, regular_user): async def test_check_invite_spent_returns_not_found(
client_factory, admin_user, regular_user
):
"""Check endpoint returns same error for spent invite as for non-existent (no info leakage).""" """Check endpoint returns same error for spent invite as for non-existent (no info leakage)."""
# Create invite # Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -547,9 +555,11 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_invite_revoked_returns_not_found(client_factory, admin_user, regular_user): async def test_check_invite_revoked_returns_not_found(
client_factory, admin_user, regular_user
):
"""Check endpoint returns same error for revoked invite as for non-existent (no info leakage).""" """Check endpoint returns same error for revoked invite as for non-existent (no info leakage)."""
from datetime import datetime, UTC from datetime import UTC, datetime
# Create invite # Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -613,6 +623,7 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular
# Register with Invite Tests (Phase 3) # Register with Invite Tests (Phase 3)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_with_valid_invite(client_factory, admin_user, regular_user): async def test_register_with_valid_invite(client_factory, admin_user, regular_user):
"""Can register with valid invite code.""" """Can register with valid invite code."""
@ -681,9 +692,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
# Check invite status # Check invite status
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(select(Invite).where(Invite.id == invite_id))
select(Invite).where(Invite.id == invite_id)
)
invite = result.scalar_one() invite = result.scalar_one()
assert invite.status == InviteStatus.SPENT assert invite.status == InviteStatus.SPENT
@ -723,9 +732,7 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user)
# Check user's godfather # Check user's godfather
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(select(User).where(User.email == new_email))
select(User).where(User.email == new_email)
)
new_user = result.scalar_one() new_user = result.scalar_one()
assert new_user.godfather_id == godfather_id assert new_user.godfather_id == godfather_id
@ -794,7 +801,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_with_revoked_invite(client_factory, admin_user, regular_user): async def test_register_with_revoked_invite(client_factory, admin_user, regular_user):
"""Cannot register with revoked invite.""" """Cannot register with revoked invite."""
from datetime import datetime, UTC from datetime import UTC, datetime
# Create invite # Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -814,9 +821,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
# Revoke invite directly in DB # Revoke invite directly in DB
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
result = await db.execute( result = await db.execute(select(Invite).where(Invite.id == invite_id))
select(Invite).where(Invite.id == invite_id)
)
invite = result.scalar_one() invite = result.scalar_one()
invite.status = InviteStatus.REVOKED invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(UTC) invite.revoked_at = datetime.now(UTC)
@ -904,6 +909,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
# User Invites API Tests (Phase 4) # User Invites API Tests (Phase 4)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user): async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user):
"""Regular user can list their own invites.""" """Regular user can list their own invites."""
@ -941,7 +947,9 @@ async def test_user_with_no_invites_gets_empty_list(client_factory, regular_user
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regular_user): async def test_spent_invite_shows_used_by_email(
client_factory, admin_user, regular_user
):
"""Spent invite shows who used it.""" """Spent invite shows who used it."""
# Create invite for regular user # Create invite for regular user
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -1002,6 +1010,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory):
# Admin Invite Management Tests (Phase 5) # Admin Invite Management Tests (Phase 5)
# ============================================================================ # ============================================================================
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user): async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user):
"""Admin can list all invites.""" """Admin can list all invites."""
@ -1153,4 +1162,3 @@ async def test_regular_user_cannot_access_admin_invites(client_factory, regular_
# Revoke # Revoke
response = await client.post("/api/admin/invites/1/revoke") response = await client.post("/api/admin/invites/1/revoke")
assert response.status_code == 403 assert response.status_code == 403

View file

@ -7,15 +7,16 @@ These tests verify that:
3. Unauthenticated users are denied access (401) 3. Unauthenticated users are denied access (401)
4. The permission system cannot be bypassed 4. The permission system cannot be bypassed
""" """
import pytest import pytest
from models import Permission from models import Permission
# ============================================================================= # =============================================================================
# Role Assignment Tests # Role Assignment Tests
# ============================================================================= # =============================================================================
class TestRoleAssignment: class TestRoleAssignment:
"""Test that roles are properly assigned and returned.""" """Test that roles are properly assigned and returned."""
@ -40,7 +41,9 @@ class TestRoleAssignment:
assert "regular" not in data["roles"] assert "regular" not in data["roles"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_has_correct_permissions(self, client_factory, regular_user): async def test_regular_user_has_correct_permissions(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
@ -72,7 +75,9 @@ class TestRoleAssignment:
assert Permission.USE_SUM.value not in permissions assert Permission.USE_SUM.value not in permissions
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_with_no_roles_has_no_permissions(self, client_factory, user_no_roles): async def test_user_with_no_roles_has_no_permissions(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client: async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
@ -85,6 +90,7 @@ class TestRoleAssignment:
# Counter Endpoint Access Tests # Counter Endpoint Access Tests
# ============================================================================= # =============================================================================
class TestCounterAccess: class TestCounterAccess:
"""Test access control for counter endpoints.""" """Test access control for counter endpoints."""
@ -97,7 +103,9 @@ class TestCounterAccess:
assert "value" in response.json() assert "value" in response.json()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_can_increment_counter(self, client_factory, regular_user): async def test_regular_user_can_increment_counter(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post("/api/counter/increment") response = await client.post("/api/counter/increment")
@ -122,7 +130,9 @@ class TestCounterAccess:
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_without_roles_cannot_view_counter(self, client_factory, user_no_roles): async def test_user_without_roles_cannot_view_counter(
self, client_factory, user_no_roles
):
"""Users with no roles should be forbidden.""" """Users with no roles should be forbidden."""
async with client_factory.create(cookies=user_no_roles["cookies"]) as client: async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/counter") response = await client.get("/api/counter")
@ -146,6 +156,7 @@ class TestCounterAccess:
# Sum Endpoint Access Tests # Sum Endpoint Access Tests
# ============================================================================= # =============================================================================
class TestSumAccess: class TestSumAccess:
"""Test access control for sum endpoint.""" """Test access control for sum endpoint."""
@ -173,7 +184,9 @@ class TestSumAccess:
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_without_roles_cannot_use_sum(self, client_factory, user_no_roles): async def test_user_without_roles_cannot_use_sum(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client: async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.post( response = await client.post(
"/api/sum", "/api/sum",
@ -195,6 +208,7 @@ class TestSumAccess:
# Audit Endpoint Access Tests # Audit Endpoint Access Tests
# ============================================================================= # =============================================================================
class TestAuditAccess: class TestAuditAccess:
"""Test access control for audit endpoints.""" """Test access control for audit endpoints."""
@ -219,7 +233,9 @@ class TestAuditAccess:
assert "total" in data assert "total" in data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_view_counter_audit(self, client_factory, regular_user): async def test_regular_user_cannot_view_counter_audit(
self, client_factory, regular_user
):
"""Regular users should be forbidden from audit endpoints.""" """Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/counter") response = await client.get("/api/audit/counter")
@ -228,7 +244,9 @@ class TestAuditAccess:
assert "permission" in response.json()["detail"].lower() assert "permission" in response.json()["detail"].lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_regular_user_cannot_view_sum_audit(self, client_factory, regular_user): async def test_regular_user_cannot_view_sum_audit(
self, client_factory, regular_user
):
"""Regular users should be forbidden from audit endpoints.""" """Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/sum") response = await client.get("/api/audit/sum")
@ -236,7 +254,9 @@ class TestAuditAccess:
assert response.status_code == 403 assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_without_roles_cannot_view_audit(self, client_factory, user_no_roles): async def test_user_without_roles_cannot_view_audit(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client: async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/audit/counter") response = await client.get("/api/audit/counter")
@ -257,6 +277,7 @@ class TestAuditAccess:
# Offensive Security Tests - Bypass Attempts # Offensive Security Tests - Bypass Attempts
# ============================================================================= # =============================================================================
class TestSecurityBypassAttempts: class TestSecurityBypassAttempts:
""" """
Offensive tests that attempt to bypass security controls. Offensive tests that attempt to bypass security controls.
@ -264,7 +285,9 @@ class TestSecurityBypassAttempts:
""" """
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_access_audit_with_forged_role_claim(self, client_factory, regular_user): async def test_cannot_access_audit_with_forged_role_claim(
self, client_factory, regular_user
):
""" """
Attempt to access audit by somehow claiming admin role. Attempt to access audit by somehow claiming admin role.
The server should verify roles from DB, not trust client claims. The server should verify roles from DB, not trust client claims.
@ -287,14 +310,18 @@ class TestSecurityBypassAttempts:
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cannot_access_with_tampered_token(self, client_factory, regular_user): async def test_cannot_access_with_tampered_token(
self, client_factory, regular_user
):
"""Test that tokens signed with wrong key are rejected.""" """Test that tokens signed with wrong key are rejected."""
# Take a valid token and modify it # Take a valid token and modify it
original_token = regular_user["cookies"].get("auth_token", "") original_token = regular_user["cookies"].get("auth_token", "")
if original_token: if original_token:
tampered_token = original_token[:-5] + "XXXXX" tampered_token = original_token[:-5] + "XXXXX"
async with client_factory.create(cookies={"auth_token": tampered_token}) as client: async with client_factory.create(
cookies={"auth_token": tampered_token}
) as client:
response = await client.get("/api/counter") response = await client.get("/api/counter")
assert response.status_code == 401 assert response.status_code == 401
@ -305,12 +332,14 @@ class TestSecurityBypassAttempts:
Test that new registrations cannot claim admin role. Test that new registrations cannot claim admin role.
New users should only get 'regular' role by default. New users should only get 'regular' role by default.
""" """
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
from models import ROLE_REGULAR from models import ROLE_REGULAR
from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR]) godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id) invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post( response = await client_factory.post(
@ -341,15 +370,17 @@ class TestSecurityBypassAttempts:
If a user is deleted, their token should no longer work. If a user is deleted, their token should no longer work.
This tests that tokens are validated against current DB state. This tests that tokens are validated against current DB state.
""" """
from tests.helpers import unique_email
from sqlalchemy import delete from sqlalchemy import delete
from models import User from models import User
from tests.helpers import unique_email
email = unique_email("deleted") email = unique_email("deleted")
# Create and login user # Create and login user
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
user = await create_user_with_roles(db, email, "password123", ["regular"]) user = await create_user_with_roles(db, email, "password123", ["regular"])
user_id = user.id user_id = user.id
@ -376,15 +407,17 @@ class TestSecurityBypassAttempts:
If a user's role is changed, the change should be reflected If a user's role is changed, the change should be reflected
in subsequent requests (no stale permission cache). in subsequent requests (no stale permission cache).
""" """
from tests.helpers import unique_email
from sqlalchemy import select from sqlalchemy import select
from models import User, Role
from models import Role, User
from tests.helpers import unique_email
email = unique_email("rolechange") email = unique_email("rolechange")
# Create regular user # Create regular user
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
from tests.conftest import create_user_with_roles from tests.conftest import create_user_with_roles
await create_user_with_roles(db, email, "password123", ["regular"]) await create_user_with_roles(db, email, "password123", ["regular"])
login_response = await client_factory.post( login_response = await client_factory.post(
@ -406,10 +439,7 @@ class TestSecurityBypassAttempts:
result = await db.execute(select(Role).where(Role.name == "admin")) result = await db.execute(select(Role).where(Role.name == "admin"))
admin_role = result.scalar_one() admin_role = result.scalar_one()
result = await db.execute(select(Role).where(Role.name == "regular")) user.roles = [admin_role] # Replace roles with admin only
regular_role = result.scalar_one()
user.roles = [admin_role] # Remove regular, add admin
await db.commit() await db.commit()
# Now should have audit access but not counter access # Now should have audit access but not counter access
@ -422,6 +452,7 @@ class TestSecurityBypassAttempts:
# Audit Record Tests # Audit Record Tests
# ============================================================================= # =============================================================================
class TestAuditRecords: class TestAuditRecords:
"""Test that actions are properly recorded in audit logs.""" """Test that actions are properly recorded in audit logs."""
@ -466,7 +497,7 @@ class TestAuditRecords:
# Find record with our values # Find record with our values
records = data["records"] records = data["records"]
matching = [r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30] matching = [
r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30
]
assert len(matching) >= 1 assert len(matching) >= 1

View file

@ -1,9 +1,9 @@
"""Tests for user profile and contact details.""" """Tests for user profile and contact details."""
import pytest
from sqlalchemy import select from sqlalchemy import select
from models import User, ROLE_REGULAR
from auth import get_password_hash from auth import get_password_hash
from models import User
from tests.helpers import unique_email from tests.helpers import unique_email
# Valid npub for testing (32 zero bytes encoded as bech32) # Valid npub for testing (32 zero bytes encoded as bech32)
@ -328,7 +328,9 @@ class TestUpdateProfileEndpoint:
assert "field_errors" in data["detail"] assert "field_errors" in data["detail"]
assert "nostr_npub" in data["detail"]["field_errors"] assert "nostr_npub" in data["detail"]["field_errors"]
async def test_multiple_invalid_fields_returns_all_errors(self, client_factory, regular_user): async def test_multiple_invalid_fields_returns_all_errors(
self, client_factory, regular_user
):
"""Multiple invalid fields return all errors.""" """Multiple invalid fields return all errors."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.put( response = await client.put(
@ -344,7 +346,9 @@ class TestUpdateProfileEndpoint:
assert "contact_email" in data["detail"]["field_errors"] assert "contact_email" in data["detail"]["field_errors"]
assert "telegram" in data["detail"]["field_errors"] assert "telegram" in data["detail"]["field_errors"]
async def test_partial_update_preserves_other_fields(self, client_factory, regular_user): async def test_partial_update_preserves_other_fields(
self, client_factory, regular_user
):
"""Updating one field doesn't affect others (they get set to the request values).""" """Updating one field doesn't affect others (they get set to the request values)."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
# Set initial values # Set initial values
@ -402,11 +406,14 @@ class TestProfilePrivacy:
class TestProfileGodfather: class TestProfileGodfather:
"""Tests for godfather information in profile.""" """Tests for godfather information in profile."""
async def test_profile_shows_godfather_email(self, client_factory, admin_user, regular_user): async def test_profile_shows_godfather_email(
self, client_factory, admin_user, regular_user
):
"""Profile shows godfather email for users who signed up with invite.""" """Profile shows godfather email for users who signed up with invite."""
from tests.helpers import unique_email
from sqlalchemy import select from sqlalchemy import select
from models import User from models import User
from tests.helpers import unique_email
# Create invite # Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client: async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -443,7 +450,9 @@ class TestProfileGodfather:
data = response.json() data = response.json()
assert data["godfather_email"] == regular_user["email"] assert data["godfather_email"] == regular_user["email"]
async def test_profile_godfather_null_for_seeded_users(self, client_factory, regular_user): async def test_profile_godfather_null_for_seeded_users(
self, client_factory, regular_user
):
"""Profile shows null godfather for users without one (e.g., seeded users).""" """Profile shows null godfather for users without one (e.g., seeded users)."""
async with client_factory.create(cookies=regular_user["cookies"]) as client: async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/profile") response = await client.get("/api/profile")

View file

@ -1,12 +1,11 @@
"""Tests for profile field validation.""" """Tests for profile field validation."""
import pytest
from validation import ( from validation import (
validate_contact_email, validate_contact_email,
validate_telegram,
validate_signal,
validate_nostr_npub, validate_nostr_npub,
validate_profile_fields, validate_profile_fields,
validate_signal,
validate_telegram,
) )
@ -140,13 +139,17 @@ class TestValidateNostrNpub:
assert validate_nostr_npub(self.VALID_NPUB) is None assert validate_nostr_npub(self.VALID_NPUB) is None
def test_wrong_prefix(self): def test_wrong_prefix(self):
result = validate_nostr_npub("nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz") result = validate_nostr_npub(
"nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz"
)
assert result is not None assert result is not None
assert "npub" in result.lower() assert "npub" in result.lower()
def test_invalid_checksum(self): def test_invalid_checksum(self):
# Change last character to break checksum # Change last character to break checksum
result = validate_nostr_npub("npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd") result = validate_nostr_npub(
"npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd"
)
assert result is not None assert result is not None
assert "checksum" in result.lower() assert "checksum" in result.lower()
@ -155,7 +158,9 @@ class TestValidateNostrNpub:
assert result is not None assert result is not None
def test_not_starting_with_npub1(self): def test_not_starting_with_npub1(self):
result = validate_nostr_npub("npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc") result = validate_nostr_npub(
"npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc"
)
assert result is not None assert result is not None
assert "npub1" in result assert "npub1" in result
@ -206,4 +211,3 @@ class TestValidateProfileFields:
nostr_npub="", nostr_npub="",
) )
assert errors == {} assert errors == {}

View file

@ -1,8 +1,9 @@
"""Validate shared constants match backend definitions.""" """Validate shared constants match backend definitions."""
import json import json
from pathlib import Path from pathlib import Path
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, AppointmentStatus from models import ROLE_ADMIN, ROLE_REGULAR, AppointmentStatus, InviteStatus
def validate_shared_constants() -> None: def validate_shared_constants() -> None:
@ -29,35 +30,42 @@ def validate_shared_constants() -> None:
# Validate invite statuses # Validate invite statuses
expected_invite_statuses = {s.name: s.value for s in InviteStatus} expected_invite_statuses = {s.name: s.value for s in InviteStatus}
if constants.get("inviteStatuses") != expected_invite_statuses: if constants.get("inviteStatuses") != expected_invite_statuses:
got = constants.get("inviteStatuses")
raise ValueError( raise ValueError(
f"Invite status mismatch in shared/constants.json. " f"Invite status mismatch. Expected: {expected_invite_statuses}, Got: {got}"
f"Expected: {expected_invite_statuses}, Got: {constants.get('inviteStatuses')}"
) )
# Validate appointment statuses # Validate appointment statuses
expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus} expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus}
if constants.get("appointmentStatuses") != expected_appointment_statuses: if constants.get("appointmentStatuses") != expected_appointment_statuses:
got = constants.get("appointmentStatuses")
raise ValueError( raise ValueError(
f"Appointment status mismatch in shared/constants.json. " f"Appointment status mismatch. "
f"Expected: {expected_appointment_statuses}, Got: {constants.get('appointmentStatuses')}" f"Expected: {expected_appointment_statuses}, Got: {got}"
) )
# Validate booking constants exist with required fields # Validate booking constants exist with required fields
booking = constants.get("booking", {}) booking = constants.get("booking", {})
required_booking_fields = ["slotDurationMinutes", "maxAdvanceDays", "minAdvanceDays", "noteMaxLength"] required_booking_fields = [
"slotDurationMinutes",
"maxAdvanceDays",
"minAdvanceDays",
"noteMaxLength",
]
for field in required_booking_fields: for field in required_booking_fields:
if field not in booking: if field not in booking:
raise ValueError(f"Missing booking constant '{field}' in shared/constants.json") raise ValueError(f"Missing booking constant '{field}' in constants.json")
# Validate validation rules exist (structure check only) # Validate validation rules exist (structure check only)
validation = constants.get("validation", {}) validation = constants.get("validation", {})
required_fields = ["telegram", "signal", "nostrNpub"] required_fields = ["telegram", "signal", "nostrNpub"]
for field in required_fields: for field in required_fields:
if field not in validation: if field not in validation:
raise ValueError(f"Missing validation rules for '{field}' in shared/constants.json") raise ValueError(
f"Missing validation rules for '{field}' in constants.json"
)
if __name__ == "__main__": if __name__ == "__main__":
validate_shared_constants() validate_shared_constants()
print("✓ Shared constants are valid") print("✓ Shared constants are valid")

View file

@ -1,9 +1,10 @@
"""Validation utilities for user profile fields.""" """Validation utilities for user profile fields."""
import json import json
from pathlib import Path from pathlib import Path
from email_validator import validate_email, EmailNotValidError
from bech32 import bech32_decode from bech32 import bech32_decode
from email_validator import EmailNotValidError, validate_email
# Load validation rules from shared constants # Load validation rules from shared constants
_constants_path = Path(__file__).parent.parent / "shared" / "constants.json" _constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
@ -143,4 +144,3 @@ def validate_profile_fields(
errors["nostr_npub"] = err errors["nostr_npub"] = err
return errors return errors