arbret/backend/auth.py
2025-12-25 00:59:57 +01:00

129 lines
3.9 KiB
Python

import os
from datetime import UTC, datetime, timedelta
import bcrypt
from fastapi import Depends, HTTPException, Request, status
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import Permission, User
from repositories.user import UserRepository
from schemas import UserResponse
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
COOKIE_NAME = "auth_token"
COOKIE_SECURE = os.environ.get("COOKIE_SECURE", "false").lower() == "true"
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
)
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(
password.encode("utf-8"),
bcrypt.gensalt(),
).decode("utf-8")
def create_access_token(
data: dict[str, str],
expires_delta: timedelta | None = None,
) -> str:
to_encode: dict[str, str | datetime] = dict(data)
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + delta
to_encode["exp"] = expire
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
"""Get user by email (backwards compatibility wrapper)."""
repo = UserRepository(db)
return await repo.get_by_email(email)
async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None:
user = await get_user_by_email(db, email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def get_current_user(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
token = request.cookies.get(COOKIE_NAME)
if not token:
raise credentials_exception
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception from None
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user
def require_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ALL required permissions.
Usage:
@app.get("/api/profile")
async def get_profile(
user: User = Depends(require_permission(Permission.MANAGE_OWN_PROFILE))
):
...
"""
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db)
missing = [p for p in required_permissions if p not in user_permissions]
if missing:
missing_str = ", ".join(p.value for p in missing)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {missing_str}",
)
return user
return permission_checker
async def build_user_response(user: User, db: AsyncSession) -> UserResponse:
"""Build a UserResponse with roles and permissions."""
permissions = await user.get_permissions(db)
return UserResponse(
id=user.id,
email=user.email,
roles=user.role_names,
permissions=[p.value for p in permissions],
)