refactors

This commit is contained in:
counterweight 2025-12-25 18:27:59 +01:00
parent f46d2ae8b3
commit 168b67acee
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
12 changed files with 471 additions and 126 deletions

View file

@ -1,6 +1,17 @@
"""Repository layer for database queries."""
from repositories.availability import AvailabilityRepository
from repositories.exchange import ExchangeRepository
from repositories.invite import InviteRepository
from repositories.price import PriceRepository
from repositories.role import RoleRepository
from repositories.user import UserRepository
__all__ = ["PriceRepository", "UserRepository"]
__all__ = [
"AvailabilityRepository",
"ExchangeRepository",
"InviteRepository",
"PriceRepository",
"RoleRepository",
"UserRepository",
]

View file

@ -0,0 +1,41 @@
"""Availability repository for database queries."""
from datetime import date
from sqlalchemy import and_, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from models import Availability
class AvailabilityRepository:
"""Repository for availability-related database queries."""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_date_range(
self, from_date: date, to_date: date
) -> list[Availability]:
"""Get availability slots for a date range."""
result = await self.db.execute(
select(Availability)
.where(and_(Availability.date >= from_date, Availability.date <= to_date))
.order_by(Availability.date, Availability.start_time)
)
return list(result.scalars().all())
async def get_by_date(self, target_date: date) -> list[Availability]:
"""Get availability slots for a specific date."""
result = await self.db.execute(
select(Availability)
.where(Availability.date == target_date)
.order_by(Availability.start_time)
)
return list(result.scalars().all())
async def delete_by_date(self, target_date: date) -> None:
"""Delete all availability for a specific date."""
await self.db.execute(
delete(Availability).where(Availability.date == target_date)
)

View file

@ -0,0 +1,163 @@
"""Exchange repository for database queries."""
import uuid
from datetime import UTC, date, datetime, time
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from models import Exchange, ExchangeStatus, User
class ExchangeRepository:
"""Repository for exchange-related database queries."""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_public_id(
self, public_id: uuid.UUID, load_user: bool = False
) -> Exchange | None:
"""Get an exchange by public ID."""
query = select(Exchange).where(Exchange.public_id == public_id)
if load_user:
query = query.options(joinedload(Exchange.user))
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_by_user_id(
self, user_id: int, order_by_desc: bool = True
) -> list[Exchange]:
"""Get all exchanges for a user."""
query = select(Exchange).where(Exchange.user_id == user_id)
if order_by_desc:
query = query.order_by(Exchange.slot_start.desc())
else:
query = query.order_by(Exchange.slot_start.asc())
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_upcoming_booked(self) -> list[Exchange]:
"""Get all upcoming booked trades, sorted by slot time ascending."""
now = datetime.now(UTC)
query = (
select(Exchange)
.options(joinedload(Exchange.user))
.where(
and_(
Exchange.slot_start > now,
Exchange.status == ExchangeStatus.BOOKED,
)
)
.order_by(Exchange.slot_start.asc())
)
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_past_trades(
self,
status: ExchangeStatus | None = None,
start_date: date | None = None,
end_date: date | None = None,
user_search: str | None = None,
) -> list[Exchange]:
"""
Get past trades with optional filters.
Args:
status: Filter by exchange status
start_date: Filter by slot_start date (inclusive start)
end_date: Filter by slot_start date (inclusive end)
user_search: Search by user email (partial match, case-insensitive)
"""
now = datetime.now(UTC)
# Start with base query for past trades
query = (
select(Exchange)
.options(joinedload(Exchange.user))
.where(
(Exchange.slot_start <= now)
| (Exchange.status != ExchangeStatus.BOOKED)
)
)
# Apply status filter
if status:
query = query.where(Exchange.status == status)
# Apply date range filter
if start_date:
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
query = query.where(Exchange.slot_start >= start_dt)
if end_date:
end_dt = datetime.combine(end_date, time.max, tzinfo=UTC)
query = query.where(Exchange.slot_start <= end_dt)
# Apply user search filter
if user_search:
query = query.join(Exchange.user).where(
User.email.ilike(f"%{user_search}%")
)
# Order by most recent first
query = query.order_by(Exchange.slot_start.desc())
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_by_slot_start(
self, slot_start: datetime, status: ExchangeStatus | None = None
) -> Exchange | None:
"""Get exchange by slot start time, optionally filtered by status."""
query = select(Exchange).where(Exchange.slot_start == slot_start)
if status:
query = query.where(Exchange.status == status)
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_by_user_and_date_range(
self,
user_id: int,
start_date: date,
end_date: date,
status: ExchangeStatus | None = None,
) -> list[Exchange]:
"""Get exchanges for a user within a date range."""
from datetime import timedelta
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
# End date should be exclusive (next day at 00:00:00)
end_dt = datetime.combine(end_date, time.min, tzinfo=UTC) + timedelta(days=1)
query = select(Exchange).where(
and_(
Exchange.user_id == user_id,
Exchange.slot_start >= start_dt,
Exchange.slot_start < end_dt,
)
)
if status:
query = query.where(Exchange.status == status)
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_booked_slots_for_date(self, target_date: date) -> set[datetime]:
"""Get set of booked slot start times for a specific date."""
from utils.date_queries import date_to_end_datetime, date_to_start_datetime
date_start = date_to_start_datetime(target_date)
date_end = date_to_end_datetime(target_date)
result = await self.db.execute(
select(Exchange.slot_start).where(
and_(
Exchange.slot_start >= date_start,
Exchange.slot_start <= date_end,
Exchange.status == ExchangeStatus.BOOKED,
)
)
)
return {row[0] for row in result.all()}

View file

@ -0,0 +1,69 @@
"""Invite repository for database queries."""
from sqlalchemy import desc, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from models import Invite, InviteStatus
class InviteRepository:
"""Repository for invite-related database queries."""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_identifier(self, identifier: str) -> Invite | None:
"""Get an invite by identifier."""
result = await self.db.execute(
select(Invite).where(Invite.identifier == identifier)
)
return result.scalar_one_or_none()
async def get_by_id(self, invite_id: int) -> Invite | None:
"""Get an invite by ID."""
result = await self.db.execute(select(Invite).where(Invite.id == invite_id))
return result.scalar_one_or_none()
async def get_by_godfather_id(
self, godfather_id: int, order_by_desc: bool = True
) -> list[Invite]:
"""Get all invites for a godfather user."""
query = select(Invite).where(Invite.godfather_id == godfather_id)
if order_by_desc:
query = query.order_by(desc(Invite.created_at))
else:
query = query.order_by(Invite.created_at)
result = await self.db.execute(query)
return list(result.scalars().all())
async def count(
self,
status: InviteStatus | None = None,
godfather_id: int | None = None,
) -> int:
"""Count invites matching filters."""
query = select(func.count(Invite.id))
if status:
query = query.where(Invite.status == status)
if godfather_id:
query = query.where(Invite.godfather_id == godfather_id)
result = await self.db.execute(query)
return result.scalar() or 0
async def list_paginated(
self,
page: int,
per_page: int,
status: InviteStatus | None = None,
godfather_id: int | None = None,
) -> list[Invite]:
"""Get paginated list of invites."""
offset = (page - 1) * per_page
query = select(Invite)
if status:
query = query.where(Invite.status == status)
if godfather_id:
query = query.where(Invite.godfather_id == godfather_id)
query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page)
result = await self.db.execute(query)
return list(result.scalars().all())

View file

@ -1,5 +1,7 @@
"""Price repository for database queries."""
from datetime import datetime
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession
@ -25,3 +27,32 @@ class PriceRepository:
)
result = await self.db.execute(query)
return result.scalar_one_or_none()
async def get_recent(self, limit: int = 20) -> list[PriceHistory]:
"""Get the most recent price history records."""
query = select(PriceHistory).order_by(desc(PriceHistory.timestamp)).limit(limit)
result = await self.db.execute(query)
return list(result.scalars().all())
async def get_by_timestamp(
self,
timestamp: str | datetime,
source: str = SOURCE_BITFINEX,
pair: str = PAIR_BTC_EUR,
) -> PriceHistory | None:
"""Get a price record by timestamp."""
# Convert string timestamp to datetime if needed
timestamp_dt: datetime
if isinstance(timestamp, str):
timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
else:
timestamp_dt = timestamp
result = await self.db.execute(
select(PriceHistory).where(
PriceHistory.source == source,
PriceHistory.pair == pair,
PriceHistory.timestamp == timestamp_dt,
)
)
return result.scalar_one_or_none()

View file

@ -0,0 +1,18 @@
"""Role repository for database queries."""
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from models import Role
class RoleRepository:
"""Repository for role-related database queries."""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_name(self, name: str) -> Role | None:
"""Get a role by name."""
result = await self.db.execute(select(Role).where(Role.name == name))
return result.scalar_one_or_none()