Phase 0.1: Remove backend deprecated code
- Delete routes: counter.py, sum.py - Delete jobs.py and worker.py - Delete tests: test_counter.py, test_jobs.py - Update audit.py: keep only price-history endpoints - Update models.py: remove VIEW_COUNTER, INCREMENT_COUNTER, USE_SUM permissions - Update models.py: remove Counter, SumRecord, CounterRecord, RandomNumberOutcome models - Update schemas.py: remove sum/counter related schemas - Update main.py: remove deleted router imports - Update test_permissions.py: remove tests for deprecated features - Update test_price_history.py: remove worker-related tests - Update conftest.py: remove mock_enqueue_job fixture - Update auth.py: fix example in docstring
This commit is contained in:
parent
ea85198171
commit
5bad1e7e17
14 changed files with 35 additions and 1393 deletions
|
|
@ -90,9 +90,9 @@ def require_permission(*required_permissions: Permission):
|
|||
Dependency factory that checks if user has ALL required permissions.
|
||||
|
||||
Usage:
|
||||
@app.get("/api/counter")
|
||||
async def get_counter(
|
||||
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
|
||||
@app.get("/api/profile")
|
||||
async def get_profile(
|
||||
user: User = Depends(require_permission(Permission.MANAGE_OWN_PROFILE))
|
||||
):
|
||||
...
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,59 +0,0 @@
|
|||
"""Job definitions and enqueueing utilities using pgqueuer."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import asyncpg
|
||||
from pgqueuer.queries import Queries
|
||||
|
||||
from database import ASYNCPG_DATABASE_URL
|
||||
|
||||
# Job type constants
|
||||
JOB_RANDOM_NUMBER = "random_number"
|
||||
|
||||
# Connection pool for job enqueueing (lazy initialized)
|
||||
_pool: asyncpg.Pool | None = None
|
||||
_pool_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_job_pool() -> asyncpg.Pool:
|
||||
"""Get or create the connection pool for job enqueueing."""
|
||||
global _pool
|
||||
if _pool is not None:
|
||||
return _pool
|
||||
async with _pool_lock:
|
||||
# Double-check after acquiring lock
|
||||
if _pool is None:
|
||||
_pool = await asyncpg.create_pool(
|
||||
ASYNCPG_DATABASE_URL, min_size=1, max_size=5
|
||||
)
|
||||
return _pool
|
||||
|
||||
|
||||
async def close_job_pool() -> None:
|
||||
"""Close the connection pool. Call on app shutdown."""
|
||||
global _pool
|
||||
if _pool is not None:
|
||||
await _pool.close()
|
||||
_pool = None
|
||||
|
||||
|
||||
async def enqueue_random_number_job(user_id: int) -> int:
|
||||
"""
|
||||
Enqueue a random number job for the given user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user who triggered the job.
|
||||
|
||||
Returns:
|
||||
The job ID.
|
||||
|
||||
Raises:
|
||||
Exception: If enqueueing fails.
|
||||
"""
|
||||
pool = await get_job_pool()
|
||||
async with pool.acquire() as conn:
|
||||
queries = Queries.from_asyncpg_connection(conn)
|
||||
payload = json.dumps({"user_id": user_id}).encode()
|
||||
job_ids = await queries.enqueue(JOB_RANDOM_NUMBER, payload)
|
||||
return job_ids[0]
|
||||
|
|
@ -6,16 +6,13 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from database import Base, engine
|
||||
from jobs import close_job_pool
|
||||
from routes import audit as audit_routes
|
||||
from routes import auth as auth_routes
|
||||
from routes import availability as availability_routes
|
||||
from routes import booking as booking_routes
|
||||
from routes import counter as counter_routes
|
||||
from routes import invites as invites_routes
|
||||
from routes import meta as meta_routes
|
||||
from routes import profile as profile_routes
|
||||
from routes import sum as sum_routes
|
||||
from validate_constants import validate_shared_constants
|
||||
|
||||
|
||||
|
|
@ -28,8 +25,6 @@ async def lifespan(app: FastAPI):
|
|||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
# Cleanup on shutdown
|
||||
await close_job_pool()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
|
@ -44,8 +39,6 @@ app.add_middleware(
|
|||
|
||||
# Include routers - modules with single router
|
||||
app.include_router(auth_routes.router)
|
||||
app.include_router(sum_routes.router)
|
||||
app.include_router(counter_routes.router)
|
||||
app.include_router(audit_routes.router)
|
||||
app.include_router(profile_routes.router)
|
||||
app.include_router(availability_routes.router)
|
||||
|
|
|
|||
|
|
@ -30,13 +30,6 @@ class RoleConfig(TypedDict):
|
|||
class Permission(str, PyEnum):
|
||||
"""All available permissions in the system."""
|
||||
|
||||
# Counter permissions
|
||||
VIEW_COUNTER = "view_counter"
|
||||
INCREMENT_COUNTER = "increment_counter"
|
||||
|
||||
# Sum permissions
|
||||
USE_SUM = "use_sum"
|
||||
|
||||
# Audit permissions
|
||||
VIEW_AUDIT = "view_audit"
|
||||
FETCH_PRICE = "fetch_price"
|
||||
|
|
@ -93,11 +86,8 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = {
|
|||
],
|
||||
},
|
||||
ROLE_REGULAR: {
|
||||
"description": "Regular user with counter, sum, invite, and booking access",
|
||||
"description": "Regular user with profile, invite, and booking access",
|
||||
"permissions": [
|
||||
Permission.VIEW_COUNTER,
|
||||
Permission.INCREMENT_COUNTER,
|
||||
Permission.USE_SUM,
|
||||
Permission.MANAGE_OWN_PROFILE,
|
||||
Permission.VIEW_OWN_INVITES,
|
||||
Permission.BOOK_APPOINTMENT,
|
||||
|
|
@ -231,42 +221,6 @@ class User(Base):
|
|||
return [role.name for role in self.roles]
|
||||
|
||||
|
||||
class Counter(Base):
|
||||
__tablename__ = "counter"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
|
||||
value: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
|
||||
class SumRecord(Base):
|
||||
__tablename__ = "sum_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
a: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
b: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
result: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
class CounterRecord(Base):
|
||||
__tablename__ = "counter_records"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
value_before: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
value_after: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
class Invite(Base):
|
||||
__tablename__ = "invites"
|
||||
|
||||
|
|
@ -359,27 +313,6 @@ class Appointment(Base):
|
|||
)
|
||||
|
||||
|
||||
class RandomNumberOutcome(Base):
|
||||
"""Outcome of a random number job execution."""
|
||||
|
||||
__tablename__ = "random_number_outcomes"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
job_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
|
||||
triggered_by_user_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=False, index=True
|
||||
)
|
||||
triggered_by: Mapped[User] = relationship(
|
||||
"User", foreign_keys=[triggered_by_user_id], lazy="joined"
|
||||
)
|
||||
value: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
duration_ms: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), nullable=False, default="completed")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(UTC)
|
||||
)
|
||||
|
||||
|
||||
class PriceHistory(Base):
|
||||
"""Price history records from external exchanges."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,98 +1,18 @@
|
|||
"""Audit routes for viewing action records."""
|
||||
"""Audit routes for price history."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import desc, func, select
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import (
|
||||
CounterRecord,
|
||||
Permission,
|
||||
PriceHistory,
|
||||
RandomNumberOutcome,
|
||||
SumRecord,
|
||||
User,
|
||||
)
|
||||
from pagination import (
|
||||
calculate_offset,
|
||||
calculate_total_pages,
|
||||
create_paginated_response,
|
||||
)
|
||||
from models import Permission, PriceHistory, User
|
||||
from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price
|
||||
from schemas import (
|
||||
CounterRecordResponse,
|
||||
PaginatedCounterRecords,
|
||||
PaginatedSumRecords,
|
||||
PriceHistoryResponse,
|
||||
RandomNumberOutcomeResponse,
|
||||
SumRecordResponse,
|
||||
)
|
||||
from schemas import PriceHistoryResponse
|
||||
|
||||
router = APIRouter(prefix="/api/audit", tags=["audit"])
|
||||
|
||||
R = TypeVar("R", bound=BaseModel)
|
||||
|
||||
|
||||
async def paginate_with_user_email(
|
||||
db: AsyncSession,
|
||||
model: type[SumRecord] | type[CounterRecord],
|
||||
page: int,
|
||||
per_page: int,
|
||||
row_mapper: Callable[..., R],
|
||||
) -> tuple[list[R], int, int]:
|
||||
"""
|
||||
Generic pagination helper for audit records that need user email.
|
||||
|
||||
Returns: (records, total, total_pages)
|
||||
"""
|
||||
# Get total count
|
||||
count_result = await db.execute(select(func.count(model.id)))
|
||||
total = count_result.scalar() or 0
|
||||
|
||||
# Get paginated records with user email
|
||||
offset = calculate_offset(page, per_page)
|
||||
query = (
|
||||
select(model, User.email)
|
||||
.join(User, model.user_id == User.id)
|
||||
.order_by(desc(model.created_at))
|
||||
.offset(offset)
|
||||
.limit(per_page)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
records: list[R] = [row_mapper(record, email) for record, email in rows]
|
||||
return records, total, calculate_total_pages(total, per_page)
|
||||
|
||||
|
||||
def _to_counter_record_response(
|
||||
record: CounterRecord, email: str
|
||||
) -> CounterRecordResponse:
|
||||
return CounterRecordResponse(
|
||||
id=record.id,
|
||||
user_email=email,
|
||||
value_before=record.value_before,
|
||||
value_after=record.value_after,
|
||||
created_at=record.created_at,
|
||||
)
|
||||
|
||||
|
||||
def _to_sum_record_response(record: SumRecord, email: str) -> SumRecordResponse:
|
||||
return SumRecordResponse(
|
||||
id=record.id,
|
||||
user_email=email,
|
||||
a=record.a,
|
||||
b=record.b,
|
||||
result=record.result,
|
||||
created_at=record.created_at,
|
||||
)
|
||||
|
||||
|
||||
def _to_price_history_response(record: PriceHistory) -> PriceHistoryResponse:
|
||||
return PriceHistoryResponse(
|
||||
|
|
@ -105,64 +25,6 @@ def _to_price_history_response(record: PriceHistory) -> PriceHistoryResponse:
|
|||
)
|
||||
|
||||
|
||||
@router.get("/counter", response_model=PaginatedCounterRecords)
|
||||
async def get_counter_records(
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(10, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||
) -> PaginatedCounterRecords:
|
||||
"""Get paginated counter action records."""
|
||||
records, total, _ = await paginate_with_user_email(
|
||||
db, CounterRecord, page, per_page, _to_counter_record_response
|
||||
)
|
||||
return create_paginated_response(records, total, page, per_page)
|
||||
|
||||
|
||||
@router.get("/sum", response_model=PaginatedSumRecords)
|
||||
async def get_sum_records(
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(10, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||
) -> PaginatedSumRecords:
|
||||
"""Get paginated sum action records."""
|
||||
records, total, _ = await paginate_with_user_email(
|
||||
db, SumRecord, page, per_page, _to_sum_record_response
|
||||
)
|
||||
return create_paginated_response(records, total, page, per_page)
|
||||
|
||||
|
||||
@router.get("/random-jobs", response_model=list[RandomNumberOutcomeResponse])
|
||||
async def get_random_job_outcomes(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
|
||||
) -> list[RandomNumberOutcomeResponse]:
|
||||
"""Get all random number job outcomes, newest first."""
|
||||
# Explicit join to avoid N+1 query
|
||||
query = (
|
||||
select(RandomNumberOutcome, User.email)
|
||||
.join(User, RandomNumberOutcome.triggered_by_user_id == User.id)
|
||||
.order_by(desc(RandomNumberOutcome.created_at))
|
||||
)
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
return [
|
||||
RandomNumberOutcomeResponse(
|
||||
id=outcome.id,
|
||||
job_id=outcome.job_id,
|
||||
triggered_by_user_id=outcome.triggered_by_user_id,
|
||||
triggered_by_email=email,
|
||||
value=outcome.value,
|
||||
duration_ms=outcome.duration_ms,
|
||||
status=outcome.status,
|
||||
created_at=outcome.created_at,
|
||||
)
|
||||
for outcome, email in rows
|
||||
]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Price History Endpoints
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -1,64 +0,0 @@
|
|||
"""Counter routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from jobs import enqueue_random_number_job
|
||||
from models import Counter, CounterRecord, Permission, User
|
||||
|
||||
router = APIRouter(prefix="/api/counter", tags=["counter"])
|
||||
|
||||
|
||||
async def get_or_create_counter(db: AsyncSession) -> Counter:
|
||||
"""Get the singleton counter, creating it if it doesn't exist."""
|
||||
result = await db.execute(select(Counter).where(Counter.id == 1))
|
||||
counter = result.scalar_one_or_none()
|
||||
if not counter:
|
||||
counter = Counter(id=1, value=0)
|
||||
db.add(counter)
|
||||
await db.commit()
|
||||
await db.refresh(counter)
|
||||
return counter
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_counter(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
|
||||
) -> dict[str, int]:
|
||||
"""Get the current counter value."""
|
||||
counter = await get_or_create_counter(db)
|
||||
return {"value": counter.value}
|
||||
|
||||
|
||||
@router.post("/increment")
|
||||
async def increment_counter(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
|
||||
) -> dict[str, int]:
|
||||
"""Increment the counter, record the action, and enqueue a random number job."""
|
||||
counter = await get_or_create_counter(db)
|
||||
value_before = counter.value
|
||||
counter.value += 1
|
||||
|
||||
record = CounterRecord(
|
||||
user_id=current_user.id,
|
||||
value_before=value_before,
|
||||
value_after=counter.value,
|
||||
)
|
||||
db.add(record)
|
||||
|
||||
# Enqueue random number job - if this fails, the request fails
|
||||
try:
|
||||
await enqueue_random_number_job(current_user.id)
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to enqueue job: {e}"
|
||||
) from e
|
||||
|
||||
await db.commit()
|
||||
return {"value": counter.value}
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
"""Sum calculation routes."""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from auth import require_permission
|
||||
from database import get_db
|
||||
from models import Permission, SumRecord, User
|
||||
from schemas import SumRequest, SumResponse
|
||||
|
||||
router = APIRouter(prefix="/api/sum", tags=["sum"])
|
||||
|
||||
|
||||
@router.post("", response_model=SumResponse)
|
||||
async def calculate_sum(
|
||||
data: SumRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(require_permission(Permission.USE_SUM)),
|
||||
) -> SumResponse:
|
||||
"""Calculate the sum of two numbers and record it."""
|
||||
result = data.a + data.b
|
||||
record = SumRecord(
|
||||
user_id=current_user.id,
|
||||
a=data.a,
|
||||
b=data.b,
|
||||
result=result,
|
||||
)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
return SumResponse(a=data.a, b=data.b, result=result)
|
||||
|
|
@ -37,42 +37,6 @@ class RegisterWithInvite(BaseModel):
|
|||
invite_identifier: str
|
||||
|
||||
|
||||
class SumRequest(BaseModel):
|
||||
"""Request model for sum calculation."""
|
||||
|
||||
a: float
|
||||
b: float
|
||||
|
||||
|
||||
class SumResponse(BaseModel):
|
||||
"""Response model for sum calculation."""
|
||||
|
||||
a: float
|
||||
b: float
|
||||
result: float
|
||||
|
||||
|
||||
class CounterRecordResponse(BaseModel):
|
||||
"""Response model for a counter audit record."""
|
||||
|
||||
id: int
|
||||
user_email: str
|
||||
value_before: int
|
||||
value_after: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SumRecordResponse(BaseModel):
|
||||
"""Response model for a sum audit record."""
|
||||
|
||||
id: int
|
||||
user_email: str
|
||||
a: float
|
||||
b: float
|
||||
result: float
|
||||
created_at: datetime
|
||||
|
||||
|
||||
RecordT = TypeVar("RecordT", bound=BaseModel)
|
||||
|
||||
|
||||
|
|
@ -86,10 +50,6 @@ class PaginatedResponse(BaseModel, Generic[RecordT]):
|
|||
total_pages: int
|
||||
|
||||
|
||||
PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
|
||||
PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
|
||||
|
||||
|
||||
class ProfileResponse(BaseModel):
|
||||
"""Response model for profile data."""
|
||||
|
||||
|
|
@ -258,24 +218,6 @@ class AppointmentResponse(BaseModel):
|
|||
PaginatedAppointments = PaginatedResponse[AppointmentResponse]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Random Number Job Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class RandomNumberOutcomeResponse(BaseModel):
|
||||
"""Response model for a random number job outcome."""
|
||||
|
||||
id: int
|
||||
job_id: int
|
||||
triggered_by_user_id: int
|
||||
triggered_by_email: str
|
||||
value: int
|
||||
duration_ms: int
|
||||
status: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Price History Schemas
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
# Set required env vars before importing app
|
||||
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
|
||||
|
|
@ -239,18 +238,3 @@ async def user_no_roles(client_factory):
|
|||
"cookies": dict(response.cookies),
|
||||
"response": response,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_enqueue_job():
|
||||
"""Mock job enqueueing for tests that hit the counter increment endpoint.
|
||||
|
||||
pgqueuer requires PostgreSQL-specific features that aren't available
|
||||
in the test database setup. We mock the enqueue function to avoid
|
||||
connection issues while still testing the counter logic.
|
||||
|
||||
Tests that call POST /api/counter/increment must use this fixture.
|
||||
"""
|
||||
mock = AsyncMock(return_value=1) # Return a fake job ID
|
||||
with patch("routes.counter.enqueue_random_number_job", mock):
|
||||
yield mock
|
||||
|
|
|
|||
|
|
@ -1,239 +0,0 @@
|
|||
"""Tests for counter endpoints.
|
||||
|
||||
Note: Registration now requires an invite code.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from auth import COOKIE_NAME
|
||||
from models import ROLE_REGULAR
|
||||
from tests.conftest import create_user_with_roles
|
||||
from tests.helpers import create_invite_for_godfather, unique_email
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_enqueues_job_with_user_id(client_factory, mock_enqueue_job):
|
||||
"""Verify that incrementing the counter enqueues a job with the user's ID."""
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
"email": unique_email(),
|
||||
"password": "testpass123",
|
||||
"invite_identifier": invite_code,
|
||||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
# Get user ID from the me endpoint
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
me_response = await authed.get("/api/auth/me")
|
||||
user_id = me_response.json()["id"]
|
||||
|
||||
# Increment counter
|
||||
response = await authed.post("/api/counter/increment")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify enqueue was called with the correct user_id
|
||||
mock_enqueue_job.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
# Protected endpoint tests - without auth
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_counter_requires_auth(client):
|
||||
response = await client.get("/api/counter")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_counter_requires_auth(client):
|
||||
response = await client.post("/api/counter/increment")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_counter_invalid_cookie(client_factory):
|
||||
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed:
|
||||
response = await authed.get("/api/counter")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_counter_invalid_cookie(client_factory):
|
||||
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed:
|
||||
response = await authed.post("/api/counter/increment")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# Authenticated counter tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_counter_authenticated(client_factory):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
"email": unique_email(),
|
||||
"password": "testpass123",
|
||||
"invite_identifier": invite_code,
|
||||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
response = await authed.get("/api/counter")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "value" in response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_counter(client_factory, mock_enqueue_job):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
"email": unique_email(),
|
||||
"password": "testpass123",
|
||||
"invite_identifier": invite_code,
|
||||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
# Get current value
|
||||
before = await authed.get("/api/counter")
|
||||
before_value = before.json()["value"]
|
||||
|
||||
# Increment
|
||||
response = await authed.post("/api/counter/increment")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["value"] == before_value + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_increment_counter_multiple(client_factory, mock_enqueue_job):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
"email": unique_email(),
|
||||
"password": "testpass123",
|
||||
"invite_identifier": invite_code,
|
||||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
# Get starting value
|
||||
before = await authed.get("/api/counter")
|
||||
start = before.json()["value"]
|
||||
|
||||
# Increment 3 times
|
||||
await authed.post("/api/counter/increment")
|
||||
await authed.post("/api/counter/increment")
|
||||
response = await authed.post("/api/counter/increment")
|
||||
|
||||
assert response.json()["value"] == start + 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_counter_after_increment(client_factory, mock_enqueue_job):
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite_code = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
reg = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
"email": unique_email(),
|
||||
"password": "testpass123",
|
||||
"invite_identifier": invite_code,
|
||||
},
|
||||
)
|
||||
cookies = dict(reg.cookies)
|
||||
|
||||
async with client_factory.create(cookies=cookies) as authed:
|
||||
before = await authed.get("/api/counter")
|
||||
start = before.json()["value"]
|
||||
|
||||
await authed.post("/api/counter/increment")
|
||||
await authed.post("/api/counter/increment")
|
||||
|
||||
response = await authed.get("/api/counter")
|
||||
assert response.json()["value"] == start + 2
|
||||
|
||||
|
||||
# Counter is shared between users
|
||||
@pytest.mark.asyncio
|
||||
async def test_counter_shared_between_users(client_factory, mock_enqueue_job):
|
||||
# Create godfather and invites for two users
|
||||
async with client_factory.get_db_session() as db:
|
||||
godfather = await create_user_with_roles(
|
||||
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
|
||||
)
|
||||
invite1 = await create_invite_for_godfather(db, godfather.id)
|
||||
invite2 = await create_invite_for_godfather(db, godfather.id)
|
||||
|
||||
# Create first user
|
||||
reg1 = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
"email": unique_email("share1"),
|
||||
"password": "testpass123",
|
||||
"invite_identifier": invite1,
|
||||
},
|
||||
)
|
||||
cookies1 = dict(reg1.cookies)
|
||||
|
||||
async with client_factory.create(cookies=cookies1) as user1:
|
||||
# Get starting value
|
||||
before = await user1.get("/api/counter")
|
||||
start = before.json()["value"]
|
||||
|
||||
await user1.post("/api/counter/increment")
|
||||
await user1.post("/api/counter/increment")
|
||||
|
||||
# Create second user - should see the increments
|
||||
reg2 = await client_factory.post(
|
||||
"/api/auth/register",
|
||||
json={
|
||||
"email": unique_email("share2"),
|
||||
"password": "testpass123",
|
||||
"invite_identifier": invite2,
|
||||
},
|
||||
)
|
||||
cookies2 = dict(reg2.cookies)
|
||||
|
||||
async with client_factory.create(cookies=cookies2) as user2:
|
||||
response = await user2.get("/api/counter")
|
||||
assert response.json()["value"] == start + 2
|
||||
|
||||
# Second user increments
|
||||
await user2.post("/api/counter/increment")
|
||||
|
||||
# First user sees the increment
|
||||
async with client_factory.create(cookies=cookies1) as user1:
|
||||
response = await user1.get("/api/counter")
|
||||
assert response.json()["value"] == start + 3
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
"""Tests for job handler logic."""
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from worker import process_random_number_job
|
||||
|
||||
|
||||
def create_mock_pool(mock_conn: AsyncMock) -> MagicMock:
|
||||
"""Create a mock asyncpg pool with proper async context manager behavior."""
|
||||
mock_pool = MagicMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_acquire():
|
||||
yield mock_conn
|
||||
|
||||
mock_pool.acquire = mock_acquire
|
||||
return mock_pool
|
||||
|
||||
|
||||
class TestRandomNumberJobHandler:
|
||||
"""Tests for the random number job handler logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generates_random_number_in_range(self):
|
||||
"""Verify random number is in range [0, 100]."""
|
||||
# Create mock job
|
||||
job = MagicMock()
|
||||
job.id = 123
|
||||
job.payload = json.dumps({"user_id": 1}).encode()
|
||||
|
||||
# Create mock db pool
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
# Run the job handler
|
||||
await process_random_number_job(job, mock_pool)
|
||||
|
||||
# Verify execute was called
|
||||
mock_conn.execute.assert_called_once()
|
||||
call_args = mock_conn.execute.call_args
|
||||
|
||||
# Extract the value argument (position 3 in the args)
|
||||
# Args: (query, job_id, user_id, value, duration_ms, status)
|
||||
value = call_args[0][3]
|
||||
|
||||
assert 0 <= value <= 100, f"Value {value} is not in range [0, 100]"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_correct_user_id(self):
|
||||
"""Verify the correct user_id is stored in the outcome."""
|
||||
user_id = 42
|
||||
|
||||
job = MagicMock()
|
||||
job.id = 123
|
||||
job.payload = json.dumps({"user_id": user_id}).encode()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
await process_random_number_job(job, mock_pool)
|
||||
|
||||
mock_conn.execute.assert_called_once()
|
||||
call_args = mock_conn.execute.call_args
|
||||
|
||||
# Args: (query, job_id, user_id, value, duration_ms, status)
|
||||
stored_user_id = call_args[0][2]
|
||||
assert stored_user_id == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_job_id(self):
|
||||
"""Verify the job_id is stored in the outcome."""
|
||||
job_id = 456
|
||||
|
||||
job = MagicMock()
|
||||
job.id = job_id
|
||||
job.payload = json.dumps({"user_id": 1}).encode()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
await process_random_number_job(job, mock_pool)
|
||||
|
||||
mock_conn.execute.assert_called_once()
|
||||
call_args = mock_conn.execute.call_args
|
||||
|
||||
# Args: (query, job_id, user_id, value, duration_ms, status)
|
||||
stored_job_id = call_args[0][1]
|
||||
assert stored_job_id == job_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_status_completed(self):
|
||||
"""Verify the status is set to 'completed'."""
|
||||
job = MagicMock()
|
||||
job.id = 123
|
||||
job.payload = json.dumps({"user_id": 1}).encode()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
await process_random_number_job(job, mock_pool)
|
||||
|
||||
mock_conn.execute.assert_called_once()
|
||||
call_args = mock_conn.execute.call_args
|
||||
|
||||
# Args: (query, job_id, user_id, value, duration_ms, status)
|
||||
status = call_args[0][5]
|
||||
assert status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_records_duration_ms(self):
|
||||
"""Verify duration_ms is recorded (should be >= 0)."""
|
||||
job = MagicMock()
|
||||
job.id = 123
|
||||
job.payload = json.dumps({"user_id": 1}).encode()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
await process_random_number_job(job, mock_pool)
|
||||
|
||||
mock_conn.execute.assert_called_once()
|
||||
call_args = mock_conn.execute.call_args
|
||||
|
||||
# Args: (query, job_id, user_id, value, duration_ms, status)
|
||||
duration_ms = call_args[0][4]
|
||||
assert isinstance(duration_ms, int)
|
||||
assert duration_ms >= 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_id_does_not_insert(self):
|
||||
"""Verify no insert happens if user_id is missing from payload."""
|
||||
job = MagicMock()
|
||||
job.id = 123
|
||||
job.payload = json.dumps({}).encode() # Missing user_id
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
await process_random_number_job(job, mock_pool)
|
||||
|
||||
# Should not have called execute
|
||||
mock_conn.execute.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_payload_does_not_insert(self):
|
||||
"""Verify no insert happens with empty payload."""
|
||||
job = MagicMock()
|
||||
job.id = 123
|
||||
job.payload = None
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
await process_random_number_job(job, mock_pool)
|
||||
|
||||
# Should not have called execute
|
||||
mock_conn.execute.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_json_payload_does_not_insert(self):
|
||||
"""Verify no insert happens with malformed JSON payload."""
|
||||
job = MagicMock()
|
||||
job.id = 123
|
||||
job.payload = b"not valid json {"
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
await process_random_number_job(job, mock_pool)
|
||||
|
||||
# Should not have called execute
|
||||
mock_conn.execute.assert_not_called()
|
||||
|
|
@ -50,10 +50,10 @@ class TestRoleAssignment:
|
|||
data = response.json()
|
||||
permissions = data["permissions"]
|
||||
|
||||
# Should have counter and sum permissions
|
||||
assert Permission.VIEW_COUNTER.value in permissions
|
||||
assert Permission.INCREMENT_COUNTER.value in permissions
|
||||
assert Permission.USE_SUM.value in permissions
|
||||
# Should have profile and booking permissions
|
||||
assert Permission.MANAGE_OWN_PROFILE.value in permissions
|
||||
assert Permission.BOOK_APPOINTMENT.value in permissions
|
||||
assert Permission.VIEW_OWN_APPOINTMENTS.value in permissions
|
||||
|
||||
# Should NOT have audit permission
|
||||
assert Permission.VIEW_AUDIT.value not in permissions
|
||||
|
|
@ -69,10 +69,8 @@ class TestRoleAssignment:
|
|||
# Should have audit permission
|
||||
assert Permission.VIEW_AUDIT.value in permissions
|
||||
|
||||
# Should NOT have counter/sum permissions
|
||||
assert Permission.VIEW_COUNTER.value not in permissions
|
||||
assert Permission.INCREMENT_COUNTER.value not in permissions
|
||||
assert Permission.USE_SUM.value not in permissions
|
||||
# Should NOT have booking permissions (those are for regular users)
|
||||
assert Permission.BOOK_APPOINTMENT.value not in permissions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_with_no_roles_has_no_permissions(
|
||||
|
|
@ -86,124 +84,6 @@ class TestRoleAssignment:
|
|||
assert data["permissions"] == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Counter Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCounterAccess:
|
||||
"""Test access control for counter endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_view_counter(self, client_factory, regular_user):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "value" in response.json()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_increment_counter(
|
||||
self, client_factory, regular_user, mock_enqueue_job
|
||||
):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post("/api/counter/increment")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "value" in response.json()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_view_counter(self, client_factory, admin_user):
|
||||
"""Admin users should be forbidden from counter endpoints."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "permission" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_increment_counter(self, client_factory, admin_user):
|
||||
"""Admin users should be forbidden from incrementing counter."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post("/api/counter/increment")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_view_counter(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
"""Users with no roles should be forbidden."""
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.get("/api/counter")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_cannot_view_counter(self, client):
|
||||
"""Unauthenticated requests should get 401."""
|
||||
response = await client.get("/api/counter")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_cannot_increment_counter(self, client):
|
||||
"""Unauthenticated requests should get 401."""
|
||||
response = await client.post("/api/counter/increment")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sum Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSumAccess:
|
||||
"""Test access control for sum endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_can_use_sum(self, client_factory, regular_user):
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/sum",
|
||||
json={"a": 5, "b": 3},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["result"] == 8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_use_sum(self, client_factory, admin_user):
|
||||
"""Admin users should be forbidden from sum endpoint."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/sum",
|
||||
json={"a": 5, "b": 3},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_use_sum(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.post(
|
||||
"/api/sum",
|
||||
json={"a": 5, "b": 3},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_cannot_use_sum(self, client):
|
||||
response = await client.post(
|
||||
"/api/sum",
|
||||
json={"a": 5, "b": 3},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Audit Endpoint Access Tests
|
||||
# =============================================================================
|
||||
|
|
@ -213,89 +93,37 @@ class TestAuditAccess:
|
|||
"""Test access control for audit endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_view_counter_audit(self, client_factory, admin_user):
|
||||
async def test_admin_can_view_price_history(self, client_factory, admin_user):
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
response = await client.get("/api/audit/price-history")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "records" in data
|
||||
assert "total" in data
|
||||
# Returns a list
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_view_sum_audit(self, client_factory, admin_user):
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/sum")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "records" in data
|
||||
assert "total" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_counter_audit(
|
||||
async def test_regular_user_cannot_view_price_history(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular users should be forbidden from audit endpoints."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
response = await client.get("/api/audit/price-history")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "permission" in response.json()["detail"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_sum_audit(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular users should be forbidden from audit endpoints."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/sum")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_without_roles_cannot_view_audit(
|
||||
self, client_factory, user_no_roles
|
||||
):
|
||||
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
response = await client.get("/api/audit/price-history")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_cannot_view_counter_audit(self, client):
|
||||
response = await client.get("/api/audit/counter")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_cannot_view_sum_audit(self, client):
|
||||
response = await client.get("/api/audit/sum")
|
||||
assert response.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_view_random_jobs(self, client_factory, admin_user):
|
||||
"""Admin should be able to view random job outcomes."""
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/random-jobs")
|
||||
|
||||
assert response.status_code == 200
|
||||
# Returns a list (no pagination)
|
||||
assert isinstance(response.json(), list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_cannot_view_random_jobs(
|
||||
self, client_factory, regular_user
|
||||
):
|
||||
"""Regular users should be forbidden from random-jobs endpoint."""
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/random-jobs")
|
||||
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unauthenticated_cannot_view_random_jobs(self, client):
|
||||
"""Unauthenticated users should get 401."""
|
||||
response = await client.get("/api/audit/random-jobs")
|
||||
async def test_unauthenticated_cannot_view_price_history(self, client):
|
||||
response = await client.get("/api/audit/price-history")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
|
|
@ -320,18 +148,18 @@ class TestSecurityBypassAttempts:
|
|||
"""
|
||||
# Regular user tries to access audit endpoint
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
response = await client.get("/api/audit/price-history")
|
||||
|
||||
# Should be denied regardless of any manipulation attempts
|
||||
assert response.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_access_counter_with_expired_session(self, client_factory):
|
||||
async def test_cannot_access_with_expired_session(self, client_factory):
|
||||
"""Test that invalid/expired tokens are rejected."""
|
||||
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5OTk5IiwiZXhwIjoxfQ.invalid"
|
||||
|
||||
async with client_factory.create(cookies={"auth_token": fake_token}) as client:
|
||||
response = await client.get("/api/counter")
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
|
@ -348,7 +176,7 @@ class TestSecurityBypassAttempts:
|
|||
async with client_factory.create(
|
||||
cookies={"auth_token": tampered_token}
|
||||
) as client:
|
||||
response = await client.get("/api/counter")
|
||||
response = await client.get("/api/profile")
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
|
|
@ -386,7 +214,7 @@ class TestSecurityBypassAttempts:
|
|||
|
||||
# Try to access audit with this new user
|
||||
async with client_factory.create(cookies=dict(response.cookies)) as client:
|
||||
audit_response = await client.get("/api/audit/counter")
|
||||
audit_response = await client.get("/api/audit/price-history")
|
||||
|
||||
assert audit_response.status_code == 403
|
||||
|
||||
|
|
@ -452,10 +280,10 @@ class TestSecurityBypassAttempts:
|
|||
)
|
||||
cookies = dict(login_response.cookies)
|
||||
|
||||
# Verify can access counter but not audit
|
||||
# Verify can access profile but not audit
|
||||
async with client_factory.create(cookies=cookies) as client:
|
||||
assert (await client.get("/api/counter")).status_code == 200
|
||||
assert (await client.get("/api/audit/counter")).status_code == 403
|
||||
assert (await client.get("/api/profile")).status_code == 200
|
||||
assert (await client.get("/api/audit/price-history")).status_code == 403
|
||||
|
||||
# Change user's role from regular to admin
|
||||
async with client_factory.get_db_session() as db:
|
||||
|
|
@ -468,62 +296,7 @@ class TestSecurityBypassAttempts:
|
|||
user.roles = [admin_role] # Replace roles with admin only
|
||||
await db.commit()
|
||||
|
||||
# Now should have audit access but not counter access
|
||||
# Now should have audit access but not profile access (admin doesn't have MANAGE_OWN_PROFILE)
|
||||
async with client_factory.create(cookies=cookies) as client:
|
||||
assert (await client.get("/api/audit/counter")).status_code == 200
|
||||
assert (await client.get("/api/counter")).status_code == 403
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Audit Record Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAuditRecords:
|
||||
"""Test that actions are properly recorded in audit logs."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_counter_increment_creates_audit_record(
|
||||
self, client_factory, regular_user, admin_user, mock_enqueue_job
|
||||
):
|
||||
"""Verify that counter increments are recorded and visible in audit."""
|
||||
# Regular user increments counter
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
await client.post("/api/counter/increment")
|
||||
|
||||
# Admin checks audit
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/counter")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
# Find record for our user
|
||||
records = data["records"]
|
||||
user_records = [r for r in records if r["user_email"] == regular_user["email"]]
|
||||
assert len(user_records) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sum_operation_creates_audit_record(
|
||||
self, client_factory, regular_user, admin_user
|
||||
):
|
||||
"""Verify that sum operations are recorded and visible in audit."""
|
||||
# Regular user uses sum
|
||||
async with client_factory.create(cookies=regular_user["cookies"]) as client:
|
||||
await client.post("/api/sum", json={"a": 10, "b": 20})
|
||||
|
||||
# Admin checks audit
|
||||
async with client_factory.create(cookies=admin_user["cookies"]) as client:
|
||||
response = await client.get("/api/audit/sum")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] >= 1
|
||||
|
||||
# Find record with our values
|
||||
records = data["records"]
|
||||
matching = [
|
||||
r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30
|
||||
]
|
||||
assert len(matching) >= 1
|
||||
assert (await client.get("/api/audit/price-history")).status_code == 200
|
||||
assert (await client.get("/api/profile")).status_code == 403
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Tests for price history feature."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
|
@ -9,7 +8,6 @@ import pytest
|
|||
|
||||
from models import PriceHistory
|
||||
from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price
|
||||
from worker import process_bitcoin_price_job
|
||||
|
||||
|
||||
def create_mock_httpx_client(
|
||||
|
|
@ -293,76 +291,3 @@ class TestManualFetch:
|
|||
data = response.json()
|
||||
assert data["id"] == existing_id
|
||||
assert data["price"] == 90000.0 # Original price, not the new one
|
||||
|
||||
|
||||
def create_mock_pool(mock_conn: AsyncMock) -> MagicMock:
|
||||
"""Create a mock asyncpg pool with proper async context manager behavior."""
|
||||
mock_pool = MagicMock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def mock_acquire():
|
||||
yield mock_conn
|
||||
|
||||
mock_pool.acquire = mock_acquire
|
||||
return mock_pool
|
||||
|
||||
|
||||
class TestProcessBitcoinPriceJob:
|
||||
"""Tests for the scheduled Bitcoin price job handler."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_price_on_success(self):
|
||||
"""Verify price is stored in database on successful fetch."""
|
||||
mock_http_client = create_mock_httpx_client(
|
||||
json_response=create_bitfinex_ticker_response(95000.0)
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
with patch("price_fetcher.httpx.AsyncClient", return_value=mock_http_client):
|
||||
await process_bitcoin_price_job(mock_pool)
|
||||
|
||||
# Verify execute was called with correct values
|
||||
mock_conn.execute.assert_called_once()
|
||||
call_args = mock_conn.execute.call_args
|
||||
|
||||
# Check the SQL parameters
|
||||
assert call_args[0][1] == SOURCE_BITFINEX # source
|
||||
assert call_args[0][2] == PAIR_BTC_EUR # pair
|
||||
assert call_args[0][3] == 95000.0 # price
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fails_silently_on_api_error(self):
|
||||
"""Verify no exception is raised and no DB insert on API error."""
|
||||
import httpx
|
||||
|
||||
error = httpx.HTTPStatusError(
|
||||
"Server Error", request=MagicMock(), response=MagicMock()
|
||||
)
|
||||
mock_http_client = create_mock_httpx_client(raise_for_status_error=error)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
with patch("price_fetcher.httpx.AsyncClient", return_value=mock_http_client):
|
||||
# Should not raise an exception
|
||||
await process_bitcoin_price_job(mock_pool)
|
||||
|
||||
# Should not have called execute
|
||||
mock_conn.execute.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fails_silently_on_db_error(self):
|
||||
"""Verify no exception is raised on database error."""
|
||||
mock_http_client = create_mock_httpx_client(
|
||||
json_response=create_bitfinex_ticker_response(95000.0)
|
||||
)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.execute.side_effect = Exception("Database connection error")
|
||||
mock_pool = create_mock_pool(mock_conn)
|
||||
|
||||
with patch("price_fetcher.httpx.AsyncClient", return_value=mock_http_client):
|
||||
# Should not raise an exception
|
||||
await process_bitcoin_price_job(mock_pool)
|
||||
|
|
|
|||
|
|
@ -1,202 +0,0 @@
|
|||
"""Background job worker using pgqueuer."""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import asyncpg
|
||||
from pgqueuer import Job, QueueManager, SchedulerManager
|
||||
from pgqueuer.db import AsyncpgDriver
|
||||
from pgqueuer.models import Schedule
|
||||
from pgqueuer.queries import Queries
|
||||
|
||||
from database import ASYNCPG_DATABASE_URL
|
||||
from jobs import JOB_RANDOM_NUMBER
|
||||
from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price
|
||||
|
||||
# Scheduled job type (internal to worker, not enqueued externally)
|
||||
JOB_FETCH_BITCOIN_PRICE = "fetch_bitcoin_price"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("worker")
|
||||
|
||||
|
||||
async def install_schema() -> None:
|
||||
"""Install pgqueuer schema if not already present."""
|
||||
conn = await asyncpg.connect(ASYNCPG_DATABASE_URL)
|
||||
try:
|
||||
queries = Queries.from_asyncpg_connection(conn)
|
||||
# Check if schema is already installed by looking for the main table
|
||||
if not await queries.has_table("pgqueuer"):
|
||||
await queries.install()
|
||||
logger.info("pgqueuer schema installed")
|
||||
else:
|
||||
logger.info("pgqueuer schema already exists")
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def process_random_number_job(job: Job, db_pool: asyncpg.Pool) -> None:
|
||||
"""
|
||||
Process a random number job.
|
||||
|
||||
- Parse user_id from payload
|
||||
- Generate random number 0-100
|
||||
- Record execution duration
|
||||
- Store outcome in database
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Parse payload
|
||||
payload_str = job.payload.decode() if job.payload else "{}"
|
||||
try:
|
||||
payload = json.loads(payload_str)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Job {job.id}: Invalid JSON payload: {e}")
|
||||
return
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if user_id is None:
|
||||
logger.error(f"Job {job.id}: Missing user_id in payload")
|
||||
return
|
||||
|
||||
# Generate random number
|
||||
value = random.randint(0, 100)
|
||||
|
||||
# Calculate duration
|
||||
duration_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
|
||||
# Store outcome
|
||||
async with db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO random_number_outcomes
|
||||
(job_id, triggered_by_user_id, value, duration_ms, status, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||
""",
|
||||
job.id,
|
||||
user_id,
|
||||
value,
|
||||
duration_ms,
|
||||
"completed",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Job {job.id}: Generated random number {value} for user {user_id} "
|
||||
f"(duration: {duration_ms}ms)"
|
||||
)
|
||||
|
||||
|
||||
def register_job_handlers(qm: QueueManager, db_pool: asyncpg.Pool) -> None:
|
||||
"""Register all job handlers with the queue manager."""
|
||||
|
||||
@qm.entrypoint(JOB_RANDOM_NUMBER)
|
||||
async def handle_random_number(job: Job) -> None:
|
||||
"""Handle random_number job entrypoint."""
|
||||
await process_random_number_job(job, db_pool)
|
||||
|
||||
|
||||
async def process_bitcoin_price_job(db_pool: asyncpg.Pool) -> None:
|
||||
"""
|
||||
Fetch and store Bitcoin price from Bitfinex.
|
||||
|
||||
This function is designed to fail silently - exceptions are caught and logged
|
||||
so the scheduler can continue with the next scheduled run.
|
||||
"""
|
||||
try:
|
||||
price, timestamp = await fetch_btc_eur_price()
|
||||
|
||||
async with db_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO price_history
|
||||
(source, pair, price, timestamp, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (source, pair, timestamp) DO NOTHING
|
||||
""",
|
||||
SOURCE_BITFINEX,
|
||||
PAIR_BTC_EUR,
|
||||
price,
|
||||
timestamp,
|
||||
datetime.now(UTC),
|
||||
)
|
||||
|
||||
logger.info(f"Fetched BTC/EUR price: €{price:.2f}")
|
||||
except Exception as e:
|
||||
# Fail silently - next scheduled job will continue
|
||||
logger.error(f"Failed to fetch Bitcoin price: {e}")
|
||||
|
||||
|
||||
def register_scheduled_jobs(sm: SchedulerManager, db_pool: asyncpg.Pool) -> None:
|
||||
"""Register all scheduled jobs with the scheduler manager."""
|
||||
|
||||
# Run every minute: "* * * * *" means every minute of every hour of every day
|
||||
@sm.schedule(JOB_FETCH_BITCOIN_PRICE, "* * * * *")
|
||||
async def fetch_bitcoin_price(schedule: Schedule) -> None:
|
||||
"""Fetch Bitcoin price from Bitfinex every minute."""
|
||||
await process_bitcoin_price_job(db_pool)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Main worker entry point."""
|
||||
logger.info("Installing pgqueuer schema...")
|
||||
await install_schema()
|
||||
|
||||
logger.info("Connecting to database...")
|
||||
# Connection for queue manager
|
||||
queue_conn = await asyncpg.connect(ASYNCPG_DATABASE_URL)
|
||||
# Connection for scheduler manager
|
||||
scheduler_conn = await asyncpg.connect(ASYNCPG_DATABASE_URL)
|
||||
# Connection pool for application data
|
||||
db_pool = await asyncpg.create_pool(ASYNCPG_DATABASE_URL, min_size=1, max_size=5)
|
||||
|
||||
try:
|
||||
# Setup queue manager for on-demand jobs
|
||||
queue_driver = AsyncpgDriver(queue_conn)
|
||||
qm = QueueManager(queue_driver)
|
||||
register_job_handlers(qm, db_pool)
|
||||
|
||||
# Setup scheduler manager for periodic jobs
|
||||
scheduler_driver = AsyncpgDriver(scheduler_conn)
|
||||
sm = SchedulerManager(scheduler_driver)
|
||||
register_scheduled_jobs(sm, db_pool)
|
||||
|
||||
logger.info("Worker started, processing queue jobs and scheduled jobs...")
|
||||
|
||||
# Run both managers concurrently - if either fails, both stop
|
||||
queue_task = asyncio.create_task(qm.run(), name="queue_manager")
|
||||
scheduler_task = asyncio.create_task(sm.run(), name="scheduler_manager")
|
||||
|
||||
done, pending = await asyncio.wait(
|
||||
[queue_task, scheduler_task],
|
||||
return_when=asyncio.FIRST_EXCEPTION,
|
||||
)
|
||||
|
||||
# Cancel any pending tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
# Check for exceptions in completed tasks
|
||||
for task in done:
|
||||
exc = task.exception()
|
||||
if exc is not None:
|
||||
logger.error(f"Task '{task.get_name()}' failed: {exc}")
|
||||
raise exc
|
||||
finally:
|
||||
await queue_conn.close()
|
||||
await scheduler_conn.close()
|
||||
await db_pool.close()
|
||||
logger.info("Worker stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Add table
Add a link
Reference in a new issue