- Remove EUR_TRADE_MIN, EUR_TRADE_MAX, PREMIUM_PERCENTAGE from shared_constants.py - Remove eurTradeMin, eurTradeMax, premiumPercentage from shared/constants.json - Update validate_constants.py to not require removed fields - Update seed.py and seed_e2e.py to use defaults if fields don't exist - Update tests to handle missing constants gracefully
157 lines
5 KiB
Python
157 lines
5 KiB
Python
"""Seed the database with roles, permissions, and dev users."""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from auth import get_password_hash
|
|
from database import Base, async_session, engine
|
|
from models import (
|
|
ROLE_ADMIN,
|
|
ROLE_DEFINITIONS,
|
|
ROLE_REGULAR,
|
|
Permission,
|
|
Role,
|
|
User,
|
|
)
|
|
from repositories.pricing import PricingRepository
|
|
|
|
DEV_USER_EMAIL = os.environ["DEV_USER_EMAIL"]
|
|
DEV_USER_PASSWORD = os.environ["DEV_USER_PASSWORD"]
|
|
DEV_ADMIN_EMAIL = os.environ["DEV_ADMIN_EMAIL"]
|
|
DEV_ADMIN_PASSWORD = os.environ["DEV_ADMIN_PASSWORD"]
|
|
|
|
|
|
async def upsert_role(
|
|
db: AsyncSession, name: str, description: str, permissions: list[Permission]
|
|
) -> Role:
|
|
"""Create or update a role with the given permissions."""
|
|
result = await db.execute(select(Role).where(Role.name == name))
|
|
role = result.scalar_one_or_none()
|
|
|
|
if role:
|
|
role.description = description
|
|
print(f"Updated role: {name}")
|
|
else:
|
|
role = Role(name=name, description=description)
|
|
db.add(role)
|
|
await db.flush() # Get the role ID
|
|
print(f"Created role: {name}")
|
|
|
|
# Set permissions for the role
|
|
await role.set_permissions(db, permissions)
|
|
print(f" Permissions: {', '.join(p.value for p in permissions)}")
|
|
|
|
return role
|
|
|
|
|
|
async def upsert_user(
|
|
db: AsyncSession, email: str, password: str, role_names: list[str]
|
|
) -> User:
|
|
"""Create or update a user with the given credentials and roles."""
|
|
result = await db.execute(select(User).where(User.email == email))
|
|
user = result.scalar_one_or_none()
|
|
|
|
# Get roles
|
|
roles = []
|
|
for role_name in role_names:
|
|
result = await db.execute(select(Role).where(Role.name == role_name))
|
|
role = result.scalar_one_or_none()
|
|
if not role:
|
|
raise ValueError(f"Role '{role_name}' not found")
|
|
roles.append(role)
|
|
|
|
if user:
|
|
user.hashed_password = get_password_hash(password)
|
|
user.roles = roles # type: ignore[assignment]
|
|
print(f"Updated user: {email} with roles: {role_names}")
|
|
else:
|
|
user = User(
|
|
email=email,
|
|
hashed_password=get_password_hash(password),
|
|
roles=roles,
|
|
)
|
|
db.add(user)
|
|
print(f"Created user: {email} with roles: {role_names}")
|
|
|
|
return user
|
|
|
|
|
|
async def seed_pricing_config(db: AsyncSession) -> None:
|
|
"""Seed initial pricing configuration from shared/constants.json."""
|
|
# Load constants from shared/constants.json
|
|
constants_path = Path(__file__).parent.parent / "shared" / "constants.json"
|
|
with constants_path.open() as f:
|
|
constants = json.load(f)
|
|
|
|
exchange_config = constants.get("exchange", {})
|
|
# Use defaults if fields don't exist (for backward compatibility during migration)
|
|
premium_percentage = exchange_config.get("premiumPercentage", 5)
|
|
eur_trade_min = exchange_config.get("eurTradeMin", 100)
|
|
eur_trade_max = exchange_config.get("eurTradeMax", 3000)
|
|
|
|
# Convert EUR amounts to cents
|
|
eur_min_cents = eur_trade_min * 100
|
|
eur_max_cents = eur_trade_max * 100
|
|
|
|
repo = PricingRepository(db)
|
|
config = await repo.create_or_update(
|
|
premium_buy=premium_percentage,
|
|
premium_sell=premium_percentage,
|
|
small_trade_threshold_eur=0, # Default: no small trade extra (admin will set)
|
|
small_trade_extra_premium=0, # Default: no extra premium
|
|
eur_min_buy=eur_min_cents,
|
|
eur_max_buy=eur_max_cents,
|
|
eur_min_sell=eur_min_cents,
|
|
eur_max_sell=eur_max_cents,
|
|
)
|
|
|
|
if config.id: # If config already existed (was updated)
|
|
print("Updated pricing config")
|
|
else:
|
|
print("Created pricing config")
|
|
print(f" Premium BUY: {config.premium_buy}%, Premium SELL: {config.premium_sell}%")
|
|
print(
|
|
f" Trade limits BUY: €{config.eur_min_buy / 100:.0f} - "
|
|
f"€{config.eur_max_buy / 100:.0f}"
|
|
)
|
|
print(
|
|
f" Trade limits SELL: €{config.eur_min_sell / 100:.0f} - "
|
|
f"€{config.eur_max_sell / 100:.0f}"
|
|
)
|
|
|
|
|
|
async def seed() -> None:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
async with async_session() as db:
|
|
print("\n=== Seeding Roles ===")
|
|
for role_name, role_config in ROLE_DEFINITIONS.items():
|
|
await upsert_role(
|
|
db,
|
|
role_name,
|
|
role_config["description"],
|
|
role_config["permissions"],
|
|
)
|
|
|
|
print("\n=== Seeding Users ===")
|
|
# Create regular dev user
|
|
await upsert_user(db, DEV_USER_EMAIL, DEV_USER_PASSWORD, [ROLE_REGULAR])
|
|
|
|
# Create admin dev user
|
|
await upsert_user(db, DEV_ADMIN_EMAIL, DEV_ADMIN_PASSWORD, [ROLE_ADMIN])
|
|
|
|
print("\n=== Seeding Pricing Config ===")
|
|
await seed_pricing_config(db)
|
|
|
|
await db.commit()
|
|
print("\n=== Seeding Complete ===\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(seed())
|