golfgame/server/services/ratelimit.py
adlee-was-taken 6461a7f0c7 Add metered open signups, per-IP limits, and auth security hardening
Enables public beta signup metering: DAILY_OPEN_SIGNUPS env var controls
how many users can register without an invite code per day (0=disabled,
-1=unlimited, N=daily cap). Invite codes always bypass the limit.

Also adds per-IP signup throttling (DAILY_SIGNUPS_PER_IP, default 3/day)
and fail-closed rate limiting on auth endpoints when Redis is down.

Client dynamically fetches /api/auth/signup-info to show invite field
as optional with remaining slots when open signups are enabled.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 14:28:28 -05:00

368 lines
11 KiB
Python

"""
Redis-based rate limiter service.
Implements a sliding window counter algorithm using Redis for distributed
rate limiting across multiple server instances.
"""
import hashlib
import logging
import time
from typing import Optional
import redis.asyncio as redis
from fastapi import Request, WebSocket
logger = logging.getLogger(__name__)
# Rate limit configurations: (max_requests, window_seconds)
RATE_LIMITS = {
"api_general": (100, 60), # 100 requests per minute
"api_auth": (10, 60), # 10 auth attempts per minute
"api_create_room": (5, 60), # 5 room creations per minute
"websocket_connect": (10, 60), # 10 WS connections per minute
"websocket_message": (30, 10), # 30 messages per 10 seconds
"email_send": (3, 300), # 3 emails per 5 minutes
}
class RateLimiter:
"""Token bucket rate limiter using Redis."""
def __init__(self, redis_client: redis.Redis):
"""
Initialize rate limiter with Redis client.
Args:
redis_client: Async Redis client for state storage.
"""
self.redis = redis_client
async def is_allowed(
self,
key: str,
limit: int,
window_seconds: int,
) -> tuple[bool, dict]:
"""
Check if request is allowed under rate limit.
Uses a sliding window counter algorithm:
- Divides time into fixed windows
- Counts requests in current window
- Atomically increments and checks limit
Args:
key: Unique identifier for the rate limit bucket.
limit: Maximum requests allowed in window.
window_seconds: Time window in seconds.
Returns:
Tuple of (allowed, info) where info contains:
- remaining: requests remaining in window
- reset: seconds until window resets
- limit: the limit that was applied
"""
now = int(time.time())
window_key = f"ratelimit:{key}:{now // window_seconds}"
try:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.incr(window_key)
pipe.expire(window_key, window_seconds + 1) # Extra second for safety
results = await pipe.execute()
current_count = results[0]
remaining = max(0, limit - current_count)
reset = window_seconds - (now % window_seconds)
info = {
"remaining": remaining,
"reset": reset,
"limit": limit,
}
allowed = current_count <= limit
if not allowed:
logger.warning(f"Rate limit exceeded for {key}: {current_count}/{limit}")
return allowed, info
except redis.RedisError as e:
# If Redis is unavailable, fail open (allow request)
# For auth-critical paths, callers should use fail_closed=True
logger.error(f"Rate limiter Redis error: {e}")
return True, {"remaining": limit, "reset": window_seconds, "limit": limit}
async def is_allowed_strict(
self,
key: str,
limit: int,
window_seconds: int,
) -> tuple[bool, dict]:
"""
Like is_allowed but fails closed (denies) when Redis is unavailable.
Use for security-critical paths like auth endpoints.
"""
now = int(time.time())
window_key = f"ratelimit:{key}:{now // window_seconds}"
try:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.incr(window_key)
pipe.expire(window_key, window_seconds + 1)
results = await pipe.execute()
current_count = results[0]
remaining = max(0, limit - current_count)
reset = window_seconds - (now % window_seconds)
return current_count <= limit, {
"remaining": remaining,
"reset": reset,
"limit": limit,
}
except redis.RedisError as e:
logger.error(f"Rate limiter Redis error (fail-closed): {e}")
return False, {"remaining": 0, "reset": window_seconds, "limit": limit}
def get_client_key(
self,
request: Request | WebSocket,
user_id: Optional[str] = None,
) -> str:
"""
Generate rate limit key for client.
Uses user ID if authenticated, otherwise hashes client IP.
Args:
request: HTTP request or WebSocket.
user_id: Authenticated user ID, if available.
Returns:
Unique client identifier string.
"""
if user_id:
return f"user:{user_id}"
# For anonymous users, use IP hash
client_ip = self._get_client_ip(request)
# Hash IP for privacy
ip_hash = hashlib.sha256(client_ip.encode()).hexdigest()[:16]
return f"ip:{ip_hash}"
def _get_client_ip(self, request: Request | WebSocket) -> str:
"""
Extract client IP from request, handling proxies.
Args:
request: HTTP request or WebSocket.
Returns:
Client IP address string.
"""
# Check X-Forwarded-For header (from reverse proxy)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# Take the first IP (original client)
return forwarded.split(",")[0].strip()
# Check X-Real-IP header (nginx)
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# Fall back to direct connection
if request.client:
return request.client.host
return "unknown"
class ConnectionMessageLimiter:
"""
In-memory rate limiter for WebSocket message frequency.
Used to limit messages within a single connection without
requiring Redis round-trips for every message.
"""
def __init__(self, max_messages: int = 30, window_seconds: int = 10):
"""
Initialize connection message limiter.
Args:
max_messages: Maximum messages allowed in window.
window_seconds: Time window in seconds.
"""
self.max_messages = max_messages
self.window_seconds = window_seconds
self.timestamps: list[float] = []
def check(self) -> bool:
"""
Check if another message is allowed.
Maintains a sliding window of message timestamps.
Returns:
True if message is allowed, False if rate limited.
"""
now = time.time()
cutoff = now - self.window_seconds
# Remove old timestamps
self.timestamps = [t for t in self.timestamps if t > cutoff]
# Check limit
if len(self.timestamps) >= self.max_messages:
return False
# Record this message
self.timestamps.append(now)
return True
def reset(self):
"""Reset the limiter (e.g., on reconnection)."""
self.timestamps = []
class SignupLimiter:
"""
Daily signup metering for public beta.
Tracks two counters in Redis:
- Global daily open signups (no invite code)
- Per-IP daily signups (with or without invite code)
Keys auto-expire after 24 hours.
"""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
def _today_key(self, prefix: str) -> str:
"""Generate a Redis key scoped to today's date (UTC)."""
from datetime import datetime, timezone
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
return f"signup:{prefix}:{today}"
async def check_daily_limit(self, daily_limit: int) -> tuple[bool, int]:
"""
Check if global daily open signup limit allows another registration.
Args:
daily_limit: Max open signups per day. -1 = unlimited, 0 = disabled.
Returns:
Tuple of (allowed, remaining). remaining is -1 when unlimited.
"""
if daily_limit == 0:
return False, 0
if daily_limit < 0:
return True, -1
key = self._today_key("daily_open")
try:
count = await self.redis.get(key)
current = int(count) if count else 0
remaining = max(0, daily_limit - current)
return current < daily_limit, remaining
except redis.RedisError as e:
logger.error(f"Signup limiter Redis error (daily check): {e}")
return False, 0 # Fail closed
async def check_ip_limit(self, ip_hash: str, ip_limit: int) -> tuple[bool, int]:
"""
Check if per-IP daily signup limit allows another registration.
Args:
ip_hash: Hashed client IP.
ip_limit: Max signups per IP per day. 0 = unlimited.
Returns:
Tuple of (allowed, remaining).
"""
if ip_limit <= 0:
return True, -1
key = self._today_key(f"ip:{ip_hash}")
try:
count = await self.redis.get(key)
current = int(count) if count else 0
remaining = max(0, ip_limit - current)
return current < ip_limit, remaining
except redis.RedisError as e:
logger.error(f"Signup limiter Redis error (IP check): {e}")
return False, 0 # Fail closed
async def increment_daily(self) -> None:
"""Increment the global daily open signup counter."""
key = self._today_key("daily_open")
try:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.incr(key)
pipe.expire(key, 86400 + 60) # 24h + 1min buffer
await pipe.execute()
except redis.RedisError as e:
logger.error(f"Signup limiter Redis error (daily incr): {e}")
async def increment_ip(self, ip_hash: str) -> None:
"""Increment the per-IP daily signup counter."""
key = self._today_key(f"ip:{ip_hash}")
try:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.incr(key)
pipe.expire(key, 86400 + 60)
await pipe.execute()
except redis.RedisError as e:
logger.error(f"Signup limiter Redis error (IP incr): {e}")
async def get_daily_count(self) -> int:
"""Get current daily open signup count."""
key = self._today_key("daily_open")
try:
count = await self.redis.get(key)
return int(count) if count else 0
except redis.RedisError:
return 0
# Global rate limiter instance
_rate_limiter: Optional[RateLimiter] = None
_signup_limiter: Optional[SignupLimiter] = None
async def get_rate_limiter(redis_client: redis.Redis) -> RateLimiter:
"""
Get or create the global rate limiter instance.
Args:
redis_client: Redis client for state storage.
Returns:
RateLimiter instance.
"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter(redis_client)
return _rate_limiter
async def get_signup_limiter(redis_client: redis.Redis) -> SignupLimiter:
"""Get or create the global signup limiter instance."""
global _signup_limiter
if _signup_limiter is None:
_signup_limiter = SignupLimiter(redis_client)
return _signup_limiter
def close_rate_limiter():
"""Close the global rate limiter."""
global _rate_limiter, _signup_limiter
_rate_limiter = None
_signup_limiter = None