Huge v2 uplift, now deployable with real user management and tooling!
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
33
server/services/__init__.py
Normal file
33
server/services/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Services package for Golf game V2 business logic."""
|
||||
|
||||
from .recovery_service import RecoveryService, RecoveryResult
|
||||
from .email_service import EmailService, get_email_service
|
||||
from .auth_service import AuthService, AuthResult, RegistrationResult, get_auth_service, close_auth_service
|
||||
from .admin_service import (
|
||||
AdminService,
|
||||
UserDetails,
|
||||
AuditEntry,
|
||||
SystemStats,
|
||||
InviteCode,
|
||||
get_admin_service,
|
||||
close_admin_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RecoveryService",
|
||||
"RecoveryResult",
|
||||
"EmailService",
|
||||
"get_email_service",
|
||||
"AuthService",
|
||||
"AuthResult",
|
||||
"RegistrationResult",
|
||||
"get_auth_service",
|
||||
"close_auth_service",
|
||||
"AdminService",
|
||||
"UserDetails",
|
||||
"AuditEntry",
|
||||
"SystemStats",
|
||||
"InviteCode",
|
||||
"get_admin_service",
|
||||
"close_admin_service",
|
||||
]
|
||||
1243
server/services/admin_service.py
Normal file
1243
server/services/admin_service.py
Normal file
File diff suppressed because it is too large
Load Diff
654
server/services/auth_service.py
Normal file
654
server/services/auth_service.py
Normal file
@@ -0,0 +1,654 @@
|
||||
"""
|
||||
Authentication service for Golf game.
|
||||
|
||||
Provides business logic for user registration, login, password management,
|
||||
and session handling.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
|
||||
import bcrypt
|
||||
|
||||
from config import config
|
||||
from models.user import User, UserRole, UserSession, GuestSession
|
||||
from stores.user_store import UserStore
|
||||
from services.email_service import EmailService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
"""Result of an authentication operation."""
|
||||
success: bool
|
||||
user: Optional[User] = None
|
||||
token: Optional[str] = None
|
||||
expires_at: Optional[datetime] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistrationResult:
|
||||
"""Result of a registration operation."""
|
||||
success: bool
|
||||
user: Optional[User] = None
|
||||
requires_verification: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""
|
||||
Authentication service.
|
||||
|
||||
Handles all authentication business logic:
|
||||
- User registration with optional email verification
|
||||
- Login/logout with session management
|
||||
- Password reset flow
|
||||
- Guest-to-user conversion
|
||||
- Account deletion (soft delete)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_store: UserStore,
|
||||
email_service: EmailService,
|
||||
session_expiry_hours: int = 168,
|
||||
require_email_verification: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize auth service.
|
||||
|
||||
Args:
|
||||
user_store: User persistence store.
|
||||
email_service: Email sending service.
|
||||
session_expiry_hours: Session lifetime in hours.
|
||||
require_email_verification: Whether to require email verification.
|
||||
"""
|
||||
self.user_store = user_store
|
||||
self.email_service = email_service
|
||||
self.session_expiry_hours = session_expiry_hours
|
||||
self.require_email_verification = require_email_verification
|
||||
|
||||
@classmethod
|
||||
async def create(cls, user_store: UserStore) -> "AuthService":
|
||||
"""
|
||||
Create AuthService from config.
|
||||
|
||||
Args:
|
||||
user_store: User persistence store.
|
||||
"""
|
||||
from services.email_service import get_email_service
|
||||
|
||||
return cls(
|
||||
user_store=user_store,
|
||||
email_service=get_email_service(),
|
||||
session_expiry_hours=config.SESSION_EXPIRY_HOURS,
|
||||
require_email_verification=config.REQUIRE_EMAIL_VERIFICATION,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Registration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def register(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
email: Optional[str] = None,
|
||||
guest_id: Optional[str] = None,
|
||||
) -> RegistrationResult:
|
||||
"""
|
||||
Register a new user account.
|
||||
|
||||
Args:
|
||||
username: Desired username.
|
||||
password: Plain text password.
|
||||
email: Optional email address.
|
||||
guest_id: Guest session ID if converting.
|
||||
|
||||
Returns:
|
||||
RegistrationResult with user or error.
|
||||
"""
|
||||
# Validate inputs
|
||||
if len(username) < 2 or len(username) > 50:
|
||||
return RegistrationResult(success=False, error="Username must be 2-50 characters")
|
||||
|
||||
if len(password) < 8:
|
||||
return RegistrationResult(success=False, error="Password must be at least 8 characters")
|
||||
|
||||
# Check for existing username
|
||||
existing = await self.user_store.get_user_by_username(username)
|
||||
if existing:
|
||||
return RegistrationResult(success=False, error="Username already taken")
|
||||
|
||||
# Check for existing email
|
||||
if email:
|
||||
existing = await self.user_store.get_user_by_email(email)
|
||||
if existing:
|
||||
return RegistrationResult(success=False, error="Email already registered")
|
||||
|
||||
# Hash password
|
||||
password_hash = self._hash_password(password)
|
||||
|
||||
# Generate verification token if needed
|
||||
verification_token = None
|
||||
verification_expires = None
|
||||
if email and self.require_email_verification:
|
||||
verification_token = secrets.token_urlsafe(32)
|
||||
verification_expires = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
# Create user
|
||||
user = await self.user_store.create_user(
|
||||
username=username,
|
||||
password_hash=password_hash,
|
||||
email=email,
|
||||
role=UserRole.USER,
|
||||
guest_id=guest_id,
|
||||
verification_token=verification_token,
|
||||
verification_expires=verification_expires,
|
||||
)
|
||||
|
||||
if not user:
|
||||
return RegistrationResult(success=False, error="Failed to create account")
|
||||
|
||||
# Mark guest as converted if applicable
|
||||
if guest_id:
|
||||
await self.user_store.mark_guest_converted(guest_id, user.id)
|
||||
|
||||
# Send verification email if needed
|
||||
requires_verification = False
|
||||
if email and self.require_email_verification and verification_token:
|
||||
await self.email_service.send_verification_email(
|
||||
to=email,
|
||||
token=verification_token,
|
||||
username=username,
|
||||
)
|
||||
await self.user_store.log_email(user.id, "verification", email)
|
||||
requires_verification = True
|
||||
|
||||
return RegistrationResult(
|
||||
success=True,
|
||||
user=user,
|
||||
requires_verification=requires_verification,
|
||||
)
|
||||
|
||||
async def verify_email(self, token: str) -> AuthResult:
|
||||
"""
|
||||
Verify email with token.
|
||||
|
||||
Args:
|
||||
token: Verification token from email.
|
||||
|
||||
Returns:
|
||||
AuthResult with success status.
|
||||
"""
|
||||
user = await self.user_store.get_user_by_verification_token(token)
|
||||
if not user:
|
||||
return AuthResult(success=False, error="Invalid verification token")
|
||||
|
||||
# Check expiration
|
||||
if user.verification_expires and user.verification_expires < datetime.now(timezone.utc):
|
||||
return AuthResult(success=False, error="Verification token expired")
|
||||
|
||||
# Mark as verified
|
||||
await self.user_store.clear_verification_token(user.id)
|
||||
|
||||
# Refresh user
|
||||
user = await self.user_store.get_user_by_id(user.id)
|
||||
|
||||
return AuthResult(success=True, user=user)
|
||||
|
||||
async def resend_verification(self, email: str) -> bool:
|
||||
"""
|
||||
Resend verification email.
|
||||
|
||||
Args:
|
||||
email: Email address to send to.
|
||||
|
||||
Returns:
|
||||
True if email was sent.
|
||||
"""
|
||||
user = await self.user_store.get_user_by_email(email)
|
||||
if not user or user.email_verified:
|
||||
return False
|
||||
|
||||
# Generate new token
|
||||
verification_token = secrets.token_urlsafe(32)
|
||||
verification_expires = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
await self.user_store.update_user(
|
||||
user.id,
|
||||
verification_token=verification_token,
|
||||
verification_expires=verification_expires,
|
||||
)
|
||||
|
||||
await self.email_service.send_verification_email(
|
||||
to=email,
|
||||
token=verification_token,
|
||||
username=user.username,
|
||||
)
|
||||
await self.user_store.log_email(user.id, "verification", email)
|
||||
|
||||
return True
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Login/Logout
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def login(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
device_info: Optional[dict] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Authenticate user and create session.
|
||||
|
||||
Args:
|
||||
username: Username or email.
|
||||
password: Plain text password.
|
||||
device_info: Client device information.
|
||||
ip_address: Client IP address.
|
||||
|
||||
Returns:
|
||||
AuthResult with session token or error.
|
||||
"""
|
||||
# Try username first, then email
|
||||
user = await self.user_store.get_user_by_username(username)
|
||||
if not user:
|
||||
user = await self.user_store.get_user_by_email(username)
|
||||
|
||||
if not user:
|
||||
return AuthResult(success=False, error="Invalid credentials")
|
||||
|
||||
if not user.can_login():
|
||||
return AuthResult(success=False, error="Account is disabled")
|
||||
|
||||
# Check email verification if required
|
||||
if self.require_email_verification and user.email and not user.email_verified:
|
||||
return AuthResult(success=False, error="Please verify your email first")
|
||||
|
||||
# Verify password
|
||||
if not self._verify_password(password, user.password_hash):
|
||||
return AuthResult(success=False, error="Invalid credentials")
|
||||
|
||||
# Create session
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=self.session_expiry_hours)
|
||||
|
||||
await self.user_store.create_session(
|
||||
user_id=user.id,
|
||||
token=token,
|
||||
expires_at=expires_at,
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
)
|
||||
|
||||
# Update last login
|
||||
await self.user_store.update_user(user.id, last_login=datetime.now(timezone.utc))
|
||||
|
||||
return AuthResult(
|
||||
success=True,
|
||||
user=user,
|
||||
token=token,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
async def logout(self, token: str) -> bool:
|
||||
"""
|
||||
Invalidate a session.
|
||||
|
||||
Args:
|
||||
token: Session token to invalidate.
|
||||
|
||||
Returns:
|
||||
True if session was revoked.
|
||||
"""
|
||||
return await self.user_store.revoke_session_by_token(token)
|
||||
|
||||
async def logout_all(self, user_id: str, except_token: Optional[str] = None) -> int:
|
||||
"""
|
||||
Invalidate all sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
except_token: Optional token to keep active.
|
||||
|
||||
Returns:
|
||||
Number of sessions revoked.
|
||||
"""
|
||||
return await self.user_store.revoke_all_sessions(user_id, except_token)
|
||||
|
||||
async def get_user_from_token(self, token: str) -> Optional[User]:
|
||||
"""
|
||||
Get user from session token.
|
||||
|
||||
Args:
|
||||
token: Session token.
|
||||
|
||||
Returns:
|
||||
User if valid session, None otherwise.
|
||||
"""
|
||||
session = await self.user_store.get_session_by_token(token)
|
||||
if not session or not session.is_valid():
|
||||
return None
|
||||
|
||||
# Update last used
|
||||
await self.user_store.update_session_last_used(session.id)
|
||||
|
||||
user = await self.user_store.get_user_by_id(session.user_id)
|
||||
if not user or not user.can_login():
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Password Management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def forgot_password(self, email: str) -> bool:
|
||||
"""
|
||||
Initiate password reset flow.
|
||||
|
||||
Args:
|
||||
email: Email address.
|
||||
|
||||
Returns:
|
||||
True if reset email was sent (always returns True to prevent enumeration).
|
||||
"""
|
||||
user = await self.user_store.get_user_by_email(email)
|
||||
if not user:
|
||||
# Don't reveal if email exists
|
||||
return True
|
||||
|
||||
# Generate reset token
|
||||
reset_token = secrets.token_urlsafe(32)
|
||||
reset_expires = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
await self.user_store.update_user(
|
||||
user.id,
|
||||
reset_token=reset_token,
|
||||
reset_expires=reset_expires,
|
||||
)
|
||||
|
||||
await self.email_service.send_password_reset_email(
|
||||
to=email,
|
||||
token=reset_token,
|
||||
username=user.username,
|
||||
)
|
||||
await self.user_store.log_email(user.id, "password_reset", email)
|
||||
|
||||
return True
|
||||
|
||||
async def reset_password(self, token: str, new_password: str) -> AuthResult:
|
||||
"""
|
||||
Reset password using token.
|
||||
|
||||
Args:
|
||||
token: Reset token from email.
|
||||
new_password: New password.
|
||||
|
||||
Returns:
|
||||
AuthResult with success status.
|
||||
"""
|
||||
if len(new_password) < 8:
|
||||
return AuthResult(success=False, error="Password must be at least 8 characters")
|
||||
|
||||
user = await self.user_store.get_user_by_reset_token(token)
|
||||
if not user:
|
||||
return AuthResult(success=False, error="Invalid reset token")
|
||||
|
||||
# Check expiration
|
||||
if user.reset_expires and user.reset_expires < datetime.now(timezone.utc):
|
||||
return AuthResult(success=False, error="Reset token expired")
|
||||
|
||||
# Update password
|
||||
password_hash = self._hash_password(new_password)
|
||||
await self.user_store.update_user(user.id, password_hash=password_hash)
|
||||
await self.user_store.clear_reset_token(user.id)
|
||||
|
||||
# Revoke all sessions
|
||||
await self.user_store.revoke_all_sessions(user.id)
|
||||
|
||||
# Send notification
|
||||
if user.email:
|
||||
await self.email_service.send_password_changed_notification(
|
||||
to=user.email,
|
||||
username=user.username,
|
||||
)
|
||||
await self.user_store.log_email(user.id, "password_changed", user.email)
|
||||
|
||||
return AuthResult(success=True, user=user)
|
||||
|
||||
async def change_password(
|
||||
self,
|
||||
user_id: str,
|
||||
current_password: str,
|
||||
new_password: str,
|
||||
current_token: Optional[str] = None,
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Change password for authenticated user.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
current_password: Current password for verification.
|
||||
new_password: New password.
|
||||
current_token: Current session token to keep active.
|
||||
|
||||
Returns:
|
||||
AuthResult with success status.
|
||||
"""
|
||||
if len(new_password) < 8:
|
||||
return AuthResult(success=False, error="Password must be at least 8 characters")
|
||||
|
||||
user = await self.user_store.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return AuthResult(success=False, error="User not found")
|
||||
|
||||
# Verify current password
|
||||
if not self._verify_password(current_password, user.password_hash):
|
||||
return AuthResult(success=False, error="Current password is incorrect")
|
||||
|
||||
# Update password
|
||||
password_hash = self._hash_password(new_password)
|
||||
await self.user_store.update_user(user.id, password_hash=password_hash)
|
||||
|
||||
# Revoke all sessions except current
|
||||
await self.user_store.revoke_all_sessions(user.id, except_token=current_token)
|
||||
|
||||
# Send notification
|
||||
if user.email:
|
||||
await self.email_service.send_password_changed_notification(
|
||||
to=user.email,
|
||||
username=user.username,
|
||||
)
|
||||
await self.user_store.log_email(user.id, "password_changed", user.email)
|
||||
|
||||
return AuthResult(success=True, user=user)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# User Profile
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def update_preferences(self, user_id: str, preferences: dict) -> Optional[User]:
|
||||
"""
|
||||
Update user preferences.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
preferences: New preferences dict.
|
||||
|
||||
Returns:
|
||||
Updated user or None.
|
||||
"""
|
||||
return await self.user_store.update_user(user_id, preferences=preferences)
|
||||
|
||||
async def get_sessions(self, user_id: str) -> list[UserSession]:
|
||||
"""
|
||||
Get all active sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
|
||||
Returns:
|
||||
List of active sessions.
|
||||
"""
|
||||
return await self.user_store.get_sessions_for_user(user_id)
|
||||
|
||||
async def revoke_session(self, user_id: str, session_id: str) -> bool:
|
||||
"""
|
||||
Revoke a specific session.
|
||||
|
||||
Args:
|
||||
user_id: User ID (for authorization).
|
||||
session_id: Session ID to revoke.
|
||||
|
||||
Returns:
|
||||
True if session was revoked.
|
||||
"""
|
||||
# Verify session belongs to user
|
||||
sessions = await self.user_store.get_sessions_for_user(user_id)
|
||||
if not any(s.id == session_id for s in sessions):
|
||||
return False
|
||||
|
||||
return await self.user_store.revoke_session(session_id)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Guest Conversion
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def convert_guest(
|
||||
self,
|
||||
guest_id: str,
|
||||
username: str,
|
||||
password: str,
|
||||
email: Optional[str] = None,
|
||||
) -> RegistrationResult:
|
||||
"""
|
||||
Convert guest session to full user account.
|
||||
|
||||
Args:
|
||||
guest_id: Guest session ID.
|
||||
username: Desired username.
|
||||
password: Password.
|
||||
email: Optional email.
|
||||
|
||||
Returns:
|
||||
RegistrationResult with user or error.
|
||||
"""
|
||||
# Verify guest exists and not already converted
|
||||
guest = await self.user_store.get_guest_session(guest_id)
|
||||
if not guest:
|
||||
return RegistrationResult(success=False, error="Guest session not found")
|
||||
|
||||
if guest.is_converted():
|
||||
return RegistrationResult(success=False, error="Guest already converted")
|
||||
|
||||
# Register with guest ID
|
||||
return await self.register(
|
||||
username=username,
|
||||
password=password,
|
||||
email=email,
|
||||
guest_id=guest_id,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Account Deletion
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def delete_account(self, user_id: str) -> bool:
|
||||
"""
|
||||
Soft delete user account.
|
||||
|
||||
Args:
|
||||
user_id: User ID to delete.
|
||||
|
||||
Returns:
|
||||
True if account was deleted.
|
||||
"""
|
||||
# Revoke all sessions
|
||||
await self.user_store.revoke_all_sessions(user_id)
|
||||
|
||||
# Soft delete
|
||||
user = await self.user_store.update_user(
|
||||
user_id,
|
||||
is_active=False,
|
||||
deleted_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
return user is not None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Guest Sessions
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def create_guest_session(
|
||||
self,
|
||||
guest_id: str,
|
||||
display_name: Optional[str] = None,
|
||||
) -> GuestSession:
|
||||
"""
|
||||
Create or get guest session.
|
||||
|
||||
Args:
|
||||
guest_id: Guest session ID.
|
||||
display_name: Display name for guest.
|
||||
|
||||
Returns:
|
||||
GuestSession.
|
||||
"""
|
||||
existing = await self.user_store.get_guest_session(guest_id)
|
||||
if existing:
|
||||
await self.user_store.update_guest_last_seen(guest_id)
|
||||
return existing
|
||||
|
||||
return await self.user_store.create_guest_session(guest_id, display_name)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Password Hashing
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _hash_password(self, password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed = bcrypt.hashpw(password.encode(), salt)
|
||||
return hashed.decode()
|
||||
|
||||
def _verify_password(self, password: str, password_hash: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
try:
|
||||
return bcrypt.checkpw(password.encode(), password_hash.encode())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Global auth service instance
|
||||
_auth_service: Optional[AuthService] = None
|
||||
|
||||
|
||||
async def get_auth_service(user_store: UserStore) -> AuthService:
|
||||
"""
|
||||
Get or create the global auth service instance.
|
||||
|
||||
Args:
|
||||
user_store: User persistence store.
|
||||
|
||||
Returns:
|
||||
AuthService instance.
|
||||
"""
|
||||
global _auth_service
|
||||
if _auth_service is None:
|
||||
_auth_service = await AuthService.create(user_store)
|
||||
return _auth_service
|
||||
|
||||
|
||||
async def close_auth_service() -> None:
|
||||
"""Close the global auth service."""
|
||||
global _auth_service
|
||||
_auth_service = None
|
||||
215
server/services/email_service.py
Normal file
215
server/services/email_service.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Email service for Golf game authentication.
|
||||
|
||||
Provides email sending via Resend for verification, password reset, and notifications.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmailService:
|
||||
"""
|
||||
Email service using Resend API.
|
||||
|
||||
Handles all transactional emails for authentication:
|
||||
- Email verification
|
||||
- Password reset
|
||||
- Password changed notification
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, from_address: str, base_url: str):
|
||||
"""
|
||||
Initialize email service.
|
||||
|
||||
Args:
|
||||
api_key: Resend API key.
|
||||
from_address: Sender email address.
|
||||
base_url: Base URL for verification/reset links.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.from_address = from_address
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self._client = None
|
||||
|
||||
@classmethod
|
||||
def create(cls) -> "EmailService":
|
||||
"""Create EmailService from config."""
|
||||
return cls(
|
||||
api_key=config.RESEND_API_KEY,
|
||||
from_address=config.EMAIL_FROM,
|
||||
base_url=config.BASE_URL,
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
"""Lazy-load Resend client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import resend
|
||||
resend.api_key = self.api_key
|
||||
self._client = resend
|
||||
except ImportError:
|
||||
logger.warning("resend package not installed, emails will be logged only")
|
||||
self._client = None
|
||||
return self._client
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if email service is properly configured."""
|
||||
return bool(self.api_key)
|
||||
|
||||
async def send_verification_email(
|
||||
self,
|
||||
to: str,
|
||||
token: str,
|
||||
username: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Send email verification email.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
token: Verification token.
|
||||
username: User's display name.
|
||||
|
||||
Returns:
|
||||
Resend message ID if sent, None if not configured.
|
||||
"""
|
||||
if not self.is_configured():
|
||||
logger.info(f"Email not configured. Would send verification to {to}")
|
||||
return None
|
||||
|
||||
verify_url = f"{self.base_url}/verify-email?token={token}"
|
||||
|
||||
subject = "Verify your Golf Game account"
|
||||
html = f"""
|
||||
<h2>Welcome to Golf Game, {username}!</h2>
|
||||
<p>Please verify your email address by clicking the link below:</p>
|
||||
<p><a href="{verify_url}">Verify Email Address</a></p>
|
||||
<p>Or copy and paste this URL into your browser:</p>
|
||||
<p>{verify_url}</p>
|
||||
<p>This link will expire in 24 hours.</p>
|
||||
<p>If you didn't create this account, you can safely ignore this email.</p>
|
||||
"""
|
||||
|
||||
return await self._send_email(to, subject, html)
|
||||
|
||||
async def send_password_reset_email(
|
||||
self,
|
||||
to: str,
|
||||
token: str,
|
||||
username: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Send password reset email.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
token: Reset token.
|
||||
username: User's display name.
|
||||
|
||||
Returns:
|
||||
Resend message ID if sent, None if not configured.
|
||||
"""
|
||||
if not self.is_configured():
|
||||
logger.info(f"Email not configured. Would send password reset to {to}")
|
||||
return None
|
||||
|
||||
reset_url = f"{self.base_url}/reset-password?token={token}"
|
||||
|
||||
subject = "Reset your Golf Game password"
|
||||
html = f"""
|
||||
<h2>Password Reset Request</h2>
|
||||
<p>Hi {username},</p>
|
||||
<p>We received a request to reset your password. Click the link below to set a new password:</p>
|
||||
<p><a href="{reset_url}">Reset Password</a></p>
|
||||
<p>Or copy and paste this URL into your browser:</p>
|
||||
<p>{reset_url}</p>
|
||||
<p>This link will expire in 1 hour.</p>
|
||||
<p>If you didn't request this, you can safely ignore this email. Your password will remain unchanged.</p>
|
||||
"""
|
||||
|
||||
return await self._send_email(to, subject, html)
|
||||
|
||||
async def send_password_changed_notification(
|
||||
self,
|
||||
to: str,
|
||||
username: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Send password changed notification email.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
username: User's display name.
|
||||
|
||||
Returns:
|
||||
Resend message ID if sent, None if not configured.
|
||||
"""
|
||||
if not self.is_configured():
|
||||
logger.info(f"Email not configured. Would send password change notification to {to}")
|
||||
return None
|
||||
|
||||
subject = "Your Golf Game password was changed"
|
||||
html = f"""
|
||||
<h2>Password Changed</h2>
|
||||
<p>Hi {username},</p>
|
||||
<p>Your password was successfully changed.</p>
|
||||
<p>If you did not make this change, please contact support immediately.</p>
|
||||
"""
|
||||
|
||||
return await self._send_email(to, subject, html)
|
||||
|
||||
async def _send_email(
|
||||
self,
|
||||
to: str,
|
||||
subject: str,
|
||||
html: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Send an email via Resend.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject.
|
||||
html: HTML email body.
|
||||
|
||||
Returns:
|
||||
Resend message ID if sent, None on error.
|
||||
"""
|
||||
if not self.client:
|
||||
logger.warning(f"Resend not available. Email to {to}: {subject}")
|
||||
return None
|
||||
|
||||
try:
|
||||
params = {
|
||||
"from": self.from_address,
|
||||
"to": [to],
|
||||
"subject": subject,
|
||||
"html": html,
|
||||
}
|
||||
|
||||
response = self.client.Emails.send(params)
|
||||
message_id = response.get("id") if isinstance(response, dict) else getattr(response, "id", None)
|
||||
logger.info(f"Email sent to {to}: {message_id}")
|
||||
return message_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email to {to}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global email service instance
|
||||
_email_service: Optional[EmailService] = None
|
||||
|
||||
|
||||
def get_email_service() -> EmailService:
|
||||
"""Get or create the global email service instance."""
|
||||
global _email_service
|
||||
if _email_service is None:
|
||||
_email_service = EmailService.create()
|
||||
return _email_service
|
||||
223
server/services/ratelimit.py
Normal file
223
server/services/ratelimit.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
Redis-based rate limiter service.
|
||||
|
||||
Implements a sliding window counter algorithm using Redis for distributed
|
||||
rate limiting across multiple server instances.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
from fastapi import Request, WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Rate limit configurations: (max_requests, window_seconds)
|
||||
RATE_LIMITS = {
|
||||
"api_general": (100, 60), # 100 requests per minute
|
||||
"api_auth": (10, 60), # 10 auth attempts per minute
|
||||
"api_create_room": (5, 60), # 5 room creations per minute
|
||||
"websocket_connect": (10, 60), # 10 WS connections per minute
|
||||
"websocket_message": (30, 10), # 30 messages per 10 seconds
|
||||
"email_send": (3, 300), # 3 emails per 5 minutes
|
||||
}
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter using Redis."""
|
||||
|
||||
def __init__(self, redis_client: redis.Redis):
|
||||
"""
|
||||
Initialize rate limiter with Redis client.
|
||||
|
||||
Args:
|
||||
redis_client: Async Redis client for state storage.
|
||||
"""
|
||||
self.redis = redis_client
|
||||
|
||||
async def is_allowed(
|
||||
self,
|
||||
key: str,
|
||||
limit: int,
|
||||
window_seconds: int,
|
||||
) -> tuple[bool, dict]:
|
||||
"""
|
||||
Check if request is allowed under rate limit.
|
||||
|
||||
Uses a sliding window counter algorithm:
|
||||
- Divides time into fixed windows
|
||||
- Counts requests in current window
|
||||
- Atomically increments and checks limit
|
||||
|
||||
Args:
|
||||
key: Unique identifier for the rate limit bucket.
|
||||
limit: Maximum requests allowed in window.
|
||||
window_seconds: Time window in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, info) where info contains:
|
||||
- remaining: requests remaining in window
|
||||
- reset: seconds until window resets
|
||||
- limit: the limit that was applied
|
||||
"""
|
||||
now = int(time.time())
|
||||
window_key = f"ratelimit:{key}:{now // window_seconds}"
|
||||
|
||||
try:
|
||||
async with self.redis.pipeline(transaction=True) as pipe:
|
||||
pipe.incr(window_key)
|
||||
pipe.expire(window_key, window_seconds + 1) # Extra second for safety
|
||||
results = await pipe.execute()
|
||||
|
||||
current_count = results[0]
|
||||
remaining = max(0, limit - current_count)
|
||||
reset = window_seconds - (now % window_seconds)
|
||||
|
||||
info = {
|
||||
"remaining": remaining,
|
||||
"reset": reset,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
allowed = current_count <= limit
|
||||
if not allowed:
|
||||
logger.warning(f"Rate limit exceeded for {key}: {current_count}/{limit}")
|
||||
|
||||
return allowed, info
|
||||
|
||||
except redis.RedisError as e:
|
||||
# If Redis is unavailable, fail open (allow request)
|
||||
logger.error(f"Rate limiter Redis error: {e}")
|
||||
return True, {"remaining": limit, "reset": window_seconds, "limit": limit}
|
||||
|
||||
def get_client_key(
|
||||
self,
|
||||
request: Request | WebSocket,
|
||||
user_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate rate limit key for client.
|
||||
|
||||
Uses user ID if authenticated, otherwise hashes client IP.
|
||||
|
||||
Args:
|
||||
request: HTTP request or WebSocket.
|
||||
user_id: Authenticated user ID, if available.
|
||||
|
||||
Returns:
|
||||
Unique client identifier string.
|
||||
"""
|
||||
if user_id:
|
||||
return f"user:{user_id}"
|
||||
|
||||
# For anonymous users, use IP hash
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# Hash IP for privacy
|
||||
ip_hash = hashlib.sha256(client_ip.encode()).hexdigest()[:16]
|
||||
return f"ip:{ip_hash}"
|
||||
|
||||
def _get_client_ip(self, request: Request | WebSocket) -> str:
|
||||
"""
|
||||
Extract client IP from request, handling proxies.
|
||||
|
||||
Args:
|
||||
request: HTTP request or WebSocket.
|
||||
|
||||
Returns:
|
||||
Client IP address string.
|
||||
"""
|
||||
# Check X-Forwarded-For header (from reverse proxy)
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
# Take the first IP (original client)
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
# Check X-Real-IP header (nginx)
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
# Fall back to direct connection
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
class ConnectionMessageLimiter:
|
||||
"""
|
||||
In-memory rate limiter for WebSocket message frequency.
|
||||
|
||||
Used to limit messages within a single connection without
|
||||
requiring Redis round-trips for every message.
|
||||
"""
|
||||
|
||||
def __init__(self, max_messages: int = 30, window_seconds: int = 10):
|
||||
"""
|
||||
Initialize connection message limiter.
|
||||
|
||||
Args:
|
||||
max_messages: Maximum messages allowed in window.
|
||||
window_seconds: Time window in seconds.
|
||||
"""
|
||||
self.max_messages = max_messages
|
||||
self.window_seconds = window_seconds
|
||||
self.timestamps: list[float] = []
|
||||
|
||||
def check(self) -> bool:
|
||||
"""
|
||||
Check if another message is allowed.
|
||||
|
||||
Maintains a sliding window of message timestamps.
|
||||
|
||||
Returns:
|
||||
True if message is allowed, False if rate limited.
|
||||
"""
|
||||
now = time.time()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
# Remove old timestamps
|
||||
self.timestamps = [t for t in self.timestamps if t > cutoff]
|
||||
|
||||
# Check limit
|
||||
if len(self.timestamps) >= self.max_messages:
|
||||
return False
|
||||
|
||||
# Record this message
|
||||
self.timestamps.append(now)
|
||||
return True
|
||||
|
||||
def reset(self):
|
||||
"""Reset the limiter (e.g., on reconnection)."""
|
||||
self.timestamps = []
|
||||
|
||||
|
||||
# Global rate limiter instance
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
async def get_rate_limiter(redis_client: redis.Redis) -> RateLimiter:
|
||||
"""
|
||||
Get or create the global rate limiter instance.
|
||||
|
||||
Args:
|
||||
redis_client: Redis client for state storage.
|
||||
|
||||
Returns:
|
||||
RateLimiter instance.
|
||||
"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter(redis_client)
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
def close_rate_limiter():
|
||||
"""Close the global rate limiter."""
|
||||
global _rate_limiter
|
||||
_rate_limiter = None
|
||||
353
server/services/recovery_service.py
Normal file
353
server/services/recovery_service.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Game recovery service for rebuilding active games from event store.
|
||||
|
||||
On server restart, all in-memory game state is lost. This service:
|
||||
1. Queries the event store for active games
|
||||
2. Rebuilds game state by replaying events
|
||||
3. Caches the rebuilt state in Redis
|
||||
4. Handles partial recovery (applying only new events to cached state)
|
||||
|
||||
This ensures games can survive server restarts without data loss.
|
||||
|
||||
Usage:
|
||||
recovery = RecoveryService(event_store, state_cache)
|
||||
results = await recovery.recover_all_games()
|
||||
print(f"Recovered {results['recovered']} games")
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
from stores.event_store import EventStore
|
||||
from stores.state_cache import StateCache
|
||||
from models.events import EventType
|
||||
from models.game_state import RebuiltGameState, rebuild_state, CardState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecoveryResult:
|
||||
"""Result of a game recovery attempt."""
|
||||
|
||||
game_id: str
|
||||
room_code: str
|
||||
success: bool
|
||||
phase: Optional[str] = None
|
||||
sequence_num: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class RecoveryService:
|
||||
"""
|
||||
Recovers games from event store on startup.
|
||||
|
||||
Works with the event store (PostgreSQL) as source of truth
|
||||
and state cache (Redis) for fast access during gameplay.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_store: EventStore,
|
||||
state_cache: StateCache,
|
||||
):
|
||||
"""
|
||||
Initialize recovery service.
|
||||
|
||||
Args:
|
||||
event_store: PostgreSQL event store.
|
||||
state_cache: Redis state cache.
|
||||
"""
|
||||
self.event_store = event_store
|
||||
self.state_cache = state_cache
|
||||
|
||||
async def recover_all_games(self) -> dict[str, Any]:
|
||||
"""
|
||||
Recover all active games from event store.
|
||||
|
||||
Queries PostgreSQL for active games and rebuilds their state
|
||||
from events, then caches in Redis.
|
||||
|
||||
Returns:
|
||||
Dict with recovery statistics:
|
||||
- recovered: Number of games successfully recovered
|
||||
- failed: Number of games that failed recovery
|
||||
- skipped: Number of games skipped (already ended)
|
||||
- games: List of recovered game info
|
||||
"""
|
||||
results = {
|
||||
"recovered": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"games": [],
|
||||
}
|
||||
|
||||
# Get active games from PostgreSQL
|
||||
active_games = await self.event_store.get_active_games()
|
||||
logger.info(f"Found {len(active_games)} active games to recover")
|
||||
|
||||
for game_meta in active_games:
|
||||
game_id = str(game_meta["id"])
|
||||
room_code = game_meta["room_code"]
|
||||
|
||||
try:
|
||||
result = await self.recover_game(game_id, room_code)
|
||||
|
||||
if result.success:
|
||||
results["recovered"] += 1
|
||||
results["games"].append({
|
||||
"game_id": game_id,
|
||||
"room_code": room_code,
|
||||
"phase": result.phase,
|
||||
"sequence": result.sequence_num,
|
||||
})
|
||||
else:
|
||||
if result.error == "game_ended":
|
||||
results["skipped"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
logger.warning(f"Failed to recover {game_id}: {result.error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error recovering game {game_id}: {e}", exc_info=True)
|
||||
results["failed"] += 1
|
||||
|
||||
return results
|
||||
|
||||
async def recover_game(
|
||||
self,
|
||||
game_id: str,
|
||||
room_code: Optional[str] = None,
|
||||
) -> RecoveryResult:
|
||||
"""
|
||||
Recover a single game from event store.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
room_code: Room code (optional, will be read from events).
|
||||
|
||||
Returns:
|
||||
RecoveryResult with success status and game info.
|
||||
"""
|
||||
# Get all events for this game
|
||||
events = await self.event_store.get_events(game_id)
|
||||
|
||||
if not events:
|
||||
return RecoveryResult(
|
||||
game_id=game_id,
|
||||
room_code=room_code or "",
|
||||
success=False,
|
||||
error="no_events",
|
||||
)
|
||||
|
||||
# Check if game is actually active (not ended)
|
||||
last_event = events[-1]
|
||||
if last_event.event_type == EventType.GAME_ENDED:
|
||||
return RecoveryResult(
|
||||
game_id=game_id,
|
||||
room_code=room_code or "",
|
||||
success=False,
|
||||
error="game_ended",
|
||||
)
|
||||
|
||||
# Rebuild state from events
|
||||
state = rebuild_state(events)
|
||||
|
||||
# Get room code from state if not provided
|
||||
if not room_code:
|
||||
room_code = state.room_code
|
||||
|
||||
# Convert state to cacheable dict
|
||||
state_dict = self._state_to_dict(state)
|
||||
|
||||
# Save to Redis cache
|
||||
await self.state_cache.save_game_state(game_id, state_dict)
|
||||
|
||||
# Also create/update room in cache
|
||||
await self._ensure_room_in_cache(state)
|
||||
|
||||
logger.info(
|
||||
f"Recovered game {game_id} (room {room_code}) "
|
||||
f"at sequence {state.sequence_num}, phase {state.phase.value}"
|
||||
)
|
||||
|
||||
return RecoveryResult(
|
||||
game_id=game_id,
|
||||
room_code=room_code,
|
||||
success=True,
|
||||
phase=state.phase.value,
|
||||
sequence_num=state.sequence_num,
|
||||
)
|
||||
|
||||
async def recover_from_sequence(
|
||||
self,
|
||||
game_id: str,
|
||||
cached_state: dict,
|
||||
cached_sequence: int,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Recover game by applying only new events to cached state.
|
||||
|
||||
More efficient than full rebuild when we have a recent cache.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
cached_state: Previously cached state dict.
|
||||
cached_sequence: Sequence number of cached state.
|
||||
|
||||
Returns:
|
||||
Updated state dict, or None if no new events.
|
||||
"""
|
||||
# Get events after cached sequence
|
||||
new_events = await self.event_store.get_events(
|
||||
game_id,
|
||||
from_sequence=cached_sequence + 1,
|
||||
)
|
||||
|
||||
if not new_events:
|
||||
return None # No new events
|
||||
|
||||
# Rebuild state from cache + new events
|
||||
state = self._dict_to_state(cached_state)
|
||||
for event in new_events:
|
||||
state.apply(event)
|
||||
|
||||
# Convert back to dict
|
||||
new_state = self._state_to_dict(state)
|
||||
|
||||
# Update cache
|
||||
await self.state_cache.save_game_state(game_id, new_state)
|
||||
|
||||
return new_state
|
||||
|
||||
async def _ensure_room_in_cache(self, state: RebuiltGameState) -> None:
|
||||
"""
|
||||
Ensure room exists in Redis cache after recovery.
|
||||
|
||||
Args:
|
||||
state: Rebuilt game state.
|
||||
"""
|
||||
room_code = state.room_code
|
||||
if not room_code:
|
||||
return
|
||||
|
||||
# Check if room already exists
|
||||
if await self.state_cache.room_exists(room_code):
|
||||
return
|
||||
|
||||
# Create room in cache
|
||||
await self.state_cache.create_room(
|
||||
room_code=room_code,
|
||||
game_id=state.game_id,
|
||||
host_id=state.host_id or "",
|
||||
server_id="recovered",
|
||||
)
|
||||
|
||||
# Set room status based on game phase
|
||||
if state.phase.value == "waiting":
|
||||
status = "waiting"
|
||||
elif state.phase.value in ("game_over", "round_over"):
|
||||
status = "finished"
|
||||
else:
|
||||
status = "playing"
|
||||
|
||||
await self.state_cache.set_room_status(room_code, status)
|
||||
|
||||
def _state_to_dict(self, state: RebuiltGameState) -> dict:
|
||||
"""
|
||||
Convert RebuiltGameState to dict for caching.
|
||||
|
||||
Args:
|
||||
state: Game state to convert.
|
||||
|
||||
Returns:
|
||||
Cacheable dict representation.
|
||||
"""
|
||||
return {
|
||||
"game_id": state.game_id,
|
||||
"room_code": state.room_code,
|
||||
"phase": state.phase.value,
|
||||
"current_round": state.current_round,
|
||||
"total_rounds": state.total_rounds,
|
||||
"current_player_idx": state.current_player_idx,
|
||||
"player_order": state.player_order,
|
||||
"players": {
|
||||
pid: {
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"cards": [c.to_dict() for c in p.cards],
|
||||
"score": p.score,
|
||||
"total_score": p.total_score,
|
||||
"rounds_won": p.rounds_won,
|
||||
"is_cpu": p.is_cpu,
|
||||
"cpu_profile": p.cpu_profile,
|
||||
}
|
||||
for pid, p in state.players.items()
|
||||
},
|
||||
"deck_remaining": state.deck_remaining,
|
||||
"discard_pile": [c.to_dict() for c in state.discard_pile],
|
||||
"discard_top": state.discard_pile[-1].to_dict() if state.discard_pile else None,
|
||||
"drawn_card": state.drawn_card.to_dict() if state.drawn_card else None,
|
||||
"drawn_from_discard": state.drawn_from_discard,
|
||||
"options": state.options,
|
||||
"sequence_num": state.sequence_num,
|
||||
"finisher_id": state.finisher_id,
|
||||
"host_id": state.host_id,
|
||||
"initial_flips_done": list(state.initial_flips_done),
|
||||
"players_with_final_turn": list(state.players_with_final_turn),
|
||||
}
|
||||
|
||||
def _dict_to_state(self, d: dict) -> RebuiltGameState:
|
||||
"""
|
||||
Convert dict back to RebuiltGameState.
|
||||
|
||||
Args:
|
||||
d: Cached state dict.
|
||||
|
||||
Returns:
|
||||
Reconstructed game state.
|
||||
"""
|
||||
from models.game_state import GamePhase, PlayerState
|
||||
|
||||
state = RebuiltGameState(game_id=d["game_id"])
|
||||
state.room_code = d.get("room_code", "")
|
||||
state.phase = GamePhase(d.get("phase", "waiting"))
|
||||
state.current_round = d.get("current_round", 0)
|
||||
state.total_rounds = d.get("total_rounds", 1)
|
||||
state.current_player_idx = d.get("current_player_idx", 0)
|
||||
state.player_order = d.get("player_order", [])
|
||||
state.deck_remaining = d.get("deck_remaining", 0)
|
||||
state.options = d.get("options", {})
|
||||
state.sequence_num = d.get("sequence_num", 0)
|
||||
state.finisher_id = d.get("finisher_id")
|
||||
state.host_id = d.get("host_id")
|
||||
state.initial_flips_done = set(d.get("initial_flips_done", []))
|
||||
state.players_with_final_turn = set(d.get("players_with_final_turn", []))
|
||||
state.drawn_from_discard = d.get("drawn_from_discard", False)
|
||||
|
||||
# Rebuild players
|
||||
players_data = d.get("players", {})
|
||||
for pid, pdata in players_data.items():
|
||||
player = PlayerState(
|
||||
id=pdata["id"],
|
||||
name=pdata["name"],
|
||||
is_cpu=pdata.get("is_cpu", False),
|
||||
cpu_profile=pdata.get("cpu_profile"),
|
||||
score=pdata.get("score", 0),
|
||||
total_score=pdata.get("total_score", 0),
|
||||
rounds_won=pdata.get("rounds_won", 0),
|
||||
)
|
||||
player.cards = [CardState.from_dict(c) for c in pdata.get("cards", [])]
|
||||
state.players[pid] = player
|
||||
|
||||
# Rebuild discard pile
|
||||
discard_data = d.get("discard_pile", [])
|
||||
state.discard_pile = [CardState.from_dict(c) for c in discard_data]
|
||||
|
||||
# Rebuild drawn card
|
||||
drawn = d.get("drawn_card")
|
||||
if drawn:
|
||||
state.drawn_card = CardState.from_dict(drawn)
|
||||
|
||||
return state
|
||||
583
server/services/replay_service.py
Normal file
583
server/services/replay_service.py
Normal file
@@ -0,0 +1,583 @@
|
||||
"""
|
||||
Replay service for Golf game.
|
||||
|
||||
Provides game replay functionality, share link generation, and game export/import.
|
||||
Leverages the event-sourced architecture for perfect game reconstruction.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, List
|
||||
|
||||
import asyncpg
|
||||
|
||||
from stores.event_store import EventStore
|
||||
from models.events import GameEvent, EventType
|
||||
from models.game_state import rebuild_state, RebuiltGameState, CardState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# SQL schema for replay/sharing tables
|
||||
REPLAY_SCHEMA_SQL = """
|
||||
-- Public share links for completed games
|
||||
CREATE TABLE IF NOT EXISTS shared_games (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
game_id UUID NOT NULL,
|
||||
share_code VARCHAR(12) UNIQUE NOT NULL,
|
||||
created_by VARCHAR(50),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
expires_at TIMESTAMPTZ,
|
||||
view_count INTEGER DEFAULT 0,
|
||||
is_public BOOLEAN DEFAULT true,
|
||||
title VARCHAR(100),
|
||||
description TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_shared_games_code ON shared_games(share_code);
|
||||
CREATE INDEX IF NOT EXISTS idx_shared_games_game ON shared_games(game_id);
|
||||
|
||||
-- Track replay views for analytics
|
||||
CREATE TABLE IF NOT EXISTS replay_views (
|
||||
id SERIAL PRIMARY KEY,
|
||||
shared_game_id UUID REFERENCES shared_games(id),
|
||||
viewer_id VARCHAR(50),
|
||||
viewed_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
ip_hash VARCHAR(64),
|
||||
watch_duration_seconds INTEGER
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_replay_views_shared ON replay_views(shared_game_id);
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayFrame:
|
||||
"""Single frame in a replay."""
|
||||
event_index: int
|
||||
event_type: str
|
||||
event_data: dict
|
||||
game_state: dict
|
||||
timestamp: float # Seconds from start
|
||||
player_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GameReplay:
|
||||
"""Complete replay of a game."""
|
||||
game_id: str
|
||||
frames: List[ReplayFrame]
|
||||
total_duration_seconds: float
|
||||
player_names: List[str]
|
||||
final_scores: dict
|
||||
winner: Optional[str]
|
||||
options: dict
|
||||
room_code: str
|
||||
total_rounds: int
|
||||
|
||||
|
||||
class ReplayService:
|
||||
"""
|
||||
Service for game replay, export, and sharing.
|
||||
|
||||
Provides:
|
||||
- Replay building from event store
|
||||
- Share link creation and retrieval
|
||||
- Game export/import
|
||||
"""
|
||||
|
||||
EXPORT_VERSION = "1.0"
|
||||
|
||||
def __init__(self, pool: asyncpg.Pool, event_store: EventStore):
|
||||
"""
|
||||
Initialize replay service.
|
||||
|
||||
Args:
|
||||
pool: asyncpg connection pool.
|
||||
event_store: Event store for retrieving game events.
|
||||
"""
|
||||
self.pool = pool
|
||||
self.event_store = event_store
|
||||
|
||||
async def initialize_schema(self) -> None:
|
||||
"""Create replay tables if they don't exist."""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute(REPLAY_SCHEMA_SQL)
|
||||
logger.info("Replay schema initialized")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Replay Building
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def build_replay(self, game_id: str) -> GameReplay:
|
||||
"""
|
||||
Build complete replay from event store.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
|
||||
Returns:
|
||||
GameReplay with all frames and metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If no events found for game.
|
||||
"""
|
||||
events = await self.event_store.get_events(game_id)
|
||||
if not events:
|
||||
raise ValueError(f"No events found for game {game_id}")
|
||||
|
||||
frames = []
|
||||
state = RebuiltGameState(game_id=game_id)
|
||||
start_time = None
|
||||
|
||||
for i, event in enumerate(events):
|
||||
if start_time is None:
|
||||
start_time = event.timestamp
|
||||
|
||||
# Apply event to get state
|
||||
state.apply(event)
|
||||
|
||||
# Calculate timestamp relative to start
|
||||
elapsed = (event.timestamp - start_time).total_seconds()
|
||||
|
||||
frames.append(ReplayFrame(
|
||||
event_index=i,
|
||||
event_type=event.event_type.value,
|
||||
event_data=event.data,
|
||||
game_state=self._state_to_dict(state),
|
||||
timestamp=elapsed,
|
||||
player_id=event.player_id,
|
||||
))
|
||||
|
||||
# Extract final game info
|
||||
player_names = [p.name for p in state.players.values()]
|
||||
final_scores = {p.name: p.total_score for p in state.players.values()}
|
||||
|
||||
# Determine winner (lowest total score)
|
||||
winner = None
|
||||
if state.phase.value == "game_over" and state.players:
|
||||
winner_player = min(state.players.values(), key=lambda p: p.total_score)
|
||||
winner = winner_player.name
|
||||
|
||||
return GameReplay(
|
||||
game_id=game_id,
|
||||
frames=frames,
|
||||
total_duration_seconds=frames[-1].timestamp if frames else 0,
|
||||
player_names=player_names,
|
||||
final_scores=final_scores,
|
||||
winner=winner,
|
||||
options=state.options,
|
||||
room_code=state.room_code,
|
||||
total_rounds=state.total_rounds,
|
||||
)
|
||||
|
||||
async def get_replay_frame(
|
||||
self,
|
||||
game_id: str,
|
||||
frame_index: int
|
||||
) -> Optional[ReplayFrame]:
|
||||
"""
|
||||
Get a specific frame from a replay.
|
||||
|
||||
Useful for seeking to a specific point without loading entire replay.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
frame_index: Index of frame to retrieve (0-based).
|
||||
|
||||
Returns:
|
||||
ReplayFrame or None if index out of range.
|
||||
"""
|
||||
events = await self.event_store.get_events(
|
||||
game_id,
|
||||
from_sequence=1,
|
||||
to_sequence=frame_index + 1
|
||||
)
|
||||
|
||||
if not events or len(events) <= frame_index:
|
||||
return None
|
||||
|
||||
state = RebuiltGameState(game_id=game_id)
|
||||
start_time = events[0].timestamp if events else None
|
||||
|
||||
for event in events:
|
||||
state.apply(event)
|
||||
|
||||
last_event = events[-1]
|
||||
elapsed = (last_event.timestamp - start_time).total_seconds() if start_time else 0
|
||||
|
||||
return ReplayFrame(
|
||||
event_index=frame_index,
|
||||
event_type=last_event.event_type.value,
|
||||
event_data=last_event.data,
|
||||
game_state=self._state_to_dict(state),
|
||||
timestamp=elapsed,
|
||||
player_id=last_event.player_id,
|
||||
)
|
||||
|
||||
def _state_to_dict(self, state: RebuiltGameState) -> dict:
|
||||
"""Convert RebuiltGameState to serializable dict."""
|
||||
players = []
|
||||
for pid in state.player_order:
|
||||
if pid in state.players:
|
||||
p = state.players[pid]
|
||||
players.append({
|
||||
"id": p.id,
|
||||
"name": p.name,
|
||||
"cards": [c.to_dict() for c in p.cards],
|
||||
"score": p.score,
|
||||
"total_score": p.total_score,
|
||||
"rounds_won": p.rounds_won,
|
||||
"is_cpu": p.is_cpu,
|
||||
"all_face_up": p.all_face_up(),
|
||||
})
|
||||
|
||||
return {
|
||||
"phase": state.phase.value,
|
||||
"players": players,
|
||||
"current_player_idx": state.current_player_idx,
|
||||
"current_player_id": state.player_order[state.current_player_idx] if state.player_order else None,
|
||||
"deck_remaining": state.deck_remaining,
|
||||
"discard_pile": [c.to_dict() for c in state.discard_pile],
|
||||
"discard_top": state.discard_pile[-1].to_dict() if state.discard_pile else None,
|
||||
"drawn_card": state.drawn_card.to_dict() if state.drawn_card else None,
|
||||
"current_round": state.current_round,
|
||||
"total_rounds": state.total_rounds,
|
||||
"finisher_id": state.finisher_id,
|
||||
"options": state.options,
|
||||
}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Share Links
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def create_share_link(
|
||||
self,
|
||||
game_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
expires_days: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate shareable link for a game.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
user_id: ID of user creating the share.
|
||||
title: Optional custom title.
|
||||
description: Optional description.
|
||||
expires_days: Days until link expires (None = never).
|
||||
|
||||
Returns:
|
||||
12-character share code.
|
||||
"""
|
||||
share_code = secrets.token_urlsafe(9)[:12]
|
||||
|
||||
expires_at = None
|
||||
if expires_days:
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_days)
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
INSERT INTO shared_games
|
||||
(game_id, share_code, created_by, title, description, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
""", game_id, share_code, user_id, title, description, expires_at)
|
||||
|
||||
logger.info(f"Created share link {share_code} for game {game_id}")
|
||||
return share_code
|
||||
|
||||
async def get_shared_game(self, share_code: str) -> Optional[dict]:
|
||||
"""
|
||||
Retrieve shared game by code.
|
||||
|
||||
Args:
|
||||
share_code: 12-character share code.
|
||||
|
||||
Returns:
|
||||
Shared game metadata dict, or None if not found/expired.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT sg.*, g.room_code, g.completed_at, g.num_players, g.num_rounds
|
||||
FROM shared_games sg
|
||||
JOIN games_v2 g ON sg.game_id = g.id
|
||||
WHERE sg.share_code = $1
|
||||
AND sg.is_public = true
|
||||
AND (sg.expires_at IS NULL OR sg.expires_at > NOW())
|
||||
""", share_code)
|
||||
|
||||
if row:
|
||||
# Increment view count
|
||||
await conn.execute("""
|
||||
UPDATE shared_games SET view_count = view_count + 1
|
||||
WHERE share_code = $1
|
||||
""", share_code)
|
||||
|
||||
return dict(row)
|
||||
return None
|
||||
|
||||
async def record_replay_view(
|
||||
self,
|
||||
shared_game_id: str,
|
||||
viewer_id: Optional[str] = None,
|
||||
ip_hash: Optional[str] = None,
|
||||
duration_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Record a replay view for analytics.
|
||||
|
||||
Args:
|
||||
shared_game_id: UUID of the shared_games record.
|
||||
viewer_id: Optional user ID of viewer.
|
||||
ip_hash: Optional hashed IP for rate limiting.
|
||||
duration_seconds: Optional watch duration.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
INSERT INTO replay_views
|
||||
(shared_game_id, viewer_id, ip_hash, watch_duration_seconds)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", shared_game_id, viewer_id, ip_hash, duration_seconds)
|
||||
|
||||
async def get_user_shared_games(self, user_id: str) -> List[dict]:
|
||||
"""
|
||||
Get all shared games created by a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
|
||||
Returns:
|
||||
List of shared game metadata dicts.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT sg.*, g.room_code, g.completed_at
|
||||
FROM shared_games sg
|
||||
JOIN games_v2 g ON sg.game_id = g.id
|
||||
WHERE sg.created_by = $1
|
||||
ORDER BY sg.created_at DESC
|
||||
""", user_id)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def delete_share_link(self, share_code: str, user_id: str) -> bool:
|
||||
"""
|
||||
Delete a share link.
|
||||
|
||||
Args:
|
||||
share_code: Share code to delete.
|
||||
user_id: User requesting deletion (must be creator).
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found or not authorized.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM shared_games
|
||||
WHERE share_code = $1 AND created_by = $2
|
||||
""", share_code, user_id)
|
||||
return result == "DELETE 1"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Export/Import
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def export_game(self, game_id: str) -> dict:
|
||||
"""
|
||||
Export game as portable JSON format.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
|
||||
Returns:
|
||||
Export data dict suitable for JSON serialization.
|
||||
"""
|
||||
replay = await self.build_replay(game_id)
|
||||
|
||||
# Get raw events for export
|
||||
events = await self.event_store.get_events(game_id)
|
||||
start_time = events[0].timestamp if events else datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"version": self.EXPORT_VERSION,
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"game": {
|
||||
"id": replay.game_id,
|
||||
"room_code": replay.room_code,
|
||||
"players": replay.player_names,
|
||||
"winner": replay.winner,
|
||||
"final_scores": replay.final_scores,
|
||||
"duration_seconds": replay.total_duration_seconds,
|
||||
"total_rounds": replay.total_rounds,
|
||||
"options": replay.options,
|
||||
},
|
||||
"events": [
|
||||
{
|
||||
"type": event.event_type.value,
|
||||
"sequence": event.sequence_num,
|
||||
"player_id": event.player_id,
|
||||
"data": event.data,
|
||||
"timestamp": (event.timestamp - start_time).total_seconds(),
|
||||
}
|
||||
for event in events
|
||||
],
|
||||
}
|
||||
|
||||
async def import_game(self, export_data: dict, user_id: str) -> str:
|
||||
"""
|
||||
Import a game from exported JSON.
|
||||
|
||||
Creates a new game record with the imported events.
|
||||
|
||||
Args:
|
||||
export_data: Exported game data.
|
||||
user_id: User performing the import.
|
||||
|
||||
Returns:
|
||||
New game ID.
|
||||
|
||||
Raises:
|
||||
ValueError: If export format is invalid.
|
||||
"""
|
||||
version = export_data.get("version")
|
||||
if version != self.EXPORT_VERSION:
|
||||
raise ValueError(f"Unsupported export version: {version}")
|
||||
|
||||
if "events" not in export_data or not export_data["events"]:
|
||||
raise ValueError("Export contains no events")
|
||||
|
||||
# Generate new game ID
|
||||
import uuid
|
||||
new_game_id = str(uuid.uuid4())
|
||||
|
||||
# Calculate base timestamp
|
||||
base_time = datetime.now(timezone.utc)
|
||||
|
||||
# Import events with new game ID
|
||||
events = []
|
||||
for event_data in export_data["events"]:
|
||||
event = GameEvent(
|
||||
event_type=EventType(event_data["type"]),
|
||||
game_id=new_game_id,
|
||||
sequence_num=event_data["sequence"],
|
||||
player_id=event_data.get("player_id"),
|
||||
data=event_data["data"],
|
||||
timestamp=base_time + timedelta(seconds=event_data.get("timestamp", 0)),
|
||||
)
|
||||
events.append(event)
|
||||
|
||||
# Batch insert events
|
||||
await self.event_store.append_batch(events)
|
||||
|
||||
# Create game metadata record
|
||||
game_info = export_data.get("game", {})
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
INSERT INTO games_v2
|
||||
(id, room_code, status, num_rounds, options, completed_at)
|
||||
VALUES ($1, $2, 'imported', $3, $4, NOW())
|
||||
""",
|
||||
new_game_id,
|
||||
f"IMP-{secrets.token_hex(2).upper()}", # Generate room code for imported games
|
||||
game_info.get("total_rounds", 1),
|
||||
json.dumps(game_info.get("options", {})),
|
||||
)
|
||||
|
||||
logger.info(f"Imported game as {new_game_id} by user {user_id}")
|
||||
return new_game_id
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Game History Queries
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def get_user_game_history(
|
||||
self,
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get game history for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID.
|
||||
limit: Max games to return.
|
||||
offset: Pagination offset.
|
||||
|
||||
Returns:
|
||||
List of game summary dicts.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT g.id, g.room_code, g.status, g.completed_at,
|
||||
g.num_players, g.num_rounds, g.winner_id,
|
||||
$1 = ANY(g.player_ids) as participated
|
||||
FROM games_v2 g
|
||||
WHERE $1 = ANY(g.player_ids)
|
||||
AND g.status IN ('completed', 'imported')
|
||||
ORDER BY g.completed_at DESC NULLS LAST
|
||||
LIMIT $2 OFFSET $3
|
||||
""", user_id, limit, offset)
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def can_view_game(self, user_id: Optional[str], game_id: str) -> bool:
|
||||
"""
|
||||
Check if user can view a game replay.
|
||||
|
||||
Users can view games they played in or games that are shared publicly.
|
||||
|
||||
Args:
|
||||
user_id: User ID (None for anonymous).
|
||||
game_id: Game UUID.
|
||||
|
||||
Returns:
|
||||
True if user can view the game.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
# Check if user played in the game
|
||||
if user_id:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT 1 FROM games_v2
|
||||
WHERE id = $1 AND $2 = ANY(player_ids)
|
||||
""", game_id, user_id)
|
||||
if row:
|
||||
return True
|
||||
|
||||
# Check if game has a public share link
|
||||
row = await conn.fetchrow("""
|
||||
SELECT 1 FROM shared_games
|
||||
WHERE game_id = $1
|
||||
AND is_public = true
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
""", game_id)
|
||||
return row is not None
|
||||
|
||||
|
||||
# Global instance
|
||||
_replay_service: Optional[ReplayService] = None
|
||||
|
||||
|
||||
async def get_replay_service(pool: asyncpg.Pool, event_store: EventStore) -> ReplayService:
|
||||
"""Get or create the replay service instance."""
|
||||
global _replay_service
|
||||
if _replay_service is None:
|
||||
_replay_service = ReplayService(pool, event_store)
|
||||
await _replay_service.initialize_schema()
|
||||
return _replay_service
|
||||
|
||||
|
||||
def set_replay_service(service: ReplayService) -> None:
|
||||
"""Set the global replay service instance."""
|
||||
global _replay_service
|
||||
_replay_service = service
|
||||
|
||||
|
||||
def close_replay_service() -> None:
|
||||
"""Close the replay service."""
|
||||
global _replay_service
|
||||
_replay_service = None
|
||||
265
server/services/spectator.py
Normal file
265
server/services/spectator.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Spectator manager for Golf game.
|
||||
|
||||
Enables spectators to watch live games in progress via WebSocket connections.
|
||||
Spectators receive game state updates but cannot interact with the game.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum spectators per game to prevent resource exhaustion
|
||||
MAX_SPECTATORS_PER_GAME = 50
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpectatorInfo:
|
||||
"""Information about a spectator connection."""
|
||||
websocket: WebSocket
|
||||
joined_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
user_id: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class SpectatorManager:
|
||||
"""
|
||||
Manage spectators watching live games.
|
||||
|
||||
Spectators can join any active game and receive real-time updates.
|
||||
They see the same state as players but cannot take actions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# game_id -> list of SpectatorInfo
|
||||
self._spectators: Dict[str, List[SpectatorInfo]] = {}
|
||||
# websocket -> game_id (for reverse lookup on disconnect)
|
||||
self._ws_to_game: Dict[WebSocket, str] = {}
|
||||
|
||||
async def add_spectator(
|
||||
self,
|
||||
game_id: str,
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Add spectator to a game.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
websocket: Spectator's WebSocket connection.
|
||||
user_id: Optional user ID.
|
||||
username: Optional display name.
|
||||
|
||||
Returns:
|
||||
True if added, False if game is at spectator limit.
|
||||
"""
|
||||
if game_id not in self._spectators:
|
||||
self._spectators[game_id] = []
|
||||
|
||||
# Check spectator limit
|
||||
if len(self._spectators[game_id]) >= MAX_SPECTATORS_PER_GAME:
|
||||
logger.warning(f"Game {game_id} at spectator limit ({MAX_SPECTATORS_PER_GAME})")
|
||||
return False
|
||||
|
||||
info = SpectatorInfo(
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
username=username or "Spectator",
|
||||
)
|
||||
self._spectators[game_id].append(info)
|
||||
self._ws_to_game[websocket] = game_id
|
||||
|
||||
logger.info(f"Spectator joined game {game_id} (total: {len(self._spectators[game_id])})")
|
||||
return True
|
||||
|
||||
async def remove_spectator(self, game_id: str, websocket: WebSocket) -> None:
|
||||
"""
|
||||
Remove spectator from a game.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
websocket: Spectator's WebSocket connection.
|
||||
"""
|
||||
if game_id in self._spectators:
|
||||
# Find and remove the spectator
|
||||
self._spectators[game_id] = [
|
||||
info for info in self._spectators[game_id]
|
||||
if info.websocket != websocket
|
||||
]
|
||||
logger.info(f"Spectator left game {game_id} (remaining: {len(self._spectators[game_id])})")
|
||||
|
||||
# Clean up empty games
|
||||
if not self._spectators[game_id]:
|
||||
del self._spectators[game_id]
|
||||
|
||||
# Clean up reverse lookup
|
||||
self._ws_to_game.pop(websocket, None)
|
||||
|
||||
async def remove_spectator_by_ws(self, websocket: WebSocket) -> None:
|
||||
"""
|
||||
Remove spectator by WebSocket (for disconnect handling).
|
||||
|
||||
Args:
|
||||
websocket: Spectator's WebSocket connection.
|
||||
"""
|
||||
game_id = self._ws_to_game.get(websocket)
|
||||
if game_id:
|
||||
await self.remove_spectator(game_id, websocket)
|
||||
|
||||
async def broadcast_to_spectators(self, game_id: str, message: dict) -> None:
|
||||
"""
|
||||
Send update to all spectators of a game.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
message: Message to broadcast.
|
||||
"""
|
||||
if game_id not in self._spectators:
|
||||
return
|
||||
|
||||
dead_connections: List[SpectatorInfo] = []
|
||||
|
||||
for info in self._spectators[game_id]:
|
||||
try:
|
||||
await info.websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to send to spectator: {e}")
|
||||
dead_connections.append(info)
|
||||
|
||||
# Clean up dead connections
|
||||
for info in dead_connections:
|
||||
self._spectators[game_id] = [
|
||||
s for s in self._spectators[game_id]
|
||||
if s.websocket != info.websocket
|
||||
]
|
||||
self._ws_to_game.pop(info.websocket, None)
|
||||
|
||||
# Clean up empty games
|
||||
if game_id in self._spectators and not self._spectators[game_id]:
|
||||
del self._spectators[game_id]
|
||||
|
||||
async def send_game_state(
|
||||
self,
|
||||
game_id: str,
|
||||
game_state: dict,
|
||||
event_type: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send current game state to all spectators.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
game_state: Current game state dict.
|
||||
event_type: Optional event type that triggered this update.
|
||||
"""
|
||||
message = {
|
||||
"type": "game_state",
|
||||
"game_state": game_state,
|
||||
"spectator_count": self.get_spectator_count(game_id),
|
||||
}
|
||||
if event_type:
|
||||
message["event_type"] = event_type
|
||||
|
||||
await self.broadcast_to_spectators(game_id, message)
|
||||
|
||||
def get_spectator_count(self, game_id: str) -> int:
|
||||
"""
|
||||
Get number of spectators for a game.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
|
||||
Returns:
|
||||
Spectator count.
|
||||
"""
|
||||
return len(self._spectators.get(game_id, []))
|
||||
|
||||
def get_spectator_usernames(self, game_id: str) -> list[str]:
|
||||
"""
|
||||
Get list of spectator usernames.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
|
||||
Returns:
|
||||
List of spectator usernames.
|
||||
"""
|
||||
if game_id not in self._spectators:
|
||||
return []
|
||||
return [
|
||||
info.username or "Anonymous"
|
||||
for info in self._spectators[game_id]
|
||||
]
|
||||
|
||||
def get_games_with_spectators(self) -> dict[str, int]:
|
||||
"""
|
||||
Get all games that have spectators.
|
||||
|
||||
Returns:
|
||||
Dict of game_id -> spectator count.
|
||||
"""
|
||||
return {
|
||||
game_id: len(spectators)
|
||||
for game_id, spectators in self._spectators.items()
|
||||
if spectators
|
||||
}
|
||||
|
||||
async def notify_game_ended(self, game_id: str, final_state: dict) -> None:
|
||||
"""
|
||||
Notify spectators that a game has ended.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
final_state: Final game state with scores.
|
||||
"""
|
||||
await self.broadcast_to_spectators(game_id, {
|
||||
"type": "game_ended",
|
||||
"final_state": final_state,
|
||||
})
|
||||
|
||||
async def close_all_for_game(self, game_id: str) -> None:
|
||||
"""
|
||||
Close all spectator connections for a game.
|
||||
|
||||
Use when a game is being cleaned up.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
"""
|
||||
if game_id not in self._spectators:
|
||||
return
|
||||
|
||||
for info in list(self._spectators[game_id]):
|
||||
try:
|
||||
await info.websocket.close(code=1000, reason="Game ended")
|
||||
except Exception:
|
||||
pass
|
||||
self._ws_to_game.pop(info.websocket, None)
|
||||
|
||||
del self._spectators[game_id]
|
||||
logger.info(f"Closed all spectators for game {game_id}")
|
||||
|
||||
|
||||
# Global instance
|
||||
_spectator_manager: Optional[SpectatorManager] = None
|
||||
|
||||
|
||||
def get_spectator_manager() -> SpectatorManager:
|
||||
"""Get the global spectator manager instance."""
|
||||
global _spectator_manager
|
||||
if _spectator_manager is None:
|
||||
_spectator_manager = SpectatorManager()
|
||||
return _spectator_manager
|
||||
|
||||
|
||||
def close_spectator_manager() -> None:
|
||||
"""Close the spectator manager."""
|
||||
global _spectator_manager
|
||||
_spectator_manager = None
|
||||
977
server/services/stats_service.py
Normal file
977
server/services/stats_service.py
Normal file
@@ -0,0 +1,977 @@
|
||||
"""
|
||||
Stats service for Golf game leaderboards and achievements.
|
||||
|
||||
Provides player statistics aggregation, leaderboard queries, and achievement tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
|
||||
import asyncpg
|
||||
|
||||
from stores.event_store import EventStore
|
||||
from models.events import EventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlayerStats:
|
||||
"""Full player statistics."""
|
||||
user_id: str
|
||||
username: str
|
||||
games_played: int = 0
|
||||
games_won: int = 0
|
||||
win_rate: float = 0.0
|
||||
rounds_played: int = 0
|
||||
rounds_won: int = 0
|
||||
avg_score: float = 0.0
|
||||
best_round_score: Optional[int] = None
|
||||
worst_round_score: Optional[int] = None
|
||||
knockouts: int = 0
|
||||
perfect_rounds: int = 0
|
||||
wolfpacks: int = 0
|
||||
current_win_streak: int = 0
|
||||
best_win_streak: int = 0
|
||||
first_game_at: Optional[datetime] = None
|
||||
last_game_at: Optional[datetime] = None
|
||||
achievements: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LeaderboardEntry:
|
||||
"""Single entry on a leaderboard."""
|
||||
rank: int
|
||||
user_id: str
|
||||
username: str
|
||||
value: float
|
||||
games_played: int
|
||||
secondary_value: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Achievement:
|
||||
"""Achievement definition."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
icon: str
|
||||
category: str
|
||||
threshold: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserAchievement:
|
||||
"""Achievement earned by a user."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
icon: str
|
||||
earned_at: datetime
|
||||
game_id: Optional[str] = None
|
||||
|
||||
|
||||
class StatsService:
|
||||
"""
|
||||
Player statistics and leaderboards service.
|
||||
|
||||
Provides methods for:
|
||||
- Querying player stats
|
||||
- Fetching leaderboards by various metrics
|
||||
- Processing game completion for stats aggregation
|
||||
- Achievement checking and awarding
|
||||
"""
|
||||
|
||||
def __init__(self, pool: asyncpg.Pool, event_store: Optional[EventStore] = None):
|
||||
"""
|
||||
Initialize stats service.
|
||||
|
||||
Args:
|
||||
pool: asyncpg connection pool.
|
||||
event_store: Optional EventStore for event-based stats processing.
|
||||
"""
|
||||
self.pool = pool
|
||||
self.event_store = event_store
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Stats Queries
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def get_player_stats(self, user_id: str) -> Optional[PlayerStats]:
|
||||
"""
|
||||
Get full stats for a specific player.
|
||||
|
||||
Args:
|
||||
user_id: User UUID.
|
||||
|
||||
Returns:
|
||||
PlayerStats or None if player not found.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT s.*, u.username,
|
||||
ROUND(s.games_won::numeric / NULLIF(s.games_played, 0) * 100, 1) as win_rate,
|
||||
ROUND(s.total_points::numeric / NULLIF(s.total_rounds, 0), 1) as avg_score_calc
|
||||
FROM player_stats s
|
||||
JOIN users_v2 u ON s.user_id = u.id
|
||||
WHERE s.user_id = $1
|
||||
""", user_id)
|
||||
|
||||
if not row:
|
||||
# Check if user exists but has no stats
|
||||
user_row = await conn.fetchrow(
|
||||
"SELECT username FROM users_v2 WHERE id = $1",
|
||||
user_id
|
||||
)
|
||||
if user_row:
|
||||
return PlayerStats(
|
||||
user_id=user_id,
|
||||
username=user_row["username"],
|
||||
)
|
||||
return None
|
||||
|
||||
# Get achievements
|
||||
achievements = await conn.fetch("""
|
||||
SELECT achievement_id FROM user_achievements
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
return PlayerStats(
|
||||
user_id=str(row["user_id"]),
|
||||
username=row["username"],
|
||||
games_played=row["games_played"] or 0,
|
||||
games_won=row["games_won"] or 0,
|
||||
win_rate=float(row["win_rate"] or 0),
|
||||
rounds_played=row["total_rounds"] or 0,
|
||||
rounds_won=row["rounds_won"] or 0,
|
||||
avg_score=float(row["avg_score_calc"] or 0),
|
||||
best_round_score=row["best_score"],
|
||||
worst_round_score=row["worst_score"],
|
||||
knockouts=row["knockouts"] or 0,
|
||||
perfect_rounds=row["perfect_rounds"] or 0,
|
||||
wolfpacks=row["wolfpacks"] or 0,
|
||||
current_win_streak=row["current_win_streak"] or 0,
|
||||
best_win_streak=row["best_win_streak"] or 0,
|
||||
first_game_at=row["first_game_at"].replace(tzinfo=timezone.utc) if row["first_game_at"] else None,
|
||||
last_game_at=row["last_game_at"].replace(tzinfo=timezone.utc) if row["last_game_at"] else None,
|
||||
achievements=[a["achievement_id"] for a in achievements],
|
||||
)
|
||||
|
||||
async def get_leaderboard(
|
||||
self,
|
||||
metric: str = "wins",
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> List[LeaderboardEntry]:
|
||||
"""
|
||||
Get leaderboard by metric.
|
||||
|
||||
Args:
|
||||
metric: Ranking metric - wins, win_rate, avg_score, knockouts, streak.
|
||||
limit: Maximum entries to return.
|
||||
offset: Pagination offset.
|
||||
|
||||
Returns:
|
||||
List of LeaderboardEntry sorted by metric.
|
||||
"""
|
||||
order_map = {
|
||||
"wins": ("games_won", "DESC"),
|
||||
"win_rate": ("win_rate", "DESC"),
|
||||
"avg_score": ("avg_score", "ASC"), # Lower is better
|
||||
"knockouts": ("knockouts", "DESC"),
|
||||
"streak": ("best_win_streak", "DESC"),
|
||||
}
|
||||
|
||||
if metric not in order_map:
|
||||
metric = "wins"
|
||||
|
||||
column, direction = order_map[metric]
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
# Check if materialized view exists
|
||||
view_exists = await conn.fetchval(
|
||||
"SELECT 1 FROM pg_matviews WHERE matviewname = 'leaderboard_overall'"
|
||||
)
|
||||
|
||||
if view_exists:
|
||||
# Use materialized view for performance
|
||||
rows = await conn.fetch(f"""
|
||||
SELECT
|
||||
user_id, username, games_played, games_won,
|
||||
win_rate, avg_score, knockouts, best_win_streak,
|
||||
ROW_NUMBER() OVER (ORDER BY {column} {direction}) as rank
|
||||
FROM leaderboard_overall
|
||||
ORDER BY {column} {direction}
|
||||
LIMIT $1 OFFSET $2
|
||||
""", limit, offset)
|
||||
else:
|
||||
# Fall back to direct query
|
||||
rows = await conn.fetch(f"""
|
||||
SELECT
|
||||
s.user_id, u.username, s.games_played, s.games_won,
|
||||
ROUND(s.games_won::numeric / NULLIF(s.games_played, 0) * 100, 1) as win_rate,
|
||||
ROUND(s.total_points::numeric / NULLIF(s.total_rounds, 0), 1) as avg_score,
|
||||
s.knockouts, s.best_win_streak,
|
||||
ROW_NUMBER() OVER (ORDER BY {column} {direction}) as rank
|
||||
FROM player_stats s
|
||||
JOIN users_v2 u ON s.user_id = u.id
|
||||
WHERE s.games_played >= 5
|
||||
AND u.deleted_at IS NULL
|
||||
AND (u.is_banned = false OR u.is_banned IS NULL)
|
||||
ORDER BY {column} {direction}
|
||||
LIMIT $1 OFFSET $2
|
||||
""", limit, offset)
|
||||
|
||||
return [
|
||||
LeaderboardEntry(
|
||||
rank=row["rank"],
|
||||
user_id=str(row["user_id"]),
|
||||
username=row["username"],
|
||||
value=float(row[column] or 0),
|
||||
games_played=row["games_played"],
|
||||
secondary_value=float(row["win_rate"] or 0) if metric != "win_rate" else None,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_player_rank(self, user_id: str, metric: str = "wins") -> Optional[int]:
|
||||
"""
|
||||
Get a player's rank on a leaderboard.
|
||||
|
||||
Args:
|
||||
user_id: User UUID.
|
||||
metric: Ranking metric.
|
||||
|
||||
Returns:
|
||||
Rank number or None if not ranked (< 5 games or not found).
|
||||
"""
|
||||
order_map = {
|
||||
"wins": ("games_won", "DESC"),
|
||||
"win_rate": ("win_rate", "DESC"),
|
||||
"avg_score": ("avg_score", "ASC"),
|
||||
"knockouts": ("knockouts", "DESC"),
|
||||
"streak": ("best_win_streak", "DESC"),
|
||||
}
|
||||
|
||||
if metric not in order_map:
|
||||
return None
|
||||
|
||||
column, direction = order_map[metric]
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
# Check if user qualifies (5+ games)
|
||||
games = await conn.fetchval(
|
||||
"SELECT games_played FROM player_stats WHERE user_id = $1",
|
||||
user_id
|
||||
)
|
||||
if not games or games < 5:
|
||||
return None
|
||||
|
||||
view_exists = await conn.fetchval(
|
||||
"SELECT 1 FROM pg_matviews WHERE matviewname = 'leaderboard_overall'"
|
||||
)
|
||||
|
||||
if view_exists:
|
||||
row = await conn.fetchrow(f"""
|
||||
SELECT rank FROM (
|
||||
SELECT user_id, ROW_NUMBER() OVER (ORDER BY {column} {direction}) as rank
|
||||
FROM leaderboard_overall
|
||||
) ranked
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
else:
|
||||
row = await conn.fetchrow(f"""
|
||||
SELECT rank FROM (
|
||||
SELECT s.user_id, ROW_NUMBER() OVER (ORDER BY {column} {direction}) as rank
|
||||
FROM player_stats s
|
||||
JOIN users_v2 u ON s.user_id = u.id
|
||||
WHERE s.games_played >= 5
|
||||
AND u.deleted_at IS NULL
|
||||
AND (u.is_banned = false OR u.is_banned IS NULL)
|
||||
) ranked
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
return row["rank"] if row else None
|
||||
|
||||
async def refresh_leaderboard(self) -> bool:
|
||||
"""
|
||||
Refresh the materialized leaderboard view.
|
||||
|
||||
Returns:
|
||||
True if refresh succeeded.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
try:
|
||||
# Check if view exists
|
||||
view_exists = await conn.fetchval(
|
||||
"SELECT 1 FROM pg_matviews WHERE matviewname = 'leaderboard_overall'"
|
||||
)
|
||||
if view_exists:
|
||||
await conn.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY leaderboard_overall")
|
||||
logger.info("Refreshed leaderboard materialized view")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh leaderboard: {e}")
|
||||
return False
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Achievement Queries
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def get_achievements(self) -> List[Achievement]:
|
||||
"""Get all available achievements."""
|
||||
async with self.pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT id, name, description, icon, category, threshold
|
||||
FROM achievements
|
||||
ORDER BY sort_order
|
||||
""")
|
||||
|
||||
return [
|
||||
Achievement(
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"] or "",
|
||||
icon=row["icon"] or "",
|
||||
category=row["category"] or "",
|
||||
threshold=row["threshold"] or 0,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_user_achievements(self, user_id: str) -> List[UserAchievement]:
|
||||
"""
|
||||
Get achievements earned by a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID.
|
||||
|
||||
Returns:
|
||||
List of earned achievements.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT a.id, a.name, a.description, a.icon, ua.earned_at, ua.game_id
|
||||
FROM user_achievements ua
|
||||
JOIN achievements a ON ua.achievement_id = a.id
|
||||
WHERE ua.user_id = $1
|
||||
ORDER BY ua.earned_at DESC
|
||||
""", user_id)
|
||||
|
||||
return [
|
||||
UserAchievement(
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"] or "",
|
||||
icon=row["icon"] or "",
|
||||
earned_at=row["earned_at"].replace(tzinfo=timezone.utc) if row["earned_at"] else datetime.now(timezone.utc),
|
||||
game_id=str(row["game_id"]) if row["game_id"] else None,
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Stats Processing (Game Completion)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def process_game_end(self, game_id: str) -> List[str]:
|
||||
"""
|
||||
Process a completed game and update player stats.
|
||||
|
||||
Extracts game data from events and updates player_stats table.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
|
||||
Returns:
|
||||
List of newly awarded achievement IDs.
|
||||
"""
|
||||
if not self.event_store:
|
||||
logger.warning("No event store configured, skipping stats processing")
|
||||
return []
|
||||
|
||||
# Get game events
|
||||
try:
|
||||
events = await self.event_store.get_events(game_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get events for game {game_id}: {e}")
|
||||
return []
|
||||
|
||||
if not events:
|
||||
logger.warning(f"No events found for game {game_id}")
|
||||
return []
|
||||
|
||||
# Extract game data from events
|
||||
game_data = self._extract_game_data(events)
|
||||
|
||||
if not game_data:
|
||||
logger.warning(f"Could not extract game data from events for {game_id}")
|
||||
return []
|
||||
|
||||
all_new_achievements = []
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
for player_id, player_data in game_data["players"].items():
|
||||
# Skip CPU players (they don't have user accounts)
|
||||
if player_data.get("is_cpu"):
|
||||
continue
|
||||
|
||||
# Check if this is a valid user UUID
|
||||
try:
|
||||
UUID(player_id)
|
||||
except (ValueError, TypeError):
|
||||
# Not a UUID - likely a websocket session ID, skip
|
||||
continue
|
||||
|
||||
# Ensure stats row exists
|
||||
await conn.execute("""
|
||||
INSERT INTO player_stats (user_id)
|
||||
VALUES ($1)
|
||||
ON CONFLICT (user_id) DO NOTHING
|
||||
""", player_id)
|
||||
|
||||
# Calculate values
|
||||
is_winner = player_id == game_data["winner_id"]
|
||||
total_score = player_data["total_score"]
|
||||
rounds_won = player_data["rounds_won"]
|
||||
num_rounds = game_data["num_rounds"]
|
||||
knockouts = player_data.get("knockouts", 0)
|
||||
best_round = player_data.get("best_round")
|
||||
worst_round = player_data.get("worst_round")
|
||||
perfect_rounds = player_data.get("perfect_rounds", 0)
|
||||
wolfpacks = player_data.get("wolfpacks", 0)
|
||||
has_human_opponents = game_data.get("has_human_opponents", False)
|
||||
|
||||
# Update stats
|
||||
await conn.execute("""
|
||||
UPDATE player_stats SET
|
||||
games_played = games_played + 1,
|
||||
games_won = games_won + $2,
|
||||
total_rounds = total_rounds + $3,
|
||||
rounds_won = rounds_won + $4,
|
||||
total_points = total_points + $5,
|
||||
knockouts = knockouts + $6,
|
||||
perfect_rounds = perfect_rounds + $7,
|
||||
wolfpacks = wolfpacks + $8,
|
||||
best_score = CASE
|
||||
WHEN best_score IS NULL THEN $9
|
||||
WHEN $9 IS NOT NULL AND $9 < best_score THEN $9
|
||||
ELSE best_score
|
||||
END,
|
||||
worst_score = CASE
|
||||
WHEN worst_score IS NULL THEN $10
|
||||
WHEN $10 IS NOT NULL AND $10 > worst_score THEN $10
|
||||
ELSE worst_score
|
||||
END,
|
||||
current_win_streak = CASE WHEN $2 = 1 THEN current_win_streak + 1 ELSE 0 END,
|
||||
best_win_streak = GREATEST(best_win_streak,
|
||||
CASE WHEN $2 = 1 THEN current_win_streak + 1 ELSE best_win_streak END),
|
||||
first_game_at = COALESCE(first_game_at, NOW()),
|
||||
last_game_at = NOW(),
|
||||
games_vs_humans = games_vs_humans + $11,
|
||||
games_won_vs_humans = games_won_vs_humans + $12,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
""",
|
||||
player_id,
|
||||
1 if is_winner else 0,
|
||||
num_rounds,
|
||||
rounds_won,
|
||||
total_score,
|
||||
knockouts,
|
||||
perfect_rounds,
|
||||
wolfpacks,
|
||||
best_round,
|
||||
worst_round,
|
||||
1 if has_human_opponents else 0,
|
||||
1 if is_winner and has_human_opponents else 0,
|
||||
)
|
||||
|
||||
# Check for new achievements
|
||||
new_achievements = await self._check_achievements(
|
||||
conn, player_id, game_id, player_data, is_winner
|
||||
)
|
||||
all_new_achievements.extend(new_achievements)
|
||||
|
||||
logger.info(f"Processed stats for game {game_id}, awarded {len(all_new_achievements)} achievements")
|
||||
return all_new_achievements
|
||||
|
||||
def _extract_game_data(self, events) -> Optional[dict]:
|
||||
"""
|
||||
Extract game statistics from event stream.
|
||||
|
||||
Args:
|
||||
events: List of GameEvent objects.
|
||||
|
||||
Returns:
|
||||
Dict with players, num_rounds, winner_id, etc.
|
||||
"""
|
||||
data = {
|
||||
"players": {},
|
||||
"num_rounds": 0,
|
||||
"winner_id": None,
|
||||
"has_human_opponents": False,
|
||||
}
|
||||
|
||||
human_count = 0
|
||||
|
||||
for event in events:
|
||||
if event.event_type == EventType.PLAYER_JOINED:
|
||||
is_cpu = event.data.get("is_cpu", False)
|
||||
if not is_cpu:
|
||||
human_count += 1
|
||||
|
||||
data["players"][event.player_id] = {
|
||||
"is_cpu": is_cpu,
|
||||
"total_score": 0,
|
||||
"rounds_won": 0,
|
||||
"knockouts": 0,
|
||||
"perfect_rounds": 0,
|
||||
"wolfpacks": 0,
|
||||
"best_round": None,
|
||||
"worst_round": None,
|
||||
}
|
||||
|
||||
elif event.event_type == EventType.ROUND_ENDED:
|
||||
data["num_rounds"] += 1
|
||||
scores = event.data.get("scores", {})
|
||||
finisher_id = event.data.get("finisher_id")
|
||||
|
||||
# Track who went out first (knockout)
|
||||
if finisher_id and finisher_id in data["players"]:
|
||||
data["players"][finisher_id]["knockouts"] += 1
|
||||
|
||||
# Find round winner (lowest score)
|
||||
if scores:
|
||||
min_score = min(scores.values())
|
||||
for pid, score in scores.items():
|
||||
if pid in data["players"]:
|
||||
p = data["players"][pid]
|
||||
p["total_score"] += score
|
||||
|
||||
# Track best/worst rounds
|
||||
if p["best_round"] is None or score < p["best_round"]:
|
||||
p["best_round"] = score
|
||||
if p["worst_round"] is None or score > p["worst_round"]:
|
||||
p["worst_round"] = score
|
||||
|
||||
# Check for perfect round (score <= 0)
|
||||
if score <= 0:
|
||||
p["perfect_rounds"] += 1
|
||||
|
||||
# Award round win
|
||||
if score == min_score:
|
||||
p["rounds_won"] += 1
|
||||
|
||||
# Check for wolfpack (4 Jacks) in final hands
|
||||
final_hands = event.data.get("final_hands", {})
|
||||
for pid, hand in final_hands.items():
|
||||
if pid in data["players"]:
|
||||
jack_count = sum(1 for card in hand if card.get("rank") == "J")
|
||||
if jack_count >= 4:
|
||||
data["players"][pid]["wolfpacks"] += 1
|
||||
|
||||
elif event.event_type == EventType.GAME_ENDED:
|
||||
data["winner_id"] = event.data.get("winner_id")
|
||||
|
||||
# Mark if there were human opponents
|
||||
data["has_human_opponents"] = human_count > 1
|
||||
|
||||
return data if data["num_rounds"] > 0 else None
|
||||
|
||||
async def _check_achievements(
|
||||
self,
|
||||
conn: asyncpg.Connection,
|
||||
user_id: str,
|
||||
game_id: str,
|
||||
player_data: dict,
|
||||
is_winner: bool,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Check and award new achievements to a player.
|
||||
|
||||
Args:
|
||||
conn: Database connection (within transaction).
|
||||
user_id: Player's user ID.
|
||||
game_id: Current game ID.
|
||||
player_data: Player's data from this game.
|
||||
is_winner: Whether player won the game.
|
||||
|
||||
Returns:
|
||||
List of newly awarded achievement IDs.
|
||||
"""
|
||||
new_achievements = []
|
||||
|
||||
# Get current stats (after update)
|
||||
stats = await conn.fetchrow("""
|
||||
SELECT games_won, knockouts, best_win_streak, current_win_streak, perfect_rounds, wolfpacks
|
||||
FROM player_stats
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
if not stats:
|
||||
return []
|
||||
|
||||
# Get already earned achievements
|
||||
earned = await conn.fetch("""
|
||||
SELECT achievement_id FROM user_achievements WHERE user_id = $1
|
||||
""", user_id)
|
||||
earned_ids = {e["achievement_id"] for e in earned}
|
||||
|
||||
# Check win milestones
|
||||
wins = stats["games_won"]
|
||||
if wins >= 1 and "first_win" not in earned_ids:
|
||||
new_achievements.append("first_win")
|
||||
if wins >= 10 and "win_10" not in earned_ids:
|
||||
new_achievements.append("win_10")
|
||||
if wins >= 50 and "win_50" not in earned_ids:
|
||||
new_achievements.append("win_50")
|
||||
if wins >= 100 and "win_100" not in earned_ids:
|
||||
new_achievements.append("win_100")
|
||||
|
||||
# Check streak achievements
|
||||
streak = stats["current_win_streak"]
|
||||
if streak >= 5 and "streak_5" not in earned_ids:
|
||||
new_achievements.append("streak_5")
|
||||
if streak >= 10 and "streak_10" not in earned_ids:
|
||||
new_achievements.append("streak_10")
|
||||
|
||||
# Check knockout achievements
|
||||
if stats["knockouts"] >= 10 and "knockout_10" not in earned_ids:
|
||||
new_achievements.append("knockout_10")
|
||||
|
||||
# Check round-specific achievements from this game
|
||||
best_round = player_data.get("best_round")
|
||||
if best_round is not None:
|
||||
if best_round <= 0 and "perfect_round" not in earned_ids:
|
||||
new_achievements.append("perfect_round")
|
||||
if best_round < 0 and "negative_round" not in earned_ids:
|
||||
new_achievements.append("negative_round")
|
||||
|
||||
# Check wolfpack
|
||||
if player_data.get("wolfpacks", 0) > 0 and "wolfpack" not in earned_ids:
|
||||
new_achievements.append("wolfpack")
|
||||
|
||||
# Award new achievements
|
||||
for achievement_id in new_achievements:
|
||||
try:
|
||||
await conn.execute("""
|
||||
INSERT INTO user_achievements (user_id, achievement_id, game_id)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT DO NOTHING
|
||||
""", user_id, achievement_id, game_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to award achievement {achievement_id}: {e}")
|
||||
|
||||
return new_achievements
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Direct Game State Processing (for legacy games without event sourcing)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def process_game_from_state(
|
||||
self,
|
||||
players: list,
|
||||
winner_id: Optional[str],
|
||||
num_rounds: int,
|
||||
player_user_ids: dict[str, str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Process game stats directly from game state (for legacy games).
|
||||
|
||||
This is used when games don't have event sourcing. Stats are updated
|
||||
based on final game state.
|
||||
|
||||
Args:
|
||||
players: List of game.Player objects with final scores.
|
||||
winner_id: Player ID of the winner.
|
||||
num_rounds: Total rounds played.
|
||||
player_user_ids: Optional mapping of player_id to user_id (for authenticated players).
|
||||
|
||||
Returns:
|
||||
List of newly awarded achievement IDs.
|
||||
"""
|
||||
if not players:
|
||||
return []
|
||||
|
||||
# Count human players for has_human_opponents calculation
|
||||
# For legacy games, we assume all players are human unless otherwise indicated
|
||||
human_count = len(players)
|
||||
has_human_opponents = human_count > 1
|
||||
|
||||
all_new_achievements = []
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
for player in players:
|
||||
# Get user_id - could be the player_id itself if it's a UUID,
|
||||
# or mapped via player_user_ids
|
||||
user_id = None
|
||||
if player_user_ids and player.id in player_user_ids:
|
||||
user_id = player_user_ids[player.id]
|
||||
else:
|
||||
# Try to use player.id as user_id if it looks like a UUID
|
||||
try:
|
||||
UUID(player.id)
|
||||
user_id = player.id
|
||||
except (ValueError, TypeError):
|
||||
# Not a UUID, skip this player
|
||||
continue
|
||||
|
||||
if not user_id:
|
||||
continue
|
||||
|
||||
# Ensure stats row exists
|
||||
await conn.execute("""
|
||||
INSERT INTO player_stats (user_id)
|
||||
VALUES ($1)
|
||||
ON CONFLICT (user_id) DO NOTHING
|
||||
""", user_id)
|
||||
|
||||
is_winner = player.id == winner_id
|
||||
total_score = player.total_score
|
||||
rounds_won = player.rounds_won
|
||||
|
||||
# We don't have per-round data in legacy mode, so some stats are limited
|
||||
# Use total_score / num_rounds as an approximation for avg round score
|
||||
avg_round_score = total_score / num_rounds if num_rounds > 0 else None
|
||||
|
||||
# Update stats
|
||||
await conn.execute("""
|
||||
UPDATE player_stats SET
|
||||
games_played = games_played + 1,
|
||||
games_won = games_won + $2,
|
||||
total_rounds = total_rounds + $3,
|
||||
rounds_won = rounds_won + $4,
|
||||
total_points = total_points + $5,
|
||||
best_score = CASE
|
||||
WHEN best_score IS NULL THEN $6
|
||||
WHEN $6 IS NOT NULL AND $6 < best_score THEN $6
|
||||
ELSE best_score
|
||||
END,
|
||||
worst_score = CASE
|
||||
WHEN worst_score IS NULL THEN $7
|
||||
WHEN $7 IS NOT NULL AND $7 > worst_score THEN $7
|
||||
ELSE worst_score
|
||||
END,
|
||||
current_win_streak = CASE WHEN $2 = 1 THEN current_win_streak + 1 ELSE 0 END,
|
||||
best_win_streak = GREATEST(best_win_streak,
|
||||
CASE WHEN $2 = 1 THEN current_win_streak + 1 ELSE best_win_streak END),
|
||||
first_game_at = COALESCE(first_game_at, NOW()),
|
||||
last_game_at = NOW(),
|
||||
games_vs_humans = games_vs_humans + $8,
|
||||
games_won_vs_humans = games_won_vs_humans + $9,
|
||||
updated_at = NOW()
|
||||
WHERE user_id = $1
|
||||
""",
|
||||
user_id,
|
||||
1 if is_winner else 0,
|
||||
num_rounds,
|
||||
rounds_won,
|
||||
total_score,
|
||||
avg_round_score, # Approximation for best_score
|
||||
avg_round_score, # Approximation for worst_score
|
||||
1 if has_human_opponents else 0,
|
||||
1 if is_winner and has_human_opponents else 0,
|
||||
)
|
||||
|
||||
# Check achievements (limited data in legacy mode)
|
||||
new_achievements = await self._check_achievements_legacy(
|
||||
conn, user_id, is_winner
|
||||
)
|
||||
all_new_achievements.extend(new_achievements)
|
||||
|
||||
logger.info(f"Processed stats for legacy game with {len(players)} players")
|
||||
return all_new_achievements
|
||||
|
||||
async def _check_achievements_legacy(
|
||||
self,
|
||||
conn: asyncpg.Connection,
|
||||
user_id: str,
|
||||
is_winner: bool,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Check and award achievements for legacy games (limited data).
|
||||
|
||||
Only checks win-based achievements since we don't have round-level data.
|
||||
"""
|
||||
new_achievements = []
|
||||
|
||||
# Get current stats
|
||||
stats = await conn.fetchrow("""
|
||||
SELECT games_won, current_win_streak FROM player_stats
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
if not stats:
|
||||
return []
|
||||
|
||||
# Get already earned achievements
|
||||
earned = await conn.fetch("""
|
||||
SELECT achievement_id FROM user_achievements WHERE user_id = $1
|
||||
""", user_id)
|
||||
earned_ids = {e["achievement_id"] for e in earned}
|
||||
|
||||
# Check win milestones
|
||||
wins = stats["games_won"]
|
||||
if wins >= 1 and "first_win" not in earned_ids:
|
||||
new_achievements.append("first_win")
|
||||
if wins >= 10 and "win_10" not in earned_ids:
|
||||
new_achievements.append("win_10")
|
||||
if wins >= 50 and "win_50" not in earned_ids:
|
||||
new_achievements.append("win_50")
|
||||
if wins >= 100 and "win_100" not in earned_ids:
|
||||
new_achievements.append("win_100")
|
||||
|
||||
# Check streak achievements
|
||||
streak = stats["current_win_streak"]
|
||||
if streak >= 5 and "streak_5" not in earned_ids:
|
||||
new_achievements.append("streak_5")
|
||||
if streak >= 10 and "streak_10" not in earned_ids:
|
||||
new_achievements.append("streak_10")
|
||||
|
||||
# Award new achievements
|
||||
for achievement_id in new_achievements:
|
||||
try:
|
||||
await conn.execute("""
|
||||
INSERT INTO user_achievements (user_id, achievement_id)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING
|
||||
""", user_id, achievement_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to award achievement {achievement_id}: {e}")
|
||||
|
||||
return new_achievements
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Stats Queue Management
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def queue_game_for_processing(self, game_id: str) -> int:
|
||||
"""
|
||||
Add a game to the stats processing queue.
|
||||
|
||||
Args:
|
||||
game_id: Game UUID.
|
||||
|
||||
Returns:
|
||||
Queue entry ID.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO stats_queue (game_id)
|
||||
VALUES ($1)
|
||||
RETURNING id
|
||||
""", game_id)
|
||||
return row["id"]
|
||||
|
||||
async def process_pending_queue(self, limit: int = 100) -> int:
|
||||
"""
|
||||
Process pending games in the stats queue.
|
||||
|
||||
Args:
|
||||
limit: Maximum games to process.
|
||||
|
||||
Returns:
|
||||
Number of games processed.
|
||||
"""
|
||||
processed = 0
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
# Get pending games
|
||||
games = await conn.fetch("""
|
||||
SELECT id, game_id FROM stats_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at
|
||||
LIMIT $1
|
||||
""", limit)
|
||||
|
||||
for game in games:
|
||||
try:
|
||||
# Mark as processing
|
||||
await conn.execute("""
|
||||
UPDATE stats_queue SET status = 'processing' WHERE id = $1
|
||||
""", game["id"])
|
||||
|
||||
# Process
|
||||
await self.process_game_end(str(game["game_id"]))
|
||||
|
||||
# Mark complete
|
||||
await conn.execute("""
|
||||
UPDATE stats_queue
|
||||
SET status = 'completed', processed_at = NOW()
|
||||
WHERE id = $1
|
||||
""", game["id"])
|
||||
|
||||
processed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process game {game['game_id']}: {e}")
|
||||
# Mark failed
|
||||
await conn.execute("""
|
||||
UPDATE stats_queue
|
||||
SET status = 'failed', error_message = $2, processed_at = NOW()
|
||||
WHERE id = $1
|
||||
""", game["id"], str(e))
|
||||
|
||||
return processed
|
||||
|
||||
async def cleanup_old_queue_entries(self, days: int = 7) -> int:
|
||||
"""
|
||||
Clean up old completed/failed queue entries.
|
||||
|
||||
Args:
|
||||
days: Delete entries older than this many days.
|
||||
|
||||
Returns:
|
||||
Number of entries deleted.
|
||||
"""
|
||||
async with self.pool.acquire() as conn:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM stats_queue
|
||||
WHERE status IN ('completed', 'failed')
|
||||
AND processed_at < NOW() - INTERVAL '1 day' * $1
|
||||
""", days)
|
||||
# Parse "DELETE N" result
|
||||
return int(result.split()[1]) if result else 0
|
||||
|
||||
|
||||
# Global stats service instance
|
||||
_stats_service: Optional[StatsService] = None
|
||||
|
||||
|
||||
async def get_stats_service(
|
||||
pool: asyncpg.Pool,
|
||||
event_store: Optional[EventStore] = None,
|
||||
) -> StatsService:
|
||||
"""
|
||||
Get or create the global stats service instance.
|
||||
|
||||
Args:
|
||||
pool: asyncpg connection pool.
|
||||
event_store: Optional EventStore.
|
||||
|
||||
Returns:
|
||||
StatsService instance.
|
||||
"""
|
||||
global _stats_service
|
||||
if _stats_service is None:
|
||||
_stats_service = StatsService(pool, event_store)
|
||||
return _stats_service
|
||||
|
||||
|
||||
def set_stats_service(service: StatsService) -> None:
|
||||
"""Set the global stats service instance."""
|
||||
global _stats_service
|
||||
_stats_service = service
|
||||
|
||||
|
||||
def close_stats_service() -> None:
|
||||
"""Close the global stats service."""
|
||||
global _stats_service
|
||||
_stats_service = None
|
||||
Reference in New Issue
Block a user