211 lines
6.3 KiB
Python
211 lines
6.3 KiB
Python
import os
|
|
from contextlib import asynccontextmanager
|
|
|
|
# Set required env vars before importing app
|
|
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
|
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
|
|
|
from database import Base, get_db
|
|
from main import app
|
|
from models import User, Role, Permission, ROLE_DEFINITIONS, ROLE_REGULAR, ROLE_ADMIN
|
|
from auth import get_password_hash
|
|
from tests.helpers import unique_email
|
|
|
|
TEST_DATABASE_URL = os.getenv(
|
|
"TEST_DATABASE_URL",
|
|
"postgresql+asyncpg://postgres:postgres@localhost:5432/arbret_test"
|
|
)
|
|
|
|
|
|
class ClientFactory:
|
|
"""Factory for creating httpx clients with optional cookies."""
|
|
|
|
def __init__(self, transport, base_url, session_factory):
|
|
self._transport = transport
|
|
self._base_url = base_url
|
|
self._session_factory = session_factory
|
|
|
|
@asynccontextmanager
|
|
async def create(self, cookies: dict | None = None):
|
|
"""Create a new client, optionally with cookies set."""
|
|
async with AsyncClient(
|
|
transport=self._transport,
|
|
base_url=self._base_url,
|
|
cookies=cookies or {},
|
|
) as client:
|
|
yield client
|
|
|
|
async def request(self, method: str, url: str, **kwargs):
|
|
"""Make a one-off request without cookies."""
|
|
async with self.create() as client:
|
|
return await client.request(method, url, **kwargs)
|
|
|
|
async def get(self, url: str, **kwargs):
|
|
return await self.request("GET", url, **kwargs)
|
|
|
|
async def post(self, url: str, **kwargs):
|
|
return await self.request("POST", url, **kwargs)
|
|
|
|
@asynccontextmanager
|
|
async def get_db_session(self):
|
|
"""Get a database session for direct DB operations in tests."""
|
|
async with self._session_factory() as session:
|
|
yield session
|
|
|
|
|
|
async def setup_roles(db: AsyncSession) -> dict[str, Role]:
|
|
"""Create all roles with their permissions from ROLE_DEFINITIONS."""
|
|
roles = {}
|
|
for role_name, config in ROLE_DEFINITIONS.items():
|
|
# Check if role exists
|
|
result = await db.execute(select(Role).where(Role.name == role_name))
|
|
role = result.scalar_one_or_none()
|
|
|
|
if not role:
|
|
role = Role(name=role_name, description=config["description"])
|
|
db.add(role)
|
|
await db.flush()
|
|
|
|
# Set permissions
|
|
await role.set_permissions(db, config["permissions"])
|
|
roles[role_name] = role
|
|
|
|
await db.commit()
|
|
return roles
|
|
|
|
|
|
async def create_user_with_roles(
|
|
db: AsyncSession,
|
|
email: str,
|
|
password: str,
|
|
role_names: list[str],
|
|
) -> User:
|
|
"""Create a user with specified roles."""
|
|
# Get roles
|
|
roles = []
|
|
for role_name in role_names:
|
|
result = await db.execute(select(Role).where(Role.name == role_name))
|
|
role = result.scalar_one_or_none()
|
|
if not role:
|
|
raise ValueError(f"Role '{role_name}' not found. Did you run setup_roles()?")
|
|
roles.append(role)
|
|
|
|
user = User(
|
|
email=email,
|
|
hashed_password=get_password_hash(password),
|
|
roles=roles,
|
|
)
|
|
db.add(user)
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def client_factory():
|
|
"""Fixture that provides a factory for creating clients."""
|
|
engine = create_async_engine(TEST_DATABASE_URL)
|
|
session_factory = async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
# Create tables
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
# Setup roles
|
|
async with session_factory() as db:
|
|
await setup_roles(db)
|
|
|
|
async def override_get_db():
|
|
async with session_factory() as session:
|
|
yield session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
factory = ClientFactory(transport, "http://test", session_factory)
|
|
|
|
yield factory
|
|
|
|
app.dependency_overrides.clear()
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def client(client_factory):
|
|
"""Fixture for a simple client without cookies (backwards compatible)."""
|
|
async with client_factory.create() as c:
|
|
yield c
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def regular_user(client_factory):
|
|
"""Create a regular user and return their credentials and cookies."""
|
|
email = unique_email("regular")
|
|
password = "password123"
|
|
|
|
async with client_factory.get_db_session() as db:
|
|
await create_user_with_roles(db, email, password, [ROLE_REGULAR])
|
|
|
|
# Login to get cookies
|
|
response = await client_factory.post(
|
|
"/api/auth/login",
|
|
json={"email": email, "password": password},
|
|
)
|
|
|
|
return {
|
|
"email": email,
|
|
"password": password,
|
|
"cookies": dict(response.cookies),
|
|
"response": response,
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def admin_user(client_factory):
|
|
"""Create an admin user and return their credentials and cookies."""
|
|
email = unique_email("admin")
|
|
password = "password123"
|
|
|
|
async with client_factory.get_db_session() as db:
|
|
await create_user_with_roles(db, email, password, [ROLE_ADMIN])
|
|
|
|
# Login to get cookies
|
|
response = await client_factory.post(
|
|
"/api/auth/login",
|
|
json={"email": email, "password": password},
|
|
)
|
|
|
|
return {
|
|
"email": email,
|
|
"password": password,
|
|
"cookies": dict(response.cookies),
|
|
"response": response,
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def user_no_roles(client_factory):
|
|
"""Create a user with NO roles and return their credentials and cookies."""
|
|
email = unique_email("noroles")
|
|
password = "password123"
|
|
|
|
async with client_factory.get_db_session() as db:
|
|
await create_user_with_roles(db, email, password, [])
|
|
|
|
# Login to get cookies
|
|
response = await client_factory.post(
|
|
"/api/auth/login",
|
|
json={"email": email, "password": password},
|
|
)
|
|
|
|
return {
|
|
"email": email,
|
|
"password": password,
|
|
"cookies": dict(response.cookies),
|
|
"response": response,
|
|
}
|