first round of review

This commit is contained in:
counterweight 2025-12-18 22:24:46 +01:00
parent 7ebfb7a2dd
commit da5a0d03eb
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
14 changed files with 362 additions and 244 deletions

11
.envrc Normal file
View file

@ -0,0 +1,11 @@
# Local development environment variables
# To use: install direnv (https://direnv.net), then run `direnv allow`
# Backend
export SECRET_KEY="dev-secret-key-change-in-production"
export DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/arbret"
export TEST_DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
# Frontend
export NEXT_PUBLIC_API_URL="http://localhost:8000"

View file

@ -3,8 +3,7 @@ from datetime import datetime, timedelta, timezone
from typing import Optional
import bcrypt
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import Depends, HTTPException, Request, status
from jose import JWTError, jwt
from pydantic import BaseModel, EmailStr
from sqlalchemy import select
@ -13,10 +12,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database import get_db
from models import User
SECRET_KEY = os.getenv("SECRET_KEY", "dev-secret-key-change-in-production")
SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
security = HTTPBearer()
COOKIE_NAME = "auth_token"
class UserCreate(BaseModel):
@ -74,16 +73,19 @@ async def authenticate_user(db: AsyncSession, email: str, password: str) -> Opti
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
request: Request,
db: AsyncSession = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token = request.cookies.get(COOKIE_NAME)
if not token:
raise credentials_exception
try:
token = credentials.credentials
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id_str = payload.get("sub")
if user_id_str is None:

11
backend/env.example Normal file
View file

@ -0,0 +1,11 @@
# Environment variables for the backend
# For local dev: use direnv with the root .envrc file (recommended)
# For production: set these in your deployment environment
# Required: Secret key for JWT token signing
# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))"
SECRET_KEY=
# Database URL
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/arbret

View file

@ -1,5 +1,5 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi import FastAPI, Depends, HTTPException, Response, status
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@ -7,10 +7,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database import engine, get_db, Base
from models import Counter, User
from auth import (
ACCESS_TOKEN_EXPIRE_MINUTES,
COOKIE_NAME,
UserCreate,
UserLogin,
UserResponse,
TokenResponse,
get_password_hash,
get_user_by_email,
authenticate_user,
@ -37,9 +38,24 @@ app.add_middleware(
)
def set_auth_cookie(response: Response, token: str) -> None:
response.set_cookie(
key=COOKIE_NAME,
value=token,
httponly=True,
secure=False, # Set to True in production with HTTPS
samesite="lax",
max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
)
# Auth endpoints
@app.post("/api/auth/register", response_model=TokenResponse)
async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
@app.post("/api/auth/register", response_model=UserResponse)
async def register(
user_data: UserCreate,
response: Response,
db: AsyncSession = Depends(get_db),
):
existing_user = await get_user_by_email(db, user_data.email)
if existing_user:
raise HTTPException(
@ -56,15 +72,16 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
await db.refresh(user)
access_token = create_access_token(data={"sub": str(user.id)})
return TokenResponse(
access_token=access_token,
token_type="bearer",
user=UserResponse(id=user.id, email=user.email),
)
set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email)
@app.post("/api/auth/login", response_model=TokenResponse)
async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)):
@app.post("/api/auth/login", response_model=UserResponse)
async def login(
user_data: UserLogin,
response: Response,
db: AsyncSession = Depends(get_db),
):
user = await authenticate_user(db, user_data.email, user_data.password)
if not user:
raise HTTPException(
@ -73,11 +90,14 @@ async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)):
)
access_token = create_access_token(data={"sub": str(user.id)})
return TokenResponse(
access_token=access_token,
token_type="bearer",
user=UserResponse(id=user.id, email=user.email),
)
set_auth_cookie(response, access_token)
return UserResponse(id=user.id, email=user.email)
@app.post("/api/auth/logout")
async def logout(response: Response):
response.delete_cookie(key=COOKIE_NAME)
return {"ok": True}
@app.get("/api/auth/me", response_model=UserResponse)

View file

@ -1,4 +1,9 @@
import os
from contextlib import asynccontextmanager
# Set required env vars before importing app
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
@ -12,8 +17,38 @@ TEST_DATABASE_URL = os.getenv(
)
class ClientFactory:
"""Factory for creating httpx clients with optional cookies."""
def __init__(self, transport, base_url):
self._transport = transport
self._base_url = base_url
@asynccontextmanager
async def create(self, cookies: dict | None = None):
"""Create a new client, optionally with cookies set."""
async with AsyncClient(
transport=self._transport,
base_url=self._base_url,
cookies=cookies or {},
) as client:
yield client
async def request(self, method: str, url: str, **kwargs):
"""Make a one-off request without cookies."""
async with self.create() as client:
return await client.request(method, url, **kwargs)
async def get(self, url: str, **kwargs):
return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs):
return await self.request("POST", url, **kwargs)
@pytest.fixture(scope="function")
async def client():
async def client_factory():
"""Fixture that provides a factory for creating clients."""
engine = create_async_engine(TEST_DATABASE_URL)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
@ -28,8 +63,17 @@ async def client():
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
yield c
transport = ASGITransport(app=app)
factory = ClientFactory(transport, "http://test")
yield factory
app.dependency_overrides.clear()
await engine.dispose()
@pytest.fixture(scope="function")
async def client(client_factory):
"""Fixture for a simple client without cookies (backwards compatible)."""
async with client_factory.create() as c:
yield c

View file

@ -1,28 +1,14 @@
import pytest
import uuid
from auth import COOKIE_NAME
def unique_email(prefix: str = "test") -> str:
"""Generate a unique email for tests sharing the same database."""
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
async def create_user_and_get_token(client, email: str = None, password: str = "testpass123") -> str:
"""Helper to create a user and return their auth token."""
if email is None:
email = unique_email()
response = await client.post(
"/api/auth/register",
json={"email": email, "password": password},
)
return response.json()["access_token"]
def auth_header(token: str) -> dict:
"""Helper to create auth headers from token."""
return {"Authorization": f"Bearer {token}"}
# Registration tests
@pytest.mark.asyncio
async def test_register_success(client):
@ -33,10 +19,10 @@ async def test_register_success(client):
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
assert data["user"]["email"] == email
assert "id" in data["user"]
assert data["email"] == email
assert "id" in data
# Cookie should be set
assert COOKIE_NAME in response.cookies
@pytest.mark.asyncio
@ -101,9 +87,8 @@ async def test_login_success(client):
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
assert data["user"]["email"] == email
assert data["email"] == email
assert COOKIE_NAME in response.cookies
@pytest.mark.asyncio
@ -148,10 +133,20 @@ async def test_login_missing_fields(client):
# Get current user tests
@pytest.mark.asyncio
async def test_get_me_success(client):
async def test_get_me_success(client_factory):
email = unique_email("me")
token = await create_user_and_get_token(client, email)
response = await client.get("/api/auth/me", headers=auth_header(token))
# Register and get cookies
reg_response = await client_factory.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
cookies = dict(reg_response.cookies)
# Use authenticated client
async with client_factory.create(cookies=cookies) as authed:
response = await authed.get("/api/auth/me")
assert response.status_code == 200
data = response.json()
assert data["email"] == email
@ -159,94 +154,89 @@ async def test_get_me_success(client):
@pytest.mark.asyncio
async def test_get_me_no_token(client):
async def test_get_me_no_cookie(client):
response = await client.get("/api/auth/me")
# HTTPBearer returns 401/403 when credentials are missing
assert response.status_code in [401, 403]
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_me_invalid_token(client):
response = await client.get(
"/api/auth/me",
headers={"Authorization": "Bearer invalidtoken123"},
)
async def test_get_me_invalid_cookie(client_factory):
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed:
response = await authed.get("/api/auth/me")
assert response.status_code == 401
assert response.json()["detail"] == "Invalid authentication credentials"
@pytest.mark.asyncio
async def test_get_me_malformed_auth_header(client):
response = await client.get(
"/api/auth/me",
headers={"Authorization": "NotBearer token123"},
)
# Invalid scheme returns 401/403
assert response.status_code in [401, 403]
@pytest.mark.asyncio
async def test_get_me_expired_token(client):
response = await client.get(
"/api/auth/me",
headers={"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjEsImV4cCI6MH0.invalid"},
)
async def test_get_me_expired_token(client_factory):
bad_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjEsImV4cCI6MH0.invalid"
async with client_factory.create(cookies={COOKIE_NAME: bad_token}) as authed:
response = await authed.get("/api/auth/me")
assert response.status_code == 401
# Token validation tests
# Cookie validation tests
@pytest.mark.asyncio
async def test_token_from_register_works_for_me(client):
async def test_cookie_from_register_works_for_me(client_factory):
email = unique_email("tokentest")
register_response = await client.post(
reg_response = await client_factory.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
token = register_response.json()["access_token"]
cookies = dict(reg_response.cookies)
async with client_factory.create(cookies=cookies) as authed:
me_response = await authed.get("/api/auth/me")
me_response = await client.get("/api/auth/me", headers=auth_header(token))
assert me_response.status_code == 200
assert me_response.json()["email"] == email
@pytest.mark.asyncio
async def test_token_from_login_works_for_me(client):
async def test_cookie_from_login_works_for_me(client_factory):
email = unique_email("logintoken")
await client.post(
await client_factory.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
login_response = await client.post(
login_response = await client_factory.post(
"/api/auth/login",
json={"email": email, "password": "password123"},
)
token = login_response.json()["access_token"]
cookies = dict(login_response.cookies)
async with client_factory.create(cookies=cookies) as authed:
me_response = await authed.get("/api/auth/me")
me_response = await client.get("/api/auth/me", headers=auth_header(token))
assert me_response.status_code == 200
assert me_response.json()["email"] == email
# Multiple users tests
@pytest.mark.asyncio
async def test_multiple_users_isolated(client):
async def test_multiple_users_isolated(client_factory):
email1 = unique_email("user1")
email2 = unique_email("user2")
resp1 = await client.post(
resp1 = await client_factory.post(
"/api/auth/register",
json={"email": email1, "password": "password1"},
)
resp2 = await client.post(
resp2 = await client_factory.post(
"/api/auth/register",
json={"email": email2, "password": "password2"},
)
token1 = resp1.json()["access_token"]
token2 = resp2.json()["access_token"]
cookies1 = dict(resp1.cookies)
cookies2 = dict(resp2.cookies)
me1 = await client.get("/api/auth/me", headers=auth_header(token1))
me2 = await client.get("/api/auth/me", headers=auth_header(token2))
async with client_factory.create(cookies=cookies1) as user1:
me1 = await user1.get("/api/auth/me")
async with client_factory.create(cookies=cookies2) as user2:
me2 = await user2.get("/api/auth/me")
assert me1.json()["email"] == email1
assert me2.json()["email"] == email2
@ -280,3 +270,21 @@ async def test_case_sensitive_password(client):
json={"email": email, "password": "password123"},
)
assert response.status_code == 401
# Logout tests
@pytest.mark.asyncio
async def test_logout_success(client_factory):
email = unique_email("logout")
reg_response = await client_factory.post(
"/api/auth/register",
json={"email": email, "password": "password123"},
)
cookies = dict(reg_response.cookies)
async with client_factory.create(cookies=cookies) as authed:
logout_response = await authed.post("/api/auth/logout")
assert logout_response.status_code == 200
assert logout_response.json() == {"ok": True}

View file

@ -1,128 +1,149 @@
import pytest
import uuid
from auth import COOKIE_NAME
def unique_email(prefix: str = "counter") -> str:
"""Generate a unique email for tests sharing the same database."""
return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com"
async def create_user_and_get_headers(client, email: str = None) -> dict:
"""Create a user and return auth headers for authenticated requests."""
if email is None:
email = unique_email()
response = await client.post(
"/api/auth/register",
json={"email": email, "password": "testpass123"},
)
token = response.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
# Protected endpoint tests - without auth
@pytest.mark.asyncio
async def test_get_counter_requires_auth(client):
response = await client.get("/api/counter")
assert response.status_code in [401, 403]
assert response.status_code == 401
@pytest.mark.asyncio
async def test_increment_counter_requires_auth(client):
response = await client.post("/api/counter/increment")
assert response.status_code in [401, 403]
@pytest.mark.asyncio
async def test_get_counter_invalid_token(client):
response = await client.get(
"/api/counter",
headers={"Authorization": "Bearer invalidtoken"},
)
assert response.status_code == 401
@pytest.mark.asyncio
async def test_increment_counter_invalid_token(client):
response = await client.post(
"/api/counter/increment",
headers={"Authorization": "Bearer invalidtoken"},
)
async def test_get_counter_invalid_cookie(client_factory):
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed:
response = await authed.get("/api/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_increment_counter_invalid_cookie(client_factory):
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed:
response = await authed.post("/api/counter/increment")
assert response.status_code == 401
# Authenticated counter tests
@pytest.mark.asyncio
async def test_get_counter_authenticated(client):
auth_headers = await create_user_and_get_headers(client)
response = await client.get("/api/counter", headers=auth_headers)
async def test_get_counter_authenticated(client_factory):
reg = await client_factory.post(
"/api/auth/register",
json={"email": unique_email(), "password": "testpass123"},
)
cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed:
response = await authed.get("/api/counter")
assert response.status_code == 200
assert "value" in response.json()
@pytest.mark.asyncio
async def test_increment_counter(client):
auth_headers = await create_user_and_get_headers(client)
async def test_increment_counter(client_factory):
reg = await client_factory.post(
"/api/auth/register",
json={"email": unique_email(), "password": "testpass123"},
)
cookies = dict(reg.cookies)
# Get current value
before = await client.get("/api/counter", headers=auth_headers)
before_value = before.json()["value"]
# Increment
response = await client.post("/api/counter/increment", headers=auth_headers)
assert response.status_code == 200
assert response.json()["value"] == before_value + 1
async with client_factory.create(cookies=cookies) as authed:
# Get current value
before = await authed.get("/api/counter")
before_value = before.json()["value"]
# Increment
response = await authed.post("/api/counter/increment")
assert response.status_code == 200
assert response.json()["value"] == before_value + 1
@pytest.mark.asyncio
async def test_increment_counter_multiple(client):
auth_headers = await create_user_and_get_headers(client)
async def test_increment_counter_multiple(client_factory):
reg = await client_factory.post(
"/api/auth/register",
json={"email": unique_email(), "password": "testpass123"},
)
cookies = dict(reg.cookies)
# Get starting value
before = await client.get("/api/counter", headers=auth_headers)
start = before.json()["value"]
# Increment 3 times
await client.post("/api/counter/increment", headers=auth_headers)
await client.post("/api/counter/increment", headers=auth_headers)
response = await client.post("/api/counter/increment", headers=auth_headers)
assert response.json()["value"] == start + 3
async with client_factory.create(cookies=cookies) as authed:
# Get starting value
before = await authed.get("/api/counter")
start = before.json()["value"]
# Increment 3 times
await authed.post("/api/counter/increment")
await authed.post("/api/counter/increment")
response = await authed.post("/api/counter/increment")
assert response.json()["value"] == start + 3
@pytest.mark.asyncio
async def test_get_counter_after_increment(client):
auth_headers = await create_user_and_get_headers(client)
async def test_get_counter_after_increment(client_factory):
reg = await client_factory.post(
"/api/auth/register",
json={"email": unique_email(), "password": "testpass123"},
)
cookies = dict(reg.cookies)
before = await client.get("/api/counter", headers=auth_headers)
start = before.json()["value"]
await client.post("/api/counter/increment", headers=auth_headers)
await client.post("/api/counter/increment", headers=auth_headers)
response = await client.get("/api/counter", headers=auth_headers)
assert response.json()["value"] == start + 2
async with client_factory.create(cookies=cookies) as authed:
before = await authed.get("/api/counter")
start = before.json()["value"]
await authed.post("/api/counter/increment")
await authed.post("/api/counter/increment")
response = await authed.get("/api/counter")
assert response.json()["value"] == start + 2
# Counter is shared between users
@pytest.mark.asyncio
async def test_counter_shared_between_users(client):
headers1 = await create_user_and_get_headers(client, unique_email("share1"))
async def test_counter_shared_between_users(client_factory):
# Create first user
reg1 = await client_factory.post(
"/api/auth/register",
json={"email": unique_email("share1"), "password": "testpass123"},
)
cookies1 = dict(reg1.cookies)
# Get starting value
before = await client.get("/api/counter", headers=headers1)
start = before.json()["value"]
async with client_factory.create(cookies=cookies1) as user1:
# Get starting value
before = await user1.get("/api/counter")
start = before.json()["value"]
await user1.post("/api/counter/increment")
await user1.post("/api/counter/increment")
await client.post("/api/counter/increment", headers=headers1)
await client.post("/api/counter/increment", headers=headers1)
# Create second user - should see the increments
reg2 = await client_factory.post(
"/api/auth/register",
json={"email": unique_email("share2"), "password": "testpass123"},
)
cookies2 = dict(reg2.cookies)
# Second user sees the increments
headers2 = await create_user_and_get_headers(client, unique_email("share2"))
response = await client.get("/api/counter", headers=headers2)
assert response.json()["value"] == start + 2
# Second user increments
await client.post("/api/counter/increment", headers=headers2)
async with client_factory.create(cookies=cookies2) as user2:
response = await user2.get("/api/counter")
assert response.json()["value"] == start + 2
# Second user increments
await user2.post("/api/counter/increment")
# First user sees the increment
response = await client.get("/api/counter", headers=headers1)
assert response.json()["value"] == start + 3
async with client_factory.create(cookies=cookies1) as user1:
response = await user1.get("/api/counter")
assert response.json()["value"] == start + 3

View file

@ -2,6 +2,8 @@
import { createContext, useContext, useState, useEffect, ReactNode } from "react";
const API_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000";
interface User {
id: number;
email: string;
@ -9,54 +11,43 @@ interface User {
interface AuthContextType {
user: User | null;
token: string | null;
isLoading: boolean;
login: (email: string, password: string) => Promise<void>;
register: (email: string, password: string) => Promise<void>;
logout: () => void;
logout: () => Promise<void>;
}
const AuthContext = createContext<AuthContextType | null>(null);
export function AuthProvider({ children }: { children: ReactNode }) {
const [user, setUser] = useState<User | null>(null);
const [token, setToken] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(true);
useEffect(() => {
const storedToken = localStorage.getItem("token");
if (storedToken) {
setToken(storedToken);
fetchUser(storedToken);
} else {
setIsLoading(false);
}
checkAuth();
}, []);
const fetchUser = async (authToken: string) => {
const checkAuth = async () => {
try {
const res = await fetch("http://localhost:8000/api/auth/me", {
headers: { Authorization: `Bearer ${authToken}` },
const res = await fetch(`${API_URL}/api/auth/me`, {
credentials: "include",
});
if (res.ok) {
const userData = await res.json();
setUser(userData);
} else {
localStorage.removeItem("token");
setToken(null);
}
} catch {
localStorage.removeItem("token");
setToken(null);
// Not authenticated
} finally {
setIsLoading(false);
}
};
const login = async (email: string, password: string) => {
const res = await fetch("http://localhost:8000/api/auth/login", {
const res = await fetch(`${API_URL}/api/auth/login`, {
method: "POST",
headers: { "Content-Type": "application/json" },
credentials: "include",
body: JSON.stringify({ email, password }),
});
@ -65,16 +56,15 @@ export function AuthProvider({ children }: { children: ReactNode }) {
throw new Error(error.detail || "Login failed");
}
const data = await res.json();
localStorage.setItem("token", data.access_token);
setToken(data.access_token);
setUser(data.user);
const userData = await res.json();
setUser(userData);
};
const register = async (email: string, password: string) => {
const res = await fetch("http://localhost:8000/api/auth/register", {
const res = await fetch(`${API_URL}/api/auth/register`, {
method: "POST",
headers: { "Content-Type": "application/json" },
credentials: "include",
body: JSON.stringify({ email, password }),
});
@ -83,20 +73,20 @@ export function AuthProvider({ children }: { children: ReactNode }) {
throw new Error(error.detail || "Registration failed");
}
const data = await res.json();
localStorage.setItem("token", data.access_token);
setToken(data.access_token);
setUser(data.user);
const userData = await res.json();
setUser(userData);
};
const logout = () => {
localStorage.removeItem("token");
setToken(null);
const logout = async () => {
await fetch(`${API_URL}/api/auth/logout`, {
method: "POST",
credentials: "include",
});
setUser(null);
};
return (
<AuthContext.Provider value={{ user, token, isLoading, login, register, logout }}>
<AuthContext.Provider value={{ user, isLoading, login, register, logout }}>
{children}
</AuthContext.Provider>
);
@ -109,4 +99,3 @@ export function useAuth() {
}
return context;
}

View file

@ -12,14 +12,12 @@ vi.mock("next/navigation", () => ({
// Default mock values
let mockUser: { id: number; email: string } | null = { id: 1, email: "test@example.com" };
let mockToken: string | null = "valid-token";
let mockIsLoading = false;
const mockLogout = vi.fn();
vi.mock("./auth-context", () => ({
useAuth: () => ({
user: mockUser,
token: mockToken,
isLoading: mockIsLoading,
logout: mockLogout,
}),
@ -29,7 +27,6 @@ beforeEach(() => {
vi.clearAllMocks();
// Reset to authenticated state
mockUser = { id: 1, email: "test@example.com" };
mockToken = "valid-token";
mockIsLoading = false;
});
@ -64,14 +61,18 @@ describe("Home - Authenticated", () => {
expect(screen.getByText("Sign out")).toBeDefined();
});
test("clicking sign out calls logout", async () => {
test("clicking sign out calls logout and redirects", async () => {
vi.spyOn(global, "fetch").mockResolvedValue({
json: () => Promise.resolve({ value: 42 }),
} as Response);
render(<Home />);
fireEvent.click(screen.getByText("Sign out"));
expect(mockLogout).toHaveBeenCalled();
await waitFor(() => {
expect(mockLogout).toHaveBeenCalled();
expect(mockPush).toHaveBeenCalledWith("/login");
});
});
test("renders counter value after fetch", async () => {
@ -85,7 +86,7 @@ describe("Home - Authenticated", () => {
});
});
test("fetches counter with auth header", async () => {
test("fetches counter with credentials", async () => {
const fetchSpy = vi.spyOn(global, "fetch").mockResolvedValue({
json: () => Promise.resolve({ value: 0 }),
} as Response);
@ -96,7 +97,7 @@ describe("Home - Authenticated", () => {
expect(fetchSpy).toHaveBeenCalledWith(
"http://localhost:8000/api/counter",
expect.objectContaining({
headers: { Authorization: "Bearer valid-token" },
credentials: "include",
})
);
});
@ -111,7 +112,7 @@ describe("Home - Authenticated", () => {
expect(screen.getByText("Increment")).toBeDefined();
});
test("clicking increment button calls API with auth header", async () => {
test("clicking increment button calls API with credentials", async () => {
const fetchSpy = vi
.spyOn(global, "fetch")
.mockResolvedValueOnce({ json: () => Promise.resolve({ value: 0 }) } as Response)
@ -127,7 +128,7 @@ describe("Home - Authenticated", () => {
"http://localhost:8000/api/counter/increment",
expect.objectContaining({
method: "POST",
headers: { Authorization: "Bearer valid-token" },
credentials: "include",
})
);
});
@ -149,7 +150,6 @@ describe("Home - Authenticated", () => {
describe("Home - Unauthenticated", () => {
test("redirects to login when not authenticated", async () => {
mockUser = null;
mockToken = null;
render(<Home />);
@ -160,16 +160,14 @@ describe("Home - Unauthenticated", () => {
test("returns null when not authenticated", () => {
mockUser = null;
mockToken = null;
const { container } = render(<Home />);
// Should render nothing (just redirects)
expect(container.querySelector("main")).toBeNull();
});
test("does not fetch counter when no token", () => {
test("does not fetch counter when no user", () => {
mockUser = null;
mockToken = null;
const fetchSpy = vi.spyOn(global, "fetch");
render(<Home />);
@ -182,7 +180,6 @@ describe("Home - Loading State", () => {
test("does not redirect while loading", () => {
mockIsLoading = true;
mockUser = null;
mockToken = null;
render(<Home />);

View file

@ -4,9 +4,11 @@ import { useEffect, useState } from "react";
import { useRouter } from "next/navigation";
import { useAuth } from "./auth-context";
const API_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000";
export default function Home() {
const [count, setCount] = useState<number | null>(null);
const { user, token, isLoading, logout } = useAuth();
const { user, isLoading, logout } = useAuth();
const router = useRouter();
useEffect(() => {
@ -16,26 +18,30 @@ export default function Home() {
}, [isLoading, user, router]);
useEffect(() => {
if (token) {
fetch("http://localhost:8000/api/counter", {
headers: { Authorization: `Bearer ${token}` },
if (user) {
fetch(`${API_URL}/api/counter`, {
credentials: "include",
})
.then((res) => res.json())
.then((data) => setCount(data.value))
.catch(() => setCount(null));
}
}, [token]);
}, [user]);
const increment = async () => {
if (!token) return;
const res = await fetch("http://localhost:8000/api/counter/increment", {
const res = await fetch(`${API_URL}/api/counter/increment`, {
method: "POST",
headers: { Authorization: `Bearer ${token}` },
credentials: "include",
});
const data = await res.json();
setCount(data.value);
};
const handleLogout = async () => {
await logout();
router.push("/login");
};
if (isLoading) {
return (
<main style={styles.main}>
@ -53,7 +59,7 @@ export default function Home() {
<div style={styles.header}>
<div style={styles.userInfo}>
<span style={styles.userEmail}>{user.email}</span>
<button onClick={logout} style={styles.logoutBtn}>
<button onClick={handleLogout} style={styles.logoutBtn}>
Sign out
</button>
</div>

View file

@ -5,9 +5,9 @@ function uniqueEmail(): string {
return `test-${Date.now()}-${Math.random().toString(36).substring(7)}@example.com`;
}
// Helper to clear localStorage
// Helper to clear auth cookies
async function clearAuth(page: Page) {
await page.evaluate(() => localStorage.clear());
await page.context().clearCookies();
}
test.describe("Authentication Flow", () => {
@ -83,7 +83,7 @@ test.describe("Signup", () => {
await page.click('button[type="submit"]');
await expect(page).toHaveURL("/");
// Clear and try again with same email
// Clear cookies and try again with same email
await clearAuth(page);
await page.goto("/signup");
await page.fill('input[type="email"]', email);
@ -248,7 +248,7 @@ test.describe("Session Persistence", () => {
await expect(page.getByText(email)).toBeVisible();
});
test("token is stored in localStorage", async ({ page }) => {
test("auth cookie is set after login", async ({ page }) => {
const email = uniqueEmail();
await page.goto("/signup");
@ -258,13 +258,14 @@ test.describe("Session Persistence", () => {
await page.click('button[type="submit"]');
await expect(page).toHaveURL("/");
// Check localStorage
const token = await page.evaluate(() => localStorage.getItem("token"));
expect(token).toBeTruthy();
expect(token!.length).toBeGreaterThan(10);
// Check cookies
const cookies = await page.context().cookies();
const authCookie = cookies.find((c) => c.name === "auth_token");
expect(authCookie).toBeTruthy();
expect(authCookie!.httpOnly).toBe(true);
});
test("token is cleared on logout", async ({ page }) => {
test("auth cookie is cleared on logout", async ({ page }) => {
const email = uniqueEmail();
await page.goto("/signup");
@ -276,8 +277,9 @@ test.describe("Session Persistence", () => {
await page.click("text=Sign out");
const token = await page.evaluate(() => localStorage.getItem("token"));
expect(token).toBeNull();
const cookies = await page.context().cookies();
const authCookie = cookies.find((c) => c.name === "auth_token");
// Cookie should be deleted or have empty value
expect(!authCookie || authCookie.value === "").toBe(true);
});
});

View file

@ -8,7 +8,7 @@ function uniqueEmail(): string {
// Helper to authenticate a user
async function authenticate(page: Page): Promise<string> {
const email = uniqueEmail();
await page.evaluate(() => localStorage.clear());
await page.context().clearCookies();
await page.goto("/signup");
await page.fill('input[type="email"]', email);
await page.fill('input[type="password"]', "password123");
@ -95,13 +95,13 @@ test.describe("Counter - Authenticated", () => {
test.describe("Counter - Unauthenticated", () => {
test("redirects to login when accessing counter without auth", async ({ page }) => {
await page.evaluate(() => localStorage.clear());
await page.context().clearCookies();
await page.goto("/");
await expect(page).toHaveURL("/login");
});
test("shows login form when redirected", async ({ page }) => {
await page.evaluate(() => localStorage.clear());
await page.context().clearCookies();
await page.goto("/");
await expect(page.locator("h1")).toHaveText("Welcome back");
});
@ -138,11 +138,11 @@ test.describe("Counter - Session Integration", () => {
test("counter API requires authentication", async ({ page }) => {
// Try to access counter API directly without auth
const response = await page.request.get("http://localhost:8000/api/counter");
expect(response.status()).toBe(403);
expect(response.status()).toBe(401);
});
test("counter increment API requires authentication", async ({ page }) => {
const response = await page.request.post("http://localhost:8000/api/counter/increment");
expect(response.status()).toBe(403);
expect(response.status()).toBe(401);
});
});

7
frontend/env.example Normal file
View file

@ -0,0 +1,7 @@
# Environment variables for the frontend
# For local dev: use direnv with the root .envrc file (recommended)
# For production: set these in your deployment environment
# API URL for the backend
NEXT_PUBLIC_API_URL=http://localhost:8000

View file

@ -10,7 +10,7 @@ sleep 1
# Start db
docker compose up -d db
# Start backend
# Start backend (SECRET_KEY should be set via .envrc or environment)
cd backend
uv run uvicorn main:app --port 8000 &
PID=$!