170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
"""Rate limiter and API call budget tracker.
|
|
|
|
Enforces per-platform rate limits and a global call budget so the agent
|
|
can't hammer APIs or run up unbounded costs.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
|
|
from sentiment_agent.config import RateLimitConfig
|
|
|
|
|
|
class BudgetExhaustedError(Exception):
|
|
"""Raised when the global API call budget is spent."""
|
|
|
|
|
|
class RateLimitExceededError(Exception):
|
|
"""Raised when a platform's rate limit is hit and cooldown hasn't elapsed."""
|
|
|
|
|
|
@dataclass
|
|
class _PlatformState:
|
|
"""Tracks call timestamps and active request count for one platform."""
|
|
|
|
config: RateLimitConfig
|
|
call_timestamps: list[float] = field(default_factory=list)
|
|
active_requests: int = 0
|
|
last_429_at: float = 0.0
|
|
|
|
|
|
class RateLimiter:
|
|
"""Manages rate limiting across all platforms + a global call budget.
|
|
|
|
Usage:
|
|
limiter = RateLimiter(max_total_calls=50)
|
|
limiter.register_platform("reddit", RateLimitConfig(...))
|
|
|
|
async with limiter.acquire("reddit"):
|
|
await do_reddit_call()
|
|
"""
|
|
|
|
def __init__(self, max_total_calls: int = 50):
|
|
self._max_total = max_total_calls
|
|
self._total_calls = 0
|
|
self._platforms: dict[str, _PlatformState] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
@property
|
|
def total_calls(self) -> int:
|
|
return self._total_calls
|
|
|
|
@property
|
|
def remaining_calls(self) -> int:
|
|
return max(0, self._max_total - self._total_calls)
|
|
|
|
def register_platform(self, name: str, config: RateLimitConfig) -> None:
|
|
self._platforms[name] = _PlatformState(config=config)
|
|
|
|
def acquire(self, platform: str) -> _AcquireContext:
|
|
"""Context manager that enforces rate limits before allowing a call."""
|
|
return _AcquireContext(self, platform)
|
|
|
|
async def _acquire(self, platform: str) -> None:
|
|
async with self._lock:
|
|
if self._total_calls >= self._max_total:
|
|
raise BudgetExhaustedError(
|
|
f"Global API call budget exhausted ({self._max_total} calls). "
|
|
"Increase max_total_api_calls in SafetyConfig to allow more."
|
|
)
|
|
|
|
state = self._platforms.get(platform)
|
|
if not state:
|
|
raise ValueError(f"Platform '{platform}' not registered with rate limiter")
|
|
|
|
now = time.monotonic()
|
|
|
|
# Check 429 cooldown
|
|
if state.last_429_at:
|
|
elapsed = now - state.last_429_at
|
|
if elapsed < state.config.cooldown_after_429:
|
|
remaining = state.config.cooldown_after_429 - elapsed
|
|
raise RateLimitExceededError(
|
|
f"Platform '{platform}' is in cooldown after 429. "
|
|
f"Try again in {remaining:.0f}s."
|
|
)
|
|
state.last_429_at = 0.0
|
|
|
|
# Check burst limit
|
|
if state.active_requests >= state.config.burst_size:
|
|
raise RateLimitExceededError(
|
|
f"Platform '{platform}' burst limit reached "
|
|
f"({state.config.burst_size} concurrent). Wait for a request to finish."
|
|
)
|
|
|
|
# Check RPM: discard timestamps older than 60s, then check count
|
|
cutoff = now - 60.0
|
|
state.call_timestamps = [t for t in state.call_timestamps if t > cutoff]
|
|
|
|
if len(state.call_timestamps) >= state.config.requests_per_minute:
|
|
oldest = state.call_timestamps[0]
|
|
wait_time = 60.0 - (now - oldest)
|
|
raise RateLimitExceededError(
|
|
f"Platform '{platform}' rate limit: {state.config.requests_per_minute}/min. "
|
|
f"Try again in {wait_time:.0f}s."
|
|
)
|
|
|
|
# All clear — record the call
|
|
state.call_timestamps.append(now)
|
|
state.active_requests += 1
|
|
self._total_calls += 1
|
|
|
|
async def _release(self, platform: str) -> None:
|
|
async with self._lock:
|
|
state = self._platforms.get(platform)
|
|
if state:
|
|
state.active_requests = max(0, state.active_requests - 1)
|
|
|
|
def record_429(self, platform: str) -> None:
|
|
"""Call this when an API returns 429 to trigger cooldown."""
|
|
state = self._platforms.get(platform)
|
|
if state:
|
|
state.last_429_at = time.monotonic()
|
|
|
|
def get_stats(self) -> dict:
|
|
"""Return current usage stats for logging/reporting."""
|
|
stats: dict = {
|
|
"total_calls": self._total_calls,
|
|
"remaining_calls": self.remaining_calls,
|
|
"platforms": {},
|
|
}
|
|
for name, state in self._platforms.items():
|
|
now = time.monotonic()
|
|
cutoff = now - 60.0
|
|
recent = [t for t in state.call_timestamps if t > cutoff]
|
|
stats["platforms"][name] = {
|
|
"calls_last_60s": len(recent),
|
|
"active_requests": state.active_requests,
|
|
"rpm_limit": state.config.requests_per_minute,
|
|
"in_cooldown": bool(
|
|
state.last_429_at
|
|
and (now - state.last_429_at) < state.config.cooldown_after_429
|
|
),
|
|
}
|
|
return stats
|
|
|
|
|
|
class _AcquireContext:
|
|
"""Async context manager for rate-limited API calls."""
|
|
|
|
def __init__(self, limiter: RateLimiter, platform: str):
|
|
self._limiter = limiter
|
|
self._platform = platform
|
|
|
|
async def __aenter__(self) -> None:
|
|
await self._limiter._acquire(self._platform)
|
|
|
|
async def __aexit__(self, *exc_info) -> None:
|
|
# Check if the call got a 429
|
|
if exc_info[0] is not None:
|
|
import httpx
|
|
|
|
exc = exc_info[1]
|
|
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429:
|
|
self._limiter.record_429(self._platform)
|
|
|
|
await self._limiter._release(self._platform)
|