Huge v2 uplift, now deployable with real user management and tooling!

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Aaron D. Lee
2026-01-27 11:32:15 -05:00
parent c912a56c2d
commit bea85e6b28
61 changed files with 25153 additions and 362 deletions

View File

@@ -0,0 +1,33 @@
"""Services package for Golf game V2 business logic."""
from .recovery_service import RecoveryService, RecoveryResult
from .email_service import EmailService, get_email_service
from .auth_service import AuthService, AuthResult, RegistrationResult, get_auth_service, close_auth_service
from .admin_service import (
AdminService,
UserDetails,
AuditEntry,
SystemStats,
InviteCode,
get_admin_service,
close_admin_service,
)
__all__ = [
"RecoveryService",
"RecoveryResult",
"EmailService",
"get_email_service",
"AuthService",
"AuthResult",
"RegistrationResult",
"get_auth_service",
"close_auth_service",
"AdminService",
"UserDetails",
"AuditEntry",
"SystemStats",
"InviteCode",
"get_admin_service",
"close_admin_service",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,654 @@
"""
Authentication service for Golf game.
Provides business logic for user registration, login, password management,
and session handling.
"""
import logging
import secrets
from dataclasses import dataclass
from datetime import datetime, timezone, timedelta
from typing import Optional
import bcrypt
from config import config
from models.user import User, UserRole, UserSession, GuestSession
from stores.user_store import UserStore
from services.email_service import EmailService
logger = logging.getLogger(__name__)
@dataclass
class AuthResult:
"""Result of an authentication operation."""
success: bool
user: Optional[User] = None
token: Optional[str] = None
expires_at: Optional[datetime] = None
error: Optional[str] = None
@dataclass
class RegistrationResult:
"""Result of a registration operation."""
success: bool
user: Optional[User] = None
requires_verification: bool = False
error: Optional[str] = None
class AuthService:
"""
Authentication service.
Handles all authentication business logic:
- User registration with optional email verification
- Login/logout with session management
- Password reset flow
- Guest-to-user conversion
- Account deletion (soft delete)
"""
def __init__(
self,
user_store: UserStore,
email_service: EmailService,
session_expiry_hours: int = 168,
require_email_verification: bool = False,
):
"""
Initialize auth service.
Args:
user_store: User persistence store.
email_service: Email sending service.
session_expiry_hours: Session lifetime in hours.
require_email_verification: Whether to require email verification.
"""
self.user_store = user_store
self.email_service = email_service
self.session_expiry_hours = session_expiry_hours
self.require_email_verification = require_email_verification
@classmethod
async def create(cls, user_store: UserStore) -> "AuthService":
"""
Create AuthService from config.
Args:
user_store: User persistence store.
"""
from services.email_service import get_email_service
return cls(
user_store=user_store,
email_service=get_email_service(),
session_expiry_hours=config.SESSION_EXPIRY_HOURS,
require_email_verification=config.REQUIRE_EMAIL_VERIFICATION,
)
# -------------------------------------------------------------------------
# Registration
# -------------------------------------------------------------------------
async def register(
self,
username: str,
password: str,
email: Optional[str] = None,
guest_id: Optional[str] = None,
) -> RegistrationResult:
"""
Register a new user account.
Args:
username: Desired username.
password: Plain text password.
email: Optional email address.
guest_id: Guest session ID if converting.
Returns:
RegistrationResult with user or error.
"""
# Validate inputs
if len(username) < 2 or len(username) > 50:
return RegistrationResult(success=False, error="Username must be 2-50 characters")
if len(password) < 8:
return RegistrationResult(success=False, error="Password must be at least 8 characters")
# Check for existing username
existing = await self.user_store.get_user_by_username(username)
if existing:
return RegistrationResult(success=False, error="Username already taken")
# Check for existing email
if email:
existing = await self.user_store.get_user_by_email(email)
if existing:
return RegistrationResult(success=False, error="Email already registered")
# Hash password
password_hash = self._hash_password(password)
# Generate verification token if needed
verification_token = None
verification_expires = None
if email and self.require_email_verification:
verification_token = secrets.token_urlsafe(32)
verification_expires = datetime.now(timezone.utc) + timedelta(hours=24)
# Create user
user = await self.user_store.create_user(
username=username,
password_hash=password_hash,
email=email,
role=UserRole.USER,
guest_id=guest_id,
verification_token=verification_token,
verification_expires=verification_expires,
)
if not user:
return RegistrationResult(success=False, error="Failed to create account")
# Mark guest as converted if applicable
if guest_id:
await self.user_store.mark_guest_converted(guest_id, user.id)
# Send verification email if needed
requires_verification = False
if email and self.require_email_verification and verification_token:
await self.email_service.send_verification_email(
to=email,
token=verification_token,
username=username,
)
await self.user_store.log_email(user.id, "verification", email)
requires_verification = True
return RegistrationResult(
success=True,
user=user,
requires_verification=requires_verification,
)
async def verify_email(self, token: str) -> AuthResult:
"""
Verify email with token.
Args:
token: Verification token from email.
Returns:
AuthResult with success status.
"""
user = await self.user_store.get_user_by_verification_token(token)
if not user:
return AuthResult(success=False, error="Invalid verification token")
# Check expiration
if user.verification_expires and user.verification_expires < datetime.now(timezone.utc):
return AuthResult(success=False, error="Verification token expired")
# Mark as verified
await self.user_store.clear_verification_token(user.id)
# Refresh user
user = await self.user_store.get_user_by_id(user.id)
return AuthResult(success=True, user=user)
async def resend_verification(self, email: str) -> bool:
"""
Resend verification email.
Args:
email: Email address to send to.
Returns:
True if email was sent.
"""
user = await self.user_store.get_user_by_email(email)
if not user or user.email_verified:
return False
# Generate new token
verification_token = secrets.token_urlsafe(32)
verification_expires = datetime.now(timezone.utc) + timedelta(hours=24)
await self.user_store.update_user(
user.id,
verification_token=verification_token,
verification_expires=verification_expires,
)
await self.email_service.send_verification_email(
to=email,
token=verification_token,
username=user.username,
)
await self.user_store.log_email(user.id, "verification", email)
return True
# -------------------------------------------------------------------------
# Login/Logout
# -------------------------------------------------------------------------
async def login(
self,
username: str,
password: str,
device_info: Optional[dict] = None,
ip_address: Optional[str] = None,
) -> AuthResult:
"""
Authenticate user and create session.
Args:
username: Username or email.
password: Plain text password.
device_info: Client device information.
ip_address: Client IP address.
Returns:
AuthResult with session token or error.
"""
# Try username first, then email
user = await self.user_store.get_user_by_username(username)
if not user:
user = await self.user_store.get_user_by_email(username)
if not user:
return AuthResult(success=False, error="Invalid credentials")
if not user.can_login():
return AuthResult(success=False, error="Account is disabled")
# Check email verification if required
if self.require_email_verification and user.email and not user.email_verified:
return AuthResult(success=False, error="Please verify your email first")
# Verify password
if not self._verify_password(password, user.password_hash):
return AuthResult(success=False, error="Invalid credentials")
# Create session
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(hours=self.session_expiry_hours)
await self.user_store.create_session(
user_id=user.id,
token=token,
expires_at=expires_at,
device_info=device_info,
ip_address=ip_address,
)
# Update last login
await self.user_store.update_user(user.id, last_login=datetime.now(timezone.utc))
return AuthResult(
success=True,
user=user,
token=token,
expires_at=expires_at,
)
async def logout(self, token: str) -> bool:
"""
Invalidate a session.
Args:
token: Session token to invalidate.
Returns:
True if session was revoked.
"""
return await self.user_store.revoke_session_by_token(token)
async def logout_all(self, user_id: str, except_token: Optional[str] = None) -> int:
"""
Invalidate all sessions for a user.
Args:
user_id: User ID.
except_token: Optional token to keep active.
Returns:
Number of sessions revoked.
"""
return await self.user_store.revoke_all_sessions(user_id, except_token)
async def get_user_from_token(self, token: str) -> Optional[User]:
"""
Get user from session token.
Args:
token: Session token.
Returns:
User if valid session, None otherwise.
"""
session = await self.user_store.get_session_by_token(token)
if not session or not session.is_valid():
return None
# Update last used
await self.user_store.update_session_last_used(session.id)
user = await self.user_store.get_user_by_id(session.user_id)
if not user or not user.can_login():
return None
return user
# -------------------------------------------------------------------------
# Password Management
# -------------------------------------------------------------------------
async def forgot_password(self, email: str) -> bool:
"""
Initiate password reset flow.
Args:
email: Email address.
Returns:
True if reset email was sent (always returns True to prevent enumeration).
"""
user = await self.user_store.get_user_by_email(email)
if not user:
# Don't reveal if email exists
return True
# Generate reset token
reset_token = secrets.token_urlsafe(32)
reset_expires = datetime.now(timezone.utc) + timedelta(hours=1)
await self.user_store.update_user(
user.id,
reset_token=reset_token,
reset_expires=reset_expires,
)
await self.email_service.send_password_reset_email(
to=email,
token=reset_token,
username=user.username,
)
await self.user_store.log_email(user.id, "password_reset", email)
return True
async def reset_password(self, token: str, new_password: str) -> AuthResult:
"""
Reset password using token.
Args:
token: Reset token from email.
new_password: New password.
Returns:
AuthResult with success status.
"""
if len(new_password) < 8:
return AuthResult(success=False, error="Password must be at least 8 characters")
user = await self.user_store.get_user_by_reset_token(token)
if not user:
return AuthResult(success=False, error="Invalid reset token")
# Check expiration
if user.reset_expires and user.reset_expires < datetime.now(timezone.utc):
return AuthResult(success=False, error="Reset token expired")
# Update password
password_hash = self._hash_password(new_password)
await self.user_store.update_user(user.id, password_hash=password_hash)
await self.user_store.clear_reset_token(user.id)
# Revoke all sessions
await self.user_store.revoke_all_sessions(user.id)
# Send notification
if user.email:
await self.email_service.send_password_changed_notification(
to=user.email,
username=user.username,
)
await self.user_store.log_email(user.id, "password_changed", user.email)
return AuthResult(success=True, user=user)
async def change_password(
self,
user_id: str,
current_password: str,
new_password: str,
current_token: Optional[str] = None,
) -> AuthResult:
"""
Change password for authenticated user.
Args:
user_id: User ID.
current_password: Current password for verification.
new_password: New password.
current_token: Current session token to keep active.
Returns:
AuthResult with success status.
"""
if len(new_password) < 8:
return AuthResult(success=False, error="Password must be at least 8 characters")
user = await self.user_store.get_user_by_id(user_id)
if not user:
return AuthResult(success=False, error="User not found")
# Verify current password
if not self._verify_password(current_password, user.password_hash):
return AuthResult(success=False, error="Current password is incorrect")
# Update password
password_hash = self._hash_password(new_password)
await self.user_store.update_user(user.id, password_hash=password_hash)
# Revoke all sessions except current
await self.user_store.revoke_all_sessions(user.id, except_token=current_token)
# Send notification
if user.email:
await self.email_service.send_password_changed_notification(
to=user.email,
username=user.username,
)
await self.user_store.log_email(user.id, "password_changed", user.email)
return AuthResult(success=True, user=user)
# -------------------------------------------------------------------------
# User Profile
# -------------------------------------------------------------------------
async def update_preferences(self, user_id: str, preferences: dict) -> Optional[User]:
"""
Update user preferences.
Args:
user_id: User ID.
preferences: New preferences dict.
Returns:
Updated user or None.
"""
return await self.user_store.update_user(user_id, preferences=preferences)
async def get_sessions(self, user_id: str) -> list[UserSession]:
"""
Get all active sessions for a user.
Args:
user_id: User ID.
Returns:
List of active sessions.
"""
return await self.user_store.get_sessions_for_user(user_id)
async def revoke_session(self, user_id: str, session_id: str) -> bool:
"""
Revoke a specific session.
Args:
user_id: User ID (for authorization).
session_id: Session ID to revoke.
Returns:
True if session was revoked.
"""
# Verify session belongs to user
sessions = await self.user_store.get_sessions_for_user(user_id)
if not any(s.id == session_id for s in sessions):
return False
return await self.user_store.revoke_session(session_id)
# -------------------------------------------------------------------------
# Guest Conversion
# -------------------------------------------------------------------------
async def convert_guest(
self,
guest_id: str,
username: str,
password: str,
email: Optional[str] = None,
) -> RegistrationResult:
"""
Convert guest session to full user account.
Args:
guest_id: Guest session ID.
username: Desired username.
password: Password.
email: Optional email.
Returns:
RegistrationResult with user or error.
"""
# Verify guest exists and not already converted
guest = await self.user_store.get_guest_session(guest_id)
if not guest:
return RegistrationResult(success=False, error="Guest session not found")
if guest.is_converted():
return RegistrationResult(success=False, error="Guest already converted")
# Register with guest ID
return await self.register(
username=username,
password=password,
email=email,
guest_id=guest_id,
)
# -------------------------------------------------------------------------
# Account Deletion
# -------------------------------------------------------------------------
async def delete_account(self, user_id: str) -> bool:
"""
Soft delete user account.
Args:
user_id: User ID to delete.
Returns:
True if account was deleted.
"""
# Revoke all sessions
await self.user_store.revoke_all_sessions(user_id)
# Soft delete
user = await self.user_store.update_user(
user_id,
is_active=False,
deleted_at=datetime.now(timezone.utc),
)
return user is not None
# -------------------------------------------------------------------------
# Guest Sessions
# -------------------------------------------------------------------------
async def create_guest_session(
self,
guest_id: str,
display_name: Optional[str] = None,
) -> GuestSession:
"""
Create or get guest session.
Args:
guest_id: Guest session ID.
display_name: Display name for guest.
Returns:
GuestSession.
"""
existing = await self.user_store.get_guest_session(guest_id)
if existing:
await self.user_store.update_guest_last_seen(guest_id)
return existing
return await self.user_store.create_guest_session(guest_id, display_name)
# -------------------------------------------------------------------------
# Password Hashing
# -------------------------------------------------------------------------
def _hash_password(self, password: str) -> str:
"""Hash a password using bcrypt."""
salt = bcrypt.gensalt()
hashed = bcrypt.hashpw(password.encode(), salt)
return hashed.decode()
def _verify_password(self, password: str, password_hash: str) -> bool:
"""Verify a password against its hash."""
try:
return bcrypt.checkpw(password.encode(), password_hash.encode())
except Exception:
return False
# Global auth service instance
_auth_service: Optional[AuthService] = None
async def get_auth_service(user_store: UserStore) -> AuthService:
"""
Get or create the global auth service instance.
Args:
user_store: User persistence store.
Returns:
AuthService instance.
"""
global _auth_service
if _auth_service is None:
_auth_service = await AuthService.create(user_store)
return _auth_service
async def close_auth_service() -> None:
"""Close the global auth service."""
global _auth_service
_auth_service = None

View File

@@ -0,0 +1,215 @@
"""
Email service for Golf game authentication.
Provides email sending via Resend for verification, password reset, and notifications.
"""
import logging
from typing import Optional
from config import config
logger = logging.getLogger(__name__)
class EmailService:
"""
Email service using Resend API.
Handles all transactional emails for authentication:
- Email verification
- Password reset
- Password changed notification
"""
def __init__(self, api_key: str, from_address: str, base_url: str):
"""
Initialize email service.
Args:
api_key: Resend API key.
from_address: Sender email address.
base_url: Base URL for verification/reset links.
"""
self.api_key = api_key
self.from_address = from_address
self.base_url = base_url.rstrip("/")
self._client = None
@classmethod
def create(cls) -> "EmailService":
"""Create EmailService from config."""
return cls(
api_key=config.RESEND_API_KEY,
from_address=config.EMAIL_FROM,
base_url=config.BASE_URL,
)
@property
def client(self):
"""Lazy-load Resend client."""
if self._client is None:
try:
import resend
resend.api_key = self.api_key
self._client = resend
except ImportError:
logger.warning("resend package not installed, emails will be logged only")
self._client = None
return self._client
def is_configured(self) -> bool:
"""Check if email service is properly configured."""
return bool(self.api_key)
async def send_verification_email(
self,
to: str,
token: str,
username: str,
) -> Optional[str]:
"""
Send email verification email.
Args:
to: Recipient email address.
token: Verification token.
username: User's display name.
Returns:
Resend message ID if sent, None if not configured.
"""
if not self.is_configured():
logger.info(f"Email not configured. Would send verification to {to}")
return None
verify_url = f"{self.base_url}/verify-email?token={token}"
subject = "Verify your Golf Game account"
html = f"""
<h2>Welcome to Golf Game, {username}!</h2>
<p>Please verify your email address by clicking the link below:</p>
<p><a href="{verify_url}">Verify Email Address</a></p>
<p>Or copy and paste this URL into your browser:</p>
<p>{verify_url}</p>
<p>This link will expire in 24 hours.</p>
<p>If you didn't create this account, you can safely ignore this email.</p>
"""
return await self._send_email(to, subject, html)
async def send_password_reset_email(
self,
to: str,
token: str,
username: str,
) -> Optional[str]:
"""
Send password reset email.
Args:
to: Recipient email address.
token: Reset token.
username: User's display name.
Returns:
Resend message ID if sent, None if not configured.
"""
if not self.is_configured():
logger.info(f"Email not configured. Would send password reset to {to}")
return None
reset_url = f"{self.base_url}/reset-password?token={token}"
subject = "Reset your Golf Game password"
html = f"""
<h2>Password Reset Request</h2>
<p>Hi {username},</p>
<p>We received a request to reset your password. Click the link below to set a new password:</p>
<p><a href="{reset_url}">Reset Password</a></p>
<p>Or copy and paste this URL into your browser:</p>
<p>{reset_url}</p>
<p>This link will expire in 1 hour.</p>
<p>If you didn't request this, you can safely ignore this email. Your password will remain unchanged.</p>
"""
return await self._send_email(to, subject, html)
async def send_password_changed_notification(
self,
to: str,
username: str,
) -> Optional[str]:
"""
Send password changed notification email.
Args:
to: Recipient email address.
username: User's display name.
Returns:
Resend message ID if sent, None if not configured.
"""
if not self.is_configured():
logger.info(f"Email not configured. Would send password change notification to {to}")
return None
subject = "Your Golf Game password was changed"
html = f"""
<h2>Password Changed</h2>
<p>Hi {username},</p>
<p>Your password was successfully changed.</p>
<p>If you did not make this change, please contact support immediately.</p>
"""
return await self._send_email(to, subject, html)
async def _send_email(
self,
to: str,
subject: str,
html: str,
) -> Optional[str]:
"""
Send an email via Resend.
Args:
to: Recipient email address.
subject: Email subject.
html: HTML email body.
Returns:
Resend message ID if sent, None on error.
"""
if not self.client:
logger.warning(f"Resend not available. Email to {to}: {subject}")
return None
try:
params = {
"from": self.from_address,
"to": [to],
"subject": subject,
"html": html,
}
response = self.client.Emails.send(params)
message_id = response.get("id") if isinstance(response, dict) else getattr(response, "id", None)
logger.info(f"Email sent to {to}: {message_id}")
return message_id
except Exception as e:
logger.error(f"Failed to send email to {to}: {e}")
return None
# Global email service instance
_email_service: Optional[EmailService] = None
def get_email_service() -> EmailService:
"""Get or create the global email service instance."""
global _email_service
if _email_service is None:
_email_service = EmailService.create()
return _email_service

View File

@@ -0,0 +1,223 @@
"""
Redis-based rate limiter service.
Implements a sliding window counter algorithm using Redis for distributed
rate limiting across multiple server instances.
"""
import hashlib
import logging
import time
from typing import Optional
import redis.asyncio as redis
from fastapi import Request, WebSocket
logger = logging.getLogger(__name__)
# Rate limit configurations: (max_requests, window_seconds)
RATE_LIMITS = {
"api_general": (100, 60), # 100 requests per minute
"api_auth": (10, 60), # 10 auth attempts per minute
"api_create_room": (5, 60), # 5 room creations per minute
"websocket_connect": (10, 60), # 10 WS connections per minute
"websocket_message": (30, 10), # 30 messages per 10 seconds
"email_send": (3, 300), # 3 emails per 5 minutes
}
class RateLimiter:
"""Token bucket rate limiter using Redis."""
def __init__(self, redis_client: redis.Redis):
"""
Initialize rate limiter with Redis client.
Args:
redis_client: Async Redis client for state storage.
"""
self.redis = redis_client
async def is_allowed(
self,
key: str,
limit: int,
window_seconds: int,
) -> tuple[bool, dict]:
"""
Check if request is allowed under rate limit.
Uses a sliding window counter algorithm:
- Divides time into fixed windows
- Counts requests in current window
- Atomically increments and checks limit
Args:
key: Unique identifier for the rate limit bucket.
limit: Maximum requests allowed in window.
window_seconds: Time window in seconds.
Returns:
Tuple of (allowed, info) where info contains:
- remaining: requests remaining in window
- reset: seconds until window resets
- limit: the limit that was applied
"""
now = int(time.time())
window_key = f"ratelimit:{key}:{now // window_seconds}"
try:
async with self.redis.pipeline(transaction=True) as pipe:
pipe.incr(window_key)
pipe.expire(window_key, window_seconds + 1) # Extra second for safety
results = await pipe.execute()
current_count = results[0]
remaining = max(0, limit - current_count)
reset = window_seconds - (now % window_seconds)
info = {
"remaining": remaining,
"reset": reset,
"limit": limit,
}
allowed = current_count <= limit
if not allowed:
logger.warning(f"Rate limit exceeded for {key}: {current_count}/{limit}")
return allowed, info
except redis.RedisError as e:
# If Redis is unavailable, fail open (allow request)
logger.error(f"Rate limiter Redis error: {e}")
return True, {"remaining": limit, "reset": window_seconds, "limit": limit}
def get_client_key(
self,
request: Request | WebSocket,
user_id: Optional[str] = None,
) -> str:
"""
Generate rate limit key for client.
Uses user ID if authenticated, otherwise hashes client IP.
Args:
request: HTTP request or WebSocket.
user_id: Authenticated user ID, if available.
Returns:
Unique client identifier string.
"""
if user_id:
return f"user:{user_id}"
# For anonymous users, use IP hash
client_ip = self._get_client_ip(request)
# Hash IP for privacy
ip_hash = hashlib.sha256(client_ip.encode()).hexdigest()[:16]
return f"ip:{ip_hash}"
def _get_client_ip(self, request: Request | WebSocket) -> str:
"""
Extract client IP from request, handling proxies.
Args:
request: HTTP request or WebSocket.
Returns:
Client IP address string.
"""
# Check X-Forwarded-For header (from reverse proxy)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# Take the first IP (original client)
return forwarded.split(",")[0].strip()
# Check X-Real-IP header (nginx)
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip.strip()
# Fall back to direct connection
if request.client:
return request.client.host
return "unknown"
class ConnectionMessageLimiter:
"""
In-memory rate limiter for WebSocket message frequency.
Used to limit messages within a single connection without
requiring Redis round-trips for every message.
"""
def __init__(self, max_messages: int = 30, window_seconds: int = 10):
"""
Initialize connection message limiter.
Args:
max_messages: Maximum messages allowed in window.
window_seconds: Time window in seconds.
"""
self.max_messages = max_messages
self.window_seconds = window_seconds
self.timestamps: list[float] = []
def check(self) -> bool:
"""
Check if another message is allowed.
Maintains a sliding window of message timestamps.
Returns:
True if message is allowed, False if rate limited.
"""
now = time.time()
cutoff = now - self.window_seconds
# Remove old timestamps
self.timestamps = [t for t in self.timestamps if t > cutoff]
# Check limit
if len(self.timestamps) >= self.max_messages:
return False
# Record this message
self.timestamps.append(now)
return True
def reset(self):
"""Reset the limiter (e.g., on reconnection)."""
self.timestamps = []
# Global rate limiter instance
_rate_limiter: Optional[RateLimiter] = None
async def get_rate_limiter(redis_client: redis.Redis) -> RateLimiter:
"""
Get or create the global rate limiter instance.
Args:
redis_client: Redis client for state storage.
Returns:
RateLimiter instance.
"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter(redis_client)
return _rate_limiter
def close_rate_limiter():
"""Close the global rate limiter."""
global _rate_limiter
_rate_limiter = None

View File

@@ -0,0 +1,353 @@
"""
Game recovery service for rebuilding active games from event store.
On server restart, all in-memory game state is lost. This service:
1. Queries the event store for active games
2. Rebuilds game state by replaying events
3. Caches the rebuilt state in Redis
4. Handles partial recovery (applying only new events to cached state)
This ensures games can survive server restarts without data loss.
Usage:
recovery = RecoveryService(event_store, state_cache)
results = await recovery.recover_all_games()
print(f"Recovered {results['recovered']} games")
"""
import logging
from dataclasses import dataclass
from typing import Optional, Any
from stores.event_store import EventStore
from stores.state_cache import StateCache
from models.events import EventType
from models.game_state import RebuiltGameState, rebuild_state, CardState
logger = logging.getLogger(__name__)
@dataclass
class RecoveryResult:
"""Result of a game recovery attempt."""
game_id: str
room_code: str
success: bool
phase: Optional[str] = None
sequence_num: int = 0
error: Optional[str] = None
class RecoveryService:
"""
Recovers games from event store on startup.
Works with the event store (PostgreSQL) as source of truth
and state cache (Redis) for fast access during gameplay.
"""
def __init__(
self,
event_store: EventStore,
state_cache: StateCache,
):
"""
Initialize recovery service.
Args:
event_store: PostgreSQL event store.
state_cache: Redis state cache.
"""
self.event_store = event_store
self.state_cache = state_cache
async def recover_all_games(self) -> dict[str, Any]:
"""
Recover all active games from event store.
Queries PostgreSQL for active games and rebuilds their state
from events, then caches in Redis.
Returns:
Dict with recovery statistics:
- recovered: Number of games successfully recovered
- failed: Number of games that failed recovery
- skipped: Number of games skipped (already ended)
- games: List of recovered game info
"""
results = {
"recovered": 0,
"failed": 0,
"skipped": 0,
"games": [],
}
# Get active games from PostgreSQL
active_games = await self.event_store.get_active_games()
logger.info(f"Found {len(active_games)} active games to recover")
for game_meta in active_games:
game_id = str(game_meta["id"])
room_code = game_meta["room_code"]
try:
result = await self.recover_game(game_id, room_code)
if result.success:
results["recovered"] += 1
results["games"].append({
"game_id": game_id,
"room_code": room_code,
"phase": result.phase,
"sequence": result.sequence_num,
})
else:
if result.error == "game_ended":
results["skipped"] += 1
else:
results["failed"] += 1
logger.warning(f"Failed to recover {game_id}: {result.error}")
except Exception as e:
logger.error(f"Error recovering game {game_id}: {e}", exc_info=True)
results["failed"] += 1
return results
async def recover_game(
self,
game_id: str,
room_code: Optional[str] = None,
) -> RecoveryResult:
"""
Recover a single game from event store.
Args:
game_id: Game UUID.
room_code: Room code (optional, will be read from events).
Returns:
RecoveryResult with success status and game info.
"""
# Get all events for this game
events = await self.event_store.get_events(game_id)
if not events:
return RecoveryResult(
game_id=game_id,
room_code=room_code or "",
success=False,
error="no_events",
)
# Check if game is actually active (not ended)
last_event = events[-1]
if last_event.event_type == EventType.GAME_ENDED:
return RecoveryResult(
game_id=game_id,
room_code=room_code or "",
success=False,
error="game_ended",
)
# Rebuild state from events
state = rebuild_state(events)
# Get room code from state if not provided
if not room_code:
room_code = state.room_code
# Convert state to cacheable dict
state_dict = self._state_to_dict(state)
# Save to Redis cache
await self.state_cache.save_game_state(game_id, state_dict)
# Also create/update room in cache
await self._ensure_room_in_cache(state)
logger.info(
f"Recovered game {game_id} (room {room_code}) "
f"at sequence {state.sequence_num}, phase {state.phase.value}"
)
return RecoveryResult(
game_id=game_id,
room_code=room_code,
success=True,
phase=state.phase.value,
sequence_num=state.sequence_num,
)
async def recover_from_sequence(
self,
game_id: str,
cached_state: dict,
cached_sequence: int,
) -> Optional[dict]:
"""
Recover game by applying only new events to cached state.
More efficient than full rebuild when we have a recent cache.
Args:
game_id: Game UUID.
cached_state: Previously cached state dict.
cached_sequence: Sequence number of cached state.
Returns:
Updated state dict, or None if no new events.
"""
# Get events after cached sequence
new_events = await self.event_store.get_events(
game_id,
from_sequence=cached_sequence + 1,
)
if not new_events:
return None # No new events
# Rebuild state from cache + new events
state = self._dict_to_state(cached_state)
for event in new_events:
state.apply(event)
# Convert back to dict
new_state = self._state_to_dict(state)
# Update cache
await self.state_cache.save_game_state(game_id, new_state)
return new_state
async def _ensure_room_in_cache(self, state: RebuiltGameState) -> None:
"""
Ensure room exists in Redis cache after recovery.
Args:
state: Rebuilt game state.
"""
room_code = state.room_code
if not room_code:
return
# Check if room already exists
if await self.state_cache.room_exists(room_code):
return
# Create room in cache
await self.state_cache.create_room(
room_code=room_code,
game_id=state.game_id,
host_id=state.host_id or "",
server_id="recovered",
)
# Set room status based on game phase
if state.phase.value == "waiting":
status = "waiting"
elif state.phase.value in ("game_over", "round_over"):
status = "finished"
else:
status = "playing"
await self.state_cache.set_room_status(room_code, status)
def _state_to_dict(self, state: RebuiltGameState) -> dict:
"""
Convert RebuiltGameState to dict for caching.
Args:
state: Game state to convert.
Returns:
Cacheable dict representation.
"""
return {
"game_id": state.game_id,
"room_code": state.room_code,
"phase": state.phase.value,
"current_round": state.current_round,
"total_rounds": state.total_rounds,
"current_player_idx": state.current_player_idx,
"player_order": state.player_order,
"players": {
pid: {
"id": p.id,
"name": p.name,
"cards": [c.to_dict() for c in p.cards],
"score": p.score,
"total_score": p.total_score,
"rounds_won": p.rounds_won,
"is_cpu": p.is_cpu,
"cpu_profile": p.cpu_profile,
}
for pid, p in state.players.items()
},
"deck_remaining": state.deck_remaining,
"discard_pile": [c.to_dict() for c in state.discard_pile],
"discard_top": state.discard_pile[-1].to_dict() if state.discard_pile else None,
"drawn_card": state.drawn_card.to_dict() if state.drawn_card else None,
"drawn_from_discard": state.drawn_from_discard,
"options": state.options,
"sequence_num": state.sequence_num,
"finisher_id": state.finisher_id,
"host_id": state.host_id,
"initial_flips_done": list(state.initial_flips_done),
"players_with_final_turn": list(state.players_with_final_turn),
}
def _dict_to_state(self, d: dict) -> RebuiltGameState:
"""
Convert dict back to RebuiltGameState.
Args:
d: Cached state dict.
Returns:
Reconstructed game state.
"""
from models.game_state import GamePhase, PlayerState
state = RebuiltGameState(game_id=d["game_id"])
state.room_code = d.get("room_code", "")
state.phase = GamePhase(d.get("phase", "waiting"))
state.current_round = d.get("current_round", 0)
state.total_rounds = d.get("total_rounds", 1)
state.current_player_idx = d.get("current_player_idx", 0)
state.player_order = d.get("player_order", [])
state.deck_remaining = d.get("deck_remaining", 0)
state.options = d.get("options", {})
state.sequence_num = d.get("sequence_num", 0)
state.finisher_id = d.get("finisher_id")
state.host_id = d.get("host_id")
state.initial_flips_done = set(d.get("initial_flips_done", []))
state.players_with_final_turn = set(d.get("players_with_final_turn", []))
state.drawn_from_discard = d.get("drawn_from_discard", False)
# Rebuild players
players_data = d.get("players", {})
for pid, pdata in players_data.items():
player = PlayerState(
id=pdata["id"],
name=pdata["name"],
is_cpu=pdata.get("is_cpu", False),
cpu_profile=pdata.get("cpu_profile"),
score=pdata.get("score", 0),
total_score=pdata.get("total_score", 0),
rounds_won=pdata.get("rounds_won", 0),
)
player.cards = [CardState.from_dict(c) for c in pdata.get("cards", [])]
state.players[pid] = player
# Rebuild discard pile
discard_data = d.get("discard_pile", [])
state.discard_pile = [CardState.from_dict(c) for c in discard_data]
# Rebuild drawn card
drawn = d.get("drawn_card")
if drawn:
state.drawn_card = CardState.from_dict(drawn)
return state

View File

@@ -0,0 +1,583 @@
"""
Replay service for Golf game.
Provides game replay functionality, share link generation, and game export/import.
Leverages the event-sourced architecture for perfect game reconstruction.
"""
import json
import logging
import secrets
from dataclasses import dataclass, asdict
from datetime import datetime, timezone, timedelta
from typing import Optional, List
import asyncpg
from stores.event_store import EventStore
from models.events import GameEvent, EventType
from models.game_state import rebuild_state, RebuiltGameState, CardState
logger = logging.getLogger(__name__)
# SQL schema for replay/sharing tables
REPLAY_SCHEMA_SQL = """
-- Public share links for completed games
CREATE TABLE IF NOT EXISTS shared_games (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
game_id UUID NOT NULL,
share_code VARCHAR(12) UNIQUE NOT NULL,
created_by VARCHAR(50),
created_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ,
view_count INTEGER DEFAULT 0,
is_public BOOLEAN DEFAULT true,
title VARCHAR(100),
description TEXT
);
CREATE INDEX IF NOT EXISTS idx_shared_games_code ON shared_games(share_code);
CREATE INDEX IF NOT EXISTS idx_shared_games_game ON shared_games(game_id);
-- Track replay views for analytics
CREATE TABLE IF NOT EXISTS replay_views (
id SERIAL PRIMARY KEY,
shared_game_id UUID REFERENCES shared_games(id),
viewer_id VARCHAR(50),
viewed_at TIMESTAMPTZ DEFAULT NOW(),
ip_hash VARCHAR(64),
watch_duration_seconds INTEGER
);
CREATE INDEX IF NOT EXISTS idx_replay_views_shared ON replay_views(shared_game_id);
"""
@dataclass
class ReplayFrame:
"""Single frame in a replay."""
event_index: int
event_type: str
event_data: dict
game_state: dict
timestamp: float # Seconds from start
player_id: Optional[str] = None
@dataclass
class GameReplay:
"""Complete replay of a game."""
game_id: str
frames: List[ReplayFrame]
total_duration_seconds: float
player_names: List[str]
final_scores: dict
winner: Optional[str]
options: dict
room_code: str
total_rounds: int
class ReplayService:
"""
Service for game replay, export, and sharing.
Provides:
- Replay building from event store
- Share link creation and retrieval
- Game export/import
"""
EXPORT_VERSION = "1.0"
def __init__(self, pool: asyncpg.Pool, event_store: EventStore):
"""
Initialize replay service.
Args:
pool: asyncpg connection pool.
event_store: Event store for retrieving game events.
"""
self.pool = pool
self.event_store = event_store
async def initialize_schema(self) -> None:
"""Create replay tables if they don't exist."""
async with self.pool.acquire() as conn:
await conn.execute(REPLAY_SCHEMA_SQL)
logger.info("Replay schema initialized")
# -------------------------------------------------------------------------
# Replay Building
# -------------------------------------------------------------------------
async def build_replay(self, game_id: str) -> GameReplay:
"""
Build complete replay from event store.
Args:
game_id: Game UUID.
Returns:
GameReplay with all frames and metadata.
Raises:
ValueError: If no events found for game.
"""
events = await self.event_store.get_events(game_id)
if not events:
raise ValueError(f"No events found for game {game_id}")
frames = []
state = RebuiltGameState(game_id=game_id)
start_time = None
for i, event in enumerate(events):
if start_time is None:
start_time = event.timestamp
# Apply event to get state
state.apply(event)
# Calculate timestamp relative to start
elapsed = (event.timestamp - start_time).total_seconds()
frames.append(ReplayFrame(
event_index=i,
event_type=event.event_type.value,
event_data=event.data,
game_state=self._state_to_dict(state),
timestamp=elapsed,
player_id=event.player_id,
))
# Extract final game info
player_names = [p.name for p in state.players.values()]
final_scores = {p.name: p.total_score for p in state.players.values()}
# Determine winner (lowest total score)
winner = None
if state.phase.value == "game_over" and state.players:
winner_player = min(state.players.values(), key=lambda p: p.total_score)
winner = winner_player.name
return GameReplay(
game_id=game_id,
frames=frames,
total_duration_seconds=frames[-1].timestamp if frames else 0,
player_names=player_names,
final_scores=final_scores,
winner=winner,
options=state.options,
room_code=state.room_code,
total_rounds=state.total_rounds,
)
async def get_replay_frame(
self,
game_id: str,
frame_index: int
) -> Optional[ReplayFrame]:
"""
Get a specific frame from a replay.
Useful for seeking to a specific point without loading entire replay.
Args:
game_id: Game UUID.
frame_index: Index of frame to retrieve (0-based).
Returns:
ReplayFrame or None if index out of range.
"""
events = await self.event_store.get_events(
game_id,
from_sequence=1,
to_sequence=frame_index + 1
)
if not events or len(events) <= frame_index:
return None
state = RebuiltGameState(game_id=game_id)
start_time = events[0].timestamp if events else None
for event in events:
state.apply(event)
last_event = events[-1]
elapsed = (last_event.timestamp - start_time).total_seconds() if start_time else 0
return ReplayFrame(
event_index=frame_index,
event_type=last_event.event_type.value,
event_data=last_event.data,
game_state=self._state_to_dict(state),
timestamp=elapsed,
player_id=last_event.player_id,
)
def _state_to_dict(self, state: RebuiltGameState) -> dict:
"""Convert RebuiltGameState to serializable dict."""
players = []
for pid in state.player_order:
if pid in state.players:
p = state.players[pid]
players.append({
"id": p.id,
"name": p.name,
"cards": [c.to_dict() for c in p.cards],
"score": p.score,
"total_score": p.total_score,
"rounds_won": p.rounds_won,
"is_cpu": p.is_cpu,
"all_face_up": p.all_face_up(),
})
return {
"phase": state.phase.value,
"players": players,
"current_player_idx": state.current_player_idx,
"current_player_id": state.player_order[state.current_player_idx] if state.player_order else None,
"deck_remaining": state.deck_remaining,
"discard_pile": [c.to_dict() for c in state.discard_pile],
"discard_top": state.discard_pile[-1].to_dict() if state.discard_pile else None,
"drawn_card": state.drawn_card.to_dict() if state.drawn_card else None,
"current_round": state.current_round,
"total_rounds": state.total_rounds,
"finisher_id": state.finisher_id,
"options": state.options,
}
# -------------------------------------------------------------------------
# Share Links
# -------------------------------------------------------------------------
async def create_share_link(
self,
game_id: str,
user_id: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
expires_days: Optional[int] = None,
) -> str:
"""
Generate shareable link for a game.
Args:
game_id: Game UUID.
user_id: ID of user creating the share.
title: Optional custom title.
description: Optional description.
expires_days: Days until link expires (None = never).
Returns:
12-character share code.
"""
share_code = secrets.token_urlsafe(9)[:12]
expires_at = None
if expires_days:
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_days)
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO shared_games
(game_id, share_code, created_by, title, description, expires_at)
VALUES ($1, $2, $3, $4, $5, $6)
""", game_id, share_code, user_id, title, description, expires_at)
logger.info(f"Created share link {share_code} for game {game_id}")
return share_code
async def get_shared_game(self, share_code: str) -> Optional[dict]:
"""
Retrieve shared game by code.
Args:
share_code: 12-character share code.
Returns:
Shared game metadata dict, or None if not found/expired.
"""
async with self.pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT sg.*, g.room_code, g.completed_at, g.num_players, g.num_rounds
FROM shared_games sg
JOIN games_v2 g ON sg.game_id = g.id
WHERE sg.share_code = $1
AND sg.is_public = true
AND (sg.expires_at IS NULL OR sg.expires_at > NOW())
""", share_code)
if row:
# Increment view count
await conn.execute("""
UPDATE shared_games SET view_count = view_count + 1
WHERE share_code = $1
""", share_code)
return dict(row)
return None
async def record_replay_view(
self,
shared_game_id: str,
viewer_id: Optional[str] = None,
ip_hash: Optional[str] = None,
duration_seconds: Optional[int] = None,
) -> None:
"""
Record a replay view for analytics.
Args:
shared_game_id: UUID of the shared_games record.
viewer_id: Optional user ID of viewer.
ip_hash: Optional hashed IP for rate limiting.
duration_seconds: Optional watch duration.
"""
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO replay_views
(shared_game_id, viewer_id, ip_hash, watch_duration_seconds)
VALUES ($1, $2, $3, $4)
""", shared_game_id, viewer_id, ip_hash, duration_seconds)
async def get_user_shared_games(self, user_id: str) -> List[dict]:
"""
Get all shared games created by a user.
Args:
user_id: User ID.
Returns:
List of shared game metadata dicts.
"""
async with self.pool.acquire() as conn:
rows = await conn.fetch("""
SELECT sg.*, g.room_code, g.completed_at
FROM shared_games sg
JOIN games_v2 g ON sg.game_id = g.id
WHERE sg.created_by = $1
ORDER BY sg.created_at DESC
""", user_id)
return [dict(row) for row in rows]
async def delete_share_link(self, share_code: str, user_id: str) -> bool:
"""
Delete a share link.
Args:
share_code: Share code to delete.
user_id: User requesting deletion (must be creator).
Returns:
True if deleted, False if not found or not authorized.
"""
async with self.pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM shared_games
WHERE share_code = $1 AND created_by = $2
""", share_code, user_id)
return result == "DELETE 1"
# -------------------------------------------------------------------------
# Export/Import
# -------------------------------------------------------------------------
async def export_game(self, game_id: str) -> dict:
"""
Export game as portable JSON format.
Args:
game_id: Game UUID.
Returns:
Export data dict suitable for JSON serialization.
"""
replay = await self.build_replay(game_id)
# Get raw events for export
events = await self.event_store.get_events(game_id)
start_time = events[0].timestamp if events else datetime.now(timezone.utc)
return {
"version": self.EXPORT_VERSION,
"exported_at": datetime.now(timezone.utc).isoformat(),
"game": {
"id": replay.game_id,
"room_code": replay.room_code,
"players": replay.player_names,
"winner": replay.winner,
"final_scores": replay.final_scores,
"duration_seconds": replay.total_duration_seconds,
"total_rounds": replay.total_rounds,
"options": replay.options,
},
"events": [
{
"type": event.event_type.value,
"sequence": event.sequence_num,
"player_id": event.player_id,
"data": event.data,
"timestamp": (event.timestamp - start_time).total_seconds(),
}
for event in events
],
}
async def import_game(self, export_data: dict, user_id: str) -> str:
"""
Import a game from exported JSON.
Creates a new game record with the imported events.
Args:
export_data: Exported game data.
user_id: User performing the import.
Returns:
New game ID.
Raises:
ValueError: If export format is invalid.
"""
version = export_data.get("version")
if version != self.EXPORT_VERSION:
raise ValueError(f"Unsupported export version: {version}")
if "events" not in export_data or not export_data["events"]:
raise ValueError("Export contains no events")
# Generate new game ID
import uuid
new_game_id = str(uuid.uuid4())
# Calculate base timestamp
base_time = datetime.now(timezone.utc)
# Import events with new game ID
events = []
for event_data in export_data["events"]:
event = GameEvent(
event_type=EventType(event_data["type"]),
game_id=new_game_id,
sequence_num=event_data["sequence"],
player_id=event_data.get("player_id"),
data=event_data["data"],
timestamp=base_time + timedelta(seconds=event_data.get("timestamp", 0)),
)
events.append(event)
# Batch insert events
await self.event_store.append_batch(events)
# Create game metadata record
game_info = export_data.get("game", {})
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO games_v2
(id, room_code, status, num_rounds, options, completed_at)
VALUES ($1, $2, 'imported', $3, $4, NOW())
""",
new_game_id,
f"IMP-{secrets.token_hex(2).upper()}", # Generate room code for imported games
game_info.get("total_rounds", 1),
json.dumps(game_info.get("options", {})),
)
logger.info(f"Imported game as {new_game_id} by user {user_id}")
return new_game_id
# -------------------------------------------------------------------------
# Game History Queries
# -------------------------------------------------------------------------
async def get_user_game_history(
self,
user_id: str,
limit: int = 20,
offset: int = 0,
) -> List[dict]:
"""
Get game history for a user.
Args:
user_id: User ID.
limit: Max games to return.
offset: Pagination offset.
Returns:
List of game summary dicts.
"""
async with self.pool.acquire() as conn:
rows = await conn.fetch("""
SELECT g.id, g.room_code, g.status, g.completed_at,
g.num_players, g.num_rounds, g.winner_id,
$1 = ANY(g.player_ids) as participated
FROM games_v2 g
WHERE $1 = ANY(g.player_ids)
AND g.status IN ('completed', 'imported')
ORDER BY g.completed_at DESC NULLS LAST
LIMIT $2 OFFSET $3
""", user_id, limit, offset)
return [dict(row) for row in rows]
async def can_view_game(self, user_id: Optional[str], game_id: str) -> bool:
"""
Check if user can view a game replay.
Users can view games they played in or games that are shared publicly.
Args:
user_id: User ID (None for anonymous).
game_id: Game UUID.
Returns:
True if user can view the game.
"""
async with self.pool.acquire() as conn:
# Check if user played in the game
if user_id:
row = await conn.fetchrow("""
SELECT 1 FROM games_v2
WHERE id = $1 AND $2 = ANY(player_ids)
""", game_id, user_id)
if row:
return True
# Check if game has a public share link
row = await conn.fetchrow("""
SELECT 1 FROM shared_games
WHERE game_id = $1
AND is_public = true
AND (expires_at IS NULL OR expires_at > NOW())
""", game_id)
return row is not None
# Global instance
_replay_service: Optional[ReplayService] = None
async def get_replay_service(pool: asyncpg.Pool, event_store: EventStore) -> ReplayService:
"""Get or create the replay service instance."""
global _replay_service
if _replay_service is None:
_replay_service = ReplayService(pool, event_store)
await _replay_service.initialize_schema()
return _replay_service
def set_replay_service(service: ReplayService) -> None:
"""Set the global replay service instance."""
global _replay_service
_replay_service = service
def close_replay_service() -> None:
"""Close the replay service."""
global _replay_service
_replay_service = None

View File

@@ -0,0 +1,265 @@
"""
Spectator manager for Golf game.
Enables spectators to watch live games in progress via WebSocket connections.
Spectators receive game state updates but cannot interact with the game.
"""
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from datetime import datetime, timezone
from fastapi import WebSocket
logger = logging.getLogger(__name__)
# Maximum spectators per game to prevent resource exhaustion
MAX_SPECTATORS_PER_GAME = 50
@dataclass
class SpectatorInfo:
"""Information about a spectator connection."""
websocket: WebSocket
joined_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
user_id: Optional[str] = None
username: Optional[str] = None
class SpectatorManager:
"""
Manage spectators watching live games.
Spectators can join any active game and receive real-time updates.
They see the same state as players but cannot take actions.
"""
def __init__(self):
# game_id -> list of SpectatorInfo
self._spectators: Dict[str, List[SpectatorInfo]] = {}
# websocket -> game_id (for reverse lookup on disconnect)
self._ws_to_game: Dict[WebSocket, str] = {}
async def add_spectator(
self,
game_id: str,
websocket: WebSocket,
user_id: Optional[str] = None,
username: Optional[str] = None,
) -> bool:
"""
Add spectator to a game.
Args:
game_id: Game UUID.
websocket: Spectator's WebSocket connection.
user_id: Optional user ID.
username: Optional display name.
Returns:
True if added, False if game is at spectator limit.
"""
if game_id not in self._spectators:
self._spectators[game_id] = []
# Check spectator limit
if len(self._spectators[game_id]) >= MAX_SPECTATORS_PER_GAME:
logger.warning(f"Game {game_id} at spectator limit ({MAX_SPECTATORS_PER_GAME})")
return False
info = SpectatorInfo(
websocket=websocket,
user_id=user_id,
username=username or "Spectator",
)
self._spectators[game_id].append(info)
self._ws_to_game[websocket] = game_id
logger.info(f"Spectator joined game {game_id} (total: {len(self._spectators[game_id])})")
return True
async def remove_spectator(self, game_id: str, websocket: WebSocket) -> None:
"""
Remove spectator from a game.
Args:
game_id: Game UUID.
websocket: Spectator's WebSocket connection.
"""
if game_id in self._spectators:
# Find and remove the spectator
self._spectators[game_id] = [
info for info in self._spectators[game_id]
if info.websocket != websocket
]
logger.info(f"Spectator left game {game_id} (remaining: {len(self._spectators[game_id])})")
# Clean up empty games
if not self._spectators[game_id]:
del self._spectators[game_id]
# Clean up reverse lookup
self._ws_to_game.pop(websocket, None)
async def remove_spectator_by_ws(self, websocket: WebSocket) -> None:
"""
Remove spectator by WebSocket (for disconnect handling).
Args:
websocket: Spectator's WebSocket connection.
"""
game_id = self._ws_to_game.get(websocket)
if game_id:
await self.remove_spectator(game_id, websocket)
async def broadcast_to_spectators(self, game_id: str, message: dict) -> None:
"""
Send update to all spectators of a game.
Args:
game_id: Game UUID.
message: Message to broadcast.
"""
if game_id not in self._spectators:
return
dead_connections: List[SpectatorInfo] = []
for info in self._spectators[game_id]:
try:
await info.websocket.send_json(message)
except Exception as e:
logger.debug(f"Failed to send to spectator: {e}")
dead_connections.append(info)
# Clean up dead connections
for info in dead_connections:
self._spectators[game_id] = [
s for s in self._spectators[game_id]
if s.websocket != info.websocket
]
self._ws_to_game.pop(info.websocket, None)
# Clean up empty games
if game_id in self._spectators and not self._spectators[game_id]:
del self._spectators[game_id]
async def send_game_state(
self,
game_id: str,
game_state: dict,
event_type: Optional[str] = None,
) -> None:
"""
Send current game state to all spectators.
Args:
game_id: Game UUID.
game_state: Current game state dict.
event_type: Optional event type that triggered this update.
"""
message = {
"type": "game_state",
"game_state": game_state,
"spectator_count": self.get_spectator_count(game_id),
}
if event_type:
message["event_type"] = event_type
await self.broadcast_to_spectators(game_id, message)
def get_spectator_count(self, game_id: str) -> int:
"""
Get number of spectators for a game.
Args:
game_id: Game UUID.
Returns:
Spectator count.
"""
return len(self._spectators.get(game_id, []))
def get_spectator_usernames(self, game_id: str) -> list[str]:
"""
Get list of spectator usernames.
Args:
game_id: Game UUID.
Returns:
List of spectator usernames.
"""
if game_id not in self._spectators:
return []
return [
info.username or "Anonymous"
for info in self._spectators[game_id]
]
def get_games_with_spectators(self) -> dict[str, int]:
"""
Get all games that have spectators.
Returns:
Dict of game_id -> spectator count.
"""
return {
game_id: len(spectators)
for game_id, spectators in self._spectators.items()
if spectators
}
async def notify_game_ended(self, game_id: str, final_state: dict) -> None:
"""
Notify spectators that a game has ended.
Args:
game_id: Game UUID.
final_state: Final game state with scores.
"""
await self.broadcast_to_spectators(game_id, {
"type": "game_ended",
"final_state": final_state,
})
async def close_all_for_game(self, game_id: str) -> None:
"""
Close all spectator connections for a game.
Use when a game is being cleaned up.
Args:
game_id: Game UUID.
"""
if game_id not in self._spectators:
return
for info in list(self._spectators[game_id]):
try:
await info.websocket.close(code=1000, reason="Game ended")
except Exception:
pass
self._ws_to_game.pop(info.websocket, None)
del self._spectators[game_id]
logger.info(f"Closed all spectators for game {game_id}")
# Global instance
_spectator_manager: Optional[SpectatorManager] = None
def get_spectator_manager() -> SpectatorManager:
"""Get the global spectator manager instance."""
global _spectator_manager
if _spectator_manager is None:
_spectator_manager = SpectatorManager()
return _spectator_manager
def close_spectator_manager() -> None:
"""Close the spectator manager."""
global _spectator_manager
_spectator_manager = None

View File

@@ -0,0 +1,977 @@
"""
Stats service for Golf game leaderboards and achievements.
Provides player statistics aggregation, leaderboard queries, and achievement tracking.
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional, List
from uuid import UUID
import asyncpg
from stores.event_store import EventStore
from models.events import EventType
logger = logging.getLogger(__name__)
@dataclass
class PlayerStats:
"""Full player statistics."""
user_id: str
username: str
games_played: int = 0
games_won: int = 0
win_rate: float = 0.0
rounds_played: int = 0
rounds_won: int = 0
avg_score: float = 0.0
best_round_score: Optional[int] = None
worst_round_score: Optional[int] = None
knockouts: int = 0
perfect_rounds: int = 0
wolfpacks: int = 0
current_win_streak: int = 0
best_win_streak: int = 0
first_game_at: Optional[datetime] = None
last_game_at: Optional[datetime] = None
achievements: List[str] = field(default_factory=list)
@dataclass
class LeaderboardEntry:
"""Single entry on a leaderboard."""
rank: int
user_id: str
username: str
value: float
games_played: int
secondary_value: Optional[float] = None
@dataclass
class Achievement:
"""Achievement definition."""
id: str
name: str
description: str
icon: str
category: str
threshold: int
@dataclass
class UserAchievement:
"""Achievement earned by a user."""
id: str
name: str
description: str
icon: str
earned_at: datetime
game_id: Optional[str] = None
class StatsService:
"""
Player statistics and leaderboards service.
Provides methods for:
- Querying player stats
- Fetching leaderboards by various metrics
- Processing game completion for stats aggregation
- Achievement checking and awarding
"""
def __init__(self, pool: asyncpg.Pool, event_store: Optional[EventStore] = None):
"""
Initialize stats service.
Args:
pool: asyncpg connection pool.
event_store: Optional EventStore for event-based stats processing.
"""
self.pool = pool
self.event_store = event_store
# -------------------------------------------------------------------------
# Stats Queries
# -------------------------------------------------------------------------
async def get_player_stats(self, user_id: str) -> Optional[PlayerStats]:
"""
Get full stats for a specific player.
Args:
user_id: User UUID.
Returns:
PlayerStats or None if player not found.
"""
async with self.pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT s.*, u.username,
ROUND(s.games_won::numeric / NULLIF(s.games_played, 0) * 100, 1) as win_rate,
ROUND(s.total_points::numeric / NULLIF(s.total_rounds, 0), 1) as avg_score_calc
FROM player_stats s
JOIN users_v2 u ON s.user_id = u.id
WHERE s.user_id = $1
""", user_id)
if not row:
# Check if user exists but has no stats
user_row = await conn.fetchrow(
"SELECT username FROM users_v2 WHERE id = $1",
user_id
)
if user_row:
return PlayerStats(
user_id=user_id,
username=user_row["username"],
)
return None
# Get achievements
achievements = await conn.fetch("""
SELECT achievement_id FROM user_achievements
WHERE user_id = $1
""", user_id)
return PlayerStats(
user_id=str(row["user_id"]),
username=row["username"],
games_played=row["games_played"] or 0,
games_won=row["games_won"] or 0,
win_rate=float(row["win_rate"] or 0),
rounds_played=row["total_rounds"] or 0,
rounds_won=row["rounds_won"] or 0,
avg_score=float(row["avg_score_calc"] or 0),
best_round_score=row["best_score"],
worst_round_score=row["worst_score"],
knockouts=row["knockouts"] or 0,
perfect_rounds=row["perfect_rounds"] or 0,
wolfpacks=row["wolfpacks"] or 0,
current_win_streak=row["current_win_streak"] or 0,
best_win_streak=row["best_win_streak"] or 0,
first_game_at=row["first_game_at"].replace(tzinfo=timezone.utc) if row["first_game_at"] else None,
last_game_at=row["last_game_at"].replace(tzinfo=timezone.utc) if row["last_game_at"] else None,
achievements=[a["achievement_id"] for a in achievements],
)
async def get_leaderboard(
self,
metric: str = "wins",
limit: int = 50,
offset: int = 0,
) -> List[LeaderboardEntry]:
"""
Get leaderboard by metric.
Args:
metric: Ranking metric - wins, win_rate, avg_score, knockouts, streak.
limit: Maximum entries to return.
offset: Pagination offset.
Returns:
List of LeaderboardEntry sorted by metric.
"""
order_map = {
"wins": ("games_won", "DESC"),
"win_rate": ("win_rate", "DESC"),
"avg_score": ("avg_score", "ASC"), # Lower is better
"knockouts": ("knockouts", "DESC"),
"streak": ("best_win_streak", "DESC"),
}
if metric not in order_map:
metric = "wins"
column, direction = order_map[metric]
async with self.pool.acquire() as conn:
# Check if materialized view exists
view_exists = await conn.fetchval(
"SELECT 1 FROM pg_matviews WHERE matviewname = 'leaderboard_overall'"
)
if view_exists:
# Use materialized view for performance
rows = await conn.fetch(f"""
SELECT
user_id, username, games_played, games_won,
win_rate, avg_score, knockouts, best_win_streak,
ROW_NUMBER() OVER (ORDER BY {column} {direction}) as rank
FROM leaderboard_overall
ORDER BY {column} {direction}
LIMIT $1 OFFSET $2
""", limit, offset)
else:
# Fall back to direct query
rows = await conn.fetch(f"""
SELECT
s.user_id, u.username, s.games_played, s.games_won,
ROUND(s.games_won::numeric / NULLIF(s.games_played, 0) * 100, 1) as win_rate,
ROUND(s.total_points::numeric / NULLIF(s.total_rounds, 0), 1) as avg_score,
s.knockouts, s.best_win_streak,
ROW_NUMBER() OVER (ORDER BY {column} {direction}) as rank
FROM player_stats s
JOIN users_v2 u ON s.user_id = u.id
WHERE s.games_played >= 5
AND u.deleted_at IS NULL
AND (u.is_banned = false OR u.is_banned IS NULL)
ORDER BY {column} {direction}
LIMIT $1 OFFSET $2
""", limit, offset)
return [
LeaderboardEntry(
rank=row["rank"],
user_id=str(row["user_id"]),
username=row["username"],
value=float(row[column] or 0),
games_played=row["games_played"],
secondary_value=float(row["win_rate"] or 0) if metric != "win_rate" else None,
)
for row in rows
]
async def get_player_rank(self, user_id: str, metric: str = "wins") -> Optional[int]:
"""
Get a player's rank on a leaderboard.
Args:
user_id: User UUID.
metric: Ranking metric.
Returns:
Rank number or None if not ranked (< 5 games or not found).
"""
order_map = {
"wins": ("games_won", "DESC"),
"win_rate": ("win_rate", "DESC"),
"avg_score": ("avg_score", "ASC"),
"knockouts": ("knockouts", "DESC"),
"streak": ("best_win_streak", "DESC"),
}
if metric not in order_map:
return None
column, direction = order_map[metric]
async with self.pool.acquire() as conn:
# Check if user qualifies (5+ games)
games = await conn.fetchval(
"SELECT games_played FROM player_stats WHERE user_id = $1",
user_id
)
if not games or games < 5:
return None
view_exists = await conn.fetchval(
"SELECT 1 FROM pg_matviews WHERE matviewname = 'leaderboard_overall'"
)
if view_exists:
row = await conn.fetchrow(f"""
SELECT rank FROM (
SELECT user_id, ROW_NUMBER() OVER (ORDER BY {column} {direction}) as rank
FROM leaderboard_overall
) ranked
WHERE user_id = $1
""", user_id)
else:
row = await conn.fetchrow(f"""
SELECT rank FROM (
SELECT s.user_id, ROW_NUMBER() OVER (ORDER BY {column} {direction}) as rank
FROM player_stats s
JOIN users_v2 u ON s.user_id = u.id
WHERE s.games_played >= 5
AND u.deleted_at IS NULL
AND (u.is_banned = false OR u.is_banned IS NULL)
) ranked
WHERE user_id = $1
""", user_id)
return row["rank"] if row else None
async def refresh_leaderboard(self) -> bool:
"""
Refresh the materialized leaderboard view.
Returns:
True if refresh succeeded.
"""
async with self.pool.acquire() as conn:
try:
# Check if view exists
view_exists = await conn.fetchval(
"SELECT 1 FROM pg_matviews WHERE matviewname = 'leaderboard_overall'"
)
if view_exists:
await conn.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY leaderboard_overall")
logger.info("Refreshed leaderboard materialized view")
return True
except Exception as e:
logger.error(f"Failed to refresh leaderboard: {e}")
return False
# -------------------------------------------------------------------------
# Achievement Queries
# -------------------------------------------------------------------------
async def get_achievements(self) -> List[Achievement]:
"""Get all available achievements."""
async with self.pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, name, description, icon, category, threshold
FROM achievements
ORDER BY sort_order
""")
return [
Achievement(
id=row["id"],
name=row["name"],
description=row["description"] or "",
icon=row["icon"] or "",
category=row["category"] or "",
threshold=row["threshold"] or 0,
)
for row in rows
]
async def get_user_achievements(self, user_id: str) -> List[UserAchievement]:
"""
Get achievements earned by a user.
Args:
user_id: User UUID.
Returns:
List of earned achievements.
"""
async with self.pool.acquire() as conn:
rows = await conn.fetch("""
SELECT a.id, a.name, a.description, a.icon, ua.earned_at, ua.game_id
FROM user_achievements ua
JOIN achievements a ON ua.achievement_id = a.id
WHERE ua.user_id = $1
ORDER BY ua.earned_at DESC
""", user_id)
return [
UserAchievement(
id=row["id"],
name=row["name"],
description=row["description"] or "",
icon=row["icon"] or "",
earned_at=row["earned_at"].replace(tzinfo=timezone.utc) if row["earned_at"] else datetime.now(timezone.utc),
game_id=str(row["game_id"]) if row["game_id"] else None,
)
for row in rows
]
# -------------------------------------------------------------------------
# Stats Processing (Game Completion)
# -------------------------------------------------------------------------
async def process_game_end(self, game_id: str) -> List[str]:
"""
Process a completed game and update player stats.
Extracts game data from events and updates player_stats table.
Args:
game_id: Game UUID.
Returns:
List of newly awarded achievement IDs.
"""
if not self.event_store:
logger.warning("No event store configured, skipping stats processing")
return []
# Get game events
try:
events = await self.event_store.get_events(game_id)
except Exception as e:
logger.error(f"Failed to get events for game {game_id}: {e}")
return []
if not events:
logger.warning(f"No events found for game {game_id}")
return []
# Extract game data from events
game_data = self._extract_game_data(events)
if not game_data:
logger.warning(f"Could not extract game data from events for {game_id}")
return []
all_new_achievements = []
async with self.pool.acquire() as conn:
async with conn.transaction():
for player_id, player_data in game_data["players"].items():
# Skip CPU players (they don't have user accounts)
if player_data.get("is_cpu"):
continue
# Check if this is a valid user UUID
try:
UUID(player_id)
except (ValueError, TypeError):
# Not a UUID - likely a websocket session ID, skip
continue
# Ensure stats row exists
await conn.execute("""
INSERT INTO player_stats (user_id)
VALUES ($1)
ON CONFLICT (user_id) DO NOTHING
""", player_id)
# Calculate values
is_winner = player_id == game_data["winner_id"]
total_score = player_data["total_score"]
rounds_won = player_data["rounds_won"]
num_rounds = game_data["num_rounds"]
knockouts = player_data.get("knockouts", 0)
best_round = player_data.get("best_round")
worst_round = player_data.get("worst_round")
perfect_rounds = player_data.get("perfect_rounds", 0)
wolfpacks = player_data.get("wolfpacks", 0)
has_human_opponents = game_data.get("has_human_opponents", False)
# Update stats
await conn.execute("""
UPDATE player_stats SET
games_played = games_played + 1,
games_won = games_won + $2,
total_rounds = total_rounds + $3,
rounds_won = rounds_won + $4,
total_points = total_points + $5,
knockouts = knockouts + $6,
perfect_rounds = perfect_rounds + $7,
wolfpacks = wolfpacks + $8,
best_score = CASE
WHEN best_score IS NULL THEN $9
WHEN $9 IS NOT NULL AND $9 < best_score THEN $9
ELSE best_score
END,
worst_score = CASE
WHEN worst_score IS NULL THEN $10
WHEN $10 IS NOT NULL AND $10 > worst_score THEN $10
ELSE worst_score
END,
current_win_streak = CASE WHEN $2 = 1 THEN current_win_streak + 1 ELSE 0 END,
best_win_streak = GREATEST(best_win_streak,
CASE WHEN $2 = 1 THEN current_win_streak + 1 ELSE best_win_streak END),
first_game_at = COALESCE(first_game_at, NOW()),
last_game_at = NOW(),
games_vs_humans = games_vs_humans + $11,
games_won_vs_humans = games_won_vs_humans + $12,
updated_at = NOW()
WHERE user_id = $1
""",
player_id,
1 if is_winner else 0,
num_rounds,
rounds_won,
total_score,
knockouts,
perfect_rounds,
wolfpacks,
best_round,
worst_round,
1 if has_human_opponents else 0,
1 if is_winner and has_human_opponents else 0,
)
# Check for new achievements
new_achievements = await self._check_achievements(
conn, player_id, game_id, player_data, is_winner
)
all_new_achievements.extend(new_achievements)
logger.info(f"Processed stats for game {game_id}, awarded {len(all_new_achievements)} achievements")
return all_new_achievements
def _extract_game_data(self, events) -> Optional[dict]:
"""
Extract game statistics from event stream.
Args:
events: List of GameEvent objects.
Returns:
Dict with players, num_rounds, winner_id, etc.
"""
data = {
"players": {},
"num_rounds": 0,
"winner_id": None,
"has_human_opponents": False,
}
human_count = 0
for event in events:
if event.event_type == EventType.PLAYER_JOINED:
is_cpu = event.data.get("is_cpu", False)
if not is_cpu:
human_count += 1
data["players"][event.player_id] = {
"is_cpu": is_cpu,
"total_score": 0,
"rounds_won": 0,
"knockouts": 0,
"perfect_rounds": 0,
"wolfpacks": 0,
"best_round": None,
"worst_round": None,
}
elif event.event_type == EventType.ROUND_ENDED:
data["num_rounds"] += 1
scores = event.data.get("scores", {})
finisher_id = event.data.get("finisher_id")
# Track who went out first (knockout)
if finisher_id and finisher_id in data["players"]:
data["players"][finisher_id]["knockouts"] += 1
# Find round winner (lowest score)
if scores:
min_score = min(scores.values())
for pid, score in scores.items():
if pid in data["players"]:
p = data["players"][pid]
p["total_score"] += score
# Track best/worst rounds
if p["best_round"] is None or score < p["best_round"]:
p["best_round"] = score
if p["worst_round"] is None or score > p["worst_round"]:
p["worst_round"] = score
# Check for perfect round (score <= 0)
if score <= 0:
p["perfect_rounds"] += 1
# Award round win
if score == min_score:
p["rounds_won"] += 1
# Check for wolfpack (4 Jacks) in final hands
final_hands = event.data.get("final_hands", {})
for pid, hand in final_hands.items():
if pid in data["players"]:
jack_count = sum(1 for card in hand if card.get("rank") == "J")
if jack_count >= 4:
data["players"][pid]["wolfpacks"] += 1
elif event.event_type == EventType.GAME_ENDED:
data["winner_id"] = event.data.get("winner_id")
# Mark if there were human opponents
data["has_human_opponents"] = human_count > 1
return data if data["num_rounds"] > 0 else None
async def _check_achievements(
self,
conn: asyncpg.Connection,
user_id: str,
game_id: str,
player_data: dict,
is_winner: bool,
) -> List[str]:
"""
Check and award new achievements to a player.
Args:
conn: Database connection (within transaction).
user_id: Player's user ID.
game_id: Current game ID.
player_data: Player's data from this game.
is_winner: Whether player won the game.
Returns:
List of newly awarded achievement IDs.
"""
new_achievements = []
# Get current stats (after update)
stats = await conn.fetchrow("""
SELECT games_won, knockouts, best_win_streak, current_win_streak, perfect_rounds, wolfpacks
FROM player_stats
WHERE user_id = $1
""", user_id)
if not stats:
return []
# Get already earned achievements
earned = await conn.fetch("""
SELECT achievement_id FROM user_achievements WHERE user_id = $1
""", user_id)
earned_ids = {e["achievement_id"] for e in earned}
# Check win milestones
wins = stats["games_won"]
if wins >= 1 and "first_win" not in earned_ids:
new_achievements.append("first_win")
if wins >= 10 and "win_10" not in earned_ids:
new_achievements.append("win_10")
if wins >= 50 and "win_50" not in earned_ids:
new_achievements.append("win_50")
if wins >= 100 and "win_100" not in earned_ids:
new_achievements.append("win_100")
# Check streak achievements
streak = stats["current_win_streak"]
if streak >= 5 and "streak_5" not in earned_ids:
new_achievements.append("streak_5")
if streak >= 10 and "streak_10" not in earned_ids:
new_achievements.append("streak_10")
# Check knockout achievements
if stats["knockouts"] >= 10 and "knockout_10" not in earned_ids:
new_achievements.append("knockout_10")
# Check round-specific achievements from this game
best_round = player_data.get("best_round")
if best_round is not None:
if best_round <= 0 and "perfect_round" not in earned_ids:
new_achievements.append("perfect_round")
if best_round < 0 and "negative_round" not in earned_ids:
new_achievements.append("negative_round")
# Check wolfpack
if player_data.get("wolfpacks", 0) > 0 and "wolfpack" not in earned_ids:
new_achievements.append("wolfpack")
# Award new achievements
for achievement_id in new_achievements:
try:
await conn.execute("""
INSERT INTO user_achievements (user_id, achievement_id, game_id)
VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING
""", user_id, achievement_id, game_id)
except Exception as e:
logger.error(f"Failed to award achievement {achievement_id}: {e}")
return new_achievements
# -------------------------------------------------------------------------
# Direct Game State Processing (for legacy games without event sourcing)
# -------------------------------------------------------------------------
async def process_game_from_state(
self,
players: list,
winner_id: Optional[str],
num_rounds: int,
player_user_ids: dict[str, str] = None,
) -> List[str]:
"""
Process game stats directly from game state (for legacy games).
This is used when games don't have event sourcing. Stats are updated
based on final game state.
Args:
players: List of game.Player objects with final scores.
winner_id: Player ID of the winner.
num_rounds: Total rounds played.
player_user_ids: Optional mapping of player_id to user_id (for authenticated players).
Returns:
List of newly awarded achievement IDs.
"""
if not players:
return []
# Count human players for has_human_opponents calculation
# For legacy games, we assume all players are human unless otherwise indicated
human_count = len(players)
has_human_opponents = human_count > 1
all_new_achievements = []
async with self.pool.acquire() as conn:
async with conn.transaction():
for player in players:
# Get user_id - could be the player_id itself if it's a UUID,
# or mapped via player_user_ids
user_id = None
if player_user_ids and player.id in player_user_ids:
user_id = player_user_ids[player.id]
else:
# Try to use player.id as user_id if it looks like a UUID
try:
UUID(player.id)
user_id = player.id
except (ValueError, TypeError):
# Not a UUID, skip this player
continue
if not user_id:
continue
# Ensure stats row exists
await conn.execute("""
INSERT INTO player_stats (user_id)
VALUES ($1)
ON CONFLICT (user_id) DO NOTHING
""", user_id)
is_winner = player.id == winner_id
total_score = player.total_score
rounds_won = player.rounds_won
# We don't have per-round data in legacy mode, so some stats are limited
# Use total_score / num_rounds as an approximation for avg round score
avg_round_score = total_score / num_rounds if num_rounds > 0 else None
# Update stats
await conn.execute("""
UPDATE player_stats SET
games_played = games_played + 1,
games_won = games_won + $2,
total_rounds = total_rounds + $3,
rounds_won = rounds_won + $4,
total_points = total_points + $5,
best_score = CASE
WHEN best_score IS NULL THEN $6
WHEN $6 IS NOT NULL AND $6 < best_score THEN $6
ELSE best_score
END,
worst_score = CASE
WHEN worst_score IS NULL THEN $7
WHEN $7 IS NOT NULL AND $7 > worst_score THEN $7
ELSE worst_score
END,
current_win_streak = CASE WHEN $2 = 1 THEN current_win_streak + 1 ELSE 0 END,
best_win_streak = GREATEST(best_win_streak,
CASE WHEN $2 = 1 THEN current_win_streak + 1 ELSE best_win_streak END),
first_game_at = COALESCE(first_game_at, NOW()),
last_game_at = NOW(),
games_vs_humans = games_vs_humans + $8,
games_won_vs_humans = games_won_vs_humans + $9,
updated_at = NOW()
WHERE user_id = $1
""",
user_id,
1 if is_winner else 0,
num_rounds,
rounds_won,
total_score,
avg_round_score, # Approximation for best_score
avg_round_score, # Approximation for worst_score
1 if has_human_opponents else 0,
1 if is_winner and has_human_opponents else 0,
)
# Check achievements (limited data in legacy mode)
new_achievements = await self._check_achievements_legacy(
conn, user_id, is_winner
)
all_new_achievements.extend(new_achievements)
logger.info(f"Processed stats for legacy game with {len(players)} players")
return all_new_achievements
async def _check_achievements_legacy(
self,
conn: asyncpg.Connection,
user_id: str,
is_winner: bool,
) -> List[str]:
"""
Check and award achievements for legacy games (limited data).
Only checks win-based achievements since we don't have round-level data.
"""
new_achievements = []
# Get current stats
stats = await conn.fetchrow("""
SELECT games_won, current_win_streak FROM player_stats
WHERE user_id = $1
""", user_id)
if not stats:
return []
# Get already earned achievements
earned = await conn.fetch("""
SELECT achievement_id FROM user_achievements WHERE user_id = $1
""", user_id)
earned_ids = {e["achievement_id"] for e in earned}
# Check win milestones
wins = stats["games_won"]
if wins >= 1 and "first_win" not in earned_ids:
new_achievements.append("first_win")
if wins >= 10 and "win_10" not in earned_ids:
new_achievements.append("win_10")
if wins >= 50 and "win_50" not in earned_ids:
new_achievements.append("win_50")
if wins >= 100 and "win_100" not in earned_ids:
new_achievements.append("win_100")
# Check streak achievements
streak = stats["current_win_streak"]
if streak >= 5 and "streak_5" not in earned_ids:
new_achievements.append("streak_5")
if streak >= 10 and "streak_10" not in earned_ids:
new_achievements.append("streak_10")
# Award new achievements
for achievement_id in new_achievements:
try:
await conn.execute("""
INSERT INTO user_achievements (user_id, achievement_id)
VALUES ($1, $2)
ON CONFLICT DO NOTHING
""", user_id, achievement_id)
except Exception as e:
logger.error(f"Failed to award achievement {achievement_id}: {e}")
return new_achievements
# -------------------------------------------------------------------------
# Stats Queue Management
# -------------------------------------------------------------------------
async def queue_game_for_processing(self, game_id: str) -> int:
"""
Add a game to the stats processing queue.
Args:
game_id: Game UUID.
Returns:
Queue entry ID.
"""
async with self.pool.acquire() as conn:
row = await conn.fetchrow("""
INSERT INTO stats_queue (game_id)
VALUES ($1)
RETURNING id
""", game_id)
return row["id"]
async def process_pending_queue(self, limit: int = 100) -> int:
"""
Process pending games in the stats queue.
Args:
limit: Maximum games to process.
Returns:
Number of games processed.
"""
processed = 0
async with self.pool.acquire() as conn:
# Get pending games
games = await conn.fetch("""
SELECT id, game_id FROM stats_queue
WHERE status = 'pending'
ORDER BY created_at
LIMIT $1
""", limit)
for game in games:
try:
# Mark as processing
await conn.execute("""
UPDATE stats_queue SET status = 'processing' WHERE id = $1
""", game["id"])
# Process
await self.process_game_end(str(game["game_id"]))
# Mark complete
await conn.execute("""
UPDATE stats_queue
SET status = 'completed', processed_at = NOW()
WHERE id = $1
""", game["id"])
processed += 1
except Exception as e:
logger.error(f"Failed to process game {game['game_id']}: {e}")
# Mark failed
await conn.execute("""
UPDATE stats_queue
SET status = 'failed', error_message = $2, processed_at = NOW()
WHERE id = $1
""", game["id"], str(e))
return processed
async def cleanup_old_queue_entries(self, days: int = 7) -> int:
"""
Clean up old completed/failed queue entries.
Args:
days: Delete entries older than this many days.
Returns:
Number of entries deleted.
"""
async with self.pool.acquire() as conn:
result = await conn.execute("""
DELETE FROM stats_queue
WHERE status IN ('completed', 'failed')
AND processed_at < NOW() - INTERVAL '1 day' * $1
""", days)
# Parse "DELETE N" result
return int(result.split()[1]) if result else 0
# Global stats service instance
_stats_service: Optional[StatsService] = None
async def get_stats_service(
pool: asyncpg.Pool,
event_store: Optional[EventStore] = None,
) -> StatsService:
"""
Get or create the global stats service instance.
Args:
pool: asyncpg connection pool.
event_store: Optional EventStore.
Returns:
StatsService instance.
"""
global _stats_service
if _stats_service is None:
_stats_service = StatsService(pool, event_store)
return _stats_service
def set_stats_service(service: StatsService) -> None:
"""Set the global stats service instance."""
global _stats_service
_stats_service = service
def close_stats_service() -> None:
"""Close the global stats service."""
global _stats_service
_stats_service = None