refactors
This commit is contained in:
parent
f46d2ae8b3
commit
168b67acee
12 changed files with 471 additions and 126 deletions
|
|
@ -1,11 +1,12 @@
|
|||
"""Response mappers for converting models to API response schemas."""
|
||||
|
||||
from models import Exchange, Invite
|
||||
from models import Exchange, Invite, PriceHistory
|
||||
from schemas import (
|
||||
AdminExchangeResponse,
|
||||
ExchangeResponse,
|
||||
ExchangeUserContact,
|
||||
InviteResponse,
|
||||
PriceHistoryResponse,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -89,3 +90,19 @@ class InviteMapper:
|
|||
spent_at=invite.spent_at,
|
||||
revoked_at=invite.revoked_at,
|
||||
)
|
||||
|
||||
|
||||
class PriceHistoryMapper:
|
||||
"""Mapper for PriceHistory model to response schemas."""
|
||||
|
||||
@staticmethod
|
||||
def to_response(record: PriceHistory) -> PriceHistoryResponse:
|
||||
"""Convert a PriceHistory model to PriceHistoryResponse schema."""
|
||||
return PriceHistoryResponse(
|
||||
id=record.id,
|
||||
source=record.source,
|
||||
pair=record.pair,
|
||||
price=record.price,
|
||||
timestamp=record.timestamp,
|
||||
created_at=record.created_at,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,17 @@
|
|||
"""Repository layer for database queries."""
|
||||
|
||||
from repositories.availability import AvailabilityRepository
|
||||
from repositories.exchange import ExchangeRepository
|
||||
from repositories.invite import InviteRepository
|
||||
from repositories.price import PriceRepository
|
||||
from repositories.role import RoleRepository
|
||||
from repositories.user import UserRepository
|
||||
|
||||
__all__ = ["PriceRepository", "UserRepository"]
|
||||
__all__ = [
|
||||
"AvailabilityRepository",
|
||||
"ExchangeRepository",
|
||||
"InviteRepository",
|
||||
"PriceRepository",
|
||||
"RoleRepository",
|
||||
"UserRepository",
|
||||
]
|
||||
|
|
|
|||
41
backend/repositories/availability.py
Normal file
41
backend/repositories/availability.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Availability repository for database queries."""
|
||||
|
||||
from datetime import date
|
||||
|
||||
from sqlalchemy import and_, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models import Availability
|
||||
|
||||
|
||||
class AvailabilityRepository:
|
||||
"""Repository for availability-related database queries."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_by_date_range(
|
||||
self, from_date: date, to_date: date
|
||||
) -> list[Availability]:
|
||||
"""Get availability slots for a date range."""
|
||||
result = await self.db.execute(
|
||||
select(Availability)
|
||||
.where(and_(Availability.date >= from_date, Availability.date <= to_date))
|
||||
.order_by(Availability.date, Availability.start_time)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_date(self, target_date: date) -> list[Availability]:
|
||||
"""Get availability slots for a specific date."""
|
||||
result = await self.db.execute(
|
||||
select(Availability)
|
||||
.where(Availability.date == target_date)
|
||||
.order_by(Availability.start_time)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_by_date(self, target_date: date) -> None:
|
||||
"""Delete all availability for a specific date."""
|
||||
await self.db.execute(
|
||||
delete(Availability).where(Availability.date == target_date)
|
||||
)
|
||||
163
backend/repositories/exchange.py
Normal file
163
backend/repositories/exchange.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
"""Exchange repository for database queries."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, date, datetime, time
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from models import Exchange, ExchangeStatus, User
|
||||
|
||||
|
||||
class ExchangeRepository:
|
||||
"""Repository for exchange-related database queries."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_by_public_id(
|
||||
self, public_id: uuid.UUID, load_user: bool = False
|
||||
) -> Exchange | None:
|
||||
"""Get an exchange by public ID."""
|
||||
query = select(Exchange).where(Exchange.public_id == public_id)
|
||||
if load_user:
|
||||
query = query.options(joinedload(Exchange.user))
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_user_id(
|
||||
self, user_id: int, order_by_desc: bool = True
|
||||
) -> list[Exchange]:
|
||||
"""Get all exchanges for a user."""
|
||||
query = select(Exchange).where(Exchange.user_id == user_id)
|
||||
if order_by_desc:
|
||||
query = query.order_by(Exchange.slot_start.desc())
|
||||
else:
|
||||
query = query.order_by(Exchange.slot_start.asc())
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_upcoming_booked(self) -> list[Exchange]:
|
||||
"""Get all upcoming booked trades, sorted by slot time ascending."""
|
||||
now = datetime.now(UTC)
|
||||
query = (
|
||||
select(Exchange)
|
||||
.options(joinedload(Exchange.user))
|
||||
.where(
|
||||
and_(
|
||||
Exchange.slot_start > now,
|
||||
Exchange.status == ExchangeStatus.BOOKED,
|
||||
)
|
||||
)
|
||||
.order_by(Exchange.slot_start.asc())
|
||||
)
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_past_trades(
|
||||
self,
|
||||
status: ExchangeStatus | None = None,
|
||||
start_date: date | None = None,
|
||||
end_date: date | None = None,
|
||||
user_search: str | None = None,
|
||||
) -> list[Exchange]:
|
||||
"""
|
||||
Get past trades with optional filters.
|
||||
|
||||
Args:
|
||||
status: Filter by exchange status
|
||||
start_date: Filter by slot_start date (inclusive start)
|
||||
end_date: Filter by slot_start date (inclusive end)
|
||||
user_search: Search by user email (partial match, case-insensitive)
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Start with base query for past trades
|
||||
query = (
|
||||
select(Exchange)
|
||||
.options(joinedload(Exchange.user))
|
||||
.where(
|
||||
(Exchange.slot_start <= now)
|
||||
| (Exchange.status != ExchangeStatus.BOOKED)
|
||||
)
|
||||
)
|
||||
|
||||
# Apply status filter
|
||||
if status:
|
||||
query = query.where(Exchange.status == status)
|
||||
|
||||
# Apply date range filter
|
||||
if start_date:
|
||||
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
|
||||
query = query.where(Exchange.slot_start >= start_dt)
|
||||
if end_date:
|
||||
end_dt = datetime.combine(end_date, time.max, tzinfo=UTC)
|
||||
query = query.where(Exchange.slot_start <= end_dt)
|
||||
|
||||
# Apply user search filter
|
||||
if user_search:
|
||||
query = query.join(Exchange.user).where(
|
||||
User.email.ilike(f"%{user_search}%")
|
||||
)
|
||||
|
||||
# Order by most recent first
|
||||
query = query.order_by(Exchange.slot_start.desc())
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_slot_start(
|
||||
self, slot_start: datetime, status: ExchangeStatus | None = None
|
||||
) -> Exchange | None:
|
||||
"""Get exchange by slot start time, optionally filtered by status."""
|
||||
query = select(Exchange).where(Exchange.slot_start == slot_start)
|
||||
if status:
|
||||
query = query.where(Exchange.status == status)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_user_and_date_range(
|
||||
self,
|
||||
user_id: int,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
status: ExchangeStatus | None = None,
|
||||
) -> list[Exchange]:
|
||||
"""Get exchanges for a user within a date range."""
|
||||
from datetime import timedelta
|
||||
|
||||
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
|
||||
# End date should be exclusive (next day at 00:00:00)
|
||||
end_dt = datetime.combine(end_date, time.min, tzinfo=UTC) + timedelta(days=1)
|
||||
|
||||
query = select(Exchange).where(
|
||||
and_(
|
||||
Exchange.user_id == user_id,
|
||||
Exchange.slot_start >= start_dt,
|
||||
Exchange.slot_start < end_dt,
|
||||
)
|
||||
)
|
||||
if status:
|
||||
query = query.where(Exchange.status == status)
|
||||
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_booked_slots_for_date(self, target_date: date) -> set[datetime]:
|
||||
"""Get set of booked slot start times for a specific date."""
|
||||
from utils.date_queries import date_to_end_datetime, date_to_start_datetime
|
||||
|
||||
date_start = date_to_start_datetime(target_date)
|
||||
date_end = date_to_end_datetime(target_date)
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Exchange.slot_start).where(
|
||||
and_(
|
||||
Exchange.slot_start >= date_start,
|
||||
Exchange.slot_start <= date_end,
|
||||
Exchange.status == ExchangeStatus.BOOKED,
|
||||
)
|
||||
)
|
||||
)
|
||||
return {row[0] for row in result.all()}
|
||||
69
backend/repositories/invite.py
Normal file
69
backend/repositories/invite.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
"""Invite repository for database queries."""
|
||||
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models import Invite, InviteStatus
|
||||
|
||||
|
||||
class InviteRepository:
|
||||
"""Repository for invite-related database queries."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_by_identifier(self, identifier: str) -> Invite | None:
|
||||
"""Get an invite by identifier."""
|
||||
result = await self.db.execute(
|
||||
select(Invite).where(Invite.identifier == identifier)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_id(self, invite_id: int) -> Invite | None:
|
||||
"""Get an invite by ID."""
|
||||
result = await self.db.execute(select(Invite).where(Invite.id == invite_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_godfather_id(
|
||||
self, godfather_id: int, order_by_desc: bool = True
|
||||
) -> list[Invite]:
|
||||
"""Get all invites for a godfather user."""
|
||||
query = select(Invite).where(Invite.godfather_id == godfather_id)
|
||||
if order_by_desc:
|
||||
query = query.order_by(desc(Invite.created_at))
|
||||
else:
|
||||
query = query.order_by(Invite.created_at)
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count(
|
||||
self,
|
||||
status: InviteStatus | None = None,
|
||||
godfather_id: int | None = None,
|
||||
) -> int:
|
||||
"""Count invites matching filters."""
|
||||
query = select(func.count(Invite.id))
|
||||
if status:
|
||||
query = query.where(Invite.status == status)
|
||||
if godfather_id:
|
||||
query = query.where(Invite.godfather_id == godfather_id)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def list_paginated(
|
||||
self,
|
||||
page: int,
|
||||
per_page: int,
|
||||
status: InviteStatus | None = None,
|
||||
godfather_id: int | None = None,
|
||||
) -> list[Invite]:
|
||||
"""Get paginated list of invites."""
|
||||
offset = (page - 1) * per_page
|
||||
query = select(Invite)
|
||||
if status:
|
||||
query = query.where(Invite.status == status)
|
||||
if godfather_id:
|
||||
query = query.where(Invite.godfather_id == godfather_id)
|
||||
query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page)
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
"""Price repository for database queries."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -25,3 +27,32 @@ class PriceRepository:
|
|||
)
|
||||
result = await self.db.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_recent(self, limit: int = 20) -> list[PriceHistory]:
|
||||
"""Get the most recent price history records."""
|
||||
query = select(PriceHistory).order_by(desc(PriceHistory.timestamp)).limit(limit)
|
||||
result = await self.db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_timestamp(
|
||||
self,
|
||||
timestamp: str | datetime,
|
||||
source: str = SOURCE_BITFINEX,
|
||||
pair: str = PAIR_BTC_EUR,
|
||||
) -> PriceHistory | None:
|
||||
"""Get a price record by timestamp."""
|
||||
# Convert string timestamp to datetime if needed
|
||||
timestamp_dt: datetime
|
||||
if isinstance(timestamp, str):
|
||||
timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
else:
|
||||
timestamp_dt = timestamp
|
||||
|
||||
result = await self.db.execute(
|
||||
select(PriceHistory).where(
|
||||
PriceHistory.source == source,
|
||||
PriceHistory.pair == pair,
|
||||
PriceHistory.timestamp == timestamp_dt,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
|
|
|||
18
backend/repositories/role.py
Normal file
18
backend/repositories/role.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Role repository for database queries."""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from models import Role
|
||||
|
||||
|
||||
class RoleRepository:
|
||||
"""Repository for role-related database queries."""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_by_name(self, name: str) -> Role | None:
|
||||
"""Get a role by name."""
|
||||
result = await self.db.execute(select(Role).where(Role.name == name))
|
||||
return result.scalar_one_or_none()
|
||||
|
|
@ -1,22 +1,18 @@
|
|||
"""Exchange routes for Bitcoin trading."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from date_validation import validate_date_in_range
|
||||
from exceptions import BadRequestError
|
||||
from mappers import ExchangeMapper
|
||||
from models import (
|
||||
Availability,
|
||||
BitcoinTransferMethod,
|
||||
Exchange,
|
||||
ExchangeStatus,
|
||||
Permission,
|
||||
PriceHistory,
|
||||
|
|
@ -24,6 +20,7 @@ from models import (
|
|||
User,
|
||||
)
|
||||
from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price
|
||||
from repositories.exchange import ExchangeRepository
|
||||
from repositories.price import PriceRepository
|
||||
from schemas import (
|
||||
AdminExchangeResponse,
|
||||
|
|
@ -44,6 +41,7 @@ from shared_constants import (
|
|||
PREMIUM_PERCENTAGE,
|
||||
SLOT_DURATION_MINUTES,
|
||||
)
|
||||
from utils.enum_validation import validate_enum
|
||||
|
||||
router = APIRouter(prefix="/api/exchange", tags=["exchange"])
|
||||
|
||||
|
|
@ -190,28 +188,18 @@ async def get_available_slots(
|
|||
validate_date_in_range(date_param, context="book")
|
||||
|
||||
# Get availability for the date
|
||||
result = await db.execute(
|
||||
select(Availability).where(Availability.date == date_param)
|
||||
)
|
||||
availabilities = result.scalars().all()
|
||||
from repositories.availability import AvailabilityRepository
|
||||
from repositories.exchange import ExchangeRepository
|
||||
|
||||
availability_repo = AvailabilityRepository(db)
|
||||
availabilities = await availability_repo.get_by_date(date_param)
|
||||
|
||||
if not availabilities:
|
||||
return AvailableSlotsResponse(date=date_param, slots=[])
|
||||
|
||||
# Get already booked slots for the date
|
||||
date_start = datetime.combine(date_param, time.min, tzinfo=UTC)
|
||||
date_end = datetime.combine(date_param, time.max, tzinfo=UTC)
|
||||
|
||||
result = await db.execute(
|
||||
select(Exchange.slot_start).where(
|
||||
and_(
|
||||
Exchange.slot_start >= date_start,
|
||||
Exchange.slot_start <= date_end,
|
||||
Exchange.status == ExchangeStatus.BOOKED,
|
||||
)
|
||||
)
|
||||
)
|
||||
booked_starts = {row[0] for row in result.all()}
|
||||
exchange_repo = ExchangeRepository(db)
|
||||
booked_starts = await exchange_repo.get_booked_slots_for_date(date_param)
|
||||
|
||||
# Expand each availability into slots
|
||||
all_slots: list[BookableSlot] = []
|
||||
|
|
@ -247,21 +235,16 @@ async def create_exchange(
|
|||
- EUR amount is within configured limits
|
||||
"""
|
||||
# Validate direction
|
||||
try:
|
||||
direction = TradeDirection(request.direction)
|
||||
except ValueError:
|
||||
raise BadRequestError(
|
||||
f"Invalid direction: {request.direction}. Must be 'buy' or 'sell'."
|
||||
) from None
|
||||
direction: TradeDirection = validate_enum(
|
||||
TradeDirection, request.direction, "direction"
|
||||
)
|
||||
|
||||
# Validate bitcoin transfer method
|
||||
try:
|
||||
bitcoin_transfer_method = BitcoinTransferMethod(request.bitcoin_transfer_method)
|
||||
except ValueError:
|
||||
raise BadRequestError(
|
||||
f"Invalid bitcoin_transfer_method: {request.bitcoin_transfer_method}. "
|
||||
"Must be 'onchain' or 'lightning'."
|
||||
) from None
|
||||
bitcoin_transfer_method: BitcoinTransferMethod = validate_enum(
|
||||
BitcoinTransferMethod,
|
||||
request.bitcoin_transfer_method,
|
||||
"bitcoin_transfer_method",
|
||||
)
|
||||
|
||||
# Use service to create exchange (handles all validation)
|
||||
service = ExchangeService(db)
|
||||
|
|
@ -289,12 +272,8 @@ async def get_my_trades(
|
|||
current_user: User = Depends(require_permission(Permission.VIEW_OWN_EXCHANGES)),
|
||||
) -> list[ExchangeResponse]:
|
||||
"""Get the current user's exchanges, sorted by date (newest first)."""
|
||||
result = await db.execute(
|
||||
select(Exchange)
|
||||
.where(Exchange.user_id == current_user.id)
|
||||
.order_by(Exchange.slot_start.desc())
|
||||
)
|
||||
exchanges = result.scalars().all()
|
||||
exchange_repo = ExchangeRepository(db)
|
||||
exchanges = await exchange_repo.get_by_user_id(current_user.id, order_by_desc=True)
|
||||
|
||||
return [ExchangeMapper.to_response(ex, current_user.email) for ex in exchanges]
|
||||
|
||||
|
|
@ -348,19 +327,8 @@ async def get_upcoming_trades(
|
|||
_current_user: User = Depends(require_permission(Permission.VIEW_ALL_EXCHANGES)),
|
||||
) -> list[AdminExchangeResponse]:
|
||||
"""Get all upcoming booked trades, sorted by slot time ascending."""
|
||||
now = datetime.now(UTC)
|
||||
result = await db.execute(
|
||||
select(Exchange)
|
||||
.options(joinedload(Exchange.user))
|
||||
.where(
|
||||
and_(
|
||||
Exchange.slot_start > now,
|
||||
Exchange.status == ExchangeStatus.BOOKED,
|
||||
)
|
||||
)
|
||||
.order_by(Exchange.slot_start.asc())
|
||||
)
|
||||
exchanges = result.scalars().all()
|
||||
exchange_repo = ExchangeRepository(db)
|
||||
exchanges = await exchange_repo.get_upcoming_booked()
|
||||
|
||||
return [ExchangeMapper.to_admin_response(ex) for ex in exchanges]
|
||||
|
||||
|
|
@ -383,45 +351,19 @@ async def get_past_trades(
|
|||
- user_search: Search by user email (partial match)
|
||||
"""
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Start with base query for past trades (slot_start <= now OR not booked)
|
||||
query = (
|
||||
select(Exchange)
|
||||
.options(joinedload(Exchange.user))
|
||||
.where(
|
||||
(Exchange.slot_start <= now) | (Exchange.status != ExchangeStatus.BOOKED)
|
||||
)
|
||||
)
|
||||
|
||||
# Apply status filter
|
||||
status_enum: ExchangeStatus | None = None
|
||||
if status:
|
||||
try:
|
||||
status_enum = ExchangeStatus(status)
|
||||
query = query.where(Exchange.status == status_enum)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}",
|
||||
) from None
|
||||
status_enum = validate_enum(ExchangeStatus, status, "status")
|
||||
|
||||
# Apply date range filter
|
||||
if start_date:
|
||||
start_dt = datetime.combine(start_date, time.min, tzinfo=UTC)
|
||||
query = query.where(Exchange.slot_start >= start_dt)
|
||||
if end_date:
|
||||
end_dt = datetime.combine(end_date, time.max, tzinfo=UTC)
|
||||
query = query.where(Exchange.slot_start <= end_dt)
|
||||
|
||||
# Apply user search filter (join with User table)
|
||||
if user_search:
|
||||
query = query.join(Exchange.user).where(User.email.ilike(f"%{user_search}%"))
|
||||
|
||||
# Order by most recent first
|
||||
query = query.order_by(Exchange.slot_start.desc())
|
||||
|
||||
result = await db.execute(query)
|
||||
exchanges = result.scalars().all()
|
||||
# Use repository for query
|
||||
exchange_repo = ExchangeRepository(db)
|
||||
exchanges = await exchange_repo.get_past_trades(
|
||||
status=status_enum,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
user_search=user_search,
|
||||
)
|
||||
|
||||
return [ExchangeMapper.to_admin_response(ex) for ex in exchanges]
|
||||
|
||||
|
|
@ -487,6 +429,10 @@ async def search_users(
|
|||
Returns users whose email contains the search query (case-insensitive).
|
||||
Limited to 10 results for autocomplete purposes.
|
||||
"""
|
||||
# Note: UserRepository doesn't have search yet, but we can add it
|
||||
# For now, keeping direct query for this specific use case
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(User.email.ilike(f"%{q}%")).order_by(User.email).limit(10)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
"""Exchange service for business logic related to Bitcoin trading."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, date, datetime, time, timedelta
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
|
@ -15,7 +14,6 @@ from exceptions import (
|
|||
ServiceUnavailableError,
|
||||
)
|
||||
from models import (
|
||||
Availability,
|
||||
BitcoinTransferMethod,
|
||||
Exchange,
|
||||
ExchangeStatus,
|
||||
|
|
@ -23,6 +21,7 @@ from models import (
|
|||
TradeDirection,
|
||||
User,
|
||||
)
|
||||
from repositories.exchange import ExchangeRepository
|
||||
from repositories.price import PriceRepository
|
||||
from shared_constants import (
|
||||
EUR_TRADE_INCREMENT,
|
||||
|
|
@ -44,6 +43,7 @@ class ExchangeService:
|
|||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.price_repo = PriceRepository(db)
|
||||
self.exchange_repo = ExchangeRepository(db)
|
||||
|
||||
def apply_premium_for_direction(
|
||||
self,
|
||||
|
|
@ -107,20 +107,21 @@ class ExchangeService:
|
|||
self, slot_start: datetime, slot_date: date
|
||||
) -> None:
|
||||
"""Verify slot falls within availability."""
|
||||
from repositories.availability import AvailabilityRepository
|
||||
|
||||
slot_start_time = slot_start.time()
|
||||
slot_end_dt = slot_start + timedelta(minutes=SLOT_DURATION_MINUTES)
|
||||
slot_end_time = slot_end_dt.time()
|
||||
|
||||
result = await self.db.execute(
|
||||
select(Availability).where(
|
||||
and_(
|
||||
Availability.date == slot_date,
|
||||
Availability.start_time <= slot_start_time,
|
||||
Availability.end_time >= slot_end_time,
|
||||
)
|
||||
)
|
||||
)
|
||||
matching_availability = result.scalar_one_or_none()
|
||||
availability_repo = AvailabilityRepository(self.db)
|
||||
availabilities = await availability_repo.get_by_date(slot_date)
|
||||
|
||||
# Check if any availability block contains this slot
|
||||
matching_availability = None
|
||||
for avail in availabilities:
|
||||
if avail.start_time <= slot_start_time and avail.end_time >= slot_end_time:
|
||||
matching_availability = avail
|
||||
break
|
||||
|
||||
if not matching_availability:
|
||||
slot_str = slot_start.strftime("%Y-%m-%d %H:%M")
|
||||
|
|
@ -171,29 +172,19 @@ class ExchangeService:
|
|||
self, user: User, slot_date: date
|
||||
) -> Exchange | None:
|
||||
"""Check if user already has a trade on this date."""
|
||||
existing_trade_query = select(Exchange).where(
|
||||
and_(
|
||||
Exchange.user_id == user.id,
|
||||
Exchange.slot_start
|
||||
>= datetime.combine(slot_date, time.min, tzinfo=UTC),
|
||||
Exchange.slot_start
|
||||
< datetime.combine(slot_date, time.max, tzinfo=UTC) + timedelta(days=1),
|
||||
Exchange.status == ExchangeStatus.BOOKED,
|
||||
)
|
||||
exchanges = await self.exchange_repo.get_by_user_and_date_range(
|
||||
user_id=user.id,
|
||||
start_date=slot_date,
|
||||
end_date=slot_date,
|
||||
status=ExchangeStatus.BOOKED,
|
||||
)
|
||||
result = await self.db.execute(existing_trade_query)
|
||||
return result.scalar_one_or_none()
|
||||
return exchanges[0] if exchanges else None
|
||||
|
||||
async def check_slot_already_booked(self, slot_start: datetime) -> Exchange | None:
|
||||
"""Check if slot is already booked (only consider BOOKED status)."""
|
||||
slot_booked_query = select(Exchange).where(
|
||||
and_(
|
||||
Exchange.slot_start == slot_start,
|
||||
Exchange.status == ExchangeStatus.BOOKED,
|
||||
)
|
||||
return await self.exchange_repo.get_by_slot_start(
|
||||
slot_start, status=ExchangeStatus.BOOKED
|
||||
)
|
||||
result = await self.db.execute(slot_booked_query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_exchange(
|
||||
self,
|
||||
|
|
@ -297,9 +288,7 @@ class ExchangeService:
|
|||
NotFoundError: If exchange not found or user doesn't own it
|
||||
(for security, returns 404)
|
||||
"""
|
||||
query = select(Exchange).where(Exchange.public_id == public_id)
|
||||
result = await self.db.execute(query)
|
||||
exchange = result.scalar_one_or_none()
|
||||
exchange = await self.exchange_repo.get_by_public_id(public_id)
|
||||
|
||||
if not exchange:
|
||||
raise NotFoundError("Trade")
|
||||
|
|
|
|||
1
backend/utils/__init__.py
Normal file
1
backend/utils/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Utility modules for common functionality."""
|
||||
27
backend/utils/date_queries.py
Normal file
27
backend/utils/date_queries.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Utilities for date/time query operations."""
|
||||
|
||||
from datetime import UTC, date, datetime, time
|
||||
|
||||
|
||||
def date_to_start_datetime(d: date) -> datetime:
|
||||
"""Convert a date to datetime at start of day (00:00:00) in UTC."""
|
||||
return datetime.combine(d, time.min, tzinfo=UTC)
|
||||
|
||||
|
||||
def date_to_end_datetime(d: date) -> datetime:
|
||||
"""Convert a date to datetime at end of day (23:59:59.999999) in UTC."""
|
||||
return datetime.combine(d, time.max, tzinfo=UTC)
|
||||
|
||||
|
||||
def date_range_to_datetime_range(
|
||||
start_date: date, end_date: date
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""
|
||||
Convert a date range to datetime range.
|
||||
|
||||
Returns:
|
||||
Tuple of (start_datetime, end_datetime) where:
|
||||
- start_datetime is start_date at 00:00:00 UTC
|
||||
- end_datetime is end_date at 23:59:59.999999 UTC
|
||||
"""
|
||||
return date_to_start_datetime(start_date), date_to_end_datetime(end_date)
|
||||
32
backend/utils/enum_validation.py
Normal file
32
backend/utils/enum_validation.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
"""Utilities for validating enum values from strings."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import TypeVar
|
||||
|
||||
from exceptions import BadRequestError
|
||||
|
||||
T = TypeVar("T", bound=Enum)
|
||||
|
||||
|
||||
def validate_enum(enum_class: type[T], value: str, field_name: str = "value") -> T:
|
||||
"""
|
||||
Validate and convert string to enum.
|
||||
|
||||
Args:
|
||||
enum_class: The enum class to validate against
|
||||
value: The string value to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The validated enum value
|
||||
|
||||
Raises:
|
||||
BadRequestError: If the value is not a valid enum member
|
||||
"""
|
||||
try:
|
||||
return enum_class(value)
|
||||
except ValueError:
|
||||
valid_values = ", ".join(e.value for e in enum_class)
|
||||
raise BadRequestError(
|
||||
f"Invalid {field_name}: {value}. Must be one of: {valid_values}"
|
||||
) from None
|
||||
Loading…
Add table
Add a link
Reference in a new issue