Add ruff linter/formatter for Python

- Add ruff as dev dependency
- Configure ruff in pyproject.toml with strict 88-char line limit
- Ignore B008 (FastAPI Depends pattern is standard)
- Allow longer lines in tests for readability
- Fix all lint issues in source files
- Add Makefile targets: lint-backend, format-backend, fix-backend
This commit is contained in:
counterweight 2025-12-21 21:54:26 +01:00
parent 69bc8413e0
commit 6c218130e9
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
31 changed files with 1234 additions and 876 deletions

View file

@ -1,4 +1,4 @@
.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants
.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants lint-backend format-backend fix-backend
-include .env
export
@ -93,3 +93,12 @@ check-types-fresh: generate-types-standalone
check-constants:
@cd backend && uv run python validate_constants.py
lint-backend:
cd backend && uv run ruff check .
format-backend:
cd backend && uv run ruff format .
fix-backend:
cd backend && uv run ruff check --fix . && uv run ruff format .

View file

@ -1,5 +1,5 @@
import os
from datetime import datetime, timedelta, timezone
from datetime import UTC, datetime, timedelta
import bcrypt
from fastapi import Depends, HTTPException, Request, status
@ -8,7 +8,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import User, Permission
from models import Permission, User
from schemas import UserResponse
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
@ -32,9 +32,13 @@ def get_password_hash(password: str) -> str:
).decode("utf-8")
def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str:
def create_access_token(
data: dict[str, str],
expires_delta: timedelta | None = None,
) -> str:
to_encode: dict[str, str | datetime] = dict(data)
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + delta
to_encode["exp"] = expire
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded
@ -72,7 +76,7 @@ async def get_current_user(
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception
raise credentials_exception from None
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
@ -83,13 +87,16 @@ async def get_current_user(
def require_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ALL of the required permissions.
Dependency factory that checks if user has ALL required permissions.
Usage:
@app.get("/api/counter")
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))):
async def get_counter(
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
):
...
"""
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
@ -99,11 +106,13 @@ def require_permission(*required_permissions: Permission):
missing = [p for p in required_permissions if p not in user_permissions]
if missing:
missing_str = ", ".join(p.value for p in missing)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}",
detail=f"Missing required permissions: {missing_str}",
)
return user
return permission_checker

View file

@ -1,8 +1,11 @@
import os
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret")
DATABASE_URL = os.getenv(
"DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret"
)
engine = create_async_engine(DATABASE_URL)
async_session = async_sessionmaker(engine, expire_on_commit=False)
@ -15,4 +18,3 @@ class Base(DeclarativeBase):
async def get_db():
async with async_session() as session:
yield session

View file

@ -1,4 +1,5 @@
"""Utilities for invite code generation and validation."""
import random
from pathlib import Path
@ -53,8 +54,4 @@ def is_valid_identifier_format(identifier: str) -> bool:
return False
# Check number is two digits
if len(number) != 2 or not number.isdigit():
return False
return True
return len(number) == 2 and number.isdigit()

View file

@ -1,19 +1,20 @@
"""FastAPI application entry point."""
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from database import engine, Base
from routes import sum as sum_routes
from routes import counter as counter_routes
from database import Base, engine
from routes import audit as audit_routes
from routes import profile as profile_routes
from routes import invites as invites_routes
from routes import auth as auth_routes
from routes import meta as meta_routes
from routes import availability as availability_routes
from routes import booking as booking_routes
from routes import counter as counter_routes
from routes import invites as invites_routes
from routes import meta as meta_routes
from routes import profile as profile_routes
from routes import sum as sum_routes
from validate_constants import validate_shared_constants

View file

@ -1,9 +1,24 @@
from datetime import datetime, date, time, timezone
from datetime import UTC, date, datetime, time
from enum import Enum as PyEnum
from typing import TypedDict
from sqlalchemy import Integer, String, Float, DateTime, Date, Time, ForeignKey, Table, Column, Enum, UniqueConstraint, select
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy import (
Column,
Date,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
String,
Table,
Time,
UniqueConstraint,
select,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship
from database import Base
@ -14,6 +29,7 @@ class RoleConfig(TypedDict):
class Permission(str, PyEnum):
"""All available permissions in the system."""
# Counter permissions
VIEW_COUNTER = "view_counter"
INCREMENT_COUNTER = "increment_counter"
@ -41,6 +57,7 @@ class Permission(str, PyEnum):
class InviteStatus(str, PyEnum):
"""Status of an invite."""
READY = "ready"
SPENT = "spent"
REVOKED = "revoked"
@ -48,6 +65,7 @@ class InviteStatus(str, PyEnum):
class AppointmentStatus(str, PyEnum):
"""Status of an appointment."""
BOOKED = "booked"
CANCELLED_BY_USER = "cancelled_by_user"
CANCELLED_BY_ADMIN = "cancelled_by_admin"
@ -60,7 +78,7 @@ ROLE_REGULAR = "regular"
# Role definitions with their permissions
ROLE_DEFINITIONS: dict[str, RoleConfig] = {
ROLE_ADMIN: {
"description": "Administrator with audit, invite, and appointment management access",
"description": "Administrator with audit/invite/appointment access",
"permissions": [
Permission.VIEW_AUDIT,
Permission.MANAGE_INVITES,
@ -88,7 +106,12 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = {
role_permissions = Table(
"role_permissions",
Base.metadata,
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
Column(
"role_id",
Integer,
ForeignKey("roles.id", ondelete="CASCADE"),
primary_key=True,
),
Column("permission", Enum(Permission), primary_key=True),
)
@ -97,8 +120,18 @@ role_permissions = Table(
user_roles = Table(
"user_roles",
Base.metadata,
Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True),
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
Column(
"user_id",
Integer,
ForeignKey("users.id", ondelete="CASCADE"),
primary_key=True,
),
Column(
"role_id",
Integer,
ForeignKey("roles.id", ondelete="CASCADE"),
primary_key=True,
),
)
@ -118,23 +151,34 @@ class Role(Base):
async def get_permissions(self, db: AsyncSession) -> set[Permission]:
"""Get all permissions for this role."""
result = await db.execute(
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id)
query = select(role_permissions.c.permission).where(
role_permissions.c.role_id == self.id
)
result = await db.execute(query)
return {row[0] for row in result.fetchall()}
async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None:
async def set_permissions(
self, db: AsyncSession, permissions: list[Permission]
) -> None:
"""Set all permissions for this role (replaces existing)."""
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id))
delete_query = role_permissions.delete().where(
role_permissions.c.role_id == self.id
)
await db.execute(delete_query)
for perm in permissions:
await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm))
insert_query = role_permissions.insert().values(
role_id=self.id, permission=perm
)
await db.execute(insert_query)
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
email: Mapped[str] = mapped_column(
String(255), unique=True, nullable=False, index=True
)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
# Contact details (all optional)
@ -192,12 +236,14 @@ class SumRecord(Base):
__tablename__ = "sum_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
a: Mapped[float] = mapped_column(Float, nullable=False)
b: Mapped[float] = mapped_column(Float, nullable=False)
result: Mapped[float] = mapped_column(Float, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
@ -205,11 +251,13 @@ class CounterRecord(Base):
__tablename__ = "counter_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
value_before: Mapped[int] = mapped_column(Integer, nullable=False)
value_after: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
@ -217,7 +265,9 @@ class Invite(Base):
__tablename__ = "invites"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
identifier: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
identifier: Mapped[str] = mapped_column(
String(64), unique=True, nullable=False, index=True
)
status: Mapped[InviteStatus] = mapped_column(
Enum(InviteStatus), nullable=False, default=InviteStatus.READY
)
@ -244,14 +294,19 @@ class Invite(Base):
# Timestamps
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
spent_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
revoked_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
spent_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
class Availability(Base):
"""Admin availability slots for booking."""
__tablename__ = "availability"
__table_args__ = (
UniqueConstraint("date", "start_time", name="uq_availability_date_start"),
@ -262,34 +317,37 @@ class Availability(Base):
start_time: Mapped[time] = mapped_column(Time, nullable=False)
end_time: Mapped[time] = mapped_column(Time, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc)
default=lambda: datetime.now(UTC),
onupdate=lambda: datetime.now(UTC),
)
class Appointment(Base):
"""User appointment bookings."""
__tablename__ = "appointments"
__table_args__ = (
UniqueConstraint("slot_start", name="uq_appointment_slot_start"),
)
__table_args__ = (UniqueConstraint("slot_start", name="uq_appointment_slot_start"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined")
slot_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
slot_start: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, index=True
)
slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
note: Mapped[str | None] = mapped_column(String(144), nullable=True)
status: Mapped[AppointmentStatus] = mapped_column(
Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
cancelled_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
cancelled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)

View file

@ -20,6 +20,7 @@ dev = [
"httpx>=0.28.1",
"aiosqlite>=0.20.0",
"mypy>=1.13.0",
"ruff>=0.14.10",
]
[tool.mypy]
@ -30,3 +31,27 @@ check_untyped_defs = true
ignore_missing_imports = true
exclude = ["tests/"]
[tool.ruff]
line-length = 88
target-version = "py311"
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"UP", # pyupgrade
"SIM", # flake8-simplify
"RUF", # ruff-specific rules
]
ignore = [
"B008", # function-call-in-default-argument (standard FastAPI pattern with Depends)
]
[tool.ruff.format]
quote-style = "double"
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["E501"] # Allow longer lines in tests for readability

View file

@ -1,22 +1,23 @@
"""Audit routes for viewing action records."""
from typing import Callable, TypeVar
from collections.abc import Callable
from typing import TypeVar
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy import select, func, desc
from sqlalchemy import desc, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission
from database import get_db
from models import User, SumRecord, CounterRecord, Permission
from models import CounterRecord, Permission, SumRecord, User
from schemas import (
CounterRecordResponse,
SumRecordResponse,
PaginatedCounterRecords,
PaginatedSumRecords,
SumRecordResponse,
)
router = APIRouter(prefix="/api/audit", tags=["audit"])
R = TypeVar("R", bound=BaseModel)

View file

@ -1,5 +1,6 @@
"""Authentication routes for register, login, logout, and current user."""
from datetime import datetime, timezone
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Response, status
from sqlalchemy import select
@ -9,18 +10,17 @@ from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME,
COOKIE_SECURE,
get_password_hash,
get_user_by_email,
authenticate_user,
build_user_response,
create_access_token,
get_current_user,
build_user_response,
get_password_hash,
get_user_by_email,
)
from database import get_db
from invite_utils import normalize_identifier
from models import User, Role, ROLE_REGULAR, Invite, InviteStatus
from schemas import UserLogin, UserResponse, RegisterWithInvite
from models import ROLE_REGULAR, Invite, InviteStatus, Role, User
from schemas import RegisterWithInvite, UserLogin, UserResponse
router = APIRouter(prefix="/api/auth", tags=["auth"])
@ -52,9 +52,8 @@ async def register(
"""Register a new user using an invite code."""
# Validate invite
normalized_identifier = normalize_identifier(user_data.invite_identifier)
result = await db.execute(
select(Invite).where(Invite.identifier == normalized_identifier)
)
query = select(Invite).where(Invite.identifier == normalized_identifier)
result = await db.execute(query)
invite = result.scalar_one_or_none()
# Return same error for not found, spent, and revoked to avoid information leakage
@ -90,7 +89,7 @@ async def register(
# Mark invite as spent
invite.status = InviteStatus.SPENT
invite.used_by_id = user.id
invite.spent_at = datetime.now(timezone.utc)
invite.spent_at = datetime.now(UTC)
await db.commit()
await db.refresh(user)

View file

@ -1,28 +1,28 @@
"""Availability routes for admin to manage booking availability."""
from datetime import date, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, delete, and_
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission
from database import get_db
from models import User, Availability, Permission
from models import Availability, Permission, User
from schemas import (
TimeSlot,
AvailabilityDay,
AvailabilityResponse,
SetAvailabilityRequest,
CopyAvailabilityRequest,
SetAvailabilityRequest,
TimeSlot,
)
from shared_constants import MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS
from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS
router = APIRouter(prefix="/api/admin/availability", tags=["availability"])
def _get_date_range_bounds() -> tuple[date, date]:
"""Get the valid date range for availability (using MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
"""Get valid date range (MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
today = date.today()
min_date = today + timedelta(days=MIN_ADVANCE_DAYS)
max_date = today + timedelta(days=MAX_ADVANCE_DAYS)
@ -34,12 +34,14 @@ def _validate_date_in_range(d: date, min_date: date, max_date: date) -> None:
if d < min_date:
raise HTTPException(
status_code=400,
detail=f"Cannot set availability for past dates. Earliest allowed: {min_date}",
detail=f"Cannot set availability for past dates. "
f"Earliest allowed: {min_date}",
)
if d > max_date:
raise HTTPException(
status_code=400,
detail=f"Cannot set availability more than {MAX_ADVANCE_DAYS} days ahead. Latest allowed: {max_date}",
detail=f"Cannot set more than {MAX_ADVANCE_DAYS} days ahead. "
f"Latest allowed: {max_date}",
)
@ -70,15 +72,16 @@ async def get_availability(
for slot in slots:
if slot.date not in days_dict:
days_dict[slot.date] = []
days_dict[slot.date].append(TimeSlot(
start_time=slot.start_time,
end_time=slot.end_time,
))
days_dict[slot.date].append(
TimeSlot(
start_time=slot.start_time,
end_time=slot.end_time,
)
)
# Convert to response format
days = [
AvailabilityDay(date=d, slots=days_dict[d])
for d in sorted(days_dict.keys())
AvailabilityDay(date=d, slots=days_dict[d]) for d in sorted(days_dict.keys())
]
return AvailabilityResponse(days=days)
@ -98,9 +101,12 @@ async def set_availability(
sorted_slots = sorted(request.slots, key=lambda s: s.start_time)
for i in range(len(sorted_slots) - 1):
if sorted_slots[i].end_time > sorted_slots[i + 1].start_time:
end = sorted_slots[i].end_time
start = sorted_slots[i + 1].start_time
raise HTTPException(
status_code=400,
detail=f"Time slots overlap on {request.date}: slot ending at {sorted_slots[i].end_time} overlaps with slot starting at {sorted_slots[i + 1].start_time}. Please ensure all time slots are non-overlapping.",
detail=f"Time slots overlap: slot ending at {end} "
f"overlaps with slot starting at {start}",
)
# Validate each slot's end_time > start_time
@ -108,13 +114,12 @@ async def set_availability(
if slot.end_time <= slot.start_time:
raise HTTPException(
status_code=400,
detail=f"Invalid time slot on {request.date}: end time {slot.end_time} must be after start time {slot.start_time}. Please correct the time range.",
detail=f"Invalid time slot: end time {slot.end_time} "
f"must be after start time {slot.start_time}",
)
# Delete existing availability for this date
await db.execute(
delete(Availability).where(Availability.date == request.date)
)
await db.execute(delete(Availability).where(Availability.date == request.date))
# Create new availability slots
for slot in request.slots:
@ -139,7 +144,7 @@ async def copy_availability(
"""Copy availability from one day to multiple target days."""
min_date, max_date = _get_date_range_bounds()
# Validate source date is in range (for consistency, though DB query would fail anyway)
# Validate source date is in range
_validate_date_in_range(request.source_date, min_date, max_date)
# Validate target dates
@ -169,9 +174,8 @@ async def copy_availability(
continue # Skip copying to self
# Delete existing availability for target date
await db.execute(
delete(Availability).where(Availability.date == target_date)
)
del_query = delete(Availability).where(Availability.date == target_date)
await db.execute(del_query)
# Copy slots
target_slots: list[TimeSlot] = []
@ -182,10 +186,12 @@ async def copy_availability(
end_time=source_slot.end_time,
)
db.add(new_availability)
target_slots.append(TimeSlot(
start_time=source_slot.start_time,
end_time=source_slot.end_time,
))
target_slots.append(
TimeSlot(
start_time=source_slot.start_time,
end_time=source_slot.end_time,
)
)
copied_days.append(AvailabilityDay(date=target_date, slots=target_slots))
@ -197,4 +203,3 @@ async def copy_availability(
raise
return AvailabilityResponse(days=copied_days)

View file

@ -1,24 +1,24 @@
"""Booking routes for users to book appointments."""
from datetime import date, datetime, time, timedelta, timezone
from datetime import UTC, date, datetime, time, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select, and_, func
from sqlalchemy import and_, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from auth import require_permission
from database import get_db
from models import User, Availability, Appointment, AppointmentStatus, Permission
from models import Appointment, AppointmentStatus, Availability, Permission, User
from schemas import (
BookableSlot,
AvailableSlotsResponse,
BookingRequest,
AppointmentResponse,
AvailableSlotsResponse,
BookableSlot,
BookingRequest,
PaginatedAppointments,
)
from shared_constants import SLOT_DURATION_MINUTES, MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS
from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS, SLOT_DURATION_MINUTES
router = APIRouter(prefix="/api/booking", tags=["booking"])
@ -74,12 +74,14 @@ def _validate_booking_date(d: date) -> None:
if d < min_date:
raise HTTPException(
status_code=400,
detail=f"Cannot book for today or past dates. Earliest bookable date: {min_date}",
detail=f"Cannot book for today or past dates. "
f"Earliest bookable: {min_date}",
)
if d > max_date:
raise HTTPException(
status_code=400,
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. Latest bookable: {max_date}",
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. "
f"Latest bookable: {max_date}",
)
@ -92,8 +94,8 @@ def _expand_availability_to_slots(
for avail in availability_slots:
# Create datetime objects for start and end
current = datetime.combine(target_date, avail.start_time, tzinfo=timezone.utc)
end = datetime.combine(target_date, avail.end_time, tzinfo=timezone.utc)
current = datetime.combine(target_date, avail.start_time, tzinfo=UTC)
end = datetime.combine(target_date, avail.end_time, tzinfo=UTC)
# Generate 15-minute slots
while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end:
@ -128,12 +130,11 @@ async def get_available_slots(
all_slots = _expand_availability_to_slots(availability_slots, target_date)
# Get existing booked appointments for this date
day_start = datetime.combine(target_date, time.min, tzinfo=timezone.utc)
day_end = datetime.combine(target_date, time.max, tzinfo=timezone.utc)
day_start = datetime.combine(target_date, time.min, tzinfo=UTC)
day_end = datetime.combine(target_date, time.max, tzinfo=UTC)
result = await db.execute(
select(Appointment.slot_start)
.where(
select(Appointment.slot_start).where(
and_(
Appointment.slot_start >= day_start,
Appointment.slot_start <= day_end,
@ -145,8 +146,7 @@ async def get_available_slots(
# Filter out already booked slots
available_slots = [
slot for slot in all_slots
if slot.start_time not in booked_starts
slot for slot in all_slots if slot.start_time not in booked_starts
]
return AvailableSlotsResponse(date=target_date, slots=available_slots)
@ -162,12 +162,13 @@ async def create_booking(
slot_date = request.slot_start.date()
_validate_booking_date(slot_date)
# Validate slot is on the correct minute boundary (derived from SLOT_DURATION_MINUTES)
# Validate slot is on the correct minute boundary
valid_minutes = _get_valid_minute_boundaries()
if request.slot_start.minute not in valid_minutes:
raise HTTPException(
status_code=400,
detail=f"Slot start time must be on {SLOT_DURATION_MINUTES}-minute boundary (valid minutes: {valid_minutes})",
detail=f"Slot must be on {SLOT_DURATION_MINUTES}-minute boundary "
f"(valid minutes: {valid_minutes})",
)
if request.slot_start.second != 0 or request.slot_start.microsecond != 0:
raise HTTPException(
@ -177,11 +178,11 @@ async def create_booking(
# Verify slot falls within availability
slot_start_time = request.slot_start.time()
slot_end_time = (request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)).time()
slot_end_dt = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
slot_end_time = slot_end_dt.time()
result = await db.execute(
select(Availability)
.where(
select(Availability).where(
and_(
Availability.date == slot_date,
Availability.start_time <= slot_start_time,
@ -192,9 +193,11 @@ async def create_booking(
matching_availability = result.scalar_one_or_none()
if not matching_availability:
slot_str = request.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException(
status_code=400,
detail=f"Selected slot at {request.slot_start.strftime('%Y-%m-%d %H:%M')} UTC is not within any available time ranges for {slot_date}. Please select a different time slot.",
detail=f"Selected slot at {slot_str} UTC is not within "
f"any available time ranges for {slot_date}",
)
# Create the appointment
@ -216,8 +219,8 @@ async def create_booking(
await db.rollback()
raise HTTPException(
status_code=409,
detail="This slot has already been booked. Please select another slot.",
)
detail="This slot has already been booked. Select another slot.",
) from None
return _to_appointment_response(appointment, current_user.email)
@ -242,20 +245,19 @@ async def get_my_appointments(
)
appointments = result.scalars().all()
return [
_to_appointment_response(apt, current_user.email)
for apt in appointments
]
return [_to_appointment_response(apt, current_user.email) for apt in appointments]
@appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse)
@appointments_router.post(
"/{appointment_id}/cancel", response_model=AppointmentResponse
)
async def cancel_my_appointment(
appointment_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)),
) -> AppointmentResponse:
"""Cancel one of the current user's appointments."""
# Get the appointment with explicit eager loading of user relationship
# Get the appointment with eager loading of user relationship
result = await db.execute(
select(Appointment)
.options(joinedload(Appointment.user))
@ -266,31 +268,35 @@ async def cancel_my_appointment(
if not appointment:
raise HTTPException(
status_code=404,
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.",
detail=f"Appointment {appointment_id} not found",
)
# Verify ownership
if appointment.user_id != current_user.id:
raise HTTPException(status_code=403, detail="Cannot cancel another user's appointment")
raise HTTPException(
status_code=403,
detail="Cannot cancel another user's appointment",
)
# Check if already cancelled
if appointment.status != AppointmentStatus.BOOKED:
raise HTTPException(
status_code=400,
detail=f"Cannot cancel appointment with status '{appointment.status.value}'"
detail=f"Cannot cancel: status is '{appointment.status.value}'",
)
# Check if appointment is in the past
if appointment.slot_start <= datetime.now(timezone.utc):
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC"
if appointment.slot_start <= datetime.now(UTC):
apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException(
status_code=400,
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started."
detail=f"Cannot cancel appointment at {apt_time} UTC: "
"already started or in the past",
)
# Cancel the appointment
appointment.status = AppointmentStatus.CANCELLED_BY_USER
appointment.cancelled_at = datetime.now(timezone.utc)
appointment.cancelled_at = datetime.now(UTC)
await db.commit()
await db.refresh(appointment)
@ -302,7 +308,9 @@ async def cancel_my_appointment(
# Admin Appointments Endpoints
# =============================================================================
admin_appointments_router = APIRouter(prefix="/api/admin/appointments", tags=["admin-appointments"])
admin_appointments_router = APIRouter(
prefix="/api/admin/appointments", tags=["admin-appointments"]
)
@admin_appointments_router.get("", response_model=PaginatedAppointments)
@ -344,14 +352,18 @@ async def get_all_appointments(
)
@admin_appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse)
@admin_appointments_router.post(
"/{appointment_id}/cancel", response_model=AppointmentResponse
)
async def admin_cancel_appointment(
appointment_id: int,
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.CANCEL_ANY_APPOINTMENT)),
_current_user: User = Depends(
require_permission(Permission.CANCEL_ANY_APPOINTMENT)
),
) -> AppointmentResponse:
"""Cancel any appointment (admin only)."""
# Get the appointment with explicit eager loading of user relationship
# Get the appointment with eager loading of user relationship
result = await db.execute(
select(Appointment)
.options(joinedload(Appointment.user))
@ -362,27 +374,28 @@ async def admin_cancel_appointment(
if not appointment:
raise HTTPException(
status_code=404,
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.",
detail=f"Appointment {appointment_id} not found",
)
# Check if already cancelled
if appointment.status != AppointmentStatus.BOOKED:
raise HTTPException(
status_code=400,
detail=f"Cannot cancel appointment with status '{appointment.status.value}'"
detail=f"Cannot cancel: status is '{appointment.status.value}'",
)
# Check if appointment is in the past
if appointment.slot_start <= datetime.now(timezone.utc):
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC"
if appointment.slot_start <= datetime.now(UTC):
apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
raise HTTPException(
status_code=400,
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started."
detail=f"Cannot cancel appointment at {apt_time} UTC: "
"already started or in the past",
)
# Cancel the appointment
appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN
appointment.cancelled_at = datetime.now(timezone.utc)
appointment.cancelled_at = datetime.now(UTC)
await db.commit()
await db.refresh(appointment)

View file

@ -1,12 +1,12 @@
"""Counter routes."""
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission
from database import get_db
from models import Counter, User, CounterRecord, Permission
from models import Counter, CounterRecord, Permission, User
router = APIRouter(prefix="/api/counter", tags=["counter"])

View file

@ -1,25 +1,29 @@
"""Invite routes for public check, user invites, and admin management."""
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy import select, func, desc
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import desc, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission
from database import get_db
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format
from models import User, Invite, InviteStatus, Permission
from invite_utils import (
generate_invite_identifier,
is_valid_identifier_format,
normalize_identifier,
)
from models import Invite, InviteStatus, Permission, User
from schemas import (
AdminUserResponse,
InviteCheckResponse,
InviteCreate,
InviteResponse,
UserInviteResponse,
PaginatedInviteRecords,
AdminUserResponse,
UserInviteResponse,
)
router = APIRouter(prefix="/api/invites", tags=["invites"])
admin_router = APIRouter(prefix="/api/admin", tags=["admin"])
@ -54,9 +58,7 @@ async def check_invite(
if not is_valid_identifier_format(normalized):
return InviteCheckResponse(valid=False, error="Invalid invite code format")
result = await db.execute(
select(Invite).where(Invite.identifier == normalized)
)
result = await db.execute(select(Invite).where(Invite.identifier == normalized))
invite = result.scalar_one_or_none()
# Return same error for not found, spent, and revoked to avoid information leakage
@ -112,9 +114,7 @@ async def create_invite(
) -> InviteResponse:
"""Create a new invite for a specified godfather user."""
# Validate godfather exists
result = await db.execute(
select(User.id).where(User.id == data.godfather_id)
)
result = await db.execute(select(User.id).where(User.id == data.godfather_id))
godfather_id = result.scalar_one_or_none()
if not godfather_id:
raise HTTPException(
@ -141,8 +141,8 @@ async def create_invite(
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate unique invite code. Please try again.",
)
detail="Failed to generate unique invite code. Try again.",
) from None
if invite is None:
raise HTTPException(
@ -156,7 +156,9 @@ async def create_invite(
async def list_all_invites(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"),
status_filter: str | None = Query(
None, alias="status", description="Filter by status: ready, spent, revoked"
),
godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
@ -175,8 +177,9 @@ async def list_all_invites(
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked",
)
detail=f"Invalid status: {status_filter}. "
"Must be ready, spent, or revoked",
) from None
if godfather_id:
query = query.where(Invite.godfather_id == godfather_id)
@ -224,11 +227,12 @@ async def revoke_invite(
if invite.status != InviteStatus.READY:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.",
detail=f"Cannot revoke invite with status '{invite.status.value}'. "
"Only READY invites can be revoked.",
)
invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(timezone.utc)
invite.revoked_at = datetime.now(UTC)
await db.commit()
await db.refresh(invite)

View file

@ -1,7 +1,8 @@
"""Meta endpoints for shared constants."""
from fastapi import APIRouter
from models import Permission, InviteStatus, ROLE_ADMIN, ROLE_REGULAR
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, Permission
from schemas import ConstantsResponse
router = APIRouter(prefix="/api/meta", tags=["meta"])
@ -15,4 +16,3 @@ async def get_constants() -> ConstantsResponse:
roles=[ROLE_ADMIN, ROLE_REGULAR],
invite_statuses=[s.value for s in InviteStatus],
)

View file

@ -1,15 +1,15 @@
"""Profile routes for user contact details."""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from auth import get_current_user
from database import get_db
from models import User, ROLE_REGULAR
from models import ROLE_REGULAR, User
from schemas import ProfileResponse, ProfileUpdate
from validation import validate_profile_fields
router = APIRouter(prefix="/api/profile", tags=["profile"])
@ -29,9 +29,7 @@ async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str
"""Get the email of a godfather user by ID."""
if not godfather_id:
return None
result = await db.execute(
select(User.email).where(User.id == godfather_id)
)
result = await db.execute(select(User.email).where(User.id == godfather_id))
return result.scalar_one_or_none()

View file

@ -1,13 +1,13 @@
"""Sum calculation routes."""
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission
from database import get_db
from models import User, SumRecord, Permission
from models import Permission, SumRecord, User
from schemas import SumRequest, SumResponse
router = APIRouter(prefix="/api/sum", tags=["sum"])

View file

@ -1,5 +1,6 @@
"""Pydantic schemas for API request/response models."""
from datetime import datetime, date, time
from datetime import date, datetime, time
from typing import Generic, TypeVar
from pydantic import BaseModel, EmailStr, field_validator
@ -9,6 +10,7 @@ from shared_constants import NOTE_MAX_LENGTH
class UserCredentials(BaseModel):
"""Base model for user email/password."""
email: EmailStr
password: str
@ -19,6 +21,7 @@ UserLogin = UserCredentials
class UserResponse(BaseModel):
"""Response model for authenticated user info."""
id: int
email: str
roles: list[str]
@ -27,6 +30,7 @@ class UserResponse(BaseModel):
class RegisterWithInvite(BaseModel):
"""Request model for registration with invite."""
email: EmailStr
password: str
invite_identifier: str
@ -34,12 +38,14 @@ class RegisterWithInvite(BaseModel):
class SumRequest(BaseModel):
"""Request model for sum calculation."""
a: float
b: float
class SumResponse(BaseModel):
"""Response model for sum calculation."""
a: float
b: float
result: float
@ -47,6 +53,7 @@ class SumResponse(BaseModel):
class CounterRecordResponse(BaseModel):
"""Response model for a counter audit record."""
id: int
user_email: str
value_before: int
@ -56,6 +63,7 @@ class CounterRecordResponse(BaseModel):
class SumRecordResponse(BaseModel):
"""Response model for a sum audit record."""
id: int
user_email: str
a: float
@ -69,6 +77,7 @@ RecordT = TypeVar("RecordT", bound=BaseModel)
class PaginatedResponse(BaseModel, Generic[RecordT]):
"""Generic paginated response wrapper."""
records: list[RecordT]
total: int
page: int
@ -82,6 +91,7 @@ PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
class ProfileResponse(BaseModel):
"""Response model for profile data."""
contact_email: str | None
telegram: str | None
signal: str | None
@ -91,6 +101,7 @@ class ProfileResponse(BaseModel):
class ProfileUpdate(BaseModel):
"""Request model for updating profile."""
contact_email: str | None = None
telegram: str | None = None
signal: str | None = None
@ -99,6 +110,7 @@ class ProfileUpdate(BaseModel):
class InviteCheckResponse(BaseModel):
"""Response for invite check endpoint."""
valid: bool
status: str | None = None
error: str | None = None
@ -106,11 +118,13 @@ class InviteCheckResponse(BaseModel):
class InviteCreate(BaseModel):
"""Request model for creating an invite."""
godfather_id: int
class InviteResponse(BaseModel):
"""Response model for invite data (admin view)."""
id: int
identifier: str
godfather_id: int
@ -125,6 +139,7 @@ class InviteResponse(BaseModel):
class UserInviteResponse(BaseModel):
"""Response model for a user's invite (simpler than admin view)."""
id: int
identifier: str
status: str
@ -138,6 +153,7 @@ PaginatedInviteRecords = PaginatedResponse[InviteResponse]
class AdminUserResponse(BaseModel):
"""Minimal user info for admin dropdowns."""
id: int
email: str
@ -146,8 +162,10 @@ class AdminUserResponse(BaseModel):
# Availability Schemas
# =============================================================================
class TimeSlot(BaseModel):
"""A single time slot (start and end time)."""
start_time: time
end_time: time
@ -164,23 +182,27 @@ class TimeSlot(BaseModel):
class AvailabilityDay(BaseModel):
"""Availability for a single day."""
date: date
slots: list[TimeSlot]
class AvailabilityResponse(BaseModel):
"""Response model for availability query."""
days: list[AvailabilityDay]
class SetAvailabilityRequest(BaseModel):
"""Request to set availability for a specific date."""
date: date
slots: list[TimeSlot]
class CopyAvailabilityRequest(BaseModel):
"""Request to copy availability from one day to others."""
source_date: date
target_dates: list[date]
@ -189,20 +211,24 @@ class CopyAvailabilityRequest(BaseModel):
# Booking Schemas
# =============================================================================
class BookableSlot(BaseModel):
"""A bookable 15-minute slot."""
start_time: datetime
end_time: datetime
class AvailableSlotsResponse(BaseModel):
"""Response for available slots on a given date."""
date: date
slots: list[BookableSlot]
class BookingRequest(BaseModel):
"""Request to book an appointment."""
slot_start: datetime
note: str | None = None
@ -216,6 +242,7 @@ class BookingRequest(BaseModel):
class AppointmentResponse(BaseModel):
"""Response model for an appointment."""
id: int
user_id: int
user_email: str
@ -234,8 +261,10 @@ PaginatedAppointments = PaginatedResponse[AppointmentResponse]
# Meta/Constants Schemas
# =============================================================================
class ConstantsResponse(BaseModel):
"""Response model for shared constants."""
permissions: list[str]
roles: list[str]
invite_statuses: list[str]

View file

@ -1,12 +1,21 @@
"""Seed the database with roles, permissions, and dev users."""
import asyncio
import os
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, async_session, Base
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
from auth import get_password_hash
from database import Base, async_session, engine
from models import (
ROLE_ADMIN,
ROLE_DEFINITIONS,
ROLE_REGULAR,
Permission,
Role,
User,
)
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
@ -14,7 +23,9 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role:
async def upsert_role(
db: AsyncSession, name: str, description: str, permissions: list[Permission]
) -> Role:
"""Create or update a role with the given permissions."""
result = await db.execute(select(Role).where(Role.name == name))
role = result.scalar_one_or_none()
@ -35,7 +46,9 @@ async def upsert_role(db: AsyncSession, name: str, description: str, permissions
return role
async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User:
async def upsert_user(
db: AsyncSession, email: str, password: str, role_names: list[str]
) -> User:
"""Create or update a user with the given credentials and roles."""
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()

View file

@ -1,4 +1,5 @@
"""Load shared constants from shared/constants.json."""
import json
from pathlib import Path
@ -10,4 +11,3 @@ SLOT_DURATION_MINUTES: int = _constants["booking"]["slotDurationMinutes"]
MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"]
MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"]
NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"]

View file

@ -7,17 +7,17 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from auth import get_password_hash
from database import Base, get_db
from main import app
from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
from auth import get_password_hash
from models import ROLE_ADMIN, ROLE_DEFINITIONS, ROLE_REGULAR, Role, User
from tests.helpers import unique_email
TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL",
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test",
)
@ -91,7 +91,9 @@ async def create_user_with_roles(
result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none()
if not role:
raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?")
raise ValueError(
f"Role '{role_name}' not found. Did you run setup_roles()?"
)
roles.append(role)
user = User(

View file

@ -3,8 +3,8 @@ import uuid
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models import User, Invite, InviteStatus
from invite_utils import generate_invite_identifier
from models import Invite, InviteStatus, User
def unique_email(prefix: str = "test") -> str:
@ -67,7 +67,9 @@ async def create_invite_for_registration(db: AsyncSession, godfather_email: str)
godfather = result.scalar_one_or_none()
if not godfather:
raise ValueError(f"Godfather user with email '{godfather_email}' not found. "
"Create the user first using create_user_with_roles().")
raise ValueError(
f"Godfather user with email '{godfather_email}' not found. "
"Create the user first using create_user_with_roles()."
)
return await create_invite_for_godfather(db, godfather.id)

View file

@ -3,12 +3,13 @@
Note: Registration now requires an invite code. Tests that need to register
users will create invites first via the helper function.
"""
import pytest
from auth import COOKIE_NAME
from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
# Registration tests (with invite)
@ -19,7 +20,9 @@ async def test_register_success(client_factory):
# Create godfather user and invite
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("godfather"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post(
@ -49,7 +52,9 @@ async def test_register_duplicate_email(client_factory):
# Create godfather and two invites
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
@ -80,7 +85,9 @@ async def test_register_duplicate_email(client_factory):
async def test_register_invalid_email(client_factory):
"""Cannot register with invalid email format."""
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post(
@ -138,7 +145,9 @@ async def test_login_success(client_factory):
email = unique_email("login")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
@ -167,7 +176,9 @@ async def test_login_wrong_password(client_factory):
email = unique_email("wrongpass")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
@ -221,7 +232,9 @@ async def test_get_me_success(client_factory):
email = unique_email("me")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post(
@ -255,7 +268,9 @@ async def test_get_me_no_cookie(client):
@pytest.mark.asyncio
async def test_get_me_invalid_cookie(client_factory):
"""Cannot get current user with invalid cookie."""
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed:
async with client_factory.create(
cookies={COOKIE_NAME: "invalidtoken123"}
) as authed:
response = await authed.get("/api/auth/me")
assert response.status_code == 401
assert response.json()["detail"] == "Invalid authentication credentials"
@ -277,7 +292,9 @@ async def test_cookie_from_register_works_for_me(client_factory):
email = unique_email("tokentest")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post(
@ -303,7 +320,9 @@ async def test_cookie_from_login_works_for_me(client_factory):
email = unique_email("logintoken")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
@ -335,7 +354,9 @@ async def test_multiple_users_isolated(client_factory):
email2 = unique_email("user2")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
@ -377,7 +398,9 @@ async def test_password_is_hashed(client_factory):
email = unique_email("hashtest")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
@ -401,7 +424,9 @@ async def test_case_sensitive_password(client_factory):
email = unique_email("casetest")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
await client_factory.post(
@ -426,7 +451,9 @@ async def test_logout_success(client_factory):
email = unique_email("logout")
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg_response = await client_factory.post(

View file

@ -3,7 +3,9 @@ Availability API Tests
Tests for the admin availability management endpoints.
"""
from datetime import date, time, timedelta
from datetime import date, timedelta
import pytest
@ -19,6 +21,7 @@ def in_days(n: int) -> date:
# Permission Tests
# =============================================================================
class TestAvailabilityPermissions:
"""Test that only admins can access availability endpoints."""
@ -44,7 +47,9 @@ class TestAvailabilityPermissions:
assert response.status_code == 200
@pytest.mark.asyncio
async def test_regular_user_cannot_get_availability(self, client_factory, regular_user):
async def test_regular_user_cannot_get_availability(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get(
"/api/admin/availability",
@ -53,7 +58,9 @@ class TestAvailabilityPermissions:
assert response.status_code == 403
@pytest.mark.asyncio
async def test_regular_user_cannot_set_availability(self, client_factory, regular_user):
async def test_regular_user_cannot_set_availability(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.put(
"/api/admin/availability",
@ -88,6 +95,7 @@ class TestAvailabilityPermissions:
# Set Availability Tests
# =============================================================================
class TestSetAvailability:
"""Test setting availability for a date."""
@ -128,7 +136,9 @@ class TestSetAvailability:
assert len(data["slots"]) == 2
@pytest.mark.asyncio
async def test_set_empty_slots_clears_availability(self, client_factory, admin_user):
async def test_set_empty_slots_clears_availability(
self, client_factory, admin_user
):
async with client_factory.create(cookies=admin_user["cookies"]) as client:
# First set some availability
await client.put(
@ -162,7 +172,7 @@ class TestSetAvailability:
)
# Replace with different slots
response = await client.put(
await client.put(
"/api/admin/availability",
json={
"date": str(tomorrow()),
@ -186,6 +196,7 @@ class TestSetAvailability:
# Validation Tests
# =============================================================================
class TestAvailabilityValidation:
"""Test validation rules for availability."""
@ -283,6 +294,7 @@ class TestAvailabilityValidation:
# Get Availability Tests
# =============================================================================
class TestGetAvailability:
"""Test retrieving availability."""
@ -360,6 +372,7 @@ class TestGetAvailability:
# Copy Availability Tests
# =============================================================================
class TestCopyAvailability:
"""Test copying availability from one day to others."""

View file

@ -3,7 +3,9 @@ Booking API Tests
Tests for the user booking endpoints.
"""
from datetime import date, datetime, timedelta, timezone
from datetime import UTC, date, datetime, timedelta
import pytest
from models import Appointment, AppointmentStatus
@ -21,11 +23,14 @@ def in_days(n: int) -> date:
# Permission Tests
# =============================================================================
class TestBookingPermissions:
"""Test that only regular users can book appointments."""
@pytest.mark.asyncio
async def test_regular_user_can_get_slots(self, client_factory, regular_user, admin_user):
async def test_regular_user_can_get_slots(
self, client_factory, regular_user, admin_user
):
"""Regular user can get available slots."""
# First, admin sets up availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -39,12 +44,16 @@ class TestBookingPermissions:
# Regular user gets slots
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_regular_user_can_book(self, client_factory, regular_user, admin_user):
async def test_regular_user_can_book(
self, client_factory, regular_user, admin_user
):
"""Regular user can book an appointment."""
# Admin sets up availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -69,7 +78,9 @@ class TestBookingPermissions:
async def test_admin_cannot_get_slots(self, client_factory, admin_user):
"""Admin cannot access booking slots endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 403
@ -96,7 +107,9 @@ class TestBookingPermissions:
@pytest.mark.asyncio
async def test_unauthenticated_cannot_get_slots(self, client):
"""Unauthenticated user cannot get slots."""
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 401
@pytest.mark.asyncio
@ -113,6 +126,7 @@ class TestBookingPermissions:
# Get Slots Tests
# =============================================================================
class TestGetSlots:
"""Test getting available booking slots."""
@ -120,7 +134,9 @@ class TestGetSlots:
async def test_get_slots_no_availability(self, client_factory, regular_user):
"""Returns empty slots when no availability set."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200
data = response.json()
@ -128,7 +144,9 @@ class TestGetSlots:
assert data["slots"] == []
@pytest.mark.asyncio
async def test_get_slots_expands_to_15min(self, client_factory, regular_user, admin_user):
async def test_get_slots_expands_to_15min(
self, client_factory, regular_user, admin_user
):
"""Availability is expanded into 15-minute slots."""
# Admin sets 1-hour availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -142,7 +160,9 @@ class TestGetSlots:
# User gets slots - should be 4 x 15-minute slots
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200
data = response.json()
@ -156,7 +176,9 @@ class TestGetSlots:
assert "10:00:00" in data["slots"][3]["end_time"]
@pytest.mark.asyncio
async def test_get_slots_excludes_booked(self, client_factory, regular_user, admin_user):
async def test_get_slots_excludes_booked(
self, client_factory, regular_user, admin_user
):
"""Already booked slots are excluded from available slots."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -176,7 +198,9 @@ class TestGetSlots:
)
# Get slots again - should have 3 left
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
response = await client.get(
"/api/booking/slots", params={"date": str(tomorrow())}
)
assert response.status_code == 200
data = response.json()
@ -189,6 +213,7 @@ class TestGetSlots:
# Booking Tests
# =============================================================================
class TestCreateBooking:
"""Test creating bookings."""
@ -248,7 +273,9 @@ class TestCreateBooking:
assert data["note"] is None
@pytest.mark.asyncio
async def test_cannot_double_book_slot(self, client_factory, regular_user, admin_user, alt_regular_user):
async def test_cannot_double_book_slot(
self, client_factory, regular_user, admin_user, alt_regular_user
):
"""Cannot book a slot that's already booked."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -279,7 +306,9 @@ class TestCreateBooking:
assert "already been booked" in response.json()["detail"]
@pytest.mark.asyncio
async def test_cannot_book_outside_availability(self, client_factory, regular_user, admin_user):
async def test_cannot_book_outside_availability(
self, client_factory, regular_user, admin_user
):
"""Cannot book a slot outside of availability."""
# Admin sets availability for morning only
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -306,6 +335,7 @@ class TestCreateBooking:
# Date Validation Tests
# =============================================================================
class TestBookingDateValidation:
"""Test date validation for bookings."""
@ -319,7 +349,10 @@ class TestBookingDateValidation:
)
assert response.status_code == 400
assert "past" in response.json()["detail"].lower() or "today" in response.json()["detail"].lower()
assert (
"past" in response.json()["detail"].lower()
or "today" in response.json()["detail"].lower()
)
@pytest.mark.asyncio
async def test_cannot_book_past_date(self, client_factory, regular_user):
@ -350,7 +383,9 @@ class TestBookingDateValidation:
async def test_cannot_get_slots_today(self, client_factory, regular_user):
"""Cannot get slots for today."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(date.today())})
response = await client.get(
"/api/booking/slots", params={"date": str(date.today())}
)
assert response.status_code == 400
@ -359,7 +394,9 @@ class TestBookingDateValidation:
"""Cannot get slots for past date."""
yesterday = date.today() - timedelta(days=1)
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/booking/slots", params={"date": str(yesterday)})
response = await client.get(
"/api/booking/slots", params={"date": str(yesterday)}
)
assert response.status_code == 400
@ -368,11 +405,14 @@ class TestBookingDateValidation:
# Time Validation Tests
# =============================================================================
class TestBookingTimeValidation:
"""Test time validation for bookings."""
@pytest.mark.asyncio
async def test_slot_must_be_15min_boundary(self, client_factory, regular_user, admin_user):
async def test_slot_must_be_15min_boundary(
self, client_factory, regular_user, admin_user
):
"""Slot start time must be on 15-minute boundary."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -399,6 +439,7 @@ class TestBookingTimeValidation:
# Note Validation Tests
# =============================================================================
class TestBookingNoteValidation:
"""Test note validation for bookings."""
@ -426,7 +467,9 @@ class TestBookingNoteValidation:
assert response.status_code == 422
@pytest.mark.asyncio
async def test_note_exactly_144_chars(self, client_factory, regular_user, admin_user):
async def test_note_exactly_144_chars(
self, client_factory, regular_user, admin_user
):
"""Note of exactly 144 characters is allowed."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -454,6 +497,7 @@ class TestBookingNoteValidation:
# User Appointments Tests
# =============================================================================
class TestUserAppointments:
"""Test user appointments endpoints."""
@ -467,7 +511,9 @@ class TestUserAppointments:
assert response.json() == []
@pytest.mark.asyncio
async def test_get_my_appointments_with_bookings(self, client_factory, regular_user, admin_user):
async def test_get_my_appointments_with_bookings(
self, client_factory, regular_user, admin_user
):
"""Returns user's appointments."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -502,7 +548,9 @@ class TestUserAppointments:
assert "Second" in notes
@pytest.mark.asyncio
async def test_admin_cannot_view_user_appointments(self, client_factory, admin_user):
async def test_admin_cannot_view_user_appointments(
self, client_factory, admin_user
):
"""Admin cannot access user appointments endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/appointments")
@ -520,7 +568,9 @@ class TestCancelAppointment:
"""Test cancelling appointments."""
@pytest.mark.asyncio
async def test_cancel_own_appointment(self, client_factory, regular_user, admin_user):
async def test_cancel_own_appointment(
self, client_factory, regular_user, admin_user
):
"""User can cancel their own appointment."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -549,7 +599,9 @@ class TestCancelAppointment:
assert data["cancelled_at"] is not None
@pytest.mark.asyncio
async def test_cannot_cancel_others_appointment(self, client_factory, regular_user, alt_regular_user, admin_user):
async def test_cannot_cancel_others_appointment(
self, client_factory, regular_user, alt_regular_user, admin_user
):
"""User cannot cancel another user's appointment."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -577,7 +629,9 @@ class TestCancelAppointment:
assert "another user" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_cannot_cancel_nonexistent_appointment(self, client_factory, regular_user):
async def test_cannot_cancel_nonexistent_appointment(
self, client_factory, regular_user
):
"""Returns 404 for non-existent appointment."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post("/api/appointments/99999/cancel")
@ -585,7 +639,9 @@ class TestCancelAppointment:
assert response.status_code == 404
@pytest.mark.asyncio
async def test_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user):
async def test_cannot_cancel_already_cancelled(
self, client_factory, regular_user, admin_user
):
"""Cannot cancel an already cancelled appointment."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -613,7 +669,9 @@ class TestCancelAppointment:
assert "cancelled_by_user" in response.json()["detail"]
@pytest.mark.asyncio
async def test_admin_cannot_use_user_cancel_endpoint(self, client_factory, admin_user):
async def test_admin_cannot_use_user_cancel_endpoint(
self, client_factory, admin_user
):
"""Admin cannot use user cancel endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/appointments/1/cancel")
@ -621,7 +679,9 @@ class TestCancelAppointment:
assert response.status_code == 403
@pytest.mark.asyncio
async def test_cancelled_slot_becomes_available(self, client_factory, regular_user, admin_user):
async def test_cancelled_slot_becomes_available(
self, client_factory, regular_user, admin_user
):
"""After cancelling, the slot becomes available again."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -663,7 +723,7 @@ class TestCancelAppointment:
"""User cannot cancel a past appointment."""
# Create a past appointment directly in DB
async with client_factory.get_db_session() as db:
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
past_time = datetime.now(UTC) - timedelta(hours=1)
appointment = Appointment(
user_id=regular_user["user"]["id"],
slot_start=past_time,
@ -687,11 +747,14 @@ class TestCancelAppointment:
# Admin Appointments Tests
# =============================================================================
class TestAdminViewAppointments:
"""Test admin viewing all appointments."""
@pytest.mark.asyncio
async def test_admin_can_view_all_appointments(self, client_factory, regular_user, admin_user):
async def test_admin_can_view_all_appointments(
self, client_factory, regular_user, admin_user
):
"""Admin can view all appointments."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -725,7 +788,9 @@ class TestAdminViewAppointments:
assert any(apt["note"] == "Test" for apt in data["records"])
@pytest.mark.asyncio
async def test_regular_user_cannot_view_all_appointments(self, client_factory, regular_user):
async def test_regular_user_cannot_view_all_appointments(
self, client_factory, regular_user
):
"""Regular user cannot access admin appointments endpoint."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/admin/appointments")
@ -743,7 +808,9 @@ class TestAdminCancelAppointment:
"""Test admin cancelling appointments."""
@pytest.mark.asyncio
async def test_admin_can_cancel_any_appointment(self, client_factory, regular_user, admin_user):
async def test_admin_can_cancel_any_appointment(
self, client_factory, regular_user, admin_user
):
"""Admin can cancel any user's appointment."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -765,7 +832,9 @@ class TestAdminCancelAppointment:
# Admin cancels
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 200
data = response.json()
@ -773,7 +842,9 @@ class TestAdminCancelAppointment:
assert data["cancelled_at"] is not None
@pytest.mark.asyncio
async def test_regular_user_cannot_use_admin_cancel(self, client_factory, regular_user, admin_user):
async def test_regular_user_cannot_use_admin_cancel(
self, client_factory, regular_user, admin_user
):
"""Regular user cannot use admin cancel endpoint."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -799,7 +870,9 @@ class TestAdminCancelAppointment:
assert response.status_code == 403
@pytest.mark.asyncio
async def test_admin_cancel_nonexistent_appointment(self, client_factory, admin_user):
async def test_admin_cancel_nonexistent_appointment(
self, client_factory, admin_user
):
"""Returns 404 for non-existent appointment."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/admin/appointments/99999/cancel")
@ -807,7 +880,9 @@ class TestAdminCancelAppointment:
assert response.status_code == 404
@pytest.mark.asyncio
async def test_admin_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user):
async def test_admin_cannot_cancel_already_cancelled(
self, client_factory, regular_user, admin_user
):
"""Admin cannot cancel an already cancelled appointment."""
# Admin sets availability
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
@ -832,17 +907,21 @@ class TestAdminCancelAppointment:
# Admin tries to cancel again
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 400
assert "cancelled_by_user" in response.json()["detail"]
@pytest.mark.asyncio
async def test_admin_cannot_cancel_past_appointment(self, client_factory, regular_user, admin_user):
async def test_admin_cannot_cancel_past_appointment(
self, client_factory, regular_user, admin_user
):
"""Admin cannot cancel a past appointment."""
# Create a past appointment directly in DB
async with client_factory.get_db_session() as db:
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
past_time = datetime.now(UTC) - timedelta(hours=1)
appointment = Appointment(
user_id=regular_user["user"]["id"],
slot_start=past_time,
@ -856,8 +935,9 @@ class TestAdminCancelAppointment:
# Admin tries to cancel
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
response = await admin_client.post(
f"/api/admin/appointments/{apt_id}/cancel"
)
assert response.status_code == 400
assert "past" in response.json()["detail"].lower()

View file

@ -2,12 +2,13 @@
Note: Registration now requires an invite code.
"""
import pytest
from auth import COOKIE_NAME
from models import ROLE_REGULAR
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
# Protected endpoint tests - without auth
@ -41,7 +42,9 @@ async def test_increment_counter_invalid_cookie(client_factory):
@pytest.mark.asyncio
async def test_get_counter_authenticated(client_factory):
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
@ -64,7 +67,9 @@ async def test_get_counter_authenticated(client_factory):
@pytest.mark.asyncio
async def test_increment_counter(client_factory):
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
@ -91,7 +96,9 @@ async def test_increment_counter(client_factory):
@pytest.mark.asyncio
async def test_increment_counter_multiple(client_factory):
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
@ -120,7 +127,9 @@ async def test_increment_counter_multiple(client_factory):
@pytest.mark.asyncio
async def test_get_counter_after_increment(client_factory):
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
@ -149,7 +158,9 @@ async def test_get_counter_after_increment(client_factory):
async def test_counter_shared_between_users(client_factory):
# Create godfather and invites for two users
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)

View file

@ -1,22 +1,23 @@
"""Tests for invite functionality."""
import pytest
from sqlalchemy import select
from invite_utils import (
generate_invite_identifier,
normalize_identifier,
is_valid_identifier_format,
BIP39_WORDS,
generate_invite_identifier,
is_valid_identifier_format,
normalize_identifier,
)
from models import Invite, InviteStatus, User, ROLE_REGULAR
from tests.helpers import unique_email
from models import ROLE_REGULAR, Invite, InviteStatus, User
from tests.conftest import create_user_with_roles
from tests.helpers import unique_email
# ============================================================================
# Invite Utils Tests
# ============================================================================
def test_bip39_words_loaded():
"""BIP39 word list should have exactly 2048 words."""
assert len(BIP39_WORDS) == 2048
@ -89,6 +90,7 @@ def test_is_valid_identifier_format_invalid():
# Invite Model Tests
# ============================================================================
@pytest.mark.asyncio
async def test_create_invite(client_factory):
"""Can create an invite with godfather."""
@ -173,7 +175,7 @@ async def test_invite_unique_identifier(client_factory):
@pytest.mark.asyncio
async def test_invite_status_transitions(client_factory):
"""Invite status can be changed."""
from datetime import datetime, UTC
from datetime import UTC, datetime
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(
@ -206,7 +208,7 @@ async def test_invite_status_transitions(client_factory):
@pytest.mark.asyncio
async def test_invite_revoke(client_factory):
"""Invite can be revoked."""
from datetime import datetime, UTC
from datetime import UTC, datetime
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(
@ -236,6 +238,7 @@ async def test_invite_revoke(client_factory):
# User Godfather Tests
# ============================================================================
@pytest.mark.asyncio
async def test_user_godfather_relationship(client_factory):
"""User can have a godfather."""
@ -254,9 +257,7 @@ async def test_user_godfather_relationship(client_factory):
await db.commit()
# Query user fresh
result = await db.execute(
select(User).where(User.id == user.id)
)
result = await db.execute(select(User).where(User.id == user.id))
loaded_user = result.scalar_one()
assert loaded_user.godfather_id == godfather.id
@ -280,6 +281,7 @@ async def test_user_without_godfather(client_factory):
# Admin Create Invite API Tests (Phase 2)
# ============================================================================
@pytest.mark.asyncio
async def test_admin_can_create_invite(client_factory, admin_user, regular_user):
"""Admin can create an invite for a regular user."""
@ -387,9 +389,7 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
# Query from DB
async with client_factory.get_db_session() as db:
result = await db.execute(
select(Invite).where(Invite.id == invite_id)
)
result = await db.execute(select(Invite).where(Invite.id == invite_id))
invite = result.scalar_one()
assert invite.identifier == data["identifier"]
@ -398,7 +398,9 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
@pytest.mark.asyncio
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user):
async def test_create_invite_retries_on_collision(
client_factory, admin_user, regular_user
):
"""Create invite retries with new identifier on collision."""
from unittest.mock import patch
@ -419,6 +421,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
# Mock generator to first return the same identifier (collision), then a new one
call_count = 0
def mock_generator():
nonlocal call_count
call_count += 1
@ -426,7 +429,9 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
return identifier1 # Will collide
return f"unique-word-{call_count:02d}" # Won't collide
with patch("routes.invites.generate_invite_identifier", side_effect=mock_generator):
with patch(
"routes.invites.generate_invite_identifier", side_effect=mock_generator
):
response2 = await client.post(
"/api/admin/invites",
json={"godfather_id": godfather.id},
@ -442,6 +447,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
# Invite Check API Tests (Phase 3)
# ============================================================================
@pytest.mark.asyncio
async def test_check_invite_valid(client_factory, admin_user, regular_user):
"""Check endpoint returns valid=True for READY invite."""
@ -509,7 +515,9 @@ async def test_check_invite_invalid_format(client_factory):
@pytest.mark.asyncio
async def test_check_invite_spent_returns_not_found(client_factory, admin_user, regular_user):
async def test_check_invite_spent_returns_not_found(
client_factory, admin_user, regular_user
):
"""Check endpoint returns same error for spent invite as for non-existent (no info leakage)."""
# Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -547,9 +555,11 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
@pytest.mark.asyncio
async def test_check_invite_revoked_returns_not_found(client_factory, admin_user, regular_user):
async def test_check_invite_revoked_returns_not_found(
client_factory, admin_user, regular_user
):
"""Check endpoint returns same error for revoked invite as for non-existent (no info leakage)."""
from datetime import datetime, UTC
from datetime import UTC, datetime
# Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -613,6 +623,7 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular
# Register with Invite Tests (Phase 3)
# ============================================================================
@pytest.mark.asyncio
async def test_register_with_valid_invite(client_factory, admin_user, regular_user):
"""Can register with valid invite code."""
@ -681,9 +692,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
# Check invite status
async with client_factory.get_db_session() as db:
result = await db.execute(
select(Invite).where(Invite.id == invite_id)
)
result = await db.execute(select(Invite).where(Invite.id == invite_id))
invite = result.scalar_one()
assert invite.status == InviteStatus.SPENT
@ -723,9 +732,7 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user)
# Check user's godfather
async with client_factory.get_db_session() as db:
result = await db.execute(
select(User).where(User.email == new_email)
)
result = await db.execute(select(User).where(User.email == new_email))
new_user = result.scalar_one()
assert new_user.godfather_id == godfather_id
@ -794,7 +801,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
@pytest.mark.asyncio
async def test_register_with_revoked_invite(client_factory, admin_user, regular_user):
"""Cannot register with revoked invite."""
from datetime import datetime, UTC
from datetime import UTC, datetime
# Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -814,9 +821,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
# Revoke invite directly in DB
async with client_factory.get_db_session() as db:
result = await db.execute(
select(Invite).where(Invite.id == invite_id)
)
result = await db.execute(select(Invite).where(Invite.id == invite_id))
invite = result.scalar_one()
invite.status = InviteStatus.REVOKED
invite.revoked_at = datetime.now(UTC)
@ -904,6 +909,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
# User Invites API Tests (Phase 4)
# ============================================================================
@pytest.mark.asyncio
async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user):
"""Regular user can list their own invites."""
@ -941,7 +947,9 @@ async def test_user_with_no_invites_gets_empty_list(client_factory, regular_user
@pytest.mark.asyncio
async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regular_user):
async def test_spent_invite_shows_used_by_email(
client_factory, admin_user, regular_user
):
"""Spent invite shows who used it."""
# Create invite for regular user
async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -1002,6 +1010,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory):
# Admin Invite Management Tests (Phase 5)
# ============================================================================
@pytest.mark.asyncio
async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user):
"""Admin can list all invites."""
@ -1153,4 +1162,3 @@ async def test_regular_user_cannot_access_admin_invites(client_factory, regular_
# Revoke
response = await client.post("/api/admin/invites/1/revoke")
assert response.status_code == 403

View file

@ -7,15 +7,16 @@ These tests verify that:
3. Unauthenticated users are denied access (401)
4. The permission system cannot be bypassed
"""
import pytest
from models import Permission
# =============================================================================
# Role Assignment Tests
# =============================================================================
class TestRoleAssignment:
"""Test that roles are properly assigned and returned."""
@ -40,7 +41,9 @@ class TestRoleAssignment:
assert "regular" not in data["roles"]
@pytest.mark.asyncio
async def test_regular_user_has_correct_permissions(self, client_factory, regular_user):
async def test_regular_user_has_correct_permissions(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/auth/me")
@ -72,7 +75,9 @@ class TestRoleAssignment:
assert Permission.USE_SUM.value not in permissions
@pytest.mark.asyncio
async def test_user_with_no_roles_has_no_permissions(self, client_factory, user_no_roles):
async def test_user_with_no_roles_has_no_permissions(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/auth/me")
@ -85,6 +90,7 @@ class TestRoleAssignment:
# Counter Endpoint Access Tests
# =============================================================================
class TestCounterAccess:
"""Test access control for counter endpoints."""
@ -97,7 +103,9 @@ class TestCounterAccess:
assert "value" in response.json()
@pytest.mark.asyncio
async def test_regular_user_can_increment_counter(self, client_factory, regular_user):
async def test_regular_user_can_increment_counter(
self, client_factory, regular_user
):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post("/api/counter/increment")
@ -122,7 +130,9 @@ class TestCounterAccess:
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_view_counter(self, client_factory, user_no_roles):
async def test_user_without_roles_cannot_view_counter(
self, client_factory, user_no_roles
):
"""Users with no roles should be forbidden."""
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/counter")
@ -146,6 +156,7 @@ class TestCounterAccess:
# Sum Endpoint Access Tests
# =============================================================================
class TestSumAccess:
"""Test access control for sum endpoint."""
@ -173,7 +184,9 @@ class TestSumAccess:
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_use_sum(self, client_factory, user_no_roles):
async def test_user_without_roles_cannot_use_sum(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.post(
"/api/sum",
@ -195,6 +208,7 @@ class TestSumAccess:
# Audit Endpoint Access Tests
# =============================================================================
class TestAuditAccess:
"""Test access control for audit endpoints."""
@ -219,7 +233,9 @@ class TestAuditAccess:
assert "total" in data
@pytest.mark.asyncio
async def test_regular_user_cannot_view_counter_audit(self, client_factory, regular_user):
async def test_regular_user_cannot_view_counter_audit(
self, client_factory, regular_user
):
"""Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
@ -228,7 +244,9 @@ class TestAuditAccess:
assert "permission" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_regular_user_cannot_view_sum_audit(self, client_factory, regular_user):
async def test_regular_user_cannot_view_sum_audit(
self, client_factory, regular_user
):
"""Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/sum")
@ -236,7 +254,9 @@ class TestAuditAccess:
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_view_audit(self, client_factory, user_no_roles):
async def test_user_without_roles_cannot_view_audit(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/audit/counter")
@ -257,6 +277,7 @@ class TestAuditAccess:
# Offensive Security Tests - Bypass Attempts
# =============================================================================
class TestSecurityBypassAttempts:
"""
Offensive tests that attempt to bypass security controls.
@ -264,7 +285,9 @@ class TestSecurityBypassAttempts:
"""
@pytest.mark.asyncio
async def test_cannot_access_audit_with_forged_role_claim(self, client_factory, regular_user):
async def test_cannot_access_audit_with_forged_role_claim(
self, client_factory, regular_user
):
"""
Attempt to access audit by somehow claiming admin role.
The server should verify roles from DB, not trust client claims.
@ -287,14 +310,18 @@ class TestSecurityBypassAttempts:
assert response.status_code == 401
@pytest.mark.asyncio
async def test_cannot_access_with_tampered_token(self, client_factory, regular_user):
async def test_cannot_access_with_tampered_token(
self, client_factory, regular_user
):
"""Test that tokens signed with wrong key are rejected."""
# Take a valid token and modify it
original_token = regular_user["cookies"].get("auth_token", "")
if original_token:
tampered_token = original_token[:-5] + "XXXXX"
async with client_factory.create(cookies={"auth_token": tampered_token}) as client:
async with client_factory.create(
cookies={"auth_token": tampered_token}
) as client:
response = await client.get("/api/counter")
assert response.status_code == 401
@ -305,12 +332,14 @@ class TestSecurityBypassAttempts:
Test that new registrations cannot claim admin role.
New users should only get 'regular' role by default.
"""
from tests.helpers import unique_email, create_invite_for_godfather
from tests.conftest import create_user_with_roles
from models import ROLE_REGULAR
from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
response = await client_factory.post(
@ -341,15 +370,17 @@ class TestSecurityBypassAttempts:
If a user is deleted, their token should no longer work.
This tests that tokens are validated against current DB state.
"""
from tests.helpers import unique_email
from sqlalchemy import delete
from models import User
from tests.helpers import unique_email
email = unique_email("deleted")
# Create and login user
async with client_factory.get_db_session() as db:
from tests.conftest import create_user_with_roles
user = await create_user_with_roles(db, email, "password123", ["regular"])
user_id = user.id
@ -376,15 +407,17 @@ class TestSecurityBypassAttempts:
If a user's role is changed, the change should be reflected
in subsequent requests (no stale permission cache).
"""
from tests.helpers import unique_email
from sqlalchemy import select
from models import User, Role
from models import Role, User
from tests.helpers import unique_email
email = unique_email("rolechange")
# Create regular user
async with client_factory.get_db_session() as db:
from tests.conftest import create_user_with_roles
await create_user_with_roles(db, email, "password123", ["regular"])
login_response = await client_factory.post(
@ -406,10 +439,7 @@ class TestSecurityBypassAttempts:
result = await db.execute(select(Role).where(Role.name == "admin"))
admin_role = result.scalar_one()
result = await db.execute(select(Role).where(Role.name == "regular"))
regular_role = result.scalar_one()
user.roles = [admin_role] # Remove regular, add admin
user.roles = [admin_role] # Replace roles with admin only
await db.commit()
# Now should have audit access but not counter access
@ -422,6 +452,7 @@ class TestSecurityBypassAttempts:
# Audit Record Tests
# =============================================================================
class TestAuditRecords:
"""Test that actions are properly recorded in audit logs."""
@ -466,7 +497,7 @@ class TestAuditRecords:
# Find record with our values
records = data["records"]
matching = [r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30]
matching = [
r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30
]
assert len(matching) >= 1

View file

@ -1,9 +1,9 @@
"""Tests for user profile and contact details."""
import pytest
from sqlalchemy import select
from models import User, ROLE_REGULAR
from auth import get_password_hash
from models import User
from tests.helpers import unique_email
# Valid npub for testing (32 zero bytes encoded as bech32)
@ -328,7 +328,9 @@ class TestUpdateProfileEndpoint:
assert "field_errors" in data["detail"]
assert "nostr_npub" in data["detail"]["field_errors"]
async def test_multiple_invalid_fields_returns_all_errors(self, client_factory, regular_user):
async def test_multiple_invalid_fields_returns_all_errors(
self, client_factory, regular_user
):
"""Multiple invalid fields return all errors."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.put(
@ -344,7 +346,9 @@ class TestUpdateProfileEndpoint:
assert "contact_email" in data["detail"]["field_errors"]
assert "telegram" in data["detail"]["field_errors"]
async def test_partial_update_preserves_other_fields(self, client_factory, regular_user):
async def test_partial_update_preserves_other_fields(
self, client_factory, regular_user
):
"""Updating one field doesn't affect others (they get set to the request values)."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
# Set initial values
@ -402,11 +406,14 @@ class TestProfilePrivacy:
class TestProfileGodfather:
"""Tests for godfather information in profile."""
async def test_profile_shows_godfather_email(self, client_factory, admin_user, regular_user):
async def test_profile_shows_godfather_email(
self, client_factory, admin_user, regular_user
):
"""Profile shows godfather email for users who signed up with invite."""
from tests.helpers import unique_email
from sqlalchemy import select
from models import User
from tests.helpers import unique_email
# Create invite
async with client_factory.create(cookies=admin_user["cookies"]) as client:
@ -443,7 +450,9 @@ class TestProfileGodfather:
data = response.json()
assert data["godfather_email"] == regular_user["email"]
async def test_profile_godfather_null_for_seeded_users(self, client_factory, regular_user):
async def test_profile_godfather_null_for_seeded_users(
self, client_factory, regular_user
):
"""Profile shows null godfather for users without one (e.g., seeded users)."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/profile")

View file

@ -1,12 +1,11 @@
"""Tests for profile field validation."""
import pytest
from validation import (
validate_contact_email,
validate_telegram,
validate_signal,
validate_nostr_npub,
validate_profile_fields,
validate_signal,
validate_telegram,
)
@ -140,13 +139,17 @@ class TestValidateNostrNpub:
assert validate_nostr_npub(self.VALID_NPUB) is None
def test_wrong_prefix(self):
result = validate_nostr_npub("nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz")
result = validate_nostr_npub(
"nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz"
)
assert result is not None
assert "npub" in result.lower()
def test_invalid_checksum(self):
# Change last character to break checksum
result = validate_nostr_npub("npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd")
result = validate_nostr_npub(
"npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd"
)
assert result is not None
assert "checksum" in result.lower()
@ -155,7 +158,9 @@ class TestValidateNostrNpub:
assert result is not None
def test_not_starting_with_npub1(self):
result = validate_nostr_npub("npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc")
result = validate_nostr_npub(
"npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc"
)
assert result is not None
assert "npub1" in result
@ -206,4 +211,3 @@ class TestValidateProfileFields:
nostr_npub="",
)
assert errors == {}

View file

@ -1,8 +1,9 @@
"""Validate shared constants match backend definitions."""
import json
from pathlib import Path
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, AppointmentStatus
from models import ROLE_ADMIN, ROLE_REGULAR, AppointmentStatus, InviteStatus
def validate_shared_constants() -> None:
@ -29,35 +30,42 @@ def validate_shared_constants() -> None:
# Validate invite statuses
expected_invite_statuses = {s.name: s.value for s in InviteStatus}
if constants.get("inviteStatuses") != expected_invite_statuses:
got = constants.get("inviteStatuses")
raise ValueError(
f"Invite status mismatch in shared/constants.json. "
f"Expected: {expected_invite_statuses}, Got: {constants.get('inviteStatuses')}"
f"Invite status mismatch. Expected: {expected_invite_statuses}, Got: {got}"
)
# Validate appointment statuses
expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus}
if constants.get("appointmentStatuses") != expected_appointment_statuses:
got = constants.get("appointmentStatuses")
raise ValueError(
f"Appointment status mismatch in shared/constants.json. "
f"Expected: {expected_appointment_statuses}, Got: {constants.get('appointmentStatuses')}"
f"Appointment status mismatch. "
f"Expected: {expected_appointment_statuses}, Got: {got}"
)
# Validate booking constants exist with required fields
booking = constants.get("booking", {})
required_booking_fields = ["slotDurationMinutes", "maxAdvanceDays", "minAdvanceDays", "noteMaxLength"]
required_booking_fields = [
"slotDurationMinutes",
"maxAdvanceDays",
"minAdvanceDays",
"noteMaxLength",
]
for field in required_booking_fields:
if field not in booking:
raise ValueError(f"Missing booking constant '{field}' in shared/constants.json")
raise ValueError(f"Missing booking constant '{field}' in constants.json")
# Validate validation rules exist (structure check only)
validation = constants.get("validation", {})
required_fields = ["telegram", "signal", "nostrNpub"]
for field in required_fields:
if field not in validation:
raise ValueError(f"Missing validation rules for '{field}' in shared/constants.json")
raise ValueError(
f"Missing validation rules for '{field}' in constants.json"
)
if __name__ == "__main__":
validate_shared_constants()
print("✓ Shared constants are valid")

View file

@ -1,9 +1,10 @@
"""Validation utilities for user profile fields."""
import json
from pathlib import Path
from email_validator import validate_email, EmailNotValidError
from bech32 import bech32_decode
from email_validator import EmailNotValidError, validate_email
# Load validation rules from shared constants
_constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
@ -143,4 +144,3 @@ def validate_profile_fields(
errors["nostr_npub"] = err
return errors