tests passing

This commit is contained in:
counterweight 2025-12-18 23:33:32 +01:00
parent 322bdd3e6e
commit b173b47925
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
18 changed files with 1414 additions and 93 deletions

29
.env.example Normal file
View file

@ -0,0 +1,29 @@
# Local development environment variables
# Copy this file to .env and fill in the values
# To use: install direnv (https://direnv.net), then run `direnv allow`
# =============================================================================
# Backend
# =============================================================================
# Required: Secret key for JWT token signing
# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))"
SECRET_KEY=
# Database URL
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/arbret
# Dev user credentials (regular user)
DEV_USER_EMAIL=dev@example.com
DEV_USER_PASSWORD=devpass123
# Dev admin credentials
DEV_ADMIN_EMAIL=admin@example.com
DEV_ADMIN_PASSWORD=admin123
# =============================================================================
# Frontend
# =============================================================================
# API URL for the backend
NEXT_PUBLIC_API_URL=http://localhost:8000

2
.gitignore vendored
View file

@ -12,7 +12,7 @@ node_modules/
# Env # Env
.env .env
.env.* .env
# IDE # IDE
.idea/ .idea/

View file

@ -1,6 +1,6 @@
import os import os
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Optional from typing import List, Optional
import bcrypt import bcrypt
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
@ -10,7 +10,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db from database import get_db
from models import User from models import User, Permission
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
ALGORITHM = "HS256" ALGORITHM = "HS256"
@ -30,6 +30,8 @@ UserLogin = UserCredentials
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: int id: int
email: str email: str
roles: List[str]
permissions: List[str]
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
@ -99,3 +101,64 @@ async def get_current_user(
raise credentials_exception raise credentials_exception
return user return user
def require_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ALL of the required permissions.
Usage:
@app.get("/api/counter")
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))):
...
"""
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db)
missing = [p for p in required_permissions if p not in user_permissions]
if missing:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}",
)
return user
return permission_checker
def require_any_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ANY of the required permissions.
Usage:
@app.get("/api/resource")
async def get_resource(user: User = Depends(require_any_permission(Permission.VIEW, Permission.ADMIN))):
...
"""
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db)
if not any(p in user_permissions for p in required_permissions):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires one of: {', '.join(p.value for p in required_permissions)}",
)
return user
return permission_checker
async def build_user_response(user: User, db: AsyncSession) -> UserResponse:
"""Build a UserResponse with roles and permissions."""
permissions = await user.get_permissions(db)
return UserResponse(
id=user.id,
email=user.email,
roles=user.role_names,
permissions=[p.value for p in permissions],
)

View file

@ -9,3 +9,10 @@ SECRET_KEY=
# Database URL # Database URL
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/arbret DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/arbret
# Dev user credentials (regular user)
DEV_USER_EMAIL=
DEV_USER_PASSWORD=
# Dev admin credentials
DEV_ADMIN_EMAIL=
DEV_ADMIN_PASSWORD=

View file

@ -1,12 +1,15 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime
from typing import List
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from sqlalchemy import select, func, desc from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from database import engine, get_db, Base from database import engine, get_db, Base
from models import Counter, User, SumRecord, CounterRecord from models import Counter, User, SumRecord, CounterRecord, Permission, Role
from auth import ( from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES, ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME, COOKIE_NAME,
@ -18,6 +21,8 @@ from auth import (
authenticate_user, authenticate_user,
create_access_token, create_access_token,
get_current_user, get_current_user,
require_permission,
build_user_response,
) )
@ -50,6 +55,12 @@ def set_auth_cookie(response: Response, token: str) -> None:
) )
async def get_default_role(db: AsyncSession) -> Role | None:
"""Get the default 'regular' role for new users."""
result = await db.execute(select(Role).where(Role.name == "regular"))
return result.scalar_one_or_none()
# Auth endpoints # Auth endpoints
@app.post("/api/auth/register", response_model=UserResponse) @app.post("/api/auth/register", response_model=UserResponse)
async def register( async def register(
@ -68,13 +79,19 @@ async def register(
email=user_data.email, email=user_data.email,
hashed_password=get_password_hash(user_data.password), hashed_password=get_password_hash(user_data.password),
) )
# Assign default role if it exists
default_role = await get_default_role(db)
if default_role:
user.roles.append(default_role)
db.add(user) db.add(user)
await db.commit() await db.commit()
await db.refresh(user) await db.refresh(user)
access_token = create_access_token(data={"sub": str(user.id)}) access_token = create_access_token(data={"sub": str(user.id)})
set_auth_cookie(response, access_token) set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email) return await build_user_response(user, db)
@app.post("/api/auth/login", response_model=UserResponse) @app.post("/api/auth/login", response_model=UserResponse)
@ -92,7 +109,7 @@ async def login(
access_token = create_access_token(data={"sub": str(user.id)}) access_token = create_access_token(data={"sub": str(user.id)})
set_auth_cookie(response, access_token) set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email) return await build_user_response(user, db)
@app.post("/api/auth/logout") @app.post("/api/auth/logout")
@ -102,8 +119,11 @@ async def logout(response: Response):
@app.get("/api/auth/me", response_model=UserResponse) @app.get("/api/auth/me", response_model=UserResponse)
async def get_me(current_user: User = Depends(get_current_user)): async def get_me(
return UserResponse(id=current_user.id, email=current_user.email) current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
return await build_user_response(current_user, db)
# Counter endpoints # Counter endpoints
@ -121,7 +141,7 @@ async def get_or_create_counter(db: AsyncSession) -> Counter:
@app.get("/api/counter") @app.get("/api/counter")
async def get_counter( async def get_counter(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user), _current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
): ):
counter = await get_or_create_counter(db) counter = await get_or_create_counter(db)
return {"value": counter.value} return {"value": counter.value}
@ -130,7 +150,7 @@ async def get_counter(
@app.post("/api/counter/increment") @app.post("/api/counter/increment")
async def increment_counter( async def increment_counter(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
): ):
counter = await get_or_create_counter(db) counter = await get_or_create_counter(db)
value_before = counter.value value_before = counter.value
@ -162,7 +182,7 @@ class SumResponse(BaseModel):
async def calculate_sum( async def calculate_sum(
data: SumRequest, data: SumRequest,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(require_permission(Permission.USE_SUM)),
): ):
result = data.a + data.b result = data.a + data.b
record = SumRecord( record = SumRecord(
@ -177,10 +197,6 @@ async def calculate_sum(
# Audit endpoints # Audit endpoints
from datetime import datetime
from typing import List
class CounterRecordResponse(BaseModel): class CounterRecordResponse(BaseModel):
id: int id: int
user_email: str user_email: str
@ -219,7 +235,7 @@ async def get_counter_records(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100), per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user), _current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
): ):
# Get total count # Get total count
count_result = await db.execute(select(func.count(CounterRecord.id))) count_result = await db.execute(select(func.count(CounterRecord.id)))
@ -263,7 +279,7 @@ async def get_sum_records(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100), per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user), _current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
): ):
# Get total count # Get total count
count_result = await db.execute(select(func.count(SumRecord.id))) count_result = await db.execute(select(func.count(SumRecord.id)))
@ -301,4 +317,3 @@ async def get_sum_records(
per_page=per_page, per_page=per_page,
total_pages=total_pages, total_pages=total_pages,
) )

View file

@ -1,14 +1,92 @@
from datetime import datetime from datetime import datetime, UTC
from sqlalchemy import Integer, String, Float, DateTime, ForeignKey from enum import Enum as PyEnum
from sqlalchemy.orm import Mapped, mapped_column from typing import List, Set
from sqlalchemy import Integer, String, Float, DateTime, ForeignKey, Table, Column, Enum, select
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.ext.asyncio import AsyncSession
from database import Base from database import Base
class Counter(Base): class Permission(str, PyEnum):
__tablename__ = "counter" """All available permissions in the system."""
# Counter permissions
VIEW_COUNTER = "view_counter"
INCREMENT_COUNTER = "increment_counter"
# Sum permissions
USE_SUM = "use_sum"
# Audit permissions
VIEW_AUDIT = "view_audit"
id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
value: Mapped[int] = mapped_column(Integer, default=0) # Role definitions with their permissions
ROLE_DEFINITIONS = {
"admin": {
"description": "Administrator with audit access",
"permissions": [
Permission.VIEW_AUDIT,
],
},
"regular": {
"description": "Regular user with counter and sum access",
"permissions": [
Permission.VIEW_COUNTER,
Permission.INCREMENT_COUNTER,
Permission.USE_SUM,
],
},
}
# Association table: Role <-> Permission (many-to-many)
role_permissions = Table(
"role_permissions",
Base.metadata,
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
Column("permission", Enum(Permission), primary_key=True),
)
# Association table: User <-> Role (many-to-many)
user_roles = Table(
"user_roles",
Base.metadata,
Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True),
Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True),
)
class Role(Base):
__tablename__ = "roles"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
description: Mapped[str] = mapped_column(String(255), nullable=True)
# Relationship to users
users: Mapped[List["User"]] = relationship(
"User",
secondary=user_roles,
back_populates="roles",
)
async def get_permissions(self, db: AsyncSession) -> Set[Permission]:
"""Get all permissions for this role."""
result = await db.execute(
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id)
)
return {row[0] for row in result.fetchall()}
async def add_permission(self, db: AsyncSession, permission: Permission) -> None:
"""Add a permission to this role."""
await db.execute(role_permissions.insert().values(role_id=self.id, permission=permission))
async def set_permissions(self, db: AsyncSession, permissions: List[Permission]) -> None:
"""Set all permissions for this role (replaces existing)."""
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id))
for perm in permissions:
await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm))
class User(Base): class User(Base):
@ -17,6 +95,39 @@ class User(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
# Relationship to roles
roles: Mapped[List[Role]] = relationship(
"Role",
secondary=user_roles,
back_populates="users",
lazy="selectin",
)
async def get_permissions(self, db: AsyncSession) -> Set[Permission]:
"""Get all permissions from all roles."""
permissions: Set[Permission] = set()
for role in self.roles:
role_perms = await role.get_permissions(db)
permissions.update(role_perms)
return permissions
async def has_permission(self, db: AsyncSession, permission: Permission) -> bool:
"""Check if user has a specific permission through any of their roles."""
permissions = await self.get_permissions(db)
return permission in permissions
@property
def role_names(self) -> List[str]:
"""Get list of role names for API responses."""
return [role.name for role in self.roles]
class Counter(Base):
__tablename__ = "counter"
id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
value: Mapped[int] = mapped_column(Integer, default=0)
class SumRecord(Base): class SumRecord(Base):
@ -27,7 +138,9 @@ class SumRecord(Base):
a: Mapped[float] = mapped_column(Float, nullable=False) a: Mapped[float] = mapped_column(Float, nullable=False)
b: Mapped[float] = mapped_column(Float, nullable=False) b: Mapped[float] = mapped_column(Float, nullable=False)
result: Mapped[float] = mapped_column(Float, nullable=False) result: Mapped[float] = mapped_column(Float, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
class CounterRecord(Base): class CounterRecord(Base):
@ -37,5 +150,6 @@ class CounterRecord(Base):
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
value_before: Mapped[int] = mapped_column(Integer, nullable=False) value_before: Mapped[int] = mapped_column(Integer, nullable=False)
value_after: Mapped[int] = mapped_column(Integer, nullable=False) value_after: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)

View file

@ -1,14 +1,67 @@
"""Seed the database with a dev user.""" """Seed the database with roles, permissions, and dev users."""
import asyncio import asyncio
import os import os
from typing import List
from sqlalchemy import select from sqlalchemy import select
from database import engine, async_session, Base from database import engine, async_session, Base
from models import User from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS
from auth import get_password_hash from auth import get_password_hash
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"] DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"] DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
async def upsert_role(db, name: str, description: str, permissions: List[Permission]) -> Role:
"""Create or update a role with the given permissions."""
result = await db.execute(select(Role).where(Role.name == name))
role = result.scalar_one_or_none()
if role:
role.description = description
print(f"Updated role: {name}")
else:
role = Role(name=name, description=description)
db.add(role)
await db.flush() # Get the role ID
print(f"Created role: {name}")
# Set permissions for the role
await role.set_permissions(db, permissions)
print(f" Permissions: {', '.join(p.value for p in permissions)}")
return role
async def upsert_user(db, email: str, password: str, role_names: List[str]) -> User:
"""Create or update a user with the given credentials and roles."""
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
# Get roles
roles = []
for role_name in role_names:
result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none()
if role:
roles.append(role)
if user:
user.hashed_password = get_password_hash(password)
user.roles = roles
print(f"Updated user: {email} with roles: {role_names}")
else:
user = User(
email=email,
hashed_password=get_password_hash(password),
roles=roles,
)
db.add(user)
print(f"Created user: {email} with roles: {role_names}")
return user
async def seed(): async def seed():
@ -16,23 +69,25 @@ async def seed():
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
async with async_session() as db: async with async_session() as db:
result = await db.execute(select(User).where(User.email == DEV_USER_EMAIL)) print("\n=== Seeding Roles ===")
user = result.scalar_one_or_none() for role_name, role_config in ROLE_DEFINITIONS.items():
await upsert_role(
if user: db,
user.hashed_password = get_password_hash(DEV_USER_PASSWORD) role_name,
await db.commit() role_config["description"],
print(f"Updated dev user: {DEV_USER_EMAIL} / {DEV_USER_PASSWORD}") role_config["permissions"],
else:
user = User(
email=DEV_USER_EMAIL,
hashed_password=get_password_hash(DEV_USER_PASSWORD),
) )
db.add(user)
await db.commit() print("\n=== Seeding Users ===")
print(f"Created dev user: {DEV_USER_EMAIL} / {DEV_USER_PASSWORD}") # Create regular dev user
await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, ["regular"])
# Create admin dev user
await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, ["admin"])
await db.commit()
print("\n=== Seeding Complete ===\n")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(seed()) asyncio.run(seed())

View file

@ -1,15 +1,19 @@
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import List
# Set required env vars before importing app # Set required env vars before importing app
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only") os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy import select
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from database import Base, get_db from database import Base, get_db
from main import app from main import app
from models import User, Role, Permission, ROLE_DEFINITIONS
from auth import get_password_hash
TEST_DATABASE_URL = os.getenv( TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL", "TEST_DATABASE_URL",
@ -20,9 +24,10 @@ TEST_DATABASE_URL = os.getenv(
class ClientFactory: class ClientFactory:
"""Factory for creating httpx clients with optional cookies.""" """Factory for creating httpx clients with optional cookies."""
def __init__(self, transport, base_url): def __init__(self, transport, base_url, session_factory):
self._transport = transport self._transport = transport
self._base_url = base_url self._base_url = base_url
self._session_factory = session_factory
@asynccontextmanager @asynccontextmanager
async def create(self, cookies: dict | None = None): async def create(self, cookies: dict | None = None):
@ -45,6 +50,59 @@ class ClientFactory:
async def post(self, url: str, **kwargs): async def post(self, url: str, **kwargs):
return await self.request("POST", url, **kwargs) return await self.request("POST", url, **kwargs)
@asynccontextmanager
async def get_db_session(self):
"""Get a database session for direct DB operations in tests."""
async with self._session_factory() as session:
yield session
async def setup_roles(db: AsyncSession) -> dict[str, Role]:
"""Create all roles with their permissions from ROLE_DEFINITIONS."""
roles = {}
for role_name, config in ROLE_DEFINITIONS.items():
# Check if role exists
result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none()
if not role:
role = Role(name=role_name, description=config["description"])
db.add(role)
await db.flush()
# Set permissions
await role.set_permissions(db, config["permissions"])
roles[role_name] = role
await db.commit()
return roles
async def create_user_with_roles(
db: AsyncSession,
email: str,
password: str,
role_names: List[str],
) -> User:
"""Create a user with specified roles."""
# Get roles
roles = []
for role_name in role_names:
result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none()
if role:
roles.append(role)
user = User(
email=email,
hashed_password=get_password_hash(password),
roles=roles,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def client_factory(): async def client_factory():
@ -57,6 +115,10 @@ async def client_factory():
await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
# Setup roles
async with session_factory() as db:
await setup_roles(db)
async def override_get_db(): async def override_get_db():
async with session_factory() as session: async with session_factory() as session:
yield session yield session
@ -64,7 +126,7 @@ async def client_factory():
app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
factory = ClientFactory(transport, "http://test") factory = ClientFactory(transport, "http://test", session_factory)
yield factory yield factory
@ -77,3 +139,78 @@ async def client(client_factory):
"""Fixture for a simple client without cookies (backwards compatible).""" """Fixture for a simple client without cookies (backwards compatible)."""
async with client_factory.create() as c: async with client_factory.create() as c:
yield c yield c
@pytest.fixture(scope="function")
async def regular_user(client_factory):
"""Create a regular user and return their credentials and cookies."""
from tests.helpers import unique_email
email = unique_email("regular")
password = "password123"
async with client_factory.get_db_session() as db:
await create_user_with_roles(db, email, password, ["regular"])
# Login to get cookies
response = await client_factory.post(
"/api/auth/login",
json={"email": email, "password": password},
)
return {
"email": email,
"password": password,
"cookies": dict(response.cookies),
"response": response,
}
@pytest.fixture(scope="function")
async def admin_user(client_factory):
"""Create an admin user and return their credentials and cookies."""
from tests.helpers import unique_email
email = unique_email("admin")
password = "password123"
async with client_factory.get_db_session() as db:
await create_user_with_roles(db, email, password, ["admin"])
# Login to get cookies
response = await client_factory.post(
"/api/auth/login",
json={"email": email, "password": password},
)
return {
"email": email,
"password": password,
"cookies": dict(response.cookies),
"response": response,
}
@pytest.fixture(scope="function")
async def user_no_roles(client_factory):
"""Create a user with NO roles and return their credentials and cookies."""
from tests.helpers import unique_email
email = unique_email("noroles")
password = "password123"
async with client_factory.get_db_session() as db:
await create_user_with_roles(db, email, password, [])
# Login to get cookies
response = await client_factory.post(
"/api/auth/login",
json={"email": email, "password": password},
)
return {
"email": email,
"password": password,
"cookies": dict(response.cookies),
"response": response,
}

View file

@ -16,6 +16,10 @@ async def test_register_success(client):
data = response.json() data = response.json()
assert data["email"] == email assert data["email"] == email
assert "id" in data assert "id" in data
assert "roles" in data
assert "permissions" in data
# New users get regular role by default
assert "regular" in data["roles"]
# Cookie should be set # Cookie should be set
assert COOKIE_NAME in response.cookies assert COOKIE_NAME in response.cookies
@ -83,6 +87,8 @@ async def test_login_success(client):
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["email"] == email assert data["email"] == email
assert "roles" in data
assert "permissions" in data
assert COOKIE_NAME in response.cookies assert COOKIE_NAME in response.cookies
@ -146,6 +152,8 @@ async def test_get_me_success(client_factory):
data = response.json() data = response.json()
assert data["email"] == email assert data["email"] == email
assert "id" in data assert "id" in data
assert "roles" in data
assert "permissions" in data
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -0,0 +1,461 @@
"""
Permission and Role-Based Access Control Tests
These tests verify that:
1. Users can only access endpoints they have permission for
2. Users without proper roles are denied access (403)
3. Unauthenticated users are denied access (401)
4. The permission system cannot be bypassed
"""
import pytest
from models import Permission
# =============================================================================
# Role Assignment Tests
# =============================================================================
class TestRoleAssignment:
"""Test that roles are properly assigned and returned."""
@pytest.mark.asyncio
async def test_regular_user_has_correct_roles(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/auth/me")
assert response.status_code == 200
data = response.json()
assert "regular" in data["roles"]
assert "admin" not in data["roles"]
@pytest.mark.asyncio
async def test_admin_user_has_correct_roles(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/auth/me")
assert response.status_code == 200
data = response.json()
assert "admin" in data["roles"]
assert "regular" not in data["roles"]
@pytest.mark.asyncio
async def test_regular_user_has_correct_permissions(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/auth/me")
data = response.json()
permissions = data["permissions"]
# Should have counter and sum permissions
assert Permission.VIEW_COUNTER.value in permissions
assert Permission.INCREMENT_COUNTER.value in permissions
assert Permission.USE_SUM.value in permissions
# Should NOT have audit permission
assert Permission.VIEW_AUDIT.value not in permissions
@pytest.mark.asyncio
async def test_admin_user_has_correct_permissions(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/auth/me")
data = response.json()
permissions = data["permissions"]
# Should have audit permission
assert Permission.VIEW_AUDIT.value in permissions
# Should NOT have counter/sum permissions
assert Permission.VIEW_COUNTER.value not in permissions
assert Permission.INCREMENT_COUNTER.value not in permissions
assert Permission.USE_SUM.value not in permissions
@pytest.mark.asyncio
async def test_user_with_no_roles_has_no_permissions(self, client_factory, user_no_roles):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/auth/me")
data = response.json()
assert data["roles"] == []
assert data["permissions"] == []
# =============================================================================
# Counter Endpoint Access Tests
# =============================================================================
class TestCounterAccess:
"""Test access control for counter endpoints."""
@pytest.mark.asyncio
async def test_regular_user_can_view_counter(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/counter")
assert response.status_code == 200
assert "value" in response.json()
@pytest.mark.asyncio
async def test_regular_user_can_increment_counter(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post("/api/counter/increment")
assert response.status_code == 200
assert "value" in response.json()
@pytest.mark.asyncio
async def test_admin_cannot_view_counter(self, client_factory, admin_user):
"""Admin users should be forbidden from counter endpoints."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/counter")
assert response.status_code == 403
assert "permission" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_admin_cannot_increment_counter(self, client_factory, admin_user):
"""Admin users should be forbidden from incrementing counter."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/counter/increment")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_view_counter(self, client_factory, user_no_roles):
"""Users with no roles should be forbidden."""
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/counter")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_unauthenticated_cannot_view_counter(self, client):
"""Unauthenticated requests should get 401."""
response = await client.get("/api/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_unauthenticated_cannot_increment_counter(self, client):
"""Unauthenticated requests should get 401."""
response = await client.post("/api/counter/increment")
assert response.status_code == 401
# =============================================================================
# Sum Endpoint Access Tests
# =============================================================================
class TestSumAccess:
"""Test access control for sum endpoint."""
@pytest.mark.asyncio
async def test_regular_user_can_use_sum(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post(
"/api/sum",
json={"a": 5, "b": 3},
)
assert response.status_code == 200
data = response.json()
assert data["result"] == 8
@pytest.mark.asyncio
async def test_admin_cannot_use_sum(self, client_factory, admin_user):
"""Admin users should be forbidden from sum endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post(
"/api/sum",
json={"a": 5, "b": 3},
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_use_sum(self, client_factory, user_no_roles):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.post(
"/api/sum",
json={"a": 5, "b": 3},
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_unauthenticated_cannot_use_sum(self, client):
response = await client.post(
"/api/sum",
json={"a": 5, "b": 3},
)
assert response.status_code == 401
# =============================================================================
# Audit Endpoint Access Tests
# =============================================================================
class TestAuditAccess:
"""Test access control for audit endpoints."""
@pytest.mark.asyncio
async def test_admin_can_view_counter_audit(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
assert response.status_code == 200
data = response.json()
assert "records" in data
assert "total" in data
@pytest.mark.asyncio
async def test_admin_can_view_sum_audit(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/sum")
assert response.status_code == 200
data = response.json()
assert "records" in data
assert "total" in data
@pytest.mark.asyncio
async def test_regular_user_cannot_view_counter_audit(self, client_factory, regular_user):
"""Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
assert response.status_code == 403
assert "permission" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_regular_user_cannot_view_sum_audit(self, client_factory, regular_user):
"""Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/sum")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_view_audit(self, client_factory, user_no_roles):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/audit/counter")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_unauthenticated_cannot_view_counter_audit(self, client):
response = await client.get("/api/audit/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_unauthenticated_cannot_view_sum_audit(self, client):
response = await client.get("/api/audit/sum")
assert response.status_code == 401
# =============================================================================
# Offensive Security Tests - Bypass Attempts
# =============================================================================
class TestSecurityBypassAttempts:
"""
Offensive tests that attempt to bypass security controls.
These simulate potential attack vectors.
"""
@pytest.mark.asyncio
async def test_cannot_access_audit_with_forged_role_claim(self, client_factory, regular_user):
"""
Attempt to access audit by somehow claiming admin role.
The server should verify roles from DB, not trust client claims.
"""
# Regular user tries to access audit endpoint
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
# Should be denied regardless of any manipulation attempts
assert response.status_code == 403
@pytest.mark.asyncio
async def test_cannot_access_counter_with_expired_session(self, client_factory):
"""Test that invalid/expired tokens are rejected."""
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5OTk5IiwiZXhwIjoxfQ.invalid"
async with client_factory.create(cookies={"auth_token": fake_token}) as client:
response = await client.get("/api/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_cannot_access_with_tampered_token(self, client_factory, regular_user):
"""Test that tokens signed with wrong key are rejected."""
# Take a valid token and modify it
original_token = regular_user["cookies"].get("auth_token", "")
if original_token:
tampered_token = original_token[:-5] + "XXXXX"
async with client_factory.create(cookies={"auth_token": tampered_token}) as client:
response = await client.get("/api/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_cannot_escalate_to_admin_via_registration(self, client_factory):
"""
Test that new registrations cannot claim admin role.
New users should only get 'regular' role by default.
"""
from tests.helpers import unique_email
response = await client_factory.post(
"/api/auth/register",
json={"email": unique_email(), "password": "password123"},
)
assert response.status_code == 200
data = response.json()
# Should only have regular role, not admin
assert "admin" not in data["roles"]
assert Permission.VIEW_AUDIT.value not in data["permissions"]
# Try to access audit with this new user
async with client_factory.create(cookies=dict(response.cookies)) as client:
audit_response = await client.get("/api/audit/counter")
assert audit_response.status_code == 403
@pytest.mark.asyncio
async def test_deleted_user_token_is_invalid(self, client_factory):
"""
If a user is deleted, their token should no longer work.
This tests that tokens are validated against current DB state.
"""
from tests.helpers import unique_email
from sqlalchemy import delete
from models import User
email = unique_email("deleted")
# Create and login user
async with client_factory.get_db_session() as db:
from tests.conftest import create_user_with_roles
user = await create_user_with_roles(db, email, "password123", ["regular"])
user_id = user.id
login_response = await client_factory.post(
"/api/auth/login",
json={"email": email, "password": "password123"},
)
cookies = dict(login_response.cookies)
# Delete the user from DB
async with client_factory.get_db_session() as db:
await db.execute(delete(User).where(User.id == user_id))
await db.commit()
# Try to use the old token
async with client_factory.create(cookies=cookies) as client:
response = await client.get("/api/auth/me")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_role_change_reflected_immediately(self, client_factory):
"""
If a user's role is changed, the change should be reflected
in subsequent requests (no stale permission cache).
"""
from tests.helpers import unique_email
from sqlalchemy import select
from models import User, Role
email = unique_email("rolechange")
# Create regular user
async with client_factory.get_db_session() as db:
from tests.conftest import create_user_with_roles
await create_user_with_roles(db, email, "password123", ["regular"])
login_response = await client_factory.post(
"/api/auth/login",
json={"email": email, "password": "password123"},
)
cookies = dict(login_response.cookies)
# Verify can access counter but not audit
async with client_factory.create(cookies=cookies) as client:
assert (await client.get("/api/counter")).status_code == 200
assert (await client.get("/api/audit/counter")).status_code == 403
# Change user's role from regular to admin
async with client_factory.get_db_session() as db:
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one()
result = await db.execute(select(Role).where(Role.name == "admin"))
admin_role = result.scalar_one()
result = await db.execute(select(Role).where(Role.name == "regular"))
regular_role = result.scalar_one()
user.roles = [admin_role] # Remove regular, add admin
await db.commit()
# Now should have audit access but not counter access
async with client_factory.create(cookies=cookies) as client:
assert (await client.get("/api/audit/counter")).status_code == 200
assert (await client.get("/api/counter")).status_code == 403
# =============================================================================
# Audit Record Tests
# =============================================================================
class TestAuditRecords:
"""Test that actions are properly recorded in audit logs."""
@pytest.mark.asyncio
async def test_counter_increment_creates_audit_record(
self, client_factory, regular_user, admin_user
):
"""Verify that counter increments are recorded and visible in audit."""
# Regular user increments counter
async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post("/api/counter/increment")
# Admin checks audit
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
# Find record for our user
records = data["records"]
user_records = [r for r in records if r["user_email"] == regular_user["email"]]
assert len(user_records) >= 1
@pytest.mark.asyncio
async def test_sum_operation_creates_audit_record(
self, client_factory, regular_user, admin_user
):
"""Verify that sum operations are recorded and visible in audit."""
# Regular user uses sum
async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post("/api/sum", json={"a": 10, "b": 20})
# Admin checks audit
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/sum")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
# Find record with our values
records = data["records"]
matching = [r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30]
assert len(matching) >= 1

View file

@ -2,7 +2,7 @@
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useAuth } from "../auth-context"; import { useAuth, Permission } from "../auth-context";
import { API_URL } from "../config"; import { API_URL } from "../config";
interface CounterRecord { interface CounterRecord {
@ -35,26 +35,32 @@ export default function AuditPage() {
const [sumData, setSumData] = useState<PaginatedResponse<SumRecord> | null>(null); const [sumData, setSumData] = useState<PaginatedResponse<SumRecord> | null>(null);
const [counterPage, setCounterPage] = useState(1); const [counterPage, setCounterPage] = useState(1);
const [sumPage, setSumPage] = useState(1); const [sumPage, setSumPage] = useState(1);
const { user, isLoading, logout } = useAuth(); const { user, isLoading, logout, hasPermission } = useAuth();
const router = useRouter(); const router = useRouter();
useEffect(() => { const canViewAudit = hasPermission(Permission.VIEW_AUDIT);
if (!isLoading && !user) {
router.push("/login");
}
}, [isLoading, user, router]);
useEffect(() => { useEffect(() => {
if (user) { if (!isLoading) {
if (!user) {
router.push("/login");
} else if (!canViewAudit) {
router.push("/");
}
}
}, [isLoading, user, router, canViewAudit]);
useEffect(() => {
if (user && canViewAudit) {
fetchCounterRecords(counterPage); fetchCounterRecords(counterPage);
} }
}, [user, counterPage]); }, [user, counterPage, canViewAudit]);
useEffect(() => { useEffect(() => {
if (user) { if (user && canViewAudit) {
fetchSumRecords(sumPage); fetchSumRecords(sumPage);
} }
}, [user, sumPage]); }, [user, sumPage, canViewAudit]);
const fetchCounterRecords = async (page: number) => { const fetchCounterRecords = async (page: number) => {
try { try {
@ -97,7 +103,7 @@ export default function AuditPage() {
); );
} }
if (!user) { if (!user || !canViewAudit) {
return null; return null;
} }
@ -105,10 +111,6 @@ export default function AuditPage() {
<main style={styles.main}> <main style={styles.main}>
<div style={styles.header}> <div style={styles.header}>
<div style={styles.nav}> <div style={styles.nav}>
<a href="/" style={styles.navLink}>Counter</a>
<span style={styles.navDivider}></span>
<a href="/sum" style={styles.navLink}>Sum</a>
<span style={styles.navDivider}></span>
<span style={styles.navCurrent}>Audit</span> <span style={styles.navCurrent}>Audit</span>
</div> </div>
<div style={styles.userInfo}> <div style={styles.userInfo}>

View file

@ -4,9 +4,21 @@ import { createContext, useContext, useState, useEffect, ReactNode } from "react
import { API_URL } from "./config"; import { API_URL } from "./config";
// Permission constants matching backend
export const Permission = {
VIEW_COUNTER: "view_counter",
INCREMENT_COUNTER: "increment_counter",
USE_SUM: "use_sum",
VIEW_AUDIT: "view_audit",
} as const;
export type PermissionType = typeof Permission[keyof typeof Permission];
interface User { interface User {
id: number; id: number;
email: string; email: string;
roles: string[];
permissions: string[];
} }
interface AuthContextType { interface AuthContextType {
@ -15,6 +27,9 @@ interface AuthContextType {
login: (email: string, password: string) => Promise<void>; login: (email: string, password: string) => Promise<void>;
register: (email: string, password: string) => Promise<void>; register: (email: string, password: string) => Promise<void>;
logout: () => Promise<void>; logout: () => Promise<void>;
hasPermission: (permission: PermissionType) => boolean;
hasAnyPermission: (...permissions: PermissionType[]) => boolean;
hasRole: (role: string) => boolean;
} }
const AuthContext = createContext<AuthContextType | null>(null); const AuthContext = createContext<AuthContextType | null>(null);
@ -85,8 +100,31 @@ export function AuthProvider({ children }: { children: ReactNode }) {
setUser(null); setUser(null);
}; };
const hasPermission = (permission: PermissionType): boolean => {
return user?.permissions.includes(permission) ?? false;
};
const hasAnyPermission = (...permissions: PermissionType[]): boolean => {
return permissions.some((p) => user?.permissions.includes(p) ?? false);
};
const hasRole = (role: string): boolean => {
return user?.roles.includes(role) ?? false;
};
return ( return (
<AuthContext.Provider value={{ user, isLoading, login, register, logout }}> <AuthContext.Provider
value={{
user,
isLoading,
login,
register,
logout,
hasPermission,
hasAnyPermission,
hasRole,
}}
>
{children} {children}
</AuthContext.Provider> </AuthContext.Provider>
); );

View file

@ -11,23 +11,46 @@ vi.mock("next/navigation", () => ({
})); }));
// Default mock values // Default mock values
let mockUser: { id: number; email: string } | null = { id: 1, email: "test@example.com" }; let mockUser: { id: number; email: string; roles: string[]; permissions: string[] } | null = {
id: 1,
email: "test@example.com",
roles: ["regular"],
permissions: ["view_counter", "increment_counter", "use_sum"],
};
let mockIsLoading = false; let mockIsLoading = false;
const mockLogout = vi.fn(); const mockLogout = vi.fn();
const mockHasPermission = vi.fn((permission: string) =>
mockUser?.permissions.includes(permission) ?? false
);
vi.mock("./auth-context", () => ({ vi.mock("./auth-context", () => ({
useAuth: () => ({ useAuth: () => ({
user: mockUser, user: mockUser,
isLoading: mockIsLoading, isLoading: mockIsLoading,
logout: mockLogout, logout: mockLogout,
hasPermission: mockHasPermission,
}), }),
Permission: {
VIEW_COUNTER: "view_counter",
INCREMENT_COUNTER: "increment_counter",
USE_SUM: "use_sum",
VIEW_AUDIT: "view_audit",
},
})); }));
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
// Reset to authenticated state // Reset to authenticated state
mockUser = { id: 1, email: "test@example.com" }; mockUser = {
id: 1,
email: "test@example.com",
roles: ["regular"],
permissions: ["view_counter", "increment_counter", "use_sum"],
};
mockIsLoading = false; mockIsLoading = false;
mockHasPermission.mockImplementation((permission: string) =>
mockUser?.permissions.includes(permission) ?? false
);
}); });
afterEach(() => { afterEach(() => {

View file

@ -2,19 +2,26 @@
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useAuth } from "./auth-context"; import { useAuth, Permission } from "./auth-context";
import { API_URL } from "./config"; import { API_URL } from "./config";
export default function Home() { export default function Home() {
const [count, setCount] = useState<number | null>(null); const [count, setCount] = useState<number | null>(null);
const { user, isLoading, logout } = useAuth(); const { user, isLoading, logout, hasPermission } = useAuth();
const router = useRouter(); const router = useRouter();
const canViewCounter = hasPermission(Permission.VIEW_COUNTER);
useEffect(() => { useEffect(() => {
if (!isLoading && !user) { if (!isLoading) {
router.push("/login"); if (!user) {
router.push("/login");
} else if (!canViewCounter) {
// Redirect to audit if user has audit permission, otherwise to login
router.push(hasPermission(Permission.VIEW_AUDIT) ? "/audit" : "/login");
}
} }
}, [isLoading, user, router]); }, [isLoading, user, router, canViewCounter, hasPermission]);
useEffect(() => { useEffect(() => {
if (user) { if (user) {
@ -49,7 +56,7 @@ export default function Home() {
); );
} }
if (!user) { if (!user || !canViewCounter) {
return null; return null;
} }
@ -60,8 +67,6 @@ export default function Home() {
<span style={styles.navCurrent}>Counter</span> <span style={styles.navCurrent}>Counter</span>
<span style={styles.navDivider}></span> <span style={styles.navDivider}></span>
<a href="/sum" style={styles.navLink}>Sum</a> <a href="/sum" style={styles.navLink}>Sum</a>
<span style={styles.navDivider}></span>
<a href="/audit" style={styles.navLink}>Audit</a>
</div> </div>
<div style={styles.userInfo}> <div style={styles.userInfo}>
<span style={styles.userEmail}>{user.email}</span> <span style={styles.userEmail}>{user.email}</span>

View file

@ -2,7 +2,7 @@
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useAuth } from "../auth-context"; import { useAuth, Permission } from "../auth-context";
import { API_URL } from "../config"; import { API_URL } from "../config";
export default function SumPage() { export default function SumPage() {
@ -10,14 +10,20 @@ export default function SumPage() {
const [b, setB] = useState(""); const [b, setB] = useState("");
const [result, setResult] = useState<number | null>(null); const [result, setResult] = useState<number | null>(null);
const [showResult, setShowResult] = useState(false); const [showResult, setShowResult] = useState(false);
const { user, isLoading, logout } = useAuth(); const { user, isLoading, logout, hasPermission } = useAuth();
const router = useRouter(); const router = useRouter();
const canUseSum = hasPermission(Permission.USE_SUM);
useEffect(() => { useEffect(() => {
if (!isLoading && !user) { if (!isLoading) {
router.push("/login"); if (!user) {
router.push("/login");
} else if (!canUseSum) {
router.push(hasPermission(Permission.VIEW_AUDIT) ? "/audit" : "/login");
}
} }
}, [isLoading, user, router]); }, [isLoading, user, router, canUseSum, hasPermission]);
const handleSum = async () => { const handleSum = async () => {
const numA = parseFloat(a) || 0; const numA = parseFloat(a) || 0;
@ -60,7 +66,7 @@ export default function SumPage() {
); );
} }
if (!user) { if (!user || !canUseSum) {
return null; return null;
} }
@ -71,8 +77,6 @@ export default function SumPage() {
<a href="/" style={styles.navLink}>Counter</a> <a href="/" style={styles.navLink}>Counter</a>
<span style={styles.navDivider}></span> <span style={styles.navDivider}></span>
<span style={styles.navCurrent}>Sum</span> <span style={styles.navCurrent}>Sum</span>
<span style={styles.navDivider}></span>
<a href="/audit" style={styles.navLink}>Audit</a>
</div> </div>
<div style={styles.userInfo}> <div style={styles.userInfo}>
<span style={styles.userEmail}>{user.email}</span> <span style={styles.userEmail}>{user.email}</span>

View file

@ -46,10 +46,22 @@ test.describe("Counter - Authenticated", () => {
await expect(page.locator("h1")).not.toHaveText("..."); await expect(page.locator("h1")).not.toHaveText("...");
const before = Number(await page.locator("h1").textContent()); const before = Number(await page.locator("h1").textContent());
// Click increment and wait for each update to complete
await page.click("text=Increment"); await page.click("text=Increment");
await expect(page.locator("h1")).not.toHaveText(String(before));
const afterFirst = Number(await page.locator("h1").textContent());
await page.click("text=Increment"); await page.click("text=Increment");
await expect(page.locator("h1")).not.toHaveText(String(afterFirst));
const afterSecond = Number(await page.locator("h1").textContent());
await page.click("text=Increment"); await page.click("text=Increment");
await expect(page.locator("h1")).toHaveText(String(before + 3)); await expect(page.locator("h1")).not.toHaveText(String(afterSecond));
// Final value should be at least 3 more than we started with
const final = Number(await page.locator("h1").textContent());
expect(final).toBeGreaterThanOrEqual(before + 3);
}); });
test("counter persists after page reload", async ({ page }) => { test("counter persists after page reload", async ({ page }) => {
@ -73,21 +85,28 @@ test.describe("Counter - Authenticated", () => {
const initialValue = Number(await page.locator("h1").textContent()); const initialValue = Number(await page.locator("h1").textContent());
await page.click("text=Increment"); await page.click("text=Increment");
await page.click("text=Increment"); await page.click("text=Increment");
const afterFirst = initialValue + 2; // Wait for the counter to update (value should increase by 2 from what this user started with)
await expect(page.locator("h1")).toHaveText(String(afterFirst)); await expect(page.locator("h1")).not.toHaveText(String(initialValue));
const afterFirstUser = Number(await page.locator("h1").textContent());
expect(afterFirstUser).toBeGreaterThan(initialValue);
// Second user in new context sees the same value // Second user in new context sees the current value
const page2 = await browser.newPage(); const page2 = await browser.newPage();
await authenticate(page2); await authenticate(page2);
await expect(page2.locator("h1")).toHaveText(String(afterFirst)); await expect(page2.locator("h1")).not.toHaveText("...");
const page2InitialValue = Number(await page2.locator("h1").textContent());
// The value should be at least what user 1 saw (might be higher due to parallel tests)
expect(page2InitialValue).toBeGreaterThanOrEqual(afterFirstUser);
// Second user increments // Second user increments
await page2.click("text=Increment"); await page2.click("text=Increment");
await expect(page2.locator("h1")).toHaveText(String(afterFirst + 1)); await expect(page2.locator("h1")).toHaveText(String(page2InitialValue + 1));
// First user reloads and sees the increment // First user reloads and sees the increment (value should be >= what page2 has)
await page.reload(); await page.reload();
await expect(page.locator("h1")).toHaveText(String(afterFirst + 1)); await expect(page.locator("h1")).not.toHaveText("...");
const page1Reloaded = Number(await page.locator("h1").textContent());
expect(page1Reloaded).toBeGreaterThanOrEqual(page2InitialValue + 1);
await page2.close(); await page2.close();
}); });
@ -129,8 +148,9 @@ test.describe("Counter - Session Integration", () => {
await page.click('button[type="submit"]'); await page.click('button[type="submit"]');
await expect(page).toHaveURL("/"); await expect(page).toHaveURL("/");
// Counter should be visible // Counter should be visible - wait for it to load (not showing "...")
await expect(page.locator("h1")).toBeVisible(); await expect(page.locator("h1")).toBeVisible();
await expect(page.locator("h1")).not.toHaveText("...");
const text = await page.locator("h1").textContent(); const text = await page.locator("h1").textContent();
expect(text).toMatch(/^\d+$/); expect(text).toMatch(/^\d+$/);
}); });

View file

@ -0,0 +1,324 @@
import { test, expect, Page, APIRequestContext } from "@playwright/test";
/**
* Permission-based E2E tests
*
* These tests verify that:
* 1. Regular users can only access Counter and Sum pages
* 2. Admin users can only access the Audit page
* 3. Users are properly redirected based on their permissions
* 4. API calls respect permission boundaries
*/
const API_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000";
// Test credentials - must match what's seeded in the database via seed.py
// These come from environment variables DEV_USER_EMAIL/PASSWORD and DEV_ADMIN_EMAIL/PASSWORD
const REGULAR_USER = {
email: process.env.DEV_USER_EMAIL || "user@example.com",
password: process.env.DEV_USER_PASSWORD || "user123",
};
const ADMIN_USER = {
email: process.env.DEV_ADMIN_EMAIL || "admin@example.com",
password: process.env.DEV_ADMIN_PASSWORD || "admin123",
};
// Helper to clear auth cookies
async function clearAuth(page: Page) {
await page.context().clearCookies();
}
// Helper to create a user with specific role via API
async function createUserWithRole(
request: APIRequestContext,
email: string,
password: string,
roleName: string
): Promise<void> {
// This requires direct DB access or a test endpoint
// For now, we'll use the seeded users from conftest
}
// Helper to login a user
async function loginUser(page: Page, email: string, password: string) {
await page.goto("/login");
await page.fill('input[type="email"]', email);
await page.fill('input[type="password"]', password);
await page.click('button[type="submit"]');
// Wait for navigation away from login page
await page.waitForURL((url) => !url.pathname.includes("/login"), { timeout: 10000 });
}
// Setup: Users are pre-seeded via seed.py before e2e tests run
// The seed script creates:
// - A regular user (DEV_USER_EMAIL/PASSWORD) with "regular" role
// - An admin user (DEV_ADMIN_EMAIL/PASSWORD) with "admin" role
test.beforeAll(async () => {
// No need to create users - they are seeded by scripts/e2e.sh
});
test.describe("Regular User Access", () => {
test.beforeEach(async ({ page }) => {
await clearAuth(page);
await loginUser(page, REGULAR_USER.email, REGULAR_USER.password);
});
test("can access counter page", async ({ page }) => {
await page.goto("/");
// Should stay on counter page
await expect(page).toHaveURL("/");
// Should see counter UI
await expect(page.getByText("Current Count")).toBeVisible();
await expect(page.getByRole("button", { name: /increment/i })).toBeVisible();
});
test("can access sum page", async ({ page }) => {
await page.goto("/sum");
// Should stay on sum page
await expect(page).toHaveURL("/sum");
// Should see sum UI
await expect(page.getByText("Sum Calculator")).toBeVisible();
});
test("cannot access audit page - redirected to counter", async ({ page }) => {
await page.goto("/audit");
// Should be redirected to counter page (home)
await expect(page).toHaveURL("/");
});
test("navigation only shows Counter and Sum", async ({ page }) => {
await page.goto("/");
// Should see Counter and Sum in nav
await expect(page.getByText("Counter")).toBeVisible();
await expect(page.getByText("Sum")).toBeVisible();
// Should NOT see Audit in nav (for regular users)
const auditLinks = page.locator('a[href="/audit"]');
await expect(auditLinks).toHaveCount(0);
});
test("can navigate between Counter and Sum", async ({ page }) => {
await page.goto("/");
// Go to Sum
await page.click('a[href="/sum"]');
await expect(page).toHaveURL("/sum");
// Go back to Counter
await page.click('a[href="/"]');
await expect(page).toHaveURL("/");
});
test("can use counter functionality", async ({ page }) => {
await page.goto("/");
// Get initial count (might be any number)
const countElement = page.locator("h1").first();
await expect(countElement).toBeVisible();
// Click increment
await page.click('button:has-text("Increment")');
// Wait for update
await page.waitForTimeout(500);
// Counter should have updated (we just verify no error occurred)
await expect(countElement).toBeVisible();
});
test("can use sum functionality", async ({ page }) => {
await page.goto("/sum");
// Fill in numbers
await page.fill('input[aria-label="First number"]', "5");
await page.fill('input[aria-label="Second number"]', "3");
// Calculate
await page.click('button:has-text("Calculate")');
// Should show result
await expect(page.getByText("8")).toBeVisible();
});
});
test.describe("Admin User Access", () => {
// Skip these tests if admin user isn't set up
// In real scenario, you'd create admin user in beforeAll
test.skip(
!process.env.DEV_ADMIN_EMAIL,
"Admin tests require DEV_ADMIN_EMAIL and DEV_ADMIN_PASSWORD env vars"
);
const adminEmail = process.env.DEV_ADMIN_EMAIL || ADMIN_USER.email;
const adminPassword = process.env.DEV_ADMIN_PASSWORD || ADMIN_USER.password;
test.beforeEach(async ({ page }) => {
await clearAuth(page);
await loginUser(page, adminEmail, adminPassword);
});
test("redirected from counter page to audit", async ({ page }) => {
await page.goto("/");
// Should be redirected to audit page
await expect(page).toHaveURL("/audit");
});
test("redirected from sum page to audit", async ({ page }) => {
await page.goto("/sum");
// Should be redirected to audit page
await expect(page).toHaveURL("/audit");
});
test("can access audit page", async ({ page }) => {
await page.goto("/audit");
// Should stay on audit page
await expect(page).toHaveURL("/audit");
// Should see audit tables
await expect(page.getByText("Counter Activity")).toBeVisible();
await expect(page.getByText("Sum Activity")).toBeVisible();
});
test("navigation only shows Audit", async ({ page }) => {
await page.goto("/audit");
// Should see Audit as current
await expect(page.getByText("Audit")).toBeVisible();
// Should NOT see Counter or Sum links (for admin users)
const counterLinks = page.locator('a[href="/"]');
const sumLinks = page.locator('a[href="/sum"]');
await expect(counterLinks).toHaveCount(0);
await expect(sumLinks).toHaveCount(0);
});
test("audit page shows records", async ({ page }) => {
await page.goto("/audit");
// Should see the tables
await expect(page.getByRole("table")).toHaveCount(2);
// Should see column headers (use first() since there are two tables with same headers)
await expect(page.getByRole("columnheader", { name: "User" }).first()).toBeVisible();
await expect(page.getByRole("columnheader", { name: "Date" }).first()).toBeVisible();
});
});
test.describe("Unauthenticated Access", () => {
test.beforeEach(async ({ page }) => {
await clearAuth(page);
});
test("counter page redirects to login", async ({ page }) => {
await page.goto("/");
await expect(page).toHaveURL("/login");
});
test("sum page redirects to login", async ({ page }) => {
await page.goto("/sum");
await expect(page).toHaveURL("/login");
});
test("audit page redirects to login", async ({ page }) => {
await page.goto("/audit");
await expect(page).toHaveURL("/login");
});
});
test.describe("Permission Boundary via API", () => {
test("regular user API call to audit returns 403", async ({ page, request }) => {
// Login as regular user
await clearAuth(page);
await loginUser(page, REGULAR_USER.email, REGULAR_USER.password);
// Get cookies
const cookies = await page.context().cookies();
const authCookie = cookies.find(c => c.name === "auth_token");
if (authCookie) {
// Try to call audit API directly
const response = await request.get(`${API_URL}/api/audit/counter`, {
headers: {
Cookie: `auth_token=${authCookie.value}`,
},
});
expect(response.status()).toBe(403);
}
});
test("admin user API call to counter returns 403", async ({ page, request }) => {
const adminEmail = process.env.DEV_ADMIN_EMAIL;
const adminPassword = process.env.DEV_ADMIN_PASSWORD;
if (!adminEmail || !adminPassword) {
test.skip();
return;
}
// Login as admin
await clearAuth(page);
await loginUser(page, adminEmail, adminPassword);
// Get cookies
const cookies = await page.context().cookies();
const authCookie = cookies.find(c => c.name === "auth_token");
if (authCookie) {
// Try to call counter API directly
const response = await request.get(`${API_URL}/api/counter`, {
headers: {
Cookie: `auth_token=${authCookie.value}`,
},
});
expect(response.status()).toBe(403);
}
});
});
test.describe("Session and Logout", () => {
test("logout clears permissions - cannot access protected pages", async ({ page }) => {
// Login
await clearAuth(page);
await loginUser(page, REGULAR_USER.email, REGULAR_USER.password);
await expect(page).toHaveURL("/");
// Logout
await page.click("text=Sign out");
await expect(page).toHaveURL("/login");
// Try to access counter
await page.goto("/");
await expect(page).toHaveURL("/login");
});
test("cannot access pages with tampered cookie", async ({ page, context }) => {
// Set a fake auth cookie
await context.addCookies([
{
name: "auth_token",
value: "fake-token-that-should-not-work",
domain: "localhost",
path: "/",
},
]);
// Try to access protected page
await page.goto("/");
// Should be redirected to login
await expect(page).toHaveURL("/login");
});
});

View file

@ -3,6 +3,13 @@ set -e
cd "$(dirname "$0")/.." cd "$(dirname "$0")/.."
# Load environment variables if .env exists
if [ -f .env ]; then
set -a
source .env
set +a
fi
# Kill any existing backend # Kill any existing backend
pkill -f "uvicorn main:app" 2>/dev/null || true pkill -f "uvicorn main:app" 2>/dev/null || true
sleep 1 sleep 1
@ -10,6 +17,15 @@ sleep 1
# Start db # Start db
docker compose up -d db docker compose up -d db
# Wait for db to be ready
sleep 2
# Seed the database with roles and test users
cd backend
echo "Seeding database..."
uv run python seed.py
cd ..
# Start backend (SECRET_KEY should be set via .envrc or environment) # Start backend (SECRET_KEY should be set via .envrc or environment)
cd backend cd backend
uv run uvicorn main:app --port 8000 & uv run uvicorn main:app --port 8000 &