tests passing

This commit is contained in:
counterweight 2025-12-18 22:08:31 +01:00
parent 0995e1cc77
commit 7ebfb7a2dd
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
20 changed files with 2009 additions and 126 deletions

100
backend/auth.py Normal file
View file

@ -0,0 +1,100 @@
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
import bcrypt
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
from pydantic import BaseModel, EmailStr
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import User
SECRET_KEY = os.getenv("SECRET_KEY", "dev-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
security = HTTPBearer()
class UserCreate(BaseModel):
email: EmailStr
password: str
class UserLogin(BaseModel):
email: EmailStr
password: str
class UserResponse(BaseModel):
id: int
email: str
class TokenResponse(BaseModel):
access_token: str
token_type: str
user: UserResponse
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(
plain_password.encode("utf-8"),
hashed_password.encode("utf-8"),
)
def get_password_hash(password: str) -> str:
return bcrypt.hashpw(
password.encode("utf-8"),
bcrypt.gensalt(),
).decode("utf-8")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
async def get_user_by_email(db: AsyncSession, email: str) -> Optional[User]:
result = await db.execute(select(User).where(User.email == email))
return result.scalar_one_or_none()
async def authenticate_user(db: AsyncSession, email: str, password: str) -> Optional[User]:
user = await get_user_by_email(db, email)
if not user or not verify_password(password, user.hashed_password):
return None
return user
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
token = credentials.credentials
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub")
if user_id_str is None:
raise credentials_exception
user_id = int(user_id_str)
except (JWTError, ValueError):
raise credentials_exception
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user

View file

@ -1,11 +1,22 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, get_db, Base
from models import Counter
from models import Counter, User
from auth import (
UserCreate,
UserLogin,
UserResponse,
TokenResponse,
get_password_hash,
get_user_by_email,
authenticate_user,
create_access_token,
get_current_user,
)
@asynccontextmanager
@ -22,9 +33,59 @@ app.add_middleware(
allow_origins=["http://localhost:3000"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
# Auth endpoints
@app.post("/api/auth/register", response_model=TokenResponse)
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
existing_user = await get_user_by_email(db, user_data.email)
if existing_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email already registered",
)
user = User(
email=user_data.email,
hashed_password=get_password_hash(user_data.password),
)
db.add(user)
await db.commit()
await db.refresh(user)
access_token = create_access_token(data={"sub": str(user.id)})
return TokenResponse(
access_token=access_token,
token_type="bearer",
user=UserResponse(id=user.id, email=user.email),
)
@app.post("/api/auth/login", response_model=TokenResponse)
async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)):
user = await authenticate_user(db, user_data.email, user_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect email or password",
)
access_token = create_access_token(data={"sub": str(user.id)})
return TokenResponse(
access_token=access_token,
token_type="bearer",
user=UserResponse(id=user.id, email=user.email),
)
@app.get("/api/auth/me", response_model=UserResponse)
async def get_me(current_user: User = Depends(get_current_user)):
return UserResponse(id=current_user.id, email=current_user.email)
# Counter endpoints
async def get_or_create_counter(db: AsyncSession) -> Counter:
result = await db.execute(select(Counter).where(Counter.id == 1))
counter = result.scalar_one_or_none()
@ -37,13 +98,19 @@ async def get_or_create_counter(db: AsyncSession) -> Counter:
@app.get("/api/counter")
async def get_counter(db: AsyncSession = Depends(get_db)):
async def get_counter(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
):
counter = await get_or_create_counter(db)
return {"value": counter.value}
@app.post("/api/counter/increment")
async def increment_counter(db: AsyncSession = Depends(get_db)):
async def increment_counter(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(get_current_user),
):
counter = await get_or_create_counter(db)
counter.value += 1
await db.commit()

View file

@ -1,4 +1,4 @@
from sqlalchemy import Integer
from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from database import Base
@ -9,3 +9,11 @@ class Counter(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
value: Mapped[int] = mapped_column(Integer, default=0)
class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)

View file

@ -7,6 +7,9 @@ dependencies = [
"uvicorn>=0.34.0",
"sqlalchemy[asyncio]>=2.0.36",
"asyncpg>=0.30.0",
"bcrypt>=4.0.0",
"python-jose[cryptography]>=3.3.0",
"email-validator>=2.0.0",
]
[dependency-groups]

View file

@ -1,4 +1,3 @@
[pytest]
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function

35
backend/tests/conftest.py Normal file
View file

@ -0,0 +1,35 @@
import os
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from database import Base, get_db
from main import app
TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL",
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
)
@pytest.fixture(scope="function")
async def client():
engine = create_async_engine(TEST_DATABASE_URL)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async def override_get_db():
async with session_factory() as session:
yield session
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
yield c
app.dependency_overrides.clear()
await engine.dispose()

282
backend/tests/test_auth.py Normal file
View file

@ -0,0 +1,282 @@
import pytest
import uuid
def unique_email(prefix: str = "test") -> str:
"""Generate a unique email for tests sharing the same database."""
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
async def create_user_and_get_token(client, email: str = None, password: str = "testpass123") -> str:
"""Helper to create a user and return their auth token."""
if email is None:
email = unique_email()
response = await client.post(
"/api/auth/register",
json={"email": email, "password": password},
)
return response.json()["access_token"]
def auth_header(token: str) -> dict:
"""Helper to create auth headers from token."""
return {"Authorization": f"Bearer {token}"}
# Registration tests
@pytest.mark.asyncio
async def test_register_success(client):
email = unique_email("register")
response = await client.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
assert data["user"]["email"] == email
assert "id" in data["user"]
@pytest.mark.asyncio
async def test_register_duplicate_email(client):
email = unique_email("duplicate")
await client.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
response = await client.post(
"/api/auth/register",
json={"email": email, "password": "differentpass"},
)
assert response.status_code == 400
assert response.json()["detail"] == "Email already registered"
@pytest.mark.asyncio
async def test_register_invalid_email(client):
response = await client.post(
"/api/auth/register",
json={"email": "notanemail", "password": "password123"},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_register_missing_password(client):
response = await client.post(
"/api/auth/register",
json={"email": unique_email()},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_register_missing_email(client):
response = await client.post(
"/api/auth/register",
json={"password": "password123"},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_register_empty_body(client):
response = await client.post("/api/auth/register", json={})
assert response.status_code == 422
# Login tests
@pytest.mark.asyncio
async def test_login_success(client):
email = unique_email("login")
await client.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
response = await client.post(
"/api/auth/login",
json={"email": email, "password": "password123"},
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
assert data["user"]["email"] == email
@pytest.mark.asyncio
async def test_login_wrong_password(client):
email = unique_email("wrongpass")
await client.post(
"/api/auth/register",
json={"email": email, "password": "correctpassword"},
)
response = await client.post(
"/api/auth/login",
json={"email": email, "password": "wrongpassword"},
)
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect email or password"
@pytest.mark.asyncio
async def test_login_nonexistent_user(client):
response = await client.post(
"/api/auth/login",
json={"email": unique_email("nonexistent"), "password": "password123"},
)
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect email or password"
@pytest.mark.asyncio
async def test_login_invalid_email_format(client):
response = await client.post(
"/api/auth/login",
json={"email": "invalidemail", "password": "password123"},
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_login_missing_fields(client):
response = await client.post("/api/auth/login", json={})
assert response.status_code == 422
# Get current user tests
@pytest.mark.asyncio
async def test_get_me_success(client):
email = unique_email("me")
token = await create_user_and_get_token(client, email)
response = await client.get("/api/auth/me", headers=auth_header(token))
assert response.status_code == 200
data = response.json()
assert data["email"] == email
assert "id" in data
@pytest.mark.asyncio
async def test_get_me_no_token(client):
response = await client.get("/api/auth/me")
# HTTPBearer returns 401/403 when credentials are missing
assert response.status_code in [401, 403]
@pytest.mark.asyncio
async def test_get_me_invalid_token(client):
response = await client.get(
"/api/auth/me",
headers={"Authorization": "Bearer invalidtoken123"},
)
assert response.status_code == 401
assert response.json()["detail"] == "Invalid authentication credentials"
@pytest.mark.asyncio
async def test_get_me_malformed_auth_header(client):
response = await client.get(
"/api/auth/me",
headers={"Authorization": "NotBearer token123"},
)
# Invalid scheme returns 401/403
assert response.status_code in [401, 403]
@pytest.mark.asyncio
async def test_get_me_expired_token(client):
response = await client.get(
"/api/auth/me",
headers={"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjEsImV4cCI6MH0.invalid"},
)
assert response.status_code == 401
# Token validation tests
@pytest.mark.asyncio
async def test_token_from_register_works_for_me(client):
email = unique_email("tokentest")
register_response = await client.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
token = register_response.json()["access_token"]
me_response = await client.get("/api/auth/me", headers=auth_header(token))
assert me_response.status_code == 200
assert me_response.json()["email"] == email
@pytest.mark.asyncio
async def test_token_from_login_works_for_me(client):
email = unique_email("logintoken")
await client.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
login_response = await client.post(
"/api/auth/login",
json={"email": email, "password": "password123"},
)
token = login_response.json()["access_token"]
me_response = await client.get("/api/auth/me", headers=auth_header(token))
assert me_response.status_code == 200
assert me_response.json()["email"] == email
# Multiple users tests
@pytest.mark.asyncio
async def test_multiple_users_isolated(client):
email1 = unique_email("user1")
email2 = unique_email("user2")
resp1 = await client.post(
"/api/auth/register",
json={"email": email1, "password": "password1"},
)
resp2 = await client.post(
"/api/auth/register",
json={"email": email2, "password": "password2"},
)
token1 = resp1.json()["access_token"]
token2 = resp2.json()["access_token"]
me1 = await client.get("/api/auth/me", headers=auth_header(token1))
me2 = await client.get("/api/auth/me", headers=auth_header(token2))
assert me1.json()["email"] == email1
assert me2.json()["email"] == email2
assert me1.json()["id"] != me2.json()["id"]
# Password tests
@pytest.mark.asyncio
async def test_password_is_hashed(client):
email = unique_email("hashtest")
await client.post(
"/api/auth/register",
json={"email": email, "password": "mySecurePassword123"},
)
response = await client.post(
"/api/auth/login",
json={"email": email, "password": "mySecurePassword123"},
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_case_sensitive_password(client):
email = unique_email("casetest")
await client.post(
"/api/auth/register",
json={"email": email, "password": "Password123"},
)
response = await client.post(
"/api/auth/login",
json={"email": email, "password": "password123"},
)
assert response.status_code == 401

View file

@ -1,60 +1,128 @@
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from database import Base, get_db
from main import app
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
import uuid
@pytest.fixture
async def client():
engine = create_async_engine(TEST_DATABASE_URL)
async_session = async_sessionmaker(engine, expire_on_commit=False)
def unique_email(prefix: str = "counter") -> str:
"""Generate a unique email for tests sharing the same database."""
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def override_get_db():
async with async_session() as session:
yield session
async def create_user_and_get_headers(client, email: str = None) -> dict:
"""Create a user and return auth headers for authenticated requests."""
if email is None:
email = unique_email()
response = await client.post(
"/api/auth/register",
json={"email": email, "password": "testpass123"},
)
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
yield c
app.dependency_overrides.clear()
await engine.dispose()
# Protected endpoint tests - without auth
@pytest.mark.asyncio
async def test_get_counter_requires_auth(client):
response = await client.get("/api/counter")
assert response.status_code in [401, 403]
@pytest.mark.asyncio
async def test_get_counter_initial(client):
response = await client.get("/api/counter")
async def test_increment_counter_requires_auth(client):
response = await client.post("/api/counter/increment")
assert response.status_code in [401, 403]
@pytest.mark.asyncio
async def test_get_counter_invalid_token(client):
response = await client.get(
"/api/counter",
headers={"Authorization": "Bearer invalidtoken"},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_increment_counter_invalid_token(client):
response = await client.post(
"/api/counter/increment",
headers={"Authorization": "Bearer invalidtoken"},
)
assert response.status_code == 401
# Authenticated counter tests
@pytest.mark.asyncio
async def test_get_counter_authenticated(client):
auth_headers = await create_user_and_get_headers(client)
response = await client.get("/api/counter", headers=auth_headers)
assert response.status_code == 200
assert response.json() == {"value": 0}
assert "value" in response.json()
@pytest.mark.asyncio
async def test_increment_counter(client):
response = await client.post("/api/counter/increment")
auth_headers = await create_user_and_get_headers(client)
# Get current value
before = await client.get("/api/counter", headers=auth_headers)
before_value = before.json()["value"]
# Increment
response = await client.post("/api/counter/increment", headers=auth_headers)
assert response.status_code == 200
assert response.json() == {"value": 1}
assert response.json()["value"] == before_value + 1
@pytest.mark.asyncio
async def test_increment_counter_multiple(client):
await client.post("/api/counter/increment")
await client.post("/api/counter/increment")
response = await client.post("/api/counter/increment")
assert response.json() == {"value": 3}
auth_headers = await create_user_and_get_headers(client)
# Get starting value
before = await client.get("/api/counter", headers=auth_headers)
start = before.json()["value"]
# Increment 3 times
await client.post("/api/counter/increment", headers=auth_headers)
await client.post("/api/counter/increment", headers=auth_headers)
response = await client.post("/api/counter/increment", headers=auth_headers)
assert response.json()["value"] == start + 3
@pytest.mark.asyncio
async def test_get_counter_after_increment(client):
await client.post("/api/counter/increment")
await client.post("/api/counter/increment")
response = await client.get("/api/counter")
assert response.json() == {"value": 2}
auth_headers = await create_user_and_get_headers(client)
before = await client.get("/api/counter", headers=auth_headers)
start = before.json()["value"]
await client.post("/api/counter/increment", headers=auth_headers)
await client.post("/api/counter/increment", headers=auth_headers)
response = await client.get("/api/counter", headers=auth_headers)
assert response.json()["value"] == start + 2
# Counter is shared between users
@pytest.mark.asyncio
async def test_counter_shared_between_users(client):
headers1 = await create_user_and_get_headers(client, unique_email("share1"))
# Get starting value
before = await client.get("/api/counter", headers=headers1)
start = before.json()["value"]
await client.post("/api/counter/increment", headers=headers1)
await client.post("/api/counter/increment", headers=headers1)
# Second user sees the increments
headers2 = await create_user_and_get_headers(client, unique_email("share2"))
response = await client.get("/api/counter", headers=headers2)
assert response.json()["value"] == start + 2
# Second user increments
await client.post("/api/counter/increment", headers=headers2)
# First user sees the increment
response = await client.get("/api/counter", headers=headers1)
assert response.json()["value"] == start + 3