Add ruff linter/formatter for Python
- Add ruff as dev dependency - Configure ruff in pyproject.toml with strict 88-char line limit - Ignore B008 (FastAPI Depends pattern is standard) - Allow longer lines in tests for readability - Fix all lint issues in source files - Add Makefile targets: lint-backend, format-backend, fix-backend
This commit is contained in:
parent
69bc8413e0
commit
6c218130e9
31 changed files with 1234 additions and 876 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import bcrypt
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
|
@ -8,7 +8,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_db
|
||||
from models import User, Permission
|
||||
from models import Permission, User
|
||||
from schemas import UserResponse
|
||||
|
||||
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
|
||||
|
|
@ -32,9 +32,13 @@ def get_password_hash(password: str) -> str:
|
|||
).decode("utf-8")
|
||||
|
||||
|
||||
def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str:
|
||||
def create_access_token(
|
||||
data: dict[str, str],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
to_encode: dict[str, str | datetime] = dict(data)
|
||||
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(UTC) + delta
|
||||
to_encode["exp"] = expire
|
||||
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded
|
||||
|
|
@ -60,11 +64,11 @@ async def get_current_user(
|
|||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
|
||||
|
||||
token = request.cookies.get(COOKIE_NAME)
|
||||
if not token:
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id_str = payload.get("sub")
|
||||
|
|
@ -72,7 +76,7 @@ async def get_current_user(
|
|||
raise credentials_exception
|
||||
user_id = int(user_id_str)
|
||||
except (JWTError, ValueError):
|
||||
raise credentials_exception
|
||||
raise credentials_exception from None
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
|
@ -83,27 +87,32 @@ async def get_current_user(
|
|||
|
||||
def require_permission(*required_permissions: Permission):
|
||||
"""
|
||||
Dependency factory that checks if user has ALL of the required permissions.
|
||||
|
||||
Dependency factory that checks if user has ALL required permissions.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/counter")
|
||||
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))):
|
||||
async def get_counter(
|
||||
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
|
||||
):
|
||||
...
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
user = await get_current_user(request, db)
|
||||
user_permissions = await user.get_permissions(db)
|
||||
|
||||
|
||||
missing = [p for p in required_permissions if p not in user_permissions]
|
||||
if missing:
|
||||
missing_str = ", ".join(p.value for p in missing)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}",
|
||||
detail=f"Missing required permissions: {missing_str}",
|
||||
)
|
||||
return user
|
||||
|
||||
return permission_checker
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import os
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret")
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret"
|
||||
)
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
|
@ -15,4 +18,3 @@ class Base(DeclarativeBase):
|
|||
async def get_db():
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""Utilities for invite code generation and validation."""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -13,11 +14,11 @@ assert len(BIP39_WORDS) == 2048, f"Expected 2048 BIP39 words, got {len(BIP39_WOR
|
|||
def generate_invite_identifier() -> str:
|
||||
"""
|
||||
Generate a unique invite identifier.
|
||||
|
||||
|
||||
Format: word1-word2-NN where:
|
||||
- word1, word2 are random BIP39 words
|
||||
- NN is a two-digit number (00-99)
|
||||
|
||||
|
||||
Returns lowercase identifier.
|
||||
"""
|
||||
word1 = random.choice(BIP39_WORDS)
|
||||
|
|
@ -29,7 +30,7 @@ def generate_invite_identifier() -> str:
|
|||
def normalize_identifier(identifier: str) -> str:
|
||||
"""
|
||||
Normalize an invite identifier for comparison/lookup.
|
||||
|
||||
|
||||
- Converts to lowercase
|
||||
- Strips whitespace
|
||||
"""
|
||||
|
|
@ -39,22 +40,18 @@ def normalize_identifier(identifier: str) -> str:
|
|||
def is_valid_identifier_format(identifier: str) -> bool:
|
||||
"""
|
||||
Check if an identifier has valid format (word-word-NN).
|
||||
|
||||
|
||||
Does NOT check if words are valid BIP39 words.
|
||||
"""
|
||||
parts = identifier.split("-")
|
||||
if len(parts) != 3:
|
||||
return False
|
||||
|
||||
|
||||
word1, word2, number = parts
|
||||
|
||||
|
||||
# Check words are non-empty
|
||||
if not word1 or not word2:
|
||||
return False
|
||||
|
||||
# Check number is two digits
|
||||
if len(number) != 2 or not number.isdigit():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# Check number is two digits
|
||||
return len(number) == 2 and number.isdigit()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
"""FastAPI application entry point."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from database import engine, Base
|
||||
from routes import sum as sum_routes
|
||||
from routes import counter as counter_routes
|
||||
from database import Base, engine
|
||||
from routes import audit as audit_routes
|
||||
from routes import profile as profile_routes
|
||||
from routes import invites as invites_routes
|
||||
from routes import auth as auth_routes
|
||||
from routes import meta as meta_routes
|
||||
from routes import availability as availability_routes
|
||||
from routes import booking as booking_routes
|
||||
from routes import counter as counter_routes
|
||||
from routes import invites as invites_routes
|
||||
from routes import meta as meta_routes
|
||||
from routes import profile as profile_routes
|
||||
from routes import sum as sum_routes
|
||||
from validate_constants import validate_shared_constants
|
||||
|
||||
|
||||
|
|
@ -22,7 +23,7 @@ async def lifespan(app: FastAPI):
|
|||
"""Create database tables on startup and validate constants."""
|
||||
# Validate shared constants match backend definitions
|
||||
validate_shared_constants()
|
||||
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -1,9 +1,24 @@
|
|||
from datetime import datetime, date, time, timezone
|
||||
from datetime import UTC, date, datetime, time
|
||||
from enum import Enum as PyEnum
|
||||
from typing import TypedDict
|
||||
from sqlalchemy import Integer, String, Float, DateTime, Date, Time, ForeignKey, Table, Column, Enum, UniqueConstraint, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Table,
|
||||
Time,
|
||||
UniqueConstraint,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from database import Base
|
||||
|
||||
|
||||
|
|
@ -14,25 +29,26 @@ class RoleConfig(TypedDict):
|
|||
|
||||
class Permission(str, PyEnum):
|
||||
"""All available permissions in the system."""
|
||||
|
||||
# Counter permissions
|
||||
VIEW_COUNTER = "view_counter"
|
||||
INCREMENT_COUNTER = "increment_counter"
|
||||
|
||||
|
||||
# Sum permissions
|
||||
USE_SUM = "use_sum"
|
||||
|
||||
|
||||
# Audit permissions
|
||||
VIEW_AUDIT = "view_audit"
|
||||
|
||||
|
||||
# Invite permissions
|
||||
MANAGE_INVITES = "manage_invites"
|
||||
VIEW_OWN_INVITES = "view_own_invites"
|
||||
|
||||
|
||||
# Booking permissions (regular users)
|
||||
BOOK_APPOINTMENT = "book_appointment"
|
||||
VIEW_OWN_APPOINTMENTS = "view_own_appointments"
|
||||
CANCEL_OWN_APPOINTMENT = "cancel_own_appointment"
|
||||
|
||||
|
||||
# Availability/Appointments permissions (admin)
|
||||
MANAGE_AVAILABILITY = "manage_availability"
|
||||
VIEW_ALL_APPOINTMENTS = "view_all_appointments"
|
||||
|
|
@ -41,6 +57,7 @@ class Permission(str, PyEnum):
|
|||
|
||||
class InviteStatus(str, PyEnum):
|
||||
"""Status of an invite."""
|
||||
|
||||
READY = "ready"
|
||||
SPENT = "spent"
|
||||
REVOKED = "revoked"
|
||||
|
|
@ -48,6 +65,7 @@ class InviteStatus(str, PyEnum):
|
|||
|
||||
class AppointmentStatus(str, PyEnum):
|
||||
"""Status of an appointment."""
|
||||
|
||||
BOOKED = "booked"
|
||||
CANCELLED_BY_USER = "cancelled_by_user"
|
||||
CANCELLED_BY_ADMIN = "cancelled_by_admin"
|
||||
|
|
@ -60,7 +78,7 @@ ROLE_REGULAR = "regular"
|
|||
# Role definitions with their permissions
|
||||
ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
||||
ROLE_ADMIN: {
|
||||
"description": "Administrator with audit, invite, and appointment management access",
|
||||
"description": "Administrator with audit/invite/appointment access",
|
||||
"permissions": [
|
||||
Permission.VIEW_AUDIT,
|
||||
Permission.MANAGE_INVITES,
|
||||
|
|
@ -88,7 +106,12 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
|||
role_permissions = Table(
|
||||
"role_permissions",
|
||||
Base.metadata,
|
||||
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column(
|
||||
"role_id",
|
||||
Integer,
|
||||
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("permission", Enum(Permission), primary_key=True),
|
||||
)
|
||||
|
||||
|
|
@ -97,8 +120,18 @@ role_permissions = Table(
|
|||
user_roles = Table(
|
||||
"user_roles",
|
||||
Base.metadata,
|
||||
Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column(
|
||||
"user_id",
|
||||
Integer,
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
Column(
|
||||
"role_id",
|
||||
Integer,
|
||||
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -108,7 +141,7 @@ class Role(Base):
|
|||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||
description: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
|
||||
|
||||
# Relationship to users
|
||||
users: Mapped[list["User"]] = relationship(
|
||||
"User",
|
||||
|
|
@ -118,31 +151,42 @@ class Role(Base):
|
|||
|
||||
async def get_permissions(self, db: AsyncSession) -> set[Permission]:
|
||||
"""Get all permissions for this role."""
|
||||
result = await db.execute(
|
||||
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id)
|
||||
query = select(role_permissions.c.permission).where(
|
||||
role_permissions.c.role_id == self.id
|
||||
)
|
||||
result = await db.execute(query)
|
||||
return {row[0] for row in result.fetchall()}
|
||||
|
||||
async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None:
|
||||
async def set_permissions(
|
||||
self, db: AsyncSession, permissions: list[Permission]
|
||||
) -> None:
|
||||
"""Set all permissions for this role (replaces existing)."""
|
||||
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id))
|
||||
delete_query = role_permissions.delete().where(
|
||||
role_permissions.c.role_id == self.id
|
||||
)
|
||||
await db.execute(delete_query)
|
||||
for perm in permissions:
|
||||
await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm))
|
||||
insert_query = role_permissions.insert().values(
|
||||
role_id=self.id, permission=perm
|
||||
)
|
||||
await db.execute(insert_query)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
email: Mapped[str] = mapped_column(
|
||||
String(255), unique=True, nullable=False, index=True
|
||||
)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
|
||||
# Contact details (all optional)
|
||||
contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
telegram: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
signal: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
nostr_npub: Mapped[str | None] = mapped_column(String(63), nullable=True)
|
||||
|
||||
|
||||
# Godfather (who invited this user) - null for seeded/admin users
|
||||
godfather_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=True
|
||||
|
|
@ -152,7 +196,7 @@ class User(Base):
|
|||
remote_side="User.id",
|
||||
foreign_keys=[godfather_id],
|
||||
)
|
||||
|
||||
|
||||
# Relationship to roles
|
||||
roles: Mapped[list[Role]] = relationship(
|
||||
"Role",
|
||||
|
|
@ -192,12 +236,14 @@ class SumRecord(Base):
|
|||
__tablename__ = "sum_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
a: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
b: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
result: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -205,11 +251,13 @@ class CounterRecord(Base):
|
|||
__tablename__ = "counter_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
value_before: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
value_after: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -217,11 +265,13 @@ class Invite(Base):
|
|||
__tablename__ = "invites"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
identifier: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
identifier: Mapped[str] = mapped_column(
|
||||
String(64), unique=True, nullable=False, index=True
|
||||
)
|
||||
status: Mapped[InviteStatus] = mapped_column(
|
||||
Enum(InviteStatus), nullable=False, default=InviteStatus.READY
|
||||
)
|
||||
|
||||
|
||||
# Godfather - the user who owns this invite
|
||||
godfather_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
|
|
@ -231,7 +281,7 @@ class Invite(Base):
|
|||
foreign_keys=[godfather_id],
|
||||
lazy="joined",
|
||||
)
|
||||
|
||||
|
||||
# User who used this invite (null until spent)
|
||||
used_by_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=True
|
||||
|
|
@ -241,17 +291,22 @@ class Invite(Base):
|
|||
foreign_keys=[used_by_id],
|
||||
lazy="joined",
|
||||
)
|
||||
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
spent_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
spent_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
|
||||
class Availability(Base):
|
||||
"""Admin availability slots for booking."""
|
||||
|
||||
__tablename__ = "availability"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("date", "start_time", name="uq_availability_date_start"),
|
||||
|
|
@ -262,34 +317,37 @@ class Availability(Base):
|
|||
start_time: Mapped[time] = mapped_column(Time, nullable=False)
|
||||
end_time: Mapped[time] = mapped_column(Time, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
class Appointment(Base):
|
||||
"""User appointment bookings."""
|
||||
|
||||
__tablename__ = "appointments"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("slot_start", name="uq_appointment_slot_start"),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("slot_start", name="uq_appointment_slot_start"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined")
|
||||
slot_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
slot_start: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, index=True
|
||||
)
|
||||
slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
note: Mapped[str | None] = mapped_column(String(144), nullable=True)
|
||||
status: Mapped[AppointmentStatus] = mapped_column(
|
||||
Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
cancelled_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
cancelled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ dev = [
|
|||
"httpx>=0.28.1",
|
||||
"aiosqlite>=0.20.0",
|
||||
"mypy>=1.13.0",
|
||||
"ruff>=0.14.10",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
|
@ -30,3 +31,27 @@ check_untyped_defs = true
|
|||
ignore_missing_imports = true
|
||||
exclude = ["tests/"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"UP", # pyupgrade
|
||||
"SIM", # flake8-simplify
|
||||
"RUF", # ruff-specific rules
|
||||
]
|
||||
ignore = [
|
||||
"B008", # function-call-in-default-argument (standard FastAPI pattern with Depends)
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["E501"] # Allow longer lines in tests for readability
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
"""Audit routes for viewing action records."""
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import User, SumRecord, CounterRecord, Permission
|
||||
from models import CounterRecord, Permission, SumRecord, User
|
||||
from schemas import (
|
||||
CounterRecordResponse,
|
||||
SumRecordResponse,
|
||||
PaginatedCounterRecords,
|
||||
PaginatedSumRecords,
|
||||
SumRecordResponse,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/audit", tags=["audit"])
|
||||
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Authentication routes for register, login, logout, and current user."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from sqlalchemy import select
|
||||
|
|
@ -9,18 +10,17 @@ from auth import (
|
|||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
COOKIE_NAME,
|
||||
COOKIE_SECURE,
|
||||
get_password_hash,
|
||||
get_user_by_email,
|
||||
authenticate_user,
|
||||
build_user_response,
|
||||
create_access_token,
|
||||
get_current_user,
|
||||
build_user_response,
|
||||
get_password_hash,
|
||||
get_user_by_email,
|
||||
)
|
||||
from database import get_db
|
||||
from invite_utils import normalize_identifier
|
||||
from models import User, Role, ROLE_REGULAR, Invite, InviteStatus
|
||||
from schemas import UserLogin, UserResponse, RegisterWithInvite
|
||||
|
||||
from models import ROLE_REGULAR, Invite, InviteStatus, Role, User
|
||||
from schemas import RegisterWithInvite, UserLogin, UserResponse
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
|
@ -52,9 +52,8 @@ async def register(
|
|||
"""Register a new user using an invite code."""
|
||||
# Validate invite
|
||||
normalized_identifier = normalize_identifier(user_data.invite_identifier)
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.identifier == normalized_identifier)
|
||||
)
|
||||
query = select(Invite).where(Invite.identifier == normalized_identifier)
|
||||
result = await db.execute(query)
|
||||
invite = result.scalar_one_or_none()
|
||||
|
||||
# Return same error for not found, spent, and revoked to avoid information leakage
|
||||
|
|
@ -90,7 +89,7 @@ async def register(
|
|||
# Mark invite as spent
|
||||
invite.status = InviteStatus.SPENT
|
||||
invite.used_by_id = user.id
|
||||
invite.spent_at = datetime.now(timezone.utc)
|
||||
invite.spent_at = datetime.now(UTC)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
|
|
|||
|
|
@ -1,28 +1,28 @@
|
|||
"""Availability routes for admin to manage booking availability."""
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, delete, and_
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import User, Availability, Permission
|
||||
from models import Availability, Permission, User
|
||||
from schemas import (
|
||||
TimeSlot,
|
||||
AvailabilityDay,
|
||||
AvailabilityResponse,
|
||||
SetAvailabilityRequest,
|
||||
CopyAvailabilityRequest,
|
||||
SetAvailabilityRequest,
|
||||
TimeSlot,
|
||||
)
|
||||
from shared_constants import MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS
|
||||
|
||||
from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS
|
||||
|
||||
router = APIRouter(prefix="/api/admin/availability", tags=["availability"])
|
||||
|
||||
|
||||
def _get_date_range_bounds() -> tuple[date, date]:
|
||||
"""Get the valid date range for availability (using MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
|
||||
"""Get valid date range (MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
|
||||
today = date.today()
|
||||
min_date = today + timedelta(days=MIN_ADVANCE_DAYS)
|
||||
max_date = today + timedelta(days=MAX_ADVANCE_DAYS)
|
||||
|
|
@ -34,12 +34,14 @@ def _validate_date_in_range(d: date, min_date: date, max_date: date) -> None:
|
|||
if d < min_date:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot set availability for past dates. Earliest allowed: {min_date}",
|
||||
detail=f"Cannot set availability for past dates. "
|
||||
f"Earliest allowed: {min_date}",
|
||||
)
|
||||
if d > max_date:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot set availability more than {MAX_ADVANCE_DAYS} days ahead. Latest allowed: {max_date}",
|
||||
detail=f"Cannot set more than {MAX_ADVANCE_DAYS} days ahead. "
|
||||
f"Latest allowed: {max_date}",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -56,7 +58,7 @@ async def get_availability(
|
|||
status_code=400,
|
||||
detail="'from' date must be before or equal to 'to' date",
|
||||
)
|
||||
|
||||
|
||||
# Query availability in range
|
||||
result = await db.execute(
|
||||
select(Availability)
|
||||
|
|
@ -64,23 +66,24 @@ async def get_availability(
|
|||
.order_by(Availability.date, Availability.start_time)
|
||||
)
|
||||
slots = result.scalars().all()
|
||||
|
||||
|
||||
# Group by date
|
||||
days_dict: dict[date, list[TimeSlot]] = {}
|
||||
for slot in slots:
|
||||
if slot.date not in days_dict:
|
||||
days_dict[slot.date] = []
|
||||
days_dict[slot.date].append(TimeSlot(
|
||||
start_time=slot.start_time,
|
||||
end_time=slot.end_time,
|
||||
))
|
||||
|
||||
days_dict[slot.date].append(
|
||||
TimeSlot(
|
||||
start_time=slot.start_time,
|
||||
end_time=slot.end_time,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
days = [
|
||||
AvailabilityDay(date=d, slots=days_dict[d])
|
||||
for d in sorted(days_dict.keys())
|
||||
AvailabilityDay(date=d, slots=days_dict[d]) for d in sorted(days_dict.keys())
|
||||
]
|
||||
|
||||
|
||||
return AvailabilityResponse(days=days)
|
||||
|
||||
|
||||
|
|
@ -93,29 +96,31 @@ async def set_availability(
|
|||
"""Set availability for a specific date. Replaces any existing availability."""
|
||||
min_date, max_date = _get_date_range_bounds()
|
||||
_validate_date_in_range(request.date, min_date, max_date)
|
||||
|
||||
|
||||
# Validate slots don't overlap
|
||||
sorted_slots = sorted(request.slots, key=lambda s: s.start_time)
|
||||
for i in range(len(sorted_slots) - 1):
|
||||
if sorted_slots[i].end_time > sorted_slots[i + 1].start_time:
|
||||
end = sorted_slots[i].end_time
|
||||
start = sorted_slots[i + 1].start_time
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Time slots overlap on {request.date}: slot ending at {sorted_slots[i].end_time} overlaps with slot starting at {sorted_slots[i + 1].start_time}. Please ensure all time slots are non-overlapping.",
|
||||
detail=f"Time slots overlap: slot ending at {end} "
|
||||
f"overlaps with slot starting at {start}",
|
||||
)
|
||||
|
||||
|
||||
# Validate each slot's end_time > start_time
|
||||
for slot in request.slots:
|
||||
if slot.end_time <= slot.start_time:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid time slot on {request.date}: end time {slot.end_time} must be after start time {slot.start_time}. Please correct the time range.",
|
||||
detail=f"Invalid time slot: end time {slot.end_time} "
|
||||
f"must be after start time {slot.start_time}",
|
||||
)
|
||||
|
||||
|
||||
# Delete existing availability for this date
|
||||
await db.execute(
|
||||
delete(Availability).where(Availability.date == request.date)
|
||||
)
|
||||
|
||||
await db.execute(delete(Availability).where(Availability.date == request.date))
|
||||
|
||||
# Create new availability slots
|
||||
for slot in request.slots:
|
||||
availability = Availability(
|
||||
|
|
@ -124,9 +129,9 @@ async def set_availability(
|
|||
end_time=slot.end_time,
|
||||
)
|
||||
db.add(availability)
|
||||
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
return AvailabilityDay(date=request.date, slots=request.slots)
|
||||
|
||||
|
||||
|
|
@ -138,14 +143,14 @@ async def copy_availability(
|
|||
) -> AvailabilityResponse:
|
||||
"""Copy availability from one day to multiple target days."""
|
||||
min_date, max_date = _get_date_range_bounds()
|
||||
|
||||
# Validate source date is in range (for consistency, though DB query would fail anyway)
|
||||
|
||||
# Validate source date is in range
|
||||
_validate_date_in_range(request.source_date, min_date, max_date)
|
||||
|
||||
|
||||
# Validate target dates
|
||||
for target_date in request.target_dates:
|
||||
_validate_date_in_range(target_date, min_date, max_date)
|
||||
|
||||
|
||||
# Get source availability
|
||||
result = await db.execute(
|
||||
select(Availability)
|
||||
|
|
@ -153,13 +158,13 @@ async def copy_availability(
|
|||
.order_by(Availability.start_time)
|
||||
)
|
||||
source_slots = result.scalars().all()
|
||||
|
||||
|
||||
if not source_slots:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No availability found for source date {request.source_date}",
|
||||
)
|
||||
|
||||
|
||||
# Copy to each target date within a single atomic transaction
|
||||
# All deletes and inserts happen before commit, ensuring atomicity
|
||||
copied_days: list[AvailabilityDay] = []
|
||||
|
|
@ -167,12 +172,11 @@ async def copy_availability(
|
|||
for target_date in request.target_dates:
|
||||
if target_date == request.source_date:
|
||||
continue # Skip copying to self
|
||||
|
||||
|
||||
# Delete existing availability for target date
|
||||
await db.execute(
|
||||
delete(Availability).where(Availability.date == target_date)
|
||||
)
|
||||
|
||||
del_query = delete(Availability).where(Availability.date == target_date)
|
||||
await db.execute(del_query)
|
||||
|
||||
# Copy slots
|
||||
target_slots: list[TimeSlot] = []
|
||||
for source_slot in source_slots:
|
||||
|
|
@ -182,19 +186,20 @@ async def copy_availability(
|
|||
end_time=source_slot.end_time,
|
||||
)
|
||||
db.add(new_availability)
|
||||
target_slots.append(TimeSlot(
|
||||
start_time=source_slot.start_time,
|
||||
end_time=source_slot.end_time,
|
||||
))
|
||||
|
||||
target_slots.append(
|
||||
TimeSlot(
|
||||
start_time=source_slot.start_time,
|
||||
end_time=source_slot.end_time,
|
||||
)
|
||||
)
|
||||
|
||||
copied_days.append(AvailabilityDay(date=target_date, slots=target_slots))
|
||||
|
||||
|
||||
# Commit all changes atomically
|
||||
await db.commit()
|
||||
except Exception:
|
||||
# Rollback on any error to maintain atomicity
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
return AvailabilityResponse(days=copied_days)
|
||||
|
||||
return AvailabilityResponse(days=copied_days)
|
||||
|
|
|
|||
|
|
@ -1,24 +1,24 @@
|
|||
"""Booking routes for users to book appointments."""
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, and_, func
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import User, Availability, Appointment, AppointmentStatus, Permission
|
||||
from models import Appointment, AppointmentStatus, Availability, Permission, User
|
||||
from schemas import (
|
||||
BookableSlot,
|
||||
AvailableSlotsResponse,
|
||||
BookingRequest,
|
||||
AppointmentResponse,
|
||||
AvailableSlotsResponse,
|
||||
BookableSlot,
|
||||
BookingRequest,
|
||||
PaginatedAppointments,
|
||||
)
|
||||
from shared_constants import SLOT_DURATION_MINUTES, MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS
|
||||
|
||||
from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS, SLOT_DURATION_MINUTES
|
||||
|
||||
router = APIRouter(prefix="/api/booking", tags=["booking"])
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ def _to_appointment_response(
|
|||
user_email: str | None = None,
|
||||
) -> AppointmentResponse:
|
||||
"""Convert an Appointment model to AppointmentResponse schema.
|
||||
|
||||
|
||||
Args:
|
||||
appointment: The appointment model instance
|
||||
user_email: Optional user email. If not provided, uses appointment.user.email
|
||||
|
|
@ -49,7 +49,7 @@ def _to_appointment_response(
|
|||
|
||||
def _get_valid_minute_boundaries() -> tuple[int, ...]:
|
||||
"""Get valid minute boundaries based on SLOT_DURATION_MINUTES.
|
||||
|
||||
|
||||
Assumes SLOT_DURATION_MINUTES divides 60 evenly (e.g., 15 minutes = 0, 15, 30, 45).
|
||||
"""
|
||||
boundaries: list[int] = []
|
||||
|
|
@ -74,12 +74,14 @@ def _validate_booking_date(d: date) -> None:
|
|||
if d < min_date:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot book for today or past dates. Earliest bookable date: {min_date}",
|
||||
detail=f"Cannot book for today or past dates. "
|
||||
f"Earliest bookable: {min_date}",
|
||||
)
|
||||
if d > max_date:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. Latest bookable: {max_date}",
|
||||
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. "
|
||||
f"Latest bookable: {max_date}",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -89,18 +91,18 @@ def _expand_availability_to_slots(
|
|||
) -> list[BookableSlot]:
|
||||
"""Expand availability time ranges into 15-minute bookable slots."""
|
||||
result: list[BookableSlot] = []
|
||||
|
||||
|
||||
for avail in availability_slots:
|
||||
# Create datetime objects for start and end
|
||||
current = datetime.combine(target_date, avail.start_time, tzinfo=timezone.utc)
|
||||
end = datetime.combine(target_date, avail.end_time, tzinfo=timezone.utc)
|
||||
|
||||
current = datetime.combine(target_date, avail.start_time, tzinfo=UTC)
|
||||
end = datetime.combine(target_date, avail.end_time, tzinfo=UTC)
|
||||
|
||||
# Generate 15-minute slots
|
||||
while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end:
|
||||
slot_end = current + timedelta(minutes=SLOT_DURATION_MINUTES)
|
||||
result.append(BookableSlot(start_time=current, end_time=slot_end))
|
||||
current = slot_end
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -112,7 +114,7 @@ async def get_available_slots(
|
|||
) -> AvailableSlotsResponse:
|
||||
"""Get available booking slots for a specific date."""
|
||||
_validate_booking_date(target_date)
|
||||
|
||||
|
||||
# Get availability for this date
|
||||
result = await db.execute(
|
||||
select(Availability)
|
||||
|
|
@ -120,20 +122,19 @@ async def get_available_slots(
|
|||
.order_by(Availability.start_time)
|
||||
)
|
||||
availability_slots = result.scalars().all()
|
||||
|
||||
|
||||
if not availability_slots:
|
||||
return AvailableSlotsResponse(date=target_date, slots=[])
|
||||
|
||||
|
||||
# Expand to 15-minute slots
|
||||
all_slots = _expand_availability_to_slots(availability_slots, target_date)
|
||||
|
||||
|
||||
# Get existing booked appointments for this date
|
||||
day_start = datetime.combine(target_date, time.min, tzinfo=timezone.utc)
|
||||
day_end = datetime.combine(target_date, time.max, tzinfo=timezone.utc)
|
||||
|
||||
day_start = datetime.combine(target_date, time.min, tzinfo=UTC)
|
||||
day_end = datetime.combine(target_date, time.max, tzinfo=UTC)
|
||||
|
||||
result = await db.execute(
|
||||
select(Appointment.slot_start)
|
||||
.where(
|
||||
select(Appointment.slot_start).where(
|
||||
and_(
|
||||
Appointment.slot_start >= day_start,
|
||||
Appointment.slot_start <= day_end,
|
||||
|
|
@ -142,13 +143,12 @@ async def get_available_slots(
|
|||
)
|
||||
)
|
||||
booked_starts = {row[0] for row in result.fetchall()}
|
||||
|
||||
|
||||
# Filter out already booked slots
|
||||
available_slots = [
|
||||
slot for slot in all_slots
|
||||
if slot.start_time not in booked_starts
|
||||
slot for slot in all_slots if slot.start_time not in booked_starts
|
||||
]
|
||||
|
||||
|
||||
return AvailableSlotsResponse(date=target_date, slots=available_slots)
|
||||
|
||||
|
||||
|
|
@ -161,27 +161,28 @@ async def create_booking(
|
|||
"""Book an appointment slot."""
|
||||
slot_date = request.slot_start.date()
|
||||
_validate_booking_date(slot_date)
|
||||
|
||||
# Validate slot is on the correct minute boundary (derived from SLOT_DURATION_MINUTES)
|
||||
|
||||
# Validate slot is on the correct minute boundary
|
||||
valid_minutes = _get_valid_minute_boundaries()
|
||||
if request.slot_start.minute not in valid_minutes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slot start time must be on {SLOT_DURATION_MINUTES}-minute boundary (valid minutes: {valid_minutes})",
|
||||
detail=f"Slot must be on {SLOT_DURATION_MINUTES}-minute boundary "
|
||||
f"(valid minutes: {valid_minutes})",
|
||||
)
|
||||
if request.slot_start.second != 0 or request.slot_start.microsecond != 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Slot start time must not have seconds or microseconds",
|
||||
)
|
||||
|
||||
|
||||
# Verify slot falls within availability
|
||||
slot_start_time = request.slot_start.time()
|
||||
slot_end_time = (request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)).time()
|
||||
|
||||
slot_end_dt = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
|
||||
slot_end_time = slot_end_dt.time()
|
||||
|
||||
result = await db.execute(
|
||||
select(Availability)
|
||||
.where(
|
||||
select(Availability).where(
|
||||
and_(
|
||||
Availability.date == slot_date,
|
||||
Availability.start_time <= slot_start_time,
|
||||
|
|
@ -190,13 +191,15 @@ async def create_booking(
|
|||
)
|
||||
)
|
||||
matching_availability = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not matching_availability:
|
||||
slot_str = request.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Selected slot at {request.slot_start.strftime('%Y-%m-%d %H:%M')} UTC is not within any available time ranges for {slot_date}. Please select a different time slot.",
|
||||
detail=f"Selected slot at {slot_str} UTC is not within "
|
||||
f"any available time ranges for {slot_date}",
|
||||
)
|
||||
|
||||
|
||||
# Create the appointment
|
||||
slot_end = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
|
||||
appointment = Appointment(
|
||||
|
|
@ -206,9 +209,9 @@ async def create_booking(
|
|||
note=request.note,
|
||||
status=AppointmentStatus.BOOKED,
|
||||
)
|
||||
|
||||
|
||||
db.add(appointment)
|
||||
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(appointment)
|
||||
|
|
@ -216,9 +219,9 @@ async def create_booking(
|
|||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This slot has already been booked. Please select another slot.",
|
||||
)
|
||||
|
||||
detail="This slot has already been booked. Select another slot.",
|
||||
) from None
|
||||
|
||||
return _to_appointment_response(appointment, current_user.email)
|
||||
|
||||
|
||||
|
|
@ -241,60 +244,63 @@ async def get_my_appointments(
|
|||
.order_by(Appointment.slot_start.desc())
|
||||
)
|
||||
appointments = result.scalars().all()
|
||||
|
||||
return [
|
||||
_to_appointment_response(apt, current_user.email)
|
||||
for apt in appointments
|
||||
]
|
||||
|
||||
return [_to_appointment_response(apt, current_user.email) for apt in appointments]
|
||||
|
||||
|
||||
@appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse)
|
||||
@appointments_router.post(
|
||||
"/{appointment_id}/cancel", response_model=AppointmentResponse
|
||||
)
|
||||
async def cancel_my_appointment(
|
||||
appointment_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)),
|
||||
) -> AppointmentResponse:
|
||||
"""Cancel one of the current user's appointments."""
|
||||
# Get the appointment with explicit eager loading of user relationship
|
||||
# Get the appointment with eager loading of user relationship
|
||||
result = await db.execute(
|
||||
select(Appointment)
|
||||
.options(joinedload(Appointment.user))
|
||||
.where(Appointment.id == appointment_id)
|
||||
)
|
||||
appointment = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not appointment:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.",
|
||||
detail=f"Appointment {appointment_id} not found",
|
||||
)
|
||||
|
||||
|
||||
# Verify ownership
|
||||
if appointment.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Cannot cancel another user's appointment")
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Cannot cancel another user's appointment",
|
||||
)
|
||||
|
||||
# Check if already cancelled
|
||||
if appointment.status != AppointmentStatus.BOOKED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel appointment with status '{appointment.status.value}'"
|
||||
detail=f"Cannot cancel: status is '{appointment.status.value}'",
|
||||
)
|
||||
|
||||
|
||||
# Check if appointment is in the past
|
||||
if appointment.slot_start <= datetime.now(timezone.utc):
|
||||
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC"
|
||||
if appointment.slot_start <= datetime.now(UTC):
|
||||
apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started."
|
||||
detail=f"Cannot cancel appointment at {apt_time} UTC: "
|
||||
"already started or in the past",
|
||||
)
|
||||
|
||||
|
||||
# Cancel the appointment
|
||||
appointment.status = AppointmentStatus.CANCELLED_BY_USER
|
||||
appointment.cancelled_at = datetime.now(timezone.utc)
|
||||
|
||||
appointment.cancelled_at = datetime.now(UTC)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(appointment)
|
||||
|
||||
|
||||
return _to_appointment_response(appointment, current_user.email)
|
||||
|
||||
|
||||
|
|
@ -302,7 +308,9 @@ async def cancel_my_appointment(
|
|||
# Admin Appointments Endpoints
|
||||
# =============================================================================
|
||||
|
||||
admin_appointments_router = APIRouter(prefix="/api/admin/appointments", tags=["admin-appointments"])
|
||||
admin_appointments_router = APIRouter(
|
||||
prefix="/api/admin/appointments", tags=["admin-appointments"]
|
||||
)
|
||||
|
||||
|
||||
@admin_appointments_router.get("", response_model=PaginatedAppointments)
|
||||
|
|
@ -317,7 +325,7 @@ async def get_all_appointments(
|
|||
count_result = await db.execute(select(func.count(Appointment.id)))
|
||||
total = count_result.scalar() or 0
|
||||
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
||||
|
||||
|
||||
# Get paginated appointments with explicit eager loading of user relationship
|
||||
offset = (page - 1) * per_page
|
||||
result = await db.execute(
|
||||
|
|
@ -328,13 +336,13 @@ async def get_all_appointments(
|
|||
.limit(per_page)
|
||||
)
|
||||
appointments = result.scalars().all()
|
||||
|
||||
|
||||
# Build responses using the eager-loaded user relationship
|
||||
records = [
|
||||
_to_appointment_response(apt) # Uses eager-loaded relationship
|
||||
for apt in appointments
|
||||
]
|
||||
|
||||
|
||||
return PaginatedAppointments(
|
||||
records=records,
|
||||
total=total,
|
||||
|
|
@ -344,47 +352,52 @@ async def get_all_appointments(
|
|||
)
|
||||
|
||||
|
||||
@admin_appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse)
|
||||
@admin_appointments_router.post(
|
||||
"/{appointment_id}/cancel", response_model=AppointmentResponse
|
||||
)
|
||||
async def admin_cancel_appointment(
|
||||
appointment_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(require_permission(Permission.CANCEL_ANY_APPOINTMENT)),
|
||||
_current_user: User = Depends(
|
||||
require_permission(Permission.CANCEL_ANY_APPOINTMENT)
|
||||
),
|
||||
) -> AppointmentResponse:
|
||||
"""Cancel any appointment (admin only)."""
|
||||
# Get the appointment with explicit eager loading of user relationship
|
||||
# Get the appointment with eager loading of user relationship
|
||||
result = await db.execute(
|
||||
select(Appointment)
|
||||
.options(joinedload(Appointment.user))
|
||||
.where(Appointment.id == appointment_id)
|
||||
)
|
||||
appointment = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not appointment:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.",
|
||||
detail=f"Appointment {appointment_id} not found",
|
||||
)
|
||||
|
||||
|
||||
# Check if already cancelled
|
||||
if appointment.status != AppointmentStatus.BOOKED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel appointment with status '{appointment.status.value}'"
|
||||
detail=f"Cannot cancel: status is '{appointment.status.value}'",
|
||||
)
|
||||
|
||||
|
||||
# Check if appointment is in the past
|
||||
if appointment.slot_start <= datetime.now(timezone.utc):
|
||||
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC"
|
||||
if appointment.slot_start <= datetime.now(UTC):
|
||||
apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started."
|
||||
detail=f"Cannot cancel appointment at {apt_time} UTC: "
|
||||
"already started or in the past",
|
||||
)
|
||||
|
||||
|
||||
# Cancel the appointment
|
||||
appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN
|
||||
appointment.cancelled_at = datetime.now(timezone.utc)
|
||||
|
||||
appointment.cancelled_at = datetime.now(UTC)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(appointment)
|
||||
|
||||
|
||||
return _to_appointment_response(appointment) # Uses eager-loaded relationship
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
"""Counter routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import Counter, User, CounterRecord, Permission
|
||||
|
||||
from models import Counter, CounterRecord, Permission, User
|
||||
|
||||
router = APIRouter(prefix="/api/counter", tags=["counter"])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,25 +1,29 @@
|
|||
"""Invite routes for public check, user invites, and admin management."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy import select, func, desc
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format
|
||||
from models import User, Invite, InviteStatus, Permission
|
||||
from invite_utils import (
|
||||
generate_invite_identifier,
|
||||
is_valid_identifier_format,
|
||||
normalize_identifier,
|
||||
)
|
||||
from models import Invite, InviteStatus, Permission, User
|
||||
from schemas import (
|
||||
AdminUserResponse,
|
||||
InviteCheckResponse,
|
||||
InviteCreate,
|
||||
InviteResponse,
|
||||
UserInviteResponse,
|
||||
PaginatedInviteRecords,
|
||||
AdminUserResponse,
|
||||
UserInviteResponse,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/invites", tags=["invites"])
|
||||
admin_router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||
|
||||
|
|
@ -54,9 +58,7 @@ async def check_invite(
|
|||
if not is_valid_identifier_format(normalized):
|
||||
return InviteCheckResponse(valid=False, error="Invalid invite code format")
|
||||
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.identifier == normalized)
|
||||
)
|
||||
result = await db.execute(select(Invite).where(Invite.identifier == normalized))
|
||||
invite = result.scalar_one_or_none()
|
||||
|
||||
# Return same error for not found, spent, and revoked to avoid information leakage
|
||||
|
|
@ -112,9 +114,7 @@ async def create_invite(
|
|||
) -> InviteResponse:
|
||||
"""Create a new invite for a specified godfather user."""
|
||||
# Validate godfather exists
|
||||
result = await db.execute(
|
||||
select(User.id).where(User.id == data.godfather_id)
|
||||
)
|
||||
result = await db.execute(select(User.id).where(User.id == data.godfather_id))
|
||||
godfather_id = result.scalar_one_or_none()
|
||||
if not godfather_id:
|
||||
raise HTTPException(
|
||||
|
|
@ -141,8 +141,8 @@ async def create_invite(
|
|||
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate unique invite code. Please try again.",
|
||||
)
|
||||
detail="Failed to generate unique invite code. Try again.",
|
||||
) from None
|
||||
|
||||
if invite is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -156,7 +156,9 @@ async def create_invite(
|
|||
async def list_all_invites(
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(10, ge=1, le=100),
|
||||
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"),
|
||||
status_filter: str | None = Query(
|
||||
None, alias="status", description="Filter by status: ready, spent, revoked"
|
||||
),
|
||||
godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
||||
|
|
@ -175,8 +177,9 @@ async def list_all_invites(
|
|||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked",
|
||||
)
|
||||
detail=f"Invalid status: {status_filter}. "
|
||||
"Must be ready, spent, or revoked",
|
||||
) from None
|
||||
|
||||
if godfather_id:
|
||||
query = query.where(Invite.godfather_id == godfather_id)
|
||||
|
|
@ -224,11 +227,12 @@ async def revoke_invite(
|
|||
if invite.status != InviteStatus.READY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.",
|
||||
detail=f"Cannot revoke invite with status '{invite.status.value}'. "
|
||||
"Only READY invites can be revoked.",
|
||||
)
|
||||
|
||||
invite.status = InviteStatus.REVOKED
|
||||
invite.revoked_at = datetime.now(timezone.utc)
|
||||
invite.revoked_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"""Meta endpoints for shared constants."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from models import Permission, InviteStatus, ROLE_ADMIN, ROLE_REGULAR
|
||||
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, Permission
|
||||
from schemas import ConstantsResponse
|
||||
|
||||
router = APIRouter(prefix="/api/meta", tags=["meta"])
|
||||
|
|
@ -15,4 +16,3 @@ async def get_constants() -> ConstantsResponse:
|
|||
roles=[ROLE_ADMIN, ROLE_REGULAR],
|
||||
invite_statuses=[s.value for s in InviteStatus],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
"""Profile routes for user contact details."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import get_current_user
|
||||
from database import get_db
|
||||
from models import User, ROLE_REGULAR
|
||||
from models import ROLE_REGULAR, User
|
||||
from schemas import ProfileResponse, ProfileUpdate
|
||||
from validation import validate_profile_fields
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/profile", tags=["profile"])
|
||||
|
||||
|
||||
|
|
@ -29,9 +29,7 @@ async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str
|
|||
"""Get the email of a godfather user by ID."""
|
||||
if not godfather_id:
|
||||
return None
|
||||
result = await db.execute(
|
||||
select(User.email).where(User.id == godfather_id)
|
||||
)
|
||||
result = await db.execute(select(User.email).where(User.id == godfather_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Sum calculation routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import User, SumRecord, Permission
|
||||
from models import Permission, SumRecord, User
|
||||
from schemas import SumRequest, SumResponse
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/sum", tags=["sum"])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Pydantic schemas for API request/response models."""
|
||||
from datetime import datetime, date, time
|
||||
|
||||
from datetime import date, datetime, time
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator
|
||||
|
|
@ -9,6 +10,7 @@ from shared_constants import NOTE_MAX_LENGTH
|
|||
|
||||
class UserCredentials(BaseModel):
|
||||
"""Base model for user email/password."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
|
@ -19,6 +21,7 @@ UserLogin = UserCredentials
|
|||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for authenticated user info."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
roles: list[str]
|
||||
|
|
@ -27,6 +30,7 @@ class UserResponse(BaseModel):
|
|||
|
||||
class RegisterWithInvite(BaseModel):
|
||||
"""Request model for registration with invite."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
invite_identifier: str
|
||||
|
|
@ -34,12 +38,14 @@ class RegisterWithInvite(BaseModel):
|
|||
|
||||
class SumRequest(BaseModel):
|
||||
"""Request model for sum calculation."""
|
||||
|
||||
a: float
|
||||
b: float
|
||||
|
||||
|
||||
class SumResponse(BaseModel):
|
||||
"""Response model for sum calculation."""
|
||||
|
||||
a: float
|
||||
b: float
|
||||
result: float
|
||||
|
|
@ -47,6 +53,7 @@ class SumResponse(BaseModel):
|
|||
|
||||
class CounterRecordResponse(BaseModel):
|
||||
"""Response model for a counter audit record."""
|
||||
|
||||
id: int
|
||||
user_email: str
|
||||
value_before: int
|
||||
|
|
@ -56,6 +63,7 @@ class CounterRecordResponse(BaseModel):
|
|||
|
||||
class SumRecordResponse(BaseModel):
|
||||
"""Response model for a sum audit record."""
|
||||
|
||||
id: int
|
||||
user_email: str
|
||||
a: float
|
||||
|
|
@ -69,6 +77,7 @@ RecordT = TypeVar("RecordT", bound=BaseModel)
|
|||
|
||||
class PaginatedResponse(BaseModel, Generic[RecordT]):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
records: list[RecordT]
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -82,6 +91,7 @@ PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
|
|||
|
||||
class ProfileResponse(BaseModel):
|
||||
"""Response model for profile data."""
|
||||
|
||||
contact_email: str | None
|
||||
telegram: str | None
|
||||
signal: str | None
|
||||
|
|
@ -91,6 +101,7 @@ class ProfileResponse(BaseModel):
|
|||
|
||||
class ProfileUpdate(BaseModel):
|
||||
"""Request model for updating profile."""
|
||||
|
||||
contact_email: str | None = None
|
||||
telegram: str | None = None
|
||||
signal: str | None = None
|
||||
|
|
@ -99,6 +110,7 @@ class ProfileUpdate(BaseModel):
|
|||
|
||||
class InviteCheckResponse(BaseModel):
|
||||
"""Response for invite check endpoint."""
|
||||
|
||||
valid: bool
|
||||
status: str | None = None
|
||||
error: str | None = None
|
||||
|
|
@ -106,11 +118,13 @@ class InviteCheckResponse(BaseModel):
|
|||
|
||||
class InviteCreate(BaseModel):
|
||||
"""Request model for creating an invite."""
|
||||
|
||||
godfather_id: int
|
||||
|
||||
|
||||
class InviteResponse(BaseModel):
|
||||
"""Response model for invite data (admin view)."""
|
||||
|
||||
id: int
|
||||
identifier: str
|
||||
godfather_id: int
|
||||
|
|
@ -125,6 +139,7 @@ class InviteResponse(BaseModel):
|
|||
|
||||
class UserInviteResponse(BaseModel):
|
||||
"""Response model for a user's invite (simpler than admin view)."""
|
||||
|
||||
id: int
|
||||
identifier: str
|
||||
status: str
|
||||
|
|
@ -138,6 +153,7 @@ PaginatedInviteRecords = PaginatedResponse[InviteResponse]
|
|||
|
||||
class AdminUserResponse(BaseModel):
|
||||
"""Minimal user info for admin dropdowns."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
|
||||
|
|
@ -146,11 +162,13 @@ class AdminUserResponse(BaseModel):
|
|||
# Availability Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TimeSlot(BaseModel):
|
||||
"""A single time slot (start and end time)."""
|
||||
|
||||
start_time: time
|
||||
end_time: time
|
||||
|
||||
|
||||
@field_validator("start_time", "end_time")
|
||||
@classmethod
|
||||
def validate_15min_boundary(cls, v: time) -> time:
|
||||
|
|
@ -164,23 +182,27 @@ class TimeSlot(BaseModel):
|
|||
|
||||
class AvailabilityDay(BaseModel):
|
||||
"""Availability for a single day."""
|
||||
|
||||
date: date
|
||||
slots: list[TimeSlot]
|
||||
|
||||
|
||||
class AvailabilityResponse(BaseModel):
|
||||
"""Response model for availability query."""
|
||||
|
||||
days: list[AvailabilityDay]
|
||||
|
||||
|
||||
class SetAvailabilityRequest(BaseModel):
|
||||
"""Request to set availability for a specific date."""
|
||||
|
||||
date: date
|
||||
slots: list[TimeSlot]
|
||||
|
||||
|
||||
class CopyAvailabilityRequest(BaseModel):
|
||||
"""Request to copy availability from one day to others."""
|
||||
|
||||
source_date: date
|
||||
target_dates: list[date]
|
||||
|
||||
|
|
@ -189,20 +211,24 @@ class CopyAvailabilityRequest(BaseModel):
|
|||
# Booking Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BookableSlot(BaseModel):
|
||||
"""A bookable 15-minute slot."""
|
||||
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
|
||||
|
||||
class AvailableSlotsResponse(BaseModel):
|
||||
"""Response for available slots on a given date."""
|
||||
|
||||
date: date
|
||||
slots: list[BookableSlot]
|
||||
|
||||
|
||||
class BookingRequest(BaseModel):
|
||||
"""Request to book an appointment."""
|
||||
|
||||
slot_start: datetime
|
||||
note: str | None = None
|
||||
|
||||
|
|
@ -216,6 +242,7 @@ class BookingRequest(BaseModel):
|
|||
|
||||
class AppointmentResponse(BaseModel):
|
||||
"""Response model for an appointment."""
|
||||
|
||||
id: int
|
||||
user_id: int
|
||||
user_email: str
|
||||
|
|
@ -234,8 +261,10 @@ PaginatedAppointments = PaginatedResponse[AppointmentResponse]
|
|||
# Meta/Constants Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ConstantsResponse(BaseModel):
|
||||
"""Response model for shared constants."""
|
||||
|
||||
permissions: list[str]
|
||||
roles: list[str]
|
||||
invite_statuses: list[str]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,21 @@
|
|||
"""Seed the database with roles, permissions, and dev users."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import engine, async_session, Base
|
||||
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
||||
from auth import get_password_hash
|
||||
from database import Base, async_session, engine
|
||||
from models import (
|
||||
ROLE_ADMIN,
|
||||
ROLE_DEFINITIONS,
|
||||
ROLE_REGULAR,
|
||||
Permission,
|
||||
Role,
|
||||
User,
|
||||
)
|
||||
|
||||
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
|
||||
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
|
||||
|
|
@ -14,11 +23,13 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
|
|||
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
|
||||
|
||||
|
||||
async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role:
|
||||
async def upsert_role(
|
||||
db: AsyncSession, name: str, description: str, permissions: list[Permission]
|
||||
) -> Role:
|
||||
"""Create or update a role with the given permissions."""
|
||||
result = await db.execute(select(Role).where(Role.name == name))
|
||||
role = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if role:
|
||||
role.description = description
|
||||
print(f"Updated role: {name}")
|
||||
|
|
@ -27,19 +38,21 @@ async def upsert_role(db: AsyncSession, name: str, description: str, permissions
|
|||
db.add(role)
|
||||
await db.flush() # Get the role ID
|
||||
print(f"Created role: {name}")
|
||||
|
||||
|
||||
# Set permissions for the role
|
||||
await role.set_permissions(db, permissions)
|
||||
print(f" Permissions: {', '.join(p.value for p in permissions)}")
|
||||
|
||||
|
||||
return role
|
||||
|
||||
|
||||
async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User:
|
||||
async def upsert_user(
|
||||
db: AsyncSession, email: str, password: str, role_names: list[str]
|
||||
) -> User:
|
||||
"""Create or update a user with the given credentials and roles."""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
|
||||
# Get roles
|
||||
roles = []
|
||||
for role_name in role_names:
|
||||
|
|
@ -48,7 +61,7 @@ async def upsert_user(db: AsyncSession, email: str, password: str, role_names: l
|
|||
if not role:
|
||||
raise ValueError(f"Role '{role_name}' not found")
|
||||
roles.append(role)
|
||||
|
||||
|
||||
if user:
|
||||
user.hashed_password = get_password_hash(password)
|
||||
user.roles = roles # type: ignore[assignment]
|
||||
|
|
@ -61,7 +74,7 @@ async def upsert_user(db: AsyncSession, email: str, password: str, role_names: l
|
|||
)
|
||||
db.add(user)
|
||||
print(f"Created user: {email} with roles: {role_names}")
|
||||
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
|
@ -78,14 +91,14 @@ async def seed() -> None:
|
|||
role_config["description"],
|
||||
role_config["permissions"],
|
||||
)
|
||||
|
||||
|
||||
print("\n=== Seeding Users ===")
|
||||
# Create regular dev user
|
||||
await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, [ROLE_REGULAR])
|
||||
|
||||
|
||||
# Create admin dev user
|
||||
await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, [ROLE_ADMIN])
|
||||
|
||||
|
||||
await db.commit()
|
||||
print("\n=== Seeding Complete ===\n")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""Load shared constants from shared/constants.json."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -10,4 +11,3 @@ SLOT_DURATION_MINUTES: int = _constants["booking"]["slotDurationMinutes"]
|
|||
MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"]
|
||||
MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"]
|
||||
NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,28 +7,28 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
|||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from auth import get_password_hash
|
||||
from database import Base, get_db
|
||||
from main import app
|
||||
from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
||||
from auth import get_password_hash
|
||||
from models import ROLE_ADMIN, ROLE_DEFINITIONS, ROLE_REGULAR, Role, User
|
||||
from tests.helpers import unique_email
|
||||
|
||||
TEST_DATABASE_URL = os.getenv(
|
||||
"TEST_DATABASE_URL",
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test",
|
||||
)
|
||||
|
||||
|
||||
class ClientFactory:
|
||||
"""Factory for creating httpx clients with optional cookies."""
|
||||
|
||||
|
||||
def __init__(self, transport, base_url, session_factory):
|
||||
self._transport = transport
|
||||
self._base_url = base_url
|
||||
self._session_factory = session_factory
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create(self, cookies: dict | None = None):
|
||||
"""Create a new client, optionally with cookies set."""
|
||||
|
|
@ -38,15 +38,15 @@ class ClientFactory:
|
|||
cookies=cookies or {},
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
|
||||
async def request(self, method: str, url: str, **kwargs):
|
||||
"""Make a one-off request without cookies."""
|
||||
async with self.create() as client:
|
||||
return await client.request(method, url, **kwargs)
|
||||
|
||||
|
||||
async def get(self, url: str, **kwargs):
|
||||
return await self.request("GET", url, **kwargs)
|
||||
|
||||
|
||||
async def post(self, url: str, **kwargs):
|
||||
return await self.request("POST", url, **kwargs)
|
||||
|
||||
|
|
@ -64,16 +64,16 @@ async def setup_roles(db: AsyncSession) -> dict[str, Role]:
|
|||
# Check if role exists
|
||||
result = await db.execute(select(Role).where(Role.name == role_name))
|
||||
role = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not role:
|
||||
role = Role(name=role_name, description=config["description"])
|
||||
db.add(role)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# Set permissions
|
||||
await role.set_permissions(db, config["permissions"])
|
||||
roles[role_name] = role
|
||||
|
||||
|
||||
await db.commit()
|
||||
return roles
|
||||
|
||||
|
|
@ -91,9 +91,11 @@ async def create_user_with_roles(
|
|||
result = await db.execute(select(Role).where(Role.name == role_name))
|
||||
role = result.scalar_one_or_none()
|
||||
if not role:
|
||||
raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?")
|
||||
raise ValueError(
|
||||
f"Role '{role_name}' not found. Did you run setup_roles()?"
|
||||
)
|
||||
roles.append(role)
|
||||
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=get_password_hash(password),
|
||||
|
|
@ -110,27 +112,27 @@ async def client_factory():
|
|||
"""Fixture that provides a factory for creating clients."""
|
||||
engine = create_async_engine(TEST_DATABASE_URL)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
# Setup roles
|
||||
async with session_factory() as db:
|
||||
await setup_roles(db)
|
||||
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
factory = ClientFactory(transport, "http://test", session_factory)
|
||||
|
||||
|
||||
yield factory
|
||||
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
await engine.dispose()
|
||||
|
||||
|
|
@ -147,17 +149,17 @@ async def regular_user(client_factory):
|
|||
"""Create a regular user and return their credentials and cookies."""
|
||||
email = unique_email("regular")
|
||||
password = "password123"
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
user = await create_user_with_roles(db, email, password, [ROLE_REGULAR])
|
||||
user_id = user.id
|
||||
|
||||
|
||||
# Login to get cookies
|
||||
response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"password": password,
|
||||
|
|
@ -172,17 +174,17 @@ async def alt_regular_user(client_factory):
|
|||
"""Create a second regular user for tests needing multiple users."""
|
||||
email = unique_email("alt_regular")
|
||||
password = "password123"
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
user = await create_user_with_roles(db, email, password, [ROLE_REGULAR])
|
||||
user_id = user.id
|
||||
|
||||
|
||||
# Login to get cookies
|
||||
response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"password": password,
|
||||
|
|
@ -197,16 +199,16 @@ async def admin_user(client_factory):
|
|||
"""Create an admin user and return their credentials and cookies."""
|
||||
email = unique_email("admin")
|
||||
password = "password123"
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
await create_user_with_roles(db, email, password, [ROLE_ADMIN])
|
||||
|
||||
|
||||
# Login to get cookies
|
||||
response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"password": password,
|
||||
|
|
@ -220,16 +222,16 @@ async def user_no_roles(client_factory):
|
|||
"""Create a user with NO roles and return their credentials and cookies."""
|
||||
email = unique_email("noroles")
|
||||
password = "password123"
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
await create_user_with_roles(db, email, password, [])
|
||||
|
||||
|
||||
# Login to get cookies
|
||||
response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"password": password,
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import uuid
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models import User, Invite, InviteStatus
|
||||
from invite_utils import generate_invite_identifier
|
||||
from models import Invite, InviteStatus, User
|
||||
|
||||
|
||||
def unique_email(prefix: str = "test") -> str:
|
||||
|
|
@ -15,24 +15,24 @@ def unique_email(prefix: str = "test") -> str:
|
|||
async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> str:
|
||||
"""
|
||||
Create an invite for an existing godfather user.
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
godfather_id: ID of the existing user who will be the godfather
|
||||
|
||||
|
||||
Returns:
|
||||
The invite identifier.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If the godfather user doesn't exist.
|
||||
"""
|
||||
# Verify godfather exists
|
||||
result = await db.execute(select(User).where(User.id == godfather_id))
|
||||
godfather = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not godfather:
|
||||
raise ValueError(f"Godfather user with ID {godfather_id} not found")
|
||||
|
||||
|
||||
# Create invite
|
||||
identifier = generate_invite_identifier()
|
||||
invite = Invite(
|
||||
|
|
@ -42,7 +42,7 @@ async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> st
|
|||
)
|
||||
db.add(invite)
|
||||
await db.commit()
|
||||
|
||||
|
||||
return identifier
|
||||
|
||||
|
||||
|
|
@ -50,24 +50,26 @@ async def create_invite_for_godfather(db: AsyncSession, godfather_id: int) -> st
|
|||
async def create_invite_for_registration(db: AsyncSession, godfather_email: str) -> str:
|
||||
"""
|
||||
Create an invite for an existing godfather user (looked up by email).
|
||||
|
||||
|
||||
The godfather must already exist in the database.
|
||||
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
godfather_email: Email of the existing user who will be the godfather
|
||||
|
||||
|
||||
Returns:
|
||||
The invite identifier.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If the godfather user doesn't exist.
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.email == godfather_email))
|
||||
godfather = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not godfather:
|
||||
raise ValueError(f"Godfather user with email '{godfather_email}' not found. "
|
||||
"Create the user first using create_user_with_roles().")
|
||||
|
||||
raise ValueError(
|
||||
f"Godfather user with email '{godfather_email}' not found. "
|
||||
"Create the user first using create_user_with_roles()."
|
||||
)
|
||||
|
||||
return await create_invite_for_godfather(db, godfather.id)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@
|
|||
Note: Registration now requires an invite code. Tests that need to register
|
||||
users will create invites first via the helper function.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from auth import COOKIE_NAME
|
||||
from models import ROLE_REGULAR
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
from tests.helpers import create_invite_for_godfather, unique_email
|
||||
|
||||
|
||||
# Registration tests (with invite)
|
||||
|
|
@ -16,12 +17,14 @@ from tests.conftest import create_user_with_roles
|
|||
async def test_register_success(client_factory):
|
||||
"""Can register with valid invite code."""
|
||||
email = unique_email("register")
|
||||
|
||||
|
||||
# Create godfather user and invite
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("godfather"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -46,13 +49,15 @@ async def test_register_success(client_factory):
|
|||
async def test_register_duplicate_email(client_factory):
|
||||
"""Cannot register with already-used email."""
|
||||
email = unique_email("duplicate")
|
||||
|
||||
|
||||
# Create godfather and two invites
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
# First registration
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -62,7 +67,7 @@ async def test_register_duplicate_email(client_factory):
|
|||
"invite_identifier": invite1,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Second registration with same email
|
||||
response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -80,9 +85,11 @@ async def test_register_duplicate_email(client_factory):
|
|||
async def test_register_invalid_email(client_factory):
|
||||
"""Cannot register with invalid email format."""
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -136,11 +143,13 @@ async def test_register_empty_body(client):
|
|||
async def test_login_success(client_factory):
|
||||
"""Can login with valid credentials."""
|
||||
email = unique_email("login")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -165,11 +174,13 @@ async def test_login_success(client_factory):
|
|||
async def test_login_wrong_password(client_factory):
|
||||
"""Cannot login with wrong password."""
|
||||
email = unique_email("wrongpass")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -219,11 +230,13 @@ async def test_login_missing_fields(client):
|
|||
async def test_get_me_success(client_factory):
|
||||
"""Can get current user info when authenticated."""
|
||||
email = unique_email("me")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -233,10 +246,10 @@ async def test_get_me_success(client_factory):
|
|||
},
|
||||
)
|
||||
cookies = dict(reg_response.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
response = await authed.get("/api/auth/me")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == email
|
||||
|
|
@ -255,7 +268,9 @@ async def test_get_me_no_cookie(client):
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_me_invalid_cookie(client_factory):
|
||||
"""Cannot get current user with invalid cookie."""
|
||||
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed:
|
||||
async with client_factory.create(
|
||||
cookies={COOKIE_NAME: "invalidtoken123"}
|
||||
) as authed:
|
||||
response = await authed.get("/api/auth/me")
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid authentication credentials"
|
||||
|
|
@ -275,11 +290,13 @@ async def test_get_me_expired_token(client_factory):
|
|||
async def test_cookie_from_register_works_for_me(client_factory):
|
||||
"""Auth cookie from registration works for subsequent requests."""
|
||||
email = unique_email("tokentest")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -289,10 +306,10 @@ async def test_cookie_from_register_works_for_me(client_factory):
|
|||
},
|
||||
)
|
||||
cookies = dict(reg_response.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
me_response = await authed.get("/api/auth/me")
|
||||
|
||||
|
||||
assert me_response.status_code == 200
|
||||
assert me_response.json()["email"] == email
|
||||
|
||||
|
|
@ -301,11 +318,13 @@ async def test_cookie_from_register_works_for_me(client_factory):
|
|||
async def test_cookie_from_login_works_for_me(client_factory):
|
||||
"""Auth cookie from login works for subsequent requests."""
|
||||
email = unique_email("logintoken")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -319,10 +338,10 @@ async def test_cookie_from_login_works_for_me(client_factory):
|
|||
json={"email": email, "password": "password123"},
|
||||
)
|
||||
cookies = dict(login_response.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
me_response = await authed.get("/api/auth/me")
|
||||
|
||||
|
||||
assert me_response.status_code == 200
|
||||
assert me_response.json()["email"] == email
|
||||
|
||||
|
|
@ -333,12 +352,14 @@ async def test_multiple_users_isolated(client_factory):
|
|||
"""Multiple users have isolated sessions."""
|
||||
email1 = unique_email("user1")
|
||||
email2 = unique_email("user2")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
resp1 = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -355,16 +376,16 @@ async def test_multiple_users_isolated(client_factory):
|
|||
"invite_identifier": invite2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
cookies1 = dict(resp1.cookies)
|
||||
cookies2 = dict(resp2.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies1) as user1:
|
||||
me1 = await user1.get("/api/auth/me")
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies2) as user2:
|
||||
me2 = await user2.get("/api/auth/me")
|
||||
|
||||
|
||||
assert me1.json()["email"] == email1
|
||||
assert me2.json()["email"] == email2
|
||||
assert me1.json()["id"] != me2.json()["id"]
|
||||
|
|
@ -375,11 +396,13 @@ async def test_multiple_users_isolated(client_factory):
|
|||
async def test_password_is_hashed(client_factory):
|
||||
"""Passwords are properly hashed (can login with correct password)."""
|
||||
email = unique_email("hashtest")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -399,11 +422,13 @@ async def test_password_is_hashed(client_factory):
|
|||
async def test_case_sensitive_password(client_factory):
|
||||
"""Passwords are case-sensitive."""
|
||||
email = unique_email("casetest")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -424,11 +449,13 @@ async def test_case_sensitive_password(client_factory):
|
|||
async def test_logout_success(client_factory):
|
||||
"""Can logout successfully."""
|
||||
email = unique_email("logout")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -438,9 +465,9 @@ async def test_logout_success(client_factory):
|
|||
},
|
||||
)
|
||||
cookies = dict(reg_response.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
logout_response = await authed.post("/api/auth/logout")
|
||||
|
||||
|
||||
assert logout_response.status_code == 200
|
||||
assert logout_response.json() == {"ok": True}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ Availability API Tests
|
|||
|
||||
Tests for the admin availability management endpoints.
|
||||
"""
|
||||
from datetime import date, time, timedelta
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
|
@ -19,6 +21,7 @@ def in_days(n: int) -> date:
|
|||
# Permission Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAvailabilityPermissions:
|
||||
"""Test that only admins can access availability endpoints."""
|
||||
|
||||
|
|
@ -44,7 +47,9 @@ class TestAvailabilityPermissions:
|
|||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_get_availability(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_get_availability(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get(
|
||||
"/api/admin/availability",
|
||||
|
|
@ -53,7 +58,9 @@ class TestAvailabilityPermissions:
|
|||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_set_availability(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_set_availability(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.put(
|
||||
"/api/admin/availability",
|
||||
|
|
@ -88,6 +95,7 @@ class TestAvailabilityPermissions:
|
|||
# Set Availability Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSetAvailability:
|
||||
"""Test setting availability for a date."""
|
||||
|
||||
|
|
@ -101,7 +109,7 @@ class TestSetAvailability:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["date"] == str(tomorrow())
|
||||
|
|
@ -122,13 +130,15 @@ class TestSetAvailability:
|
|||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["slots"]) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_empty_slots_clears_availability(self, client_factory, admin_user):
|
||||
async def test_set_empty_slots_clears_availability(
|
||||
self, client_factory, admin_user
|
||||
):
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
# First set some availability
|
||||
await client.put(
|
||||
|
|
@ -138,13 +148,13 @@ class TestSetAvailability:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Then clear it
|
||||
response = await client.put(
|
||||
"/api/admin/availability",
|
||||
json={"date": str(tomorrow()), "slots": []},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["slots"]) == 0
|
||||
|
|
@ -160,22 +170,22 @@ class TestSetAvailability:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Replace with different slots
|
||||
response = await client.put(
|
||||
await client.put(
|
||||
"/api/admin/availability",
|
||||
json={
|
||||
"date": str(tomorrow()),
|
||||
"slots": [{"start_time": "14:00:00", "end_time": "16:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Verify the replacement
|
||||
get_response = await client.get(
|
||||
"/api/admin/availability",
|
||||
params={"from": str(tomorrow()), "to": str(tomorrow())},
|
||||
)
|
||||
|
||||
|
||||
data = get_response.json()
|
||||
assert len(data["days"]) == 1
|
||||
assert len(data["days"][0]["slots"]) == 1
|
||||
|
|
@ -186,6 +196,7 @@ class TestSetAvailability:
|
|||
# Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAvailabilityValidation:
|
||||
"""Test validation rules for availability."""
|
||||
|
||||
|
|
@ -200,7 +211,7 @@ class TestAvailabilityValidation:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "past" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -214,7 +225,7 @@ class TestAvailabilityValidation:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "past" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -229,7 +240,7 @@ class TestAvailabilityValidation:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "30" in response.json()["detail"]
|
||||
|
||||
|
|
@ -243,7 +254,7 @@ class TestAvailabilityValidation:
|
|||
"slots": [{"start_time": "09:05:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 422 # Pydantic validation error
|
||||
assert "15-minute" in response.json()["detail"][0]["msg"]
|
||||
|
||||
|
|
@ -257,7 +268,7 @@ class TestAvailabilityValidation:
|
|||
"slots": [{"start_time": "12:00:00", "end_time": "09:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "after" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -274,7 +285,7 @@ class TestAvailabilityValidation:
|
|||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "overlap" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -283,6 +294,7 @@ class TestAvailabilityValidation:
|
|||
# Get Availability Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetAvailability:
|
||||
"""Test retrieving availability."""
|
||||
|
||||
|
|
@ -293,7 +305,7 @@ class TestGetAvailability:
|
|||
"/api/admin/availability",
|
||||
params={"from": str(tomorrow()), "to": str(in_days(7))},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["days"] == []
|
||||
|
|
@ -310,13 +322,13 @@ class TestGetAvailability:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Get range that includes all
|
||||
response = await client.get(
|
||||
"/api/admin/availability",
|
||||
params={"from": str(in_days(1)), "to": str(in_days(3))},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["days"]) == 3
|
||||
|
|
@ -333,13 +345,13 @@ class TestGetAvailability:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Get only a subset
|
||||
response = await client.get(
|
||||
"/api/admin/availability",
|
||||
params={"from": str(in_days(2)), "to": str(in_days(4))},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["days"]) == 3
|
||||
|
|
@ -351,7 +363,7 @@ class TestGetAvailability:
|
|||
"/api/admin/availability",
|
||||
params={"from": str(in_days(7)), "to": str(in_days(1))},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "before" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -360,6 +372,7 @@ class TestGetAvailability:
|
|||
# Copy Availability Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCopyAvailability:
|
||||
"""Test copying availability from one day to others."""
|
||||
|
||||
|
|
@ -377,7 +390,7 @@ class TestCopyAvailability:
|
|||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Copy to another day
|
||||
response = await client.post(
|
||||
"/api/admin/availability/copy",
|
||||
|
|
@ -386,7 +399,7 @@ class TestCopyAvailability:
|
|||
"target_dates": [str(in_days(2))],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["days"]) == 1
|
||||
|
|
@ -404,7 +417,7 @@ class TestCopyAvailability:
|
|||
"slots": [{"start_time": "10:00:00", "end_time": "11:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Copy to multiple days
|
||||
response = await client.post(
|
||||
"/api/admin/availability/copy",
|
||||
|
|
@ -413,7 +426,7 @@ class TestCopyAvailability:
|
|||
"target_dates": [str(in_days(2)), str(in_days(3)), str(in_days(4))],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["days"]) == 3
|
||||
|
|
@ -429,7 +442,7 @@ class TestCopyAvailability:
|
|||
"slots": [{"start_time": "08:00:00", "end_time": "09:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Set source availability
|
||||
await client.put(
|
||||
"/api/admin/availability",
|
||||
|
|
@ -438,7 +451,7 @@ class TestCopyAvailability:
|
|||
"slots": [{"start_time": "14:00:00", "end_time": "15:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Copy (should replace)
|
||||
await client.post(
|
||||
"/api/admin/availability/copy",
|
||||
|
|
@ -447,13 +460,13 @@ class TestCopyAvailability:
|
|||
"target_dates": [str(in_days(2))],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Verify target was replaced
|
||||
response = await client.get(
|
||||
"/api/admin/availability",
|
||||
params={"from": str(in_days(2)), "to": str(in_days(2))},
|
||||
)
|
||||
|
||||
|
||||
data = response.json()
|
||||
assert len(data["days"]) == 1
|
||||
assert len(data["days"][0]["slots"]) == 1
|
||||
|
|
@ -469,7 +482,7 @@ class TestCopyAvailability:
|
|||
"target_dates": [str(in_days(2))],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "no availability" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -484,7 +497,7 @@ class TestCopyAvailability:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Copy including self in targets
|
||||
response = await client.post(
|
||||
"/api/admin/availability/copy",
|
||||
|
|
@ -493,7 +506,7 @@ class TestCopyAvailability:
|
|||
"target_dates": [str(in_days(1)), str(in_days(2))],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should only have copied to day 2, not day 1 (self)
|
||||
|
|
@ -511,7 +524,7 @@ class TestCopyAvailability:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Try to copy to a date beyond 30 days
|
||||
response = await client.post(
|
||||
"/api/admin/availability/copy",
|
||||
|
|
@ -520,7 +533,7 @@ class TestCopyAvailability:
|
|||
"target_dates": [str(in_days(31))],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "30" in response.json()["detail"]
|
||||
|
||||
|
|
@ -535,6 +548,6 @@ class TestCopyAvailability:
|
|||
"target_dates": [str(in_days(1))],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "past" in response.json()["detail"].lower()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ Booking API Tests
|
|||
|
||||
Tests for the user booking endpoints.
|
||||
"""
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Appointment, AppointmentStatus
|
||||
|
|
@ -21,11 +23,14 @@ def in_days(n: int) -> date:
|
|||
# Permission Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBookingPermissions:
|
||||
"""Test that only regular users can book appointments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_get_slots(self, client_factory, regular_user, admin_user):
|
||||
async def test_regular_user_can_get_slots(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Regular user can get available slots."""
|
||||
# First, admin sets up availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -36,15 +41,19 @@ class TestBookingPermissions:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Regular user gets slots
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_book(self, client_factory, regular_user, admin_user):
|
||||
async def test_regular_user_can_book(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Regular user can book an appointment."""
|
||||
# Admin sets up availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -55,22 +64,24 @@ class TestBookingPermissions:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Regular user books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": "Test booking"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_get_slots(self, client_factory, admin_user):
|
||||
"""Admin cannot access booking slots endpoint."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -85,18 +96,20 @@ class TestBookingPermissions:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
response = await client.post(
|
||||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_cannot_get_slots(self, client):
|
||||
"""Unauthenticated user cannot get slots."""
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -113,6 +126,7 @@ class TestBookingPermissions:
|
|||
# Get Slots Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetSlots:
|
||||
"""Test getting available booking slots."""
|
||||
|
||||
|
|
@ -120,15 +134,19 @@ class TestGetSlots:
|
|||
async def test_get_slots_no_availability(self, client_factory, regular_user):
|
||||
"""Returns empty slots when no availability set."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["date"] == str(tomorrow())
|
||||
assert data["slots"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slots_expands_to_15min(self, client_factory, regular_user, admin_user):
|
||||
async def test_get_slots_expands_to_15min(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Availability is expanded into 15-minute slots."""
|
||||
# Admin sets 1-hour availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -139,15 +157,17 @@ class TestGetSlots:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User gets slots - should be 4 x 15-minute slots
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["slots"]) == 4
|
||||
|
||||
|
||||
# Verify times
|
||||
assert "09:00:00" in data["slots"][0]["start_time"]
|
||||
assert "09:15:00" in data["slots"][0]["end_time"]
|
||||
|
|
@ -156,7 +176,9 @@ class TestGetSlots:
|
|||
assert "10:00:00" in data["slots"][3]["end_time"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slots_excludes_booked(self, client_factory, regular_user, admin_user):
|
||||
async def test_get_slots_excludes_booked(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Already booked slots are excluded from available slots."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -167,17 +189,19 @@ class TestGetSlots:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "10:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books first slot
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
await client.post(
|
||||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
|
||||
|
||||
# Get slots again - should have 3 left
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["slots"]) == 3
|
||||
|
|
@ -189,6 +213,7 @@ class TestGetSlots:
|
|||
# Booking Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCreateBooking:
|
||||
"""Test creating bookings."""
|
||||
|
||||
|
|
@ -204,7 +229,7 @@ class TestCreateBooking:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
|
|
@ -214,7 +239,7 @@ class TestCreateBooking:
|
|||
"note": "Discussion about project",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["user_id"] == regular_user["user"]["id"]
|
||||
|
|
@ -235,20 +260,22 @@ class TestCreateBooking:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books without note
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["note"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_double_book_slot(self, client_factory, regular_user, admin_user, alt_regular_user):
|
||||
async def test_cannot_double_book_slot(
|
||||
self, client_factory, regular_user, admin_user, alt_regular_user
|
||||
):
|
||||
"""Cannot book a slot that's already booked."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -259,7 +286,7 @@ class TestCreateBooking:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# First user books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
|
|
@ -267,19 +294,21 @@ class TestCreateBooking:
|
|||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# Second user tries to book same slot
|
||||
async with client_factory.create(cookies=alt_regular_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 409
|
||||
assert "already been booked" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_book_outside_availability(self, client_factory, regular_user, admin_user):
|
||||
async def test_cannot_book_outside_availability(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Cannot book a slot outside of availability."""
|
||||
# Admin sets availability for morning only
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -290,14 +319,14 @@ class TestCreateBooking:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User tries to book afternoon slot
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T14:00:00Z"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "not within any available time ranges" in response.json()["detail"]
|
||||
|
||||
|
|
@ -306,6 +335,7 @@ class TestCreateBooking:
|
|||
# Date Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBookingDateValidation:
|
||||
"""Test date validation for bookings."""
|
||||
|
||||
|
|
@ -317,9 +347,12 @@ class TestBookingDateValidation:
|
|||
"/api/booking",
|
||||
json={"slot_start": f"{date.today()}T09:00:00Z"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "past" in response.json()["detail"].lower() or "today" in response.json()["detail"].lower()
|
||||
assert (
|
||||
"past" in response.json()["detail"].lower()
|
||||
or "today" in response.json()["detail"].lower()
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_book_past_date(self, client_factory, regular_user):
|
||||
|
|
@ -330,7 +363,7 @@ class TestBookingDateValidation:
|
|||
"/api/booking",
|
||||
json={"slot_start": f"{yesterday}T09:00:00Z"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -342,7 +375,7 @@ class TestBookingDateValidation:
|
|||
"/api/booking",
|
||||
json={"slot_start": f"{too_far}T09:00:00Z"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "30" in response.json()["detail"]
|
||||
|
||||
|
|
@ -350,8 +383,10 @@ class TestBookingDateValidation:
|
|||
async def test_cannot_get_slots_today(self, client_factory, regular_user):
|
||||
"""Cannot get slots for today."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(date.today())})
|
||||
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(date.today())}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -359,8 +394,10 @@ class TestBookingDateValidation:
|
|||
"""Cannot get slots for past date."""
|
||||
yesterday = date.today() - timedelta(days=1)
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(yesterday)})
|
||||
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(yesterday)}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
|
|
@ -368,11 +405,14 @@ class TestBookingDateValidation:
|
|||
# Time Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBookingTimeValidation:
|
||||
"""Test time validation for bookings."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slot_must_be_15min_boundary(self, client_factory, regular_user, admin_user):
|
||||
async def test_slot_must_be_15min_boundary(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Slot start time must be on 15-minute boundary."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -383,14 +423,14 @@ class TestBookingTimeValidation:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User tries to book at 09:05
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:05:00Z"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "15-minute" in response.json()["detail"]
|
||||
|
||||
|
|
@ -399,6 +439,7 @@ class TestBookingTimeValidation:
|
|||
# Note Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBookingNoteValidation:
|
||||
"""Test note validation for bookings."""
|
||||
|
||||
|
|
@ -414,7 +455,7 @@ class TestBookingNoteValidation:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User tries to book with long note
|
||||
long_note = "x" * 145
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
|
|
@ -422,11 +463,13 @@ class TestBookingNoteValidation:
|
|||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": long_note},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_note_exactly_144_chars(self, client_factory, regular_user, admin_user):
|
||||
async def test_note_exactly_144_chars(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Note of exactly 144 characters is allowed."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -437,7 +480,7 @@ class TestBookingNoteValidation:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books with exactly 144 char note
|
||||
note = "x" * 144
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
|
|
@ -445,7 +488,7 @@ class TestBookingNoteValidation:
|
|||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": note},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["note"] == note
|
||||
|
||||
|
|
@ -454,6 +497,7 @@ class TestBookingNoteValidation:
|
|||
# User Appointments Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestUserAppointments:
|
||||
"""Test user appointments endpoints."""
|
||||
|
||||
|
|
@ -462,12 +506,14 @@ class TestUserAppointments:
|
|||
"""Returns empty list when user has no appointments."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/appointments")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_appointments_with_bookings(self, client_factory, regular_user, admin_user):
|
||||
async def test_get_my_appointments_with_bookings(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Returns user's appointments."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -478,7 +524,7 @@ class TestUserAppointments:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books two slots
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
await client.post(
|
||||
|
|
@ -489,10 +535,10 @@ class TestUserAppointments:
|
|||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:15:00Z", "note": "Second"},
|
||||
)
|
||||
|
||||
|
||||
# Get appointments
|
||||
response = await client.get("/api/appointments")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
|
|
@ -502,11 +548,13 @@ class TestUserAppointments:
|
|||
assert "Second" in notes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_view_user_appointments(self, client_factory, admin_user):
|
||||
async def test_admin_cannot_view_user_appointments(
|
||||
self, client_factory, admin_user
|
||||
):
|
||||
"""Admin cannot access user appointments endpoint."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/appointments")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -520,7 +568,9 @@ class TestCancelAppointment:
|
|||
"""Test cancelling appointments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_own_appointment(self, client_factory, regular_user, admin_user):
|
||||
async def test_cancel_own_appointment(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""User can cancel their own appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -531,7 +581,7 @@ class TestCancelAppointment:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
book_response = await client.post(
|
||||
|
|
@ -539,17 +589,19 @@ class TestCancelAppointment:
|
|||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
apt_id = book_response.json()["id"]
|
||||
|
||||
|
||||
# Cancel
|
||||
response = await client.post(f"/api/appointments/{apt_id}/cancel")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "cancelled_by_user"
|
||||
assert data["cancelled_at"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_cancel_others_appointment(self, client_factory, regular_user, alt_regular_user, admin_user):
|
||||
async def test_cannot_cancel_others_appointment(
|
||||
self, client_factory, regular_user, alt_regular_user, admin_user
|
||||
):
|
||||
"""User cannot cancel another user's appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -560,7 +612,7 @@ class TestCancelAppointment:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# First user books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
book_response = await client.post(
|
||||
|
|
@ -568,24 +620,28 @@ class TestCancelAppointment:
|
|||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
apt_id = book_response.json()["id"]
|
||||
|
||||
|
||||
# Second user tries to cancel
|
||||
async with client_factory.create(cookies=alt_regular_user["cookies"]) as client:
|
||||
response = await client.post(f"/api/appointments/{apt_id}/cancel")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "another user" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_cancel_nonexistent_appointment(self, client_factory, regular_user):
|
||||
async def test_cannot_cancel_nonexistent_appointment(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Returns 404 for non-existent appointment."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post("/api/appointments/99999/cancel")
|
||||
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user):
|
||||
async def test_cannot_cancel_already_cancelled(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Cannot cancel an already cancelled appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -596,7 +652,7 @@ class TestCancelAppointment:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books and cancels
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
book_response = await client.post(
|
||||
|
|
@ -605,23 +661,27 @@ class TestCancelAppointment:
|
|||
)
|
||||
apt_id = book_response.json()["id"]
|
||||
await client.post(f"/api/appointments/{apt_id}/cancel")
|
||||
|
||||
|
||||
# Try to cancel again
|
||||
response = await client.post(f"/api/appointments/{apt_id}/cancel")
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "cancelled_by_user" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_use_user_cancel_endpoint(self, client_factory, admin_user):
|
||||
async def test_admin_cannot_use_user_cancel_endpoint(
|
||||
self, client_factory, admin_user
|
||||
):
|
||||
"""Admin cannot use user cancel endpoint."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post("/api/appointments/1/cancel")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_slot_becomes_available(self, client_factory, regular_user, admin_user):
|
||||
async def test_cancelled_slot_becomes_available(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""After cancelling, the slot becomes available again."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -632,7 +692,7 @@ class TestCancelAppointment:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "09:30:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
book_response = await client.post(
|
||||
|
|
@ -640,17 +700,17 @@ class TestCancelAppointment:
|
|||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
apt_id = book_response.json()["id"]
|
||||
|
||||
|
||||
# Check slots - should have 1 slot left (09:15)
|
||||
slots_response = await client.get(
|
||||
"/api/booking/slots",
|
||||
params={"date": str(tomorrow())},
|
||||
)
|
||||
assert len(slots_response.json()["slots"]) == 1
|
||||
|
||||
|
||||
# Cancel
|
||||
await client.post(f"/api/appointments/{apt_id}/cancel")
|
||||
|
||||
|
||||
# Check slots - should have 2 slots now
|
||||
slots_response = await client.get(
|
||||
"/api/booking/slots",
|
||||
|
|
@ -663,7 +723,7 @@ class TestCancelAppointment:
|
|||
"""User cannot cancel a past appointment."""
|
||||
# Create a past appointment directly in DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
appointment = Appointment(
|
||||
user_id=regular_user["user"]["id"],
|
||||
slot_start=past_time,
|
||||
|
|
@ -674,11 +734,11 @@ class TestCancelAppointment:
|
|||
await db.commit()
|
||||
await db.refresh(appointment)
|
||||
apt_id = appointment.id
|
||||
|
||||
|
||||
# Try to cancel
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post(f"/api/appointments/{apt_id}/cancel")
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "past" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -687,11 +747,14 @@ class TestCancelAppointment:
|
|||
# Admin Appointments Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdminViewAppointments:
|
||||
"""Test admin viewing all appointments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_view_all_appointments(self, client_factory, regular_user, admin_user):
|
||||
async def test_admin_can_view_all_appointments(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Admin can view all appointments."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -702,18 +765,18 @@ class TestAdminViewAppointments:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
await client.post(
|
||||
"/api/booking",
|
||||
json={"slot_start": f"{tomorrow()}T09:00:00Z", "note": "Test"},
|
||||
)
|
||||
|
||||
|
||||
# Admin views all appointments
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
response = await admin_client.get("/api/admin/appointments")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Paginated response
|
||||
|
|
@ -725,11 +788,13 @@ class TestAdminViewAppointments:
|
|||
assert any(apt["note"] == "Test" for apt in data["records"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_all_appointments(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_view_all_appointments(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular user cannot access admin appointments endpoint."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/admin/appointments")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -743,7 +808,9 @@ class TestAdminCancelAppointment:
|
|||
"""Test admin cancelling appointments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_cancel_any_appointment(self, client_factory, regular_user, admin_user):
|
||||
async def test_admin_can_cancel_any_appointment(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Admin can cancel any user's appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -754,7 +821,7 @@ class TestAdminCancelAppointment:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
book_response = await client.post(
|
||||
|
|
@ -762,18 +829,22 @@ class TestAdminCancelAppointment:
|
|||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
apt_id = book_response.json()["id"]
|
||||
|
||||
|
||||
# Admin cancels
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
||||
|
||||
response = await admin_client.post(
|
||||
f"/api/admin/appointments/{apt_id}/cancel"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "cancelled_by_admin"
|
||||
assert data["cancelled_at"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_use_admin_cancel(self, client_factory, regular_user, admin_user):
|
||||
async def test_regular_user_cannot_use_admin_cancel(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Regular user cannot use admin cancel endpoint."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -784,7 +855,7 @@ class TestAdminCancelAppointment:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
book_response = await client.post(
|
||||
|
|
@ -792,22 +863,26 @@ class TestAdminCancelAppointment:
|
|||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
apt_id = book_response.json()["id"]
|
||||
|
||||
|
||||
# User tries to use admin cancel endpoint
|
||||
response = await client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cancel_nonexistent_appointment(self, client_factory, admin_user):
|
||||
async def test_admin_cancel_nonexistent_appointment(
|
||||
self, client_factory, admin_user
|
||||
):
|
||||
"""Returns 404 for non-existent appointment."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post("/api/admin/appointments/99999/cancel")
|
||||
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user):
|
||||
async def test_admin_cannot_cancel_already_cancelled(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Admin cannot cancel an already cancelled appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -818,7 +893,7 @@ class TestAdminCancelAppointment:
|
|||
"slots": [{"start_time": "09:00:00", "end_time": "12:00:00"}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# User books
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
book_response = await client.post(
|
||||
|
|
@ -826,23 +901,27 @@ class TestAdminCancelAppointment:
|
|||
json={"slot_start": f"{tomorrow()}T09:00:00Z"},
|
||||
)
|
||||
apt_id = book_response.json()["id"]
|
||||
|
||||
|
||||
# User cancels their own appointment
|
||||
await client.post(f"/api/appointments/{apt_id}/cancel")
|
||||
|
||||
|
||||
# Admin tries to cancel again
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
||||
|
||||
response = await admin_client.post(
|
||||
f"/api/admin/appointments/{apt_id}/cancel"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "cancelled_by_user" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_cancel_past_appointment(self, client_factory, regular_user, admin_user):
|
||||
async def test_admin_cannot_cancel_past_appointment(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Admin cannot cancel a past appointment."""
|
||||
# Create a past appointment directly in DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
appointment = Appointment(
|
||||
user_id=regular_user["user"]["id"],
|
||||
slot_start=past_time,
|
||||
|
|
@ -853,11 +932,12 @@ class TestAdminCancelAppointment:
|
|||
await db.commit()
|
||||
await db.refresh(appointment)
|
||||
apt_id = appointment.id
|
||||
|
||||
|
||||
# Admin tries to cancel
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
||||
|
||||
response = await admin_client.post(
|
||||
f"/api/admin/appointments/{apt_id}/cancel"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "past" in response.json()["detail"].lower()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,13 @@
|
|||
|
||||
Note: Registration now requires an invite code.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from auth import COOKIE_NAME
|
||||
from models import ROLE_REGULAR
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
from tests.helpers import create_invite_for_godfather, unique_email
|
||||
|
||||
|
||||
# Protected endpoint tests - without auth
|
||||
|
|
@ -41,9 +42,11 @@ async def test_increment_counter_invalid_cookie(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_counter_authenticated(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -53,10 +56,10 @@ async def test_get_counter_authenticated(client_factory):
|
|||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
response = await authed.get("/api/counter")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "value" in response.json()
|
||||
|
||||
|
|
@ -64,9 +67,11 @@ async def test_get_counter_authenticated(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_increment_counter(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -76,12 +81,12 @@ async def test_increment_counter(client_factory):
|
|||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
# Get current value
|
||||
before = await authed.get("/api/counter")
|
||||
before_value = before.json()["value"]
|
||||
|
||||
|
||||
# Increment
|
||||
response = await authed.post("/api/counter/increment")
|
||||
assert response.status_code == 200
|
||||
|
|
@ -91,9 +96,11 @@ async def test_increment_counter(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_increment_counter_multiple(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -103,26 +110,28 @@ async def test_increment_counter_multiple(client_factory):
|
|||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
# Get starting value
|
||||
before = await authed.get("/api/counter")
|
||||
start = before.json()["value"]
|
||||
|
||||
|
||||
# Increment 3 times
|
||||
await authed.post("/api/counter/increment")
|
||||
await authed.post("/api/counter/increment")
|
||||
response = await authed.post("/api/counter/increment")
|
||||
|
||||
|
||||
assert response.json()["value"] == start + 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_counter_after_increment(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -132,14 +141,14 @@ async def test_get_counter_after_increment(client_factory):
|
|||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
before = await authed.get("/api/counter")
|
||||
start = before.json()["value"]
|
||||
|
||||
|
||||
await authed.post("/api/counter/increment")
|
||||
await authed.post("/api/counter/increment")
|
||||
|
||||
|
||||
response = await authed.get("/api/counter")
|
||||
assert response.json()["value"] == start + 2
|
||||
|
||||
|
|
@ -149,10 +158,12 @@ async def test_get_counter_after_increment(client_factory):
|
|||
async def test_counter_shared_between_users(client_factory):
|
||||
# Create godfather and invites for two users
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
# Create first user
|
||||
reg1 = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -163,15 +174,15 @@ async def test_counter_shared_between_users(client_factory):
|
|||
},
|
||||
)
|
||||
cookies1 = dict(reg1.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies1) as user1:
|
||||
# Get starting value
|
||||
before = await user1.get("/api/counter")
|
||||
start = before.json()["value"]
|
||||
|
||||
|
||||
await user1.post("/api/counter/increment")
|
||||
await user1.post("/api/counter/increment")
|
||||
|
||||
|
||||
# Create second user - should see the increments
|
||||
reg2 = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
|
|
@ -182,14 +193,14 @@ async def test_counter_shared_between_users(client_factory):
|
|||
},
|
||||
)
|
||||
cookies2 = dict(reg2.cookies)
|
||||
|
||||
|
||||
async with client_factory.create(cookies=cookies2) as user2:
|
||||
response = await user2.get("/api/counter")
|
||||
assert response.json()["value"] == start + 2
|
||||
|
||||
|
||||
# Second user increments
|
||||
await user2.post("/api/counter/increment")
|
||||
|
||||
|
||||
# First user sees the increment
|
||||
async with client_factory.create(cookies=cookies1) as user1:
|
||||
response = await user1.get("/api/counter")
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
"""Tests for invite functionality."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from invite_utils import (
|
||||
generate_invite_identifier,
|
||||
normalize_identifier,
|
||||
is_valid_identifier_format,
|
||||
BIP39_WORDS,
|
||||
generate_invite_identifier,
|
||||
is_valid_identifier_format,
|
||||
normalize_identifier,
|
||||
)
|
||||
from models import Invite, InviteStatus, User, ROLE_REGULAR
|
||||
from tests.helpers import unique_email
|
||||
from models import ROLE_REGULAR, Invite, InviteStatus, User
|
||||
from tests.conftest import create_user_with_roles
|
||||
|
||||
from tests.helpers import unique_email
|
||||
|
||||
# ============================================================================
|
||||
# Invite Utils Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_bip39_words_loaded():
|
||||
"""BIP39 word list should have exactly 2048 words."""
|
||||
assert len(BIP39_WORDS) == 2048
|
||||
|
|
@ -26,7 +27,7 @@ def test_generate_invite_identifier_format():
|
|||
"""Generated identifier should have word-word-NN format."""
|
||||
identifier = generate_invite_identifier()
|
||||
assert is_valid_identifier_format(identifier)
|
||||
|
||||
|
||||
parts = identifier.split("-")
|
||||
assert len(parts) == 3
|
||||
assert parts[0] in BIP39_WORDS
|
||||
|
|
@ -74,11 +75,11 @@ def test_is_valid_identifier_format_invalid():
|
|||
assert is_valid_identifier_format("apple-banana") is False
|
||||
assert is_valid_identifier_format("apple-banana-42-extra") is False
|
||||
assert is_valid_identifier_format("applebanan42") is False
|
||||
|
||||
|
||||
# Empty parts
|
||||
assert is_valid_identifier_format("-banana-42") is False
|
||||
assert is_valid_identifier_format("apple--42") is False
|
||||
|
||||
|
||||
# Invalid number format
|
||||
assert is_valid_identifier_format("apple-banana-4") is False # Single digit
|
||||
assert is_valid_identifier_format("apple-banana-420") is False # Three digits
|
||||
|
|
@ -89,6 +90,7 @@ def test_is_valid_identifier_format_invalid():
|
|||
# Invite Model Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invite(client_factory):
|
||||
"""Can create an invite with godfather."""
|
||||
|
|
@ -97,7 +99,7 @@ async def test_create_invite(client_factory):
|
|||
godfather = await create_user_with_roles(
|
||||
db, unique_email("godfather"), "password123", [ROLE_REGULAR]
|
||||
)
|
||||
|
||||
|
||||
# Create invite
|
||||
invite = Invite(
|
||||
identifier="test-invite-01",
|
||||
|
|
@ -107,7 +109,7 @@ async def test_create_invite(client_factory):
|
|||
db.add(invite)
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
|
||||
|
||||
assert invite.id is not None
|
||||
assert invite.identifier == "test-invite-01"
|
||||
assert invite.godfather_id == godfather.id
|
||||
|
|
@ -125,20 +127,20 @@ async def test_invite_godfather_relationship(client_factory):
|
|||
godfather = await create_user_with_roles(
|
||||
db, unique_email("godfather"), "password123", [ROLE_REGULAR]
|
||||
)
|
||||
|
||||
|
||||
invite = Invite(
|
||||
identifier="rel-test-01",
|
||||
godfather_id=godfather.id,
|
||||
)
|
||||
db.add(invite)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Query invite fresh
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.identifier == "rel-test-01")
|
||||
)
|
||||
loaded_invite = result.scalar_one()
|
||||
|
||||
|
||||
assert loaded_invite.godfather is not None
|
||||
assert loaded_invite.godfather.email == godfather.email
|
||||
|
||||
|
|
@ -147,25 +149,25 @@ async def test_invite_godfather_relationship(client_factory):
|
|||
async def test_invite_unique_identifier(client_factory):
|
||||
"""Invite identifier must be unique."""
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("godfather"), "password123", [ROLE_REGULAR]
|
||||
)
|
||||
|
||||
|
||||
invite1 = Invite(
|
||||
identifier="unique-test-01",
|
||||
godfather_id=godfather.id,
|
||||
)
|
||||
db.add(invite1)
|
||||
await db.commit()
|
||||
|
||||
|
||||
invite2 = Invite(
|
||||
identifier="unique-test-01", # Same identifier
|
||||
godfather_id=godfather.id,
|
||||
)
|
||||
db.add(invite2)
|
||||
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
await db.commit()
|
||||
|
||||
|
|
@ -173,8 +175,8 @@ async def test_invite_unique_identifier(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_invite_status_transitions(client_factory):
|
||||
"""Invite status can be changed."""
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("godfather"), "password123", [ROLE_REGULAR]
|
||||
|
|
@ -182,7 +184,7 @@ async def test_invite_status_transitions(client_factory):
|
|||
user = await create_user_with_roles(
|
||||
db, unique_email("invitee"), "password123", [ROLE_REGULAR]
|
||||
)
|
||||
|
||||
|
||||
invite = Invite(
|
||||
identifier="status-test-01",
|
||||
godfather_id=godfather.id,
|
||||
|
|
@ -190,14 +192,14 @@ async def test_invite_status_transitions(client_factory):
|
|||
)
|
||||
db.add(invite)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Transition to SPENT
|
||||
invite.status = InviteStatus.SPENT
|
||||
invite.used_by_id = user.id
|
||||
invite.spent_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
|
||||
|
||||
assert invite.status == InviteStatus.SPENT
|
||||
assert invite.used_by_id == user.id
|
||||
assert invite.spent_at is not None
|
||||
|
|
@ -206,13 +208,13 @@ async def test_invite_status_transitions(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_invite_revoke(client_factory):
|
||||
"""Invite can be revoked."""
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("godfather"), "password123", [ROLE_REGULAR]
|
||||
)
|
||||
|
||||
|
||||
invite = Invite(
|
||||
identifier="revoke-test-01",
|
||||
godfather_id=godfather.id,
|
||||
|
|
@ -220,13 +222,13 @@ async def test_invite_revoke(client_factory):
|
|||
)
|
||||
db.add(invite)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Revoke
|
||||
invite.status = InviteStatus.REVOKED
|
||||
invite.revoked_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
|
||||
|
||||
assert invite.status == InviteStatus.REVOKED
|
||||
assert invite.revoked_at is not None
|
||||
assert invite.used_by_id is None # Not used
|
||||
|
|
@ -236,6 +238,7 @@ async def test_invite_revoke(client_factory):
|
|||
# User Godfather Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_godfather_relationship(client_factory):
|
||||
"""User can have a godfather."""
|
||||
|
|
@ -243,7 +246,7 @@ async def test_user_godfather_relationship(client_factory):
|
|||
godfather = await create_user_with_roles(
|
||||
db, unique_email("godfather"), "password123", [ROLE_REGULAR]
|
||||
)
|
||||
|
||||
|
||||
# Create user with godfather
|
||||
user = User(
|
||||
email=unique_email("godchild"),
|
||||
|
|
@ -252,13 +255,11 @@ async def test_user_godfather_relationship(client_factory):
|
|||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Query user fresh
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user.id)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.id == user.id))
|
||||
loaded_user = result.scalar_one()
|
||||
|
||||
|
||||
assert loaded_user.godfather_id == godfather.id
|
||||
assert loaded_user.godfather is not None
|
||||
assert loaded_user.godfather.email == godfather.email
|
||||
|
|
@ -271,7 +272,7 @@ async def test_user_without_godfather(client_factory):
|
|||
user = await create_user_with_roles(
|
||||
db, unique_email("noparent"), "password123", [ROLE_REGULAR]
|
||||
)
|
||||
|
||||
|
||||
assert user.godfather_id is None
|
||||
assert user.godfather is None
|
||||
|
||||
|
|
@ -280,6 +281,7 @@ async def test_user_without_godfather(client_factory):
|
|||
# Admin Create Invite API Tests (Phase 2)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_create_invite(client_factory, admin_user, regular_user):
|
||||
"""Admin can create an invite for a regular user."""
|
||||
|
|
@ -290,12 +292,12 @@ async def test_admin_can_create_invite(client_factory, admin_user, regular_user)
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
response = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["godfather_id"] == godfather.id
|
||||
|
|
@ -318,12 +320,12 @@ async def test_admin_can_create_invite_for_self(client_factory, admin_user):
|
|||
select(User).where(User.email == admin_user["email"])
|
||||
)
|
||||
admin = result.scalar_one()
|
||||
|
||||
|
||||
response = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": admin.id},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["godfather_id"] == admin.id
|
||||
|
|
@ -338,7 +340,7 @@ async def test_regular_user_cannot_create_invite(client_factory, regular_user):
|
|||
"/api/admin/invites",
|
||||
json={"godfather_id": 1},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
|
|
@ -350,7 +352,7 @@ async def test_unauthenticated_cannot_create_invite(client_factory):
|
|||
"/api/admin/invites",
|
||||
json={"godfather_id": 1},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
|
|
@ -362,7 +364,7 @@ async def test_create_invite_invalid_godfather(client_factory, admin_user):
|
|||
"/api/admin/invites",
|
||||
json={"godfather_id": 99999},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -376,39 +378,39 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
response = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
|
||||
|
||||
data = response.json()
|
||||
invite_id = data["id"]
|
||||
|
||||
|
||||
# Query from DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.id == invite_id)
|
||||
)
|
||||
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||
invite = result.scalar_one()
|
||||
|
||||
|
||||
assert invite.identifier == data["identifier"]
|
||||
assert invite.godfather_id == godfather.id
|
||||
assert invite.status == InviteStatus.READY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user):
|
||||
async def test_create_invite_retries_on_collision(
|
||||
client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Create invite retries with new identifier on collision."""
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
# Create first invite normally
|
||||
response1 = await client.post(
|
||||
"/api/admin/invites",
|
||||
|
|
@ -416,22 +418,25 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
|||
)
|
||||
assert response1.status_code == 200
|
||||
identifier1 = response1.json()["identifier"]
|
||||
|
||||
|
||||
# Mock generator to first return the same identifier (collision), then a new one
|
||||
call_count = 0
|
||||
|
||||
def mock_generator():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return identifier1 # Will collide
|
||||
return f"unique-word-{call_count:02d}" # Won't collide
|
||||
|
||||
with patch("routes.invites.generate_invite_identifier", side_effect=mock_generator):
|
||||
|
||||
with patch(
|
||||
"routes.invites.generate_invite_identifier", side_effect=mock_generator
|
||||
):
|
||||
response2 = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
|
||||
|
||||
assert response2.status_code == 200
|
||||
# Should have retried and gotten a new identifier
|
||||
assert response2.json()["identifier"] != identifier1
|
||||
|
|
@ -442,6 +447,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
|||
# Invite Check API Tests (Phase 3)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_valid(client_factory, admin_user, regular_user):
|
||||
"""Check endpoint returns valid=True for READY invite."""
|
||||
|
|
@ -452,17 +458,17 @@ async def test_check_invite_valid(client_factory, admin_user, regular_user):
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Check invite (no auth needed)
|
||||
async with client_factory.create() as client:
|
||||
response = await client.get(f"/api/invites/{identifier}/check")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["valid"] is True
|
||||
|
|
@ -475,7 +481,7 @@ async def test_check_invite_not_found(client_factory):
|
|||
"""Check endpoint returns valid=False for unknown invite."""
|
||||
async with client_factory.create() as client:
|
||||
response = await client.get("/api/invites/fake-invite-99/check")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["valid"] is False
|
||||
|
|
@ -492,14 +498,14 @@ async def test_check_invite_invalid_format(client_factory):
|
|||
data = response.json()
|
||||
assert data["valid"] is False
|
||||
assert "format" in data["error"].lower()
|
||||
|
||||
|
||||
# Single digit number
|
||||
response = await client.get("/api/invites/word-word-1/check")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["valid"] is False
|
||||
assert "format" in data["error"].lower()
|
||||
|
||||
|
||||
# Too many parts
|
||||
response = await client.get("/api/invites/word-word-word-00/check")
|
||||
assert response.status_code == 200
|
||||
|
|
@ -509,7 +515,9 @@ async def test_check_invite_invalid_format(client_factory):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_spent_returns_not_found(client_factory, admin_user, regular_user):
|
||||
async def test_check_invite_spent_returns_not_found(
|
||||
client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Check endpoint returns same error for spent invite as for non-existent (no info leakage)."""
|
||||
# Create invite
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
|
|
@ -518,13 +526,13 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Use the invite
|
||||
async with client_factory.create() as client:
|
||||
await client.post(
|
||||
|
|
@ -535,11 +543,11 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Check spent invite - should return same error as non-existent
|
||||
async with client_factory.create() as client:
|
||||
response = await client.get(f"/api/invites/{identifier}/check")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["valid"] is False
|
||||
|
|
@ -547,10 +555,12 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_revoked_returns_not_found(client_factory, admin_user, regular_user):
|
||||
async def test_check_invite_revoked_returns_not_found(
|
||||
client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Check endpoint returns same error for revoked invite as for non-existent (no info leakage)."""
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
# Create invite
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
async with client_factory.get_db_session() as db:
|
||||
|
|
@ -558,14 +568,14 @@ async def test_check_invite_revoked_returns_not_found(client_factory, admin_user
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
invite_id = create_resp.json()["id"]
|
||||
|
||||
|
||||
# Revoke the invite
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||
|
|
@ -573,11 +583,11 @@ async def test_check_invite_revoked_returns_not_found(client_factory, admin_user
|
|||
invite.status = InviteStatus.REVOKED
|
||||
invite.revoked_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Check revoked invite - should return same error as non-existent
|
||||
async with client_factory.create() as client:
|
||||
response = await client.get(f"/api/invites/{identifier}/check")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["valid"] is False
|
||||
|
|
@ -594,17 +604,17 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Check with uppercase
|
||||
async with client_factory.create() as client:
|
||||
response = await client.get(f"/api/invites/{identifier.upper()}/check")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["valid"] is True
|
||||
|
||||
|
|
@ -613,6 +623,7 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular
|
|||
# Register with Invite Tests (Phase 3)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_valid_invite(client_factory, admin_user, regular_user):
|
||||
"""Can register with valid invite code."""
|
||||
|
|
@ -624,13 +635,13 @@ async def test_register_with_valid_invite(client_factory, admin_user, regular_us
|
|||
)
|
||||
godfather = result.scalar_one()
|
||||
godfather_id = godfather.id
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather_id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Register with invite
|
||||
new_email = unique_email("newuser")
|
||||
async with client_factory.create() as client:
|
||||
|
|
@ -642,7 +653,7 @@ async def test_register_with_valid_invite(client_factory, admin_user, regular_us
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["email"] == new_email
|
||||
|
|
@ -659,7 +670,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
|
|
@ -667,7 +678,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
|
|||
invite_data = create_resp.json()
|
||||
identifier = invite_data["identifier"]
|
||||
invite_id = invite_data["id"]
|
||||
|
||||
|
||||
# Register
|
||||
async with client_factory.create() as client:
|
||||
await client.post(
|
||||
|
|
@ -678,14 +689,12 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Check invite status
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.id == invite_id)
|
||||
)
|
||||
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||
invite = result.scalar_one()
|
||||
|
||||
|
||||
assert invite.status == InviteStatus.SPENT
|
||||
assert invite.used_by_id is not None
|
||||
assert invite.spent_at is not None
|
||||
|
|
@ -702,13 +711,13 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user)
|
|||
)
|
||||
godfather = result.scalar_one()
|
||||
godfather_id = godfather.id
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather_id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Register
|
||||
new_email = unique_email("godchildtest")
|
||||
async with client_factory.create() as client:
|
||||
|
|
@ -720,14 +729,12 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user)
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Check user's godfather
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == new_email)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.email == new_email))
|
||||
new_user = result.scalar_one()
|
||||
|
||||
|
||||
assert new_user.godfather_id == godfather_id
|
||||
|
||||
|
||||
|
|
@ -743,7 +750,7 @@ async def test_register_with_invalid_invite(client_factory):
|
|||
"invite_identifier": "fake-invite-99",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "invalid" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -758,13 +765,13 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# First registration
|
||||
async with client_factory.create() as client:
|
||||
await client.post(
|
||||
|
|
@ -775,7 +782,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Second registration with same invite
|
||||
async with client_factory.create() as client:
|
||||
response = await client.post(
|
||||
|
|
@ -786,7 +793,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "invalid invite code" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -794,8 +801,8 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
|
|||
@pytest.mark.asyncio
|
||||
async def test_register_with_revoked_invite(client_factory, admin_user, regular_user):
|
||||
"""Cannot register with revoked invite."""
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
# Create invite
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
async with client_factory.get_db_session() as db:
|
||||
|
|
@ -803,7 +810,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
|
|
@ -811,17 +818,15 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
|
|||
invite_data = create_resp.json()
|
||||
identifier = invite_data["identifier"]
|
||||
invite_id = invite_data["id"]
|
||||
|
||||
|
||||
# Revoke invite directly in DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.id == invite_id)
|
||||
)
|
||||
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||
invite = result.scalar_one()
|
||||
invite.status = InviteStatus.REVOKED
|
||||
invite.revoked_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Try to register
|
||||
async with client_factory.create() as client:
|
||||
response = await client.post(
|
||||
|
|
@ -832,7 +837,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "invalid invite code" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -847,13 +852,13 @@ async def test_register_duplicate_email(client_factory, admin_user, regular_user
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Try to register with existing email
|
||||
async with client_factory.create() as client:
|
||||
response = await client.post(
|
||||
|
|
@ -864,7 +869,7 @@ async def test_register_duplicate_email(client_factory, admin_user, regular_user
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "already registered" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -879,13 +884,13 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Register
|
||||
async with client_factory.create() as client:
|
||||
response = await client.post(
|
||||
|
|
@ -896,7 +901,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert "auth_token" in response.cookies
|
||||
|
||||
|
||||
|
|
@ -904,6 +909,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
|
|||
# User Invites API Tests (Phase 4)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user):
|
||||
"""Regular user can list their own invites."""
|
||||
|
|
@ -914,14 +920,14 @@ async def test_regular_user_can_list_invites(client_factory, admin_user, regular
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
|
||||
await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
|
||||
|
||||
|
||||
# List invites as regular user
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/invites")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
invites = response.json()
|
||||
assert len(invites) == 2
|
||||
|
|
@ -935,13 +941,15 @@ async def test_user_with_no_invites_gets_empty_list(client_factory, regular_user
|
|||
"""User with no invites gets empty list."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/invites")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regular_user):
|
||||
async def test_spent_invite_shows_used_by_email(
|
||||
client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Spent invite shows who used it."""
|
||||
# Create invite for regular user
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
|
|
@ -950,13 +958,13 @@ async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regu
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Use the invite
|
||||
invitee_email = unique_email("invitee")
|
||||
async with client_factory.create() as client:
|
||||
|
|
@ -968,11 +976,11 @@ async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regu
|
|||
"invite_identifier": identifier,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Check that regular user sees the invitee email
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/invites")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
invites = response.json()
|
||||
assert len(invites) == 1
|
||||
|
|
@ -985,7 +993,7 @@ async def test_admin_cannot_list_own_invites(client_factory, admin_user):
|
|||
"""Admin without VIEW_OWN_INVITES permission gets 403."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/invites")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
|
|
@ -994,7 +1002,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory):
|
|||
"""Unauthenticated user gets 401."""
|
||||
async with client_factory.create() as client:
|
||||
response = await client.get("/api/invites")
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
|
|
@ -1002,6 +1010,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory):
|
|||
# Admin Invite Management Tests (Phase 5)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user):
|
||||
"""Admin can list all invites."""
|
||||
|
|
@ -1012,13 +1021,13 @@ async def test_admin_can_list_all_invites(client_factory, admin_user, regular_us
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
|
||||
await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
|
||||
|
||||
|
||||
# List all
|
||||
response = await client.get("/api/admin/invites")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 2
|
||||
|
|
@ -1034,14 +1043,14 @@ async def test_admin_list_pagination(client_factory, admin_user, regular_user):
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
# Create 5 invites
|
||||
for _ in range(5):
|
||||
await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
|
||||
|
||||
|
||||
# Get page 1 with 2 per page
|
||||
response = await client.get("/api/admin/invites?page=1&per_page=2")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["records"]) == 2
|
||||
|
|
@ -1058,13 +1067,13 @@ async def test_admin_filter_by_status(client_factory, admin_user, regular_user):
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
# Create an invite
|
||||
await client.post("/api/admin/invites", json={"godfather_id": godfather.id})
|
||||
|
||||
|
||||
# Filter by ready
|
||||
response = await client.get("/api/admin/invites?status=ready")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for record in data["records"]:
|
||||
|
|
@ -1080,17 +1089,17 @@ async def test_admin_can_revoke_invite(client_factory, admin_user, regular_user)
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
# Create invite
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
invite_id = create_resp.json()["id"]
|
||||
|
||||
|
||||
# Revoke it
|
||||
response = await client.post(f"/api/admin/invites/{invite_id}/revoke")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "revoked"
|
||||
|
|
@ -1106,14 +1115,14 @@ async def test_cannot_revoke_spent_invite(client_factory, admin_user, regular_us
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
# Create invite
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
invite_data = create_resp.json()
|
||||
|
||||
|
||||
# Use the invite
|
||||
async with client_factory.create() as client:
|
||||
await client.post(
|
||||
|
|
@ -1124,11 +1133,11 @@ async def test_cannot_revoke_spent_invite(client_factory, admin_user, regular_us
|
|||
"invite_identifier": invite_data["identifier"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Try to revoke
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post(f"/api/admin/invites/{invite_data['id']}/revoke")
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "only ready" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -1138,7 +1147,7 @@ async def test_revoke_nonexistent_invite(client_factory, admin_user):
|
|||
"""Revoking non-existent invite returns 404."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post("/api/admin/invites/99999/revoke")
|
||||
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
|
|
@ -1149,8 +1158,7 @@ async def test_regular_user_cannot_access_admin_invites(client_factory, regular_
|
|||
# List
|
||||
response = await client.get("/api/admin/invites")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
# Revoke
|
||||
response = await client.post("/api/admin/invites/1/revoke")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
|
|
|||
|
|
@ -7,15 +7,16 @@ These tests verify that:
|
|||
3. Unauthenticated users are denied access (401)
|
||||
4. The permission system cannot be bypassed
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Permission
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Role Assignment Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRoleAssignment:
|
||||
"""Test that roles are properly assigned and returned."""
|
||||
|
||||
|
|
@ -23,7 +24,7 @@ class TestRoleAssignment:
|
|||
async def test_regular_user_has_correct_roles(self, client_factory, regular_user):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "regular" in data["roles"]
|
||||
|
|
@ -33,25 +34,27 @@ class TestRoleAssignment:
|
|||
async def test_admin_user_has_correct_roles(self, client_factory, admin_user):
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "admin" in data["roles"]
|
||||
assert "regular" not in data["roles"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_has_correct_permissions(self, client_factory, regular_user):
|
||||
async def test_regular_user_has_correct_permissions(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
||||
data = response.json()
|
||||
permissions = data["permissions"]
|
||||
|
||||
|
||||
# Should have counter and sum permissions
|
||||
assert Permission.VIEW_COUNTER.value in permissions
|
||||
assert Permission.INCREMENT_COUNTER.value in permissions
|
||||
assert Permission.USE_SUM.value in permissions
|
||||
|
||||
|
||||
# Should NOT have audit permission
|
||||
assert Permission.VIEW_AUDIT.value not in permissions
|
||||
|
||||
|
|
@ -59,23 +62,25 @@ class TestRoleAssignment:
|
|||
async def test_admin_user_has_correct_permissions(self, client_factory, admin_user):
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
||||
data = response.json()
|
||||
permissions = data["permissions"]
|
||||
|
||||
|
||||
# Should have audit permission
|
||||
assert Permission.VIEW_AUDIT.value in permissions
|
||||
|
||||
|
||||
# Should NOT have counter/sum permissions
|
||||
assert Permission.VIEW_COUNTER.value not in permissions
|
||||
assert Permission.INCREMENT_COUNTER.value not in permissions
|
||||
assert Permission.USE_SUM.value not in permissions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_with_no_roles_has_no_permissions(self, client_factory, user_no_roles):
|
||||
async def test_user_with_no_roles_has_no_permissions(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
||||
data = response.json()
|
||||
assert data["roles"] == []
|
||||
assert data["permissions"] == []
|
||||
|
|
@ -85,6 +90,7 @@ class TestRoleAssignment:
|
|||
# Counter Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCounterAccess:
|
||||
"""Test access control for counter endpoints."""
|
||||
|
||||
|
|
@ -92,15 +98,17 @@ class TestCounterAccess:
|
|||
async def test_regular_user_can_view_counter(self, client_factory, regular_user):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "value" in response.json()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_increment_counter(self, client_factory, regular_user):
|
||||
async def test_regular_user_can_increment_counter(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post("/api/counter/increment")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "value" in response.json()
|
||||
|
||||
|
|
@ -109,7 +117,7 @@ class TestCounterAccess:
|
|||
"""Admin users should be forbidden from counter endpoints."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "permission" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -118,15 +126,17 @@ class TestCounterAccess:
|
|||
"""Admin users should be forbidden from incrementing counter."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post("/api/counter/increment")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_view_counter(self, client_factory, user_no_roles):
|
||||
async def test_user_without_roles_cannot_view_counter(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
"""Users with no roles should be forbidden."""
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -146,6 +156,7 @@ class TestCounterAccess:
|
|||
# Sum Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSumAccess:
|
||||
"""Test access control for sum endpoint."""
|
||||
|
||||
|
|
@ -156,7 +167,7 @@ class TestSumAccess:
|
|||
"/api/sum",
|
||||
json={"a": 5, "b": 3},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["result"] == 8
|
||||
|
|
@ -169,17 +180,19 @@ class TestSumAccess:
|
|||
"/api/sum",
|
||||
json={"a": 5, "b": 3},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_use_sum(self, client_factory, user_no_roles):
|
||||
async def test_user_without_roles_cannot_use_sum(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/sum",
|
||||
json={"a": 5, "b": 3},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -195,6 +208,7 @@ class TestSumAccess:
|
|||
# Audit Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAuditAccess:
|
||||
"""Test access control for audit endpoints."""
|
||||
|
||||
|
|
@ -202,7 +216,7 @@ class TestAuditAccess:
|
|||
async def test_admin_can_view_counter_audit(self, client_factory, admin_user):
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "records" in data
|
||||
|
|
@ -212,34 +226,40 @@ class TestAuditAccess:
|
|||
async def test_admin_can_view_sum_audit(self, client_factory, admin_user):
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/sum")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "records" in data
|
||||
assert "total" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_counter_audit(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_view_counter_audit(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular users should be forbidden from audit endpoints."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "permission" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_sum_audit(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_view_sum_audit(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular users should be forbidden from audit endpoints."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/sum")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_view_audit(self, client_factory, user_no_roles):
|
||||
async def test_user_without_roles_cannot_view_audit(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -257,6 +277,7 @@ class TestAuditAccess:
|
|||
# Offensive Security Tests - Bypass Attempts
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSecurityBypassAttempts:
|
||||
"""
|
||||
Offensive tests that attempt to bypass security controls.
|
||||
|
|
@ -264,7 +285,9 @@ class TestSecurityBypassAttempts:
|
|||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_audit_with_forged_role_claim(self, client_factory, regular_user):
|
||||
async def test_cannot_access_audit_with_forged_role_claim(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""
|
||||
Attempt to access audit by somehow claiming admin role.
|
||||
The server should verify roles from DB, not trust client claims.
|
||||
|
|
@ -272,7 +295,7 @@ class TestSecurityBypassAttempts:
|
|||
# Regular user tries to access audit endpoint
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
|
||||
|
||||
# Should be denied regardless of any manipulation attempts
|
||||
assert response.status_code == 403
|
||||
|
||||
|
|
@ -280,23 +303,27 @@ class TestSecurityBypassAttempts:
|
|||
async def test_cannot_access_counter_with_expired_session(self, client_factory):
|
||||
"""Test that invalid/expired tokens are rejected."""
|
||||
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5OTk5IiwiZXhwIjoxfQ.invalid"
|
||||
|
||||
|
||||
async with client_factory.create(cookies={"auth_token": fake_token}) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_with_tampered_token(self, client_factory, regular_user):
|
||||
async def test_cannot_access_with_tampered_token(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Test that tokens signed with wrong key are rejected."""
|
||||
# Take a valid token and modify it
|
||||
original_token = regular_user["cookies"].get("auth_token", "")
|
||||
if original_token:
|
||||
tampered_token = original_token[:-5] + "XXXXX"
|
||||
|
||||
async with client_factory.create(cookies={"auth_token": tampered_token}) as client:
|
||||
|
||||
async with client_factory.create(
|
||||
cookies={"auth_token": tampered_token}
|
||||
) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -305,14 +332,16 @@ class TestSecurityBypassAttempts:
|
|||
Test that new registrations cannot claim admin role.
|
||||
New users should only get 'regular' role by default.
|
||||
"""
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
from models import ROLE_REGULAR
|
||||
|
||||
from tests.conftest import create_user_with_roles
|
||||
from tests.helpers import create_invite_for_godfather, unique_email
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
||||
response = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
|
|
@ -321,18 +350,18 @@ class TestSecurityBypassAttempts:
|
|||
"invite_identifier": invite_code,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
|
||||
# Should only have regular role, not admin
|
||||
assert "admin" not in data["roles"]
|
||||
assert Permission.VIEW_AUDIT.value not in data["permissions"]
|
||||
|
||||
|
||||
# Try to access audit with this new user
|
||||
async with client_factory.create(cookies=dict(response.cookies)) as client:
|
||||
audit_response = await client.get("/api/audit/counter")
|
||||
|
||||
|
||||
assert audit_response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -341,33 +370,35 @@ class TestSecurityBypassAttempts:
|
|||
If a user is deleted, their token should no longer work.
|
||||
This tests that tokens are validated against current DB state.
|
||||
"""
|
||||
from tests.helpers import unique_email
|
||||
from sqlalchemy import delete
|
||||
|
||||
from models import User
|
||||
|
||||
from tests.helpers import unique_email
|
||||
|
||||
email = unique_email("deleted")
|
||||
|
||||
|
||||
# Create and login user
|
||||
async with client_factory.get_db_session() as db:
|
||||
from tests.conftest import create_user_with_roles
|
||||
|
||||
user = await create_user_with_roles(db, email, "password123", ["regular"])
|
||||
user_id = user.id
|
||||
|
||||
|
||||
login_response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": "password123"},
|
||||
)
|
||||
cookies = dict(login_response.cookies)
|
||||
|
||||
|
||||
# Delete the user from DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
await db.execute(delete(User).where(User.id == user_id))
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Try to use the old token
|
||||
async with client_factory.create(cookies=cookies) as client:
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -376,42 +407,41 @@ class TestSecurityBypassAttempts:
|
|||
If a user's role is changed, the change should be reflected
|
||||
in subsequent requests (no stale permission cache).
|
||||
"""
|
||||
from tests.helpers import unique_email
|
||||
from sqlalchemy import select
|
||||
from models import User, Role
|
||||
|
||||
|
||||
from models import Role, User
|
||||
from tests.helpers import unique_email
|
||||
|
||||
email = unique_email("rolechange")
|
||||
|
||||
|
||||
# Create regular user
|
||||
async with client_factory.get_db_session() as db:
|
||||
from tests.conftest import create_user_with_roles
|
||||
|
||||
await create_user_with_roles(db, email, "password123", ["regular"])
|
||||
|
||||
|
||||
login_response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": "password123"},
|
||||
)
|
||||
cookies = dict(login_response.cookies)
|
||||
|
||||
|
||||
# Verify can access counter but not audit
|
||||
async with client_factory.create(cookies=cookies) as client:
|
||||
assert (await client.get("/api/counter")).status_code == 200
|
||||
assert (await client.get("/api/audit/counter")).status_code == 403
|
||||
|
||||
|
||||
# Change user's role from regular to admin
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one()
|
||||
|
||||
|
||||
result = await db.execute(select(Role).where(Role.name == "admin"))
|
||||
admin_role = result.scalar_one()
|
||||
|
||||
result = await db.execute(select(Role).where(Role.name == "regular"))
|
||||
regular_role = result.scalar_one()
|
||||
|
||||
user.roles = [admin_role] # Remove regular, add admin
|
||||
|
||||
user.roles = [admin_role] # Replace roles with admin only
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Now should have audit access but not counter access
|
||||
async with client_factory.create(cookies=cookies) as client:
|
||||
assert (await client.get("/api/audit/counter")).status_code == 200
|
||||
|
|
@ -422,6 +452,7 @@ class TestSecurityBypassAttempts:
|
|||
# Audit Record Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAuditRecords:
|
||||
"""Test that actions are properly recorded in audit logs."""
|
||||
|
||||
|
|
@ -433,15 +464,15 @@ class TestAuditRecords:
|
|||
# Regular user increments counter
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
await client.post("/api/counter/increment")
|
||||
|
||||
|
||||
# Admin checks audit
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
|
||||
# Find record for our user
|
||||
records = data["records"]
|
||||
user_records = [r for r in records if r["user_email"] == regular_user["email"]]
|
||||
|
|
@ -455,18 +486,18 @@ class TestAuditRecords:
|
|||
# Regular user uses sum
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
await client.post("/api/sum", json={"a": 10, "b": 20})
|
||||
|
||||
|
||||
# Admin checks audit
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/sum")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
|
||||
# Find record with our values
|
||||
records = data["records"]
|
||||
matching = [r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30]
|
||||
matching = [
|
||||
r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30
|
||||
]
|
||||
assert len(matching) >= 1
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""Tests for user profile and contact details."""
|
||||
import pytest
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from models import User, ROLE_REGULAR
|
||||
from auth import get_password_hash
|
||||
from models import User
|
||||
from tests.helpers import unique_email
|
||||
|
||||
# Valid npub for testing (32 zero bytes encoded as bech32)
|
||||
|
|
@ -16,7 +16,7 @@ class TestUserContactFields:
|
|||
async def test_contact_fields_default_to_none(self, client_factory):
|
||||
"""New users should have all contact fields as None."""
|
||||
email = unique_email("test")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
user = User(
|
||||
email=email,
|
||||
|
|
@ -25,7 +25,7 @@ class TestUserContactFields:
|
|||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
assert user.contact_email is None
|
||||
assert user.telegram is None
|
||||
assert user.signal is None
|
||||
|
|
@ -34,7 +34,7 @@ class TestUserContactFields:
|
|||
async def test_contact_fields_can_be_set(self, client_factory):
|
||||
"""Contact fields can be set when creating a user."""
|
||||
email = unique_email("test")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
user = User(
|
||||
email=email,
|
||||
|
|
@ -47,7 +47,7 @@ class TestUserContactFields:
|
|||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
|
||||
assert user.contact_email == "contact@example.com"
|
||||
assert user.telegram == "@alice"
|
||||
assert user.signal == "alice.42"
|
||||
|
|
@ -56,7 +56,7 @@ class TestUserContactFields:
|
|||
async def test_contact_fields_persist_after_reload(self, client_factory):
|
||||
"""Contact fields should persist in the database."""
|
||||
email = unique_email("test")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
user = User(
|
||||
email=email,
|
||||
|
|
@ -69,12 +69,12 @@ class TestUserContactFields:
|
|||
db.add(user)
|
||||
await db.commit()
|
||||
user_id = user.id
|
||||
|
||||
|
||||
# Reload from database in a new session
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
loaded_user = result.scalar_one()
|
||||
|
||||
|
||||
assert loaded_user.contact_email == "contact@example.com"
|
||||
assert loaded_user.telegram == "@bob"
|
||||
assert loaded_user.signal == "bob.99"
|
||||
|
|
@ -83,7 +83,7 @@ class TestUserContactFields:
|
|||
async def test_contact_fields_can_be_updated(self, client_factory):
|
||||
"""Contact fields can be updated after user creation."""
|
||||
email = unique_email("test")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
user = User(
|
||||
email=email,
|
||||
|
|
@ -92,21 +92,21 @@ class TestUserContactFields:
|
|||
db.add(user)
|
||||
await db.commit()
|
||||
user_id = user.id
|
||||
|
||||
|
||||
# Update fields
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
|
||||
|
||||
user.contact_email = "new@example.com"
|
||||
user.telegram = "@updated"
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Verify update persisted
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
|
||||
|
||||
assert user.contact_email == "new@example.com"
|
||||
assert user.telegram == "@updated"
|
||||
assert user.signal is None # Still None
|
||||
|
|
@ -115,7 +115,7 @@ class TestUserContactFields:
|
|||
async def test_contact_fields_can_be_cleared(self, client_factory):
|
||||
"""Contact fields can be set back to None."""
|
||||
email = unique_email("test")
|
||||
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
user = User(
|
||||
email=email,
|
||||
|
|
@ -126,21 +126,21 @@ class TestUserContactFields:
|
|||
db.add(user)
|
||||
await db.commit()
|
||||
user_id = user.id
|
||||
|
||||
|
||||
# Clear fields
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
|
||||
|
||||
user.contact_email = None
|
||||
user.telegram = None
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Verify cleared
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
|
||||
|
||||
assert user.contact_email is None
|
||||
assert user.telegram is None
|
||||
|
||||
|
|
@ -152,7 +152,7 @@ class TestGetProfileEndpoint:
|
|||
"""Regular user can fetch their profile."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "contact_email" in data
|
||||
|
|
@ -169,7 +169,7 @@ class TestGetProfileEndpoint:
|
|||
"""Admin user gets 403 when trying to access profile."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "regular users" in response.json()["detail"].lower()
|
||||
|
||||
|
|
@ -177,7 +177,7 @@ class TestGetProfileEndpoint:
|
|||
"""Unauthenticated user gets 401."""
|
||||
async with client_factory.create() as client:
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_profile_returns_existing_data(self, client_factory, regular_user):
|
||||
|
|
@ -191,11 +191,11 @@ class TestGetProfileEndpoint:
|
|||
user.contact_email = "contact@test.com"
|
||||
user.telegram = "@testuser"
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Fetch via API
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["contact_email"] == "contact@test.com"
|
||||
|
|
@ -219,7 +219,7 @@ class TestUpdateProfileEndpoint:
|
|||
"nostr_npub": VALID_NPUB,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["contact_email"] == "new@example.com"
|
||||
|
|
@ -234,10 +234,10 @@ class TestUpdateProfileEndpoint:
|
|||
"/api/profile",
|
||||
json={"telegram": "@persisted"},
|
||||
)
|
||||
|
||||
|
||||
# Fetch again to verify
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["telegram"] == "@persisted"
|
||||
|
||||
|
|
@ -248,7 +248,7 @@ class TestUpdateProfileEndpoint:
|
|||
"/api/profile",
|
||||
json={"telegram": "@admin"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_unauthenticated_user_gets_401(self, client_factory):
|
||||
|
|
@ -258,7 +258,7 @@ class TestUpdateProfileEndpoint:
|
|||
"/api/profile",
|
||||
json={"telegram": "@test"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_can_clear_fields(self, client_factory, regular_user):
|
||||
|
|
@ -272,7 +272,7 @@ class TestUpdateProfileEndpoint:
|
|||
"telegram": "@test",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Then clear them
|
||||
response = await client.put(
|
||||
"/api/profile",
|
||||
|
|
@ -283,7 +283,7 @@ class TestUpdateProfileEndpoint:
|
|||
"nostr_npub": None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["contact_email"] is None
|
||||
|
|
@ -296,7 +296,7 @@ class TestUpdateProfileEndpoint:
|
|||
"/api/profile",
|
||||
json={"contact_email": "not-an-email"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert "field_errors" in data["detail"]
|
||||
|
|
@ -309,7 +309,7 @@ class TestUpdateProfileEndpoint:
|
|||
"/api/profile",
|
||||
json={"telegram": "missing_at_sign"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert "field_errors" in data["detail"]
|
||||
|
|
@ -322,13 +322,15 @@ class TestUpdateProfileEndpoint:
|
|||
"/api/profile",
|
||||
json={"nostr_npub": "npub1invalid"},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert "field_errors" in data["detail"]
|
||||
assert "nostr_npub" in data["detail"]["field_errors"]
|
||||
|
||||
async def test_multiple_invalid_fields_returns_all_errors(self, client_factory, regular_user):
|
||||
async def test_multiple_invalid_fields_returns_all_errors(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Multiple invalid fields return all errors."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.put(
|
||||
|
|
@ -338,13 +340,15 @@ class TestUpdateProfileEndpoint:
|
|||
"telegram": "no-at",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 422
|
||||
data = response.json()
|
||||
assert "contact_email" in data["detail"]["field_errors"]
|
||||
assert "telegram" in data["detail"]["field_errors"]
|
||||
|
||||
async def test_partial_update_preserves_other_fields(self, client_factory, regular_user):
|
||||
async def test_partial_update_preserves_other_fields(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Updating one field doesn't affect others (they get set to the request values)."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
# Set initial values
|
||||
|
|
@ -355,7 +359,7 @@ class TestUpdateProfileEndpoint:
|
|||
"telegram": "@initial",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Update only telegram, but note: PUT replaces all fields
|
||||
# So we need to include all fields we want to keep
|
||||
response = await client.put(
|
||||
|
|
@ -365,7 +369,7 @@ class TestUpdateProfileEndpoint:
|
|||
"telegram": "@updated",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["contact_email"] == "initial@example.com"
|
||||
|
|
@ -386,10 +390,10 @@ class TestProfilePrivacy:
|
|||
"telegram": "@secret",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Check /api/auth/me doesn't expose it
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# These fields should NOT be in the response
|
||||
|
|
@ -402,12 +406,15 @@ class TestProfilePrivacy:
|
|||
class TestProfileGodfather:
|
||||
"""Tests for godfather information in profile."""
|
||||
|
||||
async def test_profile_shows_godfather_email(self, client_factory, admin_user, regular_user):
|
||||
async def test_profile_shows_godfather_email(
|
||||
self, client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Profile shows godfather email for users who signed up with invite."""
|
||||
from tests.helpers import unique_email
|
||||
from sqlalchemy import select
|
||||
|
||||
from models import User
|
||||
|
||||
from tests.helpers import unique_email
|
||||
|
||||
# Create invite
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
async with client_factory.get_db_session() as db:
|
||||
|
|
@ -415,13 +422,13 @@ class TestProfileGodfather:
|
|||
select(User).where(User.email == regular_user["email"])
|
||||
)
|
||||
godfather = result.scalar_one()
|
||||
|
||||
|
||||
create_resp = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
)
|
||||
identifier = create_resp.json()["identifier"]
|
||||
|
||||
|
||||
# Register new user with invite
|
||||
new_email = unique_email("godchild")
|
||||
async with client_factory.create() as client:
|
||||
|
|
@ -434,20 +441,22 @@ class TestProfileGodfather:
|
|||
},
|
||||
)
|
||||
new_user_cookies = dict(reg_resp.cookies)
|
||||
|
||||
|
||||
# Check profile shows godfather
|
||||
async with client_factory.create(cookies=new_user_cookies) as client:
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["godfather_email"] == regular_user["email"]
|
||||
|
||||
async def test_profile_godfather_null_for_seeded_users(self, client_factory, regular_user):
|
||||
async def test_profile_godfather_null_for_seeded_users(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Profile shows null godfather for users without one (e.g., seeded users)."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["godfather_email"] is None
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
"""Tests for profile field validation."""
|
||||
import pytest
|
||||
|
||||
from validation import (
|
||||
validate_contact_email,
|
||||
validate_telegram,
|
||||
validate_signal,
|
||||
validate_nostr_npub,
|
||||
validate_profile_fields,
|
||||
validate_signal,
|
||||
validate_telegram,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -140,13 +139,17 @@ class TestValidateNostrNpub:
|
|||
assert validate_nostr_npub(self.VALID_NPUB) is None
|
||||
|
||||
def test_wrong_prefix(self):
|
||||
result = validate_nostr_npub("nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz")
|
||||
result = validate_nostr_npub(
|
||||
"nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz"
|
||||
)
|
||||
assert result is not None
|
||||
assert "npub" in result.lower()
|
||||
|
||||
def test_invalid_checksum(self):
|
||||
# Change last character to break checksum
|
||||
result = validate_nostr_npub("npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd")
|
||||
result = validate_nostr_npub(
|
||||
"npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd"
|
||||
)
|
||||
assert result is not None
|
||||
assert "checksum" in result.lower()
|
||||
|
||||
|
|
@ -155,7 +158,9 @@ class TestValidateNostrNpub:
|
|||
assert result is not None
|
||||
|
||||
def test_not_starting_with_npub1(self):
|
||||
result = validate_nostr_npub("npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc")
|
||||
result = validate_nostr_npub(
|
||||
"npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc"
|
||||
)
|
||||
assert result is not None
|
||||
assert "npub1" in result
|
||||
|
||||
|
|
@ -206,4 +211,3 @@ class TestValidateProfileFields:
|
|||
nostr_npub="",
|
||||
)
|
||||
assert errors == {}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Validate shared constants match backend definitions."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, AppointmentStatus
|
||||
from models import ROLE_ADMIN, ROLE_REGULAR, AppointmentStatus, InviteStatus
|
||||
|
||||
|
||||
def validate_shared_constants() -> None:
|
||||
|
|
@ -11,13 +12,13 @@ def validate_shared_constants() -> None:
|
|||
Raises ValueError if there's a mismatch.
|
||||
"""
|
||||
constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
|
||||
|
||||
|
||||
if not constants_path.exists():
|
||||
raise ValueError(f"Shared constants file not found: {constants_path}")
|
||||
|
||||
|
||||
with open(constants_path) as f:
|
||||
constants = json.load(f)
|
||||
|
||||
|
||||
# Validate roles
|
||||
expected_roles = {"ADMIN": ROLE_ADMIN, "REGULAR": ROLE_REGULAR}
|
||||
if constants.get("roles") != expected_roles:
|
||||
|
|
@ -25,39 +26,46 @@ def validate_shared_constants() -> None:
|
|||
f"Role mismatch in shared/constants.json. "
|
||||
f"Expected: {expected_roles}, Got: {constants.get('roles')}"
|
||||
)
|
||||
|
||||
|
||||
# Validate invite statuses
|
||||
expected_invite_statuses = {s.name: s.value for s in InviteStatus}
|
||||
if constants.get("inviteStatuses") != expected_invite_statuses:
|
||||
got = constants.get("inviteStatuses")
|
||||
raise ValueError(
|
||||
f"Invite status mismatch in shared/constants.json. "
|
||||
f"Expected: {expected_invite_statuses}, Got: {constants.get('inviteStatuses')}"
|
||||
f"Invite status mismatch. Expected: {expected_invite_statuses}, Got: {got}"
|
||||
)
|
||||
|
||||
|
||||
# Validate appointment statuses
|
||||
expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus}
|
||||
if constants.get("appointmentStatuses") != expected_appointment_statuses:
|
||||
got = constants.get("appointmentStatuses")
|
||||
raise ValueError(
|
||||
f"Appointment status mismatch in shared/constants.json. "
|
||||
f"Expected: {expected_appointment_statuses}, Got: {constants.get('appointmentStatuses')}"
|
||||
f"Appointment status mismatch. "
|
||||
f"Expected: {expected_appointment_statuses}, Got: {got}"
|
||||
)
|
||||
|
||||
|
||||
# Validate booking constants exist with required fields
|
||||
booking = constants.get("booking", {})
|
||||
required_booking_fields = ["slotDurationMinutes", "maxAdvanceDays", "minAdvanceDays", "noteMaxLength"]
|
||||
required_booking_fields = [
|
||||
"slotDurationMinutes",
|
||||
"maxAdvanceDays",
|
||||
"minAdvanceDays",
|
||||
"noteMaxLength",
|
||||
]
|
||||
for field in required_booking_fields:
|
||||
if field not in booking:
|
||||
raise ValueError(f"Missing booking constant '{field}' in shared/constants.json")
|
||||
|
||||
raise ValueError(f"Missing booking constant '{field}' in constants.json")
|
||||
|
||||
# Validate validation rules exist (structure check only)
|
||||
validation = constants.get("validation", {})
|
||||
required_fields = ["telegram", "signal", "nostrNpub"]
|
||||
for field in required_fields:
|
||||
if field not in validation:
|
||||
raise ValueError(f"Missing validation rules for '{field}' in shared/constants.json")
|
||||
raise ValueError(
|
||||
f"Missing validation rules for '{field}' in constants.json"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
validate_shared_constants()
|
||||
print("✓ Shared constants are valid")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
"""Validation utilities for user profile fields."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from email_validator import validate_email, EmailNotValidError
|
||||
from bech32 import bech32_decode
|
||||
from email_validator import EmailNotValidError, validate_email
|
||||
|
||||
# Load validation rules from shared constants
|
||||
_constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
|
||||
|
|
@ -18,13 +19,13 @@ NPUB_RULES = _constants["validation"]["nostrNpub"]
|
|||
def validate_contact_email(value: str | None) -> str | None:
|
||||
"""
|
||||
Validate contact email format.
|
||||
|
||||
|
||||
Returns None if valid, error message if invalid.
|
||||
Empty/None values are valid (field is optional).
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
validate_email(value, check_deliverability=False)
|
||||
return None
|
||||
|
|
@ -35,84 +36,84 @@ def validate_contact_email(value: str | None) -> str | None:
|
|||
def validate_telegram(value: str | None) -> str | None:
|
||||
"""
|
||||
Validate Telegram handle.
|
||||
|
||||
|
||||
Must start with @ if provided, with characters after @ within max length.
|
||||
Returns None if valid, error message if invalid.
|
||||
Empty/None values are valid (field is optional).
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
|
||||
prefix = TELEGRAM_RULES["mustStartWith"]
|
||||
max_len = TELEGRAM_RULES["maxLengthAfterAt"]
|
||||
|
||||
|
||||
if not value.startswith(prefix):
|
||||
return f"Telegram handle must start with {prefix}"
|
||||
|
||||
|
||||
handle = value[1:]
|
||||
if not handle:
|
||||
return f"Telegram handle must have at least one character after {prefix}"
|
||||
|
||||
|
||||
if len(handle) > max_len:
|
||||
return f"Telegram handle must be at most {max_len} characters (after {prefix})"
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_signal(value: str | None) -> str | None:
|
||||
"""
|
||||
Validate Signal username.
|
||||
|
||||
|
||||
Any non-empty string within max length is valid.
|
||||
Returns None if valid, error message if invalid.
|
||||
Empty/None values are valid (field is optional).
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
|
||||
max_len = SIGNAL_RULES["maxLength"]
|
||||
|
||||
|
||||
# Signal usernames are fairly permissive, just check it's not empty
|
||||
if len(value.strip()) == 0:
|
||||
return "Signal username cannot be empty"
|
||||
|
||||
|
||||
if len(value) > max_len:
|
||||
return f"Signal username must be at most {max_len} characters"
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_nostr_npub(value: str | None) -> str | None:
|
||||
"""
|
||||
Validate Nostr npub (public key in bech32 format).
|
||||
|
||||
|
||||
Must be valid bech32 with 'npub' prefix.
|
||||
Returns None if valid, error message if invalid.
|
||||
Empty/None values are valid (field is optional).
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
|
||||
prefix = NPUB_RULES["prefix"]
|
||||
expected_words = NPUB_RULES["bech32Words"]
|
||||
|
||||
|
||||
if not value.startswith(prefix):
|
||||
return f"Nostr npub must start with '{prefix}'"
|
||||
|
||||
|
||||
# Decode bech32 to validate checksum
|
||||
hrp, data = bech32_decode(value)
|
||||
|
||||
|
||||
if hrp is None or data is None:
|
||||
return "Invalid Nostr npub: bech32 checksum failed"
|
||||
|
||||
|
||||
if hrp != "npub":
|
||||
return "Nostr npub must have 'npub' prefix"
|
||||
|
||||
|
||||
# npub should decode to 32 bytes (256 bits) for a public key
|
||||
# In bech32, each character encodes 5 bits, so 32 bytes = 52 characters of data
|
||||
if len(data) != expected_words:
|
||||
return "Invalid Nostr npub: incorrect length"
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -124,23 +125,22 @@ def validate_profile_fields(
|
|||
) -> dict[str, str]:
|
||||
"""
|
||||
Validate all profile fields at once.
|
||||
|
||||
|
||||
Returns a dict of field_name -> error_message for any invalid fields.
|
||||
Empty dict means all fields are valid.
|
||||
"""
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
|
||||
if err := validate_contact_email(contact_email):
|
||||
errors["contact_email"] = err
|
||||
|
||||
|
||||
if err := validate_telegram(telegram):
|
||||
errors["telegram"] = err
|
||||
|
||||
|
||||
if err := validate_signal(signal):
|
||||
errors["signal"] = err
|
||||
|
||||
|
||||
if err := validate_nostr_npub(nostr_npub):
|
||||
errors["nostr_npub"] = err
|
||||
|
||||
return errors
|
||||
|
||||
return errors
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue