Add ruff linter/formatter for Python
- Add ruff as dev dependency - Configure ruff in pyproject.toml with strict 88-char line limit - Ignore B008 (FastAPI Depends pattern is standard) - Allow longer lines in tests for readability - Fix all lint issues in source files - Add Makefile targets: lint-backend, format-backend, fix-backend
This commit is contained in:
parent
69bc8413e0
commit
6c218130e9
31 changed files with 1234 additions and 876 deletions
11
Makefile
11
Makefile
|
|
@ -1,4 +1,4 @@
|
||||||
.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants
|
.PHONY: install-backend install-frontend install setup-hooks backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck generate-types generate-types-standalone check-types-fresh check-constants lint-backend format-backend fix-backend
|
||||||
|
|
||||||
-include .env
|
-include .env
|
||||||
export
|
export
|
||||||
|
|
@ -93,3 +93,12 @@ check-types-fresh: generate-types-standalone
|
||||||
|
|
||||||
check-constants:
|
check-constants:
|
||||||
@cd backend && uv run python validate_constants.py
|
@cd backend && uv run python validate_constants.py
|
||||||
|
|
||||||
|
lint-backend:
|
||||||
|
cd backend && uv run ruff check .
|
||||||
|
|
||||||
|
format-backend:
|
||||||
|
cd backend && uv run ruff format .
|
||||||
|
|
||||||
|
fix-backend:
|
||||||
|
cd backend && uv run ruff check --fix . && uv run ruff format .
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from fastapi import Depends, HTTPException, Request, status
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
|
@ -8,7 +8,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import User, Permission
|
from models import Permission, User
|
||||||
from schemas import UserResponse
|
from schemas import UserResponse
|
||||||
|
|
||||||
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
|
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
|
||||||
|
|
@ -32,9 +32,13 @@ def get_password_hash(password: str) -> str:
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str:
|
def create_access_token(
|
||||||
|
data: dict[str, str],
|
||||||
|
expires_delta: timedelta | None = None,
|
||||||
|
) -> str:
|
||||||
to_encode: dict[str, str | datetime] = dict(data)
|
to_encode: dict[str, str | datetime] = dict(data)
|
||||||
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
expire = datetime.now(UTC) + delta
|
||||||
to_encode["exp"] = expire
|
to_encode["exp"] = expire
|
||||||
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
return encoded
|
return encoded
|
||||||
|
|
@ -72,7 +76,7 @@ async def get_current_user(
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
user_id = int(user_id_str)
|
user_id = int(user_id_str)
|
||||||
except (JWTError, ValueError):
|
except (JWTError, ValueError):
|
||||||
raise credentials_exception
|
raise credentials_exception from None
|
||||||
|
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
result = await db.execute(select(User).where(User.id == user_id))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
@ -83,13 +87,16 @@ async def get_current_user(
|
||||||
|
|
||||||
def require_permission(*required_permissions: Permission):
|
def require_permission(*required_permissions: Permission):
|
||||||
"""
|
"""
|
||||||
Dependency factory that checks if user has ALL of the required permissions.
|
Dependency factory that checks if user has ALL required permissions.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@app.get("/api/counter")
|
@app.get("/api/counter")
|
||||||
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))):
|
async def get_counter(
|
||||||
|
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
|
||||||
|
):
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def permission_checker(
|
async def permission_checker(
|
||||||
request: Request,
|
request: Request,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
|
|
@ -99,11 +106,13 @@ def require_permission(*required_permissions: Permission):
|
||||||
|
|
||||||
missing = [p for p in required_permissions if p not in user_permissions]
|
missing = [p for p in required_permissions if p not in user_permissions]
|
||||||
if missing:
|
if missing:
|
||||||
|
missing_str = ", ".join(p.value for p in missing)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}",
|
detail=f"Missing required permissions: {missing_str}",
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
return permission_checker
|
return permission_checker
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
import os
|
import os
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret")
|
DATABASE_URL = os.getenv(
|
||||||
|
"DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret"
|
||||||
|
)
|
||||||
|
|
||||||
engine = create_async_engine(DATABASE_URL)
|
engine = create_async_engine(DATABASE_URL)
|
||||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
@ -15,4 +18,3 @@ class Base(DeclarativeBase):
|
||||||
async def get_db():
|
async def get_db():
|
||||||
async with async_session() as session:
|
async with async_session() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
"""Utilities for invite code generation and validation."""
|
"""Utilities for invite code generation and validation."""
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -53,8 +54,4 @@ def is_valid_identifier_format(identifier: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check number is two digits
|
# Check number is two digits
|
||||||
if len(number) != 2 or not number.isdigit():
|
return len(number) == 2 and number.isdigit()
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,20 @@
|
||||||
"""FastAPI application entry point."""
|
"""FastAPI application entry point."""
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from database import engine, Base
|
from database import Base, engine
|
||||||
from routes import sum as sum_routes
|
|
||||||
from routes import counter as counter_routes
|
|
||||||
from routes import audit as audit_routes
|
from routes import audit as audit_routes
|
||||||
from routes import profile as profile_routes
|
|
||||||
from routes import invites as invites_routes
|
|
||||||
from routes import auth as auth_routes
|
from routes import auth as auth_routes
|
||||||
from routes import meta as meta_routes
|
|
||||||
from routes import availability as availability_routes
|
from routes import availability as availability_routes
|
||||||
from routes import booking as booking_routes
|
from routes import booking as booking_routes
|
||||||
|
from routes import counter as counter_routes
|
||||||
|
from routes import invites as invites_routes
|
||||||
|
from routes import meta as meta_routes
|
||||||
|
from routes import profile as profile_routes
|
||||||
|
from routes import sum as sum_routes
|
||||||
from validate_constants import validate_shared_constants
|
from validate_constants import validate_shared_constants
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,24 @@
|
||||||
from datetime import datetime, date, time, timezone
|
from datetime import UTC, date, datetime, time
|
||||||
from enum import Enum as PyEnum
|
from enum import Enum as PyEnum
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
from sqlalchemy import Integer, String, Float, DateTime, Date, Time, ForeignKey, Table, Column, Enum, UniqueConstraint, select
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy import (
|
||||||
|
Column,
|
||||||
|
Date,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Table,
|
||||||
|
Time,
|
||||||
|
UniqueConstraint,
|
||||||
|
select,
|
||||||
|
)
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from database import Base
|
from database import Base
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -14,6 +29,7 @@ class RoleConfig(TypedDict):
|
||||||
|
|
||||||
class Permission(str, PyEnum):
|
class Permission(str, PyEnum):
|
||||||
"""All available permissions in the system."""
|
"""All available permissions in the system."""
|
||||||
|
|
||||||
# Counter permissions
|
# Counter permissions
|
||||||
VIEW_COUNTER = "view_counter"
|
VIEW_COUNTER = "view_counter"
|
||||||
INCREMENT_COUNTER = "increment_counter"
|
INCREMENT_COUNTER = "increment_counter"
|
||||||
|
|
@ -41,6 +57,7 @@ class Permission(str, PyEnum):
|
||||||
|
|
||||||
class InviteStatus(str, PyEnum):
|
class InviteStatus(str, PyEnum):
|
||||||
"""Status of an invite."""
|
"""Status of an invite."""
|
||||||
|
|
||||||
READY = "ready"
|
READY = "ready"
|
||||||
SPENT = "spent"
|
SPENT = "spent"
|
||||||
REVOKED = "revoked"
|
REVOKED = "revoked"
|
||||||
|
|
@ -48,6 +65,7 @@ class InviteStatus(str, PyEnum):
|
||||||
|
|
||||||
class AppointmentStatus(str, PyEnum):
|
class AppointmentStatus(str, PyEnum):
|
||||||
"""Status of an appointment."""
|
"""Status of an appointment."""
|
||||||
|
|
||||||
BOOKED = "booked"
|
BOOKED = "booked"
|
||||||
CANCELLED_BY_USER = "cancelled_by_user"
|
CANCELLED_BY_USER = "cancelled_by_user"
|
||||||
CANCELLED_BY_ADMIN = "cancelled_by_admin"
|
CANCELLED_BY_ADMIN = "cancelled_by_admin"
|
||||||
|
|
@ -60,7 +78,7 @@ ROLE_REGULAR = "regular"
|
||||||
# Role definitions with their permissions
|
# Role definitions with their permissions
|
||||||
ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
||||||
ROLE_ADMIN: {
|
ROLE_ADMIN: {
|
||||||
"description": "Administrator with audit, invite, and appointment management access",
|
"description": "Administrator with audit/invite/appointment access",
|
||||||
"permissions": [
|
"permissions": [
|
||||||
Permission.VIEW_AUDIT,
|
Permission.VIEW_AUDIT,
|
||||||
Permission.MANAGE_INVITES,
|
Permission.MANAGE_INVITES,
|
||||||
|
|
@ -88,7 +106,12 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
||||||
role_permissions = Table(
|
role_permissions = Table(
|
||||||
"role_permissions",
|
"role_permissions",
|
||||||
Base.metadata,
|
Base.metadata,
|
||||||
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
|
Column(
|
||||||
|
"role_id",
|
||||||
|
Integer,
|
||||||
|
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
),
|
||||||
Column("permission", Enum(Permission), primary_key=True),
|
Column("permission", Enum(Permission), primary_key=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -97,8 +120,18 @@ role_permissions = Table(
|
||||||
user_roles = Table(
|
user_roles = Table(
|
||||||
"user_roles",
|
"user_roles",
|
||||||
Base.metadata,
|
Base.metadata,
|
||||||
Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True),
|
Column(
|
||||||
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
|
"user_id",
|
||||||
|
Integer,
|
||||||
|
ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
),
|
||||||
|
Column(
|
||||||
|
"role_id",
|
||||||
|
Integer,
|
||||||
|
ForeignKey("roles.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -118,23 +151,34 @@ class Role(Base):
|
||||||
|
|
||||||
async def get_permissions(self, db: AsyncSession) -> set[Permission]:
|
async def get_permissions(self, db: AsyncSession) -> set[Permission]:
|
||||||
"""Get all permissions for this role."""
|
"""Get all permissions for this role."""
|
||||||
result = await db.execute(
|
query = select(role_permissions.c.permission).where(
|
||||||
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id)
|
role_permissions.c.role_id == self.id
|
||||||
)
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
return {row[0] for row in result.fetchall()}
|
return {row[0] for row in result.fetchall()}
|
||||||
|
|
||||||
async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None:
|
async def set_permissions(
|
||||||
|
self, db: AsyncSession, permissions: list[Permission]
|
||||||
|
) -> None:
|
||||||
"""Set all permissions for this role (replaces existing)."""
|
"""Set all permissions for this role (replaces existing)."""
|
||||||
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id))
|
delete_query = role_permissions.delete().where(
|
||||||
|
role_permissions.c.role_id == self.id
|
||||||
|
)
|
||||||
|
await db.execute(delete_query)
|
||||||
for perm in permissions:
|
for perm in permissions:
|
||||||
await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm))
|
insert_query = role_permissions.insert().values(
|
||||||
|
role_id=self.id, permission=perm
|
||||||
|
)
|
||||||
|
await db.execute(insert_query)
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
email: Mapped[str] = mapped_column(
|
||||||
|
String(255), unique=True, nullable=False, index=True
|
||||||
|
)
|
||||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
|
||||||
# Contact details (all optional)
|
# Contact details (all optional)
|
||||||
|
|
@ -192,12 +236,14 @@ class SumRecord(Base):
|
||||||
__tablename__ = "sum_records"
|
__tablename__ = "sum_records"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
user_id: Mapped[int] = mapped_column(
|
||||||
|
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
a: Mapped[float] = mapped_column(Float, nullable=False)
|
a: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
b: Mapped[float] = mapped_column(Float, nullable=False)
|
b: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
result: Mapped[float] = mapped_column(Float, nullable=False)
|
result: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -205,11 +251,13 @@ class CounterRecord(Base):
|
||||||
__tablename__ = "counter_records"
|
__tablename__ = "counter_records"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
user_id: Mapped[int] = mapped_column(
|
||||||
|
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
value_before: Mapped[int] = mapped_column(Integer, nullable=False)
|
value_before: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
value_after: Mapped[int] = mapped_column(Integer, nullable=False)
|
value_after: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -217,7 +265,9 @@ class Invite(Base):
|
||||||
__tablename__ = "invites"
|
__tablename__ = "invites"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
identifier: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
identifier: Mapped[str] = mapped_column(
|
||||||
|
String(64), unique=True, nullable=False, index=True
|
||||||
|
)
|
||||||
status: Mapped[InviteStatus] = mapped_column(
|
status: Mapped[InviteStatus] = mapped_column(
|
||||||
Enum(InviteStatus), nullable=False, default=InviteStatus.READY
|
Enum(InviteStatus), nullable=False, default=InviteStatus.READY
|
||||||
)
|
)
|
||||||
|
|
@ -244,14 +294,19 @@ class Invite(Base):
|
||||||
|
|
||||||
# Timestamps
|
# Timestamps
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||||
|
)
|
||||||
|
spent_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
|
)
|
||||||
|
revoked_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
)
|
)
|
||||||
spent_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
revoked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
|
|
||||||
|
|
||||||
class Availability(Base):
|
class Availability(Base):
|
||||||
"""Admin availability slots for booking."""
|
"""Admin availability slots for booking."""
|
||||||
|
|
||||||
__tablename__ = "availability"
|
__tablename__ = "availability"
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint("date", "start_time", name="uq_availability_date_start"),
|
UniqueConstraint("date", "start_time", name="uq_availability_date_start"),
|
||||||
|
|
@ -262,34 +317,37 @@ class Availability(Base):
|
||||||
start_time: Mapped[time] = mapped_column(Time, nullable=False)
|
start_time: Mapped[time] = mapped_column(Time, nullable=False)
|
||||||
end_time: Mapped[time] = mapped_column(Time, nullable=False)
|
end_time: Mapped[time] = mapped_column(Time, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||||
)
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True),
|
DateTime(timezone=True),
|
||||||
default=lambda: datetime.now(timezone.utc),
|
default=lambda: datetime.now(UTC),
|
||||||
onupdate=lambda: datetime.now(timezone.utc)
|
onupdate=lambda: datetime.now(UTC),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Appointment(Base):
|
class Appointment(Base):
|
||||||
"""User appointment bookings."""
|
"""User appointment bookings."""
|
||||||
|
|
||||||
__tablename__ = "appointments"
|
__tablename__ = "appointments"
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint("slot_start", name="uq_appointment_slot_start"),)
|
||||||
UniqueConstraint("slot_start", name="uq_appointment_slot_start"),
|
|
||||||
)
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
user_id: Mapped[int] = mapped_column(
|
user_id: Mapped[int] = mapped_column(
|
||||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||||
)
|
)
|
||||||
user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined")
|
user: Mapped[User] = relationship("User", foreign_keys=[user_id], lazy="joined")
|
||||||
slot_start: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
slot_start: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, index=True
|
||||||
|
)
|
||||||
slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
slot_end: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
note: Mapped[str | None] = mapped_column(String(144), nullable=True)
|
note: Mapped[str | None] = mapped_column(String(144), nullable=True)
|
||||||
status: Mapped[AppointmentStatus] = mapped_column(
|
status: Mapped[AppointmentStatus] = mapped_column(
|
||||||
Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED
|
Enum(AppointmentStatus), nullable=False, default=AppointmentStatus.BOOKED
|
||||||
)
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||||
|
)
|
||||||
|
cancelled_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True
|
||||||
)
|
)
|
||||||
cancelled_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ dev = [
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"aiosqlite>=0.20.0",
|
"aiosqlite>=0.20.0",
|
||||||
"mypy>=1.13.0",
|
"mypy>=1.13.0",
|
||||||
|
"ruff>=0.14.10",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
|
|
@ -30,3 +31,27 @@ check_untyped_defs = true
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
exclude = ["tests/"]
|
exclude = ["tests/"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 88
|
||||||
|
target-version = "py311"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"RUF", # ruff-specific rules
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"B008", # function-call-in-default-argument (standard FastAPI pattern with Depends)
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
quote-style = "double"
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/*" = ["E501"] # Allow longer lines in tests for readability
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,23 @@
|
||||||
"""Audit routes for viewing action records."""
|
"""Audit routes for viewing action records."""
|
||||||
from typing import Callable, TypeVar
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select, func, desc
|
from sqlalchemy import desc, func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from auth import require_permission
|
from auth import require_permission
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import User, SumRecord, CounterRecord, Permission
|
from models import CounterRecord, Permission, SumRecord, User
|
||||||
from schemas import (
|
from schemas import (
|
||||||
CounterRecordResponse,
|
CounterRecordResponse,
|
||||||
SumRecordResponse,
|
|
||||||
PaginatedCounterRecords,
|
PaginatedCounterRecords,
|
||||||
PaginatedSumRecords,
|
PaginatedSumRecords,
|
||||||
|
SumRecordResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/audit", tags=["audit"])
|
router = APIRouter(prefix="/api/audit", tags=["audit"])
|
||||||
|
|
||||||
R = TypeVar("R", bound=BaseModel)
|
R = TypeVar("R", bound=BaseModel)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Authentication routes for register, login, logout, and current user."""
|
"""Authentication routes for register, login, logout, and current user."""
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -9,18 +10,17 @@ from auth import (
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
COOKIE_NAME,
|
COOKIE_NAME,
|
||||||
COOKIE_SECURE,
|
COOKIE_SECURE,
|
||||||
get_password_hash,
|
|
||||||
get_user_by_email,
|
|
||||||
authenticate_user,
|
authenticate_user,
|
||||||
|
build_user_response,
|
||||||
create_access_token,
|
create_access_token,
|
||||||
get_current_user,
|
get_current_user,
|
||||||
build_user_response,
|
get_password_hash,
|
||||||
|
get_user_by_email,
|
||||||
)
|
)
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from invite_utils import normalize_identifier
|
from invite_utils import normalize_identifier
|
||||||
from models import User, Role, ROLE_REGULAR, Invite, InviteStatus
|
from models import ROLE_REGULAR, Invite, InviteStatus, Role, User
|
||||||
from schemas import UserLogin, UserResponse, RegisterWithInvite
|
from schemas import RegisterWithInvite, UserLogin, UserResponse
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
@ -52,9 +52,8 @@ async def register(
|
||||||
"""Register a new user using an invite code."""
|
"""Register a new user using an invite code."""
|
||||||
# Validate invite
|
# Validate invite
|
||||||
normalized_identifier = normalize_identifier(user_data.invite_identifier)
|
normalized_identifier = normalize_identifier(user_data.invite_identifier)
|
||||||
result = await db.execute(
|
query = select(Invite).where(Invite.identifier == normalized_identifier)
|
||||||
select(Invite).where(Invite.identifier == normalized_identifier)
|
result = await db.execute(query)
|
||||||
)
|
|
||||||
invite = result.scalar_one_or_none()
|
invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
# Return same error for not found, spent, and revoked to avoid information leakage
|
# Return same error for not found, spent, and revoked to avoid information leakage
|
||||||
|
|
@ -90,7 +89,7 @@ async def register(
|
||||||
# Mark invite as spent
|
# Mark invite as spent
|
||||||
invite.status = InviteStatus.SPENT
|
invite.status = InviteStatus.SPENT
|
||||||
invite.used_by_id = user.id
|
invite.used_by_id = user.id
|
||||||
invite.spent_at = datetime.now(timezone.utc)
|
invite.spent_at = datetime.now(UTC)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,28 @@
|
||||||
"""Availability routes for admin to manage booking availability."""
|
"""Availability routes for admin to manage booking availability."""
|
||||||
|
|
||||||
from datetime import date, timedelta
|
from datetime import date, timedelta
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy import select, delete, and_
|
from sqlalchemy import and_, delete, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from auth import require_permission
|
from auth import require_permission
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import User, Availability, Permission
|
from models import Availability, Permission, User
|
||||||
from schemas import (
|
from schemas import (
|
||||||
TimeSlot,
|
|
||||||
AvailabilityDay,
|
AvailabilityDay,
|
||||||
AvailabilityResponse,
|
AvailabilityResponse,
|
||||||
SetAvailabilityRequest,
|
|
||||||
CopyAvailabilityRequest,
|
CopyAvailabilityRequest,
|
||||||
|
SetAvailabilityRequest,
|
||||||
|
TimeSlot,
|
||||||
)
|
)
|
||||||
from shared_constants import MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS
|
from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/admin/availability", tags=["availability"])
|
router = APIRouter(prefix="/api/admin/availability", tags=["availability"])
|
||||||
|
|
||||||
|
|
||||||
def _get_date_range_bounds() -> tuple[date, date]:
|
def _get_date_range_bounds() -> tuple[date, date]:
|
||||||
"""Get the valid date range for availability (using MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
|
"""Get valid date range (MIN_ADVANCE_DAYS to MAX_ADVANCE_DAYS)."""
|
||||||
today = date.today()
|
today = date.today()
|
||||||
min_date = today + timedelta(days=MIN_ADVANCE_DAYS)
|
min_date = today + timedelta(days=MIN_ADVANCE_DAYS)
|
||||||
max_date = today + timedelta(days=MAX_ADVANCE_DAYS)
|
max_date = today + timedelta(days=MAX_ADVANCE_DAYS)
|
||||||
|
|
@ -34,12 +34,14 @@ def _validate_date_in_range(d: date, min_date: date, max_date: date) -> None:
|
||||||
if d < min_date:
|
if d < min_date:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot set availability for past dates. Earliest allowed: {min_date}",
|
detail=f"Cannot set availability for past dates. "
|
||||||
|
f"Earliest allowed: {min_date}",
|
||||||
)
|
)
|
||||||
if d > max_date:
|
if d > max_date:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot set availability more than {MAX_ADVANCE_DAYS} days ahead. Latest allowed: {max_date}",
|
detail=f"Cannot set more than {MAX_ADVANCE_DAYS} days ahead. "
|
||||||
|
f"Latest allowed: {max_date}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -70,15 +72,16 @@ async def get_availability(
|
||||||
for slot in slots:
|
for slot in slots:
|
||||||
if slot.date not in days_dict:
|
if slot.date not in days_dict:
|
||||||
days_dict[slot.date] = []
|
days_dict[slot.date] = []
|
||||||
days_dict[slot.date].append(TimeSlot(
|
days_dict[slot.date].append(
|
||||||
start_time=slot.start_time,
|
TimeSlot(
|
||||||
end_time=slot.end_time,
|
start_time=slot.start_time,
|
||||||
))
|
end_time=slot.end_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Convert to response format
|
# Convert to response format
|
||||||
days = [
|
days = [
|
||||||
AvailabilityDay(date=d, slots=days_dict[d])
|
AvailabilityDay(date=d, slots=days_dict[d]) for d in sorted(days_dict.keys())
|
||||||
for d in sorted(days_dict.keys())
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return AvailabilityResponse(days=days)
|
return AvailabilityResponse(days=days)
|
||||||
|
|
@ -98,9 +101,12 @@ async def set_availability(
|
||||||
sorted_slots = sorted(request.slots, key=lambda s: s.start_time)
|
sorted_slots = sorted(request.slots, key=lambda s: s.start_time)
|
||||||
for i in range(len(sorted_slots) - 1):
|
for i in range(len(sorted_slots) - 1):
|
||||||
if sorted_slots[i].end_time > sorted_slots[i + 1].start_time:
|
if sorted_slots[i].end_time > sorted_slots[i + 1].start_time:
|
||||||
|
end = sorted_slots[i].end_time
|
||||||
|
start = sorted_slots[i + 1].start_time
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Time slots overlap on {request.date}: slot ending at {sorted_slots[i].end_time} overlaps with slot starting at {sorted_slots[i + 1].start_time}. Please ensure all time slots are non-overlapping.",
|
detail=f"Time slots overlap: slot ending at {end} "
|
||||||
|
f"overlaps with slot starting at {start}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate each slot's end_time > start_time
|
# Validate each slot's end_time > start_time
|
||||||
|
|
@ -108,13 +114,12 @@ async def set_availability(
|
||||||
if slot.end_time <= slot.start_time:
|
if slot.end_time <= slot.start_time:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Invalid time slot on {request.date}: end time {slot.end_time} must be after start time {slot.start_time}. Please correct the time range.",
|
detail=f"Invalid time slot: end time {slot.end_time} "
|
||||||
|
f"must be after start time {slot.start_time}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete existing availability for this date
|
# Delete existing availability for this date
|
||||||
await db.execute(
|
await db.execute(delete(Availability).where(Availability.date == request.date))
|
||||||
delete(Availability).where(Availability.date == request.date)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create new availability slots
|
# Create new availability slots
|
||||||
for slot in request.slots:
|
for slot in request.slots:
|
||||||
|
|
@ -139,7 +144,7 @@ async def copy_availability(
|
||||||
"""Copy availability from one day to multiple target days."""
|
"""Copy availability from one day to multiple target days."""
|
||||||
min_date, max_date = _get_date_range_bounds()
|
min_date, max_date = _get_date_range_bounds()
|
||||||
|
|
||||||
# Validate source date is in range (for consistency, though DB query would fail anyway)
|
# Validate source date is in range
|
||||||
_validate_date_in_range(request.source_date, min_date, max_date)
|
_validate_date_in_range(request.source_date, min_date, max_date)
|
||||||
|
|
||||||
# Validate target dates
|
# Validate target dates
|
||||||
|
|
@ -169,9 +174,8 @@ async def copy_availability(
|
||||||
continue # Skip copying to self
|
continue # Skip copying to self
|
||||||
|
|
||||||
# Delete existing availability for target date
|
# Delete existing availability for target date
|
||||||
await db.execute(
|
del_query = delete(Availability).where(Availability.date == target_date)
|
||||||
delete(Availability).where(Availability.date == target_date)
|
await db.execute(del_query)
|
||||||
)
|
|
||||||
|
|
||||||
# Copy slots
|
# Copy slots
|
||||||
target_slots: list[TimeSlot] = []
|
target_slots: list[TimeSlot] = []
|
||||||
|
|
@ -182,10 +186,12 @@ async def copy_availability(
|
||||||
end_time=source_slot.end_time,
|
end_time=source_slot.end_time,
|
||||||
)
|
)
|
||||||
db.add(new_availability)
|
db.add(new_availability)
|
||||||
target_slots.append(TimeSlot(
|
target_slots.append(
|
||||||
start_time=source_slot.start_time,
|
TimeSlot(
|
||||||
end_time=source_slot.end_time,
|
start_time=source_slot.start_time,
|
||||||
))
|
end_time=source_slot.end_time,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
copied_days.append(AvailabilityDay(date=target_date, slots=target_slots))
|
copied_days.append(AvailabilityDay(date=target_date, slots=target_slots))
|
||||||
|
|
||||||
|
|
@ -197,4 +203,3 @@ async def copy_availability(
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return AvailabilityResponse(days=copied_days)
|
return AvailabilityResponse(days=copied_days)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,24 +1,24 @@
|
||||||
"""Booking routes for users to book appointments."""
|
"""Booking routes for users to book appointments."""
|
||||||
from datetime import date, datetime, time, timedelta, timezone
|
|
||||||
|
from datetime import UTC, date, datetime, time, timedelta
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy import select, and_, func
|
from sqlalchemy import and_, func, select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from auth import require_permission
|
from auth import require_permission
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import User, Availability, Appointment, AppointmentStatus, Permission
|
from models import Appointment, AppointmentStatus, Availability, Permission, User
|
||||||
from schemas import (
|
from schemas import (
|
||||||
BookableSlot,
|
|
||||||
AvailableSlotsResponse,
|
|
||||||
BookingRequest,
|
|
||||||
AppointmentResponse,
|
AppointmentResponse,
|
||||||
|
AvailableSlotsResponse,
|
||||||
|
BookableSlot,
|
||||||
|
BookingRequest,
|
||||||
PaginatedAppointments,
|
PaginatedAppointments,
|
||||||
)
|
)
|
||||||
from shared_constants import SLOT_DURATION_MINUTES, MIN_ADVANCE_DAYS, MAX_ADVANCE_DAYS
|
from shared_constants import MAX_ADVANCE_DAYS, MIN_ADVANCE_DAYS, SLOT_DURATION_MINUTES
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/booking", tags=["booking"])
|
router = APIRouter(prefix="/api/booking", tags=["booking"])
|
||||||
|
|
||||||
|
|
@ -74,12 +74,14 @@ def _validate_booking_date(d: date) -> None:
|
||||||
if d < min_date:
|
if d < min_date:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot book for today or past dates. Earliest bookable date: {min_date}",
|
detail=f"Cannot book for today or past dates. "
|
||||||
|
f"Earliest bookable: {min_date}",
|
||||||
)
|
)
|
||||||
if d > max_date:
|
if d > max_date:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. Latest bookable: {max_date}",
|
detail=f"Cannot book more than {MAX_ADVANCE_DAYS} days ahead. "
|
||||||
|
f"Latest bookable: {max_date}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -92,8 +94,8 @@ def _expand_availability_to_slots(
|
||||||
|
|
||||||
for avail in availability_slots:
|
for avail in availability_slots:
|
||||||
# Create datetime objects for start and end
|
# Create datetime objects for start and end
|
||||||
current = datetime.combine(target_date, avail.start_time, tzinfo=timezone.utc)
|
current = datetime.combine(target_date, avail.start_time, tzinfo=UTC)
|
||||||
end = datetime.combine(target_date, avail.end_time, tzinfo=timezone.utc)
|
end = datetime.combine(target_date, avail.end_time, tzinfo=UTC)
|
||||||
|
|
||||||
# Generate 15-minute slots
|
# Generate 15-minute slots
|
||||||
while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end:
|
while current + timedelta(minutes=SLOT_DURATION_MINUTES) <= end:
|
||||||
|
|
@ -128,12 +130,11 @@ async def get_available_slots(
|
||||||
all_slots = _expand_availability_to_slots(availability_slots, target_date)
|
all_slots = _expand_availability_to_slots(availability_slots, target_date)
|
||||||
|
|
||||||
# Get existing booked appointments for this date
|
# Get existing booked appointments for this date
|
||||||
day_start = datetime.combine(target_date, time.min, tzinfo=timezone.utc)
|
day_start = datetime.combine(target_date, time.min, tzinfo=UTC)
|
||||||
day_end = datetime.combine(target_date, time.max, tzinfo=timezone.utc)
|
day_end = datetime.combine(target_date, time.max, tzinfo=UTC)
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Appointment.slot_start)
|
select(Appointment.slot_start).where(
|
||||||
.where(
|
|
||||||
and_(
|
and_(
|
||||||
Appointment.slot_start >= day_start,
|
Appointment.slot_start >= day_start,
|
||||||
Appointment.slot_start <= day_end,
|
Appointment.slot_start <= day_end,
|
||||||
|
|
@ -145,8 +146,7 @@ async def get_available_slots(
|
||||||
|
|
||||||
# Filter out already booked slots
|
# Filter out already booked slots
|
||||||
available_slots = [
|
available_slots = [
|
||||||
slot for slot in all_slots
|
slot for slot in all_slots if slot.start_time not in booked_starts
|
||||||
if slot.start_time not in booked_starts
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return AvailableSlotsResponse(date=target_date, slots=available_slots)
|
return AvailableSlotsResponse(date=target_date, slots=available_slots)
|
||||||
|
|
@ -162,12 +162,13 @@ async def create_booking(
|
||||||
slot_date = request.slot_start.date()
|
slot_date = request.slot_start.date()
|
||||||
_validate_booking_date(slot_date)
|
_validate_booking_date(slot_date)
|
||||||
|
|
||||||
# Validate slot is on the correct minute boundary (derived from SLOT_DURATION_MINUTES)
|
# Validate slot is on the correct minute boundary
|
||||||
valid_minutes = _get_valid_minute_boundaries()
|
valid_minutes = _get_valid_minute_boundaries()
|
||||||
if request.slot_start.minute not in valid_minutes:
|
if request.slot_start.minute not in valid_minutes:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Slot start time must be on {SLOT_DURATION_MINUTES}-minute boundary (valid minutes: {valid_minutes})",
|
detail=f"Slot must be on {SLOT_DURATION_MINUTES}-minute boundary "
|
||||||
|
f"(valid minutes: {valid_minutes})",
|
||||||
)
|
)
|
||||||
if request.slot_start.second != 0 or request.slot_start.microsecond != 0:
|
if request.slot_start.second != 0 or request.slot_start.microsecond != 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -177,11 +178,11 @@ async def create_booking(
|
||||||
|
|
||||||
# Verify slot falls within availability
|
# Verify slot falls within availability
|
||||||
slot_start_time = request.slot_start.time()
|
slot_start_time = request.slot_start.time()
|
||||||
slot_end_time = (request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)).time()
|
slot_end_dt = request.slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
|
||||||
|
slot_end_time = slot_end_dt.time()
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Availability)
|
select(Availability).where(
|
||||||
.where(
|
|
||||||
and_(
|
and_(
|
||||||
Availability.date == slot_date,
|
Availability.date == slot_date,
|
||||||
Availability.start_time <= slot_start_time,
|
Availability.start_time <= slot_start_time,
|
||||||
|
|
@ -192,9 +193,11 @@ async def create_booking(
|
||||||
matching_availability = result.scalar_one_or_none()
|
matching_availability = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not matching_availability:
|
if not matching_availability:
|
||||||
|
slot_str = request.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Selected slot at {request.slot_start.strftime('%Y-%m-%d %H:%M')} UTC is not within any available time ranges for {slot_date}. Please select a different time slot.",
|
detail=f"Selected slot at {slot_str} UTC is not within "
|
||||||
|
f"any available time ranges for {slot_date}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the appointment
|
# Create the appointment
|
||||||
|
|
@ -216,8 +219,8 @@ async def create_booking(
|
||||||
await db.rollback()
|
await db.rollback()
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409,
|
status_code=409,
|
||||||
detail="This slot has already been booked. Please select another slot.",
|
detail="This slot has already been booked. Select another slot.",
|
||||||
)
|
) from None
|
||||||
|
|
||||||
return _to_appointment_response(appointment, current_user.email)
|
return _to_appointment_response(appointment, current_user.email)
|
||||||
|
|
||||||
|
|
@ -242,20 +245,19 @@ async def get_my_appointments(
|
||||||
)
|
)
|
||||||
appointments = result.scalars().all()
|
appointments = result.scalars().all()
|
||||||
|
|
||||||
return [
|
return [_to_appointment_response(apt, current_user.email) for apt in appointments]
|
||||||
_to_appointment_response(apt, current_user.email)
|
|
||||||
for apt in appointments
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse)
|
@appointments_router.post(
|
||||||
|
"/{appointment_id}/cancel", response_model=AppointmentResponse
|
||||||
|
)
|
||||||
async def cancel_my_appointment(
|
async def cancel_my_appointment(
|
||||||
appointment_id: int,
|
appointment_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)),
|
current_user: User = Depends(require_permission(Permission.CANCEL_OWN_APPOINTMENT)),
|
||||||
) -> AppointmentResponse:
|
) -> AppointmentResponse:
|
||||||
"""Cancel one of the current user's appointments."""
|
"""Cancel one of the current user's appointments."""
|
||||||
# Get the appointment with explicit eager loading of user relationship
|
# Get the appointment with eager loading of user relationship
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Appointment)
|
select(Appointment)
|
||||||
.options(joinedload(Appointment.user))
|
.options(joinedload(Appointment.user))
|
||||||
|
|
@ -266,31 +268,35 @@ async def cancel_my_appointment(
|
||||||
if not appointment:
|
if not appointment:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.",
|
detail=f"Appointment {appointment_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify ownership
|
# Verify ownership
|
||||||
if appointment.user_id != current_user.id:
|
if appointment.user_id != current_user.id:
|
||||||
raise HTTPException(status_code=403, detail="Cannot cancel another user's appointment")
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Cannot cancel another user's appointment",
|
||||||
|
)
|
||||||
|
|
||||||
# Check if already cancelled
|
# Check if already cancelled
|
||||||
if appointment.status != AppointmentStatus.BOOKED:
|
if appointment.status != AppointmentStatus.BOOKED:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot cancel appointment with status '{appointment.status.value}'"
|
detail=f"Cannot cancel: status is '{appointment.status.value}'",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if appointment is in the past
|
# Check if appointment is in the past
|
||||||
if appointment.slot_start <= datetime.now(timezone.utc):
|
if appointment.slot_start <= datetime.now(UTC):
|
||||||
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC"
|
apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started."
|
detail=f"Cannot cancel appointment at {apt_time} UTC: "
|
||||||
|
"already started or in the past",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cancel the appointment
|
# Cancel the appointment
|
||||||
appointment.status = AppointmentStatus.CANCELLED_BY_USER
|
appointment.status = AppointmentStatus.CANCELLED_BY_USER
|
||||||
appointment.cancelled_at = datetime.now(timezone.utc)
|
appointment.cancelled_at = datetime.now(UTC)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(appointment)
|
await db.refresh(appointment)
|
||||||
|
|
@ -302,7 +308,9 @@ async def cancel_my_appointment(
|
||||||
# Admin Appointments Endpoints
|
# Admin Appointments Endpoints
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
admin_appointments_router = APIRouter(prefix="/api/admin/appointments", tags=["admin-appointments"])
|
admin_appointments_router = APIRouter(
|
||||||
|
prefix="/api/admin/appointments", tags=["admin-appointments"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_appointments_router.get("", response_model=PaginatedAppointments)
|
@admin_appointments_router.get("", response_model=PaginatedAppointments)
|
||||||
|
|
@ -344,14 +352,18 @@ async def get_all_appointments(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_appointments_router.post("/{appointment_id}/cancel", response_model=AppointmentResponse)
|
@admin_appointments_router.post(
|
||||||
|
"/{appointment_id}/cancel", response_model=AppointmentResponse
|
||||||
|
)
|
||||||
async def admin_cancel_appointment(
|
async def admin_cancel_appointment(
|
||||||
appointment_id: int,
|
appointment_id: int,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
_current_user: User = Depends(require_permission(Permission.CANCEL_ANY_APPOINTMENT)),
|
_current_user: User = Depends(
|
||||||
|
require_permission(Permission.CANCEL_ANY_APPOINTMENT)
|
||||||
|
),
|
||||||
) -> AppointmentResponse:
|
) -> AppointmentResponse:
|
||||||
"""Cancel any appointment (admin only)."""
|
"""Cancel any appointment (admin only)."""
|
||||||
# Get the appointment with explicit eager loading of user relationship
|
# Get the appointment with eager loading of user relationship
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Appointment)
|
select(Appointment)
|
||||||
.options(joinedload(Appointment.user))
|
.options(joinedload(Appointment.user))
|
||||||
|
|
@ -362,27 +374,28 @@ async def admin_cancel_appointment(
|
||||||
if not appointment:
|
if not appointment:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"Appointment with ID {appointment_id} not found. It may have been deleted or the ID is invalid.",
|
detail=f"Appointment {appointment_id} not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if already cancelled
|
# Check if already cancelled
|
||||||
if appointment.status != AppointmentStatus.BOOKED:
|
if appointment.status != AppointmentStatus.BOOKED:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot cancel appointment with status '{appointment.status.value}'"
|
detail=f"Cannot cancel: status is '{appointment.status.value}'",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if appointment is in the past
|
# Check if appointment is in the past
|
||||||
if appointment.slot_start <= datetime.now(timezone.utc):
|
if appointment.slot_start <= datetime.now(UTC):
|
||||||
appointment_time = appointment.slot_start.strftime('%Y-%m-%d %H:%M') + " UTC"
|
apt_time = appointment.slot_start.strftime("%Y-%m-%d %H:%M")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Cannot cancel appointment scheduled for {appointment_time} as it is in the past or has already started."
|
detail=f"Cannot cancel appointment at {apt_time} UTC: "
|
||||||
|
"already started or in the past",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Cancel the appointment
|
# Cancel the appointment
|
||||||
appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN
|
appointment.status = AppointmentStatus.CANCELLED_BY_ADMIN
|
||||||
appointment.cancelled_at = datetime.now(timezone.utc)
|
appointment.cancelled_at = datetime.now(UTC)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(appointment)
|
await db.refresh(appointment)
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
"""Counter routes."""
|
"""Counter routes."""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from auth import require_permission
|
from auth import require_permission
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import Counter, User, CounterRecord, Permission
|
from models import Counter, CounterRecord, Permission, User
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/counter", tags=["counter"])
|
router = APIRouter(prefix="/api/counter", tags=["counter"])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,29 @@
|
||||||
"""Invite routes for public check, user invites, and admin management."""
|
"""Invite routes for public check, user invites, and admin management."""
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from datetime import UTC, datetime
|
||||||
from sqlalchemy import select, func, desc
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from sqlalchemy import desc, func, select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from auth import require_permission
|
from auth import require_permission
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from invite_utils import generate_invite_identifier, normalize_identifier, is_valid_identifier_format
|
from invite_utils import (
|
||||||
from models import User, Invite, InviteStatus, Permission
|
generate_invite_identifier,
|
||||||
|
is_valid_identifier_format,
|
||||||
|
normalize_identifier,
|
||||||
|
)
|
||||||
|
from models import Invite, InviteStatus, Permission, User
|
||||||
from schemas import (
|
from schemas import (
|
||||||
|
AdminUserResponse,
|
||||||
InviteCheckResponse,
|
InviteCheckResponse,
|
||||||
InviteCreate,
|
InviteCreate,
|
||||||
InviteResponse,
|
InviteResponse,
|
||||||
UserInviteResponse,
|
|
||||||
PaginatedInviteRecords,
|
PaginatedInviteRecords,
|
||||||
AdminUserResponse,
|
UserInviteResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/invites", tags=["invites"])
|
router = APIRouter(prefix="/api/invites", tags=["invites"])
|
||||||
admin_router = APIRouter(prefix="/api/admin", tags=["admin"])
|
admin_router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||||
|
|
||||||
|
|
@ -54,9 +58,7 @@ async def check_invite(
|
||||||
if not is_valid_identifier_format(normalized):
|
if not is_valid_identifier_format(normalized):
|
||||||
return InviteCheckResponse(valid=False, error="Invalid invite code format")
|
return InviteCheckResponse(valid=False, error="Invalid invite code format")
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(select(Invite).where(Invite.identifier == normalized))
|
||||||
select(Invite).where(Invite.identifier == normalized)
|
|
||||||
)
|
|
||||||
invite = result.scalar_one_or_none()
|
invite = result.scalar_one_or_none()
|
||||||
|
|
||||||
# Return same error for not found, spent, and revoked to avoid information leakage
|
# Return same error for not found, spent, and revoked to avoid information leakage
|
||||||
|
|
@ -112,9 +114,7 @@ async def create_invite(
|
||||||
) -> InviteResponse:
|
) -> InviteResponse:
|
||||||
"""Create a new invite for a specified godfather user."""
|
"""Create a new invite for a specified godfather user."""
|
||||||
# Validate godfather exists
|
# Validate godfather exists
|
||||||
result = await db.execute(
|
result = await db.execute(select(User.id).where(User.id == data.godfather_id))
|
||||||
select(User.id).where(User.id == data.godfather_id)
|
|
||||||
)
|
|
||||||
godfather_id = result.scalar_one_or_none()
|
godfather_id = result.scalar_one_or_none()
|
||||||
if not godfather_id:
|
if not godfather_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -141,8 +141,8 @@ async def create_invite(
|
||||||
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
|
if attempt == MAX_INVITE_COLLISION_RETRIES - 1:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to generate unique invite code. Please try again.",
|
detail="Failed to generate unique invite code. Try again.",
|
||||||
)
|
) from None
|
||||||
|
|
||||||
if invite is None:
|
if invite is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -156,7 +156,9 @@ async def create_invite(
|
||||||
async def list_all_invites(
|
async def list_all_invites(
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
per_page: int = Query(10, ge=1, le=100),
|
per_page: int = Query(10, ge=1, le=100),
|
||||||
status_filter: str | None = Query(None, alias="status", description="Filter by status: ready, spent, revoked"),
|
status_filter: str | None = Query(
|
||||||
|
None, alias="status", description="Filter by status: ready, spent, revoked"
|
||||||
|
),
|
||||||
godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
|
godfather_id: int | None = Query(None, description="Filter by godfather user ID"),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
_current_user: User = Depends(require_permission(Permission.MANAGE_INVITES)),
|
||||||
|
|
@ -175,8 +177,9 @@ async def list_all_invites(
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Invalid status: {status_filter}. Must be ready, spent, or revoked",
|
detail=f"Invalid status: {status_filter}. "
|
||||||
)
|
"Must be ready, spent, or revoked",
|
||||||
|
) from None
|
||||||
|
|
||||||
if godfather_id:
|
if godfather_id:
|
||||||
query = query.where(Invite.godfather_id == godfather_id)
|
query = query.where(Invite.godfather_id == godfather_id)
|
||||||
|
|
@ -224,11 +227,12 @@ async def revoke_invite(
|
||||||
if invite.status != InviteStatus.READY:
|
if invite.status != InviteStatus.READY:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Cannot revoke invite with status '{invite.status.value}'. Only READY invites can be revoked.",
|
detail=f"Cannot revoke invite with status '{invite.status.value}'. "
|
||||||
|
"Only READY invites can be revoked.",
|
||||||
)
|
)
|
||||||
|
|
||||||
invite.status = InviteStatus.REVOKED
|
invite.status = InviteStatus.REVOKED
|
||||||
invite.revoked_at = datetime.now(timezone.utc)
|
invite.revoked_at = datetime.now(UTC)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(invite)
|
await db.refresh(invite)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
"""Meta endpoints for shared constants."""
|
"""Meta endpoints for shared constants."""
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from models import Permission, InviteStatus, ROLE_ADMIN, ROLE_REGULAR
|
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, Permission
|
||||||
from schemas import ConstantsResponse
|
from schemas import ConstantsResponse
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/meta", tags=["meta"])
|
router = APIRouter(prefix="/api/meta", tags=["meta"])
|
||||||
|
|
@ -15,4 +16,3 @@ async def get_constants() -> ConstantsResponse:
|
||||||
roles=[ROLE_ADMIN, ROLE_REGULAR],
|
roles=[ROLE_ADMIN, ROLE_REGULAR],
|
||||||
invite_statuses=[s.value for s in InviteStatus],
|
invite_statuses=[s.value for s in InviteStatus],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,15 @@
|
||||||
"""Profile routes for user contact details."""
|
"""Profile routes for user contact details."""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from auth import get_current_user
|
from auth import get_current_user
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import User, ROLE_REGULAR
|
from models import ROLE_REGULAR, User
|
||||||
from schemas import ProfileResponse, ProfileUpdate
|
from schemas import ProfileResponse, ProfileUpdate
|
||||||
from validation import validate_profile_fields
|
from validation import validate_profile_fields
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/profile", tags=["profile"])
|
router = APIRouter(prefix="/api/profile", tags=["profile"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -29,9 +29,7 @@ async def get_godfather_email(db: AsyncSession, godfather_id: int | None) -> str
|
||||||
"""Get the email of a godfather user by ID."""
|
"""Get the email of a godfather user by ID."""
|
||||||
if not godfather_id:
|
if not godfather_id:
|
||||||
return None
|
return None
|
||||||
result = await db.execute(
|
result = await db.execute(select(User.email).where(User.id == godfather_id))
|
||||||
select(User.email).where(User.id == godfather_id)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
"""Sum calculation routes."""
|
"""Sum calculation routes."""
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from auth import require_permission
|
from auth import require_permission
|
||||||
from database import get_db
|
from database import get_db
|
||||||
from models import User, SumRecord, Permission
|
from models import Permission, SumRecord, User
|
||||||
from schemas import SumRequest, SumResponse
|
from schemas import SumRequest, SumResponse
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/sum", tags=["sum"])
|
router = APIRouter(prefix="/api/sum", tags=["sum"])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Pydantic schemas for API request/response models."""
|
"""Pydantic schemas for API request/response models."""
|
||||||
from datetime import datetime, date, time
|
|
||||||
|
from datetime import date, datetime, time
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, EmailStr, field_validator
|
from pydantic import BaseModel, EmailStr, field_validator
|
||||||
|
|
@ -9,6 +10,7 @@ from shared_constants import NOTE_MAX_LENGTH
|
||||||
|
|
||||||
class UserCredentials(BaseModel):
|
class UserCredentials(BaseModel):
|
||||||
"""Base model for user email/password."""
|
"""Base model for user email/password."""
|
||||||
|
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
@ -19,6 +21,7 @@ UserLogin = UserCredentials
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
"""Response model for authenticated user info."""
|
"""Response model for authenticated user info."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
email: str
|
email: str
|
||||||
roles: list[str]
|
roles: list[str]
|
||||||
|
|
@ -27,6 +30,7 @@ class UserResponse(BaseModel):
|
||||||
|
|
||||||
class RegisterWithInvite(BaseModel):
|
class RegisterWithInvite(BaseModel):
|
||||||
"""Request model for registration with invite."""
|
"""Request model for registration with invite."""
|
||||||
|
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: str
|
password: str
|
||||||
invite_identifier: str
|
invite_identifier: str
|
||||||
|
|
@ -34,12 +38,14 @@ class RegisterWithInvite(BaseModel):
|
||||||
|
|
||||||
class SumRequest(BaseModel):
|
class SumRequest(BaseModel):
|
||||||
"""Request model for sum calculation."""
|
"""Request model for sum calculation."""
|
||||||
|
|
||||||
a: float
|
a: float
|
||||||
b: float
|
b: float
|
||||||
|
|
||||||
|
|
||||||
class SumResponse(BaseModel):
|
class SumResponse(BaseModel):
|
||||||
"""Response model for sum calculation."""
|
"""Response model for sum calculation."""
|
||||||
|
|
||||||
a: float
|
a: float
|
||||||
b: float
|
b: float
|
||||||
result: float
|
result: float
|
||||||
|
|
@ -47,6 +53,7 @@ class SumResponse(BaseModel):
|
||||||
|
|
||||||
class CounterRecordResponse(BaseModel):
|
class CounterRecordResponse(BaseModel):
|
||||||
"""Response model for a counter audit record."""
|
"""Response model for a counter audit record."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
user_email: str
|
user_email: str
|
||||||
value_before: int
|
value_before: int
|
||||||
|
|
@ -56,6 +63,7 @@ class CounterRecordResponse(BaseModel):
|
||||||
|
|
||||||
class SumRecordResponse(BaseModel):
|
class SumRecordResponse(BaseModel):
|
||||||
"""Response model for a sum audit record."""
|
"""Response model for a sum audit record."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
user_email: str
|
user_email: str
|
||||||
a: float
|
a: float
|
||||||
|
|
@ -69,6 +77,7 @@ RecordT = TypeVar("RecordT", bound=BaseModel)
|
||||||
|
|
||||||
class PaginatedResponse(BaseModel, Generic[RecordT]):
|
class PaginatedResponse(BaseModel, Generic[RecordT]):
|
||||||
"""Generic paginated response wrapper."""
|
"""Generic paginated response wrapper."""
|
||||||
|
|
||||||
records: list[RecordT]
|
records: list[RecordT]
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
|
|
@ -82,6 +91,7 @@ PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
|
||||||
|
|
||||||
class ProfileResponse(BaseModel):
|
class ProfileResponse(BaseModel):
|
||||||
"""Response model for profile data."""
|
"""Response model for profile data."""
|
||||||
|
|
||||||
contact_email: str | None
|
contact_email: str | None
|
||||||
telegram: str | None
|
telegram: str | None
|
||||||
signal: str | None
|
signal: str | None
|
||||||
|
|
@ -91,6 +101,7 @@ class ProfileResponse(BaseModel):
|
||||||
|
|
||||||
class ProfileUpdate(BaseModel):
|
class ProfileUpdate(BaseModel):
|
||||||
"""Request model for updating profile."""
|
"""Request model for updating profile."""
|
||||||
|
|
||||||
contact_email: str | None = None
|
contact_email: str | None = None
|
||||||
telegram: str | None = None
|
telegram: str | None = None
|
||||||
signal: str | None = None
|
signal: str | None = None
|
||||||
|
|
@ -99,6 +110,7 @@ class ProfileUpdate(BaseModel):
|
||||||
|
|
||||||
class InviteCheckResponse(BaseModel):
|
class InviteCheckResponse(BaseModel):
|
||||||
"""Response for invite check endpoint."""
|
"""Response for invite check endpoint."""
|
||||||
|
|
||||||
valid: bool
|
valid: bool
|
||||||
status: str | None = None
|
status: str | None = None
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|
@ -106,11 +118,13 @@ class InviteCheckResponse(BaseModel):
|
||||||
|
|
||||||
class InviteCreate(BaseModel):
|
class InviteCreate(BaseModel):
|
||||||
"""Request model for creating an invite."""
|
"""Request model for creating an invite."""
|
||||||
|
|
||||||
godfather_id: int
|
godfather_id: int
|
||||||
|
|
||||||
|
|
||||||
class InviteResponse(BaseModel):
|
class InviteResponse(BaseModel):
|
||||||
"""Response model for invite data (admin view)."""
|
"""Response model for invite data (admin view)."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
identifier: str
|
identifier: str
|
||||||
godfather_id: int
|
godfather_id: int
|
||||||
|
|
@ -125,6 +139,7 @@ class InviteResponse(BaseModel):
|
||||||
|
|
||||||
class UserInviteResponse(BaseModel):
|
class UserInviteResponse(BaseModel):
|
||||||
"""Response model for a user's invite (simpler than admin view)."""
|
"""Response model for a user's invite (simpler than admin view)."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
identifier: str
|
identifier: str
|
||||||
status: str
|
status: str
|
||||||
|
|
@ -138,6 +153,7 @@ PaginatedInviteRecords = PaginatedResponse[InviteResponse]
|
||||||
|
|
||||||
class AdminUserResponse(BaseModel):
|
class AdminUserResponse(BaseModel):
|
||||||
"""Minimal user info for admin dropdowns."""
|
"""Minimal user info for admin dropdowns."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
email: str
|
email: str
|
||||||
|
|
||||||
|
|
@ -146,8 +162,10 @@ class AdminUserResponse(BaseModel):
|
||||||
# Availability Schemas
|
# Availability Schemas
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TimeSlot(BaseModel):
|
class TimeSlot(BaseModel):
|
||||||
"""A single time slot (start and end time)."""
|
"""A single time slot (start and end time)."""
|
||||||
|
|
||||||
start_time: time
|
start_time: time
|
||||||
end_time: time
|
end_time: time
|
||||||
|
|
||||||
|
|
@ -164,23 +182,27 @@ class TimeSlot(BaseModel):
|
||||||
|
|
||||||
class AvailabilityDay(BaseModel):
|
class AvailabilityDay(BaseModel):
|
||||||
"""Availability for a single day."""
|
"""Availability for a single day."""
|
||||||
|
|
||||||
date: date
|
date: date
|
||||||
slots: list[TimeSlot]
|
slots: list[TimeSlot]
|
||||||
|
|
||||||
|
|
||||||
class AvailabilityResponse(BaseModel):
|
class AvailabilityResponse(BaseModel):
|
||||||
"""Response model for availability query."""
|
"""Response model for availability query."""
|
||||||
|
|
||||||
days: list[AvailabilityDay]
|
days: list[AvailabilityDay]
|
||||||
|
|
||||||
|
|
||||||
class SetAvailabilityRequest(BaseModel):
|
class SetAvailabilityRequest(BaseModel):
|
||||||
"""Request to set availability for a specific date."""
|
"""Request to set availability for a specific date."""
|
||||||
|
|
||||||
date: date
|
date: date
|
||||||
slots: list[TimeSlot]
|
slots: list[TimeSlot]
|
||||||
|
|
||||||
|
|
||||||
class CopyAvailabilityRequest(BaseModel):
|
class CopyAvailabilityRequest(BaseModel):
|
||||||
"""Request to copy availability from one day to others."""
|
"""Request to copy availability from one day to others."""
|
||||||
|
|
||||||
source_date: date
|
source_date: date
|
||||||
target_dates: list[date]
|
target_dates: list[date]
|
||||||
|
|
||||||
|
|
@ -189,20 +211,24 @@ class CopyAvailabilityRequest(BaseModel):
|
||||||
# Booking Schemas
|
# Booking Schemas
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class BookableSlot(BaseModel):
|
class BookableSlot(BaseModel):
|
||||||
"""A bookable 15-minute slot."""
|
"""A bookable 15-minute slot."""
|
||||||
|
|
||||||
start_time: datetime
|
start_time: datetime
|
||||||
end_time: datetime
|
end_time: datetime
|
||||||
|
|
||||||
|
|
||||||
class AvailableSlotsResponse(BaseModel):
|
class AvailableSlotsResponse(BaseModel):
|
||||||
"""Response for available slots on a given date."""
|
"""Response for available slots on a given date."""
|
||||||
|
|
||||||
date: date
|
date: date
|
||||||
slots: list[BookableSlot]
|
slots: list[BookableSlot]
|
||||||
|
|
||||||
|
|
||||||
class BookingRequest(BaseModel):
|
class BookingRequest(BaseModel):
|
||||||
"""Request to book an appointment."""
|
"""Request to book an appointment."""
|
||||||
|
|
||||||
slot_start: datetime
|
slot_start: datetime
|
||||||
note: str | None = None
|
note: str | None = None
|
||||||
|
|
||||||
|
|
@ -216,6 +242,7 @@ class BookingRequest(BaseModel):
|
||||||
|
|
||||||
class AppointmentResponse(BaseModel):
|
class AppointmentResponse(BaseModel):
|
||||||
"""Response model for an appointment."""
|
"""Response model for an appointment."""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
user_id: int
|
user_id: int
|
||||||
user_email: str
|
user_email: str
|
||||||
|
|
@ -234,8 +261,10 @@ PaginatedAppointments = PaginatedResponse[AppointmentResponse]
|
||||||
# Meta/Constants Schemas
|
# Meta/Constants Schemas
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ConstantsResponse(BaseModel):
|
class ConstantsResponse(BaseModel):
|
||||||
"""Response model for shared constants."""
|
"""Response model for shared constants."""
|
||||||
|
|
||||||
permissions: list[str]
|
permissions: list[str]
|
||||||
roles: list[str]
|
roles: list[str]
|
||||||
invite_statuses: list[str]
|
invite_statuses: list[str]
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,21 @@
|
||||||
"""Seed the database with roles, permissions, and dev users."""
|
"""Seed the database with roles, permissions, and dev users."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database import engine, async_session, Base
|
|
||||||
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
|
||||||
from auth import get_password_hash
|
from auth import get_password_hash
|
||||||
|
from database import Base, async_session, engine
|
||||||
|
from models import (
|
||||||
|
ROLE_ADMIN,
|
||||||
|
ROLE_DEFINITIONS,
|
||||||
|
ROLE_REGULAR,
|
||||||
|
Permission,
|
||||||
|
Role,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
|
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
|
||||||
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
|
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
|
||||||
|
|
@ -14,7 +23,9 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
|
||||||
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
|
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
|
||||||
|
|
||||||
|
|
||||||
async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role:
|
async def upsert_role(
|
||||||
|
db: AsyncSession, name: str, description: str, permissions: list[Permission]
|
||||||
|
) -> Role:
|
||||||
"""Create or update a role with the given permissions."""
|
"""Create or update a role with the given permissions."""
|
||||||
result = await db.execute(select(Role).where(Role.name == name))
|
result = await db.execute(select(Role).where(Role.name == name))
|
||||||
role = result.scalar_one_or_none()
|
role = result.scalar_one_or_none()
|
||||||
|
|
@ -35,7 +46,9 @@ async def upsert_role(db: AsyncSession, name: str, description: str, permissions
|
||||||
return role
|
return role
|
||||||
|
|
||||||
|
|
||||||
async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User:
|
async def upsert_user(
|
||||||
|
db: AsyncSession, email: str, password: str, role_names: list[str]
|
||||||
|
) -> User:
|
||||||
"""Create or update a user with the given credentials and roles."""
|
"""Create or update a user with the given credentials and roles."""
|
||||||
result = await db.execute(select(User).where(User.email == email))
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
"""Load shared constants from shared/constants.json."""
|
"""Load shared constants from shared/constants.json."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -10,4 +11,3 @@ SLOT_DURATION_MINUTES: int = _constants["booking"]["slotDurationMinutes"]
|
||||||
MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"]
|
MIN_ADVANCE_DAYS: int = _constants["booking"]["minAdvanceDays"]
|
||||||
MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"]
|
MAX_ADVANCE_DAYS: int = _constants["booking"]["maxAdvanceDays"]
|
||||||
NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"]
|
NOTE_MAX_LENGTH: int = _constants["booking"]["noteMaxLength"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,17 +7,17 @@ os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from auth import get_password_hash
|
||||||
from database import Base, get_db
|
from database import Base, get_db
|
||||||
from main import app
|
from main import app
|
||||||
from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
from models import ROLE_ADMIN, ROLE_DEFINITIONS, ROLE_REGULAR, Role, User
|
||||||
from auth import get_password_hash
|
|
||||||
from tests.helpers import unique_email
|
from tests.helpers import unique_email
|
||||||
|
|
||||||
TEST_DATABASE_URL = os.getenv(
|
TEST_DATABASE_URL = os.getenv(
|
||||||
"TEST_DATABASE_URL",
|
"TEST_DATABASE_URL",
|
||||||
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
|
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -91,7 +91,9 @@ async def create_user_with_roles(
|
||||||
result = await db.execute(select(Role).where(Role.name == role_name))
|
result = await db.execute(select(Role).where(Role.name == role_name))
|
||||||
role = result.scalar_one_or_none()
|
role = result.scalar_one_or_none()
|
||||||
if not role:
|
if not role:
|
||||||
raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?")
|
raise ValueError(
|
||||||
|
f"Role '{role_name}' not found. Did you run setup_roles()?"
|
||||||
|
)
|
||||||
roles.append(role)
|
roles.append(role)
|
||||||
|
|
||||||
user = User(
|
user = User(
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@ import uuid
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from models import User, Invite, InviteStatus
|
|
||||||
from invite_utils import generate_invite_identifier
|
from invite_utils import generate_invite_identifier
|
||||||
|
from models import Invite, InviteStatus, User
|
||||||
|
|
||||||
|
|
||||||
def unique_email(prefix: str = "test") -> str:
|
def unique_email(prefix: str = "test") -> str:
|
||||||
|
|
@ -67,7 +67,9 @@ async def create_invite_for_registration(db: AsyncSession, godfather_email: str)
|
||||||
godfather = result.scalar_one_or_none()
|
godfather = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not godfather:
|
if not godfather:
|
||||||
raise ValueError(f"Godfather user with email '{godfather_email}' not found. "
|
raise ValueError(
|
||||||
"Create the user first using create_user_with_roles().")
|
f"Godfather user with email '{godfather_email}' not found. "
|
||||||
|
"Create the user first using create_user_with_roles()."
|
||||||
|
)
|
||||||
|
|
||||||
return await create_invite_for_godfather(db, godfather.id)
|
return await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,13 @@
|
||||||
Note: Registration now requires an invite code. Tests that need to register
|
Note: Registration now requires an invite code. Tests that need to register
|
||||||
users will create invites first via the helper function.
|
users will create invites first via the helper function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from auth import COOKIE_NAME
|
from auth import COOKIE_NAME
|
||||||
from models import ROLE_REGULAR
|
from models import ROLE_REGULAR
|
||||||
from tests.helpers import unique_email, create_invite_for_godfather
|
|
||||||
from tests.conftest import create_user_with_roles
|
from tests.conftest import create_user_with_roles
|
||||||
|
from tests.helpers import create_invite_for_godfather, unique_email
|
||||||
|
|
||||||
|
|
||||||
# Registration tests (with invite)
|
# Registration tests (with invite)
|
||||||
|
|
@ -19,7 +20,9 @@ async def test_register_success(client_factory):
|
||||||
|
|
||||||
# Create godfather user and invite
|
# Create godfather user and invite
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("godfather"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("godfather"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
response = await client_factory.post(
|
response = await client_factory.post(
|
||||||
|
|
@ -49,7 +52,9 @@ async def test_register_duplicate_email(client_factory):
|
||||||
|
|
||||||
# Create godfather and two invites
|
# Create godfather and two invites
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
|
|
@ -80,7 +85,9 @@ async def test_register_duplicate_email(client_factory):
|
||||||
async def test_register_invalid_email(client_factory):
|
async def test_register_invalid_email(client_factory):
|
||||||
"""Cannot register with invalid email format."""
|
"""Cannot register with invalid email format."""
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
response = await client_factory.post(
|
response = await client_factory.post(
|
||||||
|
|
@ -138,7 +145,9 @@ async def test_login_success(client_factory):
|
||||||
email = unique_email("login")
|
email = unique_email("login")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
await client_factory.post(
|
await client_factory.post(
|
||||||
|
|
@ -167,7 +176,9 @@ async def test_login_wrong_password(client_factory):
|
||||||
email = unique_email("wrongpass")
|
email = unique_email("wrongpass")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
await client_factory.post(
|
await client_factory.post(
|
||||||
|
|
@ -221,7 +232,9 @@ async def test_get_me_success(client_factory):
|
||||||
email = unique_email("me")
|
email = unique_email("me")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
reg_response = await client_factory.post(
|
reg_response = await client_factory.post(
|
||||||
|
|
@ -255,7 +268,9 @@ async def test_get_me_no_cookie(client):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_me_invalid_cookie(client_factory):
|
async def test_get_me_invalid_cookie(client_factory):
|
||||||
"""Cannot get current user with invalid cookie."""
|
"""Cannot get current user with invalid cookie."""
|
||||||
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed:
|
async with client_factory.create(
|
||||||
|
cookies={COOKIE_NAME: "invalidtoken123"}
|
||||||
|
) as authed:
|
||||||
response = await authed.get("/api/auth/me")
|
response = await authed.get("/api/auth/me")
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
assert response.json()["detail"] == "Invalid authentication credentials"
|
assert response.json()["detail"] == "Invalid authentication credentials"
|
||||||
|
|
@ -277,7 +292,9 @@ async def test_cookie_from_register_works_for_me(client_factory):
|
||||||
email = unique_email("tokentest")
|
email = unique_email("tokentest")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
reg_response = await client_factory.post(
|
reg_response = await client_factory.post(
|
||||||
|
|
@ -303,7 +320,9 @@ async def test_cookie_from_login_works_for_me(client_factory):
|
||||||
email = unique_email("logintoken")
|
email = unique_email("logintoken")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
await client_factory.post(
|
await client_factory.post(
|
||||||
|
|
@ -335,7 +354,9 @@ async def test_multiple_users_isolated(client_factory):
|
||||||
email2 = unique_email("user2")
|
email2 = unique_email("user2")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
|
|
@ -377,7 +398,9 @@ async def test_password_is_hashed(client_factory):
|
||||||
email = unique_email("hashtest")
|
email = unique_email("hashtest")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
await client_factory.post(
|
await client_factory.post(
|
||||||
|
|
@ -401,7 +424,9 @@ async def test_case_sensitive_password(client_factory):
|
||||||
email = unique_email("casetest")
|
email = unique_email("casetest")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
await client_factory.post(
|
await client_factory.post(
|
||||||
|
|
@ -426,7 +451,9 @@ async def test_logout_success(client_factory):
|
||||||
email = unique_email("logout")
|
email = unique_email("logout")
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
reg_response = await client_factory.post(
|
reg_response = await client_factory.post(
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ Availability API Tests
|
||||||
|
|
||||||
Tests for the admin availability management endpoints.
|
Tests for the admin availability management endpoints.
|
||||||
"""
|
"""
|
||||||
from datetime import date, time, timedelta
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,6 +21,7 @@ def in_days(n: int) -> date:
|
||||||
# Permission Tests
|
# Permission Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestAvailabilityPermissions:
|
class TestAvailabilityPermissions:
|
||||||
"""Test that only admins can access availability endpoints."""
|
"""Test that only admins can access availability endpoints."""
|
||||||
|
|
||||||
|
|
@ -44,7 +47,9 @@ class TestAvailabilityPermissions:
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_cannot_get_availability(self, client_factory, regular_user):
|
async def test_regular_user_cannot_get_availability(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
"/api/admin/availability",
|
"/api/admin/availability",
|
||||||
|
|
@ -53,7 +58,9 @@ class TestAvailabilityPermissions:
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_cannot_set_availability(self, client_factory, regular_user):
|
async def test_regular_user_cannot_set_availability(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.put(
|
response = await client.put(
|
||||||
"/api/admin/availability",
|
"/api/admin/availability",
|
||||||
|
|
@ -88,6 +95,7 @@ class TestAvailabilityPermissions:
|
||||||
# Set Availability Tests
|
# Set Availability Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestSetAvailability:
|
class TestSetAvailability:
|
||||||
"""Test setting availability for a date."""
|
"""Test setting availability for a date."""
|
||||||
|
|
||||||
|
|
@ -128,7 +136,9 @@ class TestSetAvailability:
|
||||||
assert len(data["slots"]) == 2
|
assert len(data["slots"]) == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_empty_slots_clears_availability(self, client_factory, admin_user):
|
async def test_set_empty_slots_clears_availability(
|
||||||
|
self, client_factory, admin_user
|
||||||
|
):
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
# First set some availability
|
# First set some availability
|
||||||
await client.put(
|
await client.put(
|
||||||
|
|
@ -162,7 +172,7 @@ class TestSetAvailability:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace with different slots
|
# Replace with different slots
|
||||||
response = await client.put(
|
await client.put(
|
||||||
"/api/admin/availability",
|
"/api/admin/availability",
|
||||||
json={
|
json={
|
||||||
"date": str(tomorrow()),
|
"date": str(tomorrow()),
|
||||||
|
|
@ -186,6 +196,7 @@ class TestSetAvailability:
|
||||||
# Validation Tests
|
# Validation Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestAvailabilityValidation:
|
class TestAvailabilityValidation:
|
||||||
"""Test validation rules for availability."""
|
"""Test validation rules for availability."""
|
||||||
|
|
||||||
|
|
@ -283,6 +294,7 @@ class TestAvailabilityValidation:
|
||||||
# Get Availability Tests
|
# Get Availability Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestGetAvailability:
|
class TestGetAvailability:
|
||||||
"""Test retrieving availability."""
|
"""Test retrieving availability."""
|
||||||
|
|
||||||
|
|
@ -360,6 +372,7 @@ class TestGetAvailability:
|
||||||
# Copy Availability Tests
|
# Copy Availability Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestCopyAvailability:
|
class TestCopyAvailability:
|
||||||
"""Test copying availability from one day to others."""
|
"""Test copying availability from one day to others."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ Booking API Tests
|
||||||
|
|
||||||
Tests for the user booking endpoints.
|
Tests for the user booking endpoints.
|
||||||
"""
|
"""
|
||||||
from datetime import date, datetime, timedelta, timezone
|
|
||||||
|
from datetime import UTC, date, datetime, timedelta
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from models import Appointment, AppointmentStatus
|
from models import Appointment, AppointmentStatus
|
||||||
|
|
@ -21,11 +23,14 @@ def in_days(n: int) -> date:
|
||||||
# Permission Tests
|
# Permission Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestBookingPermissions:
|
class TestBookingPermissions:
|
||||||
"""Test that only regular users can book appointments."""
|
"""Test that only regular users can book appointments."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_can_get_slots(self, client_factory, regular_user, admin_user):
|
async def test_regular_user_can_get_slots(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Regular user can get available slots."""
|
"""Regular user can get available slots."""
|
||||||
# First, admin sets up availability
|
# First, admin sets up availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -39,12 +44,16 @@ class TestBookingPermissions:
|
||||||
|
|
||||||
# Regular user gets slots
|
# Regular user gets slots
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
response = await client.get(
|
||||||
|
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_can_book(self, client_factory, regular_user, admin_user):
|
async def test_regular_user_can_book(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Regular user can book an appointment."""
|
"""Regular user can book an appointment."""
|
||||||
# Admin sets up availability
|
# Admin sets up availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -69,7 +78,9 @@ class TestBookingPermissions:
|
||||||
async def test_admin_cannot_get_slots(self, client_factory, admin_user):
|
async def test_admin_cannot_get_slots(self, client_factory, admin_user):
|
||||||
"""Admin cannot access booking slots endpoint."""
|
"""Admin cannot access booking slots endpoint."""
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
response = await client.get(
|
||||||
|
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
@ -96,7 +107,9 @@ class TestBookingPermissions:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unauthenticated_cannot_get_slots(self, client):
|
async def test_unauthenticated_cannot_get_slots(self, client):
|
||||||
"""Unauthenticated user cannot get slots."""
|
"""Unauthenticated user cannot get slots."""
|
||||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
response = await client.get(
|
||||||
|
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||||
|
)
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -113,6 +126,7 @@ class TestBookingPermissions:
|
||||||
# Get Slots Tests
|
# Get Slots Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestGetSlots:
|
class TestGetSlots:
|
||||||
"""Test getting available booking slots."""
|
"""Test getting available booking slots."""
|
||||||
|
|
||||||
|
|
@ -120,7 +134,9 @@ class TestGetSlots:
|
||||||
async def test_get_slots_no_availability(self, client_factory, regular_user):
|
async def test_get_slots_no_availability(self, client_factory, regular_user):
|
||||||
"""Returns empty slots when no availability set."""
|
"""Returns empty slots when no availability set."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
response = await client.get(
|
||||||
|
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -128,7 +144,9 @@ class TestGetSlots:
|
||||||
assert data["slots"] == []
|
assert data["slots"] == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_slots_expands_to_15min(self, client_factory, regular_user, admin_user):
|
async def test_get_slots_expands_to_15min(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Availability is expanded into 15-minute slots."""
|
"""Availability is expanded into 15-minute slots."""
|
||||||
# Admin sets 1-hour availability
|
# Admin sets 1-hour availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -142,7 +160,9 @@ class TestGetSlots:
|
||||||
|
|
||||||
# User gets slots - should be 4 x 15-minute slots
|
# User gets slots - should be 4 x 15-minute slots
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
response = await client.get(
|
||||||
|
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -156,7 +176,9 @@ class TestGetSlots:
|
||||||
assert "10:00:00" in data["slots"][3]["end_time"]
|
assert "10:00:00" in data["slots"][3]["end_time"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_slots_excludes_booked(self, client_factory, regular_user, admin_user):
|
async def test_get_slots_excludes_booked(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Already booked slots are excluded from available slots."""
|
"""Already booked slots are excluded from available slots."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -176,7 +198,9 @@ class TestGetSlots:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get slots again - should have 3 left
|
# Get slots again - should have 3 left
|
||||||
response = await client.get("/api/booking/slots", params={"date": str(tomorrow())})
|
response = await client.get(
|
||||||
|
"/api/booking/slots", params={"date": str(tomorrow())}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -189,6 +213,7 @@ class TestGetSlots:
|
||||||
# Booking Tests
|
# Booking Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestCreateBooking:
|
class TestCreateBooking:
|
||||||
"""Test creating bookings."""
|
"""Test creating bookings."""
|
||||||
|
|
||||||
|
|
@ -248,7 +273,9 @@ class TestCreateBooking:
|
||||||
assert data["note"] is None
|
assert data["note"] is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_double_book_slot(self, client_factory, regular_user, admin_user, alt_regular_user):
|
async def test_cannot_double_book_slot(
|
||||||
|
self, client_factory, regular_user, admin_user, alt_regular_user
|
||||||
|
):
|
||||||
"""Cannot book a slot that's already booked."""
|
"""Cannot book a slot that's already booked."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -279,7 +306,9 @@ class TestCreateBooking:
|
||||||
assert "already been booked" in response.json()["detail"]
|
assert "already been booked" in response.json()["detail"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_book_outside_availability(self, client_factory, regular_user, admin_user):
|
async def test_cannot_book_outside_availability(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Cannot book a slot outside of availability."""
|
"""Cannot book a slot outside of availability."""
|
||||||
# Admin sets availability for morning only
|
# Admin sets availability for morning only
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -306,6 +335,7 @@ class TestCreateBooking:
|
||||||
# Date Validation Tests
|
# Date Validation Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestBookingDateValidation:
|
class TestBookingDateValidation:
|
||||||
"""Test date validation for bookings."""
|
"""Test date validation for bookings."""
|
||||||
|
|
||||||
|
|
@ -319,7 +349,10 @@ class TestBookingDateValidation:
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "past" in response.json()["detail"].lower() or "today" in response.json()["detail"].lower()
|
assert (
|
||||||
|
"past" in response.json()["detail"].lower()
|
||||||
|
or "today" in response.json()["detail"].lower()
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_book_past_date(self, client_factory, regular_user):
|
async def test_cannot_book_past_date(self, client_factory, regular_user):
|
||||||
|
|
@ -350,7 +383,9 @@ class TestBookingDateValidation:
|
||||||
async def test_cannot_get_slots_today(self, client_factory, regular_user):
|
async def test_cannot_get_slots_today(self, client_factory, regular_user):
|
||||||
"""Cannot get slots for today."""
|
"""Cannot get slots for today."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/booking/slots", params={"date": str(date.today())})
|
response = await client.get(
|
||||||
|
"/api/booking/slots", params={"date": str(date.today())}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
@ -359,7 +394,9 @@ class TestBookingDateValidation:
|
||||||
"""Cannot get slots for past date."""
|
"""Cannot get slots for past date."""
|
||||||
yesterday = date.today() - timedelta(days=1)
|
yesterday = date.today() - timedelta(days=1)
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/booking/slots", params={"date": str(yesterday)})
|
response = await client.get(
|
||||||
|
"/api/booking/slots", params={"date": str(yesterday)}
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
@ -368,11 +405,14 @@ class TestBookingDateValidation:
|
||||||
# Time Validation Tests
|
# Time Validation Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestBookingTimeValidation:
|
class TestBookingTimeValidation:
|
||||||
"""Test time validation for bookings."""
|
"""Test time validation for bookings."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_slot_must_be_15min_boundary(self, client_factory, regular_user, admin_user):
|
async def test_slot_must_be_15min_boundary(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Slot start time must be on 15-minute boundary."""
|
"""Slot start time must be on 15-minute boundary."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -399,6 +439,7 @@ class TestBookingTimeValidation:
|
||||||
# Note Validation Tests
|
# Note Validation Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestBookingNoteValidation:
|
class TestBookingNoteValidation:
|
||||||
"""Test note validation for bookings."""
|
"""Test note validation for bookings."""
|
||||||
|
|
||||||
|
|
@ -426,7 +467,9 @@ class TestBookingNoteValidation:
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_note_exactly_144_chars(self, client_factory, regular_user, admin_user):
|
async def test_note_exactly_144_chars(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Note of exactly 144 characters is allowed."""
|
"""Note of exactly 144 characters is allowed."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -454,6 +497,7 @@ class TestBookingNoteValidation:
|
||||||
# User Appointments Tests
|
# User Appointments Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestUserAppointments:
|
class TestUserAppointments:
|
||||||
"""Test user appointments endpoints."""
|
"""Test user appointments endpoints."""
|
||||||
|
|
||||||
|
|
@ -467,7 +511,9 @@ class TestUserAppointments:
|
||||||
assert response.json() == []
|
assert response.json() == []
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_my_appointments_with_bookings(self, client_factory, regular_user, admin_user):
|
async def test_get_my_appointments_with_bookings(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Returns user's appointments."""
|
"""Returns user's appointments."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -502,7 +548,9 @@ class TestUserAppointments:
|
||||||
assert "Second" in notes
|
assert "Second" in notes
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_cannot_view_user_appointments(self, client_factory, admin_user):
|
async def test_admin_cannot_view_user_appointments(
|
||||||
|
self, client_factory, admin_user
|
||||||
|
):
|
||||||
"""Admin cannot access user appointments endpoint."""
|
"""Admin cannot access user appointments endpoint."""
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
response = await client.get("/api/appointments")
|
response = await client.get("/api/appointments")
|
||||||
|
|
@ -520,7 +568,9 @@ class TestCancelAppointment:
|
||||||
"""Test cancelling appointments."""
|
"""Test cancelling appointments."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cancel_own_appointment(self, client_factory, regular_user, admin_user):
|
async def test_cancel_own_appointment(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""User can cancel their own appointment."""
|
"""User can cancel their own appointment."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -549,7 +599,9 @@ class TestCancelAppointment:
|
||||||
assert data["cancelled_at"] is not None
|
assert data["cancelled_at"] is not None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_cancel_others_appointment(self, client_factory, regular_user, alt_regular_user, admin_user):
|
async def test_cannot_cancel_others_appointment(
|
||||||
|
self, client_factory, regular_user, alt_regular_user, admin_user
|
||||||
|
):
|
||||||
"""User cannot cancel another user's appointment."""
|
"""User cannot cancel another user's appointment."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -577,7 +629,9 @@ class TestCancelAppointment:
|
||||||
assert "another user" in response.json()["detail"].lower()
|
assert "another user" in response.json()["detail"].lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_cancel_nonexistent_appointment(self, client_factory, regular_user):
|
async def test_cannot_cancel_nonexistent_appointment(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""Returns 404 for non-existent appointment."""
|
"""Returns 404 for non-existent appointment."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.post("/api/appointments/99999/cancel")
|
response = await client.post("/api/appointments/99999/cancel")
|
||||||
|
|
@ -585,7 +639,9 @@ class TestCancelAppointment:
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user):
|
async def test_cannot_cancel_already_cancelled(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Cannot cancel an already cancelled appointment."""
|
"""Cannot cancel an already cancelled appointment."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -613,7 +669,9 @@ class TestCancelAppointment:
|
||||||
assert "cancelled_by_user" in response.json()["detail"]
|
assert "cancelled_by_user" in response.json()["detail"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_cannot_use_user_cancel_endpoint(self, client_factory, admin_user):
|
async def test_admin_cannot_use_user_cancel_endpoint(
|
||||||
|
self, client_factory, admin_user
|
||||||
|
):
|
||||||
"""Admin cannot use user cancel endpoint."""
|
"""Admin cannot use user cancel endpoint."""
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
response = await client.post("/api/appointments/1/cancel")
|
response = await client.post("/api/appointments/1/cancel")
|
||||||
|
|
@ -621,7 +679,9 @@ class TestCancelAppointment:
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cancelled_slot_becomes_available(self, client_factory, regular_user, admin_user):
|
async def test_cancelled_slot_becomes_available(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""After cancelling, the slot becomes available again."""
|
"""After cancelling, the slot becomes available again."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -663,7 +723,7 @@ class TestCancelAppointment:
|
||||||
"""User cannot cancel a past appointment."""
|
"""User cannot cancel a past appointment."""
|
||||||
# Create a past appointment directly in DB
|
# Create a past appointment directly in DB
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||||
appointment = Appointment(
|
appointment = Appointment(
|
||||||
user_id=regular_user["user"]["id"],
|
user_id=regular_user["user"]["id"],
|
||||||
slot_start=past_time,
|
slot_start=past_time,
|
||||||
|
|
@ -687,11 +747,14 @@ class TestCancelAppointment:
|
||||||
# Admin Appointments Tests
|
# Admin Appointments Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestAdminViewAppointments:
|
class TestAdminViewAppointments:
|
||||||
"""Test admin viewing all appointments."""
|
"""Test admin viewing all appointments."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_can_view_all_appointments(self, client_factory, regular_user, admin_user):
|
async def test_admin_can_view_all_appointments(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Admin can view all appointments."""
|
"""Admin can view all appointments."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -725,7 +788,9 @@ class TestAdminViewAppointments:
|
||||||
assert any(apt["note"] == "Test" for apt in data["records"])
|
assert any(apt["note"] == "Test" for apt in data["records"])
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_cannot_view_all_appointments(self, client_factory, regular_user):
|
async def test_regular_user_cannot_view_all_appointments(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""Regular user cannot access admin appointments endpoint."""
|
"""Regular user cannot access admin appointments endpoint."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/admin/appointments")
|
response = await client.get("/api/admin/appointments")
|
||||||
|
|
@ -743,7 +808,9 @@ class TestAdminCancelAppointment:
|
||||||
"""Test admin cancelling appointments."""
|
"""Test admin cancelling appointments."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_can_cancel_any_appointment(self, client_factory, regular_user, admin_user):
|
async def test_admin_can_cancel_any_appointment(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Admin can cancel any user's appointment."""
|
"""Admin can cancel any user's appointment."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -765,7 +832,9 @@ class TestAdminCancelAppointment:
|
||||||
|
|
||||||
# Admin cancels
|
# Admin cancels
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
response = await admin_client.post(
|
||||||
|
f"/api/admin/appointments/{apt_id}/cancel"
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -773,7 +842,9 @@ class TestAdminCancelAppointment:
|
||||||
assert data["cancelled_at"] is not None
|
assert data["cancelled_at"] is not None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_cannot_use_admin_cancel(self, client_factory, regular_user, admin_user):
|
async def test_regular_user_cannot_use_admin_cancel(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Regular user cannot use admin cancel endpoint."""
|
"""Regular user cannot use admin cancel endpoint."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -799,7 +870,9 @@ class TestAdminCancelAppointment:
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_cancel_nonexistent_appointment(self, client_factory, admin_user):
|
async def test_admin_cancel_nonexistent_appointment(
|
||||||
|
self, client_factory, admin_user
|
||||||
|
):
|
||||||
"""Returns 404 for non-existent appointment."""
|
"""Returns 404 for non-existent appointment."""
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
response = await client.post("/api/admin/appointments/99999/cancel")
|
response = await client.post("/api/admin/appointments/99999/cancel")
|
||||||
|
|
@ -807,7 +880,9 @@ class TestAdminCancelAppointment:
|
||||||
assert response.status_code == 404
|
assert response.status_code == 404
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_cannot_cancel_already_cancelled(self, client_factory, regular_user, admin_user):
|
async def test_admin_cannot_cancel_already_cancelled(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Admin cannot cancel an already cancelled appointment."""
|
"""Admin cannot cancel an already cancelled appointment."""
|
||||||
# Admin sets availability
|
# Admin sets availability
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
|
|
@ -832,17 +907,21 @@ class TestAdminCancelAppointment:
|
||||||
|
|
||||||
# Admin tries to cancel again
|
# Admin tries to cancel again
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
response = await admin_client.post(
|
||||||
|
f"/api/admin/appointments/{apt_id}/cancel"
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "cancelled_by_user" in response.json()["detail"]
|
assert "cancelled_by_user" in response.json()["detail"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_cannot_cancel_past_appointment(self, client_factory, regular_user, admin_user):
|
async def test_admin_cannot_cancel_past_appointment(
|
||||||
|
self, client_factory, regular_user, admin_user
|
||||||
|
):
|
||||||
"""Admin cannot cancel a past appointment."""
|
"""Admin cannot cancel a past appointment."""
|
||||||
# Create a past appointment directly in DB
|
# Create a past appointment directly in DB
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||||
appointment = Appointment(
|
appointment = Appointment(
|
||||||
user_id=regular_user["user"]["id"],
|
user_id=regular_user["user"]["id"],
|
||||||
slot_start=past_time,
|
slot_start=past_time,
|
||||||
|
|
@ -856,8 +935,9 @@ class TestAdminCancelAppointment:
|
||||||
|
|
||||||
# Admin tries to cancel
|
# Admin tries to cancel
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as admin_client:
|
||||||
response = await admin_client.post(f"/api/admin/appointments/{apt_id}/cancel")
|
response = await admin_client.post(
|
||||||
|
f"/api/admin/appointments/{apt_id}/cancel"
|
||||||
|
)
|
||||||
|
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert "past" in response.json()["detail"].lower()
|
assert "past" in response.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,13 @@
|
||||||
|
|
||||||
Note: Registration now requires an invite code.
|
Note: Registration now requires an invite code.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from auth import COOKIE_NAME
|
from auth import COOKIE_NAME
|
||||||
from models import ROLE_REGULAR
|
from models import ROLE_REGULAR
|
||||||
from tests.helpers import unique_email, create_invite_for_godfather
|
|
||||||
from tests.conftest import create_user_with_roles
|
from tests.conftest import create_user_with_roles
|
||||||
|
from tests.helpers import create_invite_for_godfather, unique_email
|
||||||
|
|
||||||
|
|
||||||
# Protected endpoint tests - without auth
|
# Protected endpoint tests - without auth
|
||||||
|
|
@ -41,7 +42,9 @@ async def test_increment_counter_invalid_cookie(client_factory):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_counter_authenticated(client_factory):
|
async def test_get_counter_authenticated(client_factory):
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
reg = await client_factory.post(
|
reg = await client_factory.post(
|
||||||
|
|
@ -64,7 +67,9 @@ async def test_get_counter_authenticated(client_factory):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_increment_counter(client_factory):
|
async def test_increment_counter(client_factory):
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
reg = await client_factory.post(
|
reg = await client_factory.post(
|
||||||
|
|
@ -91,7 +96,9 @@ async def test_increment_counter(client_factory):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_increment_counter_multiple(client_factory):
|
async def test_increment_counter_multiple(client_factory):
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
reg = await client_factory.post(
|
reg = await client_factory.post(
|
||||||
|
|
@ -120,7 +127,9 @@ async def test_increment_counter_multiple(client_factory):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_counter_after_increment(client_factory):
|
async def test_get_counter_after_increment(client_factory):
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
reg = await client_factory.post(
|
reg = await client_factory.post(
|
||||||
|
|
@ -149,7 +158,9 @@ async def test_get_counter_after_increment(client_factory):
|
||||||
async def test_counter_shared_between_users(client_factory):
|
async def test_counter_shared_between_users(client_factory):
|
||||||
# Create godfather and invites for two users
|
# Create godfather and invites for two users
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,23 @@
|
||||||
"""Tests for invite functionality."""
|
"""Tests for invite functionality."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from invite_utils import (
|
from invite_utils import (
|
||||||
generate_invite_identifier,
|
|
||||||
normalize_identifier,
|
|
||||||
is_valid_identifier_format,
|
|
||||||
BIP39_WORDS,
|
BIP39_WORDS,
|
||||||
|
generate_invite_identifier,
|
||||||
|
is_valid_identifier_format,
|
||||||
|
normalize_identifier,
|
||||||
)
|
)
|
||||||
from models import Invite, InviteStatus, User, ROLE_REGULAR
|
from models import ROLE_REGULAR, Invite, InviteStatus, User
|
||||||
from tests.helpers import unique_email
|
|
||||||
from tests.conftest import create_user_with_roles
|
from tests.conftest import create_user_with_roles
|
||||||
|
from tests.helpers import unique_email
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Invite Utils Tests
|
# Invite Utils Tests
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
def test_bip39_words_loaded():
|
def test_bip39_words_loaded():
|
||||||
"""BIP39 word list should have exactly 2048 words."""
|
"""BIP39 word list should have exactly 2048 words."""
|
||||||
assert len(BIP39_WORDS) == 2048
|
assert len(BIP39_WORDS) == 2048
|
||||||
|
|
@ -89,6 +90,7 @@ def test_is_valid_identifier_format_invalid():
|
||||||
# Invite Model Tests
|
# Invite Model Tests
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_invite(client_factory):
|
async def test_create_invite(client_factory):
|
||||||
"""Can create an invite with godfather."""
|
"""Can create an invite with godfather."""
|
||||||
|
|
@ -173,7 +175,7 @@ async def test_invite_unique_identifier(client_factory):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invite_status_transitions(client_factory):
|
async def test_invite_status_transitions(client_factory):
|
||||||
"""Invite status can be changed."""
|
"""Invite status can be changed."""
|
||||||
from datetime import datetime, UTC
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(
|
godfather = await create_user_with_roles(
|
||||||
|
|
@ -206,7 +208,7 @@ async def test_invite_status_transitions(client_factory):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invite_revoke(client_factory):
|
async def test_invite_revoke(client_factory):
|
||||||
"""Invite can be revoked."""
|
"""Invite can be revoked."""
|
||||||
from datetime import datetime, UTC
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(
|
godfather = await create_user_with_roles(
|
||||||
|
|
@ -236,6 +238,7 @@ async def test_invite_revoke(client_factory):
|
||||||
# User Godfather Tests
|
# User Godfather Tests
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_godfather_relationship(client_factory):
|
async def test_user_godfather_relationship(client_factory):
|
||||||
"""User can have a godfather."""
|
"""User can have a godfather."""
|
||||||
|
|
@ -254,9 +257,7 @@ async def test_user_godfather_relationship(client_factory):
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Query user fresh
|
# Query user fresh
|
||||||
result = await db.execute(
|
result = await db.execute(select(User).where(User.id == user.id))
|
||||||
select(User).where(User.id == user.id)
|
|
||||||
)
|
|
||||||
loaded_user = result.scalar_one()
|
loaded_user = result.scalar_one()
|
||||||
|
|
||||||
assert loaded_user.godfather_id == godfather.id
|
assert loaded_user.godfather_id == godfather.id
|
||||||
|
|
@ -280,6 +281,7 @@ async def test_user_without_godfather(client_factory):
|
||||||
# Admin Create Invite API Tests (Phase 2)
|
# Admin Create Invite API Tests (Phase 2)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_can_create_invite(client_factory, admin_user, regular_user):
|
async def test_admin_can_create_invite(client_factory, admin_user, regular_user):
|
||||||
"""Admin can create an invite for a regular user."""
|
"""Admin can create an invite for a regular user."""
|
||||||
|
|
@ -387,9 +389,7 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
|
||||||
|
|
||||||
# Query from DB
|
# Query from DB
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
result = await db.execute(
|
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||||
select(Invite).where(Invite.id == invite_id)
|
|
||||||
)
|
|
||||||
invite = result.scalar_one()
|
invite = result.scalar_one()
|
||||||
|
|
||||||
assert invite.identifier == data["identifier"]
|
assert invite.identifier == data["identifier"]
|
||||||
|
|
@ -398,7 +398,9 @@ async def test_created_invite_persisted_in_db(client_factory, admin_user, regula
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_invite_retries_on_collision(client_factory, admin_user, regular_user):
|
async def test_create_invite_retries_on_collision(
|
||||||
|
client_factory, admin_user, regular_user
|
||||||
|
):
|
||||||
"""Create invite retries with new identifier on collision."""
|
"""Create invite retries with new identifier on collision."""
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
@ -419,6 +421,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
||||||
|
|
||||||
# Mock generator to first return the same identifier (collision), then a new one
|
# Mock generator to first return the same identifier (collision), then a new one
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
def mock_generator():
|
def mock_generator():
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
|
|
@ -426,7 +429,9 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
||||||
return identifier1 # Will collide
|
return identifier1 # Will collide
|
||||||
return f"unique-word-{call_count:02d}" # Won't collide
|
return f"unique-word-{call_count:02d}" # Won't collide
|
||||||
|
|
||||||
with patch("routes.invites.generate_invite_identifier", side_effect=mock_generator):
|
with patch(
|
||||||
|
"routes.invites.generate_invite_identifier", side_effect=mock_generator
|
||||||
|
):
|
||||||
response2 = await client.post(
|
response2 = await client.post(
|
||||||
"/api/admin/invites",
|
"/api/admin/invites",
|
||||||
json={"godfather_id": godfather.id},
|
json={"godfather_id": godfather.id},
|
||||||
|
|
@ -442,6 +447,7 @@ async def test_create_invite_retries_on_collision(client_factory, admin_user, re
|
||||||
# Invite Check API Tests (Phase 3)
|
# Invite Check API Tests (Phase 3)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_invite_valid(client_factory, admin_user, regular_user):
|
async def test_check_invite_valid(client_factory, admin_user, regular_user):
|
||||||
"""Check endpoint returns valid=True for READY invite."""
|
"""Check endpoint returns valid=True for READY invite."""
|
||||||
|
|
@ -509,7 +515,9 @@ async def test_check_invite_invalid_format(client_factory):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_invite_spent_returns_not_found(client_factory, admin_user, regular_user):
|
async def test_check_invite_spent_returns_not_found(
|
||||||
|
client_factory, admin_user, regular_user
|
||||||
|
):
|
||||||
"""Check endpoint returns same error for spent invite as for non-existent (no info leakage)."""
|
"""Check endpoint returns same error for spent invite as for non-existent (no info leakage)."""
|
||||||
# Create invite
|
# Create invite
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
|
|
@ -547,9 +555,11 @@ async def test_check_invite_spent_returns_not_found(client_factory, admin_user,
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_check_invite_revoked_returns_not_found(client_factory, admin_user, regular_user):
|
async def test_check_invite_revoked_returns_not_found(
|
||||||
|
client_factory, admin_user, regular_user
|
||||||
|
):
|
||||||
"""Check endpoint returns same error for revoked invite as for non-existent (no info leakage)."""
|
"""Check endpoint returns same error for revoked invite as for non-existent (no info leakage)."""
|
||||||
from datetime import datetime, UTC
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
# Create invite
|
# Create invite
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
|
|
@ -613,6 +623,7 @@ async def test_check_invite_case_insensitive(client_factory, admin_user, regular
|
||||||
# Register with Invite Tests (Phase 3)
|
# Register with Invite Tests (Phase 3)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_with_valid_invite(client_factory, admin_user, regular_user):
|
async def test_register_with_valid_invite(client_factory, admin_user, regular_user):
|
||||||
"""Can register with valid invite code."""
|
"""Can register with valid invite code."""
|
||||||
|
|
@ -681,9 +692,7 @@ async def test_register_marks_invite_spent(client_factory, admin_user, regular_u
|
||||||
|
|
||||||
# Check invite status
|
# Check invite status
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
result = await db.execute(
|
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||||
select(Invite).where(Invite.id == invite_id)
|
|
||||||
)
|
|
||||||
invite = result.scalar_one()
|
invite = result.scalar_one()
|
||||||
|
|
||||||
assert invite.status == InviteStatus.SPENT
|
assert invite.status == InviteStatus.SPENT
|
||||||
|
|
@ -723,9 +732,7 @@ async def test_register_sets_godfather(client_factory, admin_user, regular_user)
|
||||||
|
|
||||||
# Check user's godfather
|
# Check user's godfather
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
result = await db.execute(
|
result = await db.execute(select(User).where(User.email == new_email))
|
||||||
select(User).where(User.email == new_email)
|
|
||||||
)
|
|
||||||
new_user = result.scalar_one()
|
new_user = result.scalar_one()
|
||||||
|
|
||||||
assert new_user.godfather_id == godfather_id
|
assert new_user.godfather_id == godfather_id
|
||||||
|
|
@ -794,7 +801,7 @@ async def test_register_with_spent_invite(client_factory, admin_user, regular_us
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_register_with_revoked_invite(client_factory, admin_user, regular_user):
|
async def test_register_with_revoked_invite(client_factory, admin_user, regular_user):
|
||||||
"""Cannot register with revoked invite."""
|
"""Cannot register with revoked invite."""
|
||||||
from datetime import datetime, UTC
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
# Create invite
|
# Create invite
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
|
|
@ -814,9 +821,7 @@ async def test_register_with_revoked_invite(client_factory, admin_user, regular_
|
||||||
|
|
||||||
# Revoke invite directly in DB
|
# Revoke invite directly in DB
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
result = await db.execute(
|
result = await db.execute(select(Invite).where(Invite.id == invite_id))
|
||||||
select(Invite).where(Invite.id == invite_id)
|
|
||||||
)
|
|
||||||
invite = result.scalar_one()
|
invite = result.scalar_one()
|
||||||
invite.status = InviteStatus.REVOKED
|
invite.status = InviteStatus.REVOKED
|
||||||
invite.revoked_at = datetime.now(UTC)
|
invite.revoked_at = datetime.now(UTC)
|
||||||
|
|
@ -904,6 +909,7 @@ async def test_register_sets_auth_cookie(client_factory, admin_user, regular_use
|
||||||
# User Invites API Tests (Phase 4)
|
# User Invites API Tests (Phase 4)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user):
|
async def test_regular_user_can_list_invites(client_factory, admin_user, regular_user):
|
||||||
"""Regular user can list their own invites."""
|
"""Regular user can list their own invites."""
|
||||||
|
|
@ -941,7 +947,9 @@ async def test_user_with_no_invites_gets_empty_list(client_factory, regular_user
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_spent_invite_shows_used_by_email(client_factory, admin_user, regular_user):
|
async def test_spent_invite_shows_used_by_email(
|
||||||
|
client_factory, admin_user, regular_user
|
||||||
|
):
|
||||||
"""Spent invite shows who used it."""
|
"""Spent invite shows who used it."""
|
||||||
# Create invite for regular user
|
# Create invite for regular user
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
|
|
@ -1002,6 +1010,7 @@ async def test_unauthenticated_cannot_list_invites(client_factory):
|
||||||
# Admin Invite Management Tests (Phase 5)
|
# Admin Invite Management Tests (Phase 5)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user):
|
async def test_admin_can_list_all_invites(client_factory, admin_user, regular_user):
|
||||||
"""Admin can list all invites."""
|
"""Admin can list all invites."""
|
||||||
|
|
@ -1153,4 +1162,3 @@ async def test_regular_user_cannot_access_admin_invites(client_factory, regular_
|
||||||
# Revoke
|
# Revoke
|
||||||
response = await client.post("/api/admin/invites/1/revoke")
|
response = await client.post("/api/admin/invites/1/revoke")
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,15 +7,16 @@ These tests verify that:
|
||||||
3. Unauthenticated users are denied access (401)
|
3. Unauthenticated users are denied access (401)
|
||||||
4. The permission system cannot be bypassed
|
4. The permission system cannot be bypassed
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from models import Permission
|
from models import Permission
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Role Assignment Tests
|
# Role Assignment Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestRoleAssignment:
|
class TestRoleAssignment:
|
||||||
"""Test that roles are properly assigned and returned."""
|
"""Test that roles are properly assigned and returned."""
|
||||||
|
|
||||||
|
|
@ -40,7 +41,9 @@ class TestRoleAssignment:
|
||||||
assert "regular" not in data["roles"]
|
assert "regular" not in data["roles"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_has_correct_permissions(self, client_factory, regular_user):
|
async def test_regular_user_has_correct_permissions(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/auth/me")
|
response = await client.get("/api/auth/me")
|
||||||
|
|
||||||
|
|
@ -72,7 +75,9 @@ class TestRoleAssignment:
|
||||||
assert Permission.USE_SUM.value not in permissions
|
assert Permission.USE_SUM.value not in permissions
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_with_no_roles_has_no_permissions(self, client_factory, user_no_roles):
|
async def test_user_with_no_roles_has_no_permissions(
|
||||||
|
self, client_factory, user_no_roles
|
||||||
|
):
|
||||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||||
response = await client.get("/api/auth/me")
|
response = await client.get("/api/auth/me")
|
||||||
|
|
||||||
|
|
@ -85,6 +90,7 @@ class TestRoleAssignment:
|
||||||
# Counter Endpoint Access Tests
|
# Counter Endpoint Access Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestCounterAccess:
|
class TestCounterAccess:
|
||||||
"""Test access control for counter endpoints."""
|
"""Test access control for counter endpoints."""
|
||||||
|
|
||||||
|
|
@ -97,7 +103,9 @@ class TestCounterAccess:
|
||||||
assert "value" in response.json()
|
assert "value" in response.json()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_can_increment_counter(self, client_factory, regular_user):
|
async def test_regular_user_can_increment_counter(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.post("/api/counter/increment")
|
response = await client.post("/api/counter/increment")
|
||||||
|
|
||||||
|
|
@ -122,7 +130,9 @@ class TestCounterAccess:
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_without_roles_cannot_view_counter(self, client_factory, user_no_roles):
|
async def test_user_without_roles_cannot_view_counter(
|
||||||
|
self, client_factory, user_no_roles
|
||||||
|
):
|
||||||
"""Users with no roles should be forbidden."""
|
"""Users with no roles should be forbidden."""
|
||||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||||
response = await client.get("/api/counter")
|
response = await client.get("/api/counter")
|
||||||
|
|
@ -146,6 +156,7 @@ class TestCounterAccess:
|
||||||
# Sum Endpoint Access Tests
|
# Sum Endpoint Access Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestSumAccess:
|
class TestSumAccess:
|
||||||
"""Test access control for sum endpoint."""
|
"""Test access control for sum endpoint."""
|
||||||
|
|
||||||
|
|
@ -173,7 +184,9 @@ class TestSumAccess:
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_without_roles_cannot_use_sum(self, client_factory, user_no_roles):
|
async def test_user_without_roles_cannot_use_sum(
|
||||||
|
self, client_factory, user_no_roles
|
||||||
|
):
|
||||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/sum",
|
"/api/sum",
|
||||||
|
|
@ -195,6 +208,7 @@ class TestSumAccess:
|
||||||
# Audit Endpoint Access Tests
|
# Audit Endpoint Access Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestAuditAccess:
|
class TestAuditAccess:
|
||||||
"""Test access control for audit endpoints."""
|
"""Test access control for audit endpoints."""
|
||||||
|
|
||||||
|
|
@ -219,7 +233,9 @@ class TestAuditAccess:
|
||||||
assert "total" in data
|
assert "total" in data
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_cannot_view_counter_audit(self, client_factory, regular_user):
|
async def test_regular_user_cannot_view_counter_audit(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""Regular users should be forbidden from audit endpoints."""
|
"""Regular users should be forbidden from audit endpoints."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/audit/counter")
|
response = await client.get("/api/audit/counter")
|
||||||
|
|
@ -228,7 +244,9 @@ class TestAuditAccess:
|
||||||
assert "permission" in response.json()["detail"].lower()
|
assert "permission" in response.json()["detail"].lower()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_regular_user_cannot_view_sum_audit(self, client_factory, regular_user):
|
async def test_regular_user_cannot_view_sum_audit(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""Regular users should be forbidden from audit endpoints."""
|
"""Regular users should be forbidden from audit endpoints."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/audit/sum")
|
response = await client.get("/api/audit/sum")
|
||||||
|
|
@ -236,7 +254,9 @@ class TestAuditAccess:
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_without_roles_cannot_view_audit(self, client_factory, user_no_roles):
|
async def test_user_without_roles_cannot_view_audit(
|
||||||
|
self, client_factory, user_no_roles
|
||||||
|
):
|
||||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||||
response = await client.get("/api/audit/counter")
|
response = await client.get("/api/audit/counter")
|
||||||
|
|
||||||
|
|
@ -257,6 +277,7 @@ class TestAuditAccess:
|
||||||
# Offensive Security Tests - Bypass Attempts
|
# Offensive Security Tests - Bypass Attempts
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestSecurityBypassAttempts:
|
class TestSecurityBypassAttempts:
|
||||||
"""
|
"""
|
||||||
Offensive tests that attempt to bypass security controls.
|
Offensive tests that attempt to bypass security controls.
|
||||||
|
|
@ -264,7 +285,9 @@ class TestSecurityBypassAttempts:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_access_audit_with_forged_role_claim(self, client_factory, regular_user):
|
async def test_cannot_access_audit_with_forged_role_claim(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Attempt to access audit by somehow claiming admin role.
|
Attempt to access audit by somehow claiming admin role.
|
||||||
The server should verify roles from DB, not trust client claims.
|
The server should verify roles from DB, not trust client claims.
|
||||||
|
|
@ -287,14 +310,18 @@ class TestSecurityBypassAttempts:
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cannot_access_with_tampered_token(self, client_factory, regular_user):
|
async def test_cannot_access_with_tampered_token(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""Test that tokens signed with wrong key are rejected."""
|
"""Test that tokens signed with wrong key are rejected."""
|
||||||
# Take a valid token and modify it
|
# Take a valid token and modify it
|
||||||
original_token = regular_user["cookies"].get("auth_token", "")
|
original_token = regular_user["cookies"].get("auth_token", "")
|
||||||
if original_token:
|
if original_token:
|
||||||
tampered_token = original_token[:-5] + "XXXXX"
|
tampered_token = original_token[:-5] + "XXXXX"
|
||||||
|
|
||||||
async with client_factory.create(cookies={"auth_token": tampered_token}) as client:
|
async with client_factory.create(
|
||||||
|
cookies={"auth_token": tampered_token}
|
||||||
|
) as client:
|
||||||
response = await client.get("/api/counter")
|
response = await client.get("/api/counter")
|
||||||
|
|
||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
@ -305,12 +332,14 @@ class TestSecurityBypassAttempts:
|
||||||
Test that new registrations cannot claim admin role.
|
Test that new registrations cannot claim admin role.
|
||||||
New users should only get 'regular' role by default.
|
New users should only get 'regular' role by default.
|
||||||
"""
|
"""
|
||||||
from tests.helpers import unique_email, create_invite_for_godfather
|
|
||||||
from tests.conftest import create_user_with_roles
|
|
||||||
from models import ROLE_REGULAR
|
from models import ROLE_REGULAR
|
||||||
|
from tests.conftest import create_user_with_roles
|
||||||
|
from tests.helpers import create_invite_for_godfather, unique_email
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
godfather = await create_user_with_roles(db, unique_email("gf"), "pass123", [ROLE_REGULAR])
|
godfather = await create_user_with_roles(
|
||||||
|
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||||
|
)
|
||||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||||
|
|
||||||
response = await client_factory.post(
|
response = await client_factory.post(
|
||||||
|
|
@ -341,15 +370,17 @@ class TestSecurityBypassAttempts:
|
||||||
If a user is deleted, their token should no longer work.
|
If a user is deleted, their token should no longer work.
|
||||||
This tests that tokens are validated against current DB state.
|
This tests that tokens are validated against current DB state.
|
||||||
"""
|
"""
|
||||||
from tests.helpers import unique_email
|
|
||||||
from sqlalchemy import delete
|
from sqlalchemy import delete
|
||||||
|
|
||||||
from models import User
|
from models import User
|
||||||
|
from tests.helpers import unique_email
|
||||||
|
|
||||||
email = unique_email("deleted")
|
email = unique_email("deleted")
|
||||||
|
|
||||||
# Create and login user
|
# Create and login user
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
from tests.conftest import create_user_with_roles
|
from tests.conftest import create_user_with_roles
|
||||||
|
|
||||||
user = await create_user_with_roles(db, email, "password123", ["regular"])
|
user = await create_user_with_roles(db, email, "password123", ["regular"])
|
||||||
user_id = user.id
|
user_id = user.id
|
||||||
|
|
||||||
|
|
@ -376,15 +407,17 @@ class TestSecurityBypassAttempts:
|
||||||
If a user's role is changed, the change should be reflected
|
If a user's role is changed, the change should be reflected
|
||||||
in subsequent requests (no stale permission cache).
|
in subsequent requests (no stale permission cache).
|
||||||
"""
|
"""
|
||||||
from tests.helpers import unique_email
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from models import User, Role
|
|
||||||
|
from models import Role, User
|
||||||
|
from tests.helpers import unique_email
|
||||||
|
|
||||||
email = unique_email("rolechange")
|
email = unique_email("rolechange")
|
||||||
|
|
||||||
# Create regular user
|
# Create regular user
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
from tests.conftest import create_user_with_roles
|
from tests.conftest import create_user_with_roles
|
||||||
|
|
||||||
await create_user_with_roles(db, email, "password123", ["regular"])
|
await create_user_with_roles(db, email, "password123", ["regular"])
|
||||||
|
|
||||||
login_response = await client_factory.post(
|
login_response = await client_factory.post(
|
||||||
|
|
@ -406,10 +439,7 @@ class TestSecurityBypassAttempts:
|
||||||
result = await db.execute(select(Role).where(Role.name == "admin"))
|
result = await db.execute(select(Role).where(Role.name == "admin"))
|
||||||
admin_role = result.scalar_one()
|
admin_role = result.scalar_one()
|
||||||
|
|
||||||
result = await db.execute(select(Role).where(Role.name == "regular"))
|
user.roles = [admin_role] # Replace roles with admin only
|
||||||
regular_role = result.scalar_one()
|
|
||||||
|
|
||||||
user.roles = [admin_role] # Remove regular, add admin
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
# Now should have audit access but not counter access
|
# Now should have audit access but not counter access
|
||||||
|
|
@ -422,6 +452,7 @@ class TestSecurityBypassAttempts:
|
||||||
# Audit Record Tests
|
# Audit Record Tests
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
class TestAuditRecords:
|
class TestAuditRecords:
|
||||||
"""Test that actions are properly recorded in audit logs."""
|
"""Test that actions are properly recorded in audit logs."""
|
||||||
|
|
||||||
|
|
@ -466,7 +497,7 @@ class TestAuditRecords:
|
||||||
|
|
||||||
# Find record with our values
|
# Find record with our values
|
||||||
records = data["records"]
|
records = data["records"]
|
||||||
matching = [r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30]
|
matching = [
|
||||||
|
r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30
|
||||||
|
]
|
||||||
assert len(matching) >= 1
|
assert len(matching) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
"""Tests for user profile and contact details."""
|
"""Tests for user profile and contact details."""
|
||||||
import pytest
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from models import User, ROLE_REGULAR
|
|
||||||
from auth import get_password_hash
|
from auth import get_password_hash
|
||||||
|
from models import User
|
||||||
from tests.helpers import unique_email
|
from tests.helpers import unique_email
|
||||||
|
|
||||||
# Valid npub for testing (32 zero bytes encoded as bech32)
|
# Valid npub for testing (32 zero bytes encoded as bech32)
|
||||||
|
|
@ -328,7 +328,9 @@ class TestUpdateProfileEndpoint:
|
||||||
assert "field_errors" in data["detail"]
|
assert "field_errors" in data["detail"]
|
||||||
assert "nostr_npub" in data["detail"]["field_errors"]
|
assert "nostr_npub" in data["detail"]["field_errors"]
|
||||||
|
|
||||||
async def test_multiple_invalid_fields_returns_all_errors(self, client_factory, regular_user):
|
async def test_multiple_invalid_fields_returns_all_errors(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""Multiple invalid fields return all errors."""
|
"""Multiple invalid fields return all errors."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.put(
|
response = await client.put(
|
||||||
|
|
@ -344,7 +346,9 @@ class TestUpdateProfileEndpoint:
|
||||||
assert "contact_email" in data["detail"]["field_errors"]
|
assert "contact_email" in data["detail"]["field_errors"]
|
||||||
assert "telegram" in data["detail"]["field_errors"]
|
assert "telegram" in data["detail"]["field_errors"]
|
||||||
|
|
||||||
async def test_partial_update_preserves_other_fields(self, client_factory, regular_user):
|
async def test_partial_update_preserves_other_fields(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""Updating one field doesn't affect others (they get set to the request values)."""
|
"""Updating one field doesn't affect others (they get set to the request values)."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
# Set initial values
|
# Set initial values
|
||||||
|
|
@ -402,11 +406,14 @@ class TestProfilePrivacy:
|
||||||
class TestProfileGodfather:
|
class TestProfileGodfather:
|
||||||
"""Tests for godfather information in profile."""
|
"""Tests for godfather information in profile."""
|
||||||
|
|
||||||
async def test_profile_shows_godfather_email(self, client_factory, admin_user, regular_user):
|
async def test_profile_shows_godfather_email(
|
||||||
|
self, client_factory, admin_user, regular_user
|
||||||
|
):
|
||||||
"""Profile shows godfather email for users who signed up with invite."""
|
"""Profile shows godfather email for users who signed up with invite."""
|
||||||
from tests.helpers import unique_email
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from models import User
|
from models import User
|
||||||
|
from tests.helpers import unique_email
|
||||||
|
|
||||||
# Create invite
|
# Create invite
|
||||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||||
|
|
@ -443,7 +450,9 @@ class TestProfileGodfather:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["godfather_email"] == regular_user["email"]
|
assert data["godfather_email"] == regular_user["email"]
|
||||||
|
|
||||||
async def test_profile_godfather_null_for_seeded_users(self, client_factory, regular_user):
|
async def test_profile_godfather_null_for_seeded_users(
|
||||||
|
self, client_factory, regular_user
|
||||||
|
):
|
||||||
"""Profile shows null godfather for users without one (e.g., seeded users)."""
|
"""Profile shows null godfather for users without one (e.g., seeded users)."""
|
||||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||||
response = await client.get("/api/profile")
|
response = await client.get("/api/profile")
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
"""Tests for profile field validation."""
|
"""Tests for profile field validation."""
|
||||||
import pytest
|
|
||||||
|
|
||||||
from validation import (
|
from validation import (
|
||||||
validate_contact_email,
|
validate_contact_email,
|
||||||
validate_telegram,
|
|
||||||
validate_signal,
|
|
||||||
validate_nostr_npub,
|
validate_nostr_npub,
|
||||||
validate_profile_fields,
|
validate_profile_fields,
|
||||||
|
validate_signal,
|
||||||
|
validate_telegram,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,13 +139,17 @@ class TestValidateNostrNpub:
|
||||||
assert validate_nostr_npub(self.VALID_NPUB) is None
|
assert validate_nostr_npub(self.VALID_NPUB) is None
|
||||||
|
|
||||||
def test_wrong_prefix(self):
|
def test_wrong_prefix(self):
|
||||||
result = validate_nostr_npub("nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz")
|
result = validate_nostr_npub(
|
||||||
|
"nsec1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqwcv5dz"
|
||||||
|
)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "npub" in result.lower()
|
assert "npub" in result.lower()
|
||||||
|
|
||||||
def test_invalid_checksum(self):
|
def test_invalid_checksum(self):
|
||||||
# Change last character to break checksum
|
# Change last character to break checksum
|
||||||
result = validate_nostr_npub("npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd")
|
result = validate_nostr_npub(
|
||||||
|
"npub1qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpd"
|
||||||
|
)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "checksum" in result.lower()
|
assert "checksum" in result.lower()
|
||||||
|
|
||||||
|
|
@ -155,7 +158,9 @@ class TestValidateNostrNpub:
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
||||||
def test_not_starting_with_npub1(self):
|
def test_not_starting_with_npub1(self):
|
||||||
result = validate_nostr_npub("npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc")
|
result = validate_nostr_npub(
|
||||||
|
"npub2qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqsutgpc"
|
||||||
|
)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "npub1" in result
|
assert "npub1" in result
|
||||||
|
|
||||||
|
|
@ -206,4 +211,3 @@ class TestValidateProfileFields:
|
||||||
nostr_npub="",
|
nostr_npub="",
|
||||||
)
|
)
|
||||||
assert errors == {}
|
assert errors == {}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
"""Validate shared constants match backend definitions."""
|
"""Validate shared constants match backend definitions."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from models import ROLE_ADMIN, ROLE_REGULAR, InviteStatus, AppointmentStatus
|
from models import ROLE_ADMIN, ROLE_REGULAR, AppointmentStatus, InviteStatus
|
||||||
|
|
||||||
|
|
||||||
def validate_shared_constants() -> None:
|
def validate_shared_constants() -> None:
|
||||||
|
|
@ -29,35 +30,42 @@ def validate_shared_constants() -> None:
|
||||||
# Validate invite statuses
|
# Validate invite statuses
|
||||||
expected_invite_statuses = {s.name: s.value for s in InviteStatus}
|
expected_invite_statuses = {s.name: s.value for s in InviteStatus}
|
||||||
if constants.get("inviteStatuses") != expected_invite_statuses:
|
if constants.get("inviteStatuses") != expected_invite_statuses:
|
||||||
|
got = constants.get("inviteStatuses")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invite status mismatch in shared/constants.json. "
|
f"Invite status mismatch. Expected: {expected_invite_statuses}, Got: {got}"
|
||||||
f"Expected: {expected_invite_statuses}, Got: {constants.get('inviteStatuses')}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate appointment statuses
|
# Validate appointment statuses
|
||||||
expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus}
|
expected_appointment_statuses = {s.name: s.value for s in AppointmentStatus}
|
||||||
if constants.get("appointmentStatuses") != expected_appointment_statuses:
|
if constants.get("appointmentStatuses") != expected_appointment_statuses:
|
||||||
|
got = constants.get("appointmentStatuses")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Appointment status mismatch in shared/constants.json. "
|
f"Appointment status mismatch. "
|
||||||
f"Expected: {expected_appointment_statuses}, Got: {constants.get('appointmentStatuses')}"
|
f"Expected: {expected_appointment_statuses}, Got: {got}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate booking constants exist with required fields
|
# Validate booking constants exist with required fields
|
||||||
booking = constants.get("booking", {})
|
booking = constants.get("booking", {})
|
||||||
required_booking_fields = ["slotDurationMinutes", "maxAdvanceDays", "minAdvanceDays", "noteMaxLength"]
|
required_booking_fields = [
|
||||||
|
"slotDurationMinutes",
|
||||||
|
"maxAdvanceDays",
|
||||||
|
"minAdvanceDays",
|
||||||
|
"noteMaxLength",
|
||||||
|
]
|
||||||
for field in required_booking_fields:
|
for field in required_booking_fields:
|
||||||
if field not in booking:
|
if field not in booking:
|
||||||
raise ValueError(f"Missing booking constant '{field}' in shared/constants.json")
|
raise ValueError(f"Missing booking constant '{field}' in constants.json")
|
||||||
|
|
||||||
# Validate validation rules exist (structure check only)
|
# Validate validation rules exist (structure check only)
|
||||||
validation = constants.get("validation", {})
|
validation = constants.get("validation", {})
|
||||||
required_fields = ["telegram", "signal", "nostrNpub"]
|
required_fields = ["telegram", "signal", "nostrNpub"]
|
||||||
for field in required_fields:
|
for field in required_fields:
|
||||||
if field not in validation:
|
if field not in validation:
|
||||||
raise ValueError(f"Missing validation rules for '{field}' in shared/constants.json")
|
raise ValueError(
|
||||||
|
f"Missing validation rules for '{field}' in constants.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
validate_shared_constants()
|
validate_shared_constants()
|
||||||
print("✓ Shared constants are valid")
|
print("✓ Shared constants are valid")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
"""Validation utilities for user profile fields."""
|
"""Validation utilities for user profile fields."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from email_validator import validate_email, EmailNotValidError
|
|
||||||
from bech32 import bech32_decode
|
from bech32 import bech32_decode
|
||||||
|
from email_validator import EmailNotValidError, validate_email
|
||||||
|
|
||||||
# Load validation rules from shared constants
|
# Load validation rules from shared constants
|
||||||
_constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
|
_constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
|
||||||
|
|
@ -143,4 +144,3 @@ def validate_profile_fields(
|
||||||
errors["nostr_npub"] = err
|
errors["nostr_npub"] = err
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue