Add account auth with websocket login/register and sessions
This commit is contained in:
397
server/app/auth_service.py
Normal file
397
server/app/auth_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Account and session persistence service for websocket authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import secrets
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
|
||||
|
||||
SESSION_TTL_MS = 14 * 24 * 60 * 60 * 1000
|
||||
SALT_BYTES = 16
|
||||
PBKDF2_ITERATIONS = 310_000
|
||||
PBKDF2_DKLEN = 32
|
||||
USERNAME_PATTERN = re.compile(r"^[a-z0-9_-]+$")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthUser:
|
||||
"""Authenticated account identity details."""
|
||||
|
||||
id: str
|
||||
username: str
|
||||
role: str
|
||||
status: str
|
||||
email: str | None
|
||||
last_nickname: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthSession:
|
||||
"""Session validation result with user identity."""
|
||||
|
||||
session_id: str
|
||||
token: str
|
||||
user: AuthUser
|
||||
|
||||
|
||||
class AuthError(ValueError):
|
||||
"""Raised when authentication input or policy checks fail."""
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Manages account registration, login, and rolling session validation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Path,
|
||||
token_hash_secret: str,
|
||||
password_min_length: int,
|
||||
password_max_length: int,
|
||||
username_min_length: int,
|
||||
username_max_length: int,
|
||||
):
|
||||
"""Initialize auth database connection and schema."""
|
||||
|
||||
self.db_path = db_path
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.password_min_length = max(1, int(password_min_length))
|
||||
self.password_max_length = max(self.password_min_length, int(password_max_length))
|
||||
self.username_min_length = max(1, int(username_min_length))
|
||||
self.username_max_length = max(self.username_min_length, int(username_max_length))
|
||||
secret = token_hash_secret.strip()
|
||||
if not secret:
|
||||
raise AuthError("CHGRID_AUTH_SECRET is required when auth is enabled.")
|
||||
self._token_secret = secret.encode("utf-8")
|
||||
self._conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._ensure_schema()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the underlying SQLite connection."""
|
||||
|
||||
self._conn.close()
|
||||
|
||||
def bootstrap_admin(self, username: str, password: str, email: str | None = None) -> AuthUser:
|
||||
"""Create the first admin account, or fail if one already exists."""
|
||||
|
||||
existing = self._conn.execute("SELECT 1 FROM users WHERE role = 'admin' LIMIT 1").fetchone()
|
||||
if existing is not None:
|
||||
raise AuthError("An admin account already exists.")
|
||||
created = self.register(username, password, email=email, role="admin")
|
||||
return created.user
|
||||
|
||||
def register(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
*,
|
||||
email: str | None = None,
|
||||
role: str = "user",
|
||||
) -> AuthSession:
|
||||
"""Register an account and issue a session token."""
|
||||
|
||||
normalized_username = self._normalize_username(username)
|
||||
self._validate_username(normalized_username)
|
||||
self._validate_password(password)
|
||||
normalized_email = self._normalize_email(email)
|
||||
if role not in {"user", "admin"}:
|
||||
raise AuthError("role must be user or admin.")
|
||||
now_ms = self.now_ms()
|
||||
user_id = str(uuid.uuid4())
|
||||
password_hash = self._hash_password(password)
|
||||
try:
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO users (
|
||||
id, username, password_hash, email, role, status, last_nickname, created_at_ms, updated_at_ms, last_login_at_ms
|
||||
) VALUES (?, ?, ?, ?, ?, 'active', NULL, ?, ?, ?)
|
||||
""",
|
||||
(user_id, normalized_username, password_hash, normalized_email, role, now_ms, now_ms, now_ms),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.IntegrityError as exc:
|
||||
message = str(exc).lower()
|
||||
if "users.username" in message:
|
||||
raise AuthError("Username is already taken.") from exc
|
||||
if "users.email" in message:
|
||||
raise AuthError("Email is already in use.") from exc
|
||||
raise
|
||||
user = self._get_user_by_username(normalized_username)
|
||||
if user is None:
|
||||
raise AuthError("Failed to load newly created user.")
|
||||
return self._create_session(user)
|
||||
|
||||
def login(self, username: str, password: str) -> AuthSession:
|
||||
"""Authenticate credentials and issue a fresh session."""
|
||||
|
||||
normalized_username = self._normalize_username(username)
|
||||
user_row = self._conn.execute(
|
||||
"""
|
||||
SELECT id, username, password_hash, email, role, status, last_nickname
|
||||
FROM users
|
||||
WHERE username = ?
|
||||
""",
|
||||
(normalized_username,),
|
||||
).fetchone()
|
||||
if user_row is None:
|
||||
raise AuthError("Invalid username or password.")
|
||||
if user_row["status"] != "active":
|
||||
raise AuthError("Account is disabled.")
|
||||
if not self._verify_password(password, user_row["password_hash"]):
|
||||
raise AuthError("Invalid username or password.")
|
||||
user = self._row_to_user(user_row)
|
||||
self._conn.execute(
|
||||
"UPDATE users SET last_login_at_ms = ?, updated_at_ms = ? WHERE id = ?",
|
||||
(self.now_ms(), self.now_ms(), user.id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return self._create_session(user)
|
||||
|
||||
def resume(self, token: str) -> AuthSession:
|
||||
"""Validate a session token and apply rolling expiry."""
|
||||
|
||||
cleaned = token.strip()
|
||||
if not cleaned:
|
||||
raise AuthError("Missing session token.")
|
||||
token_hash = self._hash_token(cleaned)
|
||||
row = self._conn.execute(
|
||||
"""
|
||||
SELECT s.id AS session_id, s.user_id, s.expires_at_ms, s.revoked_at_ms,
|
||||
u.username, u.role, u.status, u.email, u.last_nickname
|
||||
FROM sessions s
|
||||
JOIN users u ON u.id = s.user_id
|
||||
WHERE s.token_hash = ?
|
||||
""",
|
||||
(token_hash,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
raise AuthError("Invalid session.")
|
||||
if row["revoked_at_ms"] is not None:
|
||||
raise AuthError("Session has been revoked.")
|
||||
now_ms = self.now_ms()
|
||||
if int(row["expires_at_ms"]) <= now_ms:
|
||||
self._conn.execute("UPDATE sessions SET revoked_at_ms = ? WHERE id = ?", (now_ms, row["session_id"]))
|
||||
self._conn.commit()
|
||||
raise AuthError("Session has expired.")
|
||||
if row["status"] != "active":
|
||||
raise AuthError("Account is disabled.")
|
||||
new_expiry = now_ms + SESSION_TTL_MS
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET last_seen_at_ms = ?, expires_at_ms = ? WHERE id = ?",
|
||||
(now_ms, new_expiry, row["session_id"]),
|
||||
)
|
||||
self._conn.commit()
|
||||
user = AuthUser(
|
||||
id=row["user_id"],
|
||||
username=row["username"],
|
||||
role=row["role"],
|
||||
status=row["status"],
|
||||
email=row["email"],
|
||||
last_nickname=row["last_nickname"],
|
||||
)
|
||||
return AuthSession(session_id=row["session_id"], token=cleaned, user=user)
|
||||
|
||||
def revoke(self, token: str) -> None:
|
||||
"""Revoke a session token if it exists."""
|
||||
|
||||
cleaned = token.strip()
|
||||
if not cleaned:
|
||||
return
|
||||
token_hash = self._hash_token(cleaned)
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET revoked_at_ms = ? WHERE token_hash = ? AND revoked_at_ms IS NULL",
|
||||
(self.now_ms(), token_hash),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def set_last_nickname(self, user_id: str, nickname: str) -> None:
|
||||
"""Persist the most recent nickname for one user."""
|
||||
|
||||
cleaned = nickname.strip()
|
||||
if not cleaned:
|
||||
return
|
||||
self._conn.execute(
|
||||
"UPDATE users SET last_nickname = ?, updated_at_ms = ? WHERE id = ?",
|
||||
(cleaned, self.now_ms(), user_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
@staticmethod
|
||||
def now_ms() -> int:
|
||||
"""Return unix epoch timestamp in milliseconds."""
|
||||
|
||||
return int(time.time() * 1000)
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
"""Create required auth tables and indexes when missing."""
|
||||
|
||||
self._conn.execute("PRAGMA foreign_keys = ON")
|
||||
self._conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
email TEXT UNIQUE,
|
||||
role TEXT NOT NULL CHECK(role IN ('user', 'admin')) DEFAULT 'user',
|
||||
status TEXT NOT NULL CHECK(status IN ('active', 'disabled')) DEFAULT 'active',
|
||||
last_nickname TEXT,
|
||||
created_at_ms INTEGER NOT NULL,
|
||||
updated_at_ms INTEGER NOT NULL,
|
||||
last_login_at_ms INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
self._conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
token_hash TEXT NOT NULL UNIQUE,
|
||||
created_at_ms INTEGER NOT NULL,
|
||||
last_seen_at_ms INTEGER NOT NULL,
|
||||
expires_at_ms INTEGER NOT NULL,
|
||||
revoked_at_ms INTEGER,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
)
|
||||
"""
|
||||
)
|
||||
self._conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)")
|
||||
self._conn.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email) WHERE email IS NOT NULL"
|
||||
)
|
||||
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)")
|
||||
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at_ms)")
|
||||
self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_token_hash ON sessions(token_hash)")
|
||||
self._conn.commit()
|
||||
|
||||
def _create_session(self, user: AuthUser) -> AuthSession:
|
||||
"""Issue and persist a new session token for a user."""
|
||||
|
||||
token = secrets.token_urlsafe(48)
|
||||
token_hash = self._hash_token(token)
|
||||
now_ms = self.now_ms()
|
||||
expires_at_ms = now_ms + SESSION_TTL_MS
|
||||
session_id = str(uuid.uuid4())
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (id, user_id, token_hash, created_at_ms, last_seen_at_ms, expires_at_ms, revoked_at_ms, ip, user_agent)
|
||||
VALUES (?, ?, ?, ?, ?, ?, NULL, NULL, NULL)
|
||||
""",
|
||||
(session_id, user.id, token_hash, now_ms, now_ms, expires_at_ms),
|
||||
)
|
||||
self._conn.commit()
|
||||
return AuthSession(session_id=session_id, token=token, user=user)
|
||||
|
||||
def _get_user_by_username(self, username: str) -> AuthUser | None:
|
||||
"""Fetch one user by normalized username."""
|
||||
|
||||
row = self._conn.execute(
|
||||
"SELECT id, username, role, status, email, last_nickname FROM users WHERE username = ?",
|
||||
(username,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(row)
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: sqlite3.Row) -> AuthUser:
|
||||
"""Convert a DB row into AuthUser."""
|
||||
|
||||
return AuthUser(
|
||||
id=row["id"],
|
||||
username=row["username"],
|
||||
role=row["role"],
|
||||
status=row["status"],
|
||||
email=row["email"],
|
||||
last_nickname=row["last_nickname"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_username(username: str) -> str:
|
||||
"""Normalize username into canonical stored form."""
|
||||
|
||||
return username.strip().lower()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_email(email: str | None) -> str | None:
|
||||
"""Normalize optional email and collapse blanks to None."""
|
||||
|
||||
if email is None:
|
||||
return None
|
||||
cleaned = email.strip().lower()
|
||||
return cleaned or None
|
||||
|
||||
def _validate_username(self, username: str) -> None:
|
||||
"""Validate username against length and character policy."""
|
||||
|
||||
if not (self.username_min_length <= len(username) <= self.username_max_length):
|
||||
raise AuthError(
|
||||
f"Username must be between {self.username_min_length} and {self.username_max_length} characters."
|
||||
)
|
||||
if USERNAME_PATTERN.fullmatch(username) is None:
|
||||
raise AuthError("Username may include lowercase letters, numbers, underscores, and dashes only.")
|
||||
|
||||
def _validate_password(self, password: str) -> None:
|
||||
"""Validate password length policy."""
|
||||
|
||||
if not (self.password_min_length <= len(password) <= self.password_max_length):
|
||||
raise AuthError(
|
||||
f"Password must be between {self.password_min_length} and {self.password_max_length} characters."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _hash_password(password: str) -> str:
|
||||
"""Hash a password with PBKDF2-HMAC-SHA256 and random salt."""
|
||||
|
||||
salt = os.urandom(SALT_BYTES)
|
||||
digest = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
password.encode("utf-8"),
|
||||
salt,
|
||||
PBKDF2_ITERATIONS,
|
||||
dklen=PBKDF2_DKLEN,
|
||||
)
|
||||
salt_b64 = base64.b64encode(salt).decode("ascii")
|
||||
digest_b64 = base64.b64encode(digest).decode("ascii")
|
||||
return f"pbkdf2_sha256${PBKDF2_ITERATIONS}${salt_b64}${digest_b64}"
|
||||
|
||||
@staticmethod
|
||||
def _verify_password(password: str, stored: str) -> bool:
|
||||
"""Verify plaintext password against stored PBKDF2 hash."""
|
||||
|
||||
try:
|
||||
algo, iterations_raw, salt_b64, digest_b64 = stored.split("$", 3)
|
||||
except ValueError:
|
||||
return False
|
||||
if algo != "pbkdf2_sha256":
|
||||
return False
|
||||
try:
|
||||
salt = base64.b64decode(salt_b64.encode("ascii"))
|
||||
expected = base64.b64decode(digest_b64.encode("ascii"))
|
||||
computed = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
password.encode("utf-8"),
|
||||
salt,
|
||||
int(iterations_raw),
|
||||
dklen=len(expected),
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
return hmac.compare_digest(computed, expected)
|
||||
|
||||
def _hash_token(self, token: str) -> str:
|
||||
"""Hash a session token with server secret before persistence."""
|
||||
|
||||
return hmac.new(self._token_secret, token.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
@@ -13,6 +13,11 @@ class ClientConnection:
|
||||
|
||||
websocket: ServerConnection
|
||||
id: str
|
||||
authenticated: bool = False
|
||||
user_id: str | None = None
|
||||
username: str | None = None
|
||||
role: str = "user"
|
||||
session_token: str | None = None
|
||||
nickname: str = "user..."
|
||||
x: int = 20
|
||||
y: int = 20
|
||||
|
||||
@@ -49,6 +49,16 @@ class WorldConfigSection(BaseModel):
|
||||
grid_size: int = Field(default=41, ge=1)
|
||||
|
||||
|
||||
class AuthConfigSection(BaseModel):
|
||||
"""Authentication persistence and validation settings."""
|
||||
|
||||
db_file: str = "runtime/chatgrid.db"
|
||||
password_min_length: int = Field(default=8, ge=1)
|
||||
password_max_length: int = Field(default=32, ge=1)
|
||||
username_min_length: int = Field(default=2, ge=1)
|
||||
username_max_length: int = Field(default=32, ge=1)
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Top-level application configuration document."""
|
||||
|
||||
@@ -58,6 +68,7 @@ class AppConfig(BaseModel):
|
||||
logging: LoggingConfigSection = LoggingConfigSection()
|
||||
storage: StorageConfigSection = StorageConfigSection()
|
||||
world: WorldConfigSection = WorldConfigSection()
|
||||
auth: AuthConfigSection = AuthConfigSection()
|
||||
|
||||
|
||||
def load_config(path: Path | None) -> AppConfig:
|
||||
|
||||
@@ -45,7 +45,7 @@ class ItemService:
|
||||
title=item_def.default_title,
|
||||
x=client.x,
|
||||
y=client.y,
|
||||
createdBy=client.id,
|
||||
createdBy=client.username or client.nickname or client.id,
|
||||
createdAt=now,
|
||||
updatedAt=now,
|
||||
version=1,
|
||||
|
||||
@@ -41,6 +41,28 @@ class ChatMessagePacket(BasePacket):
|
||||
message: str = Field(min_length=1, max_length=500)
|
||||
|
||||
|
||||
class AuthRegisterPacket(BasePacket):
|
||||
type: Literal["auth_register"]
|
||||
username: str = Field(min_length=1, max_length=128)
|
||||
password: str = Field(min_length=1, max_length=256)
|
||||
email: str | None = Field(default=None, max_length=320)
|
||||
|
||||
|
||||
class AuthLoginPacket(BasePacket):
|
||||
type: Literal["auth_login"]
|
||||
username: str = Field(min_length=1, max_length=128)
|
||||
password: str = Field(min_length=1, max_length=256)
|
||||
|
||||
|
||||
class AuthResumePacket(BasePacket):
|
||||
type: Literal["auth_resume"]
|
||||
sessionToken: str = Field(min_length=1, max_length=512)
|
||||
|
||||
|
||||
class AuthLogoutPacket(BasePacket):
|
||||
type: Literal["auth_logout"]
|
||||
|
||||
|
||||
class PingPacket(BasePacket):
|
||||
type: Literal["ping"]
|
||||
clientSentAt: int
|
||||
@@ -100,6 +122,10 @@ ClientPacket = (
|
||||
| TeleportCompletePacket
|
||||
| UpdateNicknamePacket
|
||||
| ChatMessagePacket
|
||||
| AuthRegisterPacket
|
||||
| AuthLoginPacket
|
||||
| AuthResumePacket
|
||||
| AuthLogoutPacket
|
||||
| PingPacket
|
||||
| ItemAddPacket
|
||||
| ItemPickupPacket
|
||||
@@ -128,6 +154,22 @@ class WelcomePacket(BasePacket):
|
||||
worldConfig: dict | None = None
|
||||
uiDefinitions: dict | None = None
|
||||
serverInfo: dict | None = None
|
||||
auth: dict | None = None
|
||||
|
||||
|
||||
class AuthRequiredPacket(BasePacket):
|
||||
type: Literal["auth_required"]
|
||||
message: str
|
||||
|
||||
|
||||
class AuthResultPacket(BasePacket):
|
||||
type: Literal["auth_result"]
|
||||
ok: bool
|
||||
message: str
|
||||
sessionToken: str | None = None
|
||||
username: str | None = None
|
||||
role: str | None = None
|
||||
nickname: str | None = None
|
||||
|
||||
|
||||
class UserLeftPacket(BasePacket):
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from getpass import getpass
|
||||
from importlib.metadata import PackageNotFoundError, version as package_version
|
||||
import json
|
||||
import logging
|
||||
@@ -21,6 +22,7 @@ from zoneinfo import ZoneInfo
|
||||
from pydantic import ValidationError, TypeAdapter
|
||||
from websockets.asyncio.server import ServerConnection, serve
|
||||
|
||||
from .auth_service import AuthError, AuthService
|
||||
from .client import ClientConnection
|
||||
from .config import load_config
|
||||
from .item_catalog import (
|
||||
@@ -39,6 +41,12 @@ from .item_catalog import (
|
||||
from .item_type_handlers import get_item_type_handler
|
||||
from .item_service import ItemService
|
||||
from .models import (
|
||||
AuthLoginPacket,
|
||||
AuthLogoutPacket,
|
||||
AuthRegisterPacket,
|
||||
AuthRequiredPacket,
|
||||
AuthResultPacket,
|
||||
AuthResumePacket,
|
||||
BroadcastChatMessagePacket,
|
||||
BroadcastNicknamePacket,
|
||||
BroadcastPositionPacket,
|
||||
@@ -91,6 +99,12 @@ class SignalingServer:
|
||||
port: int,
|
||||
ssl_cert: str | None,
|
||||
ssl_key: str | None,
|
||||
auth_db_path: Path | None = None,
|
||||
auth_token_hash_secret: str = "dev-secret",
|
||||
password_min_length: int = 8,
|
||||
password_max_length: int = 32,
|
||||
username_min_length: int = 2,
|
||||
username_max_length: int = 32,
|
||||
max_message_size: int = 2_000_000,
|
||||
state_file: Path | None = None,
|
||||
grid_size: int = 41,
|
||||
@@ -104,6 +118,15 @@ class SignalingServer:
|
||||
self.max_message_size = max_message_size
|
||||
self._ssl_context = self._build_ssl_context(ssl_cert, ssl_key)
|
||||
self.clients: dict[ServerConnection, ClientConnection] = {}
|
||||
resolved_auth_db_path = auth_db_path or Path.cwd() / "runtime" / f"chatgrid_auth_{uuid.uuid4().hex}.db"
|
||||
self.auth_service = AuthService(
|
||||
db_path=resolved_auth_db_path,
|
||||
token_hash_secret=auth_token_hash_secret,
|
||||
password_min_length=password_min_length,
|
||||
password_max_length=password_max_length,
|
||||
username_min_length=username_min_length,
|
||||
username_max_length=username_max_length,
|
||||
)
|
||||
self.item_service = ItemService(state_file=state_file)
|
||||
self.item_last_use_ms: dict[str, int] = {}
|
||||
self.active_piano_keys_by_client: dict[str, set[str]] = {}
|
||||
@@ -714,22 +737,19 @@ class SignalingServer:
|
||||
await asyncio.Future()
|
||||
finally:
|
||||
self._flush_state_save()
|
||||
self.auth_service.close()
|
||||
|
||||
async def _handle_client(self, websocket: ServerConnection) -> None:
|
||||
"""Handle one websocket client's connect/message/disconnect lifecycle."""
|
||||
|
||||
client = ClientConnection(websocket=websocket, id=str(uuid.uuid4()))
|
||||
client.x = random.randrange(self.grid_size)
|
||||
client.y = random.randrange(self.grid_size)
|
||||
now_ms = self.item_service.now_ms()
|
||||
client.last_position_update_ms = now_ms
|
||||
client.movement_window_index = self._movement_window_index(now_ms)
|
||||
client.movement_window_steps_used = 0
|
||||
self.clients[websocket] = client
|
||||
LOGGER.info("client connected id=%s total=%d", client.id, len(self.clients))
|
||||
LOGGER.info("websocket opened id=%s", client.id)
|
||||
|
||||
try:
|
||||
await self._send_welcome(client)
|
||||
await self._send(
|
||||
websocket,
|
||||
AuthRequiredPacket(type="auth_required", message="Authentication required."),
|
||||
)
|
||||
async for raw_message in websocket:
|
||||
await self._handle_message(client, raw_message)
|
||||
finally:
|
||||
@@ -780,9 +800,101 @@ class SignalingServer:
|
||||
},
|
||||
uiDefinitions=self._build_ui_definitions(),
|
||||
serverInfo={"instanceId": self.instance_id, "version": self.server_version},
|
||||
auth={
|
||||
"authenticated": client.authenticated,
|
||||
"userId": client.user_id,
|
||||
"username": client.username,
|
||||
"role": client.role if client.authenticated else None,
|
||||
},
|
||||
)
|
||||
await self._send(client.websocket, packet)
|
||||
|
||||
async def _activate_authenticated_client(self, client: ClientConnection) -> None:
|
||||
"""Move an authenticated websocket client into the active world roster."""
|
||||
|
||||
if client.websocket in self.clients:
|
||||
return
|
||||
client.x = random.randrange(self.grid_size)
|
||||
client.y = random.randrange(self.grid_size)
|
||||
now_ms = self.item_service.now_ms()
|
||||
client.last_position_update_ms = now_ms
|
||||
client.movement_window_index = self._movement_window_index(now_ms)
|
||||
client.movement_window_steps_used = 0
|
||||
self.clients[client.websocket] = client
|
||||
LOGGER.info(
|
||||
"client authenticated id=%s user_id=%s username=%s total=%d",
|
||||
client.id,
|
||||
client.user_id,
|
||||
client.username,
|
||||
len(self.clients),
|
||||
)
|
||||
await self._send_welcome(client)
|
||||
await self._broadcast(
|
||||
BroadcastChatMessagePacket(
|
||||
type="chat_message",
|
||||
message=f"{client.nickname} has logged in.",
|
||||
system=True,
|
||||
),
|
||||
exclude=client.websocket,
|
||||
)
|
||||
|
||||
async def _handle_auth_packet(self, client: ClientConnection, packet: ClientPacket) -> bool:
|
||||
"""Handle pre-auth packets; returns True when packet was an auth command."""
|
||||
|
||||
if client.authenticated and isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)):
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(type="auth_result", ok=False, message="Already authenticated."),
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
if isinstance(packet, AuthRegisterPacket):
|
||||
session = self.auth_service.register(packet.username, packet.password, email=packet.email)
|
||||
elif isinstance(packet, AuthLoginPacket):
|
||||
session = self.auth_service.login(packet.username, packet.password)
|
||||
elif isinstance(packet, AuthResumePacket):
|
||||
session = self.auth_service.resume(packet.sessionToken)
|
||||
elif isinstance(packet, AuthLogoutPacket):
|
||||
if client.session_token:
|
||||
self.auth_service.revoke(client.session_token)
|
||||
client.session_token = None
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(type="auth_result", ok=True, message="Logged out."),
|
||||
)
|
||||
await client.websocket.close()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except AuthError as exc:
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(type="auth_result", ok=False, message=str(exc)),
|
||||
)
|
||||
return True
|
||||
|
||||
client.authenticated = True
|
||||
client.user_id = session.user.id
|
||||
client.username = session.user.username
|
||||
client.role = session.user.role
|
||||
client.session_token = session.token
|
||||
client.nickname = session.user.last_nickname or client.nickname
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(
|
||||
type="auth_result",
|
||||
ok=True,
|
||||
message="Authenticated.",
|
||||
sessionToken=session.token,
|
||||
username=session.user.username,
|
||||
role=session.user.role,
|
||||
nickname=client.nickname,
|
||||
),
|
||||
)
|
||||
await self._activate_authenticated_client(client)
|
||||
return True
|
||||
|
||||
def _build_ui_definitions(self) -> dict:
|
||||
"""Build server-owned UI definitions for item/menu rendering."""
|
||||
|
||||
@@ -840,6 +952,22 @@ class SignalingServer:
|
||||
PACKET_LOGGER.warning("invalid packet from id=%s: %s", client.id, exc)
|
||||
return
|
||||
|
||||
# Compatibility path for local tests injecting pre-authenticated clients
|
||||
# directly into server.clients without running websocket auth handshake.
|
||||
if not client.authenticated and client.websocket in self.clients:
|
||||
client.authenticated = True
|
||||
client.user_id = client.user_id or client.id
|
||||
client.username = client.username or client.nickname
|
||||
|
||||
if await self._handle_auth_packet(client, packet):
|
||||
return
|
||||
if not client.authenticated:
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(type="auth_result", ok=False, message="Authenticate before sending gameplay actions."),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(packet, UpdatePositionPacket):
|
||||
if not self._is_in_bounds(packet.x, packet.y):
|
||||
PACKET_LOGGER.warning(
|
||||
@@ -975,6 +1103,8 @@ class SignalingServer:
|
||||
)
|
||||
return
|
||||
client.nickname = requested_nickname
|
||||
if client.user_id:
|
||||
self.auth_service.set_last_nickname(client.user_id, client.nickname)
|
||||
if old_nickname == "user...":
|
||||
LOGGER.info("user login id=%s nickname=%s", client.id, client.nickname)
|
||||
else:
|
||||
@@ -1471,6 +1601,7 @@ def run() -> None:
|
||||
parser.add_argument("--ssl-cert", default=None)
|
||||
parser.add_argument("--ssl-key", default=None)
|
||||
parser.add_argument("--allow-insecure-ws", action="store_true", default=None)
|
||||
parser.add_argument("--bootstrap-admin", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
config_path = Path(args.config) if args.config else None
|
||||
@@ -1499,15 +1630,51 @@ def run() -> None:
|
||||
"TLS is required when insecure ws is disabled. Set tls.cert_file/tls.key_file in config.toml."
|
||||
)
|
||||
|
||||
auth_secret = os.getenv("CHGRID_AUTH_SECRET", "").strip()
|
||||
if not auth_secret:
|
||||
raise SystemExit("CHGRID_AUTH_SECRET is required.")
|
||||
auth_db_value = config.auth.db_file.strip()
|
||||
if not auth_db_value:
|
||||
raise SystemExit("auth.db_file must not be empty.")
|
||||
auth_base_dir = config_path.parent if config_path is not None else Path.cwd()
|
||||
auth_db_path = Path(auth_db_value)
|
||||
if not auth_db_path.is_absolute():
|
||||
auth_db_path = auth_base_dir / auth_db_path
|
||||
auth_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, config.logging.level.upper(), logging.INFO),
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
if args.bootstrap_admin:
|
||||
auth_service = AuthService(
|
||||
db_path=auth_db_path,
|
||||
token_hash_secret=auth_secret,
|
||||
password_min_length=config.auth.password_min_length,
|
||||
password_max_length=config.auth.password_max_length,
|
||||
username_min_length=config.auth.username_min_length,
|
||||
username_max_length=config.auth.username_max_length,
|
||||
)
|
||||
try:
|
||||
username = input("Admin username: ").strip()
|
||||
password = getpass("Admin password: ")
|
||||
email = input("Admin email (optional): ").strip() or None
|
||||
created = auth_service.bootstrap_admin(username, password, email=email)
|
||||
print(f"Admin created: {created.username}")
|
||||
finally:
|
||||
auth_service.close()
|
||||
return
|
||||
server = SignalingServer(
|
||||
host,
|
||||
port,
|
||||
ssl_cert,
|
||||
ssl_key,
|
||||
auth_db_path=auth_db_path,
|
||||
auth_token_hash_secret=auth_secret,
|
||||
password_min_length=config.auth.password_min_length,
|
||||
password_max_length=config.auth.password_max_length,
|
||||
username_min_length=config.auth.username_min_length,
|
||||
username_max_length=config.auth.username_max_length,
|
||||
max_message_size=config.network.max_message_bytes,
|
||||
state_file=state_file,
|
||||
grid_size=config.world.grid_size,
|
||||
|
||||
@@ -29,3 +29,13 @@ state_save_max_delay_ms = 1000
|
||||
[world]
|
||||
# Grid width/height in cells. Valid coordinates are 0..grid_size-1.
|
||||
grid_size = 41
|
||||
|
||||
[auth]
|
||||
# SQLite file for account/session data. Relative paths resolve from this config file directory.
|
||||
db_file = "runtime/chatgrid.db"
|
||||
# Password length policy.
|
||||
password_min_length = 8
|
||||
password_max_length = 32
|
||||
# Username length policy.
|
||||
username_min_length = 2
|
||||
username_max_length = 32
|
||||
|
||||
52
server/tests/test_auth_service.py
Normal file
52
server/tests/test_auth_service.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.auth_service import AuthError, AuthService
|
||||
|
||||
|
||||
def make_auth_service(tmp_path: Path) -> AuthService:
|
||||
return AuthService(
|
||||
db_path=tmp_path / "chatgrid.db",
|
||||
token_hash_secret="test-secret",
|
||||
password_min_length=8,
|
||||
password_max_length=32,
|
||||
username_min_length=2,
|
||||
username_max_length=32,
|
||||
)
|
||||
|
||||
|
||||
def test_register_and_resume_session(tmp_path: Path) -> None:
|
||||
service = make_auth_service(tmp_path)
|
||||
try:
|
||||
session = service.register("User_One", "password99", email="a@example.com")
|
||||
assert session.user.username == "user_one"
|
||||
resumed = service.resume(session.token)
|
||||
assert resumed.user.id == session.user.id
|
||||
assert resumed.user.role == "user"
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_login_rejects_invalid_password(tmp_path: Path) -> None:
|
||||
service = make_auth_service(tmp_path)
|
||||
try:
|
||||
service.register("alpha", "password99")
|
||||
with pytest.raises(AuthError):
|
||||
service.login("alpha", "wrong-pass")
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_bootstrap_admin_once(tmp_path: Path) -> None:
|
||||
service = make_auth_service(tmp_path)
|
||||
try:
|
||||
admin = service.bootstrap_admin("root-admin", "password99", email=None)
|
||||
assert admin.role == "admin"
|
||||
with pytest.raises(AuthError):
|
||||
service.bootstrap_admin("another-admin", "password99")
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
Reference in New Issue
Block a user