Add ruff linter/formatter for Python
- Add ruff as dev dependency - Configure ruff in pyproject.toml with strict 88-char line limit - Ignore B008 (FastAPI Depends pattern is standard) - Allow longer lines in tests for readability - Fix all lint issues in source files - Add Makefile targets: lint-backend, format-backend, fix-backend
This commit is contained in:
parent
69bc8413e0
commit
6c218130e9
31 changed files with 1234 additions and 876 deletions
11
Makefile
11
Makefile
|
|
@ -1,4 +1,4 @@
|
|||
.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants
|
||||
.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants lint-backend format-backend fix-backend
|
||||
|
||||
-include .env
|
||||
export
|
||||
|
|
@ -93,3 +93,12 @@ check-types-fresh: generate-types-standalone
|
|||
|
||||
check-constants:
|
||||
@cd backend && uv run python validate_constants.py
|
||||
|
||||
lint-backend:
|
||||
cd backend && uv run ruff check .
|
||||
|
||||
format-backend:
|
||||
cd backend && uv run ruff format .
|
||||
|
||||
fix-backend:
|
||||
cd backend && uv run ruff check --fix . && uv run ruff format .
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import bcrypt
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
|
@ -8,7 +8,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_db
|
||||
from models import User, Permission
|
||||
from models import Permission, User
|
||||
from schemas import UserResponse
|
||||
|
||||
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
|
||||
|
|
@ -32,9 +32,13 @@ def get_password_hash(password: str) -> str:
|
|||
).decode("utf-8")
|
||||
|
||||
|
||||
def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str:
|
||||
def create_access_token(
|
||||
data: dict[str, str],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
to_encode: dict[str, str | datetime] = dict(data)
|
||||
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now(UTC) + delta
|
||||
to_encode["exp"] = expire
|
||||
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded
|
||||
|
|
@ -72,7 +76,7 @@ async def get_current_user(
|
|||
raise credentials_exception
|
||||
user_id = int(user_id_str)
|
||||
except (JWTError, ValueError):
|
||||
raise credentials_exception
|
||||
raise credentials_exception from None
|
||||
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
|
@ -83,13 +87,16 @@ async def get_current_user(
|
|||
|
||||
def require_permission(*required_permissions: Permission):
|
||||
"""
|
||||
Dependency factory that checks if user has ALL of the required permissions.
|
||||
Dependency factory that checks if user has ALL required permissions.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/counter")
|
||||
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))):
|
||||
async def get_counter(
|
||||
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
|
||||
):
|
||||
...
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
|
|
@ -99,11 +106,13 @@ def require_permission(*required_permissions: Permission):
|
|||
|
||||
missing = [p for p in required_permissions if p not in user_permissions]
|
||||
if missing:
|
||||
missing_str = ", ".join(p.value for p in missing)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}",
|
||||
detail=f"Missing required permissions: {missing_str}",
|
||||
)
|
||||
return user
|
||||
|
||||
return permission_checker
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import os
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret")
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret"
|
||||
)
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
|
@ -15,4 +18,3 @@ class Base(DeclarativeBase):
|
|||
async def get_db():
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""Utilities for invite code generation and validation."""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -53,8 +54,4 @@ def is_valid_identifier_format(identifier: str) -> bool:
|
|||
return False
|
||||
|
||||
# Check number is two digits
|
||||
if len(number) != 2 or not number.isdigit():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return len(number) == 2 and number.isdigit()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
"""FastAPI application entry point."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from database import engine, Base
|
||||
from routes import sum as sum_routes
|
||||
from routes import counter as counter_routes
|
||||
from database import Base, engine
|
||||
from routes import audit as audit_routes
|
||||
from routes import profile as profile_routes
|
||||
from routes import invites as invites_routes
|
||||
from routes import auth as auth_routes
|
||||
from routes import meta as meta_routes
|
||||
from routes import availability as availability_routes
|
||||
from routes import booking as booking_routes
|
||||
from routes import counter as counter_routes
|
||||
from routes import invites as invites_routes
|
||||
from routes import meta as meta_routes
|
||||
from routes import profile as profile_routes
|
||||
from routes import sum as sum_routes
|
||||
from validate_constants import validate_shared_constants
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,24 @@
|
|||
from datetime import datetime, date, time, timezone
|
||||
from datetime import UTC, date, datetime, time
|
||||
from enum import Enum as PyEnum
|
||||
from typing import TypedDict
|
||||
from sqlalchemy import Integer, String, Float, DateTime, Date, Time, ForeignKey, Table, Column, Enum, UniqueConstraint, select
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Table,
|
||||
Time,
|
||||
UniqueConstraint,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from database import Base
|
||||
|
||||
|
||||
|
|
@ -14,6 +29,7 @@ class RoleConfig(TypedDict):
|
|||
|
||||
class Permission(str, PyEnum):
|
||||
"""All available permissions in the system."""
|
||||
|
||||
# Counter permissions
|
||||
VIEW_COUNTER = "view_counter"
|
||||
INCREMENT_COUNTER = "increment_counter"
|
||||
|
|
@ -41,6 +57,7 @@ class Permission(str, PyEnum):
|
|||
|
||||
class InviteStatus(str, PyEnum):
|
||||
"""Status of an invite."""
|
||||
|
||||
READY = "ready"
|
||||
SPENT = "spent"
|
||||
REVOKED = "revoked"
|
||||
|
|
@ -48,6 +65,7 @@ class InviteStatus(str, PyEnum):
|
|||
|
||||
class AppointmentStatus(str, PyEnum):
|
||||
"""Status of an appointment."""
|
||||
|
||||
BOOKED = "booked"
|
||||
CANCELLED_BY_USER = "cancelled_by_user"
|
||||
CANCELLED_BY_ADMIN = "cancelled_by_admin"
|
||||
|
|
@ -60,7 +78,7 @@ ROLE_REGULAR = "regular"
|
|||
# Role definitions with their permissions
|
||||
ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
||||
ROLE_ADMIN: {
|
||||
"description": "Administrator with audit, invite, and appointment management access",
|
||||
"description": "Administrator with audit/invite/appointment access",
|
||||
"permissions": [
|
||||
Permission.VIEW_AUDIT,
|
||||
Permission.MANAGE_INVITES,
|
||||
|
|
@ -88,7 +106,12 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
|||
role_permissions = Table(
|
||||
"role_permissions",
|
||||
Base.metadata,
|
||||
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column(
|
||||
"role_id",
|
||||
Integer,
|
||||
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
Column("permission", Enum(Permission), primary_key=True),
|
||||
)
|
||||
|
||||
|
|
@ -97,8 +120,18 @@ role_permissions = Table(
|
|||
user_roles = Table(
|
||||
"user_roles",
|
||||
Base.metadata,
|
||||
Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
|
||||
Column(
|
||||
"user_id",
|
||||
Integer,
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
Column(
|
||||
"role_id",
|
||||
Integer,
|
||||
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -118,23 +151,34 @@ class Role(Base):
|
|||
|
||||
async def get_permissions(self, db: AsyncSession) -> set[Permission]:
|
||||
"""Get all permissions for this role."""
|
||||
result = await db.execute(
|
||||
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id)
|
||||
query = select(role_permissions.c.permission).where(
|
||||
role_permissions.c.role_id == self.id
|
||||
)
|
||||
result = await db.execute(query)
|
||||
return {row[0] for row in result.fetchall()}
|
||||
|
||||
async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None:
|
||||
async def set_permissions(
|
||||
self, db: AsyncSession, permissions: list[Permission]
|
||||
) -> None:
|
||||
"""Set all permissions for this role (replaces existing)."""
|
||||
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id))
|
||||
delete_query = role_permissions.delete().where(
|
||||
role_permissions.c.role_id == self.id
|
||||
)
|
||||
await db.execute(delete_query)
|
||||
for perm in permissions:
|
||||
await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm))
|
||||
insert_query = role_permissions.insert().values(
|
||||
role_id=self.id, permission=perm
|
||||
)
|
||||
await db.execute(insert_query)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
email: Mapped[str] = mapped_column(
|
||||
String(255), unique=True, nullable=False, index=True
|
||||
)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
|
||||
# Contact details (all optional)
|
||||
|
|
@ -192,12 +236,14 @@ class SumRecord(Base):
|
|||
__tablename__ = "sum_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
a: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
b: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
result: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -205,11 +251,13 @@ class CounterRecord(Base):
|
|||
__tablename__ = "counter_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
value_before: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
value_after: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -217,7 +265,9 @@ class Invite(Base):
|
|||
__tablename__ = "invites"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
identifier: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
identifier: Mapped[str] = mapped_column(
|
||||
String(64), unique=True, nullable=False, index=True
|
||||
)
|
||||
status: Mapped[InviteStatus] = mapped_column(
|
||||
Enum(InviteStatus), nullable=False, default=InviteStatus.READY
|
||||
)
|
||||
|
|
@ -244,14 +294,19 @@ class Invite(Base):
|
|||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
spent_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
spent_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
|
||||
class Availability(Base):
|
||||
"""Admin availability slots for booking."""
|
||||
|
||||
__tablename__ = "availability"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("date", "start_time", name="uq_availability_date_start"),
|
||||
|
|
@ -262,34 +317,37 @@ class Availability(Base):
|
|||
start_time: Mapped[time] = mapped_column(Time, nullable=False)
|
||||
end_time: Mapped[time] = mapped_column(Time, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc)
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
class Appointment(Base):
|
||||
"""User appointment bookings."""
|
||||
|
||||
__tablename__ = "appointments"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("slot_start", name="uq_appointment_slot_start"),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("slot_start", name="uq_appointment_slot_start"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined")
|
||||
slot_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
slot_start: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, index=True
|
||||
)
|
||||
slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
note: Mapped[str | None] = mapped_column(String(144), nullable=True)
|
||||
status: Mapped[AppointmentStatus] = mapped_column(
|
||||
Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
cancelled_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
cancelled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ dev = [
|
|||
"httpx>=0.28.1",
|
||||
"aiosqlite>=0.20.0",
|
||||
"mypy>=1.13.0",
|
||||
"ruff>=0.14.10",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
|
@ -30,3 +31,27 @@ check_untyped_defs = true
|
|||
ignore_missing_imports = true
|
||||
exclude = ["tests/"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"UP", # pyupgrade
|
||||
"SIM", # flake8-simplify
|
||||
"RUF", # ruff-specific rules
|
||||
]
|
||||
ignore = [
|
||||
"B008", # function-call-in-default-argument (standard FastAPI pattern with Depends)
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["E501"] # Allow longer lines in tests for readability
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
"""Audit routes for viewing action records."""
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, func, desc
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import User, SumRecord, CounterRecord, Permission
|
||||
from models import CounterRecord, Permission, SumRecord, User
|
||||
from schemas import (
|
||||
CounterRecordResponse,
|
||||
SumRecordResponse,
|
||||
PaginatedCounterRecords,
|
||||
PaginatedSumRecords,
|
||||
SumRecordResponse,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/audit", tags=["audit"])
|
||||
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Authentication routes for register, login, logout, and current user."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from sqlalchemy import select
|
||||
|
|
@ -9,18 +10,17 @@ from auth import (
|
|||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
COOKIE_NAME,
|
||||
COOKIE_SECURE,
|
||||
get_password_hash,
|
||||
get_user_by_email,
|
||||
authenticate_user,
|
||||
build_user_response,
|
||||
create_access_token,
|
||||
get_current_user,
|
||||
build_user_response,
|
||||
get_password_hash,
|
||||
get_user_by_email,
|
||||
)
|
||||
from database import get_db
|
||||
from invite_utils import normalize_identifier
|
||||
from models import User, Role, ROLE_REGULAR, Invite, InviteStatus
|
||||
from schemas import UserLogin, UserResponse, RegisterWithInvite
|
||||
|
||||
from models import ROLE_REGULAR, Invite, InviteStatus, Role, User
|
||||
from schemas import RegisterWithInvite, UserLogin, UserResponse
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
|
@ -52,9 +52,8 @@ async def register(
|
|||
"""Register a new user using an invite code."""
|
||||
# Validate invite
|
||||
normalized_identifier = normalize_identifier(user_data.invite_identifier)
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.identifier == normalized_identifier)
|
||||
)
|
||||
query = select(Invite).where(Invite.identifier == normalized_identifier)
|
||||
result = await db.execute(query)
|
||||
invite = result.scalar_one_or_none()
|
||||
|
||||
# Return same error for not found, spent, and revoked to avoid information leakage
|
||||
|
|
@ -90,7 +89,7 @@ async def register(
|
|||
# Mark invite as spent
|
||||
invite.status = InviteStatus.SPENT
|
||||
invite.used_by_id = user.id
|
||||
invite.spent_at = datetime.now(timezone.utc)
|
||||
invite.spent_at = datetime.now(UTC)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
|
|
|||
|
|
@ -1,28 +1,28 @@
|
|||
"""Availability routes for admin to manage booking availability."""
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, delete, and_
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import User, Availability, Permission
|
||||
from models import Availability, Permission, User
|
||||
from schemas import (
|
||||
TimeSlot,
|
||||
AvailabilityDay,
|
||||
AvailabilityResponse,
|
||||
SetAvailabilityRequest,
|
||||
CopyAvailabilityRequest,
|
||||
SetAvailabilityRequest,
|
||||
TimeSlot,
|
||||
)
|
||||
from shared_constants import MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS
|
||||
|
||||
from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS
|
||||
|
||||
router = APIRouter(prefix="/api/admin/availability", tags=["availability"])
|
||||
|
||||
|
||||
def _get_date_range_bounds() -> tuple[date, date]:
|
||||
"""Get the valid date range for availability (using MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
|
||||
"""Get valid date range (MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
|
||||
today = date.today()
|
||||
min_date = today + timedelta(days=MIN_ADVANCE_DAYS)
|
||||
max_date = today + timedelta(days=MAX_ADVANCE_DAYS)
|
||||
|
|
@ -34,12 +34,14 @@ def _validate_date_in_range(d: date, min_date: date, max_date: date) -> None:
|
|||
if d < min_date:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot set availability for past dates. Earliest allowed: {min_date}",
|
||||
detail=f"Cannot set availability for past dates. "
|
||||
f"Earliest allowed: {min_date}",
|
||||
)
|
||||
if d > max_date:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot set availability more than {MAX_ADVANCE_DAYS} days ahead. Latest allowed: {max_date}",
|
||||
detail=f"Cannot set more than {MAX_ADVANCE_DAYS} days ahead. "
|
||||
f"Latest allowed: {max_date}",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -70,15 +72,16 @@ async def get_availability(
|
|||
for slot in slots:
|
||||
if slot.date not in days_dict:
|
||||
days_dict[slot.date] = []
|
||||
days_dict[slot.date].append(TimeSlot(
|
||||
start_time=slot.start_time,
|
||||
end_time=slot.end_time,
|
||||
))
|
||||
days_dict[slot.date].append(
|
||||
TimeSlot(
|
||||
start_time=slot.start_time,
|
||||
end_time=slot.end_time,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to response format
|
||||
days = [
|
||||
AvailabilityDay(date=d, slots=days_dict[d])
|
||||
for d in sorted(days_dict.keys())
|
||||
AvailabilityDay(date=d, slots=days_dict[d]) for d in sorted(days_dict.keys())
|
||||
]
|
||||
|
||||
return AvailabilityResponse(days=days)
|
||||
|
|
@ -98,9 +101,12 @@ async def set_availability(
|
|||
sorted_slots = sorted(request.slots, key=lambda s: s.start_time)
|
||||
for i in range(len(sorted_slots) - 1):
|
||||
if sorted_slots[i].end_time > sorted_slots[i + 1].start_time:
|
||||
end = sorted_slots[i].end_time
|
||||
start = sorted_slots[i + 1].start_time
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Time slots overlap on {request.date}: slot ending at {sorted_slots[i].end_time} overlaps with slot starting at {sorted_slots[i + 1].start_time}. Please ensure all time slots are non-overlapping.",
|
||||
detail=f"Time slots overlap: slot ending at {end} "
|
||||
f"overlaps with slot starting at {start}",
|
||||
)
|
||||
|
||||
# Validate each slot's end_time > start_time
|
||||
|
|
@ -108,13 +114,12 @@ async def set_availability(
|
|||
if slot.end_time <= slot.start_time:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid time slot on {request.date}: end time {slot.end_time} must be after start time {slot.start_time}. Please correct the time range.",
|
||||
detail=f"Invalid time slot: end time {slot.end_time} "
|
||||
f"must be after start time {slot.start_time}",
|
||||
)
|
||||
|
||||
# Delete existing availability for this date
|
||||
await db.execute(
|
||||
delete(Availability).where(Availability.date == request.date)
|
||||
)
|
||||
await db.execute(delete(Availability).where(Availability.date == request.date))
|
||||
|
||||
# Create new availability slots
|
||||
for slot in request.slots:
|
||||
|
|
@ -139,7 +144,7 @@ async def copy_availability(
|
|||
"""Copy availability from one day to multiple target days."""
|
||||
min_date, max_date = _get_date_range_bounds()
|
||||
|
||||
# Validate source date is in range (for consistency, though DB query would fail anyway)
|
||||
# Validate source date is in range
|
||||
_validate_date_in_range(request.source_date, min_date, max_date)
|
||||
|
||||
# Validate target dates
|
||||
|
|
@ -169,9 +174,8 @@ async def copy_availability(
|
|||
continue # Skip copying to self
|
||||
|
||||
# Delete existing availability for target date
|
||||
await db.execute(
|
||||
delete(Availability).where(Availability.date == target_date)
|
||||
)
|
||||
del_query = delete(Availability).where(Availability.date == target_date)
|
||||
await db.execute(del_query)
|
||||
|
||||
# Copy slots
|
||||
target_slots: list[TimeSlot] = []
|
||||
|
|
@ -182,10 +186,12 @@ async def copy_availability(
|
|||
end_time=source_slot.end_time,
|
||||
)
|
||||
db.add(new_availability)
|
||||
target_slots.append(TimeSlot(
|
||||
start_time=source_slot.start_time,
|
||||
end_time=source_slot.end_time,
|
||||
))
|
||||
target_slots.append(
|
||||
TimeSlot(
|
||||
start_time=source_slot.start_time,
|
||||
end_time=source_slot.end_time,
|
||||
)
|
||||
)
|
||||
|
||||
copied_days.append(AvailabilityDay(date=target_date, slots=target_slots))
|
||||
|
||||
|
|
@ -197,4 +203,3 @@ async def copy_availability(
|
|||
raise
|
||||
|
||||
return AvailabilityResponse(days=copied_days)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,24 +1,24 @@
|
|||
"""Booking routes for users to book appointments."""
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select, and_, func
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import User, Availability, Appointment, AppointmentStatus, Permission
|
||||
from models import Appointment, AppointmentStatus, Availability, Permission, User
|
||||
from schemas import (
|
||||
BookableSlot,
|
||||
AvailableSlotsResponse,
|
||||
BookingRequest,
|
||||
AppointmentResponse,
|
||||
AvailableSlotsResponse,
|
||||
BookableSlot,
|
||||
BookingRequest,
|
||||
PaginatedAppointments,
|
||||
)
|
||||
from shared_constants import SLOT_DURATION_MINUTES, MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS
|
||||
|
||||
from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS, SLOT_DURATION_MINUTES
|
||||
|
||||
router = APIRouter(prefix="/api/booking", tags=["booking"])
|
||||
|
||||
|
|
@ -74,12 +74,14 @@ def _validate_booking_date(d: date) -> None:
|
|||
if d < min_date:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot book for today or past dates. Earliest bookable date: {min_date}",
|
||||
detail=f"Cannot book for today or past dates. "
|
||||
f"Earliest bookable: {min_date}",
|
||||
)
|
||||
if d > max_date:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. Latest bookable: {max_date}",
|
||||
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. "
|
||||
f"Latest bookable: {max_date}",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -92,8 +94,8 @@ def _expand_availability_to_slots(
|
|||
|
||||
for avail in availability_slots:
|
||||
# Create datetime objects for start and end
|
||||
current = datetime.combine(target_date, avail.start_time, tzinfo=timezone.utc)
|
||||
end = datetime.combine(target_date, avail.end_time, tzinfo=timezone.utc)
|
||||
current = datetime.combine(target_date, avail.start_time, tzinfo=UTC)
|
||||
end = datetime.combine(target_date, avail.end_time, tzinfo=UTC)
|
||||
|
||||
# Generate 15-minute slots
|
||||
while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end:
|
||||
|
|
@ -128,12 +130,11 @@ async def get_available_slots(
|
|||
all_slots = _expand_availability_to_slots(availability_slots, target_date)
|
||||
|
||||
# Get existing booked appointments for this date
|
||||
day_start = datetime.combine(target_date, time.min, tzinfo=timezone.utc)
|
||||
day_end = datetime.combine(target_date, time.max, tzinfo=timezone.utc)
|
||||
day_start = datetime.combine(target_date, time.min, tzinfo=UTC)
|
||||
day_end = datetime.combine(target_date, time.max, tzinfo=UTC)
|
||||
|
||||
result = await db.execute(
|
||||
select(Appointment.slot_start)
|
||||
.where(
|
||||
select(Appointment.slot_start).where(
|
||||
and_(
|
||||
Appointment.slot_start >= day_start,
|
||||
Appointment.slot_start <= day_end,
|
||||
|
|
@ -145,8 +146,7 @@ async def get_available_slots(
|
|||
|
||||
# Filter out already booked slots
|
||||
available_slots = [
|
||||
slot for slot in all_slots
|
||||
if slot.start_time not in booked_starts
|
||||
slot for slot in all_slots if slot.start_time not in booked_starts
|
||||
]
|
||||
|
||||
return AvailableSlotsResponse(date=target_date, slots=available_slots)
|
||||
|
|
@ -162,12 +162,13 @@ async def create_booking(
|
|||
slot_date = request.slot_start.date()
|
||||
_validate_booking_date(slot_date)
|
||||
|
||||
# Validate slot is on the correct minute boundary (derived from SLOT_DURATION_MINUTES)
|
||||
# Validate slot is on the correct minute boundary
|
||||
valid_minutes = _get_valid_minute_boundaries()
|
||||
if request.slot_start.minute not in valid_minutes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slot start time must be on {SLOT_DURATION_MINUTES}-minute boundary (valid minutes: {valid_minutes})",
|
||||
detail=f"Slot must be on {SLOT_DURATION_MINUTES}-minute boundary "
|
||||
f"(valid minutes: {valid_minutes})",
|
||||
)
|
||||
if request.slot_start.second != 0 or request.slot_start.microsecond != 0:
|
||||
raise HTTPException(
|
||||
|
|
@ -177,11 +178,11 @@ async def create_booking(
|
|||
|
||||
# Verify slot falls within availability
|
||||
slot_start_time = request.slot_start.time()
|
||||
slot_end_time = (request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)).time()
|
||||
slot_end_dt = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
|
||||
slot_end_time = slot_end_dt.time()
|
||||
|
||||
result = await db.execute(
|
||||
select(Availability)
|
||||
.where(
|
||||
select(Availability).where(
|
||||
and_(
|
||||
Availability.date == slot_date,
|
||||
Availability.start_time <= slot_start_time,
|
||||
|
|
@ -192,9 +193,11 @@ async def create_booking(
|
|||
matching_availability = result.scalar_one_or_none()
|
||||
|
||||
if not matching_availability:
|
||||
slot_str = request.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Selected slot at {request.slot_start.strftime('%Y-%m-%d %H:%M')} UTC is not within any available time ranges for {slot_date}. Please select a different time slot.",
|
||||
detail=f"Selected slot at {slot_str} UTC is not within "
|
||||
f"any available time ranges for {slot_date}",
|
||||
)
|
||||
|
||||
# Create the appointment
|
||||
|
|
@ -216,8 +219,8 @@ async def create_booking(
|
|||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="This slot has already been booked. Please select another slot.",
|
||||
)
|
||||
detail="This slot has already been booked. Select another slot.",
|
||||
) from None
|
||||
|
||||
return _to_appointment_response(appointment, current_user.email)
|
||||
|
||||
|
|
@ -242,20 +245,19 @@ async def get_my_appointments(
|
|||
)
|
||||
appointments = result.scalars().all()
|
||||
|
||||
return [
|
||||
_to_appointment_response(apt, current_user.email)
|
||||
for apt in appointments
|
||||
]
|
||||
return [_to_appointment_response(apt, current_user.email) for apt in appointments]
|
||||
|
||||
|
||||
@appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse)
|
||||
@appointments_router.post(
|
||||
"/{appointment_id}/cancel", response_model=AppointmentResponse
|
||||
)
|
||||
async def cancel_my_appointment(
|
||||
appointment_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)),
|
||||
) -> AppointmentResponse:
|
||||
"""Cancel one of the current user's appointments."""
|
||||
# Get the appointment with explicit eager loading of user relationship
|
||||
# Get the appointment with eager loading of user relationship
|
||||
result = await db.execute(
|
||||
select(Appointment)
|
||||
.options(joinedload(Appointment.user))
|
||||
|
|
@ -266,31 +268,35 @@ async def cancel_my_appointment(
|
|||
if not appointment:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.",
|
||||
detail=f"Appointment {appointment_id} not found",
|
||||
)
|
||||
|
||||
# Verify ownership
|
||||
if appointment.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Cannot cancel another user's appointment")
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Cannot cancel another user's appointment",
|
||||
)
|
||||
|
||||
# Check if already cancelled
|
||||
if appointment.status != AppointmentStatus.BOOKED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel appointment with status '{appointment.status.value}'"
|
||||
detail=f"Cannot cancel: status is '{appointment.status.value}'",
|
||||
)
|
||||
|
||||
# Check if appointment is in the past
|
||||
if appointment.slot_start <= datetime.now(timezone.utc):
|
||||
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC"
|
||||
if appointment.slot_start <= datetime.now(UTC):
|
||||
apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started."
|
||||
detail=f"Cannot cancel appointment at {apt_time} UTC: "
|
||||
"already started or in the past",
|
||||
)
|
||||
|
||||
# Cancel the appointment
|
||||
appointment.status = AppointmentStatus.CANCELLED_BY_USER
|
||||
appointment.cancelled_at = datetime.now(timezone.utc)
|
||||
appointment.cancelled_at = datetime.now(UTC)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(appointment)
|
||||
|
|
@ -302,7 +308,9 @@ async def cancel_my_appointment(
|
|||
# Admin Appointments Endpoints
|
||||
# =============================================================================
|
||||
|
||||
admin_appointments_router = APIRouter(prefix="/api/admin/appointments", tags=["admin-appointments"])
|
||||
admin_appointments_router = APIRouter(
|
||||
prefix="/api/admin/appointments", tags=["admin-appointments"]
|
||||
)
|
||||
|
||||
|
||||
@admin_appointments_router.get("", response_model=PaginatedAppointments)
|
||||
|
|
@ -344,14 +352,18 @@ async def get_all_appointments(
|
|||
)
|
||||
|
||||
|
||||
@admin_appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse)
|
||||
@admin_appointments_router.post(
|
||||
"/{appointment_id}/cancel", response_model=AppointmentResponse
|
||||
)
|
||||
async def admin_cancel_appointment(
|
||||
appointment_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(require_permission(Permission.CANCEL_ANY_APPOINTMENT)),
|
||||
_current_user: User = Depends(
|
||||
require_permission(Permission.CANCEL_ANY_APPOINTMENT)
|
||||
),
|
||||
) -> AppointmentResponse:
|
||||
"""Cancel any appointment (admin only)."""
|
||||
# Get the appointment with explicit eager loading of user relationship
|
||||
# Get the appointment with eager loading of user relationship
|
||||
result = await db.execute(
|
||||
select(Appointment)
|
||||
.options(joinedload(Appointment.user))
|
||||
|
|
@ -362,27 +374,28 @@ async def admin_cancel_appointment(
|
|||
if not appointment:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.",
|
||||
detail=f"Appointment {appointment_id} not found",
|
||||
)
|
||||
|
||||
# Check if already cancelled
|
||||
if appointment.status != AppointmentStatus.BOOKED:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel appointment with status '{appointment.status.value}'"
|
||||
detail=f"Cannot cancel: status is '{appointment.status.value}'",
|
||||
)
|
||||
|
||||
# Check if appointment is in the past
|
||||
if appointment.slot_start <= datetime.now(timezone.utc):
|
||||
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC"
|
||||
if appointment.slot_start <= datetime.now(UTC):
|
||||
apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started."
|
||||
detail=f"Cannot cancel appointment at {apt_time} UTC: "
|
||||
"already started or in the past",
|
||||
)
|
||||
|
||||
# Cancel the appointment
|
||||
appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN
|
||||
appointment.cancelled_at = datetime.now(timezone.utc)
|
||||
appointment.cancelled_at = datetime.now(UTC)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(appointment)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
"""Counter routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import Counter, User, CounterRecord, Permission
|
||||
|
||||
from models import Counter, CounterRecord, Permission, User
|
||||
|
||||
router = APIRouter(prefix="/api/counter", tags=["counter"])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,25 +1,29 @@
|
|||
"""Invite routes for public check, user invites, and admin management."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy import select, func, desc
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format
|
||||
from models import User, Invite, InviteStatus, Permission
|
||||
from invite_utils import (
|
||||
generate_invite_identifier,
|
||||
is_valid_identifier_format,
|
||||
normalize_identifier,
|
||||
)
|
||||
from models import Invite, InviteStatus, Permission, User
|
||||
from schemas import (
|
||||
AdminUserResponse,
|
||||
InviteCheckResponse,
|
||||
InviteCreate,
|
||||
InviteResponse,
|
||||
UserInviteResponse,
|
||||
PaginatedInviteRecords,
|
||||
AdminUserResponse,
|
||||
UserInviteResponse,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/invites", tags=["invites"])
|
||||
admin_router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||
|
||||
|
|
@ -54,9 +58,7 @@ async def check_invite(
|
|||
if not is_valid_identifier_format(normalized):
|
||||
return InviteCheckResponse(valid=False, error="Invalid invite code format")
|
||||
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.identifier == normalized)
|
||||
)
|
||||
result = await db.execute(select(Invite).where(Invite.identifier == normalized))
|
||||
invite = result.scalar_one_or_none()
|
||||
|
||||
# Return same error for not found, spent, and revoked to avoid information leakage
|
||||
|
|
@ -112,9 +114,7 @@ async def create_invite(
|
|||
) -> InviteResponse:
|
||||
"""Create a new invite for a specified godfather user."""
|
||||
# Validate godfather exists
|
||||
result = await db.execute(
|
||||
select(User.id).where(User.id == data.godfather_id)
|
||||
)
|
||||
result = await db.execute(select(User.id).where(User.id == data.godfather_id))
|
||||
godfather_id = result.scalar_one_or_none()
|
||||
if not godfather_id:
|
||||
raise HTTPException(
|
||||
|
|
@ -141,8 +141,8 @@ async def create_invite(
|
|||
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to generate unique invite code. Please try again.",
|
||||
)
|
||||
detail="Failed to generate unique invite code. Try again.",
|
||||
) from None
|
||||
|
||||
if invite is None:
|
||||
raise HTTPException(
|
||||
|
|
@ -156,7 +156,9 @@ async def create_invite(
|
|||
async def list_all_invites(
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(10, ge=1, le=100),
|
||||
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"),
|
||||
status_filter: str | None = Query(
|
||||
None, alias="status", description="Filter by status: ready, spent, revoked"
|
||||
),
|
||||
godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
||||
|
|
@ -175,8 +177,9 @@ async def list_all_invites(
|
|||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked",
|
||||
)
|
||||
detail=f"Invalid status: {status_filter}. "
|
||||
"Must be ready, spent, or revoked",
|
||||
) from None
|
||||
|
||||
if godfather_id:
|
||||
query = query.where(Invite.godfather_id == godfather_id)
|
||||
|
|
@ -224,11 +227,12 @@ async def revoke_invite(
|
|||
if invite.status != InviteStatus.READY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.",
|
||||
detail=f"Cannot revoke invite with status '{invite.status.value}'. "
|
||||
"Only READY invites can be revoked.",
|
||||
)
|
||||
|
||||
invite.status = InviteStatus.REVOKED
|
||||
invite.revoked_at = datetime.now(timezone.utc)
|
||||
invite.revoked_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(invite)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"""Meta endpoints for shared constants."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from models import Permission, InviteStatus, ROLE_ADMIN, ROLE_REGULAR
|
||||
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, Permission
|
||||
from schemas import ConstantsResponse
|
||||
|
||||
router = APIRouter(prefix="/api/meta", tags=["meta"])
|
||||
|
|
@ -15,4 +16,3 @@ async def get_constants() -> ConstantsResponse:
|
|||
roles=[ROLE_ADMIN, ROLE_REGULAR],
|
||||
invite_statuses=[s.value for s in InviteStatus],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,15 @@
|
|||
"""Profile routes for user contact details."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import get_current_user
|
||||
from database import get_db
|
||||
from models import User, ROLE_REGULAR
|
||||
from models import ROLE_REGULAR, User
|
||||
from schemas import ProfileResponse, ProfileUpdate
|
||||
from validation import validate_profile_fields
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/profile", tags=["profile"])
|
||||
|
||||
|
||||
|
|
@ -29,9 +29,7 @@ async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str
|
|||
"""Get the email of a godfather user by ID."""
|
||||
if not godfather_id:
|
||||
return None
|
||||
result = await db.execute(
|
||||
select(User.email).where(User.id == godfather_id)
|
||||
)
|
||||
result = await db.execute(select(User.email).where(User.id == godfather_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Sum calculation routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import User, SumRecord, Permission
|
||||
from models import Permission, SumRecord, User
|
||||
from schemas import SumRequest, SumResponse
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/sum", tags=["sum"])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Pydantic schemas for API request/response models."""
|
||||
from datetime import datetime, date, time
|
||||
|
||||
from datetime import date, datetime, time
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator
|
||||
|
|
@ -9,6 +10,7 @@ from shared_constants import NOTE_MAX_LENGTH
|
|||
|
||||
class UserCredentials(BaseModel):
|
||||
"""Base model for user email/password."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
|
@ -19,6 +21,7 @@ UserLogin = UserCredentials
|
|||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for authenticated user info."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
roles: list[str]
|
||||
|
|
@ -27,6 +30,7 @@ class UserResponse(BaseModel):
|
|||
|
||||
class RegisterWithInvite(BaseModel):
|
||||
"""Request model for registration with invite."""
|
||||
|
||||
email: EmailStr
|
||||
password: str
|
||||
invite_identifier: str
|
||||
|
|
@ -34,12 +38,14 @@ class RegisterWithInvite(BaseModel):
|
|||
|
||||
class SumRequest(BaseModel):
|
||||
"""Request model for sum calculation."""
|
||||
|
||||
a: float
|
||||
b: float
|
||||
|
||||
|
||||
class SumResponse(BaseModel):
|
||||
"""Response model for sum calculation."""
|
||||
|
||||
a: float
|
||||
b: float
|
||||
result: float
|
||||
|
|
@ -47,6 +53,7 @@ class SumResponse(BaseModel):
|
|||
|
||||
class CounterRecordResponse(BaseModel):
|
||||
"""Response model for a counter audit record."""
|
||||
|
||||
id: int
|
||||
user_email: str
|
||||
value_before: int
|
||||
|
|
@ -56,6 +63,7 @@ class CounterRecordResponse(BaseModel):
|
|||
|
||||
class SumRecordResponse(BaseModel):
|
||||
"""Response model for a sum audit record."""
|
||||
|
||||
id: int
|
||||
user_email: str
|
||||
a: float
|
||||
|
|
@ -69,6 +77,7 @@ RecordT = TypeVar("RecordT", bound=BaseModel)
|
|||
|
||||
class PaginatedResponse(BaseModel, Generic[RecordT]):
|
||||
"""Generic paginated response wrapper."""
|
||||
|
||||
records: list[RecordT]
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -82,6 +91,7 @@ PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
|
|||
|
||||
class ProfileResponse(BaseModel):
|
||||
"""Response model for profile data."""
|
||||
|
||||
contact_email: str | None
|
||||
telegram: str | None
|
||||
signal: str | None
|
||||
|
|
@ -91,6 +101,7 @@ class ProfileResponse(BaseModel):
|
|||
|
||||
class ProfileUpdate(BaseModel):
|
||||
"""Request model for updating profile."""
|
||||
|
||||
contact_email: str | None = None
|
||||
telegram: str | None = None
|
||||
signal: str | None = None
|
||||
|
|
@ -99,6 +110,7 @@ class ProfileUpdate(BaseModel):
|
|||
|
||||
class InviteCheckResponse(BaseModel):
|
||||
"""Response for invite check endpoint."""
|
||||
|
||||
valid: bool
|
||||
status: str | None = None
|
||||
error: str | None = None
|
||||
|
|
@ -106,11 +118,13 @@ class InviteCheckResponse(BaseModel):
|
|||
|
||||
class InviteCreate(BaseModel):
|
||||
"""Request model for creating an invite."""
|
||||
|
||||
godfather_id: int
|
||||
|
||||
|
||||
class InviteResponse(BaseModel):
|
||||
"""Response model for invite data (admin view)."""
|
||||
|
||||
id: int
|
||||
identifier: str
|
||||
godfather_id: int
|
||||
|
|
@ -125,6 +139,7 @@ class InviteResponse(BaseModel):
|
|||
|
||||
class UserInviteResponse(BaseModel):
|
||||
"""Response model for a user's invite (simpler than admin view)."""
|
||||
|
||||
id: int
|
||||
identifier: str
|
||||
status: str
|
||||
|
|
@ -138,6 +153,7 @@ PaginatedInviteRecords = PaginatedResponse[InviteResponse]
|
|||
|
||||
class AdminUserResponse(BaseModel):
|
||||
"""Minimal user info for admin dropdowns."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
|
||||
|
|
@ -146,8 +162,10 @@ class AdminUserResponse(BaseModel):
|
|||
# Availability Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TimeSlot(BaseModel):
|
||||
"""A single time slot (start and end time)."""
|
||||
|
||||
start_time: time
|
||||
end_time: time
|
||||
|
||||
|
|
@ -164,23 +182,27 @@ class TimeSlot(BaseModel):
|
|||
|
||||
class AvailabilityDay(BaseModel):
|
||||
"""Availability for a single day."""
|
||||
|
||||
date: date
|
||||
slots: list[TimeSlot]
|
||||
|
||||
|
||||
class AvailabilityResponse(BaseModel):
|
||||
"""Response model for availability query."""
|
||||
|
||||
days: list[AvailabilityDay]
|
||||
|
||||
|
||||
class SetAvailabilityRequest(BaseModel):
|
||||
"""Request to set availability for a specific date."""
|
||||
|
||||
date: date
|
||||
slots: list[TimeSlot]
|
||||
|
||||
|
||||
class CopyAvailabilityRequest(BaseModel):
|
||||
"""Request to copy availability from one day to others."""
|
||||
|
||||
source_date: date
|
||||
target_dates: list[date]
|
||||
|
||||
|
|
@ -189,20 +211,24 @@ class CopyAvailabilityRequest(BaseModel):
|
|||
# Booking Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BookableSlot(BaseModel):
|
||||
"""A bookable 15-minute slot."""
|
||||
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
|
||||
|
||||
class AvailableSlotsResponse(BaseModel):
|
||||
"""Response for available slots on a given date."""
|
||||
|
||||
date: date
|
||||
slots: list[BookableSlot]
|
||||
|
||||
|
||||
class BookingRequest(BaseModel):
|
||||
"""Request to book an appointment."""
|
||||
|
||||
slot_start: datetime
|
||||
note: str | None = None
|
||||
|
||||
|
|
@ -216,6 +242,7 @@ class BookingRequest(BaseModel):
|
|||
|
||||
class AppointmentResponse(BaseModel):
|
||||
"""Response model for an appointment."""
|
||||
|
||||
id: int
|
||||
user_id: int
|
||||
user_email: str
|
||||
|
|
@ -234,8 +261,10 @@ PaginatedAppointments = PaginatedResponse[AppointmentResponse]
|
|||
# Meta/Constants Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ConstantsResponse(BaseModel):
|
||||
"""Response model for shared constants."""
|
||||
|
||||
permissions: list[str]
|
||||
roles: list[str]
|
||||
invite_statuses: list[str]
|
||||
|
|
|
|||
|
|
@ -1,12 +1,21 @@
|
|||
"""Seed the database with roles, permissions, and dev users."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import engine, async_session, Base
|
||||
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
||||
from auth import get_password_hash
|
||||
from database import Base, async_session, engine
|
||||
from models import (
|
||||
ROLE_ADMIN,
|
||||
ROLE_DEFINITIONS,
|
||||
ROLE_REGULAR,
|
||||
Permission,
|
||||
Role,
|
||||
User,
|
||||
)
|
||||
|
||||
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
|
||||
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
|
||||
|
|
@ -14,7 +23,9 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
|
|||
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
|
||||
|
||||
|
||||
async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role:
|
||||
async def upsert_role(
|
||||
db: AsyncSession, name: str, description: str, permissions: list[Permission]
|
||||
) -> Role:
|
||||
"""Create or update a role with the given permissions."""
|
||||
result = await db.execute(select(Role).where(Role.name == name))
|
||||
role = result.scalar_one_or_none()
|
||||
|
|
@ -35,7 +46,9 @@ async def upsert_role(db: AsyncSession, name: str, description: str, permissions
|
|||
return role
|
||||
|
||||
|
||||
async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User:
|
||||
async def upsert_user(
|
||||
db: AsyncSession, email: str, password: str, role_names: list[str]
|
||||
) -> User:
|
||||
"""Create or update a user with the given credentials and roles."""
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""Load shared constants from shared/constants.json."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -10,4 +11,3 @@ SLOT_DURATION_MINUTES: int = _constants["booking"]["slotDurationMinutes"]
|
|||
MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"]
|
||||
MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"]
|
||||
NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,17 +7,17 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
|||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from auth import get_password_hash
|
||||
from database import Base, get_db
|
||||
from main import app
|
||||
from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
||||
from auth import get_password_hash
|
||||
from models import ROLE_ADMIN, ROLE_DEFINITIONS, ROLE_REGULAR, Role, User
|
||||
from tests.helpers import unique_email
|
||||
|
||||
TEST_DATABASE_URL = os.getenv(
|
||||
"TEST_DATABASE_URL",
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
|
||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -91,7 +91,9 @@ async def create_user_with_roles(
|
|||
result = await db.execute(select(Role).where(Role.name == role_name))
|
||||
role = result.scalar_one_or_none()
|
||||
if not role:
|
||||
raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?")
|
||||
raise ValueError(
|
||||
f"Role '{role_name}' not found. Did you run setup_roles()?"
|
||||
)
|
||||
roles.append(role)
|
||||
|
||||
user = User(
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ import uuid
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models import User, Invite, InviteStatus
|
||||
from invite_utils import generate_invite_identifier
|
||||
from models import Invite, InviteStatus, User
|
||||
|
||||
|
||||
def unique_email(prefix: str = "test") -> str:
|
||||
|
|
@ -67,7 +67,9 @@ async def create_invite_for_registration(db: AsyncSession, godfather_email: str)
|
|||
godfather = result.scalar_one_or_none()
|
||||
|
||||
if not godfather:
|
||||
raise ValueError(f"Godfather user with email '{godfather_email}' not found. "
|
||||
"Create the user first using create_user_with_roles().")
|
||||
raise ValueError(
|
||||
f"Godfather user with email '{godfather_email}' not found. "
|
||||
"Create the user first using create_user_with_roles()."
|
||||
)
|
||||
|
||||
return await create_invite_for_godfather(db, godfather.id)
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@
|
|||
Note: Registration now requires an invite code. Tests that need to register
|
||||
users will create invites first via the helper function.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from auth import COOKIE_NAME
|
||||
from models import ROLE_REGULAR
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
from tests.helpers import create_invite_for_godfather, unique_email
|
||||
|
||||
|
||||
# Registration tests (with invite)
|
||||
|
|
@ -19,7 +20,9 @@ async def test_register_success(client_factory):
|
|||
|
||||
# Create godfather user and invite
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("godfather"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
response = await client_factory.post(
|
||||
|
|
@ -49,7 +52,9 @@ async def test_register_duplicate_email(client_factory):
|
|||
|
||||
# Create godfather and two invites
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
|
@ -80,7 +85,9 @@ async def test_register_duplicate_email(client_factory):
|
|||
async def test_register_invalid_email(client_factory):
|
||||
"""Cannot register with invalid email format."""
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
response = await client_factory.post(
|
||||
|
|
@ -138,7 +145,9 @@ async def test_login_success(client_factory):
|
|||
email = unique_email("login")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
|
|
@ -167,7 +176,9 @@ async def test_login_wrong_password(client_factory):
|
|||
email = unique_email("wrongpass")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
|
|
@ -221,7 +232,9 @@ async def test_get_me_success(client_factory):
|
|||
email = unique_email("me")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
|
|
@ -255,7 +268,9 @@ async def test_get_me_no_cookie(client):
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_me_invalid_cookie(client_factory):
|
||||
"""Cannot get current user with invalid cookie."""
|
||||
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed:
|
||||
async with client_factory.create(
|
||||
cookies={COOKIE_NAME: "invalidtoken123"}
|
||||
) as authed:
|
||||
response = await authed.get("/api/auth/me")
|
||||
assert response.status_code == 401
|
||||
assert response.json()["detail"] == "Invalid authentication credentials"
|
||||
|
|
@ -277,7 +292,9 @@ async def test_cookie_from_register_works_for_me(client_factory):
|
|||
email = unique_email("tokentest")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
|
|
@ -303,7 +320,9 @@ async def test_cookie_from_login_works_for_me(client_factory):
|
|||
email = unique_email("logintoken")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
|
|
@ -335,7 +354,9 @@ async def test_multiple_users_isolated(client_factory):
|
|||
email2 = unique_email("user2")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
|
@ -377,7 +398,9 @@ async def test_password_is_hashed(client_factory):
|
|||
email = unique_email("hashtest")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
|
|
@ -401,7 +424,9 @@ async def test_case_sensitive_password(client_factory):
|
|||
email = unique_email("casetest")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
await client_factory.post(
|
||||
|
|
@ -426,7 +451,9 @@ async def test_logout_success(client_factory):
|
|||
email = unique_email("logout")
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg_response = await client_factory.post(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ Availability API Tests
|
|||
|
||||
Tests for the admin availability management endpoints.
|
||||
"""
|
||||
from datetime import date, time, timedelta
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
|
@ -19,6 +21,7 @@ def in_days(n: int) -> date:
|
|||
# Permission Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAvailabilityPermissions:
|
||||
"""Test that only admins can access availability endpoints."""
|
||||
|
||||
|
|
@ -44,7 +47,9 @@ class TestAvailabilityPermissions:
|
|||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_get_availability(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_get_availability(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get(
|
||||
"/api/admin/availability",
|
||||
|
|
@ -53,7 +58,9 @@ class TestAvailabilityPermissions:
|
|||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_set_availability(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_set_availability(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.put(
|
||||
"/api/admin/availability",
|
||||
|
|
@ -88,6 +95,7 @@ class TestAvailabilityPermissions:
|
|||
# Set Availability Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSetAvailability:
|
||||
"""Test setting availability for a date."""
|
||||
|
||||
|
|
@ -128,7 +136,9 @@ class TestSetAvailability:
|
|||
assert len(data["slots"]) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_empty_slots_clears_availability(self, client_factory, admin_user):
|
||||
async def test_set_empty_slots_clears_availability(
|
||||
self, client_factory, admin_user
|
||||
):
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
# First set some availability
|
||||
await client.put(
|
||||
|
|
@ -162,7 +172,7 @@ class TestSetAvailability:
|
|||
)
|
||||
|
||||
# Replace with different slots
|
||||
response = await client.put(
|
||||
await client.put(
|
||||
"/api/admin/availability",
|
||||
json={
|
||||
"date": str(tomorrow()),
|
||||
|
|
@ -186,6 +196,7 @@ class TestSetAvailability:
|
|||
# Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAvailabilityValidation:
|
||||
"""Test validation rules for availability."""
|
||||
|
||||
|
|
@ -283,6 +294,7 @@ class TestAvailabilityValidation:
|
|||
# Get Availability Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetAvailability:
|
||||
"""Test retrieving availability."""
|
||||
|
||||
|
|
@ -360,6 +372,7 @@ class TestGetAvailability:
|
|||
# Copy Availability Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCopyAvailability:
|
||||
"""Test copying availability from one day to others."""
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ Booking API Tests
|
|||
|
||||
Tests for the user booking endpoints.
|
||||
"""
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Appointment, AppointmentStatus
|
||||
|
|
@ -21,11 +23,14 @@ def in_days(n: int) -> date:
|
|||
# Permission Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBookingPermissions:
|
||||
"""Test that only regular users can book appointments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_get_slots(self, client_factory, regular_user, admin_user):
|
||||
async def test_regular_user_can_get_slots(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Regular user can get available slots."""
|
||||
# First, admin sets up availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -39,12 +44,16 @@ class TestBookingPermissions:
|
|||
|
||||
# Regular user gets slots
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_book(self, client_factory, regular_user, admin_user):
|
||||
async def test_regular_user_can_book(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Regular user can book an appointment."""
|
||||
# Admin sets up availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -69,7 +78,9 @@ class TestBookingPermissions:
|
|||
async def test_admin_cannot_get_slots(self, client_factory, admin_user):
|
||||
"""Admin cannot access booking slots endpoint."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
|
|
@ -96,7 +107,9 @@ class TestBookingPermissions:
|
|||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_cannot_get_slots(self, client):
|
||||
"""Unauthenticated user cannot get slots."""
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -113,6 +126,7 @@ class TestBookingPermissions:
|
|||
# Get Slots Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetSlots:
|
||||
"""Test getting available booking slots."""
|
||||
|
||||
|
|
@ -120,7 +134,9 @@ class TestGetSlots:
|
|||
async def test_get_slots_no_availability(self, client_factory, regular_user):
|
||||
"""Returns empty slots when no availability set."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -128,7 +144,9 @@ class TestGetSlots:
|
|||
assert data["slots"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slots_expands_to_15min(self, client_factory, regular_user, admin_user):
|
||||
async def test_get_slots_expands_to_15min(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Availability is expanded into 15-minute slots."""
|
||||
# Admin sets 1-hour availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -142,7 +160,9 @@ class TestGetSlots:
|
|||
|
||||
# User gets slots - should be 4 x 15-minute slots
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -156,7 +176,9 @@ class TestGetSlots:
|
|||
assert "10:00:00" in data["slots"][3]["end_time"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_slots_excludes_booked(self, client_factory, regular_user, admin_user):
|
||||
async def test_get_slots_excludes_booked(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Already booked slots are excluded from available slots."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -176,7 +198,9 @@ class TestGetSlots:
|
|||
)
|
||||
|
||||
# Get slots again - should have 3 left
|
||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -189,6 +213,7 @@ class TestGetSlots:
|
|||
# Booking Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCreateBooking:
|
||||
"""Test creating bookings."""
|
||||
|
||||
|
|
@ -248,7 +273,9 @@ class TestCreateBooking:
|
|||
assert data["note"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_double_book_slot(self, client_factory, regular_user, admin_user, alt_regular_user):
|
||||
async def test_cannot_double_book_slot(
|
||||
self, client_factory, regular_user, admin_user, alt_regular_user
|
||||
):
|
||||
"""Cannot book a slot that's already booked."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -279,7 +306,9 @@ class TestCreateBooking:
|
|||
assert "already been booked" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_book_outside_availability(self, client_factory, regular_user, admin_user):
|
||||
async def test_cannot_book_outside_availability(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Cannot book a slot outside of availability."""
|
||||
# Admin sets availability for morning only
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -306,6 +335,7 @@ class TestCreateBooking:
|
|||
# Date Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBookingDateValidation:
|
||||
"""Test date validation for bookings."""
|
||||
|
||||
|
|
@ -319,7 +349,10 @@ class TestBookingDateValidation:
|
|||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "past" in response.json()["detail"].lower() or "today" in response.json()["detail"].lower()
|
||||
assert (
|
||||
"past" in response.json()["detail"].lower()
|
||||
or "today" in response.json()["detail"].lower()
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_book_past_date(self, client_factory, regular_user):
|
||||
|
|
@ -350,7 +383,9 @@ class TestBookingDateValidation:
|
|||
async def test_cannot_get_slots_today(self, client_factory, regular_user):
|
||||
"""Cannot get slots for today."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(date.today())})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(date.today())}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
|
@ -359,7 +394,9 @@ class TestBookingDateValidation:
|
|||
"""Cannot get slots for past date."""
|
||||
yesterday = date.today() - timedelta(days=1)
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/booking/slots", params={"date": str(yesterday)})
|
||||
response = await client.get(
|
||||
"/api/booking/slots", params={"date": str(yesterday)}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
|
@ -368,11 +405,14 @@ class TestBookingDateValidation:
|
|||
# Time Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBookingTimeValidation:
|
||||
"""Test time validation for bookings."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slot_must_be_15min_boundary(self, client_factory, regular_user, admin_user):
|
||||
async def test_slot_must_be_15min_boundary(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Slot start time must be on 15-minute boundary."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -399,6 +439,7 @@ class TestBookingTimeValidation:
|
|||
# Note Validation Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBookingNoteValidation:
|
||||
"""Test note validation for bookings."""
|
||||
|
||||
|
|
@ -426,7 +467,9 @@ class TestBookingNoteValidation:
|
|||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_note_exactly_144_chars(self, client_factory, regular_user, admin_user):
|
||||
async def test_note_exactly_144_chars(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Note of exactly 144 characters is allowed."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -454,6 +497,7 @@ class TestBookingNoteValidation:
|
|||
# User Appointments Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestUserAppointments:
|
||||
"""Test user appointments endpoints."""
|
||||
|
||||
|
|
@ -467,7 +511,9 @@ class TestUserAppointments:
|
|||
assert response.json() == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_appointments_with_bookings(self, client_factory, regular_user, admin_user):
|
||||
async def test_get_my_appointments_with_bookings(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Returns user's appointments."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -502,7 +548,9 @@ class TestUserAppointments:
|
|||
assert "Second" in notes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_view_user_appointments(self, client_factory, admin_user):
|
||||
async def test_admin_cannot_view_user_appointments(
|
||||
self, client_factory, admin_user
|
||||
):
|
||||
"""Admin cannot access user appointments endpoint."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/appointments")
|
||||
|
|
@ -520,7 +568,9 @@ class TestCancelAppointment:
|
|||
"""Test cancelling appointments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_own_appointment(self, client_factory, regular_user, admin_user):
|
||||
async def test_cancel_own_appointment(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""User can cancel their own appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -549,7 +599,9 @@ class TestCancelAppointment:
|
|||
assert data["cancelled_at"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_cancel_others_appointment(self, client_factory, regular_user, alt_regular_user, admin_user):
|
||||
async def test_cannot_cancel_others_appointment(
|
||||
self, client_factory, regular_user, alt_regular_user, admin_user
|
||||
):
|
||||
"""User cannot cancel another user's appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -577,7 +629,9 @@ class TestCancelAppointment:
|
|||
assert "another user" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_cancel_nonexistent_appointment(self, client_factory, regular_user):
|
||||
async def test_cannot_cancel_nonexistent_appointment(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Returns 404 for non-existent appointment."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post("/api/appointments/99999/cancel")
|
||||
|
|
@ -585,7 +639,9 @@ class TestCancelAppointment:
|
|||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user):
|
||||
async def test_cannot_cancel_already_cancelled(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Cannot cancel an already cancelled appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -613,7 +669,9 @@ class TestCancelAppointment:
|
|||
assert "cancelled_by_user" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_use_user_cancel_endpoint(self, client_factory, admin_user):
|
||||
async def test_admin_cannot_use_user_cancel_endpoint(
|
||||
self, client_factory, admin_user
|
||||
):
|
||||
"""Admin cannot use user cancel endpoint."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post("/api/appointments/1/cancel")
|
||||
|
|
@ -621,7 +679,9 @@ class TestCancelAppointment:
|
|||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_slot_becomes_available(self, client_factory, regular_user, admin_user):
|
||||
async def test_cancelled_slot_becomes_available(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""After cancelling, the slot becomes available again."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -663,7 +723,7 @@ class TestCancelAppointment:
|
|||
"""User cannot cancel a past appointment."""
|
||||
# Create a past appointment directly in DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
appointment = Appointment(
|
||||
user_id=regular_user["user"]["id"],
|
||||
slot_start=past_time,
|
||||
|
|
@ -687,11 +747,14 @@ class TestCancelAppointment:
|
|||
# Admin Appointments Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdminViewAppointments:
|
||||
"""Test admin viewing all appointments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_view_all_appointments(self, client_factory, regular_user, admin_user):
|
||||
async def test_admin_can_view_all_appointments(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Admin can view all appointments."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -725,7 +788,9 @@ class TestAdminViewAppointments:
|
|||
assert any(apt["note"] == "Test" for apt in data["records"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_all_appointments(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_view_all_appointments(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular user cannot access admin appointments endpoint."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/admin/appointments")
|
||||
|
|
@ -743,7 +808,9 @@ class TestAdminCancelAppointment:
|
|||
"""Test admin cancelling appointments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_cancel_any_appointment(self, client_factory, regular_user, admin_user):
|
||||
async def test_admin_can_cancel_any_appointment(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Admin can cancel any user's appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -765,7 +832,9 @@ class TestAdminCancelAppointment:
|
|||
|
||||
# Admin cancels
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
||||
response = await admin_client.post(
|
||||
f"/api/admin/appointments/{apt_id}/cancel"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -773,7 +842,9 @@ class TestAdminCancelAppointment:
|
|||
assert data["cancelled_at"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_use_admin_cancel(self, client_factory, regular_user, admin_user):
|
||||
async def test_regular_user_cannot_use_admin_cancel(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Regular user cannot use admin cancel endpoint."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -799,7 +870,9 @@ class TestAdminCancelAppointment:
|
|||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cancel_nonexistent_appointment(self, client_factory, admin_user):
|
||||
async def test_admin_cancel_nonexistent_appointment(
|
||||
self, client_factory, admin_user
|
||||
):
|
||||
"""Returns 404 for non-existent appointment."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post("/api/admin/appointments/99999/cancel")
|
||||
|
|
@ -807,7 +880,9 @@ class TestAdminCancelAppointment:
|
|||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user):
|
||||
async def test_admin_cannot_cancel_already_cancelled(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Admin cannot cancel an already cancelled appointment."""
|
||||
# Admin sets availability
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
|
|
@ -832,17 +907,21 @@ class TestAdminCancelAppointment:
|
|||
|
||||
# Admin tries to cancel again
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
||||
response = await admin_client.post(
|
||||
f"/api/admin/appointments/{apt_id}/cancel"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "cancelled_by_user" in response.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_cancel_past_appointment(self, client_factory, regular_user, admin_user):
|
||||
async def test_admin_cannot_cancel_past_appointment(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Admin cannot cancel a past appointment."""
|
||||
# Create a past appointment directly in DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
appointment = Appointment(
|
||||
user_id=regular_user["user"]["id"],
|
||||
slot_start=past_time,
|
||||
|
|
@ -856,8 +935,9 @@ class TestAdminCancelAppointment:
|
|||
|
||||
# Admin tries to cancel
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
||||
response = await admin_client.post(
|
||||
f"/api/admin/appointments/{apt_id}/cancel"
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "past" in response.json()["detail"].lower()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,13 @@
|
|||
|
||||
Note: Registration now requires an invite code.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from auth import COOKIE_NAME
|
||||
from models import ROLE_REGULAR
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
from tests.helpers import create_invite_for_godfather, unique_email
|
||||
|
||||
|
||||
# Protected endpoint tests - without auth
|
||||
|
|
@ -41,7 +42,9 @@ async def test_increment_counter_invalid_cookie(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_counter_authenticated(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
|
|
@ -64,7 +67,9 @@ async def test_get_counter_authenticated(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_increment_counter(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
|
|
@ -91,7 +96,9 @@ async def test_increment_counter(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_increment_counter_multiple(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
|
|
@ -120,7 +127,9 @@ async def test_increment_counter_multiple(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_counter_after_increment(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
|
|
@ -149,7 +158,9 @@ async def test_get_counter_after_increment(client_factory):
|
|||
async def test_counter_shared_between_users(client_factory):
|
||||
# Create godfather and invites for two users
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
"""Tests for invite functionality."""
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from invite_utils import (
|
||||
generate_invite_identifier,
|
||||
normalize_identifier,
|
||||
is_valid_identifier_format,
|
||||
BIP39_WORDS,
|
||||
generate_invite_identifier,
|
||||
is_valid_identifier_format,
|
||||
normalize_identifier,
|
||||
)
|
||||
from models import Invite, InviteStatus, User, ROLE_REGULAR
|
||||
from tests.helpers import unique_email
|
||||
from models import ROLE_REGULAR, Invite, InviteStatus, User
|
||||
from tests.conftest import create_user_with_roles
|
||||
|
||||
from tests.helpers import unique_email
|
||||
|
||||
# ============================================================================
|
||||
# Invite Utils Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_bip39_words_loaded():
|
||||
"""BIP39 word list should have exactly 2048 words."""
|
||||
assert len(BIP39_WORDS) == 2048
|
||||
|
|
@ -89,6 +90,7 @@ def test_is_valid_identifier_format_invalid():
|
|||
# Invite Model Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invite(client_factory):
|
||||
"""Can create an invite with godfather."""
|
||||
|
|
@ -173,7 +175,7 @@ async def test_invite_unique_identifier(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_invite_status_transitions(client_factory):
|
||||
"""Invite status can be changed."""
|
||||
from datetime import datetime, UTC
|
||||
from datetime import UTC, datetime
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
|
|
@ -206,7 +208,7 @@ async def test_invite_status_transitions(client_factory):
|
|||
@pytest.mark.asyncio
|
||||
async def test_invite_revoke(client_factory):
|
||||
"""Invite can be revoked."""
|
||||
from datetime import datetime, UTC
|
||||
from datetime import UTC, datetime
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
|
|
@ -236,6 +238,7 @@ async def test_invite_revoke(client_factory):
|
|||
# User Godfather Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_godfather_relationship(client_factory):
|
||||
"""User can have a godfather."""
|
||||
|
|
@ -254,9 +257,7 @@ async def test_user_godfather_relationship(client_factory):
|
|||
await db.commit()
|
||||
|
||||
# Query user fresh
|
||||
result = await db.execute(
|
||||
select(User).where(User.id == user.id)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.id == user.id))
|
||||
loaded_user = result.scalar_one()
|
||||
|
||||
assert loaded_user.godfather_id == godfather.id
|
||||
|
|
@ -280,6 +281,7 @@ async def test_user_without_godfather(client_factory):
|
|||
# Admin Create Invite API Tests (Phase 2)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_create_invite(client_factory, admin_user, regular_user):
|
||||
"""Admin can create an invite for a regular user."""
|
||||
|
|
@ -387,9 +389,7 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
|
|||
|
||||
# Query from DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.id == invite_id)
|
||||
)
|
||||
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||
invite = result.scalar_one()
|
||||
|
||||
assert invite.identifier == data["identifier"]
|
||||
|
|
@ -398,7 +398,9 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user):
|
||||
async def test_create_invite_retries_on_collision(
|
||||
client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Create invite retries with new identifier on collision."""
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -419,6 +421,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
|||
|
||||
# Mock generator to first return the same identifier (collision), then a new one
|
||||
call_count = 0
|
||||
|
||||
def mock_generator():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
|
|
@ -426,7 +429,9 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
|||
return identifier1 # Will collide
|
||||
return f"unique-word-{call_count:02d}" # Won't collide
|
||||
|
||||
with patch("routes.invites.generate_invite_identifier", side_effect=mock_generator):
|
||||
with patch(
|
||||
"routes.invites.generate_invite_identifier", side_effect=mock_generator
|
||||
):
|
||||
response2 = await client.post(
|
||||
"/api/admin/invites",
|
||||
json={"godfather_id": godfather.id},
|
||||
|
|
@ -442,6 +447,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
|||
# Invite Check API Tests (Phase 3)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_valid(client_factory, admin_user, regular_user):
|
||||
"""Check endpoint returns valid=True for READY invite."""
|
||||
|
|
@ -509,7 +515,9 @@ async def test_check_invite_invalid_format(client_factory):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_spent_returns_not_found(client_factory, admin_user, regular_user):
|
||||
async def test_check_invite_spent_returns_not_found(
|
||||
client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Check endpoint returns same error for spent invite as for non-existent (no info leakage)."""
|
||||
# Create invite
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
|
|
@ -547,9 +555,11 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_invite_revoked_returns_not_found(client_factory, admin_user, regular_user):
|
||||
async def test_check_invite_revoked_returns_not_found(
|
||||
client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Check endpoint returns same error for revoked invite as for non-existent (no info leakage)."""
|
||||
from datetime import datetime, UTC
|
||||
from datetime import UTC, datetime
|
||||
|
||||
# Create invite
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
|
|
@ -613,6 +623,7 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular
|
|||
# Register with Invite Tests (Phase 3)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_with_valid_invite(client_factory, admin_user, regular_user):
|
||||
"""Can register with valid invite code."""
|
||||
|
|
@ -681,9 +692,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
|
|||
|
||||
# Check invite status
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.id == invite_id)
|
||||
)
|
||||
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||
invite = result.scalar_one()
|
||||
|
||||
assert invite.status == InviteStatus.SPENT
|
||||
|
|
@ -723,9 +732,7 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user)
|
|||
|
||||
# Check user's godfather
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == new_email)
|
||||
)
|
||||
result = await db.execute(select(User).where(User.email == new_email))
|
||||
new_user = result.scalar_one()
|
||||
|
||||
assert new_user.godfather_id == godfather_id
|
||||
|
|
@ -794,7 +801,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
|
|||
@pytest.mark.asyncio
|
||||
async def test_register_with_revoked_invite(client_factory, admin_user, regular_user):
|
||||
"""Cannot register with revoked invite."""
|
||||
from datetime import datetime, UTC
|
||||
from datetime import UTC, datetime
|
||||
|
||||
# Create invite
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
|
|
@ -814,9 +821,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
|
|||
|
||||
# Revoke invite directly in DB
|
||||
async with client_factory.get_db_session() as db:
|
||||
result = await db.execute(
|
||||
select(Invite).where(Invite.id == invite_id)
|
||||
)
|
||||
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||
invite = result.scalar_one()
|
||||
invite.status = InviteStatus.REVOKED
|
||||
invite.revoked_at = datetime.now(UTC)
|
||||
|
|
@ -904,6 +909,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
|
|||
# User Invites API Tests (Phase 4)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user):
|
||||
"""Regular user can list their own invites."""
|
||||
|
|
@ -941,7 +947,9 @@ async def test_user_with_no_invites_gets_empty_list(client_factory, regular_user
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regular_user):
|
||||
async def test_spent_invite_shows_used_by_email(
|
||||
client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Spent invite shows who used it."""
|
||||
# Create invite for regular user
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
|
|
@ -1002,6 +1010,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory):
|
|||
# Admin Invite Management Tests (Phase 5)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user):
|
||||
"""Admin can list all invites."""
|
||||
|
|
@ -1153,4 +1162,3 @@ async def test_regular_user_cannot_access_admin_invites(client_factory, regular_
|
|||
# Revoke
|
||||
response = await client.post("/api/admin/invites/1/revoke")
|
||||
assert response.status_code == 403
|
||||
|
||||
|
|
|
|||
|
|
@ -7,15 +7,16 @@ These tests verify that:
|
|||
3. Unauthenticated users are denied access (401)
|
||||
4. The permission system cannot be bypassed
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from models import Permission
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Role Assignment Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRoleAssignment:
|
||||
"""Test that roles are properly assigned and returned."""
|
||||
|
||||
|
|
@ -40,7 +41,9 @@ class TestRoleAssignment:
|
|||
assert "regular" not in data["roles"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_has_correct_permissions(self, client_factory, regular_user):
|
||||
async def test_regular_user_has_correct_permissions(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
|
@ -72,7 +75,9 @@ class TestRoleAssignment:
|
|||
assert Permission.USE_SUM.value not in permissions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_with_no_roles_has_no_permissions(self, client_factory, user_no_roles):
|
||||
async def test_user_with_no_roles_has_no_permissions(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.get("/api/auth/me")
|
||||
|
||||
|
|
@ -85,6 +90,7 @@ class TestRoleAssignment:
|
|||
# Counter Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCounterAccess:
|
||||
"""Test access control for counter endpoints."""
|
||||
|
||||
|
|
@ -97,7 +103,9 @@ class TestCounterAccess:
|
|||
assert "value" in response.json()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_increment_counter(self, client_factory, regular_user):
|
||||
async def test_regular_user_can_increment_counter(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post("/api/counter/increment")
|
||||
|
||||
|
|
@ -122,7 +130,9 @@ class TestCounterAccess:
|
|||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_view_counter(self, client_factory, user_no_roles):
|
||||
async def test_user_without_roles_cannot_view_counter(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
"""Users with no roles should be forbidden."""
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
|
@ -146,6 +156,7 @@ class TestCounterAccess:
|
|||
# Sum Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSumAccess:
|
||||
"""Test access control for sum endpoint."""
|
||||
|
||||
|
|
@ -173,7 +184,9 @@ class TestSumAccess:
|
|||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_use_sum(self, client_factory, user_no_roles):
|
||||
async def test_user_without_roles_cannot_use_sum(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/sum",
|
||||
|
|
@ -195,6 +208,7 @@ class TestSumAccess:
|
|||
# Audit Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAuditAccess:
|
||||
"""Test access control for audit endpoints."""
|
||||
|
||||
|
|
@ -219,7 +233,9 @@ class TestAuditAccess:
|
|||
assert "total" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_counter_audit(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_view_counter_audit(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular users should be forbidden from audit endpoints."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
|
|
@ -228,7 +244,9 @@ class TestAuditAccess:
|
|||
assert "permission" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_sum_audit(self, client_factory, regular_user):
|
||||
async def test_regular_user_cannot_view_sum_audit(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular users should be forbidden from audit endpoints."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/sum")
|
||||
|
|
@ -236,7 +254,9 @@ class TestAuditAccess:
|
|||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_view_audit(self, client_factory, user_no_roles):
|
||||
async def test_user_without_roles_cannot_view_audit(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
|
||||
|
|
@ -257,6 +277,7 @@ class TestAuditAccess:
|
|||
# Offensive Security Tests - Bypass Attempts
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSecurityBypassAttempts:
|
||||
"""
|
||||
Offensive tests that attempt to bypass security controls.
|
||||
|
|
@ -264,7 +285,9 @@ class TestSecurityBypassAttempts:
|
|||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_audit_with_forged_role_claim(self, client_factory, regular_user):
|
||||
async def test_cannot_access_audit_with_forged_role_claim(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""
|
||||
Attempt to access audit by somehow claiming admin role.
|
||||
The server should verify roles from DB, not trust client claims.
|
||||
|
|
@ -287,14 +310,18 @@ class TestSecurityBypassAttempts:
|
|||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_with_tampered_token(self, client_factory, regular_user):
|
||||
async def test_cannot_access_with_tampered_token(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Test that tokens signed with wrong key are rejected."""
|
||||
# Take a valid token and modify it
|
||||
original_token = regular_user["cookies"].get("auth_token", "")
|
||||
if original_token:
|
||||
tampered_token = original_token[:-5] + "XXXXX"
|
||||
|
||||
async with client_factory.create(cookies={"auth_token": tampered_token}) as client:
|
||||
async with client_factory.create(
|
||||
cookies={"auth_token": tampered_token}
|
||||
) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
|
@ -305,12 +332,14 @@ class TestSecurityBypassAttempts:
|
|||
Test that new registrations cannot claim admin role.
|
||||
New users should only get 'regular' role by default.
|
||||
"""
|
||||
from tests.helpers import unique_email, create_invite_for_godfather
|
||||
from tests.conftest import create_user_with_roles
|
||||
from models import ROLE_REGULAR
|
||||
from tests.conftest import create_user_with_roles
|
||||
from tests.helpers import create_invite_for_godfather, unique_email
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
response = await client_factory.post(
|
||||
|
|
@ -341,15 +370,17 @@ class TestSecurityBypassAttempts:
|
|||
If a user is deleted, their token should no longer work.
|
||||
This tests that tokens are validated against current DB state.
|
||||
"""
|
||||
from tests.helpers import unique_email
|
||||
from sqlalchemy import delete
|
||||
|
||||
from models import User
|
||||
from tests.helpers import unique_email
|
||||
|
||||
email = unique_email("deleted")
|
||||
|
||||
# Create and login user
|
||||
async with client_factory.get_db_session() as db:
|
||||
from tests.conftest import create_user_with_roles
|
||||
|
||||
user = await create_user_with_roles(db, email, "password123", ["regular"])
|
||||
user_id = user.id
|
||||
|
||||
|
|
@ -376,15 +407,17 @@ class TestSecurityBypassAttempts:
|
|||
If a user's role is changed, the change should be reflected
|
||||
in subsequent requests (no stale permission cache).
|
||||
"""
|
||||
from tests.helpers import unique_email
|
||||
from sqlalchemy import select
|
||||
from models import User, Role
|
||||
|
||||
from models import Role, User
|
||||
from tests.helpers import unique_email
|
||||
|
||||
email = unique_email("rolechange")
|
||||
|
||||
# Create regular user
|
||||
async with client_factory.get_db_session() as db:
|
||||
from tests.conftest import create_user_with_roles
|
||||
|
||||
await create_user_with_roles(db, email, "password123", ["regular"])
|
||||
|
||||
login_response = await client_factory.post(
|
||||
|
|
@ -406,10 +439,7 @@ class TestSecurityBypassAttempts:
|
|||
result = await db.execute(select(Role).where(Role.name == "admin"))
|
||||
admin_role = result.scalar_one()
|
||||
|
||||
result = await db.execute(select(Role).where(Role.name == "regular"))
|
||||
regular_role = result.scalar_one()
|
||||
|
||||
user.roles = [admin_role] # Remove regular, add admin
|
||||
user.roles = [admin_role] # Replace roles with admin only
|
||||
await db.commit()
|
||||
|
||||
# Now should have audit access but not counter access
|
||||
|
|
@ -422,6 +452,7 @@ class TestSecurityBypassAttempts:
|
|||
# Audit Record Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAuditRecords:
|
||||
"""Test that actions are properly recorded in audit logs."""
|
||||
|
||||
|
|
@ -466,7 +497,7 @@ class TestAuditRecords:
|
|||
|
||||
# Find record with our values
|
||||
records = data["records"]
|
||||
matching = [r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30]
|
||||
matching = [
|
||||
r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30
|
||||
]
|
||||
assert len(matching) >= 1
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""Tests for user profile and contact details."""
|
||||
import pytest
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from models import User, ROLE_REGULAR
|
||||
from auth import get_password_hash
|
||||
from models import User
|
||||
from tests.helpers import unique_email
|
||||
|
||||
# Valid npub for testing (32 zero bytes encoded as bech32)
|
||||
|
|
@ -328,7 +328,9 @@ class TestUpdateProfileEndpoint:
|
|||
assert "field_errors" in data["detail"]
|
||||
assert "nostr_npub" in data["detail"]["field_errors"]
|
||||
|
||||
async def test_multiple_invalid_fields_returns_all_errors(self, client_factory, regular_user):
|
||||
async def test_multiple_invalid_fields_returns_all_errors(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Multiple invalid fields return all errors."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.put(
|
||||
|
|
@ -344,7 +346,9 @@ class TestUpdateProfileEndpoint:
|
|||
assert "contact_email" in data["detail"]["field_errors"]
|
||||
assert "telegram" in data["detail"]["field_errors"]
|
||||
|
||||
async def test_partial_update_preserves_other_fields(self, client_factory, regular_user):
|
||||
async def test_partial_update_preserves_other_fields(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Updating one field doesn't affect others (they get set to the request values)."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
# Set initial values
|
||||
|
|
@ -402,11 +406,14 @@ class TestProfilePrivacy:
|
|||
class TestProfileGodfather:
|
||||
"""Tests for godfather information in profile."""
|
||||
|
||||
async def test_profile_shows_godfather_email(self, client_factory, admin_user, regular_user):
|
||||
async def test_profile_shows_godfather_email(
|
||||
self, client_factory, admin_user, regular_user
|
||||
):
|
||||
"""Profile shows godfather email for users who signed up with invite."""
|
||||
from tests.helpers import unique_email
|
||||
from sqlalchemy import select
|
||||
|
||||
from models import User
|
||||
from tests.helpers import unique_email
|
||||
|
||||
# Create invite
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
|
|
@ -443,7 +450,9 @@ class TestProfileGodfather:
|
|||
data = response.json()
|
||||
assert data["godfather_email"] == regular_user["email"]
|
||||
|
||||
async def test_profile_godfather_null_for_seeded_users(self, client_factory, regular_user):
|
||||
async def test_profile_godfather_null_for_seeded_users(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Profile shows null godfather for users without one (e.g., seeded users)."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/profile")
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
"""Tests for profile field validation."""
|
||||
import pytest
|
||||
|
||||
from validation import (
|
||||
validate_contact_email,
|
||||
validate_telegram,
|
||||
validate_signal,
|
||||
validate_nostr_npub,
|
||||
validate_profile_fields,
|
||||
validate_signal,
|
||||
validate_telegram,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -140,13 +139,17 @@ class TestValidateNostrNpub:
|
|||
assert validate_nostr_npub(self.VALID_NPUB) is None
|
||||
|
||||
def test_wrong_prefix(self):
|
||||
result = validate_nostr_npub("nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz")
|
||||
result = validate_nostr_npub(
|
||||
"nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz"
|
||||
)
|
||||
assert result is not None
|
||||
assert "npub" in result.lower()
|
||||
|
||||
def test_invalid_checksum(self):
|
||||
# Change last character to break checksum
|
||||
result = validate_nostr_npub("npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd")
|
||||
result = validate_nostr_npub(
|
||||
"npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd"
|
||||
)
|
||||
assert result is not None
|
||||
assert "checksum" in result.lower()
|
||||
|
||||
|
|
@ -155,7 +158,9 @@ class TestValidateNostrNpub:
|
|||
assert result is not None
|
||||
|
||||
def test_not_starting_with_npub1(self):
|
||||
result = validate_nostr_npub("npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc")
|
||||
result = validate_nostr_npub(
|
||||
"npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc"
|
||||
)
|
||||
assert result is not None
|
||||
assert "npub1" in result
|
||||
|
||||
|
|
@ -206,4 +211,3 @@ class TestValidateProfileFields:
|
|||
nostr_npub="",
|
||||
)
|
||||
assert errors == {}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Validate shared constants match backend definitions."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, AppointmentStatus
|
||||
from models import ROLE_ADMIN, ROLE_REGULAR, AppointmentStatus, InviteStatus
|
||||
|
||||
|
||||
def validate_shared_constants() -> None:
|
||||
|
|
@ -29,35 +30,42 @@ def validate_shared_constants() -> None:
|
|||
# Validate invite statuses
|
||||
expected_invite_statuses = {s.name: s.value for s in InviteStatus}
|
||||
if constants.get("inviteStatuses") != expected_invite_statuses:
|
||||
got = constants.get("inviteStatuses")
|
||||
raise ValueError(
|
||||
f"Invite status mismatch in shared/constants.json. "
|
||||
f"Expected: {expected_invite_statuses}, Got: {constants.get('inviteStatuses')}"
|
||||
f"Invite status mismatch. Expected: {expected_invite_statuses}, Got: {got}"
|
||||
)
|
||||
|
||||
# Validate appointment statuses
|
||||
expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus}
|
||||
if constants.get("appointmentStatuses") != expected_appointment_statuses:
|
||||
got = constants.get("appointmentStatuses")
|
||||
raise ValueError(
|
||||
f"Appointment status mismatch in shared/constants.json. "
|
||||
f"Expected: {expected_appointment_statuses}, Got: {constants.get('appointmentStatuses')}"
|
||||
f"Appointment status mismatch. "
|
||||
f"Expected: {expected_appointment_statuses}, Got: {got}"
|
||||
)
|
||||
|
||||
# Validate booking constants exist with required fields
|
||||
booking = constants.get("booking", {})
|
||||
required_booking_fields = ["slotDurationMinutes", "maxAdvanceDays", "minAdvanceDays", "noteMaxLength"]
|
||||
required_booking_fields = [
|
||||
"slotDurationMinutes",
|
||||
"maxAdvanceDays",
|
||||
"minAdvanceDays",
|
||||
"noteMaxLength",
|
||||
]
|
||||
for field in required_booking_fields:
|
||||
if field not in booking:
|
||||
raise ValueError(f"Missing booking constant '{field}' in shared/constants.json")
|
||||
raise ValueError(f"Missing booking constant '{field}' in constants.json")
|
||||
|
||||
# Validate validation rules exist (structure check only)
|
||||
validation = constants.get("validation", {})
|
||||
required_fields = ["telegram", "signal", "nostrNpub"]
|
||||
for field in required_fields:
|
||||
if field not in validation:
|
||||
raise ValueError(f"Missing validation rules for '{field}' in shared/constants.json")
|
||||
raise ValueError(
|
||||
f"Missing validation rules for '{field}' in constants.json"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
validate_shared_constants()
|
||||
print("✓ Shared constants are valid")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
"""Validation utilities for user profile fields."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from email_validator import validate_email, EmailNotValidError
|
||||
from bech32 import bech32_decode
|
||||
from email_validator import EmailNotValidError, validate_email
|
||||
|
||||
# Load validation rules from shared constants
|
||||
_constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
|
||||
|
|
@ -143,4 +144,3 @@ def validate_profile_fields(
|
|||
errors["nostr_npub"] = err
|
||||
|
||||
return errors
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue