golfgame/server/auth.py

603 lines
21 KiB
Python

"""
Authentication and user management for Golf game.
Features:
- User accounts stored in SQLite
- Admin accounts can manage other users
- Invite codes (room codes) allow new user registration
- Session-based authentication via tokens
"""
import hashlib
import secrets
import sqlite3
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Optional
from config import config
class UserRole(Enum):
"""User roles for access control."""
USER = "user"
ADMIN = "admin"
@dataclass
class User:
"""User account."""
id: str
username: str
email: Optional[str]
password_hash: str
role: UserRole
created_at: datetime
last_login: Optional[datetime]
is_active: bool
invited_by: Optional[str] # Username of who invited them
def is_admin(self) -> bool:
return self.role == UserRole.ADMIN
def to_dict(self, include_sensitive: bool = False) -> dict:
"""Convert to dictionary for API responses."""
data = {
"id": self.id,
"username": self.username,
"email": self.email,
"role": self.role.value,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_login": self.last_login.isoformat() if self.last_login else None,
"is_active": self.is_active,
"invited_by": self.invited_by,
}
if include_sensitive:
data["password_hash"] = self.password_hash
return data
@dataclass
class Session:
"""User session."""
token: str
user_id: str
created_at: datetime
expires_at: datetime
def is_expired(self) -> bool:
return datetime.now() > self.expires_at
@dataclass
class InviteCode:
"""Invite code for user registration."""
code: str
created_by: str # User ID who created the invite
created_at: datetime
expires_at: Optional[datetime]
max_uses: int
use_count: int
is_active: bool
def is_valid(self) -> bool:
if not self.is_active:
return False
if self.expires_at and datetime.now() > self.expires_at:
return False
if self.max_uses > 0 and self.use_count >= self.max_uses:
return False
return True
class AuthManager:
"""Manages user authentication and authorization."""
def __init__(self, db_path: str = "games.db"):
self.db_path = Path(db_path)
self._init_db()
self._ensure_admin()
def _init_db(self):
"""Initialize auth database schema."""
with sqlite3.connect(self.db_path) as conn:
conn.executescript("""
-- Users table
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
username TEXT UNIQUE NOT NULL,
email TEXT UNIQUE,
password_hash TEXT NOT NULL,
role TEXT DEFAULT 'user',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_login TIMESTAMP,
is_active BOOLEAN DEFAULT 1,
invited_by TEXT
);
-- Sessions table
CREATE TABLE IF NOT EXISTS sessions (
token TEXT PRIMARY KEY,
user_id TEXT REFERENCES users(id),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL
);
-- Invite codes table
CREATE TABLE IF NOT EXISTS invite_codes (
code TEXT PRIMARY KEY,
created_by TEXT REFERENCES users(id),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP,
max_uses INTEGER DEFAULT 1,
use_count INTEGER DEFAULT 0,
is_active BOOLEAN DEFAULT 1
);
-- Indexes
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at);
CREATE INDEX IF NOT EXISTS idx_invite_codes_active ON invite_codes(is_active);
""")
def _ensure_admin(self):
"""Ensure at least one admin account exists (without password - must be set on first login)."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"SELECT COUNT(*) FROM users WHERE role = ?",
(UserRole.ADMIN.value,)
)
admin_count = cursor.fetchone()[0]
if admin_count == 0:
# Check if admin emails are configured
if config.ADMIN_EMAILS:
# Create admin accounts for configured emails (no password yet)
for email in config.ADMIN_EMAILS:
username = email.split("@")[0]
self._create_user_without_password(
username=username,
email=email,
role=UserRole.ADMIN,
)
print(f"Created admin account: {username} - password must be set on first login")
else:
# Create default admin if no admins exist (no password yet)
self._create_user_without_password(
username="admin",
role=UserRole.ADMIN,
)
print("Created default admin account - password must be set on first login")
print("Set ADMIN_EMAILS in .env to configure admin accounts.")
def _create_user_without_password(
self,
username: str,
email: Optional[str] = None,
role: UserRole = UserRole.USER,
) -> Optional[str]:
"""Create a user without a password (for first-time setup)."""
user_id = secrets.token_hex(16)
# Empty password_hash indicates password needs to be set
password_hash = ""
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO users (id, username, email, password_hash, role)
VALUES (?, ?, ?, ?, ?)
""",
(user_id, username, email, password_hash, role.value),
)
return user_id
except sqlite3.IntegrityError:
return None
def needs_password_setup(self, username: str) -> bool:
"""Check if user needs to set up their password (first login)."""
user = self.get_user_by_username(username)
if not user:
return False
return user.password_hash == ""
def setup_password(self, username: str, new_password: str) -> Optional[User]:
"""Set password for first-time setup. Only works if password is not yet set."""
user = self.get_user_by_username(username)
if not user:
return None
if user.password_hash != "":
return None # Password already set
password_hash = self._hash_password(new_password)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"UPDATE users SET password_hash = ?, last_login = ? WHERE id = ?",
(password_hash, datetime.now(), user.id)
)
return self.get_user_by_id(user.id)
@staticmethod
def _hash_password(password: str) -> str:
"""Hash a password using SHA-256 with salt."""
salt = secrets.token_hex(16)
hash_input = f"{salt}:{password}".encode()
password_hash = hashlib.sha256(hash_input).hexdigest()
return f"{salt}:{password_hash}"
@staticmethod
def _verify_password(password: str, stored_hash: str) -> bool:
"""Verify a password against its hash."""
try:
salt, hash_value = stored_hash.split(":")
hash_input = f"{salt}:{password}".encode()
computed_hash = hashlib.sha256(hash_input).hexdigest()
return secrets.compare_digest(computed_hash, hash_value)
except ValueError:
return False
def create_user(
self,
username: str,
password: str,
email: Optional[str] = None,
role: UserRole = UserRole.USER,
invited_by: Optional[str] = None,
) -> Optional[User]:
"""Create a new user account."""
user_id = secrets.token_hex(16)
password_hash = self._hash_password(password)
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO users (id, username, email, password_hash, role, invited_by)
VALUES (?, ?, ?, ?, ?, ?)
""",
(user_id, username, email, password_hash, role.value, invited_by),
)
return self.get_user_by_id(user_id)
except sqlite3.IntegrityError:
return None # Username or email already exists
def get_user_by_id(self, user_id: str) -> Optional[User]:
"""Get user by ID."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(
"SELECT * FROM users WHERE id = ?",
(user_id,)
)
row = cursor.fetchone()
if row:
return self._row_to_user(row)
return None
def get_user_by_username(self, username: str) -> Optional[User]:
"""Get user by username."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(
"SELECT * FROM users WHERE username = ?",
(username,)
)
row = cursor.fetchone()
if row:
return self._row_to_user(row)
return None
def _row_to_user(self, row: sqlite3.Row) -> User:
"""Convert database row to User object."""
return User(
id=row["id"],
username=row["username"],
email=row["email"],
password_hash=row["password_hash"],
role=UserRole(row["role"]),
created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else None,
last_login=datetime.fromisoformat(row["last_login"]) if row["last_login"] else None,
is_active=bool(row["is_active"]),
invited_by=row["invited_by"],
)
def authenticate(self, username: str, password: str) -> Optional[User]:
"""Authenticate user with username and password."""
user = self.get_user_by_username(username)
if not user:
return None
if not user.is_active:
return None
if not self._verify_password(password, user.password_hash):
return None
# Update last login
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"UPDATE users SET last_login = ? WHERE id = ?",
(datetime.now(), user.id)
)
return user
def create_session(self, user: User, duration_hours: int = 24) -> Session:
"""Create a new session for a user."""
token = secrets.token_urlsafe(32)
created_at = datetime.now()
expires_at = created_at + timedelta(hours=duration_hours)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO sessions (token, user_id, created_at, expires_at)
VALUES (?, ?, ?, ?)
""",
(token, user.id, created_at, expires_at)
)
return Session(
token=token,
user_id=user.id,
created_at=created_at,
expires_at=expires_at,
)
def get_session(self, token: str) -> Optional[Session]:
"""Get session by token."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(
"SELECT * FROM sessions WHERE token = ?",
(token,)
)
row = cursor.fetchone()
if row:
session = Session(
token=row["token"],
user_id=row["user_id"],
created_at=datetime.fromisoformat(row["created_at"]),
expires_at=datetime.fromisoformat(row["expires_at"]),
)
if not session.is_expired():
return session
# Clean up expired session
self.invalidate_session(token)
return None
def get_user_from_session(self, token: str) -> Optional[User]:
"""Get user from session token."""
session = self.get_session(token)
if session:
return self.get_user_by_id(session.user_id)
return None
def invalidate_session(self, token: str):
"""Invalidate a session."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM sessions WHERE token = ?", (token,))
def invalidate_user_sessions(self, user_id: str):
"""Invalidate all sessions for a user."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM sessions WHERE user_id = ?", (user_id,))
# =========================================================================
# Invite Codes
# =========================================================================
def create_invite_code(
self,
created_by: str,
max_uses: int = 1,
expires_in_days: Optional[int] = 7,
) -> InviteCode:
"""Create a new invite code."""
code = secrets.token_urlsafe(8).upper()[:8] # 8 character code
created_at = datetime.now()
expires_at = created_at + timedelta(days=expires_in_days) if expires_in_days else None
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""
INSERT INTO invite_codes (code, created_by, created_at, expires_at, max_uses)
VALUES (?, ?, ?, ?, ?)
""",
(code, created_by, created_at, expires_at, max_uses)
)
return InviteCode(
code=code,
created_by=created_by,
created_at=created_at,
expires_at=expires_at,
max_uses=max_uses,
use_count=0,
is_active=True,
)
def get_invite_code(self, code: str) -> Optional[InviteCode]:
"""Get invite code by code string."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.execute(
"SELECT * FROM invite_codes WHERE code = ?",
(code.upper(),)
)
row = cursor.fetchone()
if row:
return InviteCode(
code=row["code"],
created_by=row["created_by"],
created_at=datetime.fromisoformat(row["created_at"]),
expires_at=datetime.fromisoformat(row["expires_at"]) if row["expires_at"] else None,
max_uses=row["max_uses"],
use_count=row["use_count"],
is_active=bool(row["is_active"]),
)
return None
def use_invite_code(self, code: str) -> bool:
"""Mark an invite code as used. Returns False if invalid."""
invite = self.get_invite_code(code)
if not invite or not invite.is_valid():
return False
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"UPDATE invite_codes SET use_count = use_count + 1 WHERE code = ?",
(code.upper(),)
)
return True
def validate_room_code_as_invite(self, room_code: str) -> bool:
"""
Check if a room code is valid for registration.
Room codes from active games act as invite codes.
"""
# First check if it's an explicit invite code
invite = self.get_invite_code(room_code)
if invite and invite.is_valid():
return True
# Check if it's an active room code (from room manager)
# This will be checked by the caller since we don't have room_manager here
return False
# =========================================================================
# Admin Functions
# =========================================================================
def list_users(self, include_inactive: bool = False) -> list[User]:
"""List all users (admin function)."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
if include_inactive:
cursor = conn.execute("SELECT * FROM users ORDER BY created_at DESC")
else:
cursor = conn.execute(
"SELECT * FROM users WHERE is_active = 1 ORDER BY created_at DESC"
)
return [self._row_to_user(row) for row in cursor.fetchall()]
def update_user(
self,
user_id: str,
username: Optional[str] = None,
email: Optional[str] = None,
role: Optional[UserRole] = None,
is_active: Optional[bool] = None,
) -> Optional[User]:
"""Update user details (admin function)."""
updates = []
params = []
if username is not None:
updates.append("username = ?")
params.append(username)
if email is not None:
updates.append("email = ?")
params.append(email)
if role is not None:
updates.append("role = ?")
params.append(role.value)
if is_active is not None:
updates.append("is_active = ?")
params.append(is_active)
if not updates:
return self.get_user_by_id(user_id)
params.append(user_id)
try:
with sqlite3.connect(self.db_path) as conn:
conn.execute(
f"UPDATE users SET {', '.join(updates)} WHERE id = ?",
params
)
return self.get_user_by_id(user_id)
except sqlite3.IntegrityError:
return None
def change_password(self, user_id: str, new_password: str) -> bool:
"""Change user password."""
password_hash = self._hash_password(new_password)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"UPDATE users SET password_hash = ? WHERE id = ?",
(password_hash, user_id)
)
return cursor.rowcount > 0
def delete_user(self, user_id: str) -> bool:
"""Delete a user (admin function). Actually just deactivates."""
# Invalidate all sessions first
self.invalidate_user_sessions(user_id)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"UPDATE users SET is_active = 0 WHERE id = ?",
(user_id,)
)
return cursor.rowcount > 0
def list_invite_codes(self, created_by: Optional[str] = None) -> list[InviteCode]:
"""List invite codes (admin function)."""
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
if created_by:
cursor = conn.execute(
"SELECT * FROM invite_codes WHERE created_by = ? ORDER BY created_at DESC",
(created_by,)
)
else:
cursor = conn.execute(
"SELECT * FROM invite_codes ORDER BY created_at DESC"
)
return [
InviteCode(
code=row["code"],
created_by=row["created_by"],
created_at=datetime.fromisoformat(row["created_at"]),
expires_at=datetime.fromisoformat(row["expires_at"]) if row["expires_at"] else None,
max_uses=row["max_uses"],
use_count=row["use_count"],
is_active=bool(row["is_active"]),
)
for row in cursor.fetchall()
]
def deactivate_invite_code(self, code: str) -> bool:
"""Deactivate an invite code (admin function)."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"UPDATE invite_codes SET is_active = 0 WHERE code = ?",
(code.upper(),)
)
return cursor.rowcount > 0
def cleanup_expired_sessions(self):
"""Remove expired sessions from database."""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"DELETE FROM sessions WHERE expires_at < ?",
(datetime.now(),)
)
# Global auth manager instance (lazy initialization)
_auth_manager: Optional[AuthManager] = None
def get_auth_manager() -> AuthManager:
"""Get or create the global auth manager instance."""
global _auth_manager
if _auth_manager is None:
_auth_manager = AuthManager()
return _auth_manager