arbret/backend/auth.py

101 lines
2.7 KiB
Python

import os
from datetime import datetime, timedelta, timezone
from typing import Optional
import bcrypt
from fastapi import Depends, HTTPException, Request, status
from jose import JWTError, jwt
from pydantic import BaseModel, EmailStr
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import User
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
COOKIE_NAME = "auth_token"
class UserCredentials(BaseModel):
email: EmailStr
password: str
UserCreate = UserCredentials
UserLogin = UserCredentials
class UserResponse(BaseModel):
id: int
email: str
class TokenResponse(BaseModel):
access_token: str
token_type: str
user: UserResponse
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
)
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(
password.encode("utf-8"),
bcrypt.gensalt(),
).decode("utf-8")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]:
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
user = await get_user_by_email(db, email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def get_current_user(
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
token = request.cookies.get(COOKIE_NAME)
if not token:
raise credentials_exception
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user