import pytest from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker from database import Base, get_db from main import app TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" @pytest.fixture async def client(): engine = create_async_engine(TEST_DATABASE_URL) async_session = async_sessionmaker(engine, expire_on_commit=False) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) async def override_get_db(): async with async_session() as session: yield session app.dependency_overrides[get_db] = override_get_db async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c: yield c app.dependency_overrides.clear() await engine.dispose() @pytest.mark.asyncio async def test_get_counter_initial(client): response = await client.get("/api/counter") assert response.status_code == 200 assert response.json() == {"value": 0} @pytest.mark.asyncio async def test_increment_counter(client): response = await client.post("/api/counter/increment") assert response.status_code == 200 assert response.json() == {"value": 1} @pytest.mark.asyncio async def test_increment_counter_multiple(client): await client.post("/api/counter/increment") await client.post("/api/counter/increment") response = await client.post("/api/counter/increment") assert response.json() == {"value": 3} @pytest.mark.asyncio async def test_get_counter_after_increment(client): await client.post("/api/counter/increment") await client.post("/api/counter/increment") response = await client.get("/api/counter") assert response.json() == {"value": 2}