arbret/backend/auth.py

102 lines
2.7 KiB
Python
Raw Normal View History

2025-12-18 22:08:31 +01:00
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
import bcrypt
2025-12-18 22:24:46 +01:00
from fastapi import Depends, HTTPException, Request, status
2025-12-18 22:08:31 +01:00
from jose import JWTError, jwt
from pydantic import BaseModel, EmailStr
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import User
2025-12-18 22:24:46 +01:00
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
2025-12-18 22:08:31 +01:00
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
2025-12-18 22:24:46 +01:00
COOKIE_NAME = "auth_token"
2025-12-18 22:08:31 +01:00
2025-12-18 22:31:19 +01:00
class UserCredentials(BaseModel):
2025-12-18 22:08:31 +01:00
email: EmailStr
password: str
2025-12-18 22:31:19 +01:00
UserCreate = UserCredentials
UserLogin = UserCredentials
2025-12-18 22:08:31 +01:00
class UserResponse(BaseModel):
id: int
email: str
class TokenResponse(BaseModel):
access_token: str
token_type: str
user: UserResponse
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
)
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(
password.encode("utf-8"),
bcrypt.gensalt(),
).decode("utf-8")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]:
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
user = await get_user_by_email(db, email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def get_current_user(
2025-12-18 22:24:46 +01:00
request: Request,
2025-12-18 22:08:31 +01:00
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
2025-12-18 22:24:46 +01:00
token = request.cookies.get(COOKIE_NAME)
if not token:
raise credentials_exception
2025-12-18 22:08:31 +01:00
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user