tests passing
This commit is contained in:
parent
322bdd3e6e
commit
b173b47925
18 changed files with 1414 additions and 93 deletions
|
|
@ -1,15 +1,19 @@
|
|||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List
|
||||
|
||||
# Set required env vars before importing app
|
||||
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||
|
||||
from database import Base, get_db
|
||||
from main import app
|
||||
from models import User, Role, Permission, ROLE_DEFINITIONS
|
||||
from auth import get_password_hash
|
||||
|
||||
TEST_DATABASE_URL = os.getenv(
|
||||
"TEST_DATABASE_URL",
|
||||
|
|
@ -20,9 +24,10 @@ TEST_DATABASE_URL = os.getenv(
|
|||
class ClientFactory:
|
||||
"""Factory for creating httpx clients with optional cookies."""
|
||||
|
||||
def __init__(self, transport, base_url):
|
||||
def __init__(self, transport, base_url, session_factory):
|
||||
self._transport = transport
|
||||
self._base_url = base_url
|
||||
self._session_factory = session_factory
|
||||
|
||||
@asynccontextmanager
|
||||
async def create(self, cookies: dict | None = None):
|
||||
|
|
@ -45,6 +50,59 @@ class ClientFactory:
|
|||
async def post(self, url: str, **kwargs):
|
||||
return await self.request("POST", url, **kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_session(self):
|
||||
"""Get a database session for direct DB operations in tests."""
|
||||
async with self._session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def setup_roles(db: AsyncSession) -> dict[str, Role]:
|
||||
"""Create all roles with their permissions from ROLE_DEFINITIONS."""
|
||||
roles = {}
|
||||
for role_name, config in ROLE_DEFINITIONS.items():
|
||||
# Check if role exists
|
||||
result = await db.execute(select(Role).where(Role.name == role_name))
|
||||
role = result.scalar_one_or_none()
|
||||
|
||||
if not role:
|
||||
role = Role(name=role_name, description=config["description"])
|
||||
db.add(role)
|
||||
await db.flush()
|
||||
|
||||
# Set permissions
|
||||
await role.set_permissions(db, config["permissions"])
|
||||
roles[role_name] = role
|
||||
|
||||
await db.commit()
|
||||
return roles
|
||||
|
||||
|
||||
async def create_user_with_roles(
|
||||
db: AsyncSession,
|
||||
email: str,
|
||||
password: str,
|
||||
role_names: List[str],
|
||||
) -> User:
|
||||
"""Create a user with specified roles."""
|
||||
# Get roles
|
||||
roles = []
|
||||
for role_name in role_names:
|
||||
result = await db.execute(select(Role).where(Role.name == role_name))
|
||||
role = result.scalar_one_or_none()
|
||||
if role:
|
||||
roles.append(role)
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
hashed_password=get_password_hash(password),
|
||||
roles=roles,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def client_factory():
|
||||
|
|
@ -57,6 +115,10 @@ async def client_factory():
|
|||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Setup roles
|
||||
async with session_factory() as db:
|
||||
await setup_roles(db)
|
||||
|
||||
async def override_get_db():
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
|
@ -64,7 +126,7 @@ async def client_factory():
|
|||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
factory = ClientFactory(transport, "http://test")
|
||||
factory = ClientFactory(transport, "http://test", session_factory)
|
||||
|
||||
yield factory
|
||||
|
||||
|
|
@ -77,3 +139,78 @@ async def client(client_factory):
|
|||
"""Fixture for a simple client without cookies (backwards compatible)."""
|
||||
async with client_factory.create() as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def regular_user(client_factory):
|
||||
"""Create a regular user and return their credentials and cookies."""
|
||||
from tests.helpers import unique_email
|
||||
|
||||
email = unique_email("regular")
|
||||
password = "password123"
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
await create_user_with_roles(db, email, password, ["regular"])
|
||||
|
||||
# Login to get cookies
|
||||
response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"password": password,
|
||||
"cookies": dict(response.cookies),
|
||||
"response": response,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def admin_user(client_factory):
|
||||
"""Create an admin user and return their credentials and cookies."""
|
||||
from tests.helpers import unique_email
|
||||
|
||||
email = unique_email("admin")
|
||||
password = "password123"
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
await create_user_with_roles(db, email, password, ["admin"])
|
||||
|
||||
# Login to get cookies
|
||||
response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"password": password,
|
||||
"cookies": dict(response.cookies),
|
||||
"response": response,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def user_no_roles(client_factory):
|
||||
"""Create a user with NO roles and return their credentials and cookies."""
|
||||
from tests.helpers import unique_email
|
||||
|
||||
email = unique_email("noroles")
|
||||
password = "password123"
|
||||
|
||||
async with client_factory.get_db_session() as db:
|
||||
await create_user_with_roles(db, email, password, [])
|
||||
|
||||
# Login to get cookies
|
||||
response = await client_factory.post(
|
||||
"/api/auth/login",
|
||||
json={"email": email, "password": password},
|
||||
)
|
||||
|
||||
return {
|
||||
"email": email,
|
||||
"password": password,
|
||||
"cookies": dict(response.cookies),
|
||||
"response": response,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue