From 168b67aceea00e98cc881c7c4d7305a38772e335 Mon Sep 17 00:00:00 2001 From: counterweight Date: Thu, 25 Dec 2025 18:27:59 +0100 Subject: [PATCH] refactors --- backend/mappers.py | 19 +++- backend/repositories/__init__.py | 13 ++- backend/repositories/availability.py | 41 +++++++ backend/repositories/exchange.py | 163 +++++++++++++++++++++++++++ backend/repositories/invite.py | 69 ++++++++++++ backend/repositories/price.py | 31 +++++ backend/repositories/role.py | 18 +++ backend/routes/exchange.py | 126 ++++++--------------- backend/services/exchange.py | 57 ++++------ backend/utils/__init__.py | 1 + backend/utils/date_queries.py | 27 +++++ backend/utils/enum_validation.py | 32 ++++++ 12 files changed, 471 insertions(+), 126 deletions(-) create mode 100644 backend/repositories/availability.py create mode 100644 backend/repositories/exchange.py create mode 100644 backend/repositories/invite.py create mode 100644 backend/repositories/role.py create mode 100644 backend/utils/__init__.py create mode 100644 backend/utils/date_queries.py create mode 100644 backend/utils/enum_validation.py diff --git a/backend/mappers.py b/backend/mappers.py index 8f2ad4a..5ebd8fa 100644 --- a/backend/mappers.py +++ b/backend/mappers.py @@ -1,11 +1,12 @@ """Response mappers for converting models to API response schemas.""" -from models import Exchange, Invite +from models import Exchange, Invite, PriceHistory from schemas import ( AdminExchangeResponse, ExchangeResponse, ExchangeUserContact, InviteResponse, + PriceHistoryResponse, ) @@ -89,3 +90,19 @@ class InviteMapper: spent_at=invite.spent_at, revoked_at=invite.revoked_at, ) + + +class PriceHistoryMapper: + """Mapper for PriceHistory model to response schemas.""" + + @staticmethod + def to_response(record: PriceHistory) -> PriceHistoryResponse: + """Convert a PriceHistory model to PriceHistoryResponse schema.""" + return PriceHistoryResponse( + id=record.id, + source=record.source, + pair=record.pair, + price=record.price, + timestamp=record.timestamp, + created_at=record.created_at, + ) diff --git a/backend/repositories/__init__.py b/backend/repositories/__init__.py index aff0836..805dee7 100644 --- a/backend/repositories/__init__.py +++ b/backend/repositories/__init__.py @@ -1,6 +1,17 @@ """Repository layer for database queries.""" +from repositories.availability import AvailabilityRepository +from repositories.exchange import ExchangeRepository +from repositories.invite import InviteRepository from repositories.price import PriceRepository +from repositories.role import RoleRepository from repositories.user import UserRepository -__all__ = ["PriceRepository", "UserRepository"] +__all__ = [ + "AvailabilityRepository", + "ExchangeRepository", + "InviteRepository", + "PriceRepository", + "RoleRepository", + "UserRepository", +] diff --git a/backend/repositories/availability.py b/backend/repositories/availability.py new file mode 100644 index 0000000..8718cce --- /dev/null +++ b/backend/repositories/availability.py @@ -0,0 +1,41 @@ +"""Availability repository for database queries.""" + +from datetime import date + +from sqlalchemy import and_, delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from models import Availability + + +class AvailabilityRepository: + """Repository for availability-related database queries.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def get_by_date_range( + self, from_date: date, to_date: date + ) -> list[Availability]: + """Get availability slots for a date range.""" + result = await self.db.execute( + select(Availability) + .where(and_(Availability.date >= from_date, Availability.date <= to_date)) + .order_by(Availability.date, Availability.start_time) + ) + return list(result.scalars().all()) + + async def get_by_date(self, target_date: date) -> list[Availability]: + """Get availability slots for a specific date.""" + result = await self.db.execute( + select(Availability) + .where(Availability.date == target_date) + .order_by(Availability.start_time) + ) + return list(result.scalars().all()) + + async def delete_by_date(self, target_date: date) -> None: + """Delete all availability for a specific date.""" + await self.db.execute( + delete(Availability).where(Availability.date == target_date) + ) diff --git a/backend/repositories/exchange.py b/backend/repositories/exchange.py new file mode 100644 index 0000000..eb28494 --- /dev/null +++ b/backend/repositories/exchange.py @@ -0,0 +1,163 @@ +"""Exchange repository for database queries.""" + +import uuid +from datetime import UTC, date, datetime, time + +from sqlalchemy import and_, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from models import Exchange, ExchangeStatus, User + + +class ExchangeRepository: + """Repository for exchange-related database queries.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def get_by_public_id( + self, public_id: uuid.UUID, load_user: bool = False + ) -> Exchange | None: + """Get an exchange by public ID.""" + query = select(Exchange).where(Exchange.public_id == public_id) + if load_user: + query = query.options(joinedload(Exchange.user)) + result = await self.db.execute(query) + return result.scalar_one_or_none() + + async def get_by_user_id( + self, user_id: int, order_by_desc: bool = True + ) -> list[Exchange]: + """Get all exchanges for a user.""" + query = select(Exchange).where(Exchange.user_id == user_id) + if order_by_desc: + query = query.order_by(Exchange.slot_start.desc()) + else: + query = query.order_by(Exchange.slot_start.asc()) + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def get_upcoming_booked(self) -> list[Exchange]: + """Get all upcoming booked trades, sorted by slot time ascending.""" + now = datetime.now(UTC) + query = ( + select(Exchange) + .options(joinedload(Exchange.user)) + .where( + and_( + Exchange.slot_start > now, + Exchange.status == ExchangeStatus.BOOKED, + ) + ) + .order_by(Exchange.slot_start.asc()) + ) + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def get_past_trades( + self, + status: ExchangeStatus | None = None, + start_date: date | None = None, + end_date: date | None = None, + user_search: str | None = None, + ) -> list[Exchange]: + """ + Get past trades with optional filters. + + Args: + status: Filter by exchange status + start_date: Filter by slot_start date (inclusive start) + end_date: Filter by slot_start date (inclusive end) + user_search: Search by user email (partial match, case-insensitive) + """ + now = datetime.now(UTC) + + # Start with base query for past trades + query = ( + select(Exchange) + .options(joinedload(Exchange.user)) + .where( + (Exchange.slot_start <= now) + | (Exchange.status != ExchangeStatus.BOOKED) + ) + ) + + # Apply status filter + if status: + query = query.where(Exchange.status == status) + + # Apply date range filter + if start_date: + start_dt = datetime.combine(start_date, time.min, tzinfo=UTC) + query = query.where(Exchange.slot_start >= start_dt) + if end_date: + end_dt = datetime.combine(end_date, time.max, tzinfo=UTC) + query = query.where(Exchange.slot_start <= end_dt) + + # Apply user search filter + if user_search: + query = query.join(Exchange.user).where( + User.email.ilike(f"%{user_search}%") + ) + + # Order by most recent first + query = query.order_by(Exchange.slot_start.desc()) + + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def get_by_slot_start( + self, slot_start: datetime, status: ExchangeStatus | None = None + ) -> Exchange | None: + """Get exchange by slot start time, optionally filtered by status.""" + query = select(Exchange).where(Exchange.slot_start == slot_start) + if status: + query = query.where(Exchange.status == status) + result = await self.db.execute(query) + return result.scalar_one_or_none() + + async def get_by_user_and_date_range( + self, + user_id: int, + start_date: date, + end_date: date, + status: ExchangeStatus | None = None, + ) -> list[Exchange]: + """Get exchanges for a user within a date range.""" + from datetime import timedelta + + start_dt = datetime.combine(start_date, time.min, tzinfo=UTC) + # End date should be exclusive (next day at 00:00:00) + end_dt = datetime.combine(end_date, time.min, tzinfo=UTC) + timedelta(days=1) + + query = select(Exchange).where( + and_( + Exchange.user_id == user_id, + Exchange.slot_start >= start_dt, + Exchange.slot_start < end_dt, + ) + ) + if status: + query = query.where(Exchange.status == status) + + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def get_booked_slots_for_date(self, target_date: date) -> set[datetime]: + """Get set of booked slot start times for a specific date.""" + from utils.date_queries import date_to_end_datetime, date_to_start_datetime + + date_start = date_to_start_datetime(target_date) + date_end = date_to_end_datetime(target_date) + + result = await self.db.execute( + select(Exchange.slot_start).where( + and_( + Exchange.slot_start >= date_start, + Exchange.slot_start <= date_end, + Exchange.status == ExchangeStatus.BOOKED, + ) + ) + ) + return {row[0] for row in result.all()} diff --git a/backend/repositories/invite.py b/backend/repositories/invite.py new file mode 100644 index 0000000..8790439 --- /dev/null +++ b/backend/repositories/invite.py @@ -0,0 +1,69 @@ +"""Invite repository for database queries.""" + +from sqlalchemy import desc, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from models import Invite, InviteStatus + + +class InviteRepository: + """Repository for invite-related database queries.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def get_by_identifier(self, identifier: str) -> Invite | None: + """Get an invite by identifier.""" + result = await self.db.execute( + select(Invite).where(Invite.identifier == identifier) + ) + return result.scalar_one_or_none() + + async def get_by_id(self, invite_id: int) -> Invite | None: + """Get an invite by ID.""" + result = await self.db.execute(select(Invite).where(Invite.id == invite_id)) + return result.scalar_one_or_none() + + async def get_by_godfather_id( + self, godfather_id: int, order_by_desc: bool = True + ) -> list[Invite]: + """Get all invites for a godfather user.""" + query = select(Invite).where(Invite.godfather_id == godfather_id) + if order_by_desc: + query = query.order_by(desc(Invite.created_at)) + else: + query = query.order_by(Invite.created_at) + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def count( + self, + status: InviteStatus | None = None, + godfather_id: int | None = None, + ) -> int: + """Count invites matching filters.""" + query = select(func.count(Invite.id)) + if status: + query = query.where(Invite.status == status) + if godfather_id: + query = query.where(Invite.godfather_id == godfather_id) + result = await self.db.execute(query) + return result.scalar() or 0 + + async def list_paginated( + self, + page: int, + per_page: int, + status: InviteStatus | None = None, + godfather_id: int | None = None, + ) -> list[Invite]: + """Get paginated list of invites.""" + offset = (page - 1) * per_page + query = select(Invite) + if status: + query = query.where(Invite.status == status) + if godfather_id: + query = query.where(Invite.godfather_id == godfather_id) + query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page) + result = await self.db.execute(query) + return list(result.scalars().all()) diff --git a/backend/repositories/price.py b/backend/repositories/price.py index b8322da..40b9b16 100644 --- a/backend/repositories/price.py +++ b/backend/repositories/price.py @@ -1,5 +1,7 @@ """Price repository for database queries.""" +from datetime import datetime + from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession @@ -25,3 +27,32 @@ class PriceRepository: ) result = await self.db.execute(query) return result.scalar_one_or_none() + + async def get_recent(self, limit: int = 20) -> list[PriceHistory]: + """Get the most recent price history records.""" + query = select(PriceHistory).order_by(desc(PriceHistory.timestamp)).limit(limit) + result = await self.db.execute(query) + return list(result.scalars().all()) + + async def get_by_timestamp( + self, + timestamp: str | datetime, + source: str = SOURCE_BITFINEX, + pair: str = PAIR_BTC_EUR, + ) -> PriceHistory | None: + """Get a price record by timestamp.""" + # Convert string timestamp to datetime if needed + timestamp_dt: datetime + if isinstance(timestamp, str): + timestamp_dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + else: + timestamp_dt = timestamp + + result = await self.db.execute( + select(PriceHistory).where( + PriceHistory.source == source, + PriceHistory.pair == pair, + PriceHistory.timestamp == timestamp_dt, + ) + ) + return result.scalar_one_or_none() diff --git a/backend/repositories/role.py b/backend/repositories/role.py new file mode 100644 index 0000000..c56c734 --- /dev/null +++ b/backend/repositories/role.py @@ -0,0 +1,18 @@ +"""Role repository for database queries.""" + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from models import Role + + +class RoleRepository: + """Repository for role-related database queries.""" + + def __init__(self, db: AsyncSession): + self.db = db + + async def get_by_name(self, name: str) -> Role | None: + """Get a role by name.""" + result = await self.db.execute(select(Role).where(Role.name == name)) + return result.scalar_one_or_none() diff --git a/backend/routes/exchange.py b/backend/routes/exchange.py index 96879af..21690ea 100644 --- a/backend/routes/exchange.py +++ b/backend/routes/exchange.py @@ -1,22 +1,18 @@ """Exchange routes for Bitcoin trading.""" import uuid -from datetime import UTC, date, datetime, time, timedelta +from datetime import UTC, date, datetime, timedelta from fastapi import APIRouter, Depends, HTTPException, Query, status -from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload from auth import require_permission from database import get_db from date_validation import validate_date_in_range -from exceptions import BadRequestError from mappers import ExchangeMapper from models import ( Availability, BitcoinTransferMethod, - Exchange, ExchangeStatus, Permission, PriceHistory, @@ -24,6 +20,7 @@ from models import ( User, ) from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price +from repositories.exchange import ExchangeRepository from repositories.price import PriceRepository from schemas import ( AdminExchangeResponse, @@ -44,6 +41,7 @@ from shared_constants import ( PREMIUM_PERCENTAGE, SLOT_DURATION_MINUTES, ) +from utils.enum_validation import validate_enum router = APIRouter(prefix="/api/exchange", tags=["exchange"]) @@ -190,28 +188,18 @@ async def get_available_slots( validate_date_in_range(date_param, context="book") # Get availability for the date - result = await db.execute( - select(Availability).where(Availability.date == date_param) - ) - availabilities = result.scalars().all() + from repositories.availability import AvailabilityRepository + from repositories.exchange import ExchangeRepository + + availability_repo = AvailabilityRepository(db) + availabilities = await availability_repo.get_by_date(date_param) if not availabilities: return AvailableSlotsResponse(date=date_param, slots=[]) # Get already booked slots for the date - date_start = datetime.combine(date_param, time.min, tzinfo=UTC) - date_end = datetime.combine(date_param, time.max, tzinfo=UTC) - - result = await db.execute( - select(Exchange.slot_start).where( - and_( - Exchange.slot_start >= date_start, - Exchange.slot_start <= date_end, - Exchange.status == ExchangeStatus.BOOKED, - ) - ) - ) - booked_starts = {row[0] for row in result.all()} + exchange_repo = ExchangeRepository(db) + booked_starts = await exchange_repo.get_booked_slots_for_date(date_param) # Expand each availability into slots all_slots: list[BookableSlot] = [] @@ -247,21 +235,16 @@ async def create_exchange( - EUR amount is within configured limits """ # Validate direction - try: - direction = TradeDirection(request.direction) - except ValueError: - raise BadRequestError( - f"Invalid direction: {request.direction}. Must be 'buy' or 'sell'." - ) from None + direction: TradeDirection = validate_enum( + TradeDirection, request.direction, "direction" + ) # Validate bitcoin transfer method - try: - bitcoin_transfer_method = BitcoinTransferMethod(request.bitcoin_transfer_method) - except ValueError: - raise BadRequestError( - f"Invalid bitcoin_transfer_method: {request.bitcoin_transfer_method}. " - "Must be 'onchain' or 'lightning'." - ) from None + bitcoin_transfer_method: BitcoinTransferMethod = validate_enum( + BitcoinTransferMethod, + request.bitcoin_transfer_method, + "bitcoin_transfer_method", + ) # Use service to create exchange (handles all validation) service = ExchangeService(db) @@ -289,12 +272,8 @@ async def get_my_trades( current_user: User = Depends(require_permission(Permission.VIEW_OWN_EXCHANGES)), ) -> list[ExchangeResponse]: """Get the current user's exchanges, sorted by date (newest first).""" - result = await db.execute( - select(Exchange) - .where(Exchange.user_id == current_user.id) - .order_by(Exchange.slot_start.desc()) - ) - exchanges = result.scalars().all() + exchange_repo = ExchangeRepository(db) + exchanges = await exchange_repo.get_by_user_id(current_user.id, order_by_desc=True) return [ExchangeMapper.to_response(ex, current_user.email) for ex in exchanges] @@ -348,19 +327,8 @@ async def get_upcoming_trades( _current_user: User = Depends(require_permission(Permission.VIEW_ALL_EXCHANGES)), ) -> list[AdminExchangeResponse]: """Get all upcoming booked trades, sorted by slot time ascending.""" - now = datetime.now(UTC) - result = await db.execute( - select(Exchange) - .options(joinedload(Exchange.user)) - .where( - and_( - Exchange.slot_start > now, - Exchange.status == ExchangeStatus.BOOKED, - ) - ) - .order_by(Exchange.slot_start.asc()) - ) - exchanges = result.scalars().all() + exchange_repo = ExchangeRepository(db) + exchanges = await exchange_repo.get_upcoming_booked() return [ExchangeMapper.to_admin_response(ex) for ex in exchanges] @@ -383,45 +351,19 @@ async def get_past_trades( - user_search: Search by user email (partial match) """ - now = datetime.now(UTC) - - # Start with base query for past trades (slot_start <= now OR not booked) - query = ( - select(Exchange) - .options(joinedload(Exchange.user)) - .where( - (Exchange.slot_start <= now) | (Exchange.status != ExchangeStatus.BOOKED) - ) - ) - # Apply status filter + status_enum: ExchangeStatus | None = None if status: - try: - status_enum = ExchangeStatus(status) - query = query.where(Exchange.status == status_enum) - except ValueError: - raise HTTPException( - status_code=400, - detail=f"Invalid status: {status}", - ) from None + status_enum = validate_enum(ExchangeStatus, status, "status") - # Apply date range filter - if start_date: - start_dt = datetime.combine(start_date, time.min, tzinfo=UTC) - query = query.where(Exchange.slot_start >= start_dt) - if end_date: - end_dt = datetime.combine(end_date, time.max, tzinfo=UTC) - query = query.where(Exchange.slot_start <= end_dt) - - # Apply user search filter (join with User table) - if user_search: - query = query.join(Exchange.user).where(User.email.ilike(f"%{user_search}%")) - - # Order by most recent first - query = query.order_by(Exchange.slot_start.desc()) - - result = await db.execute(query) - exchanges = result.scalars().all() + # Use repository for query + exchange_repo = ExchangeRepository(db) + exchanges = await exchange_repo.get_past_trades( + status=status_enum, + start_date=start_date, + end_date=end_date, + user_search=user_search, + ) return [ExchangeMapper.to_admin_response(ex) for ex in exchanges] @@ -487,6 +429,10 @@ async def search_users( Returns users whose email contains the search query (case-insensitive). Limited to 10 results for autocomplete purposes. """ + # Note: UserRepository doesn't have search yet, but we can add it + # For now, keeping direct query for this specific use case + from sqlalchemy import select + result = await db.execute( select(User).where(User.email.ilike(f"%{q}%")).order_by(User.email).limit(10) ) diff --git a/backend/services/exchange.py b/backend/services/exchange.py index cf1f641..f21b03d 100644 --- a/backend/services/exchange.py +++ b/backend/services/exchange.py @@ -1,9 +1,8 @@ """Exchange service for business logic related to Bitcoin trading.""" import uuid -from datetime import UTC, date, datetime, time, timedelta +from datetime import UTC, date, datetime, timedelta -from sqlalchemy import and_, select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -15,7 +14,6 @@ from exceptions import ( ServiceUnavailableError, ) from models import ( - Availability, BitcoinTransferMethod, Exchange, ExchangeStatus, @@ -23,6 +21,7 @@ from models import ( TradeDirection, User, ) +from repositories.exchange import ExchangeRepository from repositories.price import PriceRepository from shared_constants import ( EUR_TRADE_INCREMENT, @@ -44,6 +43,7 @@ class ExchangeService: def __init__(self, db: AsyncSession): self.db = db self.price_repo = PriceRepository(db) + self.exchange_repo = ExchangeRepository(db) def apply_premium_for_direction( self, @@ -107,20 +107,21 @@ class ExchangeService: self, slot_start: datetime, slot_date: date ) -> None: """Verify slot falls within availability.""" + from repositories.availability import AvailabilityRepository + slot_start_time = slot_start.time() slot_end_dt = slot_start + timedelta(minutes=SLOT_DURATION_MINUTES) slot_end_time = slot_end_dt.time() - result = await self.db.execute( - select(Availability).where( - and_( - Availability.date == slot_date, - Availability.start_time <= slot_start_time, - Availability.end_time >= slot_end_time, - ) - ) - ) - matching_availability = result.scalar_one_or_none() + availability_repo = AvailabilityRepository(self.db) + availabilities = await availability_repo.get_by_date(slot_date) + + # Check if any availability block contains this slot + matching_availability = None + for avail in availabilities: + if avail.start_time <= slot_start_time and avail.end_time >= slot_end_time: + matching_availability = avail + break if not matching_availability: slot_str = slot_start.strftime("%Y-%m-%d %H:%M") @@ -171,29 +172,19 @@ class ExchangeService: self, user: User, slot_date: date ) -> Exchange | None: """Check if user already has a trade on this date.""" - existing_trade_query = select(Exchange).where( - and_( - Exchange.user_id == user.id, - Exchange.slot_start - >= datetime.combine(slot_date, time.min, tzinfo=UTC), - Exchange.slot_start - < datetime.combine(slot_date, time.max, tzinfo=UTC) + timedelta(days=1), - Exchange.status == ExchangeStatus.BOOKED, - ) + exchanges = await self.exchange_repo.get_by_user_and_date_range( + user_id=user.id, + start_date=slot_date, + end_date=slot_date, + status=ExchangeStatus.BOOKED, ) - result = await self.db.execute(existing_trade_query) - return result.scalar_one_or_none() + return exchanges[0] if exchanges else None async def check_slot_already_booked(self, slot_start: datetime) -> Exchange | None: """Check if slot is already booked (only consider BOOKED status).""" - slot_booked_query = select(Exchange).where( - and_( - Exchange.slot_start == slot_start, - Exchange.status == ExchangeStatus.BOOKED, - ) + return await self.exchange_repo.get_by_slot_start( + slot_start, status=ExchangeStatus.BOOKED ) - result = await self.db.execute(slot_booked_query) - return result.scalar_one_or_none() async def create_exchange( self, @@ -297,9 +288,7 @@ class ExchangeService: NotFoundError: If exchange not found or user doesn't own it (for security, returns 404) """ - query = select(Exchange).where(Exchange.public_id == public_id) - result = await self.db.execute(query) - exchange = result.scalar_one_or_none() + exchange = await self.exchange_repo.get_by_public_id(public_id) if not exchange: raise NotFoundError("Trade") diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py new file mode 100644 index 0000000..a245cc5 --- /dev/null +++ b/backend/utils/__init__.py @@ -0,0 +1 @@ +"""Utility modules for common functionality.""" diff --git a/backend/utils/date_queries.py b/backend/utils/date_queries.py new file mode 100644 index 0000000..71f5051 --- /dev/null +++ b/backend/utils/date_queries.py @@ -0,0 +1,27 @@ +"""Utilities for date/time query operations.""" + +from datetime import UTC, date, datetime, time + + +def date_to_start_datetime(d: date) -> datetime: + """Convert a date to datetime at start of day (00:00:00) in UTC.""" + return datetime.combine(d, time.min, tzinfo=UTC) + + +def date_to_end_datetime(d: date) -> datetime: + """Convert a date to datetime at end of day (23:59:59.999999) in UTC.""" + return datetime.combine(d, time.max, tzinfo=UTC) + + +def date_range_to_datetime_range( + start_date: date, end_date: date +) -> tuple[datetime, datetime]: + """ + Convert a date range to datetime range. + + Returns: + Tuple of (start_datetime, end_datetime) where: + - start_datetime is start_date at 00:00:00 UTC + - end_datetime is end_date at 23:59:59.999999 UTC + """ + return date_to_start_datetime(start_date), date_to_end_datetime(end_date) diff --git a/backend/utils/enum_validation.py b/backend/utils/enum_validation.py new file mode 100644 index 0000000..2b4b78b --- /dev/null +++ b/backend/utils/enum_validation.py @@ -0,0 +1,32 @@ +"""Utilities for validating enum values from strings.""" + +from enum import Enum +from typing import TypeVar + +from exceptions import BadRequestError + +T = TypeVar("T", bound=Enum) + + +def validate_enum(enum_class: type[T], value: str, field_name: str = "value") -> T: + """ + Validate and convert string to enum. + + Args: + enum_class: The enum class to validate against + value: The string value to validate + field_name: Name of the field for error messages + + Returns: + The validated enum value + + Raises: + BadRequestError: If the value is not a valid enum member + """ + try: + return enum_class(value) + except ValueError: + valid_values = ", ".join(e.value for e in enum_class) + raise BadRequestError( + f"Invalid {field_name}: {value}. Must be one of: {valid_values}" + ) from None