Harden auth flow against timing and event-loop blocking

This commit is contained in:
Jage9
2026-02-25 00:17:05 -05:00
parent 54232acd87
commit e7d3b41782
5 changed files with 460 additions and 171 deletions

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import argparse
import asyncio
from collections import deque
from datetime import datetime
from getpass import getpass
from importlib.metadata import PackageNotFoundError, version as package_version
@@ -89,6 +90,12 @@ PIANO_RECORDING_MAX_EVENTS = 4096
MOVEMENT_TICK_MS = 200
MOVEMENT_MAX_STEPS_PER_TICK = 1
POSITION_PERSIST_DEBOUNCE_MS = 5_000
AUTH_HASH_MAX_CONCURRENCY = 8
AUTH_RATE_LIMIT_WINDOW_S = 30.0
AUTH_RATE_LIMIT_PER_IP = 20
AUTH_RATE_LIMIT_PER_IDENTITY = 8
AUTH_FAILURE_JITTER_MIN_MS = 0.02
AUTH_FAILURE_JITTER_MAX_MS = 0.08
class SignalingServer:
@@ -143,6 +150,9 @@ class SignalingServer:
self._pending_state_save_handle: asyncio.TimerHandle | None = None
self._pending_state_save_started_at: float | None = None
self._last_position_persist_ms_by_user: dict[str, int] = {}
self._auth_hash_semaphore = asyncio.Semaphore(AUTH_HASH_MAX_CONCURRENCY)
self._auth_failures_by_ip: dict[str, deque[float]] = {}
self._auth_failures_by_identity: dict[str, deque[float]] = {}
@staticmethod
def _resolve_server_version() -> str:
@@ -245,6 +255,81 @@ class SignalingServer:
return "radio" if item.type == "radio_station" else item.type
@staticmethod
def _client_ip(client: ClientConnection) -> str:
"""Extract best-effort remote IP string for audit logs and auth throttling."""
address = getattr(client.websocket, "remote_address", None)
if isinstance(address, tuple) and address:
return str(address[0])
if isinstance(address, str):
return address
return "unknown"
@staticmethod
def _prune_failure_window(bucket: deque[float], now_s: float) -> None:
"""Drop expired auth-failure timestamps outside the active limit window."""
threshold = now_s - AUTH_RATE_LIMIT_WINDOW_S
while bucket and bucket[0] < threshold:
bucket.popleft()
def _auth_identity_key(self, client: ClientConnection, packet: ClientPacket) -> str:
"""Build username/IP scoped key used for auth failure throttling."""
if isinstance(packet, (AuthLoginPacket, AuthRegisterPacket)):
username = packet.username.strip().lower()
elif isinstance(packet, AuthResumePacket):
username = "resume"
else:
username = "unknown"
return f"{self._client_ip(client)}::{username}"
def _is_auth_rate_limited(self, client: ClientConnection, packet: ClientPacket) -> bool:
"""Return True when recent auth failures exceed IP or identity thresholds."""
now_s = time.monotonic()
ip_key = self._client_ip(client)
identity_key = self._auth_identity_key(client, packet)
ip_bucket = self._auth_failures_by_ip.setdefault(ip_key, deque())
identity_bucket = self._auth_failures_by_identity.setdefault(identity_key, deque())
self._prune_failure_window(ip_bucket, now_s)
self._prune_failure_window(identity_bucket, now_s)
return len(ip_bucket) >= AUTH_RATE_LIMIT_PER_IP or len(identity_bucket) >= AUTH_RATE_LIMIT_PER_IDENTITY
def _record_auth_failure(self, client: ClientConnection, packet: ClientPacket) -> None:
"""Record a failed auth attempt for IP and identity-scoped throttling."""
now_s = time.monotonic()
ip_key = self._client_ip(client)
identity_key = self._auth_identity_key(client, packet)
self._auth_failures_by_ip.setdefault(ip_key, deque()).append(now_s)
self._auth_failures_by_identity.setdefault(identity_key, deque()).append(now_s)
def _clear_auth_failures(self, client: ClientConnection, packet: ClientPacket) -> None:
"""Clear identity-scoped auth failures after a successful authentication."""
now_s = time.monotonic()
identity_key = self._auth_identity_key(client, packet)
bucket = self._auth_failures_by_identity.get(identity_key)
if not bucket:
return
bucket.clear()
self._prune_failure_window(bucket, now_s)
async def _sleep_auth_failure_jitter(self) -> None:
"""Apply small randomized delay to reduce high-resolution auth timing probes."""
await asyncio.sleep(random.uniform(AUTH_FAILURE_JITTER_MIN_MS, AUTH_FAILURE_JITTER_MAX_MS))
async def _run_auth_hash_task(self, func, /, *args, **kwargs):
"""Run auth service call in a worker thread behind bounded hash concurrency."""
async with self._auth_hash_semaphore:
return await asyncio.to_thread(func, *args, **kwargs)
@staticmethod
def _resolve_item_use_sound(item: WorldItem) -> str | None:
"""Resolve one-shot use sound, preferring per-item param override."""
@@ -892,17 +977,65 @@ class SignalingServer:
)
return True
if isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)) and self._is_auth_rate_limited(
client, packet
):
LOGGER.warning(
"auth rate limited id=%s ip=%s packet=%s",
client.id,
self._client_ip(client),
packet.type,
)
await self._sleep_auth_failure_jitter()
await self._send(
client.websocket,
AuthResultPacket(
type="auth_result",
ok=False,
message="Too many authentication attempts. Try again shortly.",
authPolicy=self._auth_policy(),
),
)
return True
try:
if isinstance(packet, AuthRegisterPacket):
session = self.auth_service.register(packet.username, packet.password, email=packet.email)
session = await self._run_auth_hash_task(
self.auth_service.register,
packet.username,
packet.password,
email=packet.email,
)
LOGGER.info(
"auth register success id=%s ip=%s username=%s user_id=%s",
client.id,
self._client_ip(client),
session.user.username,
session.user.id,
)
elif isinstance(packet, AuthLoginPacket):
session = self.auth_service.login(packet.username, packet.password)
session = await self._run_auth_hash_task(self.auth_service.login, packet.username, packet.password)
LOGGER.info(
"auth login success id=%s ip=%s username=%s user_id=%s",
client.id,
self._client_ip(client),
session.user.username,
session.user.id,
)
elif isinstance(packet, AuthResumePacket):
session = self.auth_service.resume(packet.sessionToken)
LOGGER.info(
"auth resume success id=%s ip=%s username=%s user_id=%s",
client.id,
self._client_ip(client),
session.user.username,
session.user.id,
)
elif isinstance(packet, AuthLogoutPacket):
if client.session_token:
self.auth_service.revoke(client.session_token)
client.session_token = None
LOGGER.info("auth logout id=%s ip=%s username=%s", client.id, self._client_ip(client), client.username)
await self._send(
client.websocket,
AuthResultPacket(
@@ -917,6 +1050,16 @@ class SignalingServer:
else:
return False
except AuthError as exc:
if isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)):
self._record_auth_failure(client, packet)
await self._sleep_auth_failure_jitter()
LOGGER.warning(
"auth failure id=%s ip=%s packet=%s reason=%s",
client.id,
self._client_ip(client),
packet.type,
str(exc),
)
await self._send(
client.websocket,
AuthResultPacket(
@@ -928,6 +1071,9 @@ class SignalingServer:
)
return True
if isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)):
self._clear_auth_failures(client, packet)
client.authenticated = True
client.user_id = session.user.id
client.username = session.user.username