finish branch

This commit is contained in:
counterweight 2025-12-19 00:12:43 +01:00
parent 66bc4c5a45
commit 40ca82bb45
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
11 changed files with 139 additions and 128 deletions

View file

@ -1,4 +1,4 @@
.PHONY: install-backend install-frontend install backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e .PHONY: install-backend install-frontend install backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck
-include .env -include .env
export export
@ -52,3 +52,6 @@ test-e2e:
./scripts/e2e.sh ./scripts/e2e.sh
test: test-backend test-frontend test-e2e test: test-backend test-frontend test-e2e
typecheck:
cd backend && uv run mypy .

View file

@ -1,6 +1,5 @@
import os import os
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import List, Optional
import bcrypt import bcrypt
from fastapi import Depends, HTTPException, Request, status from fastapi import Depends, HTTPException, Request, status
@ -30,8 +29,8 @@ UserLogin = UserCredentials
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: int id: int
email: str email: str
roles: List[str] roles: list[str]
permissions: List[str] permissions: list[str]
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
@ -54,19 +53,20 @@ def get_password_hash(password: str) -> str:
).decode("utf-8") ).decode("utf-8")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str:
to_encode = data.copy() to_encode: dict[str, str | datetime] = dict(data)
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": expire}) to_encode["exp"] = expire
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]: async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
result = await db.execute(select(User).where(User.email == email)) result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]: async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None:
user = await get_user_by_email(db, email) user = await get_user_by_email(db, email)
if not user or not verify_password(password, user.hashed_password): if not user or not verify_password(password, user.hashed_password):
return None return None

View file

@ -1,6 +1,6 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from typing import List from typing import Any, Callable, Generic, TypeVar
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -9,7 +9,43 @@ from sqlalchemy import select, func, desc
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, get_db, Base from database import engine, get_db, Base
from models import Counter, User, SumRecord, CounterRecord, Permission, Role from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR
R = TypeVar("R", bound=BaseModel)
async def paginate_with_user_email(
db: AsyncSession,
model: type[SumRecord] | type[CounterRecord],
page: int,
per_page: int,
row_mapper: Callable[..., R],
) -> tuple[list[R], int, int]:
"""
Generic pagination helper for audit records that need user email.
Returns: (records, total, total_pages)
"""
# Get total count
count_result = await db.execute(select(func.count(model.id)))
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(model, User.email)
.join(User, model.user_id == User.id)
.order_by(desc(model.created_at))
.offset(offset)
.limit(per_page)
)
result = await db.execute(query)
rows = result.all()
records: list[R] = [row_mapper(record, email) for record, email in rows]
return records, total, total_pages
from auth import ( from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES, ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME, COOKIE_NAME,
@ -57,7 +93,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
async def get_default_role(db: AsyncSession) -> Role | None: async def get_default_role(db: AsyncSession) -> Role | None:
"""Get the default 'regular' role for new users.""" """Get the default 'regular' role for new users."""
result = await db.execute(select(Role).where(Role.name == "regular")) result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR))
return result.scalar_one_or_none() return result.scalar_one_or_none()
@ -214,20 +250,30 @@ class SumRecordResponse(BaseModel):
created_at: datetime created_at: datetime
class PaginatedCounterRecords(BaseModel): RecordT = TypeVar("RecordT", bound=BaseModel)
records: List[CounterRecordResponse]
class PaginatedResponse(BaseModel, Generic[RecordT]):
"""Generic paginated response wrapper."""
records: list[RecordT]
total: int total: int
page: int page: int
per_page: int per_page: int
total_pages: int total_pages: int
class PaginatedSumRecords(BaseModel): PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
records: List[SumRecordResponse] PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
total: int
page: int
per_page: int def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse:
total_pages: int return CounterRecordResponse(
id=record.id,
user_email=email,
value_before=record.value_before,
value_after=record.value_after,
created_at=record.created_at,
)
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords) @app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
@ -237,34 +283,9 @@ async def get_counter_records(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)), _current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
): ):
# Get total count records, total, total_pages = await paginate_with_user_email(
count_result = await db.execute(select(func.count(CounterRecord.id))) db, CounterRecord, page, per_page, _map_counter_record
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(CounterRecord, User.email)
.join(User, CounterRecord.user_id == User.id)
.order_by(desc(CounterRecord.created_at))
.offset(offset)
.limit(per_page)
) )
result = await db.execute(query)
rows = result.all()
records = [
CounterRecordResponse(
id=record.id,
user_email=email,
value_before=record.value_before,
value_after=record.value_after,
created_at=record.created_at,
)
for record, email in rows
]
return PaginatedCounterRecords( return PaginatedCounterRecords(
records=records, records=records,
total=total, total=total,
@ -274,6 +295,17 @@ async def get_counter_records(
) )
def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse:
return SumRecordResponse(
id=record.id,
user_email=email,
a=record.a,
b=record.b,
result=record.result,
created_at=record.created_at,
)
@app.get("/api/audit/sum", response_model=PaginatedSumRecords) @app.get("/api/audit/sum", response_model=PaginatedSumRecords)
async def get_sum_records( async def get_sum_records(
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
@ -281,35 +313,9 @@ async def get_sum_records(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)), _current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
): ):
# Get total count records, total, total_pages = await paginate_with_user_email(
count_result = await db.execute(select(func.count(SumRecord.id))) db, SumRecord, page, per_page, _map_sum_record
total = count_result.scalar() or 0
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
# Get paginated records with user email
offset = (page - 1) * per_page
query = (
select(SumRecord, User.email)
.join(User, SumRecord.user_id == User.id)
.order_by(desc(SumRecord.created_at))
.offset(offset)
.limit(per_page)
) )
result = await db.execute(query)
rows = result.all()
records = [
SumRecordResponse(
id=record.id,
user_email=email,
a=record.a,
b=record.b,
result=record.result,
created_at=record.created_at,
)
for record, email in rows
]
return PaginatedSumRecords( return PaginatedSumRecords(
records=records, records=records,
total=total, total=total,

View file

@ -1,12 +1,17 @@
from datetime import datetime, UTC from datetime import datetime, UTC
from enum import Enum as PyEnum from enum import Enum as PyEnum
from typing import List, Set from typing import TypedDict
from sqlalchemy import Integer, String, Float, DateTime, ForeignKey, Table, Column, Enum, select from sqlalchemy import Integer, String, Float, DateTime, ForeignKey, Table, Column, Enum, select
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import Base from database import Base
class RoleConfig(TypedDict):
description: str
permissions: list["Permission"]
class Permission(str, PyEnum): class Permission(str, PyEnum):
"""All available permissions in the system.""" """All available permissions in the system."""
# Counter permissions # Counter permissions
@ -20,15 +25,19 @@ class Permission(str, PyEnum):
VIEW_AUDIT = "view_audit" VIEW_AUDIT = "view_audit"
# Role name constants
ROLE_ADMIN = "admin"
ROLE_REGULAR = "regular"
# Role definitions with their permissions # Role definitions with their permissions
ROLE_DEFINITIONS = { ROLE_DEFINITIONS: dict[str, RoleConfig] = {
"admin": { ROLE_ADMIN: {
"description": "Administrator with audit access", "description": "Administrator with audit access",
"permissions": [ "permissions": [
Permission.VIEW_AUDIT, Permission.VIEW_AUDIT,
], ],
}, },
"regular": { ROLE_REGULAR: {
"description": "Regular user with counter and sum access", "description": "Regular user with counter and sum access",
"permissions": [ "permissions": [
Permission.VIEW_COUNTER, Permission.VIEW_COUNTER,
@ -65,24 +74,20 @@ class Role(Base):
description: Mapped[str] = mapped_column(String(255), nullable=True) description: Mapped[str] = mapped_column(String(255), nullable=True)
# Relationship to users # Relationship to users
users: Mapped[List["User"]] = relationship( users: Mapped[list["User"]] = relationship(
"User", "User",
secondary=user_roles, secondary=user_roles,
back_populates="roles", back_populates="roles",
) )
async def get_permissions(self, db: AsyncSession) -> Set[Permission]: async def get_permissions(self, db: AsyncSession) -> set[Permission]:
"""Get all permissions for this role.""" """Get all permissions for this role."""
result = await db.execute( result = await db.execute(
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id) select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id)
) )
return {row[0] for row in result.fetchall()} return {row[0] for row in result.fetchall()}
async def add_permission(self, db: AsyncSession, permission: Permission) -> None: async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None:
"""Add a permission to this role."""
await db.execute(role_permissions.insert().values(role_id=self.id, permission=permission))
async def set_permissions(self, db: AsyncSession, permissions: List[Permission]) -> None:
"""Set all permissions for this role (replaces existing).""" """Set all permissions for this role (replaces existing)."""
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id)) await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id))
for perm in permissions: for perm in permissions:
@ -97,20 +102,21 @@ class User(Base):
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
# Relationship to roles # Relationship to roles
roles: Mapped[List[Role]] = relationship( roles: Mapped[list[Role]] = relationship(
"Role", "Role",
secondary=user_roles, secondary=user_roles,
back_populates="users", back_populates="users",
lazy="selectin", lazy="selectin",
) )
async def get_permissions(self, db: AsyncSession) -> Set[Permission]: async def get_permissions(self, db: AsyncSession) -> set[Permission]:
"""Get all permissions from all roles.""" """Get all permissions from all roles in a single query."""
permissions: Set[Permission] = set() result = await db.execute(
for role in self.roles: select(role_permissions.c.permission)
role_perms = await role.get_permissions(db) .join(user_roles, role_permissions.c.role_id == user_roles.c.role_id)
permissions.update(role_perms) .where(user_roles.c.user_id == self.id)
return permissions )
return {row[0] for row in result.fetchall()}
async def has_permission(self, db: AsyncSession, permission: Permission) -> bool: async def has_permission(self, db: AsyncSession, permission: Permission) -> bool:
"""Check if user has a specific permission through any of their roles.""" """Check if user has a specific permission through any of their roles."""
@ -118,7 +124,7 @@ class User(Base):
return permission in permissions return permission in permissions
@property @property
def role_names(self) -> List[str]: def role_names(self) -> list[str]:
"""Get list of role names for API responses.""" """Get list of role names for API responses."""
return [role.name for role in self.roles] return [role.name for role in self.roles]

View file

@ -18,5 +18,14 @@ dev = [
"pytest-asyncio>=0.25.0", "pytest-asyncio>=0.25.0",
"httpx>=0.28.1", "httpx>=0.28.1",
"aiosqlite>=0.20.0", "aiosqlite>=0.20.0",
"mypy>=1.13.0",
] ]
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_ignores = true
check_untyped_defs = true
ignore_missing_imports = true
exclude = ["tests/"]

View file

@ -1,11 +1,11 @@
"""Seed the database with roles, permissions, and dev users.""" """Seed the database with roles, permissions, and dev users."""
import asyncio import asyncio
import os import os
from typing import List
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, async_session, Base from database import engine, async_session, Base
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
from auth import get_password_hash from auth import get_password_hash
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"] DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
@ -14,7 +14,7 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"] DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
async def upsert_role(db, name: str, description: str, permissions: List[Permission]) -> Role: async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role:
"""Create or update a role with the given permissions.""" """Create or update a role with the given permissions."""
result = await db.execute(select(Role).where(Role.name == name)) result = await db.execute(select(Role).where(Role.name == name))
role = result.scalar_one_or_none() role = result.scalar_one_or_none()
@ -35,7 +35,7 @@ async def upsert_role(db, name: str, description: str, permissions: List[Permiss
return role return role
async def upsert_user(db, email: str, password: str, role_names: List[str]) -> User: async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User:
"""Create or update a user with the given credentials and roles.""" """Create or update a user with the given credentials and roles."""
result = await db.execute(select(User).where(User.email == email)) result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
@ -45,12 +45,13 @@ async def upsert_user(db, email: str, password: str, role_names: List[str]) -> U
for role_name in role_names: for role_name in role_names:
result = await db.execute(select(Role).where(Role.name == role_name)) result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none() role = result.scalar_one_or_none()
if role: if not role:
roles.append(role) raise ValueError(f"Role '{role_name}' not found")
roles.append(role)
if user: if user:
user.hashed_password = get_password_hash(password) user.hashed_password = get_password_hash(password)
user.roles = roles user.roles = roles # type: ignore[assignment]
print(f"Updated user: {email} with roles: {role_names}") print(f"Updated user: {email} with roles: {role_names}")
else: else:
user = User( user = User(
@ -64,7 +65,7 @@ async def upsert_user(db, email: str, password: str, role_names: List[str]) -> U
return user return user
async def seed(): async def seed() -> None:
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
@ -80,10 +81,10 @@ async def seed():
print("\n=== Seeding Users ===") print("\n=== Seeding Users ===")
# Create regular dev user # Create regular dev user
await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, ["regular"]) await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, [ROLE_REGULAR])
# Create admin dev user # Create admin dev user
await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, ["admin"]) await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, [ROLE_ADMIN])
await db.commit() await db.commit()
print("\n=== Seeding Complete ===\n") print("\n=== Seeding Complete ===\n")

View file

@ -1,6 +1,5 @@
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import List
# Set required env vars before importing app # Set required env vars before importing app
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only") os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
@ -12,8 +11,9 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, Asyn
from database import Base, get_db from database import Base, get_db
from main import app from main import app
from models import User, Role, Permission, ROLE_DEFINITIONS from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
from auth import get_password_hash from auth import get_password_hash
from tests.helpers import unique_email
TEST_DATABASE_URL = os.getenv( TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL", "TEST_DATABASE_URL",
@ -82,7 +82,7 @@ async def create_user_with_roles(
db: AsyncSession, db: AsyncSession,
email: str, email: str,
password: str, password: str,
role_names: List[str], role_names: list[str],
) -> User: ) -> User:
"""Create a user with specified roles.""" """Create a user with specified roles."""
# Get roles # Get roles
@ -90,8 +90,9 @@ async def create_user_with_roles(
for role_name in role_names: for role_name in role_names:
result = await db.execute(select(Role).where(Role.name == role_name)) result = await db.execute(select(Role).where(Role.name == role_name))
role = result.scalar_one_or_none() role = result.scalar_one_or_none()
if role: if not role:
roles.append(role) raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?")
roles.append(role)
user = User( user = User(
email=email, email=email,
@ -144,13 +145,11 @@ async def client(client_factory):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def regular_user(client_factory): async def regular_user(client_factory):
"""Create a regular user and return their credentials and cookies.""" """Create a regular user and return their credentials and cookies."""
from tests.helpers import unique_email
email = unique_email("regular") email = unique_email("regular")
password = "password123" password = "password123"
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
await create_user_with_roles(db, email, password, ["regular"]) await create_user_with_roles(db, email, password, [ROLE_REGULAR])
# Login to get cookies # Login to get cookies
response = await client_factory.post( response = await client_factory.post(
@ -169,13 +168,11 @@ async def regular_user(client_factory):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def admin_user(client_factory): async def admin_user(client_factory):
"""Create an admin user and return their credentials and cookies.""" """Create an admin user and return their credentials and cookies."""
from tests.helpers import unique_email
email = unique_email("admin") email = unique_email("admin")
password = "password123" password = "password123"
async with client_factory.get_db_session() as db: async with client_factory.get_db_session() as db:
await create_user_with_roles(db, email, password, ["admin"]) await create_user_with_roles(db, email, password, [ROLE_ADMIN])
# Login to get cookies # Login to get cookies
response = await client_factory.post( response = await client_factory.post(
@ -194,8 +191,6 @@ async def admin_user(client_factory):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def user_no_roles(client_factory): async def user_no_roles(client_factory):
"""Create a user with NO roles and return their credentials and cookies.""" """Create a user with NO roles and return their credentials and cookies."""
from tests.helpers import unique_email
email = unique_email("noroles") email = unique_email("noroles")
password = "password123" password = "password123"

View file

@ -393,4 +393,3 @@ const pageStyles: Record<string, React.CSSProperties> = {
}; };
const styles = { ...sharedStyles, ...pageStyles }; const styles = { ...sharedStyles, ...pageStyles };

View file

@ -28,7 +28,6 @@ interface AuthContextType {
register: (email: string, password: string) => Promise<void>; register: (email: string, password: string) => Promise<void>;
logout: () => Promise<void>; logout: () => Promise<void>;
hasPermission: (permission: PermissionType) => boolean; hasPermission: (permission: PermissionType) => boolean;
hasAnyPermission: (...permissions: PermissionType[]) => boolean;
hasRole: (role: string) => boolean; hasRole: (role: string) => boolean;
} }
@ -104,10 +103,6 @@ export function AuthProvider({ children }: { children: ReactNode }) {
return user?.permissions.includes(permission) ?? false; return user?.permissions.includes(permission) ?? false;
}, [user]); }, [user]);
const hasAnyPermission = useCallback((...permissions: PermissionType[]): boolean => {
return permissions.some((p) => user?.permissions.includes(p) ?? false);
}, [user]);
const hasRole = useCallback((role: string): boolean => { const hasRole = useCallback((role: string): boolean => {
return user?.roles.includes(role) ?? false; return user?.roles.includes(role) ?? false;
}, [user]); }, [user]);
@ -121,7 +116,6 @@ export function AuthProvider({ children }: { children: ReactNode }) {
register, register,
logout, logout,
hasPermission, hasPermission,
hasAnyPermission,
hasRole, hasRole,
}} }}
> >

View file

@ -79,4 +79,3 @@ export const sharedStyles: Record<string, React.CSSProperties> = {
padding: "2rem", padding: "2rem",
}, },
}; };

View file

@ -288,4 +288,3 @@ const pageStyles: Record<string, React.CSSProperties> = {
}; };
const styles = { ...sharedStyles, ...pageStyles }; const styles = { ...sharedStyles, ...pageStyles };