arbret/backend/tests/conftest.py

79 lines
2.4 KiB
Python

import os
from contextlib import asynccontextmanager
# Set required env vars before importing app
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from database import Base, get_db
from main import app
TEST_DATABASE_URL = os.getenv(
"TEST_DATABASE_URL",
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
)
class ClientFactory:
"""Factory for creating httpx clients with optional cookies."""
def __init__(self, transport, base_url):
self._transport = transport
self._base_url = base_url
@asynccontextmanager
async def create(self, cookies: dict | None = None):
"""Create a new client, optionally with cookies set."""
async with AsyncClient(
transport=self._transport,
base_url=self._base_url,
cookies=cookies or {},
) as client:
yield client
async def request(self, method: str, url: str, **kwargs):
"""Make a one-off request without cookies."""
async with self.create() as client:
return await client.request(method, url, **kwargs)
async def get(self, url: str, **kwargs):
return await self.request("GET", url, **kwargs)
async def post(self, url: str, **kwargs):
return await self.request("POST", url, **kwargs)
@pytest.fixture(scope="function")
async def client_factory():
"""Fixture that provides a factory for creating clients."""
engine = create_async_engine(TEST_DATABASE_URL)
session_factory = async_sessionmaker(engine, expire_on_commit=False)
# Create tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)
async def override_get_db():
async with session_factory() as session:
yield session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
factory = ClientFactory(transport, "http://test")
yield factory
app.dependency_overrides.clear()
await engine.dispose()
@pytest.fixture(scope="function")
async def client(client_factory):
"""Fixture for a simple client without cookies (backwards compatible)."""
async with client_factory.create() as c:
yield c