import os from contextlib import asynccontextmanager from unittest.mock import AsyncMock, patch # Set required env vars before importing app os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only") import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from auth import get_password_hash from database import Base, get_db from main import app from models import ROLE_ADMIN, ROLE_DEFINITIONS, ROLE_REGULAR, Role, User from tests.helpers import unique_email TEST_DATABASE_URL = os.getenv( "TEST_DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test", ) class ClientFactory: """Factory for creating httpx clients with optional cookies.""" def __init__(self, transport, base_url, session_factory): self._transport = transport self._base_url = base_url self._session_factory = session_factory @asynccontextmanager async def create(self, cookies: dict | None = None): """Create a new client, optionally with cookies set.""" async with AsyncClient( transport=self._transport, base_url=self._base_url, cookies=cookies or {}, ) as client: yield client async def request(self, method: str, url: str, **kwargs): """Make a one-off request without cookies.""" async with self.create() as client: return await client.request(method, url, **kwargs) async def get(self, url: str, **kwargs): return await self.request("GET", url, **kwargs) async def post(self, url: str, **kwargs): return await self.request("POST", url, **kwargs) @asynccontextmanager async def get_db_session(self): """Get a database session for direct DB operations in tests.""" async with self._session_factory() as session: yield session async def setup_roles(db: AsyncSession) -> dict[str, Role]: """Create all roles with their permissions from ROLE_DEFINITIONS.""" roles = {} for role_name, config in ROLE_DEFINITIONS.items(): # Check if role exists result = await db.execute(select(Role).where(Role.name == role_name)) role = result.scalar_one_or_none() if not role: role = Role(name=role_name, description=config["description"]) db.add(role) await db.flush() # Set permissions await role.set_permissions(db, config["permissions"]) roles[role_name] = role await db.commit() return roles async def create_user_with_roles( db: AsyncSession, email: str, password: str, role_names: list[str], ) -> User: """Create a user with specified roles.""" # Get roles roles = [] for role_name in role_names: result = await db.execute(select(Role).where(Role.name == role_name)) role = result.scalar_one_or_none() if not role: raise ValueError( f"Role '{role_name}' not found. Did you run setup_roles()?" ) roles.append(role) user = User( email=email, hashed_password=get_password_hash(password), roles=roles, ) db.add(user) await db.commit() await db.refresh(user) return user @pytest.fixture(scope="function") async def client_factory(): """Fixture that provides a factory for creating clients.""" engine = create_async_engine(TEST_DATABASE_URL) session_factory = async_sessionmaker(engine, expire_on_commit=False) # Create tables async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) # Setup roles async with session_factory() as db: await setup_roles(db) async def override_get_db(): async with session_factory() as session: yield session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) factory = ClientFactory(transport, "http://test", session_factory) yield factory app.dependency_overrides.clear() await engine.dispose() @pytest.fixture(scope="function") async def client(client_factory): """Fixture for a simple client without cookies (backwards compatible).""" async with client_factory.create() as c: yield c @pytest.fixture(scope="function") async def regular_user(client_factory): """Create a regular user and return their credentials and cookies.""" email = unique_email("regular") password = "password123" async with client_factory.get_db_session() as db: user = await create_user_with_roles(db, email, password, [ROLE_REGULAR]) user_id = user.id # Login to get cookies response = await client_factory.post( "/api/auth/login", json={"email": email, "password": password}, ) return { "email": email, "password": password, "cookies": dict(response.cookies), "response": response, "user": {"id": user_id, "email": email}, } @pytest.fixture(scope="function") async def alt_regular_user(client_factory): """Create a second regular user for tests needing multiple users.""" email = unique_email("alt_regular") password = "password123" async with client_factory.get_db_session() as db: user = await create_user_with_roles(db, email, password, [ROLE_REGULAR]) user_id = user.id # Login to get cookies response = await client_factory.post( "/api/auth/login", json={"email": email, "password": password}, ) return { "email": email, "password": password, "cookies": dict(response.cookies), "response": response, "user": {"id": user_id, "email": email}, } @pytest.fixture(scope="function") async def admin_user(client_factory): """Create an admin user and return their credentials and cookies.""" email = unique_email("admin") password = "password123" async with client_factory.get_db_session() as db: await create_user_with_roles(db, email, password, [ROLE_ADMIN]) # Login to get cookies response = await client_factory.post( "/api/auth/login", json={"email": email, "password": password}, ) return { "email": email, "password": password, "cookies": dict(response.cookies), "response": response, } @pytest.fixture(scope="function") async def user_no_roles(client_factory): """Create a user with NO roles and return their credentials and cookies.""" email = unique_email("noroles") password = "password123" async with client_factory.get_db_session() as db: await create_user_with_roles(db, email, password, []) # Login to get cookies response = await client_factory.post( "/api/auth/login", json={"email": email, "password": password}, ) return { "email": email, "password": password, "cookies": dict(response.cookies), "response": response, } @pytest.fixture(autouse=True) def mock_enqueue_job(): """Mock job enqueueing for all tests. pgqueuer requires PostgreSQL-specific features that aren't available in the test database setup. We mock the enqueue function to avoid connection issues while still testing the counter logic. """ mock = AsyncMock(return_value=1) # Return a fake job ID with patch("routes.counter.enqueue_random_number_job", mock): yield mock