arbret/backend/auth.py
counterweight 6c218130e9
Add ruff linter/formatter for Python
- Add ruff as dev dependency
- Configure ruff in pyproject.toml with strict 88-char line limit
- Ignore B008 (FastAPI Depends pattern is standard)
- Allow longer lines in tests for readability
- Fix all lint issues in source files
- Add Makefile targets: lint-backend, format-backend, fix-backend
2025-12-21 21:54:26 +01:00

127 lines
3.8 KiB
Python

import os
from datetime import UTC, datetime, timedelta
import bcrypt
from fastapi import Depends, HTTPException, Request, status
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import Permission, User
from schemas import UserResponse
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
COOKIE_NAME = "auth_token"
COOKIE_SECURE = os.environ.get("COOKIE_SECURE", "false").lower() == "true"
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
)
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(
password.encode("utf-8"),
bcrypt.gensalt(),
).decode("utf-8")
def create_access_token(
data: dict[str, str],
expires_delta: timedelta | None = None,
) -> str:
to_encode: dict[str, str | datetime] = dict(data)
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + delta
to_encode["exp"] = expire
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None:
user = await get_user_by_email(db, email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def get_current_user(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
token = request.cookies.get(COOKIE_NAME)
if not token:
raise credentials_exception
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception from None
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user
def require_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ALL required permissions.
Usage:
@app.get("/api/counter")
async def get_counter(
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
):
...
"""
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db)
missing = [p for p in required_permissions if p not in user_permissions]
if missing:
missing_str = ", ".join(p.value for p in missing)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {missing_str}",
)
return user
return permission_checker
async def build_user_response(user: User, db: AsyncSession) -> UserResponse:
"""Build a UserResponse with roles and permissions."""
permissions = await user.get_permissions(db)
return UserResponse(
id=user.id,
email=user.email,
roles=user.role_names,
permissions=[p.value for p in permissions],
)