arbret/backend/auth.py

130 lines
3.9 KiB
Python
Raw Normal View History

2025-12-18 22:08:31 +01:00
import os
from datetime import UTC, datetime, timedelta
2025-12-18 22:08:31 +01:00
import bcrypt
2025-12-18 22:24:46 +01:00
from fastapi import Depends, HTTPException, Request, status
2025-12-18 22:08:31 +01:00
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import Permission, User
2025-12-25 00:59:57 +01:00
from repositories.user import UserRepository
2025-12-20 22:18:14 +01:00
from schemas import UserResponse
2025-12-18 22:08:31 +01:00
2025-12-18 22:24:46 +01:00
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
2025-12-18 22:08:31 +01:00
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
2025-12-18 22:24:46 +01:00
COOKIE_NAME = "auth_token"
2025-12-20 22:38:39 +01:00
COOKIE_SECURE = os.environ.get("COOKIE_SECURE", "false").lower() == "true"
2025-12-18 22:08:31 +01:00
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
)
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(
password.encode("utf-8"),
bcrypt.gensalt(),
).decode("utf-8")
def create_access_token(
data: dict[str, str],
expires_delta: timedelta | None = None,
) -> str:
2025-12-19 00:12:43 +01:00
to_encode: dict[str, str | datetime] = dict(data)
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + delta
2025-12-19 00:12:43 +01:00
to_encode["exp"] = expire
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded
2025-12-18 22:08:31 +01:00
2025-12-19 00:12:43 +01:00
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
2025-12-25 00:59:57 +01:00
"""Get user by email (backwards compatibility wrapper)."""
repo = UserRepository(db)
return await repo.get_by_email(email)
2025-12-18 22:08:31 +01:00
2025-12-19 00:12:43 +01:00
async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None:
2025-12-18 22:08:31 +01:00
user = await get_user_by_email(db, email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def get_current_user(
2025-12-18 22:24:46 +01:00
request: Request,
2025-12-18 22:08:31 +01:00
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
2025-12-18 22:24:46 +01:00
token = request.cookies.get(COOKIE_NAME)
if not token:
raise credentials_exception
2025-12-18 22:08:31 +01:00
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception from None
2025-12-18 22:08:31 +01:00
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user
2025-12-18 23:33:32 +01:00
def require_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ALL required permissions.
2025-12-18 23:33:32 +01:00
Usage:
@app.get("/api/profile")
async def get_profile(
user: User = Depends(require_permission(Permission.MANAGE_OWN_PROFILE))
):
2025-12-18 23:33:32 +01:00
...
"""
2025-12-18 23:33:32 +01:00
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db)
2025-12-18 23:33:32 +01:00
missing = [p for p in required_permissions if p not in user_permissions]
if missing:
missing_str = ", ".join(p.value for p in missing)
2025-12-18 23:33:32 +01:00
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {missing_str}",
2025-12-18 23:33:32 +01:00
)
return user
2025-12-18 23:33:32 +01:00
return permission_checker
async def build_user_response(user: User, db: AsyncSession) -> UserResponse:
"""Build a UserResponse with roles and permissions."""
permissions = await user.get_permissions(db)
return UserResponse(
id=user.id,
email=user.email,
roles=user.role_names,
permissions=[p.value for p in permissions],
)