first round of review

This commit is contained in:
counterweight 2025-12-18 22:24:46 +01:00
parent 7ebfb7a2dd
commit da5a0d03eb
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
14 changed files with 362 additions and 244 deletions

11
.envrc Normal file
View file

@ -0,0 +1,11 @@
# Local development environment variables
# To use: install direnv (https://direnv.net), then run `direnv allow`
# Backend
export SECRET_KEY="dev-secret-key-change-in-production"
export DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/arbret"
export TEST_DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
# Frontend
export NEXT_PUBLIC_API_URL="http://localhost:8000"

View file

@ -3,8 +3,7 @@ from datetime import datetime, timedelta, timezone
from typing import Optional from typing import Optional
import bcrypt import bcrypt
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt from jose import JWTError, jwt
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
from sqlalchemy import select from sqlalchemy import select
@ -13,10 +12,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db from database import get_db
from models import User from models import User
SECRET_KEY = os.getenv("SECRET_KEY", "dev-secret-key-change-in-production") SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
security = HTTPBearer() COOKIE_NAME = "auth_token"
class UserCreate(BaseModel): class UserCreate(BaseModel):
@ -74,16 +73,19 @@ async def authenticate_user(db: AsyncSession, email: str, password: str) -> Opti
async def get_current_user( async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security), request: Request,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
) -> User: ) -> User:
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials", detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
) )
token = request.cookies.get(COOKIE_NAME)
if not token:
raise credentials_exception
try: try:
token = credentials.credentials
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub") user_id_str = payload.get("sub")
if user_id_str is None: if user_id_str is None:

11
backend/env.example Normal file
View file

@ -0,0 +1,11 @@
# Environment variables for the backend
# For local dev: use direnv with the root .envrc file (recommended)
# For production: set these in your deployment environment
# Required: Secret key for JWT token signing
# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))"
SECRET_KEY=
# Database URL
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/arbret

View file

@ -1,5 +1,5 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException, status from fastapi import FastAPI, Depends, HTTPException, Response, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -7,10 +7,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, get_db, Base from database import engine, get_db, Base
from models import Counter, User from models import Counter, User
from auth import ( from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME,
UserCreate, UserCreate,
UserLogin, UserLogin,
UserResponse, UserResponse,
TokenResponse,
get_password_hash, get_password_hash,
get_user_by_email, get_user_by_email,
authenticate_user, authenticate_user,
@ -37,9 +38,24 @@ app.add_middleware(
) )
def set_auth_cookie(response: Response, token: str) -> None:
response.set_cookie(
key=COOKIE_NAME,
value=token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
# Auth endpoints # Auth endpoints
@app.post("/api/auth/register", response_model=TokenResponse) @app.post("/api/auth/register", response_model=UserResponse)
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)): async def register(
user_data: UserCreate,
response: Response,
db: AsyncSession = Depends(get_db),
):
existing_user = await get_user_by_email(db, user_data.email) existing_user = await get_user_by_email(db, user_data.email)
if existing_user: if existing_user:
raise HTTPException( raise HTTPException(
@ -56,15 +72,16 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
await db.refresh(user) await db.refresh(user)
access_token = create_access_token(data={"sub": str(user.id)}) access_token = create_access_token(data={"sub": str(user.id)})
return TokenResponse( set_auth_cookie(response, access_token)
access_token=access_token, return UserResponse(id=user.id, email=user.email)
token_type="bearer",
user=UserResponse(id=user.id, email=user.email),
)
@app.post("/api/auth/login", response_model=TokenResponse) @app.post("/api/auth/login", response_model=UserResponse)
async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)): async def login(
user_data: UserLogin,
response: Response,
db: AsyncSession = Depends(get_db),
):
user = await authenticate_user(db, user_data.email, user_data.password) user = await authenticate_user(db, user_data.email, user_data.password)
if not user: if not user:
raise HTTPException( raise HTTPException(
@ -73,11 +90,14 @@ async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)):
) )
access_token = create_access_token(data={"sub": str(user.id)}) access_token = create_access_token(data={"sub": str(user.id)})
return TokenResponse( set_auth_cookie(response, access_token)
access_token=access_token, return UserResponse(id=user.id, email=user.email)
token_type="bearer",
user=UserResponse(id=user.id, email=user.email),
) @app.post("/api/auth/logout")
async def logout(response: Response):
response.delete_cookie(key=COOKIE_NAME)
return {"ok": True}
@app.get("/api/auth/me", response_model=UserResponse) @app.get("/api/auth/me", response_model=UserResponse)

View file

@ -1,4 +1,9 @@
import os import os
from contextlib import asynccontextmanager
# Set required env vars before importing app
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
@ -12,8 +17,38 @@ TEST_DATABASE_URL = os.getenv(
) )
class ClientFactory:
"""Factory for creating httpx clients with optional cookies."""
def __init__(self, transport, base_url):
self._transport = transport
self._base_url = base_url
@asynccontextmanager
async def create(self, cookies: dict | None = None):
"""Create a new client, optionally with cookies set."""
async with AsyncClient(
transport=self._transport,
base_url=self._base_url,
cookies=cookies or {},
) as client:
yield client
async def request(self, method: str, url: str, **kwargs):
"""Make a one-off request without cookies."""
async with self.create() as client:
return await client.request(method, url, **kwargs)
async def get(self, url: str, **kwargs):
return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs):
return await self.request("POST", url, **kwargs)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
async def client(): async def client_factory():
"""Fixture that provides a factory for creating clients."""
engine = create_async_engine(TEST_DATABASE_URL) engine = create_async_engine(TEST_DATABASE_URL)
session_factory = async_sessionmaker(engine, expire_on_commit=False) session_factory = async_sessionmaker(engine, expire_on_commit=False)
@ -28,8 +63,17 @@ async def client():
app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: transport = ASGITransport(app=app)
yield c factory = ClientFactory(transport, "http://test")
yield factory
app.dependency_overrides.clear() app.dependency_overrides.clear()
await engine.dispose() await engine.dispose()
@pytest.fixture(scope="function")
async def client(client_factory):
"""Fixture for a simple client without cookies (backwards compatible)."""
async with client_factory.create() as c:
yield c

View file

@ -1,28 +1,14 @@
import pytest import pytest
import uuid import uuid
from auth import COOKIE_NAME
def unique_email(prefix: str = "test") -> str: def unique_email(prefix: str = "test") -> str:
"""Generate a unique email for tests sharing the same database.""" """Generate a unique email for tests sharing the same database."""
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com" return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
async def create_user_and_get_token(client, email: str = None, password: str = "testpass123") -> str:
"""Helper to create a user and return their auth token."""
if email is None:
email = unique_email()
response = await client.post(
"/api/auth/register",
json={"email": email, "password": password},
)
return response.json()["access_token"]
def auth_header(token: str) -> dict:
"""Helper to create auth headers from token."""
return {"Authorization": f"Bearer {token}"}
# Registration tests # Registration tests
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_success(client): async def test_register_success(client):
@ -33,10 +19,10 @@ async def test_register_success(client):
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "access_token" in data assert data["email"] == email
assert data["token_type"] == "bearer" assert "id" in data
assert data["user"]["email"] == email # Cookie should be set
assert "id" in data["user"] assert COOKIE_NAME in response.cookies
@pytest.mark.asyncio @pytest.mark.asyncio
@ -101,9 +87,8 @@ async def test_login_success(client):
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "access_token" in data assert data["email"] == email
assert data["token_type"] == "bearer" assert COOKIE_NAME in response.cookies
assert data["user"]["email"] == email
@pytest.mark.asyncio @pytest.mark.asyncio
@ -148,10 +133,20 @@ async def test_login_missing_fields(client):
# Get current user tests # Get current user tests
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_me_success(client): async def test_get_me_success(client_factory):
email = unique_email("me") email = unique_email("me")
token = await create_user_and_get_token(client, email)
response = await client.get("/api/auth/me", headers=auth_header(token)) # Register and get cookies
reg_response = await client_factory.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
cookies = dict(reg_response.cookies)
# Use authenticated client
async with client_factory.create(cookies=cookies) as authed:
response = await authed.get("/api/auth/me")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["email"] == email assert data["email"] == email
@ -159,94 +154,89 @@ async def test_get_me_success(client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_me_no_token(client): async def test_get_me_no_cookie(client):
response = await client.get("/api/auth/me") response = await client.get("/api/auth/me")
# HTTPBearer returns 401/403 when credentials are missing assert response.status_code == 401
assert response.status_code in [401, 403]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_me_invalid_token(client): async def test_get_me_invalid_cookie(client_factory):
response = await client.get( async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed:
"/api/auth/me", response = await authed.get("/api/auth/me")
headers={"Authorization": "Bearer invalidtoken123"},
)
assert response.status_code == 401 assert response.status_code == 401
assert response.json()["detail"] == "Invalid authentication credentials" assert response.json()["detail"] == "Invalid authentication credentials"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_me_malformed_auth_header(client): async def test_get_me_expired_token(client_factory):
response = await client.get( bad_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjEsImV4cCI6MH0.invalid"
"/api/auth/me", async with client_factory.create(cookies={COOKIE_NAME: bad_token}) as authed:
headers={"Authorization": "NotBearer token123"}, response = await authed.get("/api/auth/me")
)
# Invalid scheme returns 401/403
assert response.status_code in [401, 403]
@pytest.mark.asyncio
async def test_get_me_expired_token(client):
response = await client.get(
"/api/auth/me",
headers={"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjEsImV4cCI6MH0.invalid"},
)
assert response.status_code == 401 assert response.status_code == 401
# Token validation tests # Cookie validation tests
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_from_register_works_for_me(client): async def test_cookie_from_register_works_for_me(client_factory):
email = unique_email("tokentest") email = unique_email("tokentest")
register_response = await client.post(
reg_response = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={"email": email, "password": "password123"}, json={"email": email, "password": "password123"},
) )
token = register_response.json()["access_token"] cookies = dict(reg_response.cookies)
async with client_factory.create(cookies=cookies) as authed:
me_response = await authed.get("/api/auth/me")
me_response = await client.get("/api/auth/me", headers=auth_header(token))
assert me_response.status_code == 200 assert me_response.status_code == 200
assert me_response.json()["email"] == email assert me_response.json()["email"] == email
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_from_login_works_for_me(client): async def test_cookie_from_login_works_for_me(client_factory):
email = unique_email("logintoken") email = unique_email("logintoken")
await client.post(
await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={"email": email, "password": "password123"}, json={"email": email, "password": "password123"},
) )
login_response = await client.post( login_response = await client_factory.post(
"/api/auth/login", "/api/auth/login",
json={"email": email, "password": "password123"}, json={"email": email, "password": "password123"},
) )
token = login_response.json()["access_token"] cookies = dict(login_response.cookies)
async with client_factory.create(cookies=cookies) as authed:
me_response = await authed.get("/api/auth/me")
me_response = await client.get("/api/auth/me", headers=auth_header(token))
assert me_response.status_code == 200 assert me_response.status_code == 200
assert me_response.json()["email"] == email assert me_response.json()["email"] == email
# Multiple users tests # Multiple users tests
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_users_isolated(client): async def test_multiple_users_isolated(client_factory):
email1 = unique_email("user1") email1 = unique_email("user1")
email2 = unique_email("user2") email2 = unique_email("user2")
resp1 = await client.post( resp1 = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={"email": email1, "password": "password1"}, json={"email": email1, "password": "password1"},
) )
resp2 = await client.post( resp2 = await client_factory.post(
"/api/auth/register", "/api/auth/register",
json={"email": email2, "password": "password2"}, json={"email": email2, "password": "password2"},
) )
token1 = resp1.json()["access_token"] cookies1 = dict(resp1.cookies)
token2 = resp2.json()["access_token"] cookies2 = dict(resp2.cookies)
me1 = await client.get("/api/auth/me", headers=auth_header(token1)) async with client_factory.create(cookies=cookies1) as user1:
me2 = await client.get("/api/auth/me", headers=auth_header(token2)) me1 = await user1.get("/api/auth/me")
async with client_factory.create(cookies=cookies2) as user2:
me2 = await user2.get("/api/auth/me")
assert me1.json()["email"] == email1 assert me1.json()["email"] == email1
assert me2.json()["email"] == email2 assert me2.json()["email"] == email2
@ -280,3 +270,21 @@ async def test_case_sensitive_password(client):
json={"email": email, "password": "password123"}, json={"email": email, "password": "password123"},
) )
assert response.status_code == 401 assert response.status_code == 401
# Logout tests
@pytest.mark.asyncio
async def test_logout_success(client_factory):
email = unique_email("logout")
reg_response = await client_factory.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
cookies = dict(reg_response.cookies)
async with client_factory.create(cookies=cookies) as authed:
logout_response = await authed.post("/api/auth/logout")
assert logout_response.status_code == 200
assert logout_response.json() == {"ok": True}

View file

@ -1,128 +1,149 @@
import pytest import pytest
import uuid import uuid
from auth import COOKIE_NAME
def unique_email(prefix: str = "counter") -> str: def unique_email(prefix: str = "counter") -> str:
"""Generate a unique email for tests sharing the same database.""" """Generate a unique email for tests sharing the same database."""
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com" return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
async def create_user_and_get_headers(client, email: str = None) -> dict:
"""Create a user and return auth headers for authenticated requests."""
if email is None:
email = unique_email()
response = await client.post(
"/api/auth/register",
json={"email": email, "password": "testpass123"},
)
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
# Protected endpoint tests - without auth # Protected endpoint tests - without auth
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_requires_auth(client): async def test_get_counter_requires_auth(client):
response = await client.get("/api/counter") response = await client.get("/api/counter")
assert response.status_code in [401, 403] assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter_requires_auth(client): async def test_increment_counter_requires_auth(client):
response = await client.post("/api/counter/increment") response = await client.post("/api/counter/increment")
assert response.status_code in [401, 403]
@pytest.mark.asyncio
async def test_get_counter_invalid_token(client):
response = await client.get(
"/api/counter",
headers={"Authorization": "Bearer invalidtoken"},
)
assert response.status_code == 401 assert response.status_code == 401
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter_invalid_token(client): async def test_get_counter_invalid_cookie(client_factory):
response = await client.post( async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed:
"/api/counter/increment", response = await authed.get("/api/counter")
headers={"Authorization": "Bearer invalidtoken"}, assert response.status_code == 401
)
@pytest.mark.asyncio
async def test_increment_counter_invalid_cookie(client_factory):
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed:
response = await authed.post("/api/counter/increment")
assert response.status_code == 401 assert response.status_code == 401
# Authenticated counter tests # Authenticated counter tests
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_authenticated(client): async def test_get_counter_authenticated(client_factory):
auth_headers = await create_user_and_get_headers(client) reg = await client_factory.post(
response = await client.get("/api/counter", headers=auth_headers) "/api/auth/register",
json={"email": unique_email(), "password": "testpass123"},
)
cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed:
response = await authed.get("/api/counter")
assert response.status_code == 200 assert response.status_code == 200
assert "value" in response.json() assert "value" in response.json()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter(client): async def test_increment_counter(client_factory):
auth_headers = await create_user_and_get_headers(client) reg = await client_factory.post(
"/api/auth/register",
json={"email": unique_email(), "password": "testpass123"},
)
cookies = dict(reg.cookies)
# Get current value async with client_factory.create(cookies=cookies) as authed:
before = await client.get("/api/counter", headers=auth_headers) # Get current value
before_value = before.json()["value"] before = await authed.get("/api/counter")
before_value = before.json()["value"]
# Increment
response = await client.post("/api/counter/increment", headers=auth_headers) # Increment
assert response.status_code == 200 response = await authed.post("/api/counter/increment")
assert response.json()["value"] == before_value + 1 assert response.status_code == 200
assert response.json()["value"] == before_value + 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_increment_counter_multiple(client): async def test_increment_counter_multiple(client_factory):
auth_headers = await create_user_and_get_headers(client) reg = await client_factory.post(
"/api/auth/register",
json={"email": unique_email(), "password": "testpass123"},
)
cookies = dict(reg.cookies)
# Get starting value async with client_factory.create(cookies=cookies) as authed:
before = await client.get("/api/counter", headers=auth_headers) # Get starting value
start = before.json()["value"] before = await authed.get("/api/counter")
start = before.json()["value"]
# Increment 3 times
await client.post("/api/counter/increment", headers=auth_headers) # Increment 3 times
await client.post("/api/counter/increment", headers=auth_headers) await authed.post("/api/counter/increment")
response = await client.post("/api/counter/increment", headers=auth_headers) await authed.post("/api/counter/increment")
response = await authed.post("/api/counter/increment")
assert response.json()["value"] == start + 3
assert response.json()["value"] == start + 3
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_counter_after_increment(client): async def test_get_counter_after_increment(client_factory):
auth_headers = await create_user_and_get_headers(client) reg = await client_factory.post(
"/api/auth/register",
json={"email": unique_email(), "password": "testpass123"},
)
cookies = dict(reg.cookies)
before = await client.get("/api/counter", headers=auth_headers) async with client_factory.create(cookies=cookies) as authed:
start = before.json()["value"] before = await authed.get("/api/counter")
start = before.json()["value"]
await client.post("/api/counter/increment", headers=auth_headers)
await client.post("/api/counter/increment", headers=auth_headers) await authed.post("/api/counter/increment")
await authed.post("/api/counter/increment")
response = await client.get("/api/counter", headers=auth_headers)
assert response.json()["value"] == start + 2 response = await authed.get("/api/counter")
assert response.json()["value"] == start + 2
# Counter is shared between users # Counter is shared between users
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_counter_shared_between_users(client): async def test_counter_shared_between_users(client_factory):
headers1 = await create_user_and_get_headers(client, unique_email("share1")) # Create first user
reg1 = await client_factory.post(
"/api/auth/register",
json={"email": unique_email("share1"), "password": "testpass123"},
)
cookies1 = dict(reg1.cookies)
# Get starting value async with client_factory.create(cookies=cookies1) as user1:
before = await client.get("/api/counter", headers=headers1) # Get starting value
start = before.json()["value"] before = await user1.get("/api/counter")
start = before.json()["value"]
await user1.post("/api/counter/increment")
await user1.post("/api/counter/increment")
await client.post("/api/counter/increment", headers=headers1) # Create second user - should see the increments
await client.post("/api/counter/increment", headers=headers1) reg2 = await client_factory.post(
"/api/auth/register",
json={"email": unique_email("share2"), "password": "testpass123"},
)
cookies2 = dict(reg2.cookies)
# Second user sees the increments async with client_factory.create(cookies=cookies2) as user2:
headers2 = await create_user_and_get_headers(client, unique_email("share2")) response = await user2.get("/api/counter")
response = await client.get("/api/counter", headers=headers2) assert response.json()["value"] == start + 2
assert response.json()["value"] == start + 2
# Second user increments
# Second user increments await user2.post("/api/counter/increment")
await client.post("/api/counter/increment", headers=headers2)
# First user sees the increment # First user sees the increment
response = await client.get("/api/counter", headers=headers1) async with client_factory.create(cookies=cookies1) as user1:
assert response.json()["value"] == start + 3 response = await user1.get("/api/counter")
assert response.json()["value"] == start + 3

View file

@ -2,6 +2,8 @@
import { createContext, useContext, useState, useEffect, ReactNode } from "react"; import { createContext, useContext, useState, useEffect, ReactNode } from "react";
const API_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000";
interface User { interface User {
id: number; id: number;
email: string; email: string;
@ -9,54 +11,43 @@ interface User {
interface AuthContextType { interface AuthContextType {
user: User | null; user: User | null;
token: string | null;
isLoading: boolean; isLoading: boolean;
login: (email: string, password: string) => Promise<void>; login: (email: string, password: string) => Promise<void>;
register: (email: string, password: string) => Promise<void>; register: (email: string, password: string) => Promise<void>;
logout: () => void; logout: () => Promise<void>;
} }
const AuthContext = createContext<AuthContextType | null>(null); const AuthContext = createContext<AuthContextType | null>(null);
export function AuthProvider({ children }: { children: ReactNode }) { export function AuthProvider({ children }: { children: ReactNode }) {
const [user, setUser] = useState<User | null>(null); const [user, setUser] = useState<User | null>(null);
const [token, setToken] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(true); const [isLoading, setIsLoading] = useState(true);
useEffect(() => { useEffect(() => {
const storedToken = localStorage.getItem("token"); checkAuth();
if (storedToken) {
setToken(storedToken);
fetchUser(storedToken);
} else {
setIsLoading(false);
}
}, []); }, []);
const fetchUser = async (authToken: string) => { const checkAuth = async () => {
try { try {
const res = await fetch("http://localhost:8000/api/auth/me", { const res = await fetch(`${API_URL}/api/auth/me`, {
headers: { Authorization: `Bearer ${authToken}` }, credentials: "include",
}); });
if (res.ok) { if (res.ok) {
const userData = await res.json(); const userData = await res.json();
setUser(userData); setUser(userData);
} else {
localStorage.removeItem("token");
setToken(null);
} }
} catch { } catch {
localStorage.removeItem("token"); // Not authenticated
setToken(null);
} finally { } finally {
setIsLoading(false); setIsLoading(false);
} }
}; };
const login = async (email: string, password: string) => { const login = async (email: string, password: string) => {
const res = await fetch("http://localhost:8000/api/auth/login", { const res = await fetch(`${API_URL}/api/auth/login`, {
method: "POST", method: "POST",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
credentials: "include",
body: JSON.stringify({ email, password }), body: JSON.stringify({ email, password }),
}); });
@ -65,16 +56,15 @@ export function AuthProvider({ children }: { children: ReactNode }) {
throw new Error(error.detail || "Login failed"); throw new Error(error.detail || "Login failed");
} }
const data = await res.json(); const userData = await res.json();
localStorage.setItem("token", data.access_token); setUser(userData);
setToken(data.access_token);
setUser(data.user);
}; };
const register = async (email: string, password: string) => { const register = async (email: string, password: string) => {
const res = await fetch("http://localhost:8000/api/auth/register", { const res = await fetch(`${API_URL}/api/auth/register`, {
method: "POST", method: "POST",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
credentials: "include",
body: JSON.stringify({ email, password }), body: JSON.stringify({ email, password }),
}); });
@ -83,20 +73,20 @@ export function AuthProvider({ children }: { children: ReactNode }) {
throw new Error(error.detail || "Registration failed"); throw new Error(error.detail || "Registration failed");
} }
const data = await res.json(); const userData = await res.json();
localStorage.setItem("token", data.access_token); setUser(userData);
setToken(data.access_token);
setUser(data.user);
}; };
const logout = () => { const logout = async () => {
localStorage.removeItem("token"); await fetch(`${API_URL}/api/auth/logout`, {
setToken(null); method: "POST",
credentials: "include",
});
setUser(null); setUser(null);
}; };
return ( return (
<AuthContext.Provider value={{ user, token, isLoading, login, register, logout }}> <AuthContext.Provider value={{ user, isLoading, login, register, logout }}>
{children} {children}
</AuthContext.Provider> </AuthContext.Provider>
); );
@ -109,4 +99,3 @@ export function useAuth() {
} }
return context; return context;
} }

View file

@ -12,14 +12,12 @@ vi.mock("next/navigation", () => ({
// Default mock values // Default mock values
let mockUser: { id: number; email: string } | null = { id: 1, email: "test@example.com" }; let mockUser: { id: number; email: string } | null = { id: 1, email: "test@example.com" };
let mockToken: string | null = "valid-token";
let mockIsLoading = false; let mockIsLoading = false;
const mockLogout = vi.fn(); const mockLogout = vi.fn();
vi.mock("./auth-context", () => ({ vi.mock("./auth-context", () => ({
useAuth: () => ({ useAuth: () => ({
user: mockUser, user: mockUser,
token: mockToken,
isLoading: mockIsLoading, isLoading: mockIsLoading,
logout: mockLogout, logout: mockLogout,
}), }),
@ -29,7 +27,6 @@ beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
// Reset to authenticated state // Reset to authenticated state
mockUser = { id: 1, email: "test@example.com" }; mockUser = { id: 1, email: "test@example.com" };
mockToken = "valid-token";
mockIsLoading = false; mockIsLoading = false;
}); });
@ -64,14 +61,18 @@ describe("Home - Authenticated", () => {
expect(screen.getByText("Sign out")).toBeDefined(); expect(screen.getByText("Sign out")).toBeDefined();
}); });
test("clicking sign out calls logout", async () => { test("clicking sign out calls logout and redirects", async () => {
vi.spyOn(global, "fetch").mockResolvedValue({ vi.spyOn(global, "fetch").mockResolvedValue({
json: () => Promise.resolve({ value: 42 }), json: () => Promise.resolve({ value: 42 }),
} as Response); } as Response);
render(<Home />); render(<Home />);
fireEvent.click(screen.getByText("Sign out")); fireEvent.click(screen.getByText("Sign out"));
expect(mockLogout).toHaveBeenCalled();
await waitFor(() => {
expect(mockLogout).toHaveBeenCalled();
expect(mockPush).toHaveBeenCalledWith("/login");
});
}); });
test("renders counter value after fetch", async () => { test("renders counter value after fetch", async () => {
@ -85,7 +86,7 @@ describe("Home - Authenticated", () => {
}); });
}); });
test("fetches counter with auth header", async () => { test("fetches counter with credentials", async () => {
const fetchSpy = vi.spyOn(global, "fetch").mockResolvedValue({ const fetchSpy = vi.spyOn(global, "fetch").mockResolvedValue({
json: () => Promise.resolve({ value: 0 }), json: () => Promise.resolve({ value: 0 }),
} as Response); } as Response);
@ -96,7 +97,7 @@ describe("Home - Authenticated", () => {
expect(fetchSpy).toHaveBeenCalledWith( expect(fetchSpy).toHaveBeenCalledWith(
"http://localhost:8000/api/counter", "http://localhost:8000/api/counter",
expect.objectContaining({ expect.objectContaining({
headers: { Authorization: "Bearer valid-token" }, credentials: "include",
}) })
); );
}); });
@ -111,7 +112,7 @@ describe("Home - Authenticated", () => {
expect(screen.getByText("Increment")).toBeDefined(); expect(screen.getByText("Increment")).toBeDefined();
}); });
test("clicking increment button calls API with auth header", async () => { test("clicking increment button calls API with credentials", async () => {
const fetchSpy = vi const fetchSpy = vi
.spyOn(global, "fetch") .spyOn(global, "fetch")
.mockResolvedValueOnce({ json: () => Promise.resolve({ value: 0 }) } as Response) .mockResolvedValueOnce({ json: () => Promise.resolve({ value: 0 }) } as Response)
@ -127,7 +128,7 @@ describe("Home - Authenticated", () => {
"http://localhost:8000/api/counter/increment", "http://localhost:8000/api/counter/increment",
expect.objectContaining({ expect.objectContaining({
method: "POST", method: "POST",
headers: { Authorization: "Bearer valid-token" }, credentials: "include",
}) })
); );
}); });
@ -149,7 +150,6 @@ describe("Home - Authenticated", () => {
describe("Home - Unauthenticated", () => { describe("Home - Unauthenticated", () => {
test("redirects to login when not authenticated", async () => { test("redirects to login when not authenticated", async () => {
mockUser = null; mockUser = null;
mockToken = null;
render(<Home />); render(<Home />);
@ -160,16 +160,14 @@ describe("Home - Unauthenticated", () => {
test("returns null when not authenticated", () => { test("returns null when not authenticated", () => {
mockUser = null; mockUser = null;
mockToken = null;
const { container } = render(<Home />); const { container } = render(<Home />);
// Should render nothing (just redirects) // Should render nothing (just redirects)
expect(container.querySelector("main")).toBeNull(); expect(container.querySelector("main")).toBeNull();
}); });
test("does not fetch counter when no token", () => { test("does not fetch counter when no user", () => {
mockUser = null; mockUser = null;
mockToken = null;
const fetchSpy = vi.spyOn(global, "fetch"); const fetchSpy = vi.spyOn(global, "fetch");
render(<Home />); render(<Home />);
@ -182,7 +180,6 @@ describe("Home - Loading State", () => {
test("does not redirect while loading", () => { test("does not redirect while loading", () => {
mockIsLoading = true; mockIsLoading = true;
mockUser = null; mockUser = null;
mockToken = null;
render(<Home />); render(<Home />);

View file

@ -4,9 +4,11 @@ import { useEffect, useState } from "react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { useAuth } from "./auth-context"; import { useAuth } from "./auth-context";
const API_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000";
export default function Home() { export default function Home() {
const [count, setCount] = useState<number | null>(null); const [count, setCount] = useState<number | null>(null);
const { user, token, isLoading, logout } = useAuth(); const { user, isLoading, logout } = useAuth();
const router = useRouter(); const router = useRouter();
useEffect(() => { useEffect(() => {
@ -16,26 +18,30 @@ export default function Home() {
}, [isLoading, user, router]); }, [isLoading, user, router]);
useEffect(() => { useEffect(() => {
if (token) { if (user) {
fetch("http://localhost:8000/api/counter", { fetch(`${API_URL}/api/counter`, {
headers: { Authorization: `Bearer ${token}` }, credentials: "include",
}) })
.then((res) => res.json()) .then((res) => res.json())
.then((data) => setCount(data.value)) .then((data) => setCount(data.value))
.catch(() => setCount(null)); .catch(() => setCount(null));
} }
}, [token]); }, [user]);
const increment = async () => { const increment = async () => {
if (!token) return; const res = await fetch(`${API_URL}/api/counter/increment`, {
const res = await fetch("http://localhost:8000/api/counter/increment", {
method: "POST", method: "POST",
headers: { Authorization: `Bearer ${token}` }, credentials: "include",
}); });
const data = await res.json(); const data = await res.json();
setCount(data.value); setCount(data.value);
}; };
const handleLogout = async () => {
await logout();
router.push("/login");
};
if (isLoading) { if (isLoading) {
return ( return (
<main style={styles.main}> <main style={styles.main}>
@ -53,7 +59,7 @@ export default function Home() {
<div style={styles.header}> <div style={styles.header}>
<div style={styles.userInfo}> <div style={styles.userInfo}>
<span style={styles.userEmail}>{user.email}</span> <span style={styles.userEmail}>{user.email}</span>
<button onClick={logout} style={styles.logoutBtn}> <button onClick={handleLogout} style={styles.logoutBtn}>
Sign out Sign out
</button> </button>
</div> </div>

View file

@ -5,9 +5,9 @@ function uniqueEmail(): string {
return `test-${Date.now()}-${Math.random().toString(36).substring(7)}@example.com`; return `test-${Date.now()}-${Math.random().toString(36).substring(7)}@example.com`;
} }
// Helper to clear localStorage // Helper to clear auth cookies
async function clearAuth(page: Page) { async function clearAuth(page: Page) {
await page.evaluate(() => localStorage.clear()); await page.context().clearCookies();
} }
test.describe("Authentication Flow", () => { test.describe("Authentication Flow", () => {
@ -83,7 +83,7 @@ test.describe("Signup", () => {
await page.click('button[type="submit"]'); await page.click('button[type="submit"]');
await expect(page).toHaveURL("/"); await expect(page).toHaveURL("/");
// Clear and try again with same email // Clear cookies and try again with same email
await clearAuth(page); await clearAuth(page);
await page.goto("/signup"); await page.goto("/signup");
await page.fill('input[type="email"]', email); await page.fill('input[type="email"]', email);
@ -248,7 +248,7 @@ test.describe("Session Persistence", () => {
await expect(page.getByText(email)).toBeVisible(); await expect(page.getByText(email)).toBeVisible();
}); });
test("token is stored in localStorage", async ({ page }) => { test("auth cookie is set after login", async ({ page }) => {
const email = uniqueEmail(); const email = uniqueEmail();
await page.goto("/signup"); await page.goto("/signup");
@ -258,13 +258,14 @@ test.describe("Session Persistence", () => {
await page.click('button[type="submit"]'); await page.click('button[type="submit"]');
await expect(page).toHaveURL("/"); await expect(page).toHaveURL("/");
// Check localStorage // Check cookies
const token = await page.evaluate(() => localStorage.getItem("token")); const cookies = await page.context().cookies();
expect(token).toBeTruthy(); const authCookie = cookies.find((c) => c.name === "auth_token");
expect(token!.length).toBeGreaterThan(10); expect(authCookie).toBeTruthy();
expect(authCookie!.httpOnly).toBe(true);
}); });
test("token is cleared on logout", async ({ page }) => { test("auth cookie is cleared on logout", async ({ page }) => {
const email = uniqueEmail(); const email = uniqueEmail();
await page.goto("/signup"); await page.goto("/signup");
@ -276,8 +277,9 @@ test.describe("Session Persistence", () => {
await page.click("text=Sign out"); await page.click("text=Sign out");
const token = await page.evaluate(() => localStorage.getItem("token")); const cookies = await page.context().cookies();
expect(token).toBeNull(); const authCookie = cookies.find((c) => c.name === "auth_token");
// Cookie should be deleted or have empty value
expect(!authCookie || authCookie.value === "").toBe(true);
}); });
}); });

View file

@ -8,7 +8,7 @@ function uniqueEmail(): string {
// Helper to authenticate a user // Helper to authenticate a user
async function authenticate(page: Page): Promise<string> { async function authenticate(page: Page): Promise<string> {
const email = uniqueEmail(); const email = uniqueEmail();
await page.evaluate(() => localStorage.clear()); await page.context().clearCookies();
await page.goto("/signup"); await page.goto("/signup");
await page.fill('input[type="email"]', email); await page.fill('input[type="email"]', email);
await page.fill('input[type="password"]', "password123"); await page.fill('input[type="password"]', "password123");
@ -95,13 +95,13 @@ test.describe("Counter - Authenticated", () => {
test.describe("Counter - Unauthenticated", () => { test.describe("Counter - Unauthenticated", () => {
test("redirects to login when accessing counter without auth", async ({ page }) => { test("redirects to login when accessing counter without auth", async ({ page }) => {
await page.evaluate(() => localStorage.clear()); await page.context().clearCookies();
await page.goto("/"); await page.goto("/");
await expect(page).toHaveURL("/login"); await expect(page).toHaveURL("/login");
}); });
test("shows login form when redirected", async ({ page }) => { test("shows login form when redirected", async ({ page }) => {
await page.evaluate(() => localStorage.clear()); await page.context().clearCookies();
await page.goto("/"); await page.goto("/");
await expect(page.locator("h1")).toHaveText("Welcome back"); await expect(page.locator("h1")).toHaveText("Welcome back");
}); });
@ -138,11 +138,11 @@ test.describe("Counter - Session Integration", () => {
test("counter API requires authentication", async ({ page }) => { test("counter API requires authentication", async ({ page }) => {
// Try to access counter API directly without auth // Try to access counter API directly without auth
const response = await page.request.get("http://localhost:8000/api/counter"); const response = await page.request.get("http://localhost:8000/api/counter");
expect(response.status()).toBe(403); expect(response.status()).toBe(401);
}); });
test("counter increment API requires authentication", async ({ page }) => { test("counter increment API requires authentication", async ({ page }) => {
const response = await page.request.post("http://localhost:8000/api/counter/increment"); const response = await page.request.post("http://localhost:8000/api/counter/increment");
expect(response.status()).toBe(403); expect(response.status()).toBe(401);
}); });
}); });

7
frontend/env.example Normal file
View file

@ -0,0 +1,7 @@
# Environment variables for the frontend
# For local dev: use direnv with the root .envrc file (recommended)
# For production: set these in your deployment environment
# API URL for the backend
NEXT_PUBLIC_API_URL=http://localhost:8000

View file

@ -10,7 +10,7 @@ sleep 1
# Start db # Start db
docker compose up -d db docker compose up -d db
# Start backend # Start backend (SECRET_KEY should be set via .envrc or environment)
cd backend cd backend
uv run uvicorn main:app --port 8000 & uv run uvicorn main:app --port 8000 &
PID=$! PID=$!