from datetime import datetime, UTC from enum import Enum as PyEnum from typing import TypedDict from sqlalchemy import Integer, String, Float, DateTime, ForeignKey, Table, Column, Enum, select from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.ext.asyncio import AsyncSession from database import Base class RoleConfig(TypedDict): description: str permissions: list["Permission"] class Permission(str, PyEnum): """All available permissions in the system.""" # Counter permissions VIEW_COUNTER = "view_counter" INCREMENT_COUNTER = "increment_counter" # Sum permissions USE_SUM = "use_sum" # Audit permissions VIEW_AUDIT = "view_audit" # Role name constants ROLE_ADMIN = "admin" ROLE_REGULAR = "regular" # Role definitions with their permissions ROLE_DEFINITIONS: dict[str, RoleConfig] = { ROLE_ADMIN: { "description": "Administrator with audit access", "permissions": [ Permission.VIEW_AUDIT, ], }, ROLE_REGULAR: { "description": "Regular user with counter and sum access", "permissions": [ Permission.VIEW_COUNTER, Permission.INCREMENT_COUNTER, Permission.USE_SUM, ], }, } # Association table: Role <-> Permission (many-to-many) role_permissions = Table( "role_permissions", Base.metadata, Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True), Column("permission", Enum(Permission), primary_key=True), ) # Association table: User <-> Role (many-to-many) user_roles = Table( "user_roles", Base.metadata, Column("user_id", Integer, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True), Column("role_id", Integer, ForeignKey("roles.id", ondelete="CASCADE"), primary_key=True), ) class Role(Base): __tablename__ = "roles" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) description: Mapped[str] = mapped_column(String(255), nullable=True) # Relationship to users users: Mapped[list["User"]] = relationship( "User", secondary=user_roles, back_populates="roles", ) async def get_permissions(self, db: AsyncSession) -> set[Permission]: """Get all permissions for this role.""" result = await db.execute( select(role_permissions.c.permission).where(role_permissions.c.role_id == self.id) ) return {row[0] for row in result.fetchall()} async def set_permissions(self, db: AsyncSession, permissions: list[Permission]) -> None: """Set all permissions for this role (replaces existing).""" await db.execute(role_permissions.delete().where(role_permissions.c.role_id == self.id)) for perm in permissions: await db.execute(role_permissions.insert().values(role_id=self.id, permission=perm)) class User(Base): __tablename__ = "users" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) hashed_password: Mapped[str] = mapped_column(String(255), nullable=False) # Relationship to roles roles: Mapped[list[Role]] = relationship( "Role", secondary=user_roles, back_populates="users", lazy="selectin", ) async def get_permissions(self, db: AsyncSession) -> set[Permission]: """Get all permissions from all roles in a single query.""" result = await db.execute( select(role_permissions.c.permission) .join(user_roles, role_permissions.c.role_id == user_roles.c.role_id) .where(user_roles.c.user_id == self.id) ) return {row[0] for row in result.fetchall()} async def has_permission(self, db: AsyncSession, permission: Permission) -> bool: """Check if user has a specific permission through any of their roles.""" permissions = await self.get_permissions(db) return permission in permissions @property def role_names(self) -> list[str]: """Get list of role names for API responses.""" return [role.name for role in self.roles] class Counter(Base): __tablename__ = "counter" id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1) value: Mapped[int] = mapped_column(Integer, default=0) class SumRecord(Base): __tablename__ = "sum_records" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) a: Mapped[float] = mapped_column(Float, nullable=False) b: Mapped[float] = mapped_column(Float, nullable=False) result: Mapped[float] = mapped_column(Float, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(UTC) ) class CounterRecord(Base): __tablename__ = "counter_records" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True) value_before: Mapped[int] = mapped_column(Integer, nullable=False) value_after: Mapped[int] = mapped_column(Integer, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), default=lambda: datetime.now(UTC) )