2025-12-25 18:27:59 +01:00
|
|
|
"""Exchange repository for database queries."""
|
|
|
|
|
|
|
|
|
|
import uuid
|
|
|
|
|
from datetime import UTC, date, datetime, time
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import and_, select
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
|
|
|
|
|
|
from models import Exchange, ExchangeStatus, User
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExchangeRepository:
|
|
|
|
|
"""Repository for exchange-related database queries."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, db: AsyncSession):
|
|
|
|
|
self.db = db
|
|
|
|
|
|
|
|
|
|
async def get_by_public_id(
|
|
|
|
|
self, public_id: uuid.UUID, load_user: bool = False
|
|
|
|
|
) -> Exchange | None:
|
|
|
|
|
"""Get an exchange by public ID."""
|
|
|
|
|
query = select(Exchange).where(Exchange.public_id == public_id)
|
|
|
|
|
if load_user:
|
|
|
|
|
query = query.options(joinedload(Exchange.user))
|
|
|
|
|
result = await self.db.execute(query)
|
|
|
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
async def get_by_user_id(
|
|
|
|
|
self, user_id: int, order_by_desc: bool = True
|
|
|
|
|
) -> list[Exchange]:
|
|
|
|
|
"""Get all exchanges for a user."""
|
|
|
|
|
query = select(Exchange).where(Exchange.user_id == user_id)
|
|
|
|
|
if order_by_desc:
|
|
|
|
|
query = query.order_by(Exchange.slot_start.desc())
|
|
|
|
|
else:
|
|
|
|
|
query = query.order_by(Exchange.slot_start.asc())
|
|
|
|
|
result = await self.db.execute(query)
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
async def get_upcoming_booked(self) -> list[Exchange]:
|
|
|
|
|
"""Get all upcoming booked trades, sorted by slot time ascending."""
|
|
|
|
|
now = datetime.now(UTC)
|
|
|
|
|
query = (
|
|
|
|
|
select(Exchange)
|
|
|
|
|
.options(joinedload(Exchange.user))
|
|
|
|
|
.where(
|
|
|
|
|
and_(
|
|
|
|
|
Exchange.slot_start > now,
|
|
|
|
|
Exchange.status == ExchangeStatus.BOOKED,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
.order_by(Exchange.slot_start.asc())
|
|
|
|
|
)
|
|
|
|
|
result = await self.db.execute(query)
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
async def get_past_trades(
|
|
|
|
|
self,
|
|
|
|
|
status: ExchangeStatus | None = None,
|
|
|
|
|
start_date: date | None = None,
|
|
|
|
|
end_date: date | None = None,
|
|
|
|
|
user_search: str | None = None,
|
|
|
|
|
) -> list[Exchange]:
|
|
|
|
|
"""
|
|
|
|
|
Get past trades with optional filters.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
status: Filter by exchange status
|
|
|
|
|
start_date: Filter by slot_start date (inclusive start)
|
|
|
|
|
end_date: Filter by slot_start date (inclusive end)
|
|
|
|
|
user_search: Search by user email (partial match, case-insensitive)
|
|
|
|
|
"""
|
|
|
|
|
now = datetime.now(UTC)
|
|
|
|
|
|
|
|
|
|
# Start with base query for past trades
|
|
|
|
|
query = (
|
|
|
|
|
select(Exchange)
|
|
|
|
|
.options(joinedload(Exchange.user))
|
|
|
|
|
.where(
|
|
|
|
|
(Exchange.slot_start <= now)
|
|
|
|
|
| (Exchange.status != ExchangeStatus.BOOKED)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Apply status filter
|
|
|
|
|
if status:
|
|
|
|
|
query = query.where(Exchange.status == status)
|
|
|
|
|
|
|
|
|
|
# Apply date range filter
|
|
|
|
|
if start_date:
|
|
|
|
|
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
|
|
|
|
|
query = query.where(Exchange.slot_start >= start_dt)
|
|
|
|
|
if end_date:
|
|
|
|
|
end_dt = datetime.combine(end_date, time.max, tzinfo=UTC)
|
|
|
|
|
query = query.where(Exchange.slot_start <= end_dt)
|
|
|
|
|
|
|
|
|
|
# Apply user search filter
|
|
|
|
|
if user_search:
|
|
|
|
|
query = query.join(Exchange.user).where(
|
|
|
|
|
User.email.ilike(f"%{user_search}%")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Order by most recent first
|
|
|
|
|
query = query.order_by(Exchange.slot_start.desc())
|
|
|
|
|
|
|
|
|
|
result = await self.db.execute(query)
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
async def get_by_slot_start(
|
|
|
|
|
self, slot_start: datetime, status: ExchangeStatus | None = None
|
|
|
|
|
) -> Exchange | None:
|
|
|
|
|
"""Get exchange by slot start time, optionally filtered by status."""
|
|
|
|
|
query = select(Exchange).where(Exchange.slot_start == slot_start)
|
|
|
|
|
if status:
|
|
|
|
|
query = query.where(Exchange.status == status)
|
|
|
|
|
result = await self.db.execute(query)
|
|
|
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
async def get_by_user_and_date_range(
|
|
|
|
|
self,
|
|
|
|
|
user_id: int,
|
|
|
|
|
start_date: date,
|
|
|
|
|
end_date: date,
|
|
|
|
|
status: ExchangeStatus | None = None,
|
|
|
|
|
) -> list[Exchange]:
|
|
|
|
|
"""Get exchanges for a user within a date range."""
|
|
|
|
|
from datetime import timedelta
|
|
|
|
|
|
|
|
|
|
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
|
|
|
|
|
# End date should be exclusive (next day at 00:00:00)
|
|
|
|
|
end_dt = datetime.combine(end_date, time.min, tzinfo=UTC) + timedelta(days=1)
|
|
|
|
|
|
|
|
|
|
query = select(Exchange).where(
|
|
|
|
|
and_(
|
|
|
|
|
Exchange.user_id == user_id,
|
|
|
|
|
Exchange.slot_start >= start_dt,
|
|
|
|
|
Exchange.slot_start < end_dt,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if status:
|
|
|
|
|
query = query.where(Exchange.status == status)
|
|
|
|
|
|
|
|
|
|
result = await self.db.execute(query)
|
|
|
|
|
return list(result.scalars().all())
|
|
|
|
|
|
|
|
|
|
async def get_booked_slots_for_date(self, target_date: date) -> set[datetime]:
|
|
|
|
|
"""Get set of booked slot start times for a specific date."""
|
|
|
|
|
from utils.date_queries import date_to_end_datetime, date_to_start_datetime
|
|
|
|
|
|
|
|
|
|
date_start = date_to_start_datetime(target_date)
|
|
|
|
|
date_end = date_to_end_datetime(target_date)
|
|
|
|
|
|
|
|
|
|
result = await self.db.execute(
|
|
|
|
|
select(Exchange.slot_start).where(
|
|
|
|
|
and_(
|
|
|
|
|
Exchange.slot_start >= date_start,
|
|
|
|
|
Exchange.slot_start <= date_end,
|
|
|
|
|
Exchange.status == ExchangeStatus.BOOKED,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return {row[0] for row in result.all()}
|
2025-12-25 18:54:29 +01:00
|
|
|
|
|
|
|
|
async def create(self, exchange: Exchange) -> Exchange:
|
|
|
|
|
"""
|
|
|
|
|
Create a new exchange record.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
exchange: Exchange instance to persist
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Created Exchange record (committed and refreshed)
|
|
|
|
|
"""
|
|
|
|
|
self.db.add(exchange)
|
|
|
|
|
await self.db.commit()
|
|
|
|
|
await self.db.refresh(exchange)
|
|
|
|
|
return exchange
|
|
|
|
|
|
|
|
|
|
async def update(self, exchange: Exchange) -> Exchange:
|
|
|
|
|
"""
|
|
|
|
|
Update an existing exchange record.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
exchange: Exchange instance to update
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Updated Exchange record (committed and refreshed)
|
|
|
|
|
"""
|
|
|
|
|
await self.db.commit()
|
|
|
|
|
await self.db.refresh(exchange)
|
|
|
|
|
return exchange
|