tests passing
This commit is contained in:
parent
322bdd3e6e
commit
b173b47925
18 changed files with 1414 additions and 93 deletions
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import bcrypt
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
|
@ -10,7 +10,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_db
|
||||
from models import User
|
||||
from models import User, Permission
|
||||
|
||||
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
|
||||
ALGORITHM = "HS256"
|
||||
|
|
@ -30,6 +30,8 @@ UserLogin = UserCredentials
|
|||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
email: str
|
||||
roles: List[str]
|
||||
permissions: List[str]
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
|
|
@ -99,3 +101,64 @@ async def get_current_user(
|
|||
raise credentials_exception
|
||||
return user
|
||||
|
||||
|
||||
def require_permission(*required_permissions: Permission):
|
||||
"""
|
||||
Dependency factory that checks if user has ALL of the required permissions.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/counter")
|
||||
async def get_counter(user: User = Depends(require_permission(Permission.VIEW_COUNTER))):
|
||||
...
|
||||
"""
|
||||
async def permission_checker(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
user = await get_current_user(request, db)
|
||||
user_permissions = await user.get_permissions(db)
|
||||
|
||||
missing = [p for p in required_permissions if p not in user_permissions]
|
||||
if missing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permissions: {', '.join(p.value for p in missing)}",
|
||||
)
|
||||
return user
|
||||
return permission_checker
|
||||
|
||||
|
||||
def require_any_permission(*required_permissions: Permission):
|
||||
"""
|
||||
Dependency factory that checks if user has ANY of the required permissions.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/resource")
|
||||
async def get_resource(user: User = Depends(require_any_permission(Permission.VIEW, Permission.ADMIN))):
|
||||
...
|
||||
"""
|
||||
async def permission_checker(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
user = await get_current_user(request, db)
|
||||
user_permissions = await user.get_permissions(db)
|
||||
|
||||
if not any(p in user_permissions for p in required_permissions):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Requires one of: {', '.join(p.value for p in required_permissions)}",
|
||||
)
|
||||
return user
|
||||
return permission_checker
|
||||
|
||||
|
||||
async def build_user_response(user: User, db: AsyncSession) -> UserResponse:
|
||||
"""Build a UserResponse with roles and permissions."""
|
||||
permissions = await user.get_permissions(db)
|
||||
return UserResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
roles=user.role_names,
|
||||
permissions=[p.value for p in permissions],
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue