"""Invite repository for database queries.""" from sqlalchemy import desc, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from models import Invite, InviteStatus class InviteRepository: """Repository for invite-related database queries.""" def __init__(self, db: AsyncSession): self.db = db async def get_by_identifier(self, identifier: str) -> Invite | None: """Get an invite by identifier, eagerly loading relationships.""" result = await self.db.execute( select(Invite) .options(joinedload(Invite.godfather), joinedload(Invite.used_by)) .where(Invite.identifier == identifier) ) return result.scalar_one_or_none() async def get_by_id(self, invite_id: int) -> Invite | None: """Get an invite by ID, eagerly loading relationships.""" result = await self.db.execute( select(Invite) .options(joinedload(Invite.godfather), joinedload(Invite.used_by)) .where(Invite.id == invite_id) ) return result.scalar_one_or_none() async def get_by_godfather_id( self, godfather_id: int, order_by_desc: bool = True ) -> list[Invite]: """Get all invites for a godfather user, eagerly loading relationships.""" query = ( select(Invite) .options(joinedload(Invite.used_by)) .where(Invite.godfather_id == godfather_id) ) if order_by_desc: query = query.order_by(desc(Invite.created_at)) else: query = query.order_by(Invite.created_at) result = await self.db.execute(query) return list(result.scalars().all()) async def count( self, status: InviteStatus | None = None, godfather_id: int | None = None, ) -> int: """Count invites matching filters.""" query = select(func.count(Invite.id)) if status: query = query.where(Invite.status == status) if godfather_id: query = query.where(Invite.godfather_id == godfather_id) result = await self.db.execute(query) return result.scalar() or 0 async def list_paginated( self, page: int, per_page: int, status: InviteStatus | None = None, godfather_id: int | None = None, ) -> list[Invite]: """Get paginated list of invites, eagerly loading relationships.""" offset = (page - 1) * per_page query = select(Invite).options( joinedload(Invite.godfather), joinedload(Invite.used_by) ) if status: query = query.where(Invite.status == status) if godfather_id: query = query.where(Invite.godfather_id == godfather_id) query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page) result = await self.db.execute(query) return list(result.scalars().all())