tests passing

This commit is contained in:
counterweight 2025-12-18 23:33:32 +01:00
parent 322bdd3e6e
commit b173b47925
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
18 changed files with 1414 additions and 93 deletions

View file

@ -1,6 +1,6 @@
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
from typing import List, Optional
import bcrypt
from fastapi import Depends, HTTPException, Request, status
@ -10,7 +10,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import User
from models import User, Permission
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
ALGORITHM = "HS256"
@ -30,6 +30,8 @@ UserLogin = UserCredentials
class UserResponse(BaseModel):
id: int
email: str
roles: List[str]
permissions: List[str]
class TokenResponse(BaseModel):
@ -99,3 +101,64 @@ async def get_current_user(
raise credentials_exception
return user
def require_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ALL of the required permissions.
Usage:
@app.get("/api/counter")
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))):
...
"""
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db)
missing = [p for p in required_permissions if p not in user_permissions]
if missing:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}",
)
return user
return permission_checker
def require_any_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ANY of the required permissions.
Usage:
@app.get("/api/resource")
async def get_resource(user: User = Depends(require_any_permission(Permission.VIEW, Permission.ADMIN))):
...
"""
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db)
if not any(p in user_permissions for p in required_permissions):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires one of: {', '.join(p.value for p in required_permissions)}",
)
return user
return permission_checker
async def build_user_response(user: User, db: AsyncSession) -> UserResponse:
"""Build a UserResponse with roles and permissions."""
permissions = await user.get_permissions(db)
return UserResponse(
id=user.id,
email=user.email,
roles=user.role_names,
permissions=[p.value for p in permissions],
)