Harden auth flow against timing and event-loop blocking

This commit is contained in:
Jage9
2026-02-25 00:17:05 -05:00
parent 54232acd87
commit e7d3b41782
5 changed files with 460 additions and 171 deletions

View File

@@ -87,6 +87,10 @@ This is a behavior guide for packet semantics beyond raw schemas.
- Server is authoritative for all action validation and normalization. - Server is authoritative for all action validation and normalization.
- Server is authoritative for movement acceptance (bounds + rate/delta checks). - Server is authoritative for movement acceptance (bounds + rate/delta checks).
- Server persists account state (last nickname + last position) and restores spawn from that state on auth login/resume. - Server persists account state (last nickname + last position) and restores spawn from that state on auth login/resume.
- Server applies auth hardening before accepting login/register/resume:
- login/register PBKDF2 work runs off the event loop in bounded worker concurrency
- repeated auth failures are rate-limited by IP and IP+identity windows
- auth failures include small randomized response jitter to reduce high-resolution probing
- Client validates incoming packet shapes and applies runtime behavior. - Client validates incoming packet shapes and applies runtime behavior.
- Sound/media field normalization uses shared server policy helpers: - Sound/media field normalization uses shared server policy helpers:
- `none/off` normalize to empty values - `none/off` normalize to empty values

View File

@@ -11,6 +11,7 @@ from pathlib import Path
import re import re
import secrets import secrets
import sqlite3 import sqlite3
import threading
import time import time
@@ -21,6 +22,25 @@ PBKDF2_DKLEN = 32
USERNAME_PATTERN = re.compile(r"^[a-z0-9_-]+$") USERNAME_PATTERN = re.compile(r"^[a-z0-9_-]+$")
def _build_dummy_password_hash() -> str:
"""Build one deterministic PBKDF2 hash used to equalize login miss timing."""
salt = b"chgrid_dummy_salt"
digest = hashlib.pbkdf2_hmac(
"sha256",
b"chgrid_dummy_password",
salt,
PBKDF2_ITERATIONS,
dklen=PBKDF2_DKLEN,
)
salt_b64 = base64.b64encode(salt).decode("ascii")
digest_b64 = base64.b64encode(digest).decode("ascii")
return f"pbkdf2_sha256${PBKDF2_ITERATIONS}${salt_b64}${digest_b64}"
DUMMY_PASSWORD_HASH = _build_dummy_password_hash()
@dataclass(frozen=True) @dataclass(frozen=True)
class AuthUser: class AuthUser:
"""Authenticated account identity details.""" """Authenticated account identity details."""
@@ -74,12 +94,14 @@ class AuthService:
self._token_secret = secret.encode("utf-8") self._token_secret = secret.encode("utf-8")
self._conn = sqlite3.connect(self.db_path, check_same_thread=False) self._conn = sqlite3.connect(self.db_path, check_same_thread=False)
self._conn.row_factory = sqlite3.Row self._conn.row_factory = sqlite3.Row
self._conn_lock = threading.RLock()
self._ensure_schema() self._ensure_schema()
def close(self) -> None: def close(self) -> None:
"""Close the underlying SQLite connection.""" """Close the underlying SQLite connection."""
self._conn.close() with self._conn_lock:
self._conn.close()
def bootstrap_admin(self, username: str, password: str, email: str | None = None) -> AuthUser: def bootstrap_admin(self, username: str, password: str, email: str | None = None) -> AuthUser:
"""Create the first admin account, or fail if one already exists.""" """Create the first admin account, or fail if one already exists."""
@@ -92,7 +114,7 @@ class AuthService:
def has_admin(self) -> bool: def has_admin(self) -> bool:
"""Return True when at least one admin account exists.""" """Return True when at least one admin account exists."""
existing = self._conn.execute("SELECT 1 FROM users WHERE role = 'admin' LIMIT 1").fetchone() existing = self._db_fetchone("SELECT 1 FROM users WHERE role = 'admin' LIMIT 1")
return existing is not None return existing is not None
def register( def register(
@@ -105,160 +127,166 @@ class AuthService:
) -> AuthSession: ) -> AuthSession:
"""Register an account and issue a session token.""" """Register an account and issue a session token."""
normalized_username = self._normalize_username(username) with self._conn_lock:
self._validate_username(normalized_username) normalized_username = self._normalize_username(username)
self._validate_password(password) self._validate_username(normalized_username)
normalized_email = self._normalize_email(email) self._validate_password(password)
if role not in {"user", "admin"}: normalized_email = self._normalize_email(email)
raise AuthError("role must be user or admin.") if role not in {"user", "admin"}:
now_ms = self.now_ms() raise AuthError("role must be user or admin.")
password_hash = self._hash_password(password) now_ms = self.now_ms()
try: password_hash = self._hash_password(password)
self._conn.execute( try:
self._db_execute(
"""
INSERT INTO users (
username, password_hash, email, role, status, created_at_ms, updated_at_ms, last_login_at_ms
) VALUES (?, ?, ?, ?, 'active', ?, ?, ?)
""",
(normalized_username, password_hash, normalized_email, role, now_ms, now_ms, now_ms),
)
self._db_commit()
except sqlite3.IntegrityError as exc:
message = str(exc).lower()
if "users.username" in message:
raise AuthError("Username is already taken.") from exc
if "users.email" in message:
raise AuthError("Email is already in use.") from exc
raise
user = self._get_user_by_username(normalized_username)
if user is None:
raise AuthError("Failed to load newly created user.")
self._db_execute(
""" """
INSERT INTO users ( INSERT OR IGNORE INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms)
username, password_hash, email, role, status, created_at_ms, updated_at_ms, last_login_at_ms VALUES (?, ?, NULL, NULL, ?)
) VALUES (?, ?, ?, ?, 'active', ?, ?, ?)
""", """,
(normalized_username, password_hash, normalized_email, role, now_ms, now_ms, now_ms), (int(user.id), user.username, now_ms),
) )
self._conn.commit() self._db_commit()
except sqlite3.IntegrityError as exc: user = AuthUser(
message = str(exc).lower() id=user.id,
if "users.username" in message: username=user.username,
raise AuthError("Username is already taken.") from exc role=user.role,
if "users.email" in message: status=user.status,
raise AuthError("Email is already in use.") from exc email=user.email,
raise last_nickname=user.username,
user = self._get_user_by_username(normalized_username) last_x=user.last_x,
if user is None: last_y=user.last_y,
raise AuthError("Failed to load newly created user.") )
self._conn.execute( return self._create_session(user)
"""
INSERT OR IGNORE INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms)
VALUES (?, ?, NULL, NULL, ?)
""",
(int(user.id), user.username, now_ms),
)
self._conn.commit()
user = AuthUser(
id=user.id,
username=user.username,
role=user.role,
status=user.status,
email=user.email,
last_nickname=user.username,
last_x=user.last_x,
last_y=user.last_y,
)
return self._create_session(user)
def login(self, username: str, password: str) -> AuthSession: def login(self, username: str, password: str) -> AuthSession:
"""Authenticate credentials and issue a fresh session.""" """Authenticate credentials and issue a fresh session."""
normalized_username = self._normalize_username(username) with self._conn_lock:
user_row = self._conn.execute( normalized_username = self._normalize_username(username)
""" user_row = self._db_fetchone(
SELECT """
u.id, SELECT
u.username, u.id,
u.password_hash, u.username,
u.email, u.password_hash,
u.role, u.email,
u.status, u.role,
us.last_nickname, u.status,
us.last_x, us.last_nickname,
us.last_y us.last_x,
FROM users u us.last_y
LEFT JOIN user_state us ON us.user_id = u.id FROM users u
WHERE u.username = ? LEFT JOIN user_state us ON us.user_id = u.id
""", WHERE u.username = ?
(normalized_username,), """,
).fetchone() (normalized_username,),
if user_row is None:
raise AuthError("Invalid username or password.")
if user_row["status"] != "active":
raise AuthError("Account is disabled.")
if not self._verify_password(password, user_row["password_hash"]):
raise AuthError("Invalid username or password.")
user = self._row_to_user(user_row)
if not user.last_nickname:
self.set_last_nickname(user.id, user.username)
user = AuthUser(
id=user.id,
username=user.username,
role=user.role,
status=user.status,
email=user.email,
last_nickname=user.username,
last_x=user.last_x,
last_y=user.last_y,
) )
self._conn.execute( if user_row is None:
"UPDATE users SET last_login_at_ms = ?, updated_at_ms = ? WHERE id = ?", # Keep response timing aligned with existing-user password checks.
(self.now_ms(), self.now_ms(), user.id), self._verify_password(password, DUMMY_PASSWORD_HASH)
) raise AuthError("Invalid username or password.")
self._conn.commit() if user_row["status"] != "active":
return self._create_session(user) raise AuthError("Account is disabled.")
if not self._verify_password(password, user_row["password_hash"]):
raise AuthError("Invalid username or password.")
user = self._row_to_user(user_row)
if not user.last_nickname:
self.set_last_nickname(user.id, user.username)
user = AuthUser(
id=user.id,
username=user.username,
role=user.role,
status=user.status,
email=user.email,
last_nickname=user.username,
last_x=user.last_x,
last_y=user.last_y,
)
now_ms = self.now_ms()
self._db_execute(
"UPDATE users SET last_login_at_ms = ?, updated_at_ms = ? WHERE id = ?",
(now_ms, now_ms, user.id),
)
self._db_commit()
return self._create_session(user)
def resume(self, token: str) -> AuthSession: def resume(self, token: str) -> AuthSession:
"""Validate a session token and apply rolling expiry.""" """Validate a session token and apply rolling expiry."""
cleaned = token.strip() with self._conn_lock:
if not cleaned: cleaned = token.strip()
raise AuthError("Missing session token.") if not cleaned:
token_hash = self._hash_token(cleaned) raise AuthError("Missing session token.")
row = self._conn.execute( token_hash = self._hash_token(cleaned)
""" row = self._db_fetchone(
SELECT s.id AS session_id, s.user_id, s.expires_at_ms, s.revoked_at_ms, """
u.username, u.role, u.status, u.email, us.last_nickname, us.last_x, us.last_y SELECT s.id AS session_id, s.user_id, s.expires_at_ms, s.revoked_at_ms,
FROM sessions s u.username, u.role, u.status, u.email, us.last_nickname, us.last_x, us.last_y
JOIN users u ON u.id = s.user_id FROM sessions s
LEFT JOIN user_state us ON us.user_id = u.id JOIN users u ON u.id = s.user_id
WHERE s.token_hash = ? LEFT JOIN user_state us ON us.user_id = u.id
""", WHERE s.token_hash = ?
(token_hash,), """,
).fetchone() (token_hash,),
if row is None:
raise AuthError("Invalid session.")
if row["revoked_at_ms"] is not None:
raise AuthError("Session has been revoked.")
now_ms = self.now_ms()
if int(row["expires_at_ms"]) <= now_ms:
self._conn.execute("UPDATE sessions SET revoked_at_ms = ? WHERE id = ?", (now_ms, row["session_id"]))
self._conn.commit()
raise AuthError("Session has expired.")
if row["status"] != "active":
raise AuthError("Account is disabled.")
new_expiry = now_ms + SESSION_TTL_MS
self._conn.execute(
"UPDATE sessions SET last_seen_at_ms = ?, expires_at_ms = ? WHERE id = ?",
(now_ms, new_expiry, row["session_id"]),
)
self._conn.commit()
user = AuthUser(
id=str(row["user_id"]),
username=row["username"],
role=row["role"],
status=row["status"],
email=row["email"],
last_nickname=row["last_nickname"],
last_x=row["last_x"] if "last_x" in row.keys() else None,
last_y=row["last_y"] if "last_y" in row.keys() else None,
)
if not user.last_nickname:
self.set_last_nickname(user.id, user.username)
user = AuthUser(
id=user.id,
username=user.username,
role=user.role,
status=user.status,
email=user.email,
last_nickname=user.username,
last_x=user.last_x,
last_y=user.last_y,
) )
return AuthSession(session_id=row["session_id"], token=cleaned, user=user) if row is None:
raise AuthError("Invalid session.")
if row["revoked_at_ms"] is not None:
raise AuthError("Session has been revoked.")
now_ms = self.now_ms()
if int(row["expires_at_ms"]) <= now_ms:
self._db_execute("UPDATE sessions SET revoked_at_ms = ? WHERE id = ?", (now_ms, row["session_id"]))
self._db_commit()
raise AuthError("Session has expired.")
if row["status"] != "active":
raise AuthError("Account is disabled.")
new_expiry = now_ms + SESSION_TTL_MS
self._db_execute(
"UPDATE sessions SET last_seen_at_ms = ?, expires_at_ms = ? WHERE id = ?",
(now_ms, new_expiry, row["session_id"]),
)
self._db_commit()
user = AuthUser(
id=str(row["user_id"]),
username=row["username"],
role=row["role"],
status=row["status"],
email=row["email"],
last_nickname=row["last_nickname"],
last_x=row["last_x"] if "last_x" in row.keys() else None,
last_y=row["last_y"] if "last_y" in row.keys() else None,
)
if not user.last_nickname:
self.set_last_nickname(user.id, user.username)
user = AuthUser(
id=user.id,
username=user.username,
role=user.role,
status=user.status,
email=user.email,
last_nickname=user.username,
last_x=user.last_x,
last_y=user.last_y,
)
return AuthSession(session_id=row["session_id"], token=cleaned, user=user)
def revoke(self, token: str) -> None: def revoke(self, token: str) -> None:
"""Revoke a session token if it exists.""" """Revoke a session token if it exists."""
@@ -267,11 +295,11 @@ class AuthService:
if not cleaned: if not cleaned:
return return
token_hash = self._hash_token(cleaned) token_hash = self._hash_token(cleaned)
self._conn.execute( self._db_execute(
"UPDATE sessions SET revoked_at_ms = ? WHERE token_hash = ? AND revoked_at_ms IS NULL", "UPDATE sessions SET revoked_at_ms = ? WHERE token_hash = ? AND revoked_at_ms IS NULL",
(self.now_ms(), token_hash), (self.now_ms(), token_hash),
) )
self._conn.commit() self._db_commit()
def set_last_nickname(self, user_id: str, nickname: str) -> None: def set_last_nickname(self, user_id: str, nickname: str) -> None:
"""Persist the most recent nickname for one user.""" """Persist the most recent nickname for one user."""
@@ -284,7 +312,7 @@ class AuthService:
except (TypeError, ValueError): except (TypeError, ValueError):
return return
try: try:
self._conn.execute( self._db_execute(
""" """
INSERT INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms) INSERT INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms)
VALUES (?, ?, NULL, NULL, ?) VALUES (?, ?, NULL, NULL, ?)
@@ -294,9 +322,9 @@ class AuthService:
""", """,
(user_id_value, cleaned, self.now_ms()), (user_id_value, cleaned, self.now_ms()),
) )
self._conn.commit() self._db_commit()
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
self._conn.rollback() self._db_rollback()
def set_last_position(self, user_id: str, x: int, y: int) -> None: def set_last_position(self, user_id: str, x: int, y: int) -> None:
"""Persist last known world position for one user.""" """Persist last known world position for one user."""
@@ -306,7 +334,7 @@ class AuthService:
except (TypeError, ValueError): except (TypeError, ValueError):
return return
try: try:
self._conn.execute( self._db_execute(
""" """
INSERT INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms) INSERT INTO user_state (user_id, last_nickname, last_x, last_y, updated_at_ms)
VALUES (?, NULL, ?, ?, ?) VALUES (?, NULL, ?, ?, ?)
@@ -317,9 +345,9 @@ class AuthService:
""", """,
(user_id_value, int(x), int(y), self.now_ms()), (user_id_value, int(x), int(y), self.now_ms()),
) )
self._conn.commit() self._db_commit()
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
self._conn.rollback() self._db_rollback()
@staticmethod @staticmethod
def now_ms() -> int: def now_ms() -> int:
@@ -330,8 +358,9 @@ class AuthService:
def _ensure_schema(self) -> None: def _ensure_schema(self) -> None:
"""Create required auth tables and indexes when missing.""" """Create required auth tables and indexes when missing."""
self._conn.execute("PRAGMA foreign_keys = ON") with self._conn_lock:
self._conn.execute( self._db_execute("PRAGMA foreign_keys = ON")
self._db_execute(
""" """
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -346,7 +375,7 @@ class AuthService:
) )
""" """
) )
self._conn.execute( self._db_execute(
""" """
CREATE TABLE IF NOT EXISTS sessions ( CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -362,7 +391,7 @@ class AuthService:
) )
""" """
) )
self._conn.execute( self._db_execute(
""" """
CREATE TABLE IF NOT EXISTS user_state ( CREATE TABLE IF NOT EXISTS user_state (
user_id INTEGER PRIMARY KEY, user_id INTEGER PRIMARY KEY,
@@ -374,15 +403,15 @@ class AuthService:
) )
""" """
) )
self._conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)") self._db_execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)")
self._conn.execute( self._db_execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email) WHERE email IS NOT NULL" "CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email) WHERE email IS NOT NULL"
) )
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)") self._db_execute("CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)")
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at_ms)") self._db_execute("CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at_ms)")
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_token_hash ON sessions(token_hash)") self._db_execute("CREATE INDEX IF NOT EXISTS idx_sessions_token_hash ON sessions(token_hash)")
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_user_state_updated ON user_state(updated_at_ms)") self._db_execute("CREATE INDEX IF NOT EXISTS idx_user_state_updated ON user_state(updated_at_ms)")
self._conn.commit() self._db_commit()
def _create_session(self, user: AuthUser) -> AuthSession: def _create_session(self, user: AuthUser) -> AuthSession:
"""Issue and persist a new session token for a user.""" """Issue and persist a new session token for a user."""
@@ -391,21 +420,24 @@ class AuthService:
token_hash = self._hash_token(token) token_hash = self._hash_token(token)
now_ms = self.now_ms() now_ms = self.now_ms()
expires_at_ms = now_ms + SESSION_TTL_MS expires_at_ms = now_ms + SESSION_TTL_MS
self._conn.execute( self._db_execute(
""" """
INSERT INTO sessions (user_id, token_hash, created_at_ms, last_seen_at_ms, expires_at_ms, revoked_at_ms, ip, user_agent) INSERT INTO sessions (user_id, token_hash, created_at_ms, last_seen_at_ms, expires_at_ms, revoked_at_ms, ip, user_agent)
VALUES (?, ?, ?, ?, ?, NULL, NULL, NULL) VALUES (?, ?, ?, ?, ?, NULL, NULL, NULL)
""", """,
(user.id, token_hash, now_ms, now_ms, expires_at_ms), (user.id, token_hash, now_ms, now_ms, expires_at_ms),
) )
session_id = str(self._conn.execute("SELECT last_insert_rowid() AS id").fetchone()["id"]) row = self._db_fetchone("SELECT last_insert_rowid() AS id")
self._conn.commit() if row is None:
raise AuthError("Failed to create session.")
session_id = str(row["id"])
self._db_commit()
return AuthSession(session_id=session_id, token=token, user=user) return AuthSession(session_id=session_id, token=token, user=user)
def _get_user_by_username(self, username: str) -> AuthUser | None: def _get_user_by_username(self, username: str) -> AuthUser | None:
"""Fetch one user by normalized username.""" """Fetch one user by normalized username."""
row = self._conn.execute( row = self._db_fetchone(
""" """
SELECT SELECT
u.id, u.id,
@@ -421,11 +453,35 @@ class AuthService:
WHERE u.username = ? WHERE u.username = ?
""", """,
(username,), (username,),
).fetchone() )
if row is None: if row is None:
return None return None
return self._row_to_user(row) return self._row_to_user(row)
def _db_execute(self, sql: str, params: tuple | None = None) -> sqlite3.Cursor:
"""Run one SQL statement with a thread-safe connection lock."""
with self._conn_lock:
return self._conn.execute(sql, params or ())
def _db_fetchone(self, sql: str, params: tuple | None = None) -> sqlite3.Row | None:
"""Run one query and fetch a single row with connection locking."""
with self._conn_lock:
return self._conn.execute(sql, params or ()).fetchone()
def _db_commit(self) -> None:
"""Commit pending DB writes with connection locking."""
with self._conn_lock:
self._conn.commit()
def _db_rollback(self) -> None:
"""Rollback pending DB writes with connection locking."""
with self._conn_lock:
self._conn.rollback()
@staticmethod @staticmethod
def _row_to_user(row: sqlite3.Row) -> AuthUser: def _row_to_user(row: sqlite3.Row) -> AuthUser:
"""Convert a DB row into AuthUser.""" """Convert a DB row into AuthUser."""

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
from collections import deque
from datetime import datetime from datetime import datetime
from getpass import getpass from getpass import getpass
from importlib.metadata import PackageNotFoundError, version as package_version from importlib.metadata import PackageNotFoundError, version as package_version
@@ -89,6 +90,12 @@ PIANO_RECORDING_MAX_EVENTS = 4096
MOVEMENT_TICK_MS = 200 MOVEMENT_TICK_MS = 200
MOVEMENT_MAX_STEPS_PER_TICK = 1 MOVEMENT_MAX_STEPS_PER_TICK = 1
POSITION_PERSIST_DEBOUNCE_MS = 5_000 POSITION_PERSIST_DEBOUNCE_MS = 5_000
AUTH_HASH_MAX_CONCURRENCY = 8
AUTH_RATE_LIMIT_WINDOW_S = 30.0
AUTH_RATE_LIMIT_PER_IP = 20
AUTH_RATE_LIMIT_PER_IDENTITY = 8
AUTH_FAILURE_JITTER_MIN_MS = 0.02
AUTH_FAILURE_JITTER_MAX_MS = 0.08
class SignalingServer: class SignalingServer:
@@ -143,6 +150,9 @@ class SignalingServer:
self._pending_state_save_handle: asyncio.TimerHandle | None = None self._pending_state_save_handle: asyncio.TimerHandle | None = None
self._pending_state_save_started_at: float | None = None self._pending_state_save_started_at: float | None = None
self._last_position_persist_ms_by_user: dict[str, int] = {} self._last_position_persist_ms_by_user: dict[str, int] = {}
self._auth_hash_semaphore = asyncio.Semaphore(AUTH_HASH_MAX_CONCURRENCY)
self._auth_failures_by_ip: dict[str, deque[float]] = {}
self._auth_failures_by_identity: dict[str, deque[float]] = {}
@staticmethod @staticmethod
def _resolve_server_version() -> str: def _resolve_server_version() -> str:
@@ -245,6 +255,81 @@ class SignalingServer:
return "radio" if item.type == "radio_station" else item.type return "radio" if item.type == "radio_station" else item.type
@staticmethod
def _client_ip(client: ClientConnection) -> str:
"""Extract best-effort remote IP string for audit logs and auth throttling."""
address = getattr(client.websocket, "remote_address", None)
if isinstance(address, tuple) and address:
return str(address[0])
if isinstance(address, str):
return address
return "unknown"
@staticmethod
def _prune_failure_window(bucket: deque[float], now_s: float) -> None:
"""Drop expired auth-failure timestamps outside the active limit window."""
threshold = now_s - AUTH_RATE_LIMIT_WINDOW_S
while bucket and bucket[0] < threshold:
bucket.popleft()
def _auth_identity_key(self, client: ClientConnection, packet: ClientPacket) -> str:
"""Build username/IP scoped key used for auth failure throttling."""
if isinstance(packet, (AuthLoginPacket, AuthRegisterPacket)):
username = packet.username.strip().lower()
elif isinstance(packet, AuthResumePacket):
username = "resume"
else:
username = "unknown"
return f"{self._client_ip(client)}::{username}"
def _is_auth_rate_limited(self, client: ClientConnection, packet: ClientPacket) -> bool:
"""Return True when recent auth failures exceed IP or identity thresholds."""
now_s = time.monotonic()
ip_key = self._client_ip(client)
identity_key = self._auth_identity_key(client, packet)
ip_bucket = self._auth_failures_by_ip.setdefault(ip_key, deque())
identity_bucket = self._auth_failures_by_identity.setdefault(identity_key, deque())
self._prune_failure_window(ip_bucket, now_s)
self._prune_failure_window(identity_bucket, now_s)
return len(ip_bucket) >= AUTH_RATE_LIMIT_PER_IP or len(identity_bucket) >= AUTH_RATE_LIMIT_PER_IDENTITY
def _record_auth_failure(self, client: ClientConnection, packet: ClientPacket) -> None:
"""Record a failed auth attempt for IP and identity-scoped throttling."""
now_s = time.monotonic()
ip_key = self._client_ip(client)
identity_key = self._auth_identity_key(client, packet)
self._auth_failures_by_ip.setdefault(ip_key, deque()).append(now_s)
self._auth_failures_by_identity.setdefault(identity_key, deque()).append(now_s)
def _clear_auth_failures(self, client: ClientConnection, packet: ClientPacket) -> None:
"""Clear identity-scoped auth failures after a successful authentication."""
now_s = time.monotonic()
identity_key = self._auth_identity_key(client, packet)
bucket = self._auth_failures_by_identity.get(identity_key)
if not bucket:
return
bucket.clear()
self._prune_failure_window(bucket, now_s)
async def _sleep_auth_failure_jitter(self) -> None:
"""Apply small randomized delay to reduce high-resolution auth timing probes."""
await asyncio.sleep(random.uniform(AUTH_FAILURE_JITTER_MIN_MS, AUTH_FAILURE_JITTER_MAX_MS))
async def _run_auth_hash_task(self, func, /, *args, **kwargs):
"""Run auth service call in a worker thread behind bounded hash concurrency."""
async with self._auth_hash_semaphore:
return await asyncio.to_thread(func, *args, **kwargs)
@staticmethod @staticmethod
def _resolve_item_use_sound(item: WorldItem) -> str | None: def _resolve_item_use_sound(item: WorldItem) -> str | None:
"""Resolve one-shot use sound, preferring per-item param override.""" """Resolve one-shot use sound, preferring per-item param override."""
@@ -892,17 +977,65 @@ class SignalingServer:
) )
return True return True
if isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)) and self._is_auth_rate_limited(
client, packet
):
LOGGER.warning(
"auth rate limited id=%s ip=%s packet=%s",
client.id,
self._client_ip(client),
packet.type,
)
await self._sleep_auth_failure_jitter()
await self._send(
client.websocket,
AuthResultPacket(
type="auth_result",
ok=False,
message="Too many authentication attempts. Try again shortly.",
authPolicy=self._auth_policy(),
),
)
return True
try: try:
if isinstance(packet, AuthRegisterPacket): if isinstance(packet, AuthRegisterPacket):
session = self.auth_service.register(packet.username, packet.password, email=packet.email) session = await self._run_auth_hash_task(
self.auth_service.register,
packet.username,
packet.password,
email=packet.email,
)
LOGGER.info(
"auth register success id=%s ip=%s username=%s user_id=%s",
client.id,
self._client_ip(client),
session.user.username,
session.user.id,
)
elif isinstance(packet, AuthLoginPacket): elif isinstance(packet, AuthLoginPacket):
session = self.auth_service.login(packet.username, packet.password) session = await self._run_auth_hash_task(self.auth_service.login, packet.username, packet.password)
LOGGER.info(
"auth login success id=%s ip=%s username=%s user_id=%s",
client.id,
self._client_ip(client),
session.user.username,
session.user.id,
)
elif isinstance(packet, AuthResumePacket): elif isinstance(packet, AuthResumePacket):
session = self.auth_service.resume(packet.sessionToken) session = self.auth_service.resume(packet.sessionToken)
LOGGER.info(
"auth resume success id=%s ip=%s username=%s user_id=%s",
client.id,
self._client_ip(client),
session.user.username,
session.user.id,
)
elif isinstance(packet, AuthLogoutPacket): elif isinstance(packet, AuthLogoutPacket):
if client.session_token: if client.session_token:
self.auth_service.revoke(client.session_token) self.auth_service.revoke(client.session_token)
client.session_token = None client.session_token = None
LOGGER.info("auth logout id=%s ip=%s username=%s", client.id, self._client_ip(client), client.username)
await self._send( await self._send(
client.websocket, client.websocket,
AuthResultPacket( AuthResultPacket(
@@ -917,6 +1050,16 @@ class SignalingServer:
else: else:
return False return False
except AuthError as exc: except AuthError as exc:
if isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)):
self._record_auth_failure(client, packet)
await self._sleep_auth_failure_jitter()
LOGGER.warning(
"auth failure id=%s ip=%s packet=%s reason=%s",
client.id,
self._client_ip(client),
packet.type,
str(exc),
)
await self._send( await self._send(
client.websocket, client.websocket,
AuthResultPacket( AuthResultPacket(
@@ -928,6 +1071,9 @@ class SignalingServer:
) )
return True return True
if isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)):
self._clear_auth_failures(client, packet)
client.authenticated = True client.authenticated = True
client.user_id = session.user.id client.user_id = session.user.id
client.username = session.user.username client.username = session.user.username

View File

@@ -50,3 +50,20 @@ def test_bootstrap_admin_once(tmp_path: Path) -> None:
finally: finally:
service.close() service.close()
def test_login_missing_user_runs_dummy_verify(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
service = make_auth_service(tmp_path)
try:
calls: list[tuple[str, str]] = []
def fake_verify(password: str, stored: str) -> bool:
calls.append((password, stored))
return False
monkeypatch.setattr(service, "_verify_password", fake_verify)
with pytest.raises(AuthError):
service.login("missing_user", "password99")
assert len(calls) == 1
assert calls[0][0] == "password99"
finally:
service.close()

View File

@@ -16,6 +16,10 @@ def _fake_ws() -> ServerConnection:
return cast(ServerConnection, object()) return cast(ServerConnection, object())
def _packet_types(payloads: list[object]) -> list[str]:
return [getattr(packet, "type", "") for packet in payloads]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_position_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None: async def test_update_position_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41) server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41)
@@ -37,6 +41,68 @@ async def test_update_position_rejects_out_of_bounds(monkeypatch: pytest.MonkeyP
assert broadcast_payloads == [] assert broadcast_payloads == []
@pytest.mark.asyncio
async def test_auth_login_uses_hash_offload(monkeypatch: pytest.MonkeyPatch) -> None:
server = SignalingServer("127.0.0.1", 8765, None, None)
server.auth_service.register("alpha", "password99")
ws = _fake_ws()
client = ClientConnection(websocket=ws, id="u1", nickname="tester")
send_payloads: list[object] = []
offload_calls: list[str] = []
async def fake_send(websocket: ServerConnection, packet: object) -> None:
send_payloads.append(packet)
async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None:
return None
async def fake_run_auth_hash_task(func, /, *args, **kwargs):
offload_calls.append(getattr(func, "__name__", "unknown"))
return func(*args, **kwargs)
monkeypatch.setattr(server, "_send", fake_send)
monkeypatch.setattr(server, "_broadcast", fake_broadcast)
monkeypatch.setattr(server, "_run_auth_hash_task", fake_run_auth_hash_task)
await server._handle_message(client, json.dumps({"type": "auth_login", "username": "alpha", "password": "password99"}))
assert "login" in offload_calls
auth_results = [packet for packet in send_payloads if getattr(packet, "type", "") == "auth_result"]
assert auth_results
assert auth_results[-1].ok is True
@pytest.mark.asyncio
async def test_auth_rate_limit_blocks_before_hash(monkeypatch: pytest.MonkeyPatch) -> None:
server = SignalingServer("127.0.0.1", 8765, None, None)
ws = _fake_ws()
client = ClientConnection(websocket=ws, id="u1", nickname="tester")
send_payloads: list[object] = []
called_login = False
async def fake_send(websocket: ServerConnection, packet: object) -> None:
send_payloads.append(packet)
def fake_login(username: str, password: str): # pragma: no cover - should never run
nonlocal called_login
called_login = True
raise RuntimeError("unexpected login call")
monkeypatch.setattr(server, "_send", fake_send)
monkeypatch.setattr(server, "_sleep_auth_failure_jitter", lambda: asyncio.sleep(0))
monkeypatch.setattr(server.auth_service, "login", fake_login)
monkeypatch.setattr(server, "_is_auth_rate_limited", lambda _client, _packet: True)
await server._handle_message(client, json.dumps({"type": "auth_login", "username": "alpha", "password": "wrongpass"}))
assert called_login is False
assert send_payloads
assert send_payloads[-1].ok is False
assert "too many" in send_payloads[-1].message.lower()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_item_drop_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None: async def test_item_drop_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41) server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41)