"""Invite repository for database queries.""" from sqlalchemy import desc, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from models import Invite, InviteStatus class InviteRepository: """Repository for invite-related database queries.""" def __init__(self, db: AsyncSession): self.db = db async def get_by_identifier(self, identifier: str) -> Invite | None: """Get an invite by identifier, eagerly loading relationships.""" result = await self.db.execute( select(Invite) .options(joinedload(Invite.godfather), joinedload(Invite.used_by)) .where(Invite.identifier == identifier) ) return result.scalar_one_or_none() async def get_by_id(self, invite_id: int) -> Invite | None: """Get an invite by ID, eagerly loading relationships.""" result = await self.db.execute( select(Invite) .options(joinedload(Invite.godfather), joinedload(Invite.used_by)) .where(Invite.id == invite_id) ) return result.scalar_one_or_none() async def get_by_godfather_id( self, godfather_id: int, order_by_desc: bool = True ) -> list[Invite]: """Get all invites for a godfather user, eagerly loading relationships.""" query = ( select(Invite) .options(joinedload(Invite.used_by)) .where(Invite.godfather_id == godfather_id) ) if order_by_desc: query = query.order_by(desc(Invite.created_at)) else: query = query.order_by(Invite.created_at) result = await self.db.execute(query) return list(result.scalars().all()) async def count( self, status: InviteStatus | None = None, godfather_id: int | None = None, ) -> int: """Count invites matching filters.""" query = select(func.count(Invite.id)) if status: query = query.where(Invite.status == status) if godfather_id: query = query.where(Invite.godfather_id == godfather_id) result = await self.db.execute(query) return result.scalar() or 0 async def list_paginated( self, page: int, per_page: int, status: InviteStatus | None = None, godfather_id: int | None = None, ) -> list[Invite]: """Get paginated list of invites, eagerly loading relationships.""" offset = (page - 1) * per_page query = select(Invite).options( joinedload(Invite.godfather), joinedload(Invite.used_by) ) if status: query = query.where(Invite.status == status) if godfather_id: query = query.where(Invite.godfather_id == godfather_id) query = query.order_by(desc(Invite.created_at)).offset(offset).limit(per_page) result = await self.db.execute(query) return list(result.scalars().all()) async def create(self, invite: Invite) -> Invite: """ Create a new invite record. Args: invite: Invite instance to persist Returns: Created Invite record (committed and refreshed) """ self.db.add(invite) await self.db.commit() await self.db.refresh(invite) return invite async def update(self, invite: Invite) -> Invite: """ Update an existing invite record. Args: invite: Invite instance to update Returns: Updated Invite record (committed and refreshed) """ await self.db.commit() await self.db.refresh(invite) return invite async def reload_with_relationships(self, invite_id: int) -> Invite: """ Reload an invite with all relationships eagerly loaded. Args: invite_id: ID of the invite to reload Returns: Invite record with relationships loaded """ result = await self.db.execute( select(Invite) .options(joinedload(Invite.godfather), joinedload(Invite.used_by)) .where(Invite.id == invite_id) ) return result.scalar_one()