2025-12-18 22:08:31 +01:00
|
|
|
import os
|
2025-12-21 21:54:26 +01:00
|
|
|
from datetime import UTC, datetime, timedelta
|
2025-12-18 22:08:31 +01:00
|
|
|
|
|
|
|
|
import bcrypt
|
2025-12-18 22:24:46 +01:00
|
|
|
from fastapi import Depends, HTTPException, Request, status
|
2025-12-18 22:08:31 +01:00
|
|
|
from jose import JWTError, jwt
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
from database import get_db
|
2025-12-21 21:54:26 +01:00
|
|
|
from models import Permission, User
|
2025-12-25 00:59:57 +01:00
|
|
|
from repositories.user import UserRepository
|
2025-12-20 22:18:14 +01:00
|
|
|
from schemas import UserResponse
|
2025-12-18 22:08:31 +01:00
|
|
|
|
2025-12-18 22:24:46 +01:00
|
|
|
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
|
2025-12-18 22:08:31 +01:00
|
|
|
ALGORITHM = "HS256"
|
|
|
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
|
2025-12-18 22:24:46 +01:00
|
|
|
COOKIE_NAME = "auth_token"
|
2025-12-20 22:38:39 +01:00
|
|
|
COOKIE_SECURE = os.environ.get("COOKIE_SECURE", "false").lower() == "true"
|
2025-12-18 22:08:31 +01:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
|
|
|
return bcrypt.checkpw(
|
|
|
|
|
plain_password.encode("utf-8"),
|
|
|
|
|
hashed_password.encode("utf-8"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_password_hash(password: str) -> str:
|
|
|
|
|
return bcrypt.hashpw(
|
|
|
|
|
password.encode("utf-8"),
|
|
|
|
|
bcrypt.gensalt(),
|
|
|
|
|
).decode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
2025-12-21 21:54:26 +01:00
|
|
|
def create_access_token(
|
|
|
|
|
data: dict[str, str],
|
|
|
|
|
expires_delta: timedelta | None = None,
|
|
|
|
|
) -> str:
|
2025-12-19 00:12:43 +01:00
|
|
|
to_encode: dict[str, str | datetime] = dict(data)
|
2025-12-21 21:54:26 +01:00
|
|
|
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
|
expire = datetime.now(UTC) + delta
|
2025-12-19 00:12:43 +01:00
|
|
|
to_encode["exp"] = expire
|
|
|
|
|
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
|
|
|
|
return encoded
|
2025-12-18 22:08:31 +01:00
|
|
|
|
|
|
|
|
|
2025-12-19 00:12:43 +01:00
|
|
|
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
2025-12-25 00:59:57 +01:00
|
|
|
"""Get user by email (backwards compatibility wrapper)."""
|
|
|
|
|
repo = UserRepository(db)
|
|
|
|
|
return await repo.get_by_email(email)
|
2025-12-18 22:08:31 +01:00
|
|
|
|
|
|
|
|
|
2025-12-19 00:12:43 +01:00
|
|
|
async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None:
|
2025-12-18 22:08:31 +01:00
|
|
|
user = await get_user_by_email(db, email)
|
|
|
|
|
if not user or not verify_password(password, user.hashed_password):
|
|
|
|
|
return None
|
|
|
|
|
return user
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_current_user(
|
2025-12-18 22:24:46 +01:00
|
|
|
request: Request,
|
2025-12-18 22:08:31 +01:00
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
) -> User:
|
|
|
|
|
credentials_exception = HTTPException(
|
|
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
detail="Invalid authentication credentials",
|
|
|
|
|
)
|
2025-12-21 21:54:26 +01:00
|
|
|
|
2025-12-18 22:24:46 +01:00
|
|
|
token = request.cookies.get(COOKIE_NAME)
|
|
|
|
|
if not token:
|
|
|
|
|
raise credentials_exception
|
2025-12-21 21:54:26 +01:00
|
|
|
|
2025-12-18 22:08:31 +01:00
|
|
|
try:
|
|
|
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
|
|
|
user_id_str = payload.get("sub")
|
|
|
|
|
if user_id_str is None:
|
|
|
|
|
raise credentials_exception
|
|
|
|
|
user_id = int(user_id_str)
|
|
|
|
|
except (JWTError, ValueError):
|
2025-12-21 21:54:26 +01:00
|
|
|
raise credentials_exception from None
|
2025-12-18 22:08:31 +01:00
|
|
|
|
|
|
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
|
|
|
user = result.scalar_one_or_none()
|
|
|
|
|
if user is None:
|
|
|
|
|
raise credentials_exception
|
|
|
|
|
return user
|
|
|
|
|
|
2025-12-18 23:33:32 +01:00
|
|
|
|
|
|
|
|
def require_permission(*required_permissions: Permission):
|
|
|
|
|
"""
|
2025-12-21 21:54:26 +01:00
|
|
|
Dependency factory that checks if user has ALL required permissions.
|
|
|
|
|
|
2025-12-18 23:33:32 +01:00
|
|
|
Usage:
|
2025-12-22 18:07:14 +01:00
|
|
|
@app.get("/api/profile")
|
|
|
|
|
async def get_profile(
|
|
|
|
|
user: User = Depends(require_permission(Permission.MANAGE_OWN_PROFILE))
|
2025-12-21 21:54:26 +01:00
|
|
|
):
|
2025-12-18 23:33:32 +01:00
|
|
|
...
|
|
|
|
|
"""
|
2025-12-21 21:54:26 +01:00
|
|
|
|
2025-12-18 23:33:32 +01:00
|
|
|
async def permission_checker(
|
|
|
|
|
request: Request,
|
|
|
|
|
db: AsyncSession = Depends(get_db),
|
|
|
|
|
) -> User:
|
|
|
|
|
user = await get_current_user(request, db)
|
|
|
|
|
user_permissions = await user.get_permissions(db)
|
2025-12-21 21:54:26 +01:00
|
|
|
|
2025-12-18 23:33:32 +01:00
|
|
|
missing = [p for p in required_permissions if p not in user_permissions]
|
|
|
|
|
if missing:
|
2025-12-21 21:54:26 +01:00
|
|
|
missing_str = ", ".join(p.value for p in missing)
|
2025-12-18 23:33:32 +01:00
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
2025-12-21 21:54:26 +01:00
|
|
|
detail=f"Missing required permissions: {missing_str}",
|
2025-12-18 23:33:32 +01:00
|
|
|
)
|
|
|
|
|
return user
|
2025-12-21 21:54:26 +01:00
|
|
|
|
2025-12-18 23:33:32 +01:00
|
|
|
return permission_checker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def build_user_response(user: User, db: AsyncSession) -> UserResponse:
|
|
|
|
|
"""Build a UserResponse with roles and permissions."""
|
|
|
|
|
permissions = await user.get_permissions(db)
|
|
|
|
|
return UserResponse(
|
|
|
|
|
id=user.id,
|
|
|
|
|
email=user.email,
|
|
|
|
|
roles=user.role_names,
|
|
|
|
|
permissions=[p.value for p in permissions],
|
|
|
|
|
)
|