Minor fixes
This commit is contained in:
169
agentstuff/sentiment_agent/ratelimit.py
Normal file
169
agentstuff/sentiment_agent/ratelimit.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Rate limiter and API call budget tracker.
|
||||
|
||||
Enforces per-platform rate limits and a global call budget so the agent
|
||||
can't hammer APIs or run up unbounded costs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sentiment_agent.config import RateLimitConfig
|
||||
|
||||
|
||||
class BudgetExhaustedError(Exception):
|
||||
"""Raised when the global API call budget is spent."""
|
||||
|
||||
|
||||
class RateLimitExceededError(Exception):
|
||||
"""Raised when a platform's rate limit is hit and cooldown hasn't elapsed."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PlatformState:
|
||||
"""Tracks call timestamps and active request count for one platform."""
|
||||
|
||||
config: RateLimitConfig
|
||||
call_timestamps: list[float] = field(default_factory=list)
|
||||
active_requests: int = 0
|
||||
last_429_at: float = 0.0
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Manages rate limiting across all platforms + a global call budget.
|
||||
|
||||
Usage:
|
||||
limiter = RateLimiter(max_total_calls=50)
|
||||
limiter.register_platform("reddit", RateLimitConfig(...))
|
||||
|
||||
async with limiter.acquire("reddit"):
|
||||
await do_reddit_call()
|
||||
"""
|
||||
|
||||
def __init__(self, max_total_calls: int = 50):
|
||||
self._max_total = max_total_calls
|
||||
self._total_calls = 0
|
||||
self._platforms: dict[str, _PlatformState] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def total_calls(self) -> int:
|
||||
return self._total_calls
|
||||
|
||||
@property
|
||||
def remaining_calls(self) -> int:
|
||||
return max(0, self._max_total - self._total_calls)
|
||||
|
||||
def register_platform(self, name: str, config: RateLimitConfig) -> None:
|
||||
self._platforms[name] = _PlatformState(config=config)
|
||||
|
||||
def acquire(self, platform: str) -> _AcquireContext:
|
||||
"""Context manager that enforces rate limits before allowing a call."""
|
||||
return _AcquireContext(self, platform)
|
||||
|
||||
async def _acquire(self, platform: str) -> None:
|
||||
async with self._lock:
|
||||
if self._total_calls >= self._max_total:
|
||||
raise BudgetExhaustedError(
|
||||
f"Global API call budget exhausted ({self._max_total} calls). "
|
||||
"Increase max_total_api_calls in SafetyConfig to allow more."
|
||||
)
|
||||
|
||||
state = self._platforms.get(platform)
|
||||
if not state:
|
||||
raise ValueError(f"Platform '{platform}' not registered with rate limiter")
|
||||
|
||||
now = time.monotonic()
|
||||
|
||||
# Check 429 cooldown
|
||||
if state.last_429_at:
|
||||
elapsed = now - state.last_429_at
|
||||
if elapsed < state.config.cooldown_after_429:
|
||||
remaining = state.config.cooldown_after_429 - elapsed
|
||||
raise RateLimitExceededError(
|
||||
f"Platform '{platform}' is in cooldown after 429. "
|
||||
f"Try again in {remaining:.0f}s."
|
||||
)
|
||||
state.last_429_at = 0.0
|
||||
|
||||
# Check burst limit
|
||||
if state.active_requests >= state.config.burst_size:
|
||||
raise RateLimitExceededError(
|
||||
f"Platform '{platform}' burst limit reached "
|
||||
f"({state.config.burst_size} concurrent). Wait for a request to finish."
|
||||
)
|
||||
|
||||
# Check RPM: discard timestamps older than 60s, then check count
|
||||
cutoff = now - 60.0
|
||||
state.call_timestamps = [t for t in state.call_timestamps if t > cutoff]
|
||||
|
||||
if len(state.call_timestamps) >= state.config.requests_per_minute:
|
||||
oldest = state.call_timestamps[0]
|
||||
wait_time = 60.0 - (now - oldest)
|
||||
raise RateLimitExceededError(
|
||||
f"Platform '{platform}' rate limit: {state.config.requests_per_minute}/min. "
|
||||
f"Try again in {wait_time:.0f}s."
|
||||
)
|
||||
|
||||
# All clear — record the call
|
||||
state.call_timestamps.append(now)
|
||||
state.active_requests += 1
|
||||
self._total_calls += 1
|
||||
|
||||
async def _release(self, platform: str) -> None:
|
||||
async with self._lock:
|
||||
state = self._platforms.get(platform)
|
||||
if state:
|
||||
state.active_requests = max(0, state.active_requests - 1)
|
||||
|
||||
def record_429(self, platform: str) -> None:
|
||||
"""Call this when an API returns 429 to trigger cooldown."""
|
||||
state = self._platforms.get(platform)
|
||||
if state:
|
||||
state.last_429_at = time.monotonic()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Return current usage stats for logging/reporting."""
|
||||
stats: dict = {
|
||||
"total_calls": self._total_calls,
|
||||
"remaining_calls": self.remaining_calls,
|
||||
"platforms": {},
|
||||
}
|
||||
for name, state in self._platforms.items():
|
||||
now = time.monotonic()
|
||||
cutoff = now - 60.0
|
||||
recent = [t for t in state.call_timestamps if t > cutoff]
|
||||
stats["platforms"][name] = {
|
||||
"calls_last_60s": len(recent),
|
||||
"active_requests": state.active_requests,
|
||||
"rpm_limit": state.config.requests_per_minute,
|
||||
"in_cooldown": bool(
|
||||
state.last_429_at
|
||||
and (now - state.last_429_at) < state.config.cooldown_after_429
|
||||
),
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
class _AcquireContext:
|
||||
"""Async context manager for rate-limited API calls."""
|
||||
|
||||
def __init__(self, limiter: RateLimiter, platform: str):
|
||||
self._limiter = limiter
|
||||
self._platform = platform
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self._limiter._acquire(self._platform)
|
||||
|
||||
async def __aexit__(self, *exc_info) -> None:
|
||||
# Check if the call got a 429
|
||||
if exc_info[0] is not None:
|
||||
import httpx
|
||||
|
||||
exc = exc_info[1]
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code == 429:
|
||||
self._limiter.record_429(self._platform)
|
||||
|
||||
await self._limiter._release(self._platform)
|
||||
Reference in New Issue
Block a user