arbret/backend/auth.py
counterweight 5bad1e7e17
Phase 0.1: Remove backend deprecated code
- Delete routes: counter.py, sum.py
- Delete jobs.py and worker.py
- Delete tests: test_counter.py, test_jobs.py
- Update audit.py: keep only price-history endpoints
- Update models.py: remove VIEW_COUNTER, INCREMENT_COUNTER, USE_SUM permissions
- Update models.py: remove Counter, SumRecord, CounterRecord, RandomNumberOutcome models
- Update schemas.py: remove sum/counter related schemas
- Update main.py: remove deleted router imports
- Update test_permissions.py: remove tests for deprecated features
- Update test_price_history.py: remove worker-related tests
- Update conftest.py: remove mock_enqueue_job fixture
- Update auth.py: fix example in docstring
2025-12-22 18:07:14 +01:00

127 lines
3.8 KiB
Python

import os
from datetime import UTC, datetime, timedelta
import bcrypt
from fastapi import Depends, HTTPException, Request, status
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import Permission, User
from schemas import UserResponse
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
COOKIE_NAME = "auth_token"
COOKIE_SECURE = os.environ.get("COOKIE_SECURE", "false").lower() == "true"
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
)
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(
password.encode("utf-8"),
bcrypt.gensalt(),
).decode("utf-8")
def create_access_token(
data: dict[str, str],
expires_delta: timedelta | None = None,
) -> str:
to_encode: dict[str, str | datetime] = dict(data)
delta = expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(UTC) + delta
to_encode["exp"] = expire
encoded: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
async def authenticate_user(db: AsyncSession, email: str, password: str) -> User | None:
user = await get_user_by_email(db, email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def get_current_user(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
token = request.cookies.get(COOKIE_NAME)
if not token:
raise credentials_exception
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception from None
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user
def require_permission(*required_permissions: Permission):
"""
Dependency factory that checks if user has ALL required permissions.
Usage:
@app.get("/api/profile")
async def get_profile(
user: User = Depends(require_permission(Permission.MANAGE_OWN_PROFILE))
):
...
"""
async def permission_checker(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
user = await get_current_user(request, db)
user_permissions = await user.get_permissions(db)
missing = [p for p in required_permissions if p not in user_permissions]
if missing:
missing_str = ", ".join(p.value for p in missing)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {missing_str}",
)
return user
return permission_checker
async def build_user_response(user: User, db: AsyncSession) -> UserResponse:
"""Build a UserResponse with roles and permissions."""
permissions = await user.get_permissions(db)
return UserResponse(
id=user.id,
email=user.email,
roles=user.role_names,
permissions=[p.value for p in permissions],
)