Phase 0.1: Remove backend deprecated code

- Delete routes: counter.py, sum.py
- Delete jobs.py and worker.py
- Delete tests: test_counter.py, test_jobs.py
- Update audit.py: keep only price-history endpoints
- Update models.py: remove VIEW_COUNTER, INCREMENT_COUNTER, USE_SUM permissions
- Update models.py: remove Counter, SumRecord, CounterRecord, RandomNumberOutcome models
- Update schemas.py: remove sum/counter related schemas
- Update main.py: remove deleted router imports
- Update test_permissions.py: remove tests for deprecated features
- Update test_price_history.py: remove worker-related tests
- Update conftest.py: remove mock_enqueue_job fixture
- Update auth.py: fix example in docstring
This commit is contained in:
counterweight 2025-12-22 18:07:14 +01:00
parent ea85198171
commit 5bad1e7e17
Signed by: counterweight
GPG key ID: 883EDBAA726BD96C
14 changed files with 35 additions and 1393 deletions

View file

@ -90,9 +90,9 @@ def require_permission(*required_permissions: Permission):
Dependency factory that checks if user has ALL required permissions.
Usage:
@app.get("/api/counter")
async def get_counter(
user: User = Depends(require_permission(Permission.VIEW_COUNTER))
@app.get("/api/profile")
async def get_profile(
user: User = Depends(require_permission(Permission.MANAGE_OWN_PROFILE))
):
...
"""

View file

@ -1,59 +0,0 @@
"""Job definitions and enqueueing utilities using pgqueuer."""
import asyncio
import json
import asyncpg
from pgqueuer.queries import Queries
from database import ASYNCPG_DATABASE_URL
# Job type constants
JOB_RANDOM_NUMBER = "random_number"
# Connection pool for job enqueueing (lazy initialized)
_pool: asyncpg.Pool | None = None
_pool_lock = asyncio.Lock()
async def get_job_pool() -> asyncpg.Pool:
"""Get or create the connection pool for job enqueueing."""
global _pool
if _pool is not None:
return _pool
async with _pool_lock:
# Double-check after acquiring lock
if _pool is None:
_pool = await asyncpg.create_pool(
ASYNCPG_DATABASE_URL, min_size=1, max_size=5
)
return _pool
async def close_job_pool() -> None:
"""Close the connection pool. Call on app shutdown."""
global _pool
if _pool is not None:
await _pool.close()
_pool = None
async def enqueue_random_number_job(user_id: int) -> int:
"""
Enqueue a random number job for the given user.
Args:
user_id: The ID of the user who triggered the job.
Returns:
The job ID.
Raises:
Exception: If enqueueing fails.
"""
pool = await get_job_pool()
async with pool.acquire() as conn:
queries = Queries.from_asyncpg_connection(conn)
payload = json.dumps({"user_id": user_id}).encode()
job_ids = await queries.enqueue(JOB_RANDOM_NUMBER, payload)
return job_ids[0]

View file

@ -6,16 +6,13 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from database import Base, engine
from jobs import close_job_pool
from routes import audit as audit_routes
from routes import auth as auth_routes
from routes import availability as availability_routes
from routes import booking as booking_routes
from routes import counter as counter_routes
from routes import invites as invites_routes
from routes import meta as meta_routes
from routes import profile as profile_routes
from routes import sum as sum_routes
from validate_constants import validate_shared_constants
@ -28,8 +25,6 @@ async def lifespan(app: FastAPI):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
# Cleanup on shutdown
await close_job_pool()
app = FastAPI(lifespan=lifespan)
@ -44,8 +39,6 @@ app.add_middleware(
# Include routers - modules with single router
app.include_router(auth_routes.router)
app.include_router(sum_routes.router)
app.include_router(counter_routes.router)
app.include_router(audit_routes.router)
app.include_router(profile_routes.router)
app.include_router(availability_routes.router)

View file

@ -30,13 +30,6 @@ class RoleConfig(TypedDict):
class Permission(str, PyEnum):
"""All available permissions in the system."""
# Counter permissions
VIEW_COUNTER = "view_counter"
INCREMENT_COUNTER = "increment_counter"
# Sum permissions
USE_SUM = "use_sum"
# Audit permissions
VIEW_AUDIT = "view_audit"
FETCH_PRICE = "fetch_price"
@ -93,11 +86,8 @@ ROLE_DEFINITIONS: dict[str, RoleConfig] = {
],
},
ROLE_REGULAR: {
"description": "Regular user with counter, sum, invite, and booking access",
"description": "Regular user with profile, invite, and booking access",
"permissions": [
Permission.VIEW_COUNTER,
Permission.INCREMENT_COUNTER,
Permission.USE_SUM,
Permission.MANAGE_OWN_PROFILE,
Permission.VIEW_OWN_INVITES,
Permission.BOOK_APPOINTMENT,
@ -231,42 +221,6 @@ class User(Base):
return [role.name for role in self.roles]
class Counter(Base):
__tablename__ = "counter"
id: Mapped[int] = mapped_column(Integer, primary_key=True, default=1)
value: Mapped[int] = mapped_column(Integer, default=0)
class SumRecord(Base):
__tablename__ = "sum_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
a: Mapped[float] = mapped_column(Float, nullable=False)
b: Mapped[float] = mapped_column(Float, nullable=False)
result: Mapped[float] = mapped_column(Float, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
class CounterRecord(Base):
__tablename__ = "counter_records"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
value_before: Mapped[int] = mapped_column(Integer, nullable=False)
value_after: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
class Invite(Base):
__tablename__ = "invites"
@ -359,27 +313,6 @@ class Appointment(Base):
)
class RandomNumberOutcome(Base):
"""Outcome of a random number job execution."""
__tablename__ = "random_number_outcomes"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
job_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
triggered_by_user_id: Mapped[int] = mapped_column(
Integer, ForeignKey("users.id"), nullable=False, index=True
)
triggered_by: Mapped[User] = relationship(
"User", foreign_keys=[triggered_by_user_id], lazy="joined"
)
value: Mapped[int] = mapped_column(Integer, nullable=False)
duration_ms: Mapped[int] = mapped_column(Integer, nullable=False)
status: Mapped[str] = mapped_column(String(20), nullable=False, default="completed")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(UTC)
)
class PriceHistory(Base):
"""Price history records from external exchanges."""

View file

@ -1,98 +1,18 @@
"""Audit routes for viewing action records."""
"""Audit routes for price history."""
from collections.abc import Callable
from typing import TypeVar
from fastapi import APIRouter, Depends, Query
from pydantic import BaseModel
from sqlalchemy import desc, func, select
from fastapi import APIRouter, Depends
from sqlalchemy import desc, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission
from database import get_db
from models import (
CounterRecord,
Permission,
PriceHistory,
RandomNumberOutcome,
SumRecord,
User,
)
from pagination import (
calculate_offset,
calculate_total_pages,
create_paginated_response,
)
from models import Permission, PriceHistory, User
from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price
from schemas import (
CounterRecordResponse,
PaginatedCounterRecords,
PaginatedSumRecords,
PriceHistoryResponse,
RandomNumberOutcomeResponse,
SumRecordResponse,
)
from schemas import PriceHistoryResponse
router = APIRouter(prefix="/api/audit", tags=["audit"])
R = TypeVar("R", bound=BaseModel)
async def paginate_with_user_email(
db: AsyncSession,
model: type[SumRecord] | type[CounterRecord],
page: int,
per_page: int,
row_mapper: Callable[..., R],
) -> tuple[list[R], int, int]:
"""
Generic pagination helper for audit records that need user email.
Returns: (records, total, total_pages)
"""
# Get total count
count_result = await db.execute(select(func.count(model.id)))
total = count_result.scalar() or 0
# Get paginated records with user email
offset = calculate_offset(page, per_page)
query = (
select(model, User.email)
.join(User, model.user_id == User.id)
.order_by(desc(model.created_at))
.offset(offset)
.limit(per_page)
)
result = await db.execute(query)
rows = result.all()
records: list[R] = [row_mapper(record, email) for record, email in rows]
return records, total, calculate_total_pages(total, per_page)
def _to_counter_record_response(
record: CounterRecord, email: str
) -> CounterRecordResponse:
return CounterRecordResponse(
id=record.id,
user_email=email,
value_before=record.value_before,
value_after=record.value_after,
created_at=record.created_at,
)
def _to_sum_record_response(record: SumRecord, email: str) -> SumRecordResponse:
return SumRecordResponse(
id=record.id,
user_email=email,
a=record.a,
b=record.b,
result=record.result,
created_at=record.created_at,
)
def _to_price_history_response(record: PriceHistory) -> PriceHistoryResponse:
return PriceHistoryResponse(
@ -105,64 +25,6 @@ def _to_price_history_response(record: PriceHistory) -> PriceHistoryResponse:
)
@router.get("/counter", response_model=PaginatedCounterRecords)
async def get_counter_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
) -> PaginatedCounterRecords:
"""Get paginated counter action records."""
records, total, _ = await paginate_with_user_email(
db, CounterRecord, page, per_page, _to_counter_record_response
)
return create_paginated_response(records, total, page, per_page)
@router.get("/sum", response_model=PaginatedSumRecords)
async def get_sum_records(
page: int = Query(1, ge=1),
per_page: int = Query(10, ge=1, le=100),
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
) -> PaginatedSumRecords:
"""Get paginated sum action records."""
records, total, _ = await paginate_with_user_email(
db, SumRecord, page, per_page, _to_sum_record_response
)
return create_paginated_response(records, total, page, per_page)
@router.get("/random-jobs", response_model=list[RandomNumberOutcomeResponse])
async def get_random_job_outcomes(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_AUDIT)),
) -> list[RandomNumberOutcomeResponse]:
"""Get all random number job outcomes, newest first."""
# Explicit join to avoid N+1 query
query = (
select(RandomNumberOutcome, User.email)
.join(User, RandomNumberOutcome.triggered_by_user_id == User.id)
.order_by(desc(RandomNumberOutcome.created_at))
)
result = await db.execute(query)
rows = result.all()
return [
RandomNumberOutcomeResponse(
id=outcome.id,
job_id=outcome.job_id,
triggered_by_user_id=outcome.triggered_by_user_id,
triggered_by_email=email,
value=outcome.value,
duration_ms=outcome.duration_ms,
status=outcome.status,
created_at=outcome.created_at,
)
for outcome, email in rows
]
# =============================================================================
# Price History Endpoints
# =============================================================================

View file

@ -1,64 +0,0 @@
"""Counter routes."""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission
from database import get_db
from jobs import enqueue_random_number_job
from models import Counter, CounterRecord, Permission, User
router = APIRouter(prefix="/api/counter", tags=["counter"])
async def get_or_create_counter(db: AsyncSession) -> Counter:
"""Get the singleton counter, creating it if it doesn't exist."""
result = await db.execute(select(Counter).where(Counter.id == 1))
counter = result.scalar_one_or_none()
if not counter:
counter = Counter(id=1, value=0)
db.add(counter)
await db.commit()
await db.refresh(counter)
return counter
@router.get("")
async def get_counter(
db: AsyncSession = Depends(get_db),
_current_user: User = Depends(require_permission(Permission.VIEW_COUNTER)),
) -> dict[str, int]:
"""Get the current counter value."""
counter = await get_or_create_counter(db)
return {"value": counter.value}
@router.post("/increment")
async def increment_counter(
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_permission(Permission.INCREMENT_COUNTER)),
) -> dict[str, int]:
"""Increment the counter, record the action, and enqueue a random number job."""
counter = await get_or_create_counter(db)
value_before = counter.value
counter.value += 1
record = CounterRecord(
user_id=current_user.id,
value_before=value_before,
value_after=counter.value,
)
db.add(record)
# Enqueue random number job - if this fails, the request fails
try:
await enqueue_random_number_job(current_user.id)
except Exception as e:
await db.rollback()
raise HTTPException(
status_code=500, detail=f"Failed to enqueue job: {e}"
) from e
await db.commit()
return {"value": counter.value}

View file

@ -1,30 +0,0 @@
"""Sum calculation routes."""
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from auth import require_permission
from database import get_db
from models import Permission, SumRecord, User
from schemas import SumRequest, SumResponse
router = APIRouter(prefix="/api/sum", tags=["sum"])
@router.post("", response_model=SumResponse)
async def calculate_sum(
data: SumRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(require_permission(Permission.USE_SUM)),
) -> SumResponse:
"""Calculate the sum of two numbers and record it."""
result = data.a + data.b
record = SumRecord(
user_id=current_user.id,
a=data.a,
b=data.b,
result=result,
)
db.add(record)
await db.commit()
return SumResponse(a=data.a, b=data.b, result=result)

View file

@ -37,42 +37,6 @@ class RegisterWithInvite(BaseModel):
invite_identifier: str
class SumRequest(BaseModel):
"""Request model for sum calculation."""
a: float
b: float
class SumResponse(BaseModel):
"""Response model for sum calculation."""
a: float
b: float
result: float
class CounterRecordResponse(BaseModel):
"""Response model for a counter audit record."""
id: int
user_email: str
value_before: int
value_after: int
created_at: datetime
class SumRecordResponse(BaseModel):
"""Response model for a sum audit record."""
id: int
user_email: str
a: float
b: float
result: float
created_at: datetime
RecordT = TypeVar("RecordT", bound=BaseModel)
@ -86,10 +50,6 @@ class PaginatedResponse(BaseModel, Generic[RecordT]):
total_pages: int
PaginatedCounterRecords = PaginatedResponse[CounterRecordResponse]
PaginatedSumRecords = PaginatedResponse[SumRecordResponse]
class ProfileResponse(BaseModel):
"""Response model for profile data."""
@ -258,24 +218,6 @@ class AppointmentResponse(BaseModel):
PaginatedAppointments = PaginatedResponse[AppointmentResponse]
# =============================================================================
# Random Number Job Schemas
# =============================================================================
class RandomNumberOutcomeResponse(BaseModel):
"""Response model for a random number job outcome."""
id: int
job_id: int
triggered_by_user_id: int
triggered_by_email: str
value: int
duration_ms: int
status: str
created_at: datetime
# =============================================================================
# Price History Schemas
# =============================================================================

View file

@ -1,6 +1,5 @@
import os
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, patch
# Set required env vars before importing app
os.environ.setdefault("SECRET_KEY", "test-secret-key-for-testing-only")
@ -239,18 +238,3 @@ async def user_no_roles(client_factory):
"cookies": dict(response.cookies),
"response": response,
}
@pytest.fixture
def mock_enqueue_job():
"""Mock job enqueueing for tests that hit the counter increment endpoint.
pgqueuer requires PostgreSQL-specific features that aren't available
in the test database setup. We mock the enqueue function to avoid
connection issues while still testing the counter logic.
Tests that call POST /api/counter/increment must use this fixture.
"""
mock = AsyncMock(return_value=1) # Return a fake job ID
with patch("routes.counter.enqueue_random_number_job", mock):
yield mock

View file

@ -1,239 +0,0 @@
"""Tests for counter endpoints.
Note: Registration now requires an invite code.
"""
import pytest
from auth import COOKIE_NAME
from models import ROLE_REGULAR
from tests.conftest import create_user_with_roles
from tests.helpers import create_invite_for_godfather, unique_email
@pytest.mark.asyncio
async def test_increment_enqueues_job_with_user_id(client_factory, mock_enqueue_job):
"""Verify that incrementing the counter enqueues a job with the user's ID."""
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
json={
"email": unique_email(),
"password": "testpass123",
"invite_identifier": invite_code,
},
)
cookies = dict(reg.cookies)
# Get user ID from the me endpoint
async with client_factory.create(cookies=cookies) as authed:
me_response = await authed.get("/api/auth/me")
user_id = me_response.json()["id"]
# Increment counter
response = await authed.post("/api/counter/increment")
assert response.status_code == 200
# Verify enqueue was called with the correct user_id
mock_enqueue_job.assert_called_once_with(user_id)
# Protected endpoint tests - without auth
@pytest.mark.asyncio
async def test_get_counter_requires_auth(client):
response = await client.get("/api/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_increment_counter_requires_auth(client):
response = await client.post("/api/counter/increment")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_get_counter_invalid_cookie(client_factory):
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed:
response = await authed.get("/api/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_increment_counter_invalid_cookie(client_factory):
async with client_factory.create(cookies={COOKIE_NAME: "invalidtoken"}) as authed:
response = await authed.post("/api/counter/increment")
assert response.status_code == 401
# Authenticated counter tests
@pytest.mark.asyncio
async def test_get_counter_authenticated(client_factory):
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
json={
"email": unique_email(),
"password": "testpass123",
"invite_identifier": invite_code,
},
)
cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed:
response = await authed.get("/api/counter")
assert response.status_code == 200
assert "value" in response.json()
@pytest.mark.asyncio
async def test_increment_counter(client_factory, mock_enqueue_job):
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
json={
"email": unique_email(),
"password": "testpass123",
"invite_identifier": invite_code,
},
)
cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed:
# Get current value
before = await authed.get("/api/counter")
before_value = before.json()["value"]
# Increment
response = await authed.post("/api/counter/increment")
assert response.status_code == 200
assert response.json()["value"] == before_value + 1
@pytest.mark.asyncio
async def test_increment_counter_multiple(client_factory, mock_enqueue_job):
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
json={
"email": unique_email(),
"password": "testpass123",
"invite_identifier": invite_code,
},
)
cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed:
# Get starting value
before = await authed.get("/api/counter")
start = before.json()["value"]
# Increment 3 times
await authed.post("/api/counter/increment")
await authed.post("/api/counter/increment")
response = await authed.post("/api/counter/increment")
assert response.json()["value"] == start + 3
@pytest.mark.asyncio
async def test_get_counter_after_increment(client_factory, mock_enqueue_job):
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite_code = await create_invite_for_godfather(db, godfather.id)
reg = await client_factory.post(
"/api/auth/register",
json={
"email": unique_email(),
"password": "testpass123",
"invite_identifier": invite_code,
},
)
cookies = dict(reg.cookies)
async with client_factory.create(cookies=cookies) as authed:
before = await authed.get("/api/counter")
start = before.json()["value"]
await authed.post("/api/counter/increment")
await authed.post("/api/counter/increment")
response = await authed.get("/api/counter")
assert response.json()["value"] == start + 2
# Counter is shared between users
@pytest.mark.asyncio
async def test_counter_shared_between_users(client_factory, mock_enqueue_job):
# Create godfather and invites for two users
async with client_factory.get_db_session() as db:
godfather = await create_user_with_roles(
db, unique_email("gf"), "pass123", [ROLE_REGULAR]
)
invite1 = await create_invite_for_godfather(db, godfather.id)
invite2 = await create_invite_for_godfather(db, godfather.id)
# Create first user
reg1 = await client_factory.post(
"/api/auth/register",
json={
"email": unique_email("share1"),
"password": "testpass123",
"invite_identifier": invite1,
},
)
cookies1 = dict(reg1.cookies)
async with client_factory.create(cookies=cookies1) as user1:
# Get starting value
before = await user1.get("/api/counter")
start = before.json()["value"]
await user1.post("/api/counter/increment")
await user1.post("/api/counter/increment")
# Create second user - should see the increments
reg2 = await client_factory.post(
"/api/auth/register",
json={
"email": unique_email("share2"),
"password": "testpass123",
"invite_identifier": invite2,
},
)
cookies2 = dict(reg2.cookies)
async with client_factory.create(cookies=cookies2) as user2:
response = await user2.get("/api/counter")
assert response.json()["value"] == start + 2
# Second user increments
await user2.post("/api/counter/increment")
# First user sees the increment
async with client_factory.create(cookies=cookies1) as user1:
response = await user1.get("/api/counter")
assert response.json()["value"] == start + 3

View file

@ -1,176 +0,0 @@
"""Tests for job handler logic."""
import json
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
import pytest
from worker import process_random_number_job
def create_mock_pool(mock_conn: AsyncMock) -> MagicMock:
"""Create a mock asyncpg pool with proper async context manager behavior."""
mock_pool = MagicMock()
@asynccontextmanager
async def mock_acquire():
yield mock_conn
mock_pool.acquire = mock_acquire
return mock_pool
class TestRandomNumberJobHandler:
"""Tests for the random number job handler logic."""
@pytest.mark.asyncio
async def test_generates_random_number_in_range(self):
"""Verify random number is in range [0, 100]."""
# Create mock job
job = MagicMock()
job.id = 123
job.payload = json.dumps({"user_id": 1}).encode()
# Create mock db pool
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
# Run the job handler
await process_random_number_job(job, mock_pool)
# Verify execute was called
mock_conn.execute.assert_called_once()
call_args = mock_conn.execute.call_args
# Extract the value argument (position 3 in the args)
# Args: (query, job_id, user_id, value, duration_ms, status)
value = call_args[0][3]
assert 0 <= value <= 100, f"Value {value} is not in range [0, 100]"
@pytest.mark.asyncio
async def test_stores_correct_user_id(self):
"""Verify the correct user_id is stored in the outcome."""
user_id = 42
job = MagicMock()
job.id = 123
job.payload = json.dumps({"user_id": user_id}).encode()
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
await process_random_number_job(job, mock_pool)
mock_conn.execute.assert_called_once()
call_args = mock_conn.execute.call_args
# Args: (query, job_id, user_id, value, duration_ms, status)
stored_user_id = call_args[0][2]
assert stored_user_id == user_id
@pytest.mark.asyncio
async def test_stores_job_id(self):
"""Verify the job_id is stored in the outcome."""
job_id = 456
job = MagicMock()
job.id = job_id
job.payload = json.dumps({"user_id": 1}).encode()
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
await process_random_number_job(job, mock_pool)
mock_conn.execute.assert_called_once()
call_args = mock_conn.execute.call_args
# Args: (query, job_id, user_id, value, duration_ms, status)
stored_job_id = call_args[0][1]
assert stored_job_id == job_id
@pytest.mark.asyncio
async def test_stores_status_completed(self):
"""Verify the status is set to 'completed'."""
job = MagicMock()
job.id = 123
job.payload = json.dumps({"user_id": 1}).encode()
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
await process_random_number_job(job, mock_pool)
mock_conn.execute.assert_called_once()
call_args = mock_conn.execute.call_args
# Args: (query, job_id, user_id, value, duration_ms, status)
status = call_args[0][5]
assert status == "completed"
@pytest.mark.asyncio
async def test_records_duration_ms(self):
"""Verify duration_ms is recorded (should be >= 0)."""
job = MagicMock()
job.id = 123
job.payload = json.dumps({"user_id": 1}).encode()
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
await process_random_number_job(job, mock_pool)
mock_conn.execute.assert_called_once()
call_args = mock_conn.execute.call_args
# Args: (query, job_id, user_id, value, duration_ms, status)
duration_ms = call_args[0][4]
assert isinstance(duration_ms, int)
assert duration_ms >= 0
@pytest.mark.asyncio
async def test_missing_user_id_does_not_insert(self):
"""Verify no insert happens if user_id is missing from payload."""
job = MagicMock()
job.id = 123
job.payload = json.dumps({}).encode() # Missing user_id
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
await process_random_number_job(job, mock_pool)
# Should not have called execute
mock_conn.execute.assert_not_called()
@pytest.mark.asyncio
async def test_empty_payload_does_not_insert(self):
"""Verify no insert happens with empty payload."""
job = MagicMock()
job.id = 123
job.payload = None
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
await process_random_number_job(job, mock_pool)
# Should not have called execute
mock_conn.execute.assert_not_called()
@pytest.mark.asyncio
async def test_malformed_json_payload_does_not_insert(self):
"""Verify no insert happens with malformed JSON payload."""
job = MagicMock()
job.id = 123
job.payload = b"not valid json {"
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
await process_random_number_job(job, mock_pool)
# Should not have called execute
mock_conn.execute.assert_not_called()

View file

@ -50,10 +50,10 @@ class TestRoleAssignment:
data = response.json()
permissions = data["permissions"]
# Should have counter and sum permissions
assert Permission.VIEW_COUNTER.value in permissions
assert Permission.INCREMENT_COUNTER.value in permissions
assert Permission.USE_SUM.value in permissions
# Should have profile and booking permissions
assert Permission.MANAGE_OWN_PROFILE.value in permissions
assert Permission.BOOK_APPOINTMENT.value in permissions
assert Permission.VIEW_OWN_APPOINTMENTS.value in permissions
# Should NOT have audit permission
assert Permission.VIEW_AUDIT.value not in permissions
@ -69,10 +69,8 @@ class TestRoleAssignment:
# Should have audit permission
assert Permission.VIEW_AUDIT.value in permissions
# Should NOT have counter/sum permissions
assert Permission.VIEW_COUNTER.value not in permissions
assert Permission.INCREMENT_COUNTER.value not in permissions
assert Permission.USE_SUM.value not in permissions
# Should NOT have booking permissions (those are for regular users)
assert Permission.BOOK_APPOINTMENT.value not in permissions
@pytest.mark.asyncio
async def test_user_with_no_roles_has_no_permissions(
@ -86,124 +84,6 @@ class TestRoleAssignment:
assert data["permissions"] == []
# =============================================================================
# Counter Endpoint Access Tests
# =============================================================================
class TestCounterAccess:
"""Test access control for counter endpoints."""
@pytest.mark.asyncio
async def test_regular_user_can_view_counter(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/counter")
assert response.status_code == 200
assert "value" in response.json()
@pytest.mark.asyncio
async def test_regular_user_can_increment_counter(
self, client_factory, regular_user, mock_enqueue_job
):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post("/api/counter/increment")
assert response.status_code == 200
assert "value" in response.json()
@pytest.mark.asyncio
async def test_admin_cannot_view_counter(self, client_factory, admin_user):
"""Admin users should be forbidden from counter endpoints."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/counter")
assert response.status_code == 403
assert "permission" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_admin_cannot_increment_counter(self, client_factory, admin_user):
"""Admin users should be forbidden from incrementing counter."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post("/api/counter/increment")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_view_counter(
self, client_factory, user_no_roles
):
"""Users with no roles should be forbidden."""
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/counter")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_unauthenticated_cannot_view_counter(self, client):
"""Unauthenticated requests should get 401."""
response = await client.get("/api/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_unauthenticated_cannot_increment_counter(self, client):
"""Unauthenticated requests should get 401."""
response = await client.post("/api/counter/increment")
assert response.status_code == 401
# =============================================================================
# Sum Endpoint Access Tests
# =============================================================================
class TestSumAccess:
"""Test access control for sum endpoint."""
@pytest.mark.asyncio
async def test_regular_user_can_use_sum(self, client_factory, regular_user):
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.post(
"/api/sum",
json={"a": 5, "b": 3},
)
assert response.status_code == 200
data = response.json()
assert data["result"] == 8
@pytest.mark.asyncio
async def test_admin_cannot_use_sum(self, client_factory, admin_user):
"""Admin users should be forbidden from sum endpoint."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.post(
"/api/sum",
json={"a": 5, "b": 3},
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_use_sum(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.post(
"/api/sum",
json={"a": 5, "b": 3},
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_unauthenticated_cannot_use_sum(self, client):
response = await client.post(
"/api/sum",
json={"a": 5, "b": 3},
)
assert response.status_code == 401
# =============================================================================
# Audit Endpoint Access Tests
# =============================================================================
@ -213,89 +93,37 @@ class TestAuditAccess:
"""Test access control for audit endpoints."""
@pytest.mark.asyncio
async def test_admin_can_view_counter_audit(self, client_factory, admin_user):
async def test_admin_can_view_price_history(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
response = await client.get("/api/audit/price-history")
assert response.status_code == 200
data = response.json()
assert "records" in data
assert "total" in data
# Returns a list
assert isinstance(response.json(), list)
@pytest.mark.asyncio
async def test_admin_can_view_sum_audit(self, client_factory, admin_user):
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/sum")
assert response.status_code == 200
data = response.json()
assert "records" in data
assert "total" in data
@pytest.mark.asyncio
async def test_regular_user_cannot_view_counter_audit(
async def test_regular_user_cannot_view_price_history(
self, client_factory, regular_user
):
"""Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
response = await client.get("/api/audit/price-history")
assert response.status_code == 403
assert "permission" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_regular_user_cannot_view_sum_audit(
self, client_factory, regular_user
):
"""Regular users should be forbidden from audit endpoints."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/sum")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_user_without_roles_cannot_view_audit(
self, client_factory, user_no_roles
):
async with client_factory.create(cookies=user_no_roles["cookies"]) as client:
response = await client.get("/api/audit/counter")
response = await client.get("/api/audit/price-history")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_unauthenticated_cannot_view_counter_audit(self, client):
response = await client.get("/api/audit/counter")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_unauthenticated_cannot_view_sum_audit(self, client):
response = await client.get("/api/audit/sum")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_admin_can_view_random_jobs(self, client_factory, admin_user):
"""Admin should be able to view random job outcomes."""
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/random-jobs")
assert response.status_code == 200
# Returns a list (no pagination)
assert isinstance(response.json(), list)
@pytest.mark.asyncio
async def test_regular_user_cannot_view_random_jobs(
self, client_factory, regular_user
):
"""Regular users should be forbidden from random-jobs endpoint."""
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/random-jobs")
assert response.status_code == 403
@pytest.mark.asyncio
async def test_unauthenticated_cannot_view_random_jobs(self, client):
"""Unauthenticated users should get 401."""
response = await client.get("/api/audit/random-jobs")
async def test_unauthenticated_cannot_view_price_history(self, client):
response = await client.get("/api/audit/price-history")
assert response.status_code == 401
@ -320,18 +148,18 @@ class TestSecurityBypassAttempts:
"""
# Regular user tries to access audit endpoint
async with client_factory.create(cookies=regular_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
response = await client.get("/api/audit/price-history")
# Should be denied regardless of any manipulation attempts
assert response.status_code == 403
@pytest.mark.asyncio
async def test_cannot_access_counter_with_expired_session(self, client_factory):
async def test_cannot_access_with_expired_session(self, client_factory):
"""Test that invalid/expired tokens are rejected."""
fake_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5OTk5IiwiZXhwIjoxfQ.invalid"
async with client_factory.create(cookies={"auth_token": fake_token}) as client:
response = await client.get("/api/counter")
response = await client.get("/api/profile")
assert response.status_code == 401
@ -348,7 +176,7 @@ class TestSecurityBypassAttempts:
async with client_factory.create(
cookies={"auth_token": tampered_token}
) as client:
response = await client.get("/api/counter")
response = await client.get("/api/profile")
assert response.status_code == 401
@ -386,7 +214,7 @@ class TestSecurityBypassAttempts:
# Try to access audit with this new user
async with client_factory.create(cookies=dict(response.cookies)) as client:
audit_response = await client.get("/api/audit/counter")
audit_response = await client.get("/api/audit/price-history")
assert audit_response.status_code == 403
@ -452,10 +280,10 @@ class TestSecurityBypassAttempts:
)
cookies = dict(login_response.cookies)
# Verify can access counter but not audit
# Verify can access profile but not audit
async with client_factory.create(cookies=cookies) as client:
assert (await client.get("/api/counter")).status_code == 200
assert (await client.get("/api/audit/counter")).status_code == 403
assert (await client.get("/api/profile")).status_code == 200
assert (await client.get("/api/audit/price-history")).status_code == 403
# Change user's role from regular to admin
async with client_factory.get_db_session() as db:
@ -468,62 +296,7 @@ class TestSecurityBypassAttempts:
user.roles = [admin_role] # Replace roles with admin only
await db.commit()
# Now should have audit access but not counter access
# Now should have audit access but not profile access (admin doesn't have MANAGE_OWN_PROFILE)
async with client_factory.create(cookies=cookies) as client:
assert (await client.get("/api/audit/counter")).status_code == 200
assert (await client.get("/api/counter")).status_code == 403
# =============================================================================
# Audit Record Tests
# =============================================================================
class TestAuditRecords:
"""Test that actions are properly recorded in audit logs."""
@pytest.mark.asyncio
async def test_counter_increment_creates_audit_record(
self, client_factory, regular_user, admin_user, mock_enqueue_job
):
"""Verify that counter increments are recorded and visible in audit."""
# Regular user increments counter
async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post("/api/counter/increment")
# Admin checks audit
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/counter")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
# Find record for our user
records = data["records"]
user_records = [r for r in records if r["user_email"] == regular_user["email"]]
assert len(user_records) >= 1
@pytest.mark.asyncio
async def test_sum_operation_creates_audit_record(
self, client_factory, regular_user, admin_user
):
"""Verify that sum operations are recorded and visible in audit."""
# Regular user uses sum
async with client_factory.create(cookies=regular_user["cookies"]) as client:
await client.post("/api/sum", json={"a": 10, "b": 20})
# Admin checks audit
async with client_factory.create(cookies=admin_user["cookies"]) as client:
response = await client.get("/api/audit/sum")
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
# Find record with our values
records = data["records"]
matching = [
r for r in records if r["a"] == 10 and r["b"] == 20 and r["result"] == 30
]
assert len(matching) >= 1
assert (await client.get("/api/audit/price-history")).status_code == 200
assert (await client.get("/api/profile")).status_code == 403

View file

@ -1,6 +1,5 @@
"""Tests for price history feature."""
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@ -9,7 +8,6 @@ import pytest
from models import PriceHistory
from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price
from worker import process_bitcoin_price_job
def create_mock_httpx_client(
@ -293,76 +291,3 @@ class TestManualFetch:
data = response.json()
assert data["id"] == existing_id
assert data["price"] == 90000.0 # Original price, not the new one
def create_mock_pool(mock_conn: AsyncMock) -> MagicMock:
"""Create a mock asyncpg pool with proper async context manager behavior."""
mock_pool = MagicMock()
@asynccontextmanager
async def mock_acquire():
yield mock_conn
mock_pool.acquire = mock_acquire
return mock_pool
class TestProcessBitcoinPriceJob:
"""Tests for the scheduled Bitcoin price job handler."""
@pytest.mark.asyncio
async def test_stores_price_on_success(self):
"""Verify price is stored in database on successful fetch."""
mock_http_client = create_mock_httpx_client(
json_response=create_bitfinex_ticker_response(95000.0)
)
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
with patch("price_fetcher.httpx.AsyncClient", return_value=mock_http_client):
await process_bitcoin_price_job(mock_pool)
# Verify execute was called with correct values
mock_conn.execute.assert_called_once()
call_args = mock_conn.execute.call_args
# Check the SQL parameters
assert call_args[0][1] == SOURCE_BITFINEX # source
assert call_args[0][2] == PAIR_BTC_EUR # pair
assert call_args[0][3] == 95000.0 # price
@pytest.mark.asyncio
async def test_fails_silently_on_api_error(self):
"""Verify no exception is raised and no DB insert on API error."""
import httpx
error = httpx.HTTPStatusError(
"Server Error", request=MagicMock(), response=MagicMock()
)
mock_http_client = create_mock_httpx_client(raise_for_status_error=error)
mock_conn = AsyncMock()
mock_pool = create_mock_pool(mock_conn)
with patch("price_fetcher.httpx.AsyncClient", return_value=mock_http_client):
# Should not raise an exception
await process_bitcoin_price_job(mock_pool)
# Should not have called execute
mock_conn.execute.assert_not_called()
@pytest.mark.asyncio
async def test_fails_silently_on_db_error(self):
"""Verify no exception is raised on database error."""
mock_http_client = create_mock_httpx_client(
json_response=create_bitfinex_ticker_response(95000.0)
)
mock_conn = AsyncMock()
mock_conn.execute.side_effect = Exception("Database connection error")
mock_pool = create_mock_pool(mock_conn)
with patch("price_fetcher.httpx.AsyncClient", return_value=mock_http_client):
# Should not raise an exception
await process_bitcoin_price_job(mock_pool)

View file

@ -1,202 +0,0 @@
"""Background job worker using pgqueuer."""
import asyncio
import contextlib
import json
import logging
import random
import time
from datetime import UTC, datetime
import asyncpg
from pgqueuer import Job, QueueManager, SchedulerManager
from pgqueuer.db import AsyncpgDriver
from pgqueuer.models import Schedule
from pgqueuer.queries import Queries
from database import ASYNCPG_DATABASE_URL
from jobs import JOB_RANDOM_NUMBER
from price_fetcher import PAIR_BTC_EUR, SOURCE_BITFINEX, fetch_btc_eur_price
# Scheduled job type (internal to worker, not enqueued externally)
JOB_FETCH_BITCOIN_PRICE = "fetch_bitcoin_price"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("worker")
async def install_schema() -> None:
"""Install pgqueuer schema if not already present."""
conn = await asyncpg.connect(ASYNCPG_DATABASE_URL)
try:
queries = Queries.from_asyncpg_connection(conn)
# Check if schema is already installed by looking for the main table
if not await queries.has_table("pgqueuer"):
await queries.install()
logger.info("pgqueuer schema installed")
else:
logger.info("pgqueuer schema already exists")
finally:
await conn.close()
async def process_random_number_job(job: Job, db_pool: asyncpg.Pool) -> None:
"""
Process a random number job.
- Parse user_id from payload
- Generate random number 0-100
- Record execution duration
- Store outcome in database
"""
start_time = time.perf_counter()
# Parse payload
payload_str = job.payload.decode() if job.payload else "{}"
try:
payload = json.loads(payload_str)
except json.JSONDecodeError as e:
logger.error(f"Job {job.id}: Invalid JSON payload: {e}")
return
user_id = payload.get("user_id")
if user_id is None:
logger.error(f"Job {job.id}: Missing user_id in payload")
return
# Generate random number
value = random.randint(0, 100)
# Calculate duration
duration_ms = int((time.perf_counter() - start_time) * 1000)
# Store outcome
async with db_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO random_number_outcomes
(job_id, triggered_by_user_id, value, duration_ms, status, created_at)
VALUES ($1, $2, $3, $4, $5, NOW())
""",
job.id,
user_id,
value,
duration_ms,
"completed",
)
logger.info(
f"Job {job.id}: Generated random number {value} for user {user_id} "
f"(duration: {duration_ms}ms)"
)
def register_job_handlers(qm: QueueManager, db_pool: asyncpg.Pool) -> None:
"""Register all job handlers with the queue manager."""
@qm.entrypoint(JOB_RANDOM_NUMBER)
async def handle_random_number(job: Job) -> None:
"""Handle random_number job entrypoint."""
await process_random_number_job(job, db_pool)
async def process_bitcoin_price_job(db_pool: asyncpg.Pool) -> None:
"""
Fetch and store Bitcoin price from Bitfinex.
This function is designed to fail silently - exceptions are caught and logged
so the scheduler can continue with the next scheduled run.
"""
try:
price, timestamp = await fetch_btc_eur_price()
async with db_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO price_history
(source, pair, price, timestamp, created_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (source, pair, timestamp) DO NOTHING
""",
SOURCE_BITFINEX,
PAIR_BTC_EUR,
price,
timestamp,
datetime.now(UTC),
)
logger.info(f"Fetched BTC/EUR price: €{price:.2f}")
except Exception as e:
# Fail silently - next scheduled job will continue
logger.error(f"Failed to fetch Bitcoin price: {e}")
def register_scheduled_jobs(sm: SchedulerManager, db_pool: asyncpg.Pool) -> None:
"""Register all scheduled jobs with the scheduler manager."""
# Run every minute: "* * * * *" means every minute of every hour of every day
@sm.schedule(JOB_FETCH_BITCOIN_PRICE, "* * * * *")
async def fetch_bitcoin_price(schedule: Schedule) -> None:
"""Fetch Bitcoin price from Bitfinex every minute."""
await process_bitcoin_price_job(db_pool)
async def main() -> None:
"""Main worker entry point."""
logger.info("Installing pgqueuer schema...")
await install_schema()
logger.info("Connecting to database...")
# Connection for queue manager
queue_conn = await asyncpg.connect(ASYNCPG_DATABASE_URL)
# Connection for scheduler manager
scheduler_conn = await asyncpg.connect(ASYNCPG_DATABASE_URL)
# Connection pool for application data
db_pool = await asyncpg.create_pool(ASYNCPG_DATABASE_URL, min_size=1, max_size=5)
try:
# Setup queue manager for on-demand jobs
queue_driver = AsyncpgDriver(queue_conn)
qm = QueueManager(queue_driver)
register_job_handlers(qm, db_pool)
# Setup scheduler manager for periodic jobs
scheduler_driver = AsyncpgDriver(scheduler_conn)
sm = SchedulerManager(scheduler_driver)
register_scheduled_jobs(sm, db_pool)
logger.info("Worker started, processing queue jobs and scheduled jobs...")
# Run both managers concurrently - if either fails, both stop
queue_task = asyncio.create_task(qm.run(), name="queue_manager")
scheduler_task = asyncio.create_task(sm.run(), name="scheduler_manager")
done, pending = await asyncio.wait(
[queue_task, scheduler_task],
return_when=asyncio.FIRST_EXCEPTION,
)
# Cancel any pending tasks
for task in pending:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
# Check for exceptions in completed tasks
for task in done:
exc = task.exception()
if exc is not None:
logger.error(f"Task '{task.get_name()}' failed: {exc}")
raise exc
finally:
await queue_conn.close()
await scheduler_conn.close()
await db_pool.close()
logger.info("Worker stopped")
if __name__ == "__main__":
asyncio.run(main())