Huge v2 uplift, now deployable with real user management and tooling!

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee
2026-01-27 11:32:15 -05:00
parent c912a56c2d
commit bea85e6b28
61 changed files with 25153 additions and 362 deletions

View File

@@ -0,0 +1,18 @@
"""
Middleware components for Golf game server.
Provides:
- RateLimitMiddleware: API rate limiting with Redis backend
- SecurityHeadersMiddleware: Security headers (CSP, HSTS, etc.)
- RequestIDMiddleware: Request tracing with X-Request-ID
"""
from .ratelimit import RateLimitMiddleware
from .security import SecurityHeadersMiddleware
from .request_id import RequestIDMiddleware
__all__ = [
"RateLimitMiddleware",
"SecurityHeadersMiddleware",
"RequestIDMiddleware",
]

View File

@@ -0,0 +1,173 @@
"""
Rate limiting middleware for FastAPI.
Applies per-endpoint rate limits and adds X-RateLimit-* headers to responses.
"""
import logging
from typing import Callable, Optional
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, Response
from services.ratelimit import RateLimiter, RATE_LIMITS
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
HTTP middleware for rate limiting API requests.
Applies rate limits based on request path and adds standard
rate limit headers to all responses.
"""
def __init__(
self,
app,
rate_limiter: RateLimiter,
enabled: bool = True,
get_user_id: Optional[Callable[[Request], Optional[str]]] = None,
):
"""
Initialize rate limit middleware.
Args:
app: FastAPI application.
rate_limiter: RateLimiter service instance.
enabled: Whether rate limiting is enabled.
get_user_id: Optional callback to extract user ID from request.
"""
super().__init__(app)
self.limiter = rate_limiter
self.enabled = enabled
self.get_user_id = get_user_id
async def dispatch(self, request: Request, call_next) -> Response:
"""
Process request through rate limiter.
Args:
request: Incoming HTTP request.
call_next: Next middleware/handler in chain.
Returns:
HTTP response with rate limit headers.
"""
# Skip if disabled
if not self.enabled:
return await call_next(request)
# Determine rate limit tier based on path
path = request.url.path
limit_config = self._get_limit_config(path, request.method)
# No rate limiting for this endpoint
if limit_config is None:
return await call_next(request)
limit, window = limit_config
# Get user ID if authenticated
user_id = None
if self.get_user_id:
try:
user_id = self.get_user_id(request)
except Exception:
pass
# Generate client key
client_key = self.limiter.get_client_key(request, user_id)
# Check rate limit
endpoint_key = self._get_endpoint_key(path)
full_key = f"{endpoint_key}:{client_key}"
allowed, info = await self.limiter.is_allowed(full_key, limit, window)
# Build response
if allowed:
response = await call_next(request)
else:
response = JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"message": f"Too many requests. Please wait {info['reset']} seconds.",
"retry_after": info["reset"],
},
)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(info["limit"])
response.headers["X-RateLimit-Remaining"] = str(info["remaining"])
response.headers["X-RateLimit-Reset"] = str(info["reset"])
if not allowed:
response.headers["Retry-After"] = str(info["reset"])
return response
def _get_limit_config(
self,
path: str,
method: str,
) -> Optional[tuple[int, int]]:
"""
Get rate limit configuration for a path.
Args:
path: Request URL path.
method: HTTP method.
Returns:
Tuple of (limit, window_seconds) or None for no limiting.
"""
# No rate limiting for health checks
if path in ("/health", "/ready", "/metrics"):
return None
# No rate limiting for static files
if path.endswith((".js", ".css", ".html", ".ico", ".png", ".jpg")):
return None
# Authentication endpoints - stricter limits
if path.startswith("/api/auth"):
return RATE_LIMITS["api_auth"]
# Room creation - moderate limits
if path == "/api/rooms" and method == "POST":
return RATE_LIMITS["api_create_room"]
# Email endpoints - very strict
if "email" in path or "verify" in path:
return RATE_LIMITS["email_send"]
# General API endpoints
if path.startswith("/api"):
return RATE_LIMITS["api_general"]
# Default: no rate limiting for non-API paths
return None
def _get_endpoint_key(self, path: str) -> str:
"""
Normalize path to endpoint key for rate limiting.
Groups similar endpoints together (e.g., /api/users/123 -> /api/users/:id).
Args:
path: Request URL path.
Returns:
Normalized endpoint key.
"""
# Simple normalization - strip trailing slashes
key = path.rstrip("/")
# Could add more sophisticated path parameter normalization here
# For example: /api/users/123 -> /api/users/:id
return key or "/"

View File

@@ -0,0 +1,93 @@
"""
Request ID middleware for request tracing.
Generates or propagates X-Request-ID header for distributed tracing.
"""
import logging
import uuid
from typing import Optional
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from logging_config import request_id_var
logger = logging.getLogger(__name__)
class RequestIDMiddleware(BaseHTTPMiddleware):
"""
HTTP middleware for request ID generation and propagation.
- Extracts X-Request-ID from incoming request headers
- Generates a new UUID if not present
- Sets request_id in context var for logging
- Adds X-Request-ID to response headers
"""
def __init__(
self,
app,
header_name: str = "X-Request-ID",
generator: Optional[callable] = None,
):
"""
Initialize request ID middleware.
Args:
app: FastAPI application.
header_name: Header name for request ID.
generator: Optional custom ID generator function.
"""
super().__init__(app)
self.header_name = header_name
self.generator = generator or (lambda: str(uuid.uuid4()))
async def dispatch(self, request: Request, call_next) -> Response:
"""
Process request with request ID.
Args:
request: Incoming HTTP request.
call_next: Next middleware/handler in chain.
Returns:
HTTP response with X-Request-ID header.
"""
# Get or generate request ID
request_id = request.headers.get(self.header_name)
if not request_id:
request_id = self.generator()
# Set in request state for access in handlers
request.state.request_id = request_id
# Set in context var for logging
token = request_id_var.set(request_id)
try:
# Process request
response = await call_next(request)
# Add request ID to response
response.headers[self.header_name] = request_id
return response
finally:
# Reset context var
request_id_var.reset(token)
def get_request_id(request: Request) -> Optional[str]:
"""
Get request ID from request state.
Args:
request: FastAPI request object.
Returns:
Request ID string or None.
"""
return getattr(request.state, "request_id", None)

View File

@@ -0,0 +1,140 @@
"""
Security headers middleware for FastAPI.
Adds security headers to all responses:
- Content-Security-Policy (CSP)
- X-Content-Type-Options
- X-Frame-Options
- X-XSS-Protection
- Referrer-Policy
- Permissions-Policy
- Strict-Transport-Security (HSTS)
"""
import logging
from typing import Optional
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
logger = logging.getLogger(__name__)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""
HTTP middleware for adding security headers.
Configurable CSP and HSTS settings for different environments.
"""
def __init__(
self,
app,
environment: str = "development",
csp_report_uri: Optional[str] = None,
allowed_hosts: Optional[list[str]] = None,
):
"""
Initialize security headers middleware.
Args:
app: FastAPI application.
environment: Environment name (production enables HSTS).
csp_report_uri: Optional URI for CSP violation reports.
allowed_hosts: List of allowed hosts for connect-src directive.
"""
super().__init__(app)
self.environment = environment
self.csp_report_uri = csp_report_uri
self.allowed_hosts = allowed_hosts or []
async def dispatch(self, request: Request, call_next) -> Response:
"""
Add security headers to response.
Args:
request: Incoming HTTP request.
call_next: Next middleware/handler in chain.
Returns:
HTTP response with security headers.
"""
response = await call_next(request)
# Basic security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions Policy (formerly Feature-Policy)
response.headers["Permissions-Policy"] = (
"geolocation=(), "
"microphone=(), "
"camera=(), "
"payment=(), "
"usb=()"
)
# Content Security Policy
csp = self._build_csp(request)
response.headers["Content-Security-Policy"] = csp
# HSTS (only in production with HTTPS)
if self.environment == "production":
# Only add HSTS if request came via HTTPS
forwarded_proto = request.headers.get("X-Forwarded-Proto", "")
if forwarded_proto == "https" or request.url.scheme == "https":
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains; preload"
)
return response
def _build_csp(self, request: Request) -> str:
"""
Build Content-Security-Policy header.
Args:
request: HTTP request (for host-specific directives).
Returns:
CSP header value string.
"""
# Get the host for WebSocket connections
host = request.headers.get("host", "localhost")
# Build connect-src directive
connect_sources = ["'self'"]
# Add WebSocket URLs
if self.environment == "production":
connect_sources.append(f"wss://{host}")
for allowed_host in self.allowed_hosts:
connect_sources.append(f"wss://{allowed_host}")
else:
# Development - allow ws:// and wss://
connect_sources.append(f"ws://{host}")
connect_sources.append(f"wss://{host}")
connect_sources.append("ws://localhost:*")
connect_sources.append("wss://localhost:*")
directives = [
"default-src 'self'",
"script-src 'self'",
# Allow inline styles for UI (cards, animations)
"style-src 'self' 'unsafe-inline'",
"img-src 'self' data:",
"font-src 'self'",
f"connect-src {' '.join(connect_sources)}",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
]
# Add report-uri if configured
if self.csp_report_uri:
directives.append(f"report-uri {self.csp_report_uri}")
return "; ".join(directives)