with some tests

This commit is contained in:
counterweight 2025-12-18 21:48:41 +01:00
parent a764c92a0b
commit 0995e1cc77
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
18 changed files with 3020 additions and 16 deletions

18
backend/database.py Normal file
View file

@ -0,0 +1,18 @@
import os
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret")
engine = create_async_engine(DATABASE_URL)
async_session = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
pass
async def get_db():
async with async_session() as session:
yield session

View file

@ -1,7 +1,21 @@
from fastapi import FastAPI
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
app = FastAPI()
from database import engine, get_db, Base
from models import Counter
@asynccontextmanager
async def lifespan(app: FastAPI):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
@ -11,7 +25,27 @@ app.add_middleware(
)
@app.get("/api/hello")
def hello():
return {"message": "Hello from FastAPI"}
async def get_or_create_counter(db: AsyncSession) -> Counter:
result = await db.execute(select(Counter).where(Counter.id == 1))
counter = result.scalar_one_or_none()
if not counter:
counter = Counter(id=1, value=0)
db.add(counter)
await db.commit()
await db.refresh(counter)
return counter
@app.get("/api/counter")
async def get_counter(db: AsyncSession = Depends(get_db)):
counter = await get_or_create_counter(db)
return {"value": counter.value}
@app.post("/api/counter/increment")
async def increment_counter(db: AsyncSession = Depends(get_db)):
counter = await get_or_create_counter(db)
counter.value += 1
await db.commit()
return {"value": counter.value}

11
backend/models.py Normal file
View file

@ -0,0 +1,11 @@
from sqlalchemy import Integer
from sqlalchemy.orm import Mapped, mapped_column
from database import Base
class Counter(Base):
__tablename__ = "counter"
id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
value: Mapped[int] = mapped_column(Integer, default=0)

View file

@ -5,5 +5,15 @@ requires-python = ">=3.11"
dependencies = [
"fastapi>=0.115.6",
"uvicorn>=0.34.0",
"sqlalchemy[asyncio]>=2.0.36",
"asyncpg>=0.30.0",
]
[dependency-groups]
dev = [
"pytest>=8.3.4",
"pytest-asyncio>=0.25.0",
"httpx>=0.28.1",
"aiosqlite>=0.20.0",
]

4
backend/pytest.ini Normal file
View file

@ -0,0 +1,4 @@
[pytest]
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function

View file

View file

@ -0,0 +1,60 @@
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from database import Base, get_db
from main import app
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
@pytest.fixture
async def client():
engine = create_async_engine(TEST_DATABASE_URL)
async_session = async_sessionmaker(engine, expire_on_commit=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def override_get_db():
async with async_session() as session:
yield session
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
yield c
app.dependency_overrides.clear()
await engine.dispose()
@pytest.mark.asyncio
async def test_get_counter_initial(client):
response = await client.get("/api/counter")
assert response.status_code == 200
assert response.json() == {"value": 0}
@pytest.mark.asyncio
async def test_increment_counter(client):
response = await client.post("/api/counter/increment")
assert response.status_code == 200
assert response.json() == {"value": 1}
@pytest.mark.asyncio
async def test_increment_counter_multiple(client):
await client.post("/api/counter/increment")
await client.post("/api/counter/increment")
response = await client.post("/api/counter/increment")
assert response.json() == {"value": 3}
@pytest.mark.asyncio
async def test_get_counter_after_increment(client):
await client.post("/api/counter/increment")
await client.post("/api/counter/increment")
response = await client.get("/api/counter")
assert response.json() == {"value": 2}