Minor fixes

This commit is contained in:
adlee-was-taken
2026-04-04 16:29:20 -04:00
parent 05382c4081
commit 4607ff27dd
23 changed files with 1772 additions and 28 deletions

View File

@@ -0,0 +1,169 @@
"""Rate limiter and API call budget tracker.
Enforces per-platform rate limits and a global call budget so the agent
can't hammer APIs or run up unbounded costs.
"""
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass, field
from sentiment_agent.config import RateLimitConfig
class BudgetExhaustedError(Exception):
"""Raised when the global API call budget is spent."""
class RateLimitExceededError(Exception):
"""Raised when a platform's rate limit is hit and cooldown hasn't elapsed."""
@dataclass
class _PlatformState:
"""Tracks call timestamps and active request count for one platform."""
config: RateLimitConfig
call_timestamps: list[float] = field(default_factory=list)
active_requests: int = 0
last_429_at: float = 0.0
class RateLimiter:
"""Manages rate limiting across all platforms + a global call budget.
Usage:
limiter = RateLimiter(max_total_calls=50)
limiter.register_platform("reddit", RateLimitConfig(...))
async with limiter.acquire("reddit"):
await do_reddit_call()
"""
def __init__(self, max_total_calls: int = 50):
self._max_total = max_total_calls
self._total_calls = 0
self._platforms: dict[str, _PlatformState] = {}
self._lock = asyncio.Lock()
@property
def total_calls(self) -> int:
return self._total_calls
@property
def remaining_calls(self) -> int:
return max(0, self._max_total - self._total_calls)
def register_platform(self, name: str, config: RateLimitConfig) -> None:
self._platforms[name] = _PlatformState(config=config)
def acquire(self, platform: str) -> _AcquireContext:
"""Context manager that enforces rate limits before allowing a call."""
return _AcquireContext(self, platform)
async def _acquire(self, platform: str) -> None:
async with self._lock:
if self._total_calls >= self._max_total:
raise BudgetExhaustedError(
f"Global API call budget exhausted ({self._max_total} calls). "
"Increase max_total_api_calls in SafetyConfig to allow more."
)
state = self._platforms.get(platform)
if not state:
raise ValueError(f"Platform '{platform}' not registered with rate limiter")
now = time.monotonic()
# Check 429 cooldown
if state.last_429_at:
elapsed = now - state.last_429_at
if elapsed < state.config.cooldown_after_429:
remaining = state.config.cooldown_after_429 - elapsed
raise RateLimitExceededError(
f"Platform '{platform}' is in cooldown after 429. "
f"Try again in {remaining:.0f}s."
)
state.last_429_at = 0.0
# Check burst limit
if state.active_requests >= state.config.burst_size:
raise RateLimitExceededError(
f"Platform '{platform}' burst limit reached "
f"({state.config.burst_size} concurrent). Wait for a request to finish."
)
# Check RPM: discard timestamps older than 60s, then check count
cutoff = now - 60.0
state.call_timestamps = [t for t in state.call_timestamps if t > cutoff]
if len(state.call_timestamps) >= state.config.requests_per_minute:
oldest = state.call_timestamps[0]
wait_time = 60.0 - (now - oldest)
raise RateLimitExceededError(
f"Platform '{platform}' rate limit: {state.config.requests_per_minute}/min. "
f"Try again in {wait_time:.0f}s."
)
# All clear — record the call
state.call_timestamps.append(now)
state.active_requests += 1
self._total_calls += 1
async def _release(self, platform: str) -> None:
async with self._lock:
state = self._platforms.get(platform)
if state:
state.active_requests = max(0, state.active_requests - 1)
def record_429(self, platform: str) -> None:
"""Call this when an API returns 429 to trigger cooldown."""
state = self._platforms.get(platform)
if state:
state.last_429_at = time.monotonic()
def get_stats(self) -> dict:
"""Return current usage stats for logging/reporting."""
stats: dict = {
"total_calls": self._total_calls,
"remaining_calls": self.remaining_calls,
"platforms": {},
}
for name, state in self._platforms.items():
now = time.monotonic()
cutoff = now - 60.0
recent = [t for t in state.call_timestamps if t > cutoff]
stats["platforms"][name] = {
"calls_last_60s": len(recent),
"active_requests": state.active_requests,
"rpm_limit": state.config.requests_per_minute,
"in_cooldown": bool(
state.last_429_at
and (now - state.last_429_at) < state.config.cooldown_after_429
),
}
return stats
class _AcquireContext:
"""Async context manager for rate-limited API calls."""
def __init__(self, limiter: RateLimiter, platform: str):
self._limiter = limiter
self._platform = platform
async def __aenter__(self) -> None:
await self._limiter._acquire(self._platform)
async def __aexit__(self, *exc_info) -> None:
# Check if the call got a 429
if exc_info[0] is not None:
import httpx
exc = exc_info[1]
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429:
self._limiter.record_429(self._platform)
await self._limiter._release(self._platform)