Files
chat_grid/server/app/auth_service.py

577 lines
21 KiB
Python
Raw Normal View History

"""Account and session persistence service for websocket authentication."""
from __future__ import annotations
from dataclasses import dataclass
import base64
import hashlib
import hmac
import os
from pathlib import Path
import re
import secrets
import sqlite3
import threading
import time
SESSION_TTL_MS = 14 * 24 * 60 * 60 * 1000
SALT_BYTES = 16
PBKDF2_ITERATIONS = 310_000
PBKDF2_DKLEN = 32
USERNAME_PATTERN = re.compile(r"^[a-z0-9_-]+$")
def _build_dummy_password_hash() -> str:
"""Build one deterministic PBKDF2 hash used to equalize login miss timing."""
salt = b"chgrid_dummy_salt"
digest = hashlib.pbkdf2_hmac(
"sha256",
b"chgrid_dummy_password",
salt,
PBKDF2_ITERATIONS,
dklen=PBKDF2_DKLEN,
)
salt_b64 = base64.b64encode(salt).decode("ascii")
digest_b64 = base64.b64encode(digest).decode("ascii")
return f"pbkdf2_sha256${PBKDF2_ITERATIONS}${salt_b64}${digest_b64}"
DUMMY_PASSWORD_HASH = _build_dummy_password_hash()
@dataclass(frozen=True)
class AuthUser:
"""Authenticated account identity details."""
id: str
username: str
role: str
status: str
email: str | None
last_nickname: str | None
last_x: int | None
last_y: int | None
@dataclass(frozen=True)
class AuthSession:
"""Session validation result with user identity."""
session_id: str
token: str
user: AuthUser
class AuthError(ValueError):
"""Raised when authentication input or policy checks fail."""
class AuthService:
"""Manages account registration, login, and rolling session validation."""
def __init__(
self,
db_path: Path,
token_hash_secret: str,
password_min_length: int,
password_max_length: int,
username_min_length: int,
username_max_length: int,
):
"""Initialize auth database connection and schema."""
self.db_path = db_path
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.password_min_length = max(1, int(password_min_length))
self.password_max_length = max(self.password_min_length, int(password_max_length))
self.username_min_length = max(1, int(username_min_length))
self.username_max_length = max(self.username_min_length, int(username_max_length))
secret = token_hash_secret.strip()
if not secret:
raise AuthError("CHGRID_AUTH_SECRET is required when auth is enabled.")
self._token_secret = secret.encode("utf-8")
self._conn = sqlite3.connect(self.db_path, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self._conn_lock = threading.RLock()
self._ensure_schema()
def close(self) -> None:
"""Close the underlying SQLite connection."""
with self._conn_lock:
self._conn.close()
def bootstrap_admin(self, username: str, password: str, email: str | None = None) -> AuthUser:
"""Create the first admin account, or fail if one already exists."""
if self.has_admin():
raise AuthError("An admin account already exists.")
created = self.register(username, password, email=email, role="admin")
return created.user
def has_admin(self) -> bool:
"""Return True when at least one admin account exists."""
existing = self._db_fetchone("SELECT 1 FROM users WHERE role = 'admin' LIMIT 1")
return existing is not None
def register(
self,
username: str,
password: str,
*,
email: str | None = None,
role: str = "user",
) -> AuthSession:
"""Register an account and issue a session token."""
with self._conn_lock:
normalized_username = self._normalize_username(username)
self._validate_username(normalized_username)
self._validate_password(password)
normalized_email = self._normalize_email(email)
if role not in {"user", "admin"}:
raise AuthError("role must be user or admin.")
now_ms = self.now_ms()
password_hash = self._hash_password(password)
try:
self._db_execute(
"""
INSERT INTO users (
username, password_hash, email, role, status, created_at_ms, updated_at_ms, last_login_at_ms
) VALUES (?, ?, ?, ?, 'active', ?, ?, ?)
""",
(normalized_username, password_hash, normalized_email, role, now_ms, now_ms, now_ms),
)
self._db_commit()
except sqlite3.IntegrityError as exc:
message = str(exc).lower()
if "users.username" in message:
raise AuthError("Username is already taken.") from exc
if "users.email" in message:
raise AuthError("Email is already in use.") from exc
raise
user = self._get_user_by_username(normalized_username)
if user is None:
raise AuthError("Failed to load newly created user.")
self._db_execute(
"""
INSERT OR IGNORE INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms)
VALUES (?, ?, NULL, NULL, ?)
""",
(int(user.id), user.username, now_ms),
)
self._db_commit()
user = AuthUser(
id=user.id,
username=user.username,
role=user.role,
status=user.status,
email=user.email,
last_nickname=user.username,
last_x=user.last_x,
last_y=user.last_y,
)
return self._create_session(user)
def login(self, username: str, password: str) -> AuthSession:
"""Authenticate credentials and issue a fresh session."""
with self._conn_lock:
normalized_username = self._normalize_username(username)
user_row = self._db_fetchone(
"""
SELECT
u.id,
u.username,
u.password_hash,
u.email,
u.role,
u.status,
us.last_nickname,
us.last_x,
us.last_y
FROM users u
LEFT JOIN user_state us ON us.user_id = u.id
WHERE u.username = ?
""",
(normalized_username,),
)
if user_row is None:
# Keep response timing aligned with existing-user password checks.
self._verify_password(password, DUMMY_PASSWORD_HASH)
raise AuthError("Invalid username or password.")
if user_row["status"] != "active":
raise AuthError("Account is disabled.")
if not self._verify_password(password, user_row["password_hash"]):
raise AuthError("Invalid username or password.")
user = self._row_to_user(user_row)
if not user.last_nickname:
self.set_last_nickname(user.id, user.username)
user = AuthUser(
id=user.id,
username=user.username,
role=user.role,
status=user.status,
email=user.email,
last_nickname=user.username,
last_x=user.last_x,
last_y=user.last_y,
)
now_ms = self.now_ms()
self._db_execute(
"UPDATE users SET last_login_at_ms = ?, updated_at_ms = ? WHERE id = ?",
(now_ms, now_ms, user.id),
)
self._db_commit()
return self._create_session(user)
def resume(self, token: str) -> AuthSession:
"""Validate a session token and apply rolling expiry."""
with self._conn_lock:
cleaned = token.strip()
if not cleaned:
raise AuthError("Missing session token.")
token_hash = self._hash_token(cleaned)
row = self._db_fetchone(
"""
SELECT s.id AS session_id, s.user_id, s.expires_at_ms, s.revoked_at_ms,
u.username, u.role, u.status, u.email, us.last_nickname, us.last_x, us.last_y
FROM sessions s
JOIN users u ON u.id = s.user_id
LEFT JOIN user_state us ON us.user_id = u.id
WHERE s.token_hash = ?
""",
(token_hash,),
)
if row is None:
raise AuthError("Invalid session.")
if row["revoked_at_ms"] is not None:
raise AuthError("Session has been revoked.")
now_ms = self.now_ms()
if int(row["expires_at_ms"]) <= now_ms:
self._db_execute("UPDATE sessions SET revoked_at_ms = ? WHERE id = ?", (now_ms, row["session_id"]))
self._db_commit()
raise AuthError("Session has expired.")
if row["status"] != "active":
raise AuthError("Account is disabled.")
new_expiry = now_ms + SESSION_TTL_MS
self._db_execute(
"UPDATE sessions SET last_seen_at_ms = ?, expires_at_ms = ? WHERE id = ?",
(now_ms, new_expiry, row["session_id"]),
)
self._db_commit()
user = AuthUser(
id=str(row["user_id"]),
username=row["username"],
role=row["role"],
status=row["status"],
email=row["email"],
last_nickname=row["last_nickname"],
last_x=row["last_x"] if "last_x" in row.keys() else None,
last_y=row["last_y"] if "last_y" in row.keys() else None,
)
if not user.last_nickname:
self.set_last_nickname(user.id, user.username)
user = AuthUser(
id=user.id,
username=user.username,
role=user.role,
status=user.status,
email=user.email,
last_nickname=user.username,
last_x=user.last_x,
last_y=user.last_y,
)
return AuthSession(session_id=row["session_id"], token=cleaned, user=user)
def revoke(self, token: str) -> None:
"""Revoke a session token if it exists."""
cleaned = token.strip()
if not cleaned:
return
token_hash = self._hash_token(cleaned)
self._db_execute(
"UPDATE sessions SET revoked_at_ms = ? WHERE token_hash = ? AND revoked_at_ms IS NULL",
(self.now_ms(), token_hash),
)
self._db_commit()
def set_last_nickname(self, user_id: str, nickname: str) -> None:
"""Persist the most recent nickname for one user."""
cleaned = nickname.strip()
if not cleaned:
return
try:
user_id_value = int(user_id)
except (TypeError, ValueError):
return
try:
self._db_execute(
"""
INSERT INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms)
VALUES (?, ?, NULL, NULL, ?)
ON CONFLICT(user_id) DO UPDATE SET
last_nickname = excluded.last_nickname,
updated_at_ms = excluded.updated_at_ms
""",
(user_id_value, cleaned, self.now_ms()),
)
self._db_commit()
except sqlite3.IntegrityError:
self._db_rollback()
def set_last_position(self, user_id: str, x: int, y: int) -> None:
"""Persist last known world position for one user."""
try:
user_id_value = int(user_id)
except (TypeError, ValueError):
return
try:
self._db_execute(
"""
INSERT INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms)
VALUES (?, NULL, ?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
last_x = excluded.last_x,
last_y = excluded.last_y,
updated_at_ms = excluded.updated_at_ms
""",
(user_id_value, int(x), int(y), self.now_ms()),
)
self._db_commit()
except sqlite3.IntegrityError:
self._db_rollback()
@staticmethod
def now_ms() -> int:
"""Return unix epoch timestamp in milliseconds."""
return int(time.time() * 1000)
def _ensure_schema(self) -> None:
"""Create required auth tables and indexes when missing."""
with self._conn_lock:
self._db_execute("PRAGMA foreign_keys = ON")
self._db_execute(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
email TEXT UNIQUE,
role TEXT NOT NULL CHECK(role IN ('user', 'admin')) DEFAULT 'user',
status TEXT NOT NULL CHECK(status IN ('active', 'disabled')) DEFAULT 'active',
created_at_ms INTEGER NOT NULL,
updated_at_ms INTEGER NOT NULL,
last_login_at_ms INTEGER
)
"""
)
self._db_execute(
"""
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token_hash TEXT NOT NULL UNIQUE,
created_at_ms INTEGER NOT NULL,
last_seen_at_ms INTEGER NOT NULL,
expires_at_ms INTEGER NOT NULL,
revoked_at_ms INTEGER,
ip TEXT,
user_agent TEXT,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
)
"""
)
self._db_execute(
"""
CREATE TABLE IF NOT EXISTS user_state (
user_id INTEGER PRIMARY KEY,
last_nickname TEXT,
last_x INTEGER,
last_y INTEGER,
updated_at_ms INTEGER NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
)
"""
)
self._db_execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)")
self._db_execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email) WHERE email IS NOT NULL"
)
self._db_execute("CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)")
self._db_execute("CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at_ms)")
self._db_execute("CREATE INDEX IF NOT EXISTS idx_sessions_token_hash ON sessions(token_hash)")
self._db_execute("CREATE INDEX IF NOT EXISTS idx_user_state_updated ON user_state(updated_at_ms)")
self._db_commit()
def _create_session(self, user: AuthUser) -> AuthSession:
"""Issue and persist a new session token for a user."""
token = secrets.token_urlsafe(48)
token_hash = self._hash_token(token)
now_ms = self.now_ms()
expires_at_ms = now_ms + SESSION_TTL_MS
self._db_execute(
"""
INSERT INTO sessions (user_id, token_hash, created_at_ms, last_seen_at_ms, expires_at_ms, revoked_at_ms, ip, user_agent)
VALUES (?, ?, ?, ?, ?, NULL, NULL, NULL)
""",
(user.id, token_hash, now_ms, now_ms, expires_at_ms),
)
row = self._db_fetchone("SELECT last_insert_rowid() AS id")
if row is None:
raise AuthError("Failed to create session.")
session_id = str(row["id"])
self._db_commit()
return AuthSession(session_id=session_id, token=token, user=user)
def _get_user_by_username(self, username: str) -> AuthUser | None:
"""Fetch one user by normalized username."""
row = self._db_fetchone(
"""
SELECT
u.id,
u.username,
u.role,
u.status,
u.email,
us.last_nickname,
us.last_x,
us.last_y
FROM users u
LEFT JOIN user_state us ON us.user_id = u.id
WHERE u.username = ?
""",
(username,),
)
if row is None:
return None
return self._row_to_user(row)
def _db_execute(self, sql: str, params: tuple | None = None) -> sqlite3.Cursor:
"""Run one SQL statement with a thread-safe connection lock."""
with self._conn_lock:
return self._conn.execute(sql, params or ())
def _db_fetchone(self, sql: str, params: tuple | None = None) -> sqlite3.Row | None:
"""Run one query and fetch a single row with connection locking."""
with self._conn_lock:
return self._conn.execute(sql, params or ()).fetchone()
def _db_commit(self) -> None:
"""Commit pending DB writes with connection locking."""
with self._conn_lock:
self._conn.commit()
def _db_rollback(self) -> None:
"""Rollback pending DB writes with connection locking."""
with self._conn_lock:
self._conn.rollback()
@staticmethod
def _row_to_user(row: sqlite3.Row) -> AuthUser:
"""Convert a DB row into AuthUser."""
return AuthUser(
id=str(row["id"]),
username=row["username"],
role=row["role"],
status=row["status"],
email=row["email"],
last_nickname=row["last_nickname"] if "last_nickname" in row.keys() else None,
last_x=row["last_x"] if "last_x" in row.keys() else None,
last_y=row["last_y"] if "last_y" in row.keys() else None,
)
@staticmethod
def _normalize_username(username: str) -> str:
"""Normalize username into canonical stored form."""
return username.strip().lower()
@staticmethod
def _normalize_email(email: str | None) -> str | None:
"""Normalize optional email and collapse blanks to None."""
if email is None:
return None
cleaned = email.strip().lower()
return cleaned or None
def _validate_username(self, username: str) -> None:
"""Validate username against length and character policy."""
if not (self.username_min_length <= len(username) <= self.username_max_length):
raise AuthError(
f"Username must be between {self.username_min_length} and {self.username_max_length} characters."
)
if USERNAME_PATTERN.fullmatch(username) is None:
raise AuthError("Username may include lowercase letters, numbers, underscores, and dashes only.")
def _validate_password(self, password: str) -> None:
"""Validate password length policy."""
if not (self.password_min_length <= len(password) <= self.password_max_length):
raise AuthError(
f"Password must be between {self.password_min_length} and {self.password_max_length} characters."
)
@staticmethod
def _hash_password(password: str) -> str:
"""Hash a password with PBKDF2-HMAC-SHA256 and random salt."""
salt = os.urandom(SALT_BYTES)
digest = hashlib.pbkdf2_hmac(
"sha256",
password.encode("utf-8"),
salt,
PBKDF2_ITERATIONS,
dklen=PBKDF2_DKLEN,
)
salt_b64 = base64.b64encode(salt).decode("ascii")
digest_b64 = base64.b64encode(digest).decode("ascii")
return f"pbkdf2_sha256${PBKDF2_ITERATIONS}${salt_b64}${digest_b64}"
@staticmethod
def _verify_password(password: str, stored: str) -> bool:
"""Verify plaintext password against stored PBKDF2 hash."""
try:
algo, iterations_raw, salt_b64, digest_b64 = stored.split("$", 3)
except ValueError:
return False
if algo != "pbkdf2_sha256":
return False
try:
salt = base64.b64decode(salt_b64.encode("ascii"))
expected = base64.b64decode(digest_b64.encode("ascii"))
computed = hashlib.pbkdf2_hmac(
"sha256",
password.encode("utf-8"),
salt,
int(iterations_raw),
dklen=len(expected),
)
except (ValueError, TypeError):
return False
return hmac.compare_digest(computed, expected)
def _hash_token(self, token: str) -> str:
"""Hash a session token with server secret before persistence."""
return hmac.new(self._token_secret, token.encode("utf-8"), hashlib.sha256).hexdigest()