refactors

This commit is contained in:
counterweight 2025-12-25 18:27:59 +01:00
parent f46d2ae8b3
commit 168b67acee
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
12 changed files with 471 additions and 126 deletions

View file

@ -1,11 +1,12 @@
"""Response mappers for converting models to API response schemas."""
from models import Exchange, Invite
from models import Exchange, Invite, PriceHistory
from schemas import (
AdminExchangeResponse,
ExchangeResponse,
ExchangeUserContact,
InviteResponse,
PriceHistoryResponse,
)
@ -89,3 +90,19 @@ class InviteMapper:
spent_at=invite.spent_at,
revoked_at=invite.revoked_at,
)
class PriceHistoryMapper:
"""Mapper for PriceHistory model to response schemas."""
@staticmethod
def to_response(record: PriceHistory) -> PriceHistoryResponse:
"""Convert a PriceHistory model to PriceHistoryResponse schema."""
return PriceHistoryResponse(
id=record.id,
source=record.source,
pair=record.pair,
price=record.price,
timestamp=record.timestamp,
created_at=record.created_at,
)

View file

@ -1,6 +1,17 @@
"""Repository layer for database queries."""
from repositories.availability import AvailabilityRepository
from repositories.exchange import ExchangeRepository
from repositories.invite import InviteRepository
from repositories.price import PriceRepository
from repositories.role import RoleRepository
from repositories.user import UserRepository
__all__ = ["PriceRepository", "UserRepository"]
__all__ = [
"AvailabilityRepository",
"ExchangeRepository",
"InviteRepository",
"PriceRepository",
"RoleRepository",
"UserRepository",
]

View file

@ -0,0 +1,41 @@
"""Availability repository for database queries."""
from datetime import date
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from models import Availability
class AvailabilityRepository:
"""Repository for availability-related database queries."""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_date_range(
self, from_date: date, to_date: date
) -> list[Availability]:
"""Get availability slots for a date range."""
result = await self.db.execute(
select(Availability)
.where(and_(Availability.date >= from_date, Availability.date <= to_date))
.order_by(Availability.date, Availability.start_time)
)
return list(result.scalars().all())
async def get_by_date(self, target_date: date) -> list[Availability]:
"""Get availability slots for a specific date."""
result = await self.db.execute(
select(Availability)
.where(Availability.date == target_date)
.order_by(Availability.start_time)
)
return list(result.scalars().all())
async def delete_by_date(self, target_date: date) -> None:
"""Delete all availability for a specific date."""
await self.db.execute(
delete(Availability).where(Availability.date == target_date)
)

View file

@ -0,0 +1,163 @@
"""Exchange repository for database queries."""
import uuid
from datetime import UTC, date, datetime, time
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from models import Exchange, ExchangeStatus, User
class ExchangeRepository:
"""Repository for exchange-related database queries."""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_public_id(
self, public_id: uuid.UUID, load_user: bool = False
) -> Exchange | None:
"""Get an exchange by public ID."""
query = select(Exchange).where(Exchange.public_id == public_id)
if load_user:
query = query.options(joinedload(Exchange.user))
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_by_user_id(
self, user_id: int, order_by_desc: bool = True
) -> list[Exchange]:
"""Get all exchanges for a user."""
query = select(Exchange).where(Exchange.user_id == user_id)
if order_by_desc:
query = query.order_by(Exchange.slot_start.desc())
else:
query = query.order_by(Exchange.slot_start.asc())
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_upcoming_booked(self) -> list[Exchange]:
"""Get all upcoming booked trades, sorted by slot time ascending."""
now = datetime.now(UTC)
query = (
select(Exchange)
.options(joinedload(Exchange.user))
.where(
and_(
Exchange.slot_start > now,
Exchange.status == ExchangeStatus.BOOKED,
)
)
.order_by(Exchange.slot_start.asc())
)
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_past_trades(
self,
status: ExchangeStatus | None = None,
start_date: date | None = None,
end_date: date | None = None,
user_search: str | None = None,
) -> list[Exchange]:
"""
Get past trades with optional filters.
Args:
status: Filter by exchange status
start_date: Filter by slot_start date (inclusive start)
end_date: Filter by slot_start date (inclusive end)
user_search: Search by user email (partial match, case-insensitive)
"""
now = datetime.now(UTC)
# Start with base query for past trades
query = (
select(Exchange)
.options(joinedload(Exchange.user))
.where(
(Exchange.slot_start <= now)
| (Exchange.status != ExchangeStatus.BOOKED)
)
)
# Apply status filter
if status:
query = query.where(Exchange.status == status)
# Apply date range filter
if start_date:
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
query = query.where(Exchange.slot_start >= start_dt)
if end_date:
end_dt = datetime.combine(end_date, time.max, tzinfo=UTC)
query = query.where(Exchange.slot_start <= end_dt)
# Apply user search filter
if user_search:
query = query.join(Exchange.user).where(
User.email.ilike(f"%{user_search}%")
)
# Order by most recent first
query = query.order_by(Exchange.slot_start.desc())
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_by_slot_start(
self, slot_start: datetime, status: ExchangeStatus | None = None
) -> Exchange | None:
"""Get exchange by slot start time, optionally filtered by status."""
query = select(Exchange).where(Exchange.slot_start == slot_start)
if status:
query = query.where(Exchange.status == status)
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_by_user_and_date_range(
self,
user_id: int,
start_date: date,
end_date: date,
status: ExchangeStatus | None = None,
) -> list[Exchange]:
"""Get exchanges for a user within a date range."""
from datetime import timedelta
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
# End date should be exclusive (next day at 00:00:00)
end_dt = datetime.combine(end_date, time.min, tzinfo=UTC) + timedelta(days=1)
query = select(Exchange).where(
and_(
Exchange.user_id == user_id,
Exchange.slot_start >= start_dt,
Exchange.slot_start < end_dt,
)
)
if status:
query = query.where(Exchange.status == status)
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_booked_slots_for_date(self, target_date: date) -> set[datetime]:
"""Get set of booked slot start times for a specific date."""
from utils.date_queries import date_to_end_datetime, date_to_start_datetime
date_start = date_to_start_datetime(target_date)
date_end = date_to_end_datetime(target_date)
result = await self.db.execute(
select(Exchange.slot_start).where(
and_(
Exchange.slot_start >= date_start,
Exchange.slot_start <= date_end,
Exchange.status == ExchangeStatus.BOOKED,
)
)
)
return {row[0] for row in result.all()}

View file

@ -0,0 +1,69 @@
"""Invite repository for database queries."""
from sqlalchemy import desc, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from models import Invite, InviteStatus
class InviteRepository:
"""Repository for invite-related database queries."""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_identifier(self, identifier: str) -> Invite | None:
"""Get an invite by identifier."""
result = await self.db.execute(
select(Invite).where(Invite.identifier == identifier)
)
return result.scalar_one_or_none()
async def get_by_id(self, invite_id: int) -> Invite | None:
"""Get an invite by ID."""
result = await self.db.execute(select(Invite).where(Invite.id == invite_id))
return result.scalar_one_or_none()
async def get_by_godfather_id(
self, godfather_id: int, order_by_desc: bool = True
) -> list[Invite]:
"""Get all invites for a godfather user."""
query = select(Invite).where(Invite.godfather_id == godfather_id)
if order_by_desc:
query = query.order_by(desc(Invite.created_at))
else:
query = query.order_by(Invite.created_at)
result = await self.db.execute(query)
return list(result.scalars().all())
async def count(
self,
status: InviteStatus | None = None,
godfather_id: int | None = None,
) -> int:
"""Count invites matching filters."""
query = select(func.count(Invite.id))
if status:
query = query.where(Invite.status == status)
if godfather_id:
query = query.where(Invite.godfather_id == godfather_id)
result = await self.db.execute(query)
return result.scalar() or 0
async def list_paginated(
self,
page: int,
per_page: int,
status: InviteStatus | None = None,
godfather_id: int | None = None,
) -> list[Invite]:
"""Get paginated list of invites."""
offset = (page - 1) * per_page
query = select(Invite)
if status:
query = query.where(Invite.status == status)
if godfather_id:
query = query.where(Invite.godfather_id == godfather_id)
query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page)
result = await self.db.execute(query)
return list(result.scalars().all())

View file

@ -1,5 +1,7 @@
"""Price repository for database queries."""
from datetime import datetime
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
@ -25,3 +27,32 @@ class PriceRepository:
)
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_recent(self, limit: int = 20) -> list[PriceHistory]:
"""Get the most recent price history records."""
query = select(PriceHistory).order_by(desc(PriceHistory.timestamp)).limit(limit)
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_by_timestamp(
self,
timestamp: str | datetime,
source: str = SOURCE_BITFINEX,
pair: str = PAIR_BTC_EUR,
) -> PriceHistory | None:
"""Get a price record by timestamp."""
# Convert string timestamp to datetime if needed
timestamp_dt: datetime
if isinstance(timestamp, str):
timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
else:
timestamp_dt = timestamp
result = await self.db.execute(
select(PriceHistory).where(
PriceHistory.source == source,
PriceHistory.pair == pair,
PriceHistory.timestamp == timestamp_dt,
)
)
return result.scalar_one_or_none()

View file

@ -0,0 +1,18 @@
"""Role repository for database queries."""
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models import Role
class RoleRepository:
"""Repository for role-related database queries."""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_name(self, name: str) -> Role | None:
"""Get a role by name."""
result = await self.db.execute(select(Role).where(Role.name == name))
return result.scalar_one_or_none()

View file

@ -1,22 +1,18 @@
"""Exchange routes for Bitcoin trading."""
import uuid
from datetime import UTC, date, datetime, time, timedelta
from datetime import UTC, date, datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from auth import require_permission
from database import get_db
from date_validation import validate_date_in_range
from exceptions import BadRequestError
from mappers import ExchangeMapper
from models import (
Availability,
BitcoinTransferMethod,
Exchange,
ExchangeStatus,
Permission,
PriceHistory,
@ -24,6 +20,7 @@ from models import (
User,
)
from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price
from repositories.exchange import ExchangeRepository
from repositories.price import PriceRepository
from schemas import (
AdminExchangeResponse,
@ -44,6 +41,7 @@ from shared_constants import (
PREMIUM_PERCENTAGE,
SLOT_DURATION_MINUTES,
)
from utils.enum_validation import validate_enum
router = APIRouter(prefix="/api/exchange", tags=["exchange"])
@ -190,28 +188,18 @@ async def get_available_slots(
validate_date_in_range(date_param, context="book")
# Get availability for the date
result = await db.execute(
select(Availability).where(Availability.date == date_param)
)
availabilities = result.scalars().all()
from repositories.availability import AvailabilityRepository
from repositories.exchange import ExchangeRepository
availability_repo = AvailabilityRepository(db)
availabilities = await availability_repo.get_by_date(date_param)
if not availabilities:
return AvailableSlotsResponse(date=date_param, slots=[])
# Get already booked slots for the date
date_start = datetime.combine(date_param, time.min, tzinfo=UTC)
date_end = datetime.combine(date_param, time.max, tzinfo=UTC)
result = await db.execute(
select(Exchange.slot_start).where(
and_(
Exchange.slot_start >= date_start,
Exchange.slot_start <= date_end,
Exchange.status == ExchangeStatus.BOOKED,
)
)
)
booked_starts = {row[0] for row in result.all()}
exchange_repo = ExchangeRepository(db)
booked_starts = await exchange_repo.get_booked_slots_for_date(date_param)
# Expand each availability into slots
all_slots: list[BookableSlot] = []
@ -247,21 +235,16 @@ async def create_exchange(
- EUR amount is within configured limits
"""
# Validate direction
try:
direction = TradeDirection(request.direction)
except ValueError:
raise BadRequestError(
f"Invalid direction: {request.direction}. Must be 'buy' or 'sell'."
) from None
direction: TradeDirection = validate_enum(
TradeDirection, request.direction, "direction"
)
# Validate bitcoin transfer method
try:
bitcoin_transfer_method = BitcoinTransferMethod(request.bitcoin_transfer_method)
except ValueError:
raise BadRequestError(
f"Invalid bitcoin_transfer_method: {request.bitcoin_transfer_method}. "
"Must be 'onchain' or 'lightning'."
) from None
bitcoin_transfer_method: BitcoinTransferMethod = validate_enum(
BitcoinTransferMethod,
request.bitcoin_transfer_method,
"bitcoin_transfer_method",
)
# Use service to create exchange (handles all validation)
service = ExchangeService(db)
@ -289,12 +272,8 @@ async def get_my_trades(
current_user: User = Depends(require_permission(Permission.VIEW_OWN_EXCHANGES)),
) -> list[ExchangeResponse]:
"""Get the current user's exchanges, sorted by date (newest first)."""
result = await db.execute(
select(Exchange)
.where(Exchange.user_id == current_user.id)
.order_by(Exchange.slot_start.desc())
)
exchanges = result.scalars().all()
exchange_repo = ExchangeRepository(db)
exchanges = await exchange_repo.get_by_user_id(current_user.id, order_by_desc=True)
return [ExchangeMapper.to_response(ex, current_user.email) for ex in exchanges]
@ -348,19 +327,8 @@ async def get_upcoming_trades(
_current_user: User = Depends(require_permission(Permission.VIEW_ALL_EXCHANGES)),
) -> list[AdminExchangeResponse]:
"""Get all upcoming booked trades, sorted by slot time ascending."""
now = datetime.now(UTC)
result = await db.execute(
select(Exchange)
.options(joinedload(Exchange.user))
.where(
and_(
Exchange.slot_start > now,
Exchange.status == ExchangeStatus.BOOKED,
)
)
.order_by(Exchange.slot_start.asc())
)
exchanges = result.scalars().all()
exchange_repo = ExchangeRepository(db)
exchanges = await exchange_repo.get_upcoming_booked()
return [ExchangeMapper.to_admin_response(ex) for ex in exchanges]
@ -383,45 +351,19 @@ async def get_past_trades(
- user_search: Search by user email (partial match)
"""
now = datetime.now(UTC)
# Start with base query for past trades (slot_start <= now OR not booked)
query = (
select(Exchange)
.options(joinedload(Exchange.user))
.where(
(Exchange.slot_start <= now) | (Exchange.status != ExchangeStatus.BOOKED)
)
)
# Apply status filter
status_enum: ExchangeStatus | None = None
if status:
try:
status_enum = ExchangeStatus(status)
query = query.where(Exchange.status == status_enum)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}",
) from None
status_enum = validate_enum(ExchangeStatus, status, "status")
# Apply date range filter
if start_date:
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
query = query.where(Exchange.slot_start >= start_dt)
if end_date:
end_dt = datetime.combine(end_date, time.max, tzinfo=UTC)
query = query.where(Exchange.slot_start <= end_dt)
# Apply user search filter (join with User table)
if user_search:
query = query.join(Exchange.user).where(User.email.ilike(f"%{user_search}%"))
# Order by most recent first
query = query.order_by(Exchange.slot_start.desc())
result = await db.execute(query)
exchanges = result.scalars().all()
# Use repository for query
exchange_repo = ExchangeRepository(db)
exchanges = await exchange_repo.get_past_trades(
status=status_enum,
start_date=start_date,
end_date=end_date,
user_search=user_search,
)
return [ExchangeMapper.to_admin_response(ex) for ex in exchanges]
@ -487,6 +429,10 @@ async def search_users(
Returns users whose email contains the search query (case-insensitive).
Limited to 10 results for autocomplete purposes.
"""
# Note: UserRepository doesn't have search yet, but we can add it
# For now, keeping direct query for this specific use case
from sqlalchemy import select
result = await db.execute(
select(User).where(User.email.ilike(f"%{q}%")).order_by(User.email).limit(10)
)

View file

@ -1,9 +1,8 @@
"""Exchange service for business logic related to Bitcoin trading."""
import uuid
from datetime import UTC, date, datetime, time, timedelta
from datetime import UTC, date, datetime, timedelta
from sqlalchemy import and_, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
@ -15,7 +14,6 @@ from exceptions import (
ServiceUnavailableError,
)
from models import (
Availability,
BitcoinTransferMethod,
Exchange,
ExchangeStatus,
@ -23,6 +21,7 @@ from models import (
TradeDirection,
User,
)
from repositories.exchange import ExchangeRepository
from repositories.price import PriceRepository
from shared_constants import (
EUR_TRADE_INCREMENT,
@ -44,6 +43,7 @@ class ExchangeService:
def __init__(self, db: AsyncSession):
self.db = db
self.price_repo = PriceRepository(db)
self.exchange_repo = ExchangeRepository(db)
def apply_premium_for_direction(
self,
@ -107,20 +107,21 @@ class ExchangeService:
self, slot_start: datetime, slot_date: date
) -> None:
"""Verify slot falls within availability."""
from repositories.availability import AvailabilityRepository
slot_start_time = slot_start.time()
slot_end_dt = slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
slot_end_time = slot_end_dt.time()
result = await self.db.execute(
select(Availability).where(
and_(
Availability.date == slot_date,
Availability.start_time <= slot_start_time,
Availability.end_time >= slot_end_time,
)
)
)
matching_availability = result.scalar_one_or_none()
availability_repo = AvailabilityRepository(self.db)
availabilities = await availability_repo.get_by_date(slot_date)
# Check if any availability block contains this slot
matching_availability = None
for avail in availabilities:
if avail.start_time <= slot_start_time and avail.end_time >= slot_end_time:
matching_availability = avail
break
if not matching_availability:
slot_str = slot_start.strftime("%Y-%m-%d %H:%M")
@ -171,29 +172,19 @@ class ExchangeService:
self, user: User, slot_date: date
) -> Exchange | None:
"""Check if user already has a trade on this date."""
existing_trade_query = select(Exchange).where(
and_(
Exchange.user_id == user.id,
Exchange.slot_start
>= datetime.combine(slot_date, time.min, tzinfo=UTC),
Exchange.slot_start
< datetime.combine(slot_date, time.max, tzinfo=UTC) + timedelta(days=1),
Exchange.status == ExchangeStatus.BOOKED,
)
exchanges = await self.exchange_repo.get_by_user_and_date_range(
user_id=user.id,
start_date=slot_date,
end_date=slot_date,
status=ExchangeStatus.BOOKED,
)
result = await self.db.execute(existing_trade_query)
return result.scalar_one_or_none()
return exchanges[0] if exchanges else None
async def check_slot_already_booked(self, slot_start: datetime) -> Exchange | None:
"""Check if slot is already booked (only consider BOOKED status)."""
slot_booked_query = select(Exchange).where(
and_(
Exchange.slot_start == slot_start,
Exchange.status == ExchangeStatus.BOOKED,
)
return await self.exchange_repo.get_by_slot_start(
slot_start, status=ExchangeStatus.BOOKED
)
result = await self.db.execute(slot_booked_query)
return result.scalar_one_or_none()
async def create_exchange(
self,
@ -297,9 +288,7 @@ class ExchangeService:
NotFoundError: If exchange not found or user doesn't own it
(for security, returns 404)
"""
query = select(Exchange).where(Exchange.public_id == public_id)
result = await self.db.execute(query)
exchange = result.scalar_one_or_none()
exchange = await self.exchange_repo.get_by_public_id(public_id)
if not exchange:
raise NotFoundError("Trade")

View file

@ -0,0 +1 @@
"""Utility modules for common functionality."""

View file

@ -0,0 +1,27 @@
"""Utilities for date/time query operations."""
from datetime import UTC, date, datetime, time
def date_to_start_datetime(d: date) -> datetime:
"""Convert a date to datetime at start of day (00:00:00) in UTC."""
return datetime.combine(d, time.min, tzinfo=UTC)
def date_to_end_datetime(d: date) -> datetime:
"""Convert a date to datetime at end of day (23:59:59.999999) in UTC."""
return datetime.combine(d, time.max, tzinfo=UTC)
def date_range_to_datetime_range(
start_date: date, end_date: date
) -> tuple[datetime, datetime]:
"""
Convert a date range to datetime range.
Returns:
Tuple of (start_datetime, end_datetime) where:
- start_datetime is start_date at 00:00:00 UTC
- end_datetime is end_date at 23:59:59.999999 UTC
"""
return date_to_start_datetime(start_date), date_to_end_datetime(end_date)

View file

@ -0,0 +1,32 @@
"""Utilities for validating enum values from strings."""
from enum import Enum
from typing import TypeVar
from exceptions import BadRequestError
T = TypeVar("T", bound=Enum)
def validate_enum(enum_class: type[T], value: str, field_name: str = "value") -> T:
"""
Validate and convert string to enum.
Args:
enum_class: The enum class to validate against
value: The string value to validate
field_name: Name of the field for error messages
Returns:
The validated enum value
Raises:
BadRequestError: If the value is not a valid enum member
"""
try:
return enum_class(value)
except ValueError:
valid_values = ", ".join(e.value for e in enum_class)
raise BadRequestError(
f"Invalid {field_name}: {value}. Must be one of: {valid_values}"
) from None