diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..94a6067 --- /dev/null +++ b/.envrc @@ -0,0 +1,11 @@ +# Local development environment variables +# To use: install direnv (https://direnv.net), then run `direnv allow` + +# Backend +export SECRET_KEY="dev-secret-key-change-in-production" +export DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/arbret" +export TEST_DATABASE_URL="postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test" + +# Frontend +export NEXT_PUBLIC_API_URL="http://localhost:8000" + diff --git a/backend/auth.py b/backend/auth.py index b845c7e..bd04d16 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -3,8 +3,7 @@ from datetime import datetime, timedelta, timezone from typing import Optional import bcrypt -from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import Depends, HTTPException, Request, status from jose import JWTError, jwt from pydantic import BaseModel, EmailStr from sqlalchemy import select @@ -13,10 +12,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from database import get_db from models import User -SECRET_KEY = os.getenv("SECRET_KEY", "dev-secret-key-change-in-production") +SECRET_KEY = os.environ["SECRET_KEY"] # Required - see .env.example ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days -security = HTTPBearer() +COOKIE_NAME = "auth_token" class UserCreate(BaseModel): @@ -74,16 +73,19 @@ async def authenticate_user(db: AsyncSession, email: str, password: str) -> Opti async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), + request: Request, db: AsyncSession = Depends(get_db), ) -> User: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", - headers={"WWW-Authenticate": "Bearer"}, ) + + token = request.cookies.get(COOKIE_NAME) + if not token: + raise credentials_exception + try: - token = credentials.credentials payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) user_id_str = payload.get("sub") if user_id_str is None: diff --git a/backend/env.example b/backend/env.example new file mode 100644 index 0000000..8766365 --- /dev/null +++ b/backend/env.example @@ -0,0 +1,11 @@ +# Environment variables for the backend +# For local dev: use direnv with the root .envrc file (recommended) +# For production: set these in your deployment environment + +# Required: Secret key for JWT token signing +# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))" +SECRET_KEY= + +# Database URL +DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/arbret + diff --git a/backend/main.py b/backend/main.py index 20de1ca..7a4f18b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,5 @@ from contextlib import asynccontextmanager -from fastapi import FastAPI, Depends, HTTPException, status +from fastapi import FastAPI, Depends, HTTPException, Response, status from fastapi.middleware.cors import CORSMiddleware from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -7,10 +7,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from database import engine, get_db, Base from models import Counter, User from auth import ( + ACCESS_TOKEN_EXPIRE_MINUTES, + COOKIE_NAME, UserCreate, UserLogin, UserResponse, - TokenResponse, get_password_hash, get_user_by_email, authenticate_user, @@ -37,9 +38,24 @@ app.add_middleware( ) +def set_auth_cookie(response: Response, token: str) -> None: + response.set_cookie( + key=COOKIE_NAME, + value=token, + httponly=True, + secure=False, # Set to True in production with HTTPS + samesite="lax", + max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, + ) + + # Auth endpoints -@app.post("/api/auth/register", response_model=TokenResponse) -async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)): +@app.post("/api/auth/register", response_model=UserResponse) +async def register( + user_data: UserCreate, + response: Response, + db: AsyncSession = Depends(get_db), +): existing_user = await get_user_by_email(db, user_data.email) if existing_user: raise HTTPException( @@ -56,15 +72,16 @@ async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)): await db.refresh(user) access_token = create_access_token(data={"sub": str(user.id)}) - return TokenResponse( - access_token=access_token, - token_type="bearer", - user=UserResponse(id=user.id, email=user.email), - ) + set_auth_cookie(response, access_token) + return UserResponse(id=user.id, email=user.email) -@app.post("/api/auth/login", response_model=TokenResponse) -async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)): +@app.post("/api/auth/login", response_model=UserResponse) +async def login( + user_data: UserLogin, + response: Response, + db: AsyncSession = Depends(get_db), +): user = await authenticate_user(db, user_data.email, user_data.password) if not user: raise HTTPException( @@ -73,11 +90,14 @@ async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)): ) access_token = create_access_token(data={"sub": str(user.id)}) - return TokenResponse( - access_token=access_token, - token_type="bearer", - user=UserResponse(id=user.id, email=user.email), - ) + set_auth_cookie(response, access_token) + return UserResponse(id=user.id, email=user.email) + + +@app.post("/api/auth/logout") +async def logout(response: Response): + response.delete_cookie(key=COOKIE_NAME) + return {"ok": True} @app.get("/api/auth/me", response_model=UserResponse) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 4e01cb8..b2aa797 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,4 +1,9 @@ import os +from contextlib import asynccontextmanager + +# Set required env vars before importing app +os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only") + import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker @@ -12,8 +17,38 @@ TEST_DATABASE_URL = os.getenv( ) +class ClientFactory: + """Factory for creating httpx clients with optional cookies.""" + + def __init__(self, transport, base_url): + self._transport = transport + self._base_url = base_url + + @asynccontextmanager + async def create(self, cookies: dict | None = None): + """Create a new client, optionally with cookies set.""" + async with AsyncClient( + transport=self._transport, + base_url=self._base_url, + cookies=cookies or {}, + ) as client: + yield client + + async def request(self, method: str, url: str, **kwargs): + """Make a one-off request without cookies.""" + async with self.create() as client: + return await client.request(method, url, **kwargs) + + async def get(self, url: str, **kwargs): + return await self.request("GET", url, **kwargs) + + async def post(self, url: str, **kwargs): + return await self.request("POST", url, **kwargs) + + @pytest.fixture(scope="function") -async def client(): +async def client_factory(): + """Fixture that provides a factory for creating clients.""" engine = create_async_engine(TEST_DATABASE_URL) session_factory = async_sessionmaker(engine, expire_on_commit=False) @@ -28,8 +63,17 @@ async def client(): app.dependency_overrides[get_db] = override_get_db - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: - yield c + transport = ASGITransport(app=app) + factory = ClientFactory(transport, "http://test") + + yield factory app.dependency_overrides.clear() await engine.dispose() + + +@pytest.fixture(scope="function") +async def client(client_factory): + """Fixture for a simple client without cookies (backwards compatible).""" + async with client_factory.create() as c: + yield c diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index 497776d..edc2f98 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -1,28 +1,14 @@ import pytest import uuid +from auth import COOKIE_NAME + def unique_email(prefix: str = "test") -> str: """Generate a unique email for tests sharing the same database.""" return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com" -async def create_user_and_get_token(client, email: str = None, password: str = "testpass123") -> str: - """Helper to create a user and return their auth token.""" - if email is None: - email = unique_email() - response = await client.post( - "/api/auth/register", - json={"email": email, "password": password}, - ) - return response.json()["access_token"] - - -def auth_header(token: str) -> dict: - """Helper to create auth headers from token.""" - return {"Authorization": f"Bearer {token}"} - - # Registration tests @pytest.mark.asyncio async def test_register_success(client): @@ -33,10 +19,10 @@ async def test_register_success(client): ) assert response.status_code == 200 data = response.json() - assert "access_token" in data - assert data["token_type"] == "bearer" - assert data["user"]["email"] == email - assert "id" in data["user"] + assert data["email"] == email + assert "id" in data + # Cookie should be set + assert COOKIE_NAME in response.cookies @pytest.mark.asyncio @@ -101,9 +87,8 @@ async def test_login_success(client): ) assert response.status_code == 200 data = response.json() - assert "access_token" in data - assert data["token_type"] == "bearer" - assert data["user"]["email"] == email + assert data["email"] == email + assert COOKIE_NAME in response.cookies @pytest.mark.asyncio @@ -148,10 +133,20 @@ async def test_login_missing_fields(client): # Get current user tests @pytest.mark.asyncio -async def test_get_me_success(client): +async def test_get_me_success(client_factory): email = unique_email("me") - token = await create_user_and_get_token(client, email) - response = await client.get("/api/auth/me", headers=auth_header(token)) + + # Register and get cookies + reg_response = await client_factory.post( + "/api/auth/register", + json={"email": email, "password": "password123"}, + ) + cookies = dict(reg_response.cookies) + + # Use authenticated client + async with client_factory.create(cookies=cookies) as authed: + response = await authed.get("/api/auth/me") + assert response.status_code == 200 data = response.json() assert data["email"] == email @@ -159,94 +154,89 @@ async def test_get_me_success(client): @pytest.mark.asyncio -async def test_get_me_no_token(client): +async def test_get_me_no_cookie(client): response = await client.get("/api/auth/me") - # HTTPBearer returns 401/403 when credentials are missing - assert response.status_code in [401, 403] + assert response.status_code == 401 @pytest.mark.asyncio -async def test_get_me_invalid_token(client): - response = await client.get( - "/api/auth/me", - headers={"Authorization": "Bearer invalidtoken123"}, - ) +async def test_get_me_invalid_cookie(client_factory): + async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken123"}) as authed: + response = await authed.get("/api/auth/me") assert response.status_code == 401 assert response.json()["detail"] == "Invalid authentication credentials" @pytest.mark.asyncio -async def test_get_me_malformed_auth_header(client): - response = await client.get( - "/api/auth/me", - headers={"Authorization": "NotBearer token123"}, - ) - # Invalid scheme returns 401/403 - assert response.status_code in [401, 403] - - -@pytest.mark.asyncio -async def test_get_me_expired_token(client): - response = await client.get( - "/api/auth/me", - headers={"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjEsImV4cCI6MH0.invalid"}, - ) +async def test_get_me_expired_token(client_factory): + bad_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOjEsImV4cCI6MH0.invalid" + async with client_factory.create(cookies={COOKIE_NAME: bad_token}) as authed: + response = await authed.get("/api/auth/me") assert response.status_code == 401 -# Token validation tests +# Cookie validation tests @pytest.mark.asyncio -async def test_token_from_register_works_for_me(client): +async def test_cookie_from_register_works_for_me(client_factory): email = unique_email("tokentest") - register_response = await client.post( + + reg_response = await client_factory.post( "/api/auth/register", json={"email": email, "password": "password123"}, ) - token = register_response.json()["access_token"] + cookies = dict(reg_response.cookies) + + async with client_factory.create(cookies=cookies) as authed: + me_response = await authed.get("/api/auth/me") - me_response = await client.get("/api/auth/me", headers=auth_header(token)) assert me_response.status_code == 200 assert me_response.json()["email"] == email @pytest.mark.asyncio -async def test_token_from_login_works_for_me(client): +async def test_cookie_from_login_works_for_me(client_factory): email = unique_email("logintoken") - await client.post( + + await client_factory.post( "/api/auth/register", json={"email": email, "password": "password123"}, ) - login_response = await client.post( + login_response = await client_factory.post( "/api/auth/login", json={"email": email, "password": "password123"}, ) - token = login_response.json()["access_token"] + cookies = dict(login_response.cookies) + + async with client_factory.create(cookies=cookies) as authed: + me_response = await authed.get("/api/auth/me") - me_response = await client.get("/api/auth/me", headers=auth_header(token)) assert me_response.status_code == 200 assert me_response.json()["email"] == email # Multiple users tests @pytest.mark.asyncio -async def test_multiple_users_isolated(client): +async def test_multiple_users_isolated(client_factory): email1 = unique_email("user1") email2 = unique_email("user2") - resp1 = await client.post( + resp1 = await client_factory.post( "/api/auth/register", json={"email": email1, "password": "password1"}, ) - resp2 = await client.post( + resp2 = await client_factory.post( "/api/auth/register", json={"email": email2, "password": "password2"}, ) - token1 = resp1.json()["access_token"] - token2 = resp2.json()["access_token"] + cookies1 = dict(resp1.cookies) + cookies2 = dict(resp2.cookies) - me1 = await client.get("/api/auth/me", headers=auth_header(token1)) - me2 = await client.get("/api/auth/me", headers=auth_header(token2)) + async with client_factory.create(cookies=cookies1) as user1: + me1 = await user1.get("/api/auth/me") + + async with client_factory.create(cookies=cookies2) as user2: + me2 = await user2.get("/api/auth/me") assert me1.json()["email"] == email1 assert me2.json()["email"] == email2 @@ -280,3 +270,21 @@ async def test_case_sensitive_password(client): json={"email": email, "password": "password123"}, ) assert response.status_code == 401 + + +# Logout tests +@pytest.mark.asyncio +async def test_logout_success(client_factory): + email = unique_email("logout") + + reg_response = await client_factory.post( + "/api/auth/register", + json={"email": email, "password": "password123"}, + ) + cookies = dict(reg_response.cookies) + + async with client_factory.create(cookies=cookies) as authed: + logout_response = await authed.post("/api/auth/logout") + + assert logout_response.status_code == 200 + assert logout_response.json() == {"ok": True} diff --git a/backend/tests/test_counter.py b/backend/tests/test_counter.py index 545aa4b..20f1690 100644 --- a/backend/tests/test_counter.py +++ b/backend/tests/test_counter.py @@ -1,128 +1,149 @@ import pytest import uuid +from auth import COOKIE_NAME + def unique_email(prefix: str = "counter") -> str: """Generate a unique email for tests sharing the same database.""" return f"{prefix}-{uuid.uuid4().hex[:8]}@example.com" -async def create_user_and_get_headers(client, email: str = None) -> dict: - """Create a user and return auth headers for authenticated requests.""" - if email is None: - email = unique_email() - response = await client.post( - "/api/auth/register", - json={"email": email, "password": "testpass123"}, - ) - token = response.json()["access_token"] - return {"Authorization": f"Bearer {token}"} - - # Protected endpoint tests - without auth @pytest.mark.asyncio async def test_get_counter_requires_auth(client): response = await client.get("/api/counter") - assert response.status_code in [401, 403] + assert response.status_code == 401 @pytest.mark.asyncio async def test_increment_counter_requires_auth(client): response = await client.post("/api/counter/increment") - assert response.status_code in [401, 403] - - -@pytest.mark.asyncio -async def test_get_counter_invalid_token(client): - response = await client.get( - "/api/counter", - headers={"Authorization": "Bearer invalidtoken"}, - ) assert response.status_code == 401 @pytest.mark.asyncio -async def test_increment_counter_invalid_token(client): - response = await client.post( - "/api/counter/increment", - headers={"Authorization": "Bearer invalidtoken"}, - ) +async def test_get_counter_invalid_cookie(client_factory): + async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed: + response = await authed.get("/api/counter") + assert response.status_code == 401 + + +@pytest.mark.asyncio +async def test_increment_counter_invalid_cookie(client_factory): + async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed: + response = await authed.post("/api/counter/increment") assert response.status_code == 401 # Authenticated counter tests @pytest.mark.asyncio -async def test_get_counter_authenticated(client): - auth_headers = await create_user_and_get_headers(client) - response = await client.get("/api/counter", headers=auth_headers) +async def test_get_counter_authenticated(client_factory): + reg = await client_factory.post( + "/api/auth/register", + json={"email": unique_email(), "password": "testpass123"}, + ) + cookies = dict(reg.cookies) + + async with client_factory.create(cookies=cookies) as authed: + response = await authed.get("/api/counter") + assert response.status_code == 200 assert "value" in response.json() @pytest.mark.asyncio -async def test_increment_counter(client): - auth_headers = await create_user_and_get_headers(client) +async def test_increment_counter(client_factory): + reg = await client_factory.post( + "/api/auth/register", + json={"email": unique_email(), "password": "testpass123"}, + ) + cookies = dict(reg.cookies) - # Get current value - before = await client.get("/api/counter", headers=auth_headers) - before_value = before.json()["value"] - - # Increment - response = await client.post("/api/counter/increment", headers=auth_headers) - assert response.status_code == 200 - assert response.json()["value"] == before_value + 1 + async with client_factory.create(cookies=cookies) as authed: + # Get current value + before = await authed.get("/api/counter") + before_value = before.json()["value"] + + # Increment + response = await authed.post("/api/counter/increment") + assert response.status_code == 200 + assert response.json()["value"] == before_value + 1 @pytest.mark.asyncio -async def test_increment_counter_multiple(client): - auth_headers = await create_user_and_get_headers(client) +async def test_increment_counter_multiple(client_factory): + reg = await client_factory.post( + "/api/auth/register", + json={"email": unique_email(), "password": "testpass123"}, + ) + cookies = dict(reg.cookies) - # Get starting value - before = await client.get("/api/counter", headers=auth_headers) - start = before.json()["value"] - - # Increment 3 times - await client.post("/api/counter/increment", headers=auth_headers) - await client.post("/api/counter/increment", headers=auth_headers) - response = await client.post("/api/counter/increment", headers=auth_headers) - - assert response.json()["value"] == start + 3 + async with client_factory.create(cookies=cookies) as authed: + # Get starting value + before = await authed.get("/api/counter") + start = before.json()["value"] + + # Increment 3 times + await authed.post("/api/counter/increment") + await authed.post("/api/counter/increment") + response = await authed.post("/api/counter/increment") + + assert response.json()["value"] == start + 3 @pytest.mark.asyncio -async def test_get_counter_after_increment(client): - auth_headers = await create_user_and_get_headers(client) +async def test_get_counter_after_increment(client_factory): + reg = await client_factory.post( + "/api/auth/register", + json={"email": unique_email(), "password": "testpass123"}, + ) + cookies = dict(reg.cookies) - before = await client.get("/api/counter", headers=auth_headers) - start = before.json()["value"] - - await client.post("/api/counter/increment", headers=auth_headers) - await client.post("/api/counter/increment", headers=auth_headers) - - response = await client.get("/api/counter", headers=auth_headers) - assert response.json()["value"] == start + 2 + async with client_factory.create(cookies=cookies) as authed: + before = await authed.get("/api/counter") + start = before.json()["value"] + + await authed.post("/api/counter/increment") + await authed.post("/api/counter/increment") + + response = await authed.get("/api/counter") + assert response.json()["value"] == start + 2 # Counter is shared between users @pytest.mark.asyncio -async def test_counter_shared_between_users(client): - headers1 = await create_user_and_get_headers(client, unique_email("share1")) +async def test_counter_shared_between_users(client_factory): + # Create first user + reg1 = await client_factory.post( + "/api/auth/register", + json={"email": unique_email("share1"), "password": "testpass123"}, + ) + cookies1 = dict(reg1.cookies) - # Get starting value - before = await client.get("/api/counter", headers=headers1) - start = before.json()["value"] + async with client_factory.create(cookies=cookies1) as user1: + # Get starting value + before = await user1.get("/api/counter") + start = before.json()["value"] + + await user1.post("/api/counter/increment") + await user1.post("/api/counter/increment") - await client.post("/api/counter/increment", headers=headers1) - await client.post("/api/counter/increment", headers=headers1) + # Create second user - should see the increments + reg2 = await client_factory.post( + "/api/auth/register", + json={"email": unique_email("share2"), "password": "testpass123"}, + ) + cookies2 = dict(reg2.cookies) - # Second user sees the increments - headers2 = await create_user_and_get_headers(client, unique_email("share2")) - response = await client.get("/api/counter", headers=headers2) - assert response.json()["value"] == start + 2 - - # Second user increments - await client.post("/api/counter/increment", headers=headers2) + async with client_factory.create(cookies=cookies2) as user2: + response = await user2.get("/api/counter") + assert response.json()["value"] == start + 2 + + # Second user increments + await user2.post("/api/counter/increment") # First user sees the increment - response = await client.get("/api/counter", headers=headers1) - assert response.json()["value"] == start + 3 + async with client_factory.create(cookies=cookies1) as user1: + response = await user1.get("/api/counter") + assert response.json()["value"] == start + 3 diff --git a/frontend/app/auth-context.tsx b/frontend/app/auth-context.tsx index a1450bb..2e8c79c 100644 --- a/frontend/app/auth-context.tsx +++ b/frontend/app/auth-context.tsx @@ -2,6 +2,8 @@ import { createContext, useContext, useState, useEffect, ReactNode } from "react"; +const API_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000"; + interface User { id: number; email: string; @@ -9,54 +11,43 @@ interface User { interface AuthContextType { user: User | null; - token: string | null; isLoading: boolean; login: (email: string, password: string) => Promise; register: (email: string, password: string) => Promise; - logout: () => void; + logout: () => Promise; } const AuthContext = createContext(null); export function AuthProvider({ children }: { children: ReactNode }) { const [user, setUser] = useState(null); - const [token, setToken] = useState(null); const [isLoading, setIsLoading] = useState(true); useEffect(() => { - const storedToken = localStorage.getItem("token"); - if (storedToken) { - setToken(storedToken); - fetchUser(storedToken); - } else { - setIsLoading(false); - } + checkAuth(); }, []); - const fetchUser = async (authToken: string) => { + const checkAuth = async () => { try { - const res = await fetch("http://localhost:8000/api/auth/me", { - headers: { Authorization: `Bearer ${authToken}` }, + const res = await fetch(`${API_URL}/api/auth/me`, { + credentials: "include", }); if (res.ok) { const userData = await res.json(); setUser(userData); - } else { - localStorage.removeItem("token"); - setToken(null); } } catch { - localStorage.removeItem("token"); - setToken(null); + // Not authenticated } finally { setIsLoading(false); } }; const login = async (email: string, password: string) => { - const res = await fetch("http://localhost:8000/api/auth/login", { + const res = await fetch(`${API_URL}/api/auth/login`, { method: "POST", headers: { "Content-Type": "application/json" }, + credentials: "include", body: JSON.stringify({ email, password }), }); @@ -65,16 +56,15 @@ export function AuthProvider({ children }: { children: ReactNode }) { throw new Error(error.detail || "Login failed"); } - const data = await res.json(); - localStorage.setItem("token", data.access_token); - setToken(data.access_token); - setUser(data.user); + const userData = await res.json(); + setUser(userData); }; const register = async (email: string, password: string) => { - const res = await fetch("http://localhost:8000/api/auth/register", { + const res = await fetch(`${API_URL}/api/auth/register`, { method: "POST", headers: { "Content-Type": "application/json" }, + credentials: "include", body: JSON.stringify({ email, password }), }); @@ -83,20 +73,20 @@ export function AuthProvider({ children }: { children: ReactNode }) { throw new Error(error.detail || "Registration failed"); } - const data = await res.json(); - localStorage.setItem("token", data.access_token); - setToken(data.access_token); - setUser(data.user); + const userData = await res.json(); + setUser(userData); }; - const logout = () => { - localStorage.removeItem("token"); - setToken(null); + const logout = async () => { + await fetch(`${API_URL}/api/auth/logout`, { + method: "POST", + credentials: "include", + }); setUser(null); }; return ( - + {children} ); @@ -109,4 +99,3 @@ export function useAuth() { } return context; } - diff --git a/frontend/app/page.test.tsx b/frontend/app/page.test.tsx index b7d27f1..4feae15 100644 --- a/frontend/app/page.test.tsx +++ b/frontend/app/page.test.tsx @@ -12,14 +12,12 @@ vi.mock("next/navigation", () => ({ // Default mock values let mockUser: { id: number; email: string } | null = { id: 1, email: "test@example.com" }; -let mockToken: string | null = "valid-token"; let mockIsLoading = false; const mockLogout = vi.fn(); vi.mock("./auth-context", () => ({ useAuth: () => ({ user: mockUser, - token: mockToken, isLoading: mockIsLoading, logout: mockLogout, }), @@ -29,7 +27,6 @@ beforeEach(() => { vi.clearAllMocks(); // Reset to authenticated state mockUser = { id: 1, email: "test@example.com" }; - mockToken = "valid-token"; mockIsLoading = false; }); @@ -64,14 +61,18 @@ describe("Home - Authenticated", () => { expect(screen.getByText("Sign out")).toBeDefined(); }); - test("clicking sign out calls logout", async () => { + test("clicking sign out calls logout and redirects", async () => { vi.spyOn(global, "fetch").mockResolvedValue({ json: () => Promise.resolve({ value: 42 }), } as Response); render(); fireEvent.click(screen.getByText("Sign out")); - expect(mockLogout).toHaveBeenCalled(); + + await waitFor(() => { + expect(mockLogout).toHaveBeenCalled(); + expect(mockPush).toHaveBeenCalledWith("/login"); + }); }); test("renders counter value after fetch", async () => { @@ -85,7 +86,7 @@ describe("Home - Authenticated", () => { }); }); - test("fetches counter with auth header", async () => { + test("fetches counter with credentials", async () => { const fetchSpy = vi.spyOn(global, "fetch").mockResolvedValue({ json: () => Promise.resolve({ value: 0 }), } as Response); @@ -96,7 +97,7 @@ describe("Home - Authenticated", () => { expect(fetchSpy).toHaveBeenCalledWith( "http://localhost:8000/api/counter", expect.objectContaining({ - headers: { Authorization: "Bearer valid-token" }, + credentials: "include", }) ); }); @@ -111,7 +112,7 @@ describe("Home - Authenticated", () => { expect(screen.getByText("Increment")).toBeDefined(); }); - test("clicking increment button calls API with auth header", async () => { + test("clicking increment button calls API with credentials", async () => { const fetchSpy = vi .spyOn(global, "fetch") .mockResolvedValueOnce({ json: () => Promise.resolve({ value: 0 }) } as Response) @@ -127,7 +128,7 @@ describe("Home - Authenticated", () => { "http://localhost:8000/api/counter/increment", expect.objectContaining({ method: "POST", - headers: { Authorization: "Bearer valid-token" }, + credentials: "include", }) ); }); @@ -149,7 +150,6 @@ describe("Home - Authenticated", () => { describe("Home - Unauthenticated", () => { test("redirects to login when not authenticated", async () => { mockUser = null; - mockToken = null; render(); @@ -160,16 +160,14 @@ describe("Home - Unauthenticated", () => { test("returns null when not authenticated", () => { mockUser = null; - mockToken = null; const { container } = render(); // Should render nothing (just redirects) expect(container.querySelector("main")).toBeNull(); }); - test("does not fetch counter when no token", () => { + test("does not fetch counter when no user", () => { mockUser = null; - mockToken = null; const fetchSpy = vi.spyOn(global, "fetch"); render(); @@ -182,7 +180,6 @@ describe("Home - Loading State", () => { test("does not redirect while loading", () => { mockIsLoading = true; mockUser = null; - mockToken = null; render(); diff --git a/frontend/app/page.tsx b/frontend/app/page.tsx index fb884e6..718d68f 100644 --- a/frontend/app/page.tsx +++ b/frontend/app/page.tsx @@ -4,9 +4,11 @@ import { useEffect, useState } from "react"; import { useRouter } from "next/navigation"; import { useAuth } from "./auth-context"; +const API_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000"; + export default function Home() { const [count, setCount] = useState(null); - const { user, token, isLoading, logout } = useAuth(); + const { user, isLoading, logout } = useAuth(); const router = useRouter(); useEffect(() => { @@ -16,26 +18,30 @@ export default function Home() { }, [isLoading, user, router]); useEffect(() => { - if (token) { - fetch("http://localhost:8000/api/counter", { - headers: { Authorization: `Bearer ${token}` }, + if (user) { + fetch(`${API_URL}/api/counter`, { + credentials: "include", }) .then((res) => res.json()) .then((data) => setCount(data.value)) .catch(() => setCount(null)); } - }, [token]); + }, [user]); const increment = async () => { - if (!token) return; - const res = await fetch("http://localhost:8000/api/counter/increment", { + const res = await fetch(`${API_URL}/api/counter/increment`, { method: "POST", - headers: { Authorization: `Bearer ${token}` }, + credentials: "include", }); const data = await res.json(); setCount(data.value); }; + const handleLogout = async () => { + await logout(); + router.push("/login"); + }; + if (isLoading) { return (
@@ -53,7 +59,7 @@ export default function Home() {
{user.email} -
diff --git a/frontend/e2e/auth.spec.ts b/frontend/e2e/auth.spec.ts index b307fda..d68c38c 100644 --- a/frontend/e2e/auth.spec.ts +++ b/frontend/e2e/auth.spec.ts @@ -5,9 +5,9 @@ function uniqueEmail(): string { return `test-${Date.now()}-${Math.random().toString(36).substring(7)}@example.com`; } -// Helper to clear localStorage +// Helper to clear auth cookies async function clearAuth(page: Page) { - await page.evaluate(() => localStorage.clear()); + await page.context().clearCookies(); } test.describe("Authentication Flow", () => { @@ -83,7 +83,7 @@ test.describe("Signup", () => { await page.click('button[type="submit"]'); await expect(page).toHaveURL("/"); - // Clear and try again with same email + // Clear cookies and try again with same email await clearAuth(page); await page.goto("/signup"); await page.fill('input[type="email"]', email); @@ -248,7 +248,7 @@ test.describe("Session Persistence", () => { await expect(page.getByText(email)).toBeVisible(); }); - test("token is stored in localStorage", async ({ page }) => { + test("auth cookie is set after login", async ({ page }) => { const email = uniqueEmail(); await page.goto("/signup"); @@ -258,13 +258,14 @@ test.describe("Session Persistence", () => { await page.click('button[type="submit"]'); await expect(page).toHaveURL("/"); - // Check localStorage - const token = await page.evaluate(() => localStorage.getItem("token")); - expect(token).toBeTruthy(); - expect(token!.length).toBeGreaterThan(10); + // Check cookies + const cookies = await page.context().cookies(); + const authCookie = cookies.find((c) => c.name === "auth_token"); + expect(authCookie).toBeTruthy(); + expect(authCookie!.httpOnly).toBe(true); }); - test("token is cleared on logout", async ({ page }) => { + test("auth cookie is cleared on logout", async ({ page }) => { const email = uniqueEmail(); await page.goto("/signup"); @@ -276,8 +277,9 @@ test.describe("Session Persistence", () => { await page.click("text=Sign out"); - const token = await page.evaluate(() => localStorage.getItem("token")); - expect(token).toBeNull(); + const cookies = await page.context().cookies(); + const authCookie = cookies.find((c) => c.name === "auth_token"); + // Cookie should be deleted or have empty value + expect(!authCookie || authCookie.value === "").toBe(true); }); }); - diff --git a/frontend/e2e/counter.spec.ts b/frontend/e2e/counter.spec.ts index 1f395d8..5c14570 100644 --- a/frontend/e2e/counter.spec.ts +++ b/frontend/e2e/counter.spec.ts @@ -8,7 +8,7 @@ function uniqueEmail(): string { // Helper to authenticate a user async function authenticate(page: Page): Promise { const email = uniqueEmail(); - await page.evaluate(() => localStorage.clear()); + await page.context().clearCookies(); await page.goto("/signup"); await page.fill('input[type="email"]', email); await page.fill('input[type="password"]', "password123"); @@ -95,13 +95,13 @@ test.describe("Counter - Authenticated", () => { test.describe("Counter - Unauthenticated", () => { test("redirects to login when accessing counter without auth", async ({ page }) => { - await page.evaluate(() => localStorage.clear()); + await page.context().clearCookies(); await page.goto("/"); await expect(page).toHaveURL("/login"); }); test("shows login form when redirected", async ({ page }) => { - await page.evaluate(() => localStorage.clear()); + await page.context().clearCookies(); await page.goto("/"); await expect(page.locator("h1")).toHaveText("Welcome back"); }); @@ -138,11 +138,11 @@ test.describe("Counter - Session Integration", () => { test("counter API requires authentication", async ({ page }) => { // Try to access counter API directly without auth const response = await page.request.get("http://localhost:8000/api/counter"); - expect(response.status()).toBe(403); + expect(response.status()).toBe(401); }); test("counter increment API requires authentication", async ({ page }) => { const response = await page.request.post("http://localhost:8000/api/counter/increment"); - expect(response.status()).toBe(403); + expect(response.status()).toBe(401); }); }); diff --git a/frontend/env.example b/frontend/env.example new file mode 100644 index 0000000..03852f5 --- /dev/null +++ b/frontend/env.example @@ -0,0 +1,7 @@ +# Environment variables for the frontend +# For local dev: use direnv with the root .envrc file (recommended) +# For production: set these in your deployment environment + +# API URL for the backend +NEXT_PUBLIC_API_URL=http://localhost:8000 + diff --git a/scripts/e2e.sh b/scripts/e2e.sh index 7bbf8bc..3f91de2 100755 --- a/scripts/e2e.sh +++ b/scripts/e2e.sh @@ -10,7 +10,7 @@ sleep 1 # Start db docker compose up -d db -# Start backend +# Start backend (SECRET_KEY should be set via .envrc or environment) cd backend uv run uvicorn main:app --port 8000 & PID=$!