with some tests
This commit is contained in:
parent
a764c92a0b
commit
0995e1cc77
18 changed files with 3020 additions and 16 deletions
18
backend/database.py
Normal file
18
backend/database.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
import os
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@localhost:5432/arbret")
|
||||
|
||||
engine = create_async_engine(DATABASE_URL)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
async def get_db():
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
|
|
@ -1,7 +1,21 @@
|
|||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
app = FastAPI()
|
||||
from database import engine, get_db, Base
|
||||
from models import Counter
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
|
@ -11,7 +25,27 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
@app.get("/api/hello")
|
||||
def hello():
|
||||
return {"message": "Hello from FastAPI"}
|
||||
async def get_or_create_counter(db: AsyncSession) -> Counter:
|
||||
result = await db.execute(select(Counter).where(Counter.id == 1))
|
||||
counter = result.scalar_one_or_none()
|
||||
if not counter:
|
||||
counter = Counter(id=1, value=0)
|
||||
db.add(counter)
|
||||
await db.commit()
|
||||
await db.refresh(counter)
|
||||
return counter
|
||||
|
||||
|
||||
@app.get("/api/counter")
|
||||
async def get_counter(db: AsyncSession = Depends(get_db)):
|
||||
counter = await get_or_create_counter(db)
|
||||
return {"value": counter.value}
|
||||
|
||||
|
||||
@app.post("/api/counter/increment")
|
||||
async def increment_counter(db: AsyncSession = Depends(get_db)):
|
||||
counter = await get_or_create_counter(db)
|
||||
counter.value += 1
|
||||
await db.commit()
|
||||
return {"value": counter.value}
|
||||
|
||||
|
|
|
|||
11
backend/models.py
Normal file
11
backend/models.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from sqlalchemy import Integer
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from database import Base
|
||||
|
||||
|
||||
class Counter(Base):
|
||||
__tablename__ = "counter"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
|
||||
value: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
|
|
@ -5,5 +5,15 @@ requires-python = ">=3.11"
|
|||
dependencies = [
|
||||
"fastapi>=0.115.6",
|
||||
"uvicorn>=0.34.0",
|
||||
"sqlalchemy[asyncio]>=2.0.36",
|
||||
"asyncpg>=0.30.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.3.4",
|
||||
"pytest-asyncio>=0.25.0",
|
||||
"httpx>=0.28.1",
|
||||
"aiosqlite>=0.20.0",
|
||||
]
|
||||
|
||||
|
|
|
|||
4
backend/pytest.ini
Normal file
4
backend/pytest.ini
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
[pytest]
|
||||
asyncio_mode = auto
|
||||
asyncio_default_fixture_loop_scope = function
|
||||
|
||||
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
60
backend/tests/test_counter.py
Normal file
60
backend/tests/test_counter.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
|
||||
from database import Base, get_db
|
||||
from main import app
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
engine = create_async_engine(TEST_DATABASE_URL)
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def override_get_db():
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_counter_initial(client):
|
||||
response = await client.get("/api/counter")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"value": 0}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_counter(client):
|
||||
response = await client.post("/api/counter/increment")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"value": 1}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_counter_multiple(client):
|
||||
await client.post("/api/counter/increment")
|
||||
await client.post("/api/counter/increment")
|
||||
response = await client.post("/api/counter/increment")
|
||||
assert response.json() == {"value": 3}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_counter_after_increment(client):
|
||||
await client.post("/api/counter/increment")
|
||||
await client.post("/api/counter/increment")
|
||||
response = await client.get("/api/counter")
|
||||
assert response.json() == {"value": 2}
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue