79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
import os
|
|
from contextlib import asynccontextmanager
|
|
|
|
# Set required env vars before importing app
|
|
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
|
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|
|
|
from database import Base, get_db
|
|
from main import app
|
|
|
|
TEST_DATABASE_URL = os.getenv(
|
|
"TEST_DATABASE_URL",
|
|
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
|
|
)
|
|
|
|
|
|
class ClientFactory:
|
|
"""Factory for creating httpx clients with optional cookies."""
|
|
|
|
def __init__(self, transport, base_url):
|
|
self._transport = transport
|
|
self._base_url = base_url
|
|
|
|
@asynccontextmanager
|
|
async def create(self, cookies: dict | None = None):
|
|
"""Create a new client, optionally with cookies set."""
|
|
async with AsyncClient(
|
|
transport=self._transport,
|
|
base_url=self._base_url,
|
|
cookies=cookies or {},
|
|
) as client:
|
|
yield client
|
|
|
|
async def request(self, method: str, url: str, **kwargs):
|
|
"""Make a one-off request without cookies."""
|
|
async with self.create() as client:
|
|
return await client.request(method, url, **kwargs)
|
|
|
|
async def get(self, url: str, **kwargs):
|
|
return await self.request("GET", url, **kwargs)
|
|
|
|
async def post(self, url: str, **kwargs):
|
|
return await self.request("POST", url, **kwargs)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def client_factory():
|
|
"""Fixture that provides a factory for creating clients."""
|
|
engine = create_async_engine(TEST_DATABASE_URL)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
# Create tables
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async def override_get_db():
|
|
async with session_factory() as session:
|
|
yield session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
factory = ClientFactory(transport, "http://test")
|
|
|
|
yield factory
|
|
|
|
app.dependency_overrides.clear()
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def client(client_factory):
|
|
"""Fixture for a simple client without cookies (backwards compatible)."""
|
|
async with client_factory.create() as c:
|
|
yield c
|