603 lines
21 KiB
Python
603 lines
21 KiB
Python
"""
|
|
Authentication and user management for Golf game.
|
|
|
|
Features:
|
|
- User accounts stored in SQLite
|
|
- Admin accounts can manage other users
|
|
- Invite codes (room codes) allow new user registration
|
|
- Session-based authentication via tokens
|
|
"""
|
|
|
|
import hashlib
|
|
import secrets
|
|
import sqlite3
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from config import config
|
|
|
|
|
|
class UserRole(Enum):
|
|
"""User roles for access control."""
|
|
USER = "user"
|
|
ADMIN = "admin"
|
|
|
|
|
|
@dataclass
|
|
class User:
|
|
"""User account."""
|
|
id: str
|
|
username: str
|
|
email: Optional[str]
|
|
password_hash: str
|
|
role: UserRole
|
|
created_at: datetime
|
|
last_login: Optional[datetime]
|
|
is_active: bool
|
|
invited_by: Optional[str] # Username of who invited them
|
|
|
|
def is_admin(self) -> bool:
|
|
return self.role == UserRole.ADMIN
|
|
|
|
def to_dict(self, include_sensitive: bool = False) -> dict:
|
|
"""Convert to dictionary for API responses."""
|
|
data = {
|
|
"id": self.id,
|
|
"username": self.username,
|
|
"email": self.email,
|
|
"role": self.role.value,
|
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
"last_login": self.last_login.isoformat() if self.last_login else None,
|
|
"is_active": self.is_active,
|
|
"invited_by": self.invited_by,
|
|
}
|
|
if include_sensitive:
|
|
data["password_hash"] = self.password_hash
|
|
return data
|
|
|
|
|
|
@dataclass
|
|
class Session:
|
|
"""User session."""
|
|
token: str
|
|
user_id: str
|
|
created_at: datetime
|
|
expires_at: datetime
|
|
|
|
def is_expired(self) -> bool:
|
|
return datetime.now() > self.expires_at
|
|
|
|
|
|
@dataclass
|
|
class InviteCode:
|
|
"""Invite code for user registration."""
|
|
code: str
|
|
created_by: str # User ID who created the invite
|
|
created_at: datetime
|
|
expires_at: Optional[datetime]
|
|
max_uses: int
|
|
use_count: int
|
|
is_active: bool
|
|
|
|
def is_valid(self) -> bool:
|
|
if not self.is_active:
|
|
return False
|
|
if self.expires_at and datetime.now() > self.expires_at:
|
|
return False
|
|
if self.max_uses > 0 and self.use_count >= self.max_uses:
|
|
return False
|
|
return True
|
|
|
|
|
|
class AuthManager:
|
|
"""Manages user authentication and authorization."""
|
|
|
|
def __init__(self, db_path: str = "games.db"):
|
|
self.db_path = Path(db_path)
|
|
self._init_db()
|
|
self._ensure_admin()
|
|
|
|
def _init_db(self):
|
|
"""Initialize auth database schema."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.executescript("""
|
|
-- Users table
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
username TEXT UNIQUE NOT NULL,
|
|
email TEXT UNIQUE,
|
|
password_hash TEXT NOT NULL,
|
|
role TEXT DEFAULT 'user',
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
last_login TIMESTAMP,
|
|
is_active BOOLEAN DEFAULT 1,
|
|
invited_by TEXT
|
|
);
|
|
|
|
-- Sessions table
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
token TEXT PRIMARY KEY,
|
|
user_id TEXT REFERENCES users(id),
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
expires_at TIMESTAMP NOT NULL
|
|
);
|
|
|
|
-- Invite codes table
|
|
CREATE TABLE IF NOT EXISTS invite_codes (
|
|
code TEXT PRIMARY KEY,
|
|
created_by TEXT REFERENCES users(id),
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
expires_at TIMESTAMP,
|
|
max_uses INTEGER DEFAULT 1,
|
|
use_count INTEGER DEFAULT 0,
|
|
is_active BOOLEAN DEFAULT 1
|
|
);
|
|
|
|
-- Indexes
|
|
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
|
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
|
|
CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at);
|
|
CREATE INDEX IF NOT EXISTS idx_invite_codes_active ON invite_codes(is_active);
|
|
""")
|
|
|
|
def _ensure_admin(self):
|
|
"""Ensure at least one admin account exists (without password - must be set on first login)."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.execute(
|
|
"SELECT COUNT(*) FROM users WHERE role = ?",
|
|
(UserRole.ADMIN.value,)
|
|
)
|
|
admin_count = cursor.fetchone()[0]
|
|
|
|
if admin_count == 0:
|
|
# Check if admin emails are configured
|
|
if config.ADMIN_EMAILS:
|
|
# Create admin accounts for configured emails (no password yet)
|
|
for email in config.ADMIN_EMAILS:
|
|
username = email.split("@")[0]
|
|
self._create_user_without_password(
|
|
username=username,
|
|
email=email,
|
|
role=UserRole.ADMIN,
|
|
)
|
|
print(f"Created admin account: {username} - password must be set on first login")
|
|
else:
|
|
# Create default admin if no admins exist (no password yet)
|
|
self._create_user_without_password(
|
|
username="admin",
|
|
role=UserRole.ADMIN,
|
|
)
|
|
print("Created default admin account - password must be set on first login")
|
|
print("Set ADMIN_EMAILS in .env to configure admin accounts.")
|
|
|
|
def _create_user_without_password(
|
|
self,
|
|
username: str,
|
|
email: Optional[str] = None,
|
|
role: UserRole = UserRole.USER,
|
|
) -> Optional[str]:
|
|
"""Create a user without a password (for first-time setup)."""
|
|
user_id = secrets.token_hex(16)
|
|
# Empty password_hash indicates password needs to be set
|
|
password_hash = ""
|
|
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO users (id, username, email, password_hash, role)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, username, email, password_hash, role.value),
|
|
)
|
|
return user_id
|
|
except sqlite3.IntegrityError:
|
|
return None
|
|
|
|
def needs_password_setup(self, username: str) -> bool:
|
|
"""Check if user needs to set up their password (first login)."""
|
|
user = self.get_user_by_username(username)
|
|
if not user:
|
|
return False
|
|
return user.password_hash == ""
|
|
|
|
def setup_password(self, username: str, new_password: str) -> Optional[User]:
|
|
"""Set password for first-time setup. Only works if password is not yet set."""
|
|
user = self.get_user_by_username(username)
|
|
if not user:
|
|
return None
|
|
if user.password_hash != "":
|
|
return None # Password already set
|
|
|
|
password_hash = self._hash_password(new_password)
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"UPDATE users SET password_hash = ?, last_login = ? WHERE id = ?",
|
|
(password_hash, datetime.now(), user.id)
|
|
)
|
|
|
|
return self.get_user_by_id(user.id)
|
|
|
|
@staticmethod
|
|
def _hash_password(password: str) -> str:
|
|
"""Hash a password using SHA-256 with salt."""
|
|
salt = secrets.token_hex(16)
|
|
hash_input = f"{salt}:{password}".encode()
|
|
password_hash = hashlib.sha256(hash_input).hexdigest()
|
|
return f"{salt}:{password_hash}"
|
|
|
|
@staticmethod
|
|
def _verify_password(password: str, stored_hash: str) -> bool:
|
|
"""Verify a password against its hash."""
|
|
try:
|
|
salt, hash_value = stored_hash.split(":")
|
|
hash_input = f"{salt}:{password}".encode()
|
|
computed_hash = hashlib.sha256(hash_input).hexdigest()
|
|
return secrets.compare_digest(computed_hash, hash_value)
|
|
except ValueError:
|
|
return False
|
|
|
|
def create_user(
|
|
self,
|
|
username: str,
|
|
password: str,
|
|
email: Optional[str] = None,
|
|
role: UserRole = UserRole.USER,
|
|
invited_by: Optional[str] = None,
|
|
) -> Optional[User]:
|
|
"""Create a new user account."""
|
|
user_id = secrets.token_hex(16)
|
|
password_hash = self._hash_password(password)
|
|
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO users (id, username, email, password_hash, role, invited_by)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, username, email, password_hash, role.value, invited_by),
|
|
)
|
|
return self.get_user_by_id(user_id)
|
|
except sqlite3.IntegrityError:
|
|
return None # Username or email already exists
|
|
|
|
def get_user_by_id(self, user_id: str) -> Optional[User]:
|
|
"""Get user by ID."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(
|
|
"SELECT * FROM users WHERE id = ?",
|
|
(user_id,)
|
|
)
|
|
row = cursor.fetchone()
|
|
if row:
|
|
return self._row_to_user(row)
|
|
return None
|
|
|
|
def get_user_by_username(self, username: str) -> Optional[User]:
|
|
"""Get user by username."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(
|
|
"SELECT * FROM users WHERE username = ?",
|
|
(username,)
|
|
)
|
|
row = cursor.fetchone()
|
|
if row:
|
|
return self._row_to_user(row)
|
|
return None
|
|
|
|
def _row_to_user(self, row: sqlite3.Row) -> User:
|
|
"""Convert database row to User object."""
|
|
return User(
|
|
id=row["id"],
|
|
username=row["username"],
|
|
email=row["email"],
|
|
password_hash=row["password_hash"],
|
|
role=UserRole(row["role"]),
|
|
created_at=datetime.fromisoformat(row["created_at"]) if row["created_at"] else None,
|
|
last_login=datetime.fromisoformat(row["last_login"]) if row["last_login"] else None,
|
|
is_active=bool(row["is_active"]),
|
|
invited_by=row["invited_by"],
|
|
)
|
|
|
|
def authenticate(self, username: str, password: str) -> Optional[User]:
|
|
"""Authenticate user with username and password."""
|
|
user = self.get_user_by_username(username)
|
|
if not user:
|
|
return None
|
|
if not user.is_active:
|
|
return None
|
|
if not self._verify_password(password, user.password_hash):
|
|
return None
|
|
|
|
# Update last login
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"UPDATE users SET last_login = ? WHERE id = ?",
|
|
(datetime.now(), user.id)
|
|
)
|
|
|
|
return user
|
|
|
|
def create_session(self, user: User, duration_hours: int = 24) -> Session:
|
|
"""Create a new session for a user."""
|
|
token = secrets.token_urlsafe(32)
|
|
created_at = datetime.now()
|
|
expires_at = created_at + timedelta(hours=duration_hours)
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO sessions (token, user_id, created_at, expires_at)
|
|
VALUES (?, ?, ?, ?)
|
|
""",
|
|
(token, user.id, created_at, expires_at)
|
|
)
|
|
|
|
return Session(
|
|
token=token,
|
|
user_id=user.id,
|
|
created_at=created_at,
|
|
expires_at=expires_at,
|
|
)
|
|
|
|
def get_session(self, token: str) -> Optional[Session]:
|
|
"""Get session by token."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(
|
|
"SELECT * FROM sessions WHERE token = ?",
|
|
(token,)
|
|
)
|
|
row = cursor.fetchone()
|
|
if row:
|
|
session = Session(
|
|
token=row["token"],
|
|
user_id=row["user_id"],
|
|
created_at=datetime.fromisoformat(row["created_at"]),
|
|
expires_at=datetime.fromisoformat(row["expires_at"]),
|
|
)
|
|
if not session.is_expired():
|
|
return session
|
|
# Clean up expired session
|
|
self.invalidate_session(token)
|
|
return None
|
|
|
|
def get_user_from_session(self, token: str) -> Optional[User]:
|
|
"""Get user from session token."""
|
|
session = self.get_session(token)
|
|
if session:
|
|
return self.get_user_by_id(session.user_id)
|
|
return None
|
|
|
|
def invalidate_session(self, token: str):
|
|
"""Invalidate a session."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute("DELETE FROM sessions WHERE token = ?", (token,))
|
|
|
|
def invalidate_user_sessions(self, user_id: str):
|
|
"""Invalidate all sessions for a user."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute("DELETE FROM sessions WHERE user_id = ?", (user_id,))
|
|
|
|
# =========================================================================
|
|
# Invite Codes
|
|
# =========================================================================
|
|
|
|
def create_invite_code(
|
|
self,
|
|
created_by: str,
|
|
max_uses: int = 1,
|
|
expires_in_days: Optional[int] = 7,
|
|
) -> InviteCode:
|
|
"""Create a new invite code."""
|
|
code = secrets.token_urlsafe(8).upper()[:8] # 8 character code
|
|
created_at = datetime.now()
|
|
expires_at = created_at + timedelta(days=expires_in_days) if expires_in_days else None
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO invite_codes (code, created_by, created_at, expires_at, max_uses)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""",
|
|
(code, created_by, created_at, expires_at, max_uses)
|
|
)
|
|
|
|
return InviteCode(
|
|
code=code,
|
|
created_by=created_by,
|
|
created_at=created_at,
|
|
expires_at=expires_at,
|
|
max_uses=max_uses,
|
|
use_count=0,
|
|
is_active=True,
|
|
)
|
|
|
|
def get_invite_code(self, code: str) -> Optional[InviteCode]:
|
|
"""Get invite code by code string."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
cursor = conn.execute(
|
|
"SELECT * FROM invite_codes WHERE code = ?",
|
|
(code.upper(),)
|
|
)
|
|
row = cursor.fetchone()
|
|
if row:
|
|
return InviteCode(
|
|
code=row["code"],
|
|
created_by=row["created_by"],
|
|
created_at=datetime.fromisoformat(row["created_at"]),
|
|
expires_at=datetime.fromisoformat(row["expires_at"]) if row["expires_at"] else None,
|
|
max_uses=row["max_uses"],
|
|
use_count=row["use_count"],
|
|
is_active=bool(row["is_active"]),
|
|
)
|
|
return None
|
|
|
|
def use_invite_code(self, code: str) -> bool:
|
|
"""Mark an invite code as used. Returns False if invalid."""
|
|
invite = self.get_invite_code(code)
|
|
if not invite or not invite.is_valid():
|
|
return False
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"UPDATE invite_codes SET use_count = use_count + 1 WHERE code = ?",
|
|
(code.upper(),)
|
|
)
|
|
return True
|
|
|
|
def validate_room_code_as_invite(self, room_code: str) -> bool:
|
|
"""
|
|
Check if a room code is valid for registration.
|
|
Room codes from active games act as invite codes.
|
|
"""
|
|
# First check if it's an explicit invite code
|
|
invite = self.get_invite_code(room_code)
|
|
if invite and invite.is_valid():
|
|
return True
|
|
|
|
# Check if it's an active room code (from room manager)
|
|
# This will be checked by the caller since we don't have room_manager here
|
|
return False
|
|
|
|
# =========================================================================
|
|
# Admin Functions
|
|
# =========================================================================
|
|
|
|
def list_users(self, include_inactive: bool = False) -> list[User]:
|
|
"""List all users (admin function)."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
if include_inactive:
|
|
cursor = conn.execute("SELECT * FROM users ORDER BY created_at DESC")
|
|
else:
|
|
cursor = conn.execute(
|
|
"SELECT * FROM users WHERE is_active = 1 ORDER BY created_at DESC"
|
|
)
|
|
return [self._row_to_user(row) for row in cursor.fetchall()]
|
|
|
|
def update_user(
|
|
self,
|
|
user_id: str,
|
|
username: Optional[str] = None,
|
|
email: Optional[str] = None,
|
|
role: Optional[UserRole] = None,
|
|
is_active: Optional[bool] = None,
|
|
) -> Optional[User]:
|
|
"""Update user details (admin function)."""
|
|
updates = []
|
|
params = []
|
|
|
|
if username is not None:
|
|
updates.append("username = ?")
|
|
params.append(username)
|
|
if email is not None:
|
|
updates.append("email = ?")
|
|
params.append(email)
|
|
if role is not None:
|
|
updates.append("role = ?")
|
|
params.append(role.value)
|
|
if is_active is not None:
|
|
updates.append("is_active = ?")
|
|
params.append(is_active)
|
|
|
|
if not updates:
|
|
return self.get_user_by_id(user_id)
|
|
|
|
params.append(user_id)
|
|
|
|
try:
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
f"UPDATE users SET {', '.join(updates)} WHERE id = ?",
|
|
params
|
|
)
|
|
return self.get_user_by_id(user_id)
|
|
except sqlite3.IntegrityError:
|
|
return None
|
|
|
|
def change_password(self, user_id: str, new_password: str) -> bool:
|
|
"""Change user password."""
|
|
password_hash = self._hash_password(new_password)
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.execute(
|
|
"UPDATE users SET password_hash = ? WHERE id = ?",
|
|
(password_hash, user_id)
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def delete_user(self, user_id: str) -> bool:
|
|
"""Delete a user (admin function). Actually just deactivates."""
|
|
# Invalidate all sessions first
|
|
self.invalidate_user_sessions(user_id)
|
|
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.execute(
|
|
"UPDATE users SET is_active = 0 WHERE id = ?",
|
|
(user_id,)
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def list_invite_codes(self, created_by: Optional[str] = None) -> list[InviteCode]:
|
|
"""List invite codes (admin function)."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
if created_by:
|
|
cursor = conn.execute(
|
|
"SELECT * FROM invite_codes WHERE created_by = ? ORDER BY created_at DESC",
|
|
(created_by,)
|
|
)
|
|
else:
|
|
cursor = conn.execute(
|
|
"SELECT * FROM invite_codes ORDER BY created_at DESC"
|
|
)
|
|
return [
|
|
InviteCode(
|
|
code=row["code"],
|
|
created_by=row["created_by"],
|
|
created_at=datetime.fromisoformat(row["created_at"]),
|
|
expires_at=datetime.fromisoformat(row["expires_at"]) if row["expires_at"] else None,
|
|
max_uses=row["max_uses"],
|
|
use_count=row["use_count"],
|
|
is_active=bool(row["is_active"]),
|
|
)
|
|
for row in cursor.fetchall()
|
|
]
|
|
|
|
def deactivate_invite_code(self, code: str) -> bool:
|
|
"""Deactivate an invite code (admin function)."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
cursor = conn.execute(
|
|
"UPDATE invite_codes SET is_active = 0 WHERE code = ?",
|
|
(code.upper(),)
|
|
)
|
|
return cursor.rowcount > 0
|
|
|
|
def cleanup_expired_sessions(self):
|
|
"""Remove expired sessions from database."""
|
|
with sqlite3.connect(self.db_path) as conn:
|
|
conn.execute(
|
|
"DELETE FROM sessions WHERE expires_at < ?",
|
|
(datetime.now(),)
|
|
)
|
|
|
|
|
|
# Global auth manager instance (lazy initialization)
|
|
_auth_manager: Optional[AuthManager] = None
|
|
|
|
|
|
def get_auth_manager() -> AuthManager:
|
|
"""Get or create the global auth manager instance."""
|
|
global _auth_manager
|
|
if _auth_manager is None:
|
|
_auth_manager = AuthManager()
|
|
return _auth_manager
|