Numerous WebUI animations, improvements, AI fixes, opporitunity cost-based decision logic, etc.
This commit is contained in:
602
server/auth.py
Normal file
602
server/auth.py
Normal file
@@ -0,0 +1,602 @@
|
||||
"""
|
||||
Authentication and user management for Golf game.
|
||||
|
||||
Features:
|
||||
- User accounts stored in SQLite
|
||||
- Admin accounts can manage other users
|
||||
- Invite codes (room codes) allow new user registration
|
||||
- Session-based authentication via tokens
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import secrets
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from config import config
|
||||
|
||||
|
||||
class UserRole(Enum):
|
||||
"""User roles for access control."""
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
@dataclass
|
||||
class User:
|
||||
"""User account."""
|
||||
id: str
|
||||
username: str
|
||||
email: Optional[str]
|
||||
password_hash: str
|
||||
role: UserRole
|
||||
created_at: datetime
|
||||
last_login: Optional[datetime]
|
||||
is_active: bool
|
||||
invited_by: Optional[str] # Username of who invited them
|
||||
|
||||
def is_admin(self) -> bool:
|
||||
return self.role == UserRole.ADMIN
|
||||
|
||||
def to_dict(self, include_sensitive: bool = False) -> dict:
|
||||
"""Convert to dictionary for API responses."""
|
||||
data = {
|
||||
"id": self.id,
|
||||
"username": self.username,
|
||||
"email": self.email,
|
||||
"role": self.role.value,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_login": self.last_login.isoformat() if self.last_login else None,
|
||||
"is_active": self.is_active,
|
||||
"invited_by": self.invited_by,
|
||||
}
|
||||
if include_sensitive:
|
||||
data["password_hash"] = self.password_hash
|
||||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""User session."""
|
||||
token: str
|
||||
user_id: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return datetime.now() > self.expires_at
|
||||
|
||||
|
||||
@dataclass
|
||||
class InviteCode:
|
||||
"""Invite code for user registration."""
|
||||
code: str
|
||||
created_by: str # User ID who created the invite
|
||||
created_at: datetime
|
||||
expires_at: Optional[datetime]
|
||||
max_uses: int
|
||||
use_count: int
|
||||
is_active: bool
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
if not self.is_active:
|
||||
return False
|
||||
if self.expires_at and datetime.now() > self.expires_at:
|
||||
return False
|
||||
if self.max_uses > 0 and self.use_count >= self.max_uses:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class AuthManager:
|
||||
"""Manages user authentication and authorization."""
|
||||
|
||||
def __init__(self, db_path: str = "games.db"):
|
||||
self.db_path = Path(db_path)
|
||||
self._init_db()
|
||||
self._ensure_admin()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize auth database schema."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.executescript("""
|
||||
-- Users table
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
email TEXT UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT DEFAULT 'user',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login TIMESTAMP,
|
||||
is_active BOOLEAN DEFAULT 1,
|
||||
invited_by TEXT
|
||||
);
|
||||
|
||||
-- Sessions table
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
token TEXT PRIMARY KEY,
|
||||
user_id TEXT REFERENCES users(id),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL
|
||||
);
|
||||
|
||||
-- Invite codes table
|
||||
CREATE TABLE IF NOT EXISTS invite_codes (
|
||||
code TEXT PRIMARY KEY,
|
||||
created_by TEXT REFERENCES users(id),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP,
|
||||
max_uses INTEGER DEFAULT 1,
|
||||
use_count INTEGER DEFAULT 0,
|
||||
is_active BOOLEAN DEFAULT 1
|
||||
);
|
||||
|
||||
-- Indexes
|
||||
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
||||
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_invite_codes_active ON invite_codes(is_active);
|
||||
""")
|
||||
|
||||
def _ensure_admin(self):
|
||||
"""Ensure at least one admin account exists (without password - must be set on first login)."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT COUNT(*) FROM users WHERE role = ?",
|
||||
(UserRole.ADMIN.value,)
|
||||
)
|
||||
admin_count = cursor.fetchone()[0]
|
||||
|
||||
if admin_count == 0:
|
||||
# Check if admin emails are configured
|
||||
if config.ADMIN_EMAILS:
|
||||
# Create admin accounts for configured emails (no password yet)
|
||||
for email in config.ADMIN_EMAILS:
|
||||
username = email.split("@")[0]
|
||||
self._create_user_without_password(
|
||||
username=username,
|
||||
email=email,
|
||||
role=UserRole.ADMIN,
|
||||
)
|
||||
print(f"Created admin account: {username} - password must be set on first login")
|
||||
else:
|
||||
# Create default admin if no admins exist (no password yet)
|
||||
self._create_user_without_password(
|
||||
username="admin",
|
||||
role=UserRole.ADMIN,
|
||||
)
|
||||
print("Created default admin account - password must be set on first login")
|
||||
print("Set ADMIN_EMAILS in .env to configure admin accounts.")
|
||||
|
||||
def _create_user_without_password(
|
||||
self,
|
||||
username: str,
|
||||
email: Optional[str] = None,
|
||||
role: UserRole = UserRole.USER,
|
||||
) -> Optional[str]:
|
||||
"""Create a user without a password (for first-time setup)."""
|
||||
user_id = secrets.token_hex(16)
|
||||
# Empty password_hash indicates password needs to be set
|
||||
password_hash = ""
|
||||
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users (id, username, email, password_hash, role)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, username, email, password_hash, role.value),
|
||||
)
|
||||
return user_id
|
||||
except sqlite3.IntegrityError:
|
||||
return None
|
||||
|
||||
def needs_password_setup(self, username: str) -> bool:
|
||||
"""Check if user needs to set up their password (first login)."""
|
||||
user = self.get_user_by_username(username)
|
||||
if not user:
|
||||
return False
|
||||
return user.password_hash == ""
|
||||
|
||||
def setup_password(self, username: str, new_password: str) -> Optional[User]:
|
||||
"""Set password for first-time setup. Only works if password is not yet set."""
|
||||
user = self.get_user_by_username(username)
|
||||
if not user:
|
||||
return None
|
||||
if user.password_hash != "":
|
||||
return None # Password already set
|
||||
|
||||
password_hash = self._hash_password(new_password)
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"UPDATE users SET password_hash = ?, last_login = ? WHERE id = ?",
|
||||
(password_hash, datetime.now(), user.id)
|
||||
)
|
||||
|
||||
return self.get_user_by_id(user.id)
|
||||
|
||||
@staticmethod
|
||||
def _hash_password(password: str) -> str:
|
||||
"""Hash a password using SHA-256 with salt."""
|
||||
salt = secrets.token_hex(16)
|
||||
hash_input = f"{salt}:{password}".encode()
|
||||
password_hash = hashlib.sha256(hash_input).hexdigest()
|
||||
return f"{salt}:{password_hash}"
|
||||
|
||||
@staticmethod
|
||||
def _verify_password(password: str, stored_hash: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
try:
|
||||
salt, hash_value = stored_hash.split(":")
|
||||
hash_input = f"{salt}:{password}".encode()
|
||||
computed_hash = hashlib.sha256(hash_input).hexdigest()
|
||||
return secrets.compare_digest(computed_hash, hash_value)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def create_user(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
email: Optional[str] = None,
|
||||
role: UserRole = UserRole.USER,
|
||||
invited_by: Optional[str] = None,
|
||||
) -> Optional[User]:
|
||||
"""Create a new user account."""
|
||||
user_id = secrets.token_hex(16)
|
||||
password_hash = self._hash_password(password)
|
||||
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users (id, username, email, password_hash, role, invited_by)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, username, email, password_hash, role.value, invited_by),
|
||||
)
|
||||
return self.get_user_by_id(user_id)
|
||||
except sqlite3.IntegrityError:
|
||||
return None # Username or email already exists
|
||||
|
||||
def get_user_by_id(self, user_id: str) -> Optional[User]:
|
||||
"""Get user by ID."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM users WHERE id = ?",
|
||||
(user_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_user(row)
|
||||
return None
|
||||
|
||||
def get_user_by_username(self, username: str) -> Optional[User]:
|
||||
"""Get user by username."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM users WHERE username = ?",
|
||||
(username,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_user(row)
|
||||
return None
|
||||
|
||||
def _row_to_user(self, row: sqlite3.Row) -> User:
|
||||
"""Convert database row to User object."""
|
||||
return User(
|
||||
id=row["id"],
|
||||
username=row["username"],
|
||||
email=row["email"],
|
||||
password_hash=row["password_hash"],
|
||||
role=UserRole(row["role"]),
|
||||
created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else None,
|
||||
last_login=datetime.fromisoformat(row["last_login"]) if row["last_login"] else None,
|
||||
is_active=bool(row["is_active"]),
|
||||
invited_by=row["invited_by"],
|
||||
)
|
||||
|
||||
def authenticate(self, username: str, password: str) -> Optional[User]:
|
||||
"""Authenticate user with username and password."""
|
||||
user = self.get_user_by_username(username)
|
||||
if not user:
|
||||
return None
|
||||
if not user.is_active:
|
||||
return None
|
||||
if not self._verify_password(password, user.password_hash):
|
||||
return None
|
||||
|
||||
# Update last login
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"UPDATE users SET last_login = ? WHERE id = ?",
|
||||
(datetime.now(), user.id)
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
def create_session(self, user: User, duration_hours: int = 24) -> Session:
|
||||
"""Create a new session for a user."""
|
||||
token = secrets.token_urlsafe(32)
|
||||
created_at = datetime.now()
|
||||
expires_at = created_at + timedelta(hours=duration_hours)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (token, user_id, created_at, expires_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(token, user.id, created_at, expires_at)
|
||||
)
|
||||
|
||||
return Session(
|
||||
token=token,
|
||||
user_id=user.id,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
def get_session(self, token: str) -> Optional[Session]:
|
||||
"""Get session by token."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM sessions WHERE token = ?",
|
||||
(token,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
session = Session(
|
||||
token=row["token"],
|
||||
user_id=row["user_id"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
expires_at=datetime.fromisoformat(row["expires_at"]),
|
||||
)
|
||||
if not session.is_expired():
|
||||
return session
|
||||
# Clean up expired session
|
||||
self.invalidate_session(token)
|
||||
return None
|
||||
|
||||
def get_user_from_session(self, token: str) -> Optional[User]:
|
||||
"""Get user from session token."""
|
||||
session = self.get_session(token)
|
||||
if session:
|
||||
return self.get_user_by_id(session.user_id)
|
||||
return None
|
||||
|
||||
def invalidate_session(self, token: str):
|
||||
"""Invalidate a session."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("DELETE FROM sessions WHERE token = ?", (token,))
|
||||
|
||||
def invalidate_user_sessions(self, user_id: str):
|
||||
"""Invalidate all sessions for a user."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("DELETE FROM sessions WHERE user_id = ?", (user_id,))
|
||||
|
||||
# =========================================================================
|
||||
# Invite Codes
|
||||
# =========================================================================
|
||||
|
||||
def create_invite_code(
|
||||
self,
|
||||
created_by: str,
|
||||
max_uses: int = 1,
|
||||
expires_in_days: Optional[int] = 7,
|
||||
) -> InviteCode:
|
||||
"""Create a new invite code."""
|
||||
code = secrets.token_urlsafe(8).upper()[:8] # 8 character code
|
||||
created_at = datetime.now()
|
||||
expires_at = created_at + timedelta(days=expires_in_days) if expires_in_days else None
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO invite_codes (code, created_by, created_at, expires_at, max_uses)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(code, created_by, created_at, expires_at, max_uses)
|
||||
)
|
||||
|
||||
return InviteCode(
|
||||
code=code,
|
||||
created_by=created_by,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
max_uses=max_uses,
|
||||
use_count=0,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
def get_invite_code(self, code: str) -> Optional[InviteCode]:
|
||||
"""Get invite code by code string."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM invite_codes WHERE code = ?",
|
||||
(code.upper(),)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return InviteCode(
|
||||
code=row["code"],
|
||||
created_by=row["created_by"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
expires_at=datetime.fromisoformat(row["expires_at"]) if row["expires_at"] else None,
|
||||
max_uses=row["max_uses"],
|
||||
use_count=row["use_count"],
|
||||
is_active=bool(row["is_active"]),
|
||||
)
|
||||
return None
|
||||
|
||||
def use_invite_code(self, code: str) -> bool:
|
||||
"""Mark an invite code as used. Returns False if invalid."""
|
||||
invite = self.get_invite_code(code)
|
||||
if not invite or not invite.is_valid():
|
||||
return False
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"UPDATE invite_codes SET use_count = use_count + 1 WHERE code = ?",
|
||||
(code.upper(),)
|
||||
)
|
||||
return True
|
||||
|
||||
def validate_room_code_as_invite(self, room_code: str) -> bool:
|
||||
"""
|
||||
Check if a room code is valid for registration.
|
||||
Room codes from active games act as invite codes.
|
||||
"""
|
||||
# First check if it's an explicit invite code
|
||||
invite = self.get_invite_code(room_code)
|
||||
if invite and invite.is_valid():
|
||||
return True
|
||||
|
||||
# Check if it's an active room code (from room manager)
|
||||
# This will be checked by the caller since we don't have room_manager here
|
||||
return False
|
||||
|
||||
# =========================================================================
|
||||
# Admin Functions
|
||||
# =========================================================================
|
||||
|
||||
def list_users(self, include_inactive: bool = False) -> list[User]:
|
||||
"""List all users (admin function)."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
if include_inactive:
|
||||
cursor = conn.execute("SELECT * FROM users ORDER BY created_at DESC")
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM users WHERE is_active = 1 ORDER BY created_at DESC"
|
||||
)
|
||||
return [self._row_to_user(row) for row in cursor.fetchall()]
|
||||
|
||||
def update_user(
|
||||
self,
|
||||
user_id: str,
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
role: Optional[UserRole] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
) -> Optional[User]:
|
||||
"""Update user details (admin function)."""
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if username is not None:
|
||||
updates.append("username = ?")
|
||||
params.append(username)
|
||||
if email is not None:
|
||||
updates.append("email = ?")
|
||||
params.append(email)
|
||||
if role is not None:
|
||||
updates.append("role = ?")
|
||||
params.append(role.value)
|
||||
if is_active is not None:
|
||||
updates.append("is_active = ?")
|
||||
params.append(is_active)
|
||||
|
||||
if not updates:
|
||||
return self.get_user_by_id(user_id)
|
||||
|
||||
params.append(user_id)
|
||||
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
f"UPDATE users SET {', '.join(updates)} WHERE id = ?",
|
||||
params
|
||||
)
|
||||
return self.get_user_by_id(user_id)
|
||||
except sqlite3.IntegrityError:
|
||||
return None
|
||||
|
||||
def change_password(self, user_id: str, new_password: str) -> bool:
|
||||
"""Change user password."""
|
||||
password_hash = self._hash_password(new_password)
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.execute(
|
||||
"UPDATE users SET password_hash = ? WHERE id = ?",
|
||||
(password_hash, user_id)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def delete_user(self, user_id: str) -> bool:
|
||||
"""Delete a user (admin function). Actually just deactivates."""
|
||||
# Invalidate all sessions first
|
||||
self.invalidate_user_sessions(user_id)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.execute(
|
||||
"UPDATE users SET is_active = 0 WHERE id = ?",
|
||||
(user_id,)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def list_invite_codes(self, created_by: Optional[str] = None) -> list[InviteCode]:
|
||||
"""List invite codes (admin function)."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
if created_by:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM invite_codes WHERE created_by = ? ORDER BY created_at DESC",
|
||||
(created_by,)
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM invite_codes ORDER BY created_at DESC"
|
||||
)
|
||||
return [
|
||||
InviteCode(
|
||||
code=row["code"],
|
||||
created_by=row["created_by"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
expires_at=datetime.fromisoformat(row["expires_at"]) if row["expires_at"] else None,
|
||||
max_uses=row["max_uses"],
|
||||
use_count=row["use_count"],
|
||||
is_active=bool(row["is_active"]),
|
||||
)
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def deactivate_invite_code(self, code: str) -> bool:
|
||||
"""Deactivate an invite code (admin function)."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.execute(
|
||||
"UPDATE invite_codes SET is_active = 0 WHERE code = ?",
|
||||
(code.upper(),)
|
||||
)
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def cleanup_expired_sessions(self):
|
||||
"""Remove expired sessions from database."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"DELETE FROM sessions WHERE expires_at < ?",
|
||||
(datetime.now(),)
|
||||
)
|
||||
|
||||
|
||||
# Global auth manager instance (lazy initialization)
|
||||
_auth_manager: Optional[AuthManager] = None
|
||||
|
||||
|
||||
def get_auth_manager() -> AuthManager:
|
||||
"""Get or create the global auth manager instance."""
|
||||
global _auth_manager
|
||||
if _auth_manager is None:
|
||||
_auth_manager = AuthManager()
|
||||
return _auth_manager
|
||||
Reference in New Issue
Block a user