"""Rate limiter and API call budget tracker. Enforces per-platform rate limits and a global call budget so the agent can't hammer APIs or run up unbounded costs. """ from __future__ import annotations import asyncio import time from dataclasses import dataclass, field from sentiment_agent.config import RateLimitConfig class BudgetExhaustedError(Exception): """Raised when the global API call budget is spent.""" class RateLimitExceededError(Exception): """Raised when a platform's rate limit is hit and cooldown hasn't elapsed.""" @dataclass class _PlatformState: """Tracks call timestamps and active request count for one platform.""" config: RateLimitConfig call_timestamps: list[float] = field(default_factory=list) active_requests: int = 0 last_429_at: float = 0.0 class RateLimiter: """Manages rate limiting across all platforms + a global call budget. Usage: limiter = RateLimiter(max_total_calls=50) limiter.register_platform("reddit", RateLimitConfig(...)) async with limiter.acquire("reddit"): await do_reddit_call() """ def __init__(self, max_total_calls: int = 50): self._max_total = max_total_calls self._total_calls = 0 self._platforms: dict[str, _PlatformState] = {} self._lock = asyncio.Lock() @property def total_calls(self) -> int: return self._total_calls @property def remaining_calls(self) -> int: return max(0, self._max_total - self._total_calls) def register_platform(self, name: str, config: RateLimitConfig) -> None: self._platforms[name] = _PlatformState(config=config) def acquire(self, platform: str) -> _AcquireContext: """Context manager that enforces rate limits before allowing a call.""" return _AcquireContext(self, platform) async def _acquire(self, platform: str) -> None: async with self._lock: if self._total_calls >= self._max_total: raise BudgetExhaustedError( f"Global API call budget exhausted ({self._max_total} calls). " "Increase max_total_api_calls in SafetyConfig to allow more." ) state = self._platforms.get(platform) if not state: raise ValueError(f"Platform '{platform}' not registered with rate limiter") now = time.monotonic() # Check 429 cooldown if state.last_429_at: elapsed = now - state.last_429_at if elapsed < state.config.cooldown_after_429: remaining = state.config.cooldown_after_429 - elapsed raise RateLimitExceededError( f"Platform '{platform}' is in cooldown after 429. " f"Try again in {remaining:.0f}s." ) state.last_429_at = 0.0 # Check burst limit if state.active_requests >= state.config.burst_size: raise RateLimitExceededError( f"Platform '{platform}' burst limit reached " f"({state.config.burst_size} concurrent). Wait for a request to finish." ) # Check RPM: discard timestamps older than 60s, then check count cutoff = now - 60.0 state.call_timestamps = [t for t in state.call_timestamps if t > cutoff] if len(state.call_timestamps) >= state.config.requests_per_minute: oldest = state.call_timestamps[0] wait_time = 60.0 - (now - oldest) raise RateLimitExceededError( f"Platform '{platform}' rate limit: {state.config.requests_per_minute}/min. " f"Try again in {wait_time:.0f}s." ) # All clear — record the call state.call_timestamps.append(now) state.active_requests += 1 self._total_calls += 1 async def _release(self, platform: str) -> None: async with self._lock: state = self._platforms.get(platform) if state: state.active_requests = max(0, state.active_requests - 1) def record_429(self, platform: str) -> None: """Call this when an API returns 429 to trigger cooldown.""" state = self._platforms.get(platform) if state: state.last_429_at = time.monotonic() def get_stats(self) -> dict: """Return current usage stats for logging/reporting.""" stats: dict = { "total_calls": self._total_calls, "remaining_calls": self.remaining_calls, "platforms": {}, } for name, state in self._platforms.items(): now = time.monotonic() cutoff = now - 60.0 recent = [t for t in state.call_timestamps if t > cutoff] stats["platforms"][name] = { "calls_last_60s": len(recent), "active_requests": state.active_requests, "rpm_limit": state.config.requests_per_minute, "in_cooldown": bool( state.last_429_at and (now - state.last_429_at) < state.config.cooldown_after_429 ), } return stats class _AcquireContext: """Async context manager for rate-limited API calls.""" def __init__(self, limiter: RateLimiter, platform: str): self._limiter = limiter self._platform = platform async def __aenter__(self) -> None: await self._limiter._acquire(self._platform) async def __aexit__(self, *exc_info) -> None: # Check if the call got a 429 if exc_info[0] is not None: import httpx exc = exc_info[1] if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429: self._limiter.record_429(self._platform) await self._limiter._release(self._platform)