224 lines
6.3 KiB
Python
224 lines
6.3 KiB
Python
"""
|
|
Redis-based rate limiter service.
|
|
|
|
Implements a sliding window counter algorithm using Redis for distributed
|
|
rate limiting across multiple server instances.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import time
|
|
from typing import Optional
|
|
|
|
import redis.asyncio as redis
|
|
from fastapi import Request, WebSocket
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Rate limit configurations: (max_requests, window_seconds)
|
|
RATE_LIMITS = {
|
|
"api_general": (100, 60), # 100 requests per minute
|
|
"api_auth": (10, 60), # 10 auth attempts per minute
|
|
"api_create_room": (5, 60), # 5 room creations per minute
|
|
"websocket_connect": (10, 60), # 10 WS connections per minute
|
|
"websocket_message": (30, 10), # 30 messages per 10 seconds
|
|
"email_send": (3, 300), # 3 emails per 5 minutes
|
|
}
|
|
|
|
|
|
class RateLimiter:
|
|
"""Token bucket rate limiter using Redis."""
|
|
|
|
def __init__(self, redis_client: redis.Redis):
|
|
"""
|
|
Initialize rate limiter with Redis client.
|
|
|
|
Args:
|
|
redis_client: Async Redis client for state storage.
|
|
"""
|
|
self.redis = redis_client
|
|
|
|
async def is_allowed(
|
|
self,
|
|
key: str,
|
|
limit: int,
|
|
window_seconds: int,
|
|
) -> tuple[bool, dict]:
|
|
"""
|
|
Check if request is allowed under rate limit.
|
|
|
|
Uses a sliding window counter algorithm:
|
|
- Divides time into fixed windows
|
|
- Counts requests in current window
|
|
- Atomically increments and checks limit
|
|
|
|
Args:
|
|
key: Unique identifier for the rate limit bucket.
|
|
limit: Maximum requests allowed in window.
|
|
window_seconds: Time window in seconds.
|
|
|
|
Returns:
|
|
Tuple of (allowed, info) where info contains:
|
|
- remaining: requests remaining in window
|
|
- reset: seconds until window resets
|
|
- limit: the limit that was applied
|
|
"""
|
|
now = int(time.time())
|
|
window_key = f"ratelimit:{key}:{now // window_seconds}"
|
|
|
|
try:
|
|
async with self.redis.pipeline(transaction=True) as pipe:
|
|
pipe.incr(window_key)
|
|
pipe.expire(window_key, window_seconds + 1) # Extra second for safety
|
|
results = await pipe.execute()
|
|
|
|
current_count = results[0]
|
|
remaining = max(0, limit - current_count)
|
|
reset = window_seconds - (now % window_seconds)
|
|
|
|
info = {
|
|
"remaining": remaining,
|
|
"reset": reset,
|
|
"limit": limit,
|
|
}
|
|
|
|
allowed = current_count <= limit
|
|
if not allowed:
|
|
logger.warning(f"Rate limit exceeded for {key}: {current_count}/{limit}")
|
|
|
|
return allowed, info
|
|
|
|
except redis.RedisError as e:
|
|
# If Redis is unavailable, fail open (allow request)
|
|
logger.error(f"Rate limiter Redis error: {e}")
|
|
return True, {"remaining": limit, "reset": window_seconds, "limit": limit}
|
|
|
|
def get_client_key(
|
|
self,
|
|
request: Request | WebSocket,
|
|
user_id: Optional[str] = None,
|
|
) -> str:
|
|
"""
|
|
Generate rate limit key for client.
|
|
|
|
Uses user ID if authenticated, otherwise hashes client IP.
|
|
|
|
Args:
|
|
request: HTTP request or WebSocket.
|
|
user_id: Authenticated user ID, if available.
|
|
|
|
Returns:
|
|
Unique client identifier string.
|
|
"""
|
|
if user_id:
|
|
return f"user:{user_id}"
|
|
|
|
# For anonymous users, use IP hash
|
|
client_ip = self._get_client_ip(request)
|
|
|
|
# Hash IP for privacy
|
|
ip_hash = hashlib.sha256(client_ip.encode()).hexdigest()[:16]
|
|
return f"ip:{ip_hash}"
|
|
|
|
def _get_client_ip(self, request: Request | WebSocket) -> str:
|
|
"""
|
|
Extract client IP from request, handling proxies.
|
|
|
|
Args:
|
|
request: HTTP request or WebSocket.
|
|
|
|
Returns:
|
|
Client IP address string.
|
|
"""
|
|
# Check X-Forwarded-For header (from reverse proxy)
|
|
forwarded = request.headers.get("X-Forwarded-For")
|
|
if forwarded:
|
|
# Take the first IP (original client)
|
|
return forwarded.split(",")[0].strip()
|
|
|
|
# Check X-Real-IP header (nginx)
|
|
real_ip = request.headers.get("X-Real-IP")
|
|
if real_ip:
|
|
return real_ip.strip()
|
|
|
|
# Fall back to direct connection
|
|
if request.client:
|
|
return request.client.host
|
|
|
|
return "unknown"
|
|
|
|
|
|
class ConnectionMessageLimiter:
|
|
"""
|
|
In-memory rate limiter for WebSocket message frequency.
|
|
|
|
Used to limit messages within a single connection without
|
|
requiring Redis round-trips for every message.
|
|
"""
|
|
|
|
def __init__(self, max_messages: int = 30, window_seconds: int = 10):
|
|
"""
|
|
Initialize connection message limiter.
|
|
|
|
Args:
|
|
max_messages: Maximum messages allowed in window.
|
|
window_seconds: Time window in seconds.
|
|
"""
|
|
self.max_messages = max_messages
|
|
self.window_seconds = window_seconds
|
|
self.timestamps: list[float] = []
|
|
|
|
def check(self) -> bool:
|
|
"""
|
|
Check if another message is allowed.
|
|
|
|
Maintains a sliding window of message timestamps.
|
|
|
|
Returns:
|
|
True if message is allowed, False if rate limited.
|
|
"""
|
|
now = time.time()
|
|
cutoff = now - self.window_seconds
|
|
|
|
# Remove old timestamps
|
|
self.timestamps = [t for t in self.timestamps if t > cutoff]
|
|
|
|
# Check limit
|
|
if len(self.timestamps) >= self.max_messages:
|
|
return False
|
|
|
|
# Record this message
|
|
self.timestamps.append(now)
|
|
return True
|
|
|
|
def reset(self):
|
|
"""Reset the limiter (e.g., on reconnection)."""
|
|
self.timestamps = []
|
|
|
|
|
|
# Global rate limiter instance
|
|
_rate_limiter: Optional[RateLimiter] = None
|
|
|
|
|
|
async def get_rate_limiter(redis_client: redis.Redis) -> RateLimiter:
|
|
"""
|
|
Get or create the global rate limiter instance.
|
|
|
|
Args:
|
|
redis_client: Redis client for state storage.
|
|
|
|
Returns:
|
|
RateLimiter instance.
|
|
"""
|
|
global _rate_limiter
|
|
if _rate_limiter is None:
|
|
_rate_limiter = RateLimiter(redis_client)
|
|
return _rate_limiter
|
|
|
|
|
|
def close_rate_limiter():
|
|
"""Close the global rate limiter."""
|
|
global _rate_limiter
|
|
_rate_limiter = None
|