"""Exchange repository for database queries.""" import uuid from datetime import UTC, date, datetime, time from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from models import Exchange, ExchangeStatus, User class ExchangeRepository: """Repository for exchange-related database queries.""" def __init__(self, db: AsyncSession): self.db = db async def get_by_public_id( self, public_id: uuid.UUID, load_user: bool = False ) -> Exchange | None: """Get an exchange by public ID.""" query = select(Exchange).where(Exchange.public_id == public_id) if load_user: query = query.options(joinedload(Exchange.user)) result = await self.db.execute(query) return result.scalar_one_or_none() async def get_by_user_id( self, user_id: int, order_by_desc: bool = True ) -> list[Exchange]: """Get all exchanges for a user.""" query = select(Exchange).where(Exchange.user_id == user_id) if order_by_desc: query = query.order_by(Exchange.slot_start.desc()) else: query = query.order_by(Exchange.slot_start.asc()) result = await self.db.execute(query) return list(result.scalars().all()) async def get_upcoming_booked(self) -> list[Exchange]: """Get all upcoming booked trades, sorted by slot time ascending.""" now = datetime.now(UTC) query = ( select(Exchange) .options(joinedload(Exchange.user)) .where( and_( Exchange.slot_start > now, Exchange.status == ExchangeStatus.BOOKED, ) ) .order_by(Exchange.slot_start.asc()) ) result = await self.db.execute(query) return list(result.scalars().all()) async def get_past_trades( self, status: ExchangeStatus | None = None, start_date: date | None = None, end_date: date | None = None, user_search: str | None = None, ) -> list[Exchange]: """ Get past trades with optional filters. Args: status: Filter by exchange status start_date: Filter by slot_start date (inclusive start) end_date: Filter by slot_start date (inclusive end) user_search: Search by user email (partial match, case-insensitive) """ now = datetime.now(UTC) # Start with base query for past trades query = ( select(Exchange) .options(joinedload(Exchange.user)) .where( (Exchange.slot_start <= now) | (Exchange.status != ExchangeStatus.BOOKED) ) ) # Apply status filter if status: query = query.where(Exchange.status == status) # Apply date range filter if start_date: start_dt = datetime.combine(start_date, time.min, tzinfo=UTC) query = query.where(Exchange.slot_start >= start_dt) if end_date: end_dt = datetime.combine(end_date, time.max, tzinfo=UTC) query = query.where(Exchange.slot_start <= end_dt) # Apply user search filter if user_search: query = query.join(Exchange.user).where( User.email.ilike(f"%{user_search}%") ) # Order by most recent first query = query.order_by(Exchange.slot_start.desc()) result = await self.db.execute(query) return list(result.scalars().all()) async def get_by_slot_start( self, slot_start: datetime, status: ExchangeStatus | None = None ) -> Exchange | None: """Get exchange by slot start time, optionally filtered by status.""" query = select(Exchange).where(Exchange.slot_start == slot_start) if status: query = query.where(Exchange.status == status) result = await self.db.execute(query) return result.scalar_one_or_none() async def get_by_user_and_date_range( self, user_id: int, start_date: date, end_date: date, status: ExchangeStatus | None = None, ) -> list[Exchange]: """Get exchanges for a user within a date range.""" from datetime import timedelta start_dt = datetime.combine(start_date, time.min, tzinfo=UTC) # End date should be exclusive (next day at 00:00:00) end_dt = datetime.combine(end_date, time.min, tzinfo=UTC) + timedelta(days=1) query = select(Exchange).where( and_( Exchange.user_id == user_id, Exchange.slot_start >= start_dt, Exchange.slot_start < end_dt, ) ) if status: query = query.where(Exchange.status == status) result = await self.db.execute(query) return list(result.scalars().all()) async def get_booked_slots_for_date(self, target_date: date) -> set[datetime]: """Get set of booked slot start times for a specific date.""" from utils.date_queries import date_to_end_datetime, date_to_start_datetime date_start = date_to_start_datetime(target_date) date_end = date_to_end_datetime(target_date) result = await self.db.execute( select(Exchange.slot_start).where( and_( Exchange.slot_start >= date_start, Exchange.slot_start <= date_end, Exchange.status == ExchangeStatus.BOOKED, ) ) ) return {row[0] for row in result.all()} async def create(self, exchange: Exchange) -> Exchange: """ Create a new exchange record. Args: exchange: Exchange instance to persist Returns: Created Exchange record (committed and refreshed) """ self.db.add(exchange) await self.db.commit() await self.db.refresh(exchange) return exchange async def update(self, exchange: Exchange) -> Exchange: """ Update an existing exchange record. Args: exchange: Exchange instance to update Returns: Updated Exchange record (committed and refreshed) """ await self.db.commit() await self.db.refresh(exchange) return exchange