finish branch
This commit is contained in:
parent
66bc4c5a45
commit
40ca82bb45
11 changed files with 139 additions and 128 deletions
5
Makefile
5
Makefile
|
|
@ -1,4 +1,4 @@
|
||||||
.PHONY: install-backend install-frontend install backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e
|
.PHONY: install-backend install-frontend install backend frontend db db-stop db-ready db-seed dev test test-backend test-frontend test-e2e typecheck
|
||||||
|
|
||||||
-include .env
|
-include .env
|
||||||
export
|
export
|
||||||
|
|
@ -52,3 +52,6 @@ test-e2e:
|
||||||
./scripts/e2e.sh
|
./scripts/e2e.sh
|
||||||
|
|
||||||
test: test-backend test-frontend test-e2e
|
test: test-backend test-frontend test-e2e
|
||||||
|
|
||||||
|
typecheck:
|
||||||
|
cd backend && uv run mypy .
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from fastapi import Depends, HTTPException, Request, status
|
from fastapi import Depends, HTTPException, Request, status
|
||||||
|
|
@ -30,8 +29,8 @@ UserLogin = UserCredentials
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
email: str
|
email: str
|
||||||
roles: List[str]
|
roles: list[str]
|
||||||
permissions: List[str]
|
permissions: list[str]
|
||||||
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
class TokenResponse(BaseModel):
|
||||||
|
|
@ -54,19 +53,20 @@ def get_password_hash(password: str) -> str:
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
def create_access_token(data: dict[str, str], expires_delta: timedelta | None = None) -> str:
|
||||||
to_encode = data.copy()
|
to_encode: dict[str, str | datetime] = dict(data)
|
||||||
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||||
to_encode.update({"exp": expire})
|
to_encode["exp"] = expire
|
||||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]:
|
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
||||||
result = await db.execute(select(User).where(User.email == email))
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
|
async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None:
|
||||||
user = await get_user_by_email(db, email)
|
user = await get_user_by_email(db, email)
|
||||||
if not user or not verify_password(password, user.hashed_password):
|
if not user or not verify_password(password, user.hashed_password):
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
138
backend/main.py
138
backend/main.py
|
|
@ -1,6 +1,6 @@
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import Any, Callable, Generic, TypeVar
|
||||||
|
|
||||||
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
|
from fastapi import FastAPI, Depends, HTTPException, Response, status, Query
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
@ -9,7 +9,43 @@ from sqlalchemy import select, func, desc
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database import engine, get_db, Base
|
from database import engine, get_db, Base
|
||||||
from models import Counter, User, SumRecord, CounterRecord, Permission, Role
|
from models import Counter, User, SumRecord, CounterRecord, Permission, Role, ROLE_REGULAR
|
||||||
|
|
||||||
|
|
||||||
|
R = TypeVar("R", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
async def paginate_with_user_email(
|
||||||
|
db: AsyncSession,
|
||||||
|
model: type[SumRecord] | type[CounterRecord],
|
||||||
|
page: int,
|
||||||
|
per_page: int,
|
||||||
|
row_mapper: Callable[..., R],
|
||||||
|
) -> tuple[list[R], int, int]:
|
||||||
|
"""
|
||||||
|
Generic pagination helper for audit records that need user email.
|
||||||
|
|
||||||
|
Returns: (records, total, total_pages)
|
||||||
|
"""
|
||||||
|
# Get total count
|
||||||
|
count_result = await db.execute(select(func.count(model.id)))
|
||||||
|
total = count_result.scalar() or 0
|
||||||
|
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
||||||
|
|
||||||
|
# Get paginated records with user email
|
||||||
|
offset = (page - 1) * per_page
|
||||||
|
query = (
|
||||||
|
select(model, User.email)
|
||||||
|
.join(User, model.user_id == User.id)
|
||||||
|
.order_by(desc(model.created_at))
|
||||||
|
.offset(offset)
|
||||||
|
.limit(per_page)
|
||||||
|
)
|
||||||
|
result = await db.execute(query)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
records: list[R] = [row_mapper(record, email) for record, email in rows]
|
||||||
|
return records, total, total_pages
|
||||||
from auth import (
|
from auth import (
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
COOKIE_NAME,
|
COOKIE_NAME,
|
||||||
|
|
@ -57,7 +93,7 @@ def set_auth_cookie(response: Response, token: str) -> None:
|
||||||
|
|
||||||
async def get_default_role(db: AsyncSession) -> Role | None:
|
async def get_default_role(db: AsyncSession) -> Role | None:
|
||||||
"""Get the default 'regular' role for new users."""
|
"""Get the default 'regular' role for new users."""
|
||||||
result = await db.execute(select(Role).where(Role.name == "regular"))
|
result = await db.execute(select(Role).where(Role.name == ROLE_REGULAR))
|
||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -214,20 +250,30 @@ class SumRecordResponse(BaseModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class PaginatedCounterRecords(BaseModel):
|
RecordT = TypeVar("RecordT", bound=BaseModel)
|
||||||
records: List[CounterRecordResponse]
|
|
||||||
|
|
||||||
|
class PaginatedResponse(BaseModel, Generic[RecordT]):
|
||||||
|
"""Generic paginated response wrapper."""
|
||||||
|
records: list[RecordT]
|
||||||
total: int
|
total: int
|
||||||
page: int
|
page: int
|
||||||
per_page: int
|
per_page: int
|
||||||
total_pages: int
|
total_pages: int
|
||||||
|
|
||||||
|
|
||||||
class PaginatedSumRecords(BaseModel):
|
PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
|
||||||
records: List[SumRecordResponse]
|
PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
|
||||||
total: int
|
|
||||||
page: int
|
|
||||||
per_page: int
|
def _map_counter_record(record: CounterRecord, email: str) -> CounterRecordResponse:
|
||||||
total_pages: int
|
return CounterRecordResponse(
|
||||||
|
id=record.id,
|
||||||
|
user_email=email,
|
||||||
|
value_before=record.value_before,
|
||||||
|
value_after=record.value_after,
|
||||||
|
created_at=record.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
|
@app.get("/api/audit/counter", response_model=PaginatedCounterRecords)
|
||||||
|
|
@ -237,34 +283,9 @@ async def get_counter_records(
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||||
):
|
):
|
||||||
# Get total count
|
records, total, total_pages = await paginate_with_user_email(
|
||||||
count_result = await db.execute(select(func.count(CounterRecord.id)))
|
db, CounterRecord, page, per_page, _map_counter_record
|
||||||
total = count_result.scalar() or 0
|
|
||||||
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
|
||||||
|
|
||||||
# Get paginated records with user email
|
|
||||||
offset = (page - 1) * per_page
|
|
||||||
query = (
|
|
||||||
select(CounterRecord, User.email)
|
|
||||||
.join(User, CounterRecord.user_id == User.id)
|
|
||||||
.order_by(desc(CounterRecord.created_at))
|
|
||||||
.offset(offset)
|
|
||||||
.limit(per_page)
|
|
||||||
)
|
)
|
||||||
result = await db.execute(query)
|
|
||||||
rows = result.all()
|
|
||||||
|
|
||||||
records = [
|
|
||||||
CounterRecordResponse(
|
|
||||||
id=record.id,
|
|
||||||
user_email=email,
|
|
||||||
value_before=record.value_before,
|
|
||||||
value_after=record.value_after,
|
|
||||||
created_at=record.created_at,
|
|
||||||
)
|
|
||||||
for record, email in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
return PaginatedCounterRecords(
|
return PaginatedCounterRecords(
|
||||||
records=records,
|
records=records,
|
||||||
total=total,
|
total=total,
|
||||||
|
|
@ -274,6 +295,17 @@ async def get_counter_records(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _map_sum_record(record: SumRecord, email: str) -> SumRecordResponse:
|
||||||
|
return SumRecordResponse(
|
||||||
|
id=record.id,
|
||||||
|
user_email=email,
|
||||||
|
a=record.a,
|
||||||
|
b=record.b,
|
||||||
|
result=record.result,
|
||||||
|
created_at=record.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/audit/sum", response_model=PaginatedSumRecords)
|
@app.get("/api/audit/sum", response_model=PaginatedSumRecords)
|
||||||
async def get_sum_records(
|
async def get_sum_records(
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
|
|
@ -281,35 +313,9 @@ async def get_sum_records(
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||||
):
|
):
|
||||||
# Get total count
|
records, total, total_pages = await paginate_with_user_email(
|
||||||
count_result = await db.execute(select(func.count(SumRecord.id)))
|
db, SumRecord, page, per_page, _map_sum_record
|
||||||
total = count_result.scalar() or 0
|
|
||||||
total_pages = (total + per_page - 1) // per_page if total > 0 else 1
|
|
||||||
|
|
||||||
# Get paginated records with user email
|
|
||||||
offset = (page - 1) * per_page
|
|
||||||
query = (
|
|
||||||
select(SumRecord, User.email)
|
|
||||||
.join(User, SumRecord.user_id == User.id)
|
|
||||||
.order_by(desc(SumRecord.created_at))
|
|
||||||
.offset(offset)
|
|
||||||
.limit(per_page)
|
|
||||||
)
|
)
|
||||||
result = await db.execute(query)
|
|
||||||
rows = result.all()
|
|
||||||
|
|
||||||
records = [
|
|
||||||
SumRecordResponse(
|
|
||||||
id=record.id,
|
|
||||||
user_email=email,
|
|
||||||
a=record.a,
|
|
||||||
b=record.b,
|
|
||||||
result=record.result,
|
|
||||||
created_at=record.created_at,
|
|
||||||
)
|
|
||||||
for record, email in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
return PaginatedSumRecords(
|
return PaginatedSumRecords(
|
||||||
records=records,
|
records=records,
|
||||||
total=total,
|
total=total,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,17 @@
|
||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from enum import Enum as PyEnum
|
from enum import Enum as PyEnum
|
||||||
from typing import List, Set
|
from typing import TypedDict
|
||||||
from sqlalchemy import Integer, String, Float, DateTime, ForeignKey, Table, Column, Enum, select
|
from sqlalchemy import Integer, String, Float, DateTime, ForeignKey, Table, Column, Enum, select
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from database import Base
|
from database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class RoleConfig(TypedDict):
|
||||||
|
description: str
|
||||||
|
permissions: list["Permission"]
|
||||||
|
|
||||||
|
|
||||||
class Permission(str, PyEnum):
|
class Permission(str, PyEnum):
|
||||||
"""All available permissions in the system."""
|
"""All available permissions in the system."""
|
||||||
# Counter permissions
|
# Counter permissions
|
||||||
|
|
@ -20,15 +25,19 @@ class Permission(str, PyEnum):
|
||||||
VIEW_AUDIT = "view_audit"
|
VIEW_AUDIT = "view_audit"
|
||||||
|
|
||||||
|
|
||||||
|
# Role name constants
|
||||||
|
ROLE_ADMIN = "admin"
|
||||||
|
ROLE_REGULAR = "regular"
|
||||||
|
|
||||||
# Role definitions with their permissions
|
# Role definitions with their permissions
|
||||||
ROLE_DEFINITIONS = {
|
ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
||||||
"admin": {
|
ROLE_ADMIN: {
|
||||||
"description": "Administrator with audit access",
|
"description": "Administrator with audit access",
|
||||||
"permissions": [
|
"permissions": [
|
||||||
Permission.VIEW_AUDIT,
|
Permission.VIEW_AUDIT,
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"regular": {
|
ROLE_REGULAR: {
|
||||||
"description": "Regular user with counter and sum access",
|
"description": "Regular user with counter and sum access",
|
||||||
"permissions": [
|
"permissions": [
|
||||||
Permission.VIEW_COUNTER,
|
Permission.VIEW_COUNTER,
|
||||||
|
|
@ -65,24 +74,20 @@ class Role(Base):
|
||||||
description: Mapped[str] = mapped_column(String(255), nullable=True)
|
description: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||||
|
|
||||||
# Relationship to users
|
# Relationship to users
|
||||||
users: Mapped[List["User"]] = relationship(
|
users: Mapped[list["User"]] = relationship(
|
||||||
"User",
|
"User",
|
||||||
secondary=user_roles,
|
secondary=user_roles,
|
||||||
back_populates="roles",
|
back_populates="roles",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_permissions(self, db: AsyncSession) -> Set[Permission]:
|
async def get_permissions(self, db: AsyncSession) -> set[Permission]:
|
||||||
"""Get all permissions for this role."""
|
"""Get all permissions for this role."""
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id)
|
select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id)
|
||||||
)
|
)
|
||||||
return {row[0] for row in result.fetchall()}
|
return {row[0] for row in result.fetchall()}
|
||||||
|
|
||||||
async def add_permission(self, db: AsyncSession, permission: Permission) -> None:
|
async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None:
|
||||||
"""Add a permission to this role."""
|
|
||||||
await db.execute(role_permissions.insert().values(role_id=self.id, permission=permission))
|
|
||||||
|
|
||||||
async def set_permissions(self, db: AsyncSession, permissions: List[Permission]) -> None:
|
|
||||||
"""Set all permissions for this role (replaces existing)."""
|
"""Set all permissions for this role (replaces existing)."""
|
||||||
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id))
|
await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id))
|
||||||
for perm in permissions:
|
for perm in permissions:
|
||||||
|
|
@ -97,20 +102,21 @@ class User(Base):
|
||||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
|
||||||
# Relationship to roles
|
# Relationship to roles
|
||||||
roles: Mapped[List[Role]] = relationship(
|
roles: Mapped[list[Role]] = relationship(
|
||||||
"Role",
|
"Role",
|
||||||
secondary=user_roles,
|
secondary=user_roles,
|
||||||
back_populates="users",
|
back_populates="users",
|
||||||
lazy="selectin",
|
lazy="selectin",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_permissions(self, db: AsyncSession) -> Set[Permission]:
|
async def get_permissions(self, db: AsyncSession) -> set[Permission]:
|
||||||
"""Get all permissions from all roles."""
|
"""Get all permissions from all roles in a single query."""
|
||||||
permissions: Set[Permission] = set()
|
result = await db.execute(
|
||||||
for role in self.roles:
|
select(role_permissions.c.permission)
|
||||||
role_perms = await role.get_permissions(db)
|
.join(user_roles, role_permissions.c.role_id == user_roles.c.role_id)
|
||||||
permissions.update(role_perms)
|
.where(user_roles.c.user_id == self.id)
|
||||||
return permissions
|
)
|
||||||
|
return {row[0] for row in result.fetchall()}
|
||||||
|
|
||||||
async def has_permission(self, db: AsyncSession, permission: Permission) -> bool:
|
async def has_permission(self, db: AsyncSession, permission: Permission) -> bool:
|
||||||
"""Check if user has a specific permission through any of their roles."""
|
"""Check if user has a specific permission through any of their roles."""
|
||||||
|
|
@ -118,7 +124,7 @@ class User(Base):
|
||||||
return permission in permissions
|
return permission in permissions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def role_names(self) -> List[str]:
|
def role_names(self) -> list[str]:
|
||||||
"""Get list of role names for API responses."""
|
"""Get list of role names for API responses."""
|
||||||
return [role.name for role in self.roles]
|
return [role.name for role in self.roles]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,5 +18,14 @@ dev = [
|
||||||
"pytest-asyncio>=0.25.0",
|
"pytest-asyncio>=0.25.0",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"aiosqlite>=0.20.0",
|
"aiosqlite>=0.20.0",
|
||||||
|
"mypy>=1.13.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_ignores = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
ignore_missing_imports = true
|
||||||
|
exclude = ["tests/"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
"""Seed the database with roles, permissions, and dev users."""
|
"""Seed the database with roles, permissions, and dev users."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import List
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database import engine, async_session, Base
|
from database import engine, async_session, Base
|
||||||
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS
|
from models import User, Role, Permission, role_permissions, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
||||||
from auth import get_password_hash
|
from auth import get_password_hash
|
||||||
|
|
||||||
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
|
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
|
||||||
|
|
@ -14,7 +14,7 @@ DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
|
||||||
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
|
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
|
||||||
|
|
||||||
|
|
||||||
async def upsert_role(db, name: str, description: str, permissions: List[Permission]) -> Role:
|
async def upsert_role(db: AsyncSession, name: str, description: str, permissions: list[Permission]) -> Role:
|
||||||
"""Create or update a role with the given permissions."""
|
"""Create or update a role with the given permissions."""
|
||||||
result = await db.execute(select(Role).where(Role.name == name))
|
result = await db.execute(select(Role).where(Role.name == name))
|
||||||
role = result.scalar_one_or_none()
|
role = result.scalar_one_or_none()
|
||||||
|
|
@ -35,7 +35,7 @@ async def upsert_role(db, name: str, description: str, permissions: List[Permiss
|
||||||
return role
|
return role
|
||||||
|
|
||||||
|
|
||||||
async def upsert_user(db, email: str, password: str, role_names: List[str]) -> User:
|
async def upsert_user(db: AsyncSession, email: str, password: str, role_names: list[str]) -> User:
|
||||||
"""Create or update a user with the given credentials and roles."""
|
"""Create or update a user with the given credentials and roles."""
|
||||||
result = await db.execute(select(User).where(User.email == email))
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
@ -45,12 +45,13 @@ async def upsert_user(db, email: str, password: str, role_names: List[str]) -> U
|
||||||
for role_name in role_names:
|
for role_name in role_names:
|
||||||
result = await db.execute(select(Role).where(Role.name == role_name))
|
result = await db.execute(select(Role).where(Role.name == role_name))
|
||||||
role = result.scalar_one_or_none()
|
role = result.scalar_one_or_none()
|
||||||
if role:
|
if not role:
|
||||||
roles.append(role)
|
raise ValueError(f"Role '{role_name}' not found")
|
||||||
|
roles.append(role)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
user.hashed_password = get_password_hash(password)
|
user.hashed_password = get_password_hash(password)
|
||||||
user.roles = roles
|
user.roles = roles # type: ignore[assignment]
|
||||||
print(f"Updated user: {email} with roles: {role_names}")
|
print(f"Updated user: {email} with roles: {role_names}")
|
||||||
else:
|
else:
|
||||||
user = User(
|
user = User(
|
||||||
|
|
@ -64,7 +65,7 @@ async def upsert_user(db, email: str, password: str, role_names: List[str]) -> U
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def seed():
|
async def seed() -> None:
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
|
@ -80,10 +81,10 @@ async def seed():
|
||||||
|
|
||||||
print("\n=== Seeding Users ===")
|
print("\n=== Seeding Users ===")
|
||||||
# Create regular dev user
|
# Create regular dev user
|
||||||
await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, ["regular"])
|
await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, [ROLE_REGULAR])
|
||||||
|
|
||||||
# Create admin dev user
|
# Create admin dev user
|
||||||
await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, ["admin"])
|
await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, [ROLE_ADMIN])
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
print("\n=== Seeding Complete ===\n")
|
print("\n=== Seeding Complete ===\n")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import List
|
|
||||||
|
|
||||||
# Set required env vars before importing app
|
# Set required env vars before importing app
|
||||||
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
||||||
|
|
@ -12,8 +11,9 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, Asyn
|
||||||
|
|
||||||
from database import Base, get_db
|
from database import Base, get_db
|
||||||
from main import app
|
from main import app
|
||||||
from models import User, Role, Permission, ROLE_DEFINITIONS
|
from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
||||||
from auth import get_password_hash
|
from auth import get_password_hash
|
||||||
|
from tests.helpers import unique_email
|
||||||
|
|
||||||
TEST_DATABASE_URL = os.getenv(
|
TEST_DATABASE_URL = os.getenv(
|
||||||
"TEST_DATABASE_URL",
|
"TEST_DATABASE_URL",
|
||||||
|
|
@ -82,7 +82,7 @@ async def create_user_with_roles(
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
email: str,
|
email: str,
|
||||||
password: str,
|
password: str,
|
||||||
role_names: List[str],
|
role_names: list[str],
|
||||||
) -> User:
|
) -> User:
|
||||||
"""Create a user with specified roles."""
|
"""Create a user with specified roles."""
|
||||||
# Get roles
|
# Get roles
|
||||||
|
|
@ -90,8 +90,9 @@ async def create_user_with_roles(
|
||||||
for role_name in role_names:
|
for role_name in role_names:
|
||||||
result = await db.execute(select(Role).where(Role.name == role_name))
|
result = await db.execute(select(Role).where(Role.name == role_name))
|
||||||
role = result.scalar_one_or_none()
|
role = result.scalar_one_or_none()
|
||||||
if role:
|
if not role:
|
||||||
roles.append(role)
|
raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?")
|
||||||
|
roles.append(role)
|
||||||
|
|
||||||
user = User(
|
user = User(
|
||||||
email=email,
|
email=email,
|
||||||
|
|
@ -144,13 +145,11 @@ async def client(client_factory):
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def regular_user(client_factory):
|
async def regular_user(client_factory):
|
||||||
"""Create a regular user and return their credentials and cookies."""
|
"""Create a regular user and return their credentials and cookies."""
|
||||||
from tests.helpers import unique_email
|
|
||||||
|
|
||||||
email = unique_email("regular")
|
email = unique_email("regular")
|
||||||
password = "password123"
|
password = "password123"
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
await create_user_with_roles(db, email, password, ["regular"])
|
await create_user_with_roles(db, email, password, [ROLE_REGULAR])
|
||||||
|
|
||||||
# Login to get cookies
|
# Login to get cookies
|
||||||
response = await client_factory.post(
|
response = await client_factory.post(
|
||||||
|
|
@ -169,13 +168,11 @@ async def regular_user(client_factory):
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def admin_user(client_factory):
|
async def admin_user(client_factory):
|
||||||
"""Create an admin user and return their credentials and cookies."""
|
"""Create an admin user and return their credentials and cookies."""
|
||||||
from tests.helpers import unique_email
|
|
||||||
|
|
||||||
email = unique_email("admin")
|
email = unique_email("admin")
|
||||||
password = "password123"
|
password = "password123"
|
||||||
|
|
||||||
async with client_factory.get_db_session() as db:
|
async with client_factory.get_db_session() as db:
|
||||||
await create_user_with_roles(db, email, password, ["admin"])
|
await create_user_with_roles(db, email, password, [ROLE_ADMIN])
|
||||||
|
|
||||||
# Login to get cookies
|
# Login to get cookies
|
||||||
response = await client_factory.post(
|
response = await client_factory.post(
|
||||||
|
|
@ -194,8 +191,6 @@ async def admin_user(client_factory):
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def user_no_roles(client_factory):
|
async def user_no_roles(client_factory):
|
||||||
"""Create a user with NO roles and return their credentials and cookies."""
|
"""Create a user with NO roles and return their credentials and cookies."""
|
||||||
from tests.helpers import unique_email
|
|
||||||
|
|
||||||
email = unique_email("noroles")
|
email = unique_email("noroles")
|
||||||
password = "password123"
|
password = "password123"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -393,4 +393,3 @@ const pageStyles: Record<string, React.CSSProperties> = {
|
||||||
};
|
};
|
||||||
|
|
||||||
const styles = { ...sharedStyles, ...pageStyles };
|
const styles = { ...sharedStyles, ...pageStyles };
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,6 @@ interface AuthContextType {
|
||||||
register: (email: string, password: string) => Promise<void>;
|
register: (email: string, password: string) => Promise<void>;
|
||||||
logout: () => Promise<void>;
|
logout: () => Promise<void>;
|
||||||
hasPermission: (permission: PermissionType) => boolean;
|
hasPermission: (permission: PermissionType) => boolean;
|
||||||
hasAnyPermission: (...permissions: PermissionType[]) => boolean;
|
|
||||||
hasRole: (role: string) => boolean;
|
hasRole: (role: string) => boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -104,10 +103,6 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||||
return user?.permissions.includes(permission) ?? false;
|
return user?.permissions.includes(permission) ?? false;
|
||||||
}, [user]);
|
}, [user]);
|
||||||
|
|
||||||
const hasAnyPermission = useCallback((...permissions: PermissionType[]): boolean => {
|
|
||||||
return permissions.some((p) => user?.permissions.includes(p) ?? false);
|
|
||||||
}, [user]);
|
|
||||||
|
|
||||||
const hasRole = useCallback((role: string): boolean => {
|
const hasRole = useCallback((role: string): boolean => {
|
||||||
return user?.roles.includes(role) ?? false;
|
return user?.roles.includes(role) ?? false;
|
||||||
}, [user]);
|
}, [user]);
|
||||||
|
|
@ -121,7 +116,6 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||||
register,
|
register,
|
||||||
logout,
|
logout,
|
||||||
hasPermission,
|
hasPermission,
|
||||||
hasAnyPermission,
|
|
||||||
hasRole,
|
hasRole,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
|
|
||||||
|
|
@ -79,4 +79,3 @@ export const sharedStyles: Record<string, React.CSSProperties> = {
|
||||||
padding: "2rem",
|
padding: "2rem",
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -288,4 +288,3 @@ const pageStyles: Record<string, React.CSSProperties> = {
|
||||||
};
|
};
|
||||||
|
|
||||||
const styles = { ...sharedStyles, ...pageStyles };
|
const styles = { ...sharedStyles, ...pageStyles };
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue