From bf3bc90f2a613f3f5aec899dbc7da0fe7d622e72 Mon Sep 17 00:00:00 2001 From: Jage9 Date: Tue, 24 Feb 2026 22:03:10 -0500 Subject: [PATCH] Add account auth with websocket login/register and sessions --- client/index.html | 29 ++ client/public/version.js | 2 +- client/src/main.ts | 189 +++++++++++- client/src/network/messageHandlers.ts | 10 + client/src/network/protocol.ts | 29 ++ client/src/session/connectionFlow.ts | 13 +- client/src/settings/settingsStore.ts | 26 ++ client/src/styles.css | 38 +++ deploy/README.md | 2 + deploy/scripts/install_server.sh | 12 + deploy/systemd/chat-grid.service | 1 + docs/protocol-notes.md | 11 + docs/runtime-flow.md | 11 +- server/app/auth_service.py | 397 ++++++++++++++++++++++++++ server/app/client.py | 5 + server/app/config.py | 11 + server/app/item_service.py | 2 +- server/app/models.py | 42 +++ server/app/server.py | 185 +++++++++++- server/config.example.toml | 10 + server/tests/test_auth_service.py | 52 ++++ 21 files changed, 1053 insertions(+), 24 deletions(-) create mode 100644 server/app/auth_service.py create mode 100644 server/tests/test_auth_service.py diff --git a/client/index.html b/client/index.html index eeb72a6..3b8dd6f 100644 --- a/client/index.html +++ b/client/index.html @@ -9,12 +9,41 @@

Chat Grid

+
+

Login

+
+ + +
+
+ + +
+ +
+
+ diff --git a/client/public/version.js b/client/public/version.js index 75efdcc..ebbc482 100644 --- a/client/public/version.js +++ b/client/public/version.js @@ -1,5 +1,5 @@ // Maintainer-controlled web client version. // Format: YYYY.MM.DD Rn (example: 2026.02.20 R2) -window.CHGRID_WEB_VERSION = "2026.02.25 R242"; +window.CHGRID_WEB_VERSION = "2026.02.25 R243"; // Optional display timezone for timestamps. Falls back to America/Detroit if unset/invalid. window.CHGRID_TIME_ZONE = "America/Detroit"; diff --git a/client/src/main.ts b/client/src/main.ts index 2cdffa9..0f9dfcd 100644 --- a/client/src/main.ts +++ b/client/src/main.ts @@ -90,12 +90,22 @@ declare global { type Dom = { connectionStatus: HTMLElement; appVersion: HTMLElement; + loginView: HTMLElement; + registerView: HTMLElement; + authUsername: HTMLInputElement; + authPassword: HTMLInputElement; + registerUsername: HTMLInputElement; + registerPassword: HTMLInputElement; + registerEmail: HTMLInputElement; + showRegisterButton: HTMLButtonElement; + showLoginButton: HTMLButtonElement; updatesSection: HTMLElement; updatesToggle: HTMLButtonElement; updatesPanel: HTMLDivElement; nicknameContainer: HTMLDivElement; preconnectNickname: HTMLInputElement; connectButton: HTMLButtonElement; + logoutButton: HTMLButtonElement; disconnectButton: HTMLButtonElement; focusGridButton: HTMLButtonElement; settingsButton: HTMLButtonElement; @@ -113,12 +123,22 @@ type Dom = { const dom: Dom = { connectionStatus: requiredById('connectionStatus'), appVersion: requiredById('appVersion'), + loginView: requiredById('loginView'), + registerView: requiredById('registerView'), + authUsername: requiredById('authUsername'), + authPassword: requiredById('authPassword'), + registerUsername: requiredById('registerUsername'), + registerPassword: requiredById('registerPassword'), + registerEmail: requiredById('registerEmail'), + showRegisterButton: requiredById('showRegisterButton'), + showLoginButton: requiredById('showLoginButton'), updatesSection: requiredById('updatesSection'), updatesToggle: requiredById('updatesToggle'), updatesPanel: requiredById('updatesPanel'), nicknameContainer: requiredById('nicknameContainer'), preconnectNickname: requiredById('preconnectNickname'), connectButton: requiredById('connectButton'), + logoutButton: requiredById('logoutButton'), disconnectButton: requiredById('disconnectButton'), focusGridButton: requiredById('focusGridButton'), settingsButton: requiredById('settingsButton'), @@ -203,6 +223,10 @@ let lastFocusedElement: Element | null = null; let lastAnnouncementText = ''; let lastAnnouncementAt = 0; let outputMode = settings.loadOutputMode(); +let authMode: 'login' | 'register' = 'login'; +let authSessionToken = settings.loadAuthSessionToken(); +let authUsername = settings.loadAuthUsername(); +let pendingAuthRequest = false; const messageBuffer: string[] = []; let messageCursor = -1; const radioRuntime = new RadioStationRuntime(audio, getItemSpatialConfig); @@ -482,14 +506,33 @@ function sanitizeName(value: string): string { return value.replace(/[\u0000-\u001F\u007F<>]/g, '').trim().slice(0, NICKNAME_MAX_LENGTH); } +/** Normalizes auth username according to server policy. */ +function sanitizeAuthUsername(value: string): string { + return value + .trim() + .toLowerCase() + .replace(/[^a-z0-9_-]/g, '') + .slice(0, 32); +} + /** Enables/disables the connect button based on state and nickname validity. */ function updateConnectAvailability(): void { + dom.logoutButton.disabled = !authSessionToken.trim() && !state.running; if (state.running) { dom.connectButton.disabled = true; + dom.loginView.classList.add('hidden'); + dom.registerView.classList.add('hidden'); return; } - const hasNickname = sanitizeName(dom.preconnectNickname.value).length > 0; - dom.connectButton.disabled = mediaSession.isConnecting() || !hasNickname; + dom.loginView.classList.toggle('hidden', authMode !== 'login'); + dom.registerView.classList.toggle('hidden', authMode !== 'register'); + const hasSessionToken = authSessionToken.trim().length > 0; + const hasLoginCredentials = + sanitizeAuthUsername(dom.authUsername.value).length >= 2 && dom.authPassword.value.trim().length >= 8; + const hasRegisterCredentials = + sanitizeAuthUsername(dom.registerUsername.value).length >= 2 && dom.registerPassword.value.trim().length >= 8; + const authReady = hasSessionToken || (authMode === 'login' ? hasLoginCredentials : hasRegisterCredentials); + dom.connectButton.disabled = mediaSession.isConnecting() || !authReady; } /** Restores persisted outbound effect levels from local storage. */ @@ -1294,6 +1337,112 @@ async function reconnectWithRetry(reason: 'heartbeat' | 'socketClose'): Promise< reconnectInFlight = false; } +/** Switches pre-connect auth view between login and register modes. */ +function setAuthMode(mode: 'login' | 'register'): void { + authMode = mode; + dom.loginView.classList.toggle('hidden', mode !== 'login'); + dom.registerView.classList.toggle('hidden', mode !== 'register'); + updateConnectAvailability(); +} + +/** Builds outbound auth packet from local token or active auth form. */ +function buildAuthRequestPacket(): OutgoingMessage | null { + const token = authSessionToken.trim(); + if (token) { + return { type: 'auth_resume', sessionToken: token }; + } + if (authMode === 'register') { + const username = sanitizeAuthUsername(dom.registerUsername.value); + const password = dom.registerPassword.value; + const email = dom.registerEmail.value.trim(); + if (!username || !password) return null; + return { type: 'auth_register', username, password, ...(email ? { email } : {}) }; + } + const username = sanitizeAuthUsername(dom.authUsername.value); + const password = dom.authPassword.value; + if (!username || !password) return null; + return { type: 'auth_login', username, password }; +} + +/** Sends current auth request over signaling websocket after socket open. */ +function sendAuthRequest(): void { + const packet = buildAuthRequestPacket(); + if (!packet) { + updateStatus('Enter username and password.'); + audio.sfxUiCancel(); + mediaSession.setConnecting(false); + updateConnectAvailability(); + signaling.disconnect(); + return; + } + pendingAuthRequest = true; + setConnectionStatus('Authenticating...'); + signaling.send(packet); +} + +/** Handles server auth-required prompts prior to world welcome. */ +function handleAuthRequired(message: string): void { + setConnectionStatus('Authentication required.'); + updateStatus(message); +} + +/** Applies auth result state and terminates failed auth attempts quickly. */ +async function handleAuthResult(message: Extract): Promise { + pendingAuthRequest = false; + if (!message.ok) { + dom.authPassword.value = ''; + dom.registerPassword.value = ''; + if (message.message.toLowerCase().includes('session')) { + authSessionToken = ''; + settings.saveAuthSessionToken(''); + } + updateStatus(message.message); + audio.sfxUiCancel(); + mediaSession.setConnecting(false); + updateConnectAvailability(); + signaling.disconnect(); + setConnectionStatus('Authentication failed.'); + return; + } + + if (message.sessionToken) { + authSessionToken = message.sessionToken; + settings.saveAuthSessionToken(message.sessionToken); + } + if (message.username) { + authUsername = message.username; + settings.saveAuthUsername(message.username); + dom.authUsername.value = message.username; + dom.registerUsername.value = message.username; + } + if (message.nickname) { + const resolved = sanitizeName(message.nickname); + if (resolved) { + state.player.nickname = resolved; + dom.preconnectNickname.value = resolved; + settings.saveNickname(resolved); + } + } + dom.authPassword.value = ''; + dom.registerPassword.value = ''; + setConnectionStatus('Authenticated. Joining world...'); +} + +/** Clears stored auth session and returns UI to login mode. */ +function logOutAccount(): void { + authSessionToken = ''; + authUsername = ''; + settings.saveAuthSessionToken(''); + settings.saveAuthUsername(''); + if (state.running) { + signaling.send({ type: 'auth_logout' }); + disconnect(); + } + setAuthMode('login'); + updateStatus('Logged out.'); + updateConnectAvailability(); +} + /** Builds dependencies shared by connect/disconnect flow helpers. */ function getConnectionFlowDeps(): ConnectFlowDeps { return { @@ -1322,6 +1471,7 @@ function getConnectionFlowDeps(): ConnectFlowDeps { mediaDescribeError: (error) => describeMediaError(error), mediaStopLocalMedia: () => stopLocalMedia(), signalingConnect: (handler) => signaling.connect(handler as (message: IncomingMessage) => Promise), + signalingSendAuth: () => sendAuthRequest(), signalingDisconnect: () => signaling.disconnect(), onMessage: (message) => onSignalingMessage(message as IncomingMessage), persistPlayerPosition, @@ -1423,6 +1573,8 @@ const onAppMessage = createOnMessageHandler({ playIncomingItemUseSound: (url, x, y) => { void audio.playSpatialSample(url, { x, y }, { x: state.player.x, y: state.player.y }, 1); }, + handleAuthRequired, + handleAuthResult, }); /** Handles signaling packets with heartbeat/restart metadata before app-level dispatch. */ @@ -2429,20 +2581,51 @@ function setupUiHandlers(): void { persistPlayerPosition(); }, }); + dom.showRegisterButton.addEventListener('click', () => { + setAuthMode('register'); + dom.registerUsername.focus(); + }); + dom.showLoginButton.addEventListener('click', () => { + setAuthMode('login'); + dom.authUsername.focus(); + }); + dom.logoutButton.addEventListener('click', () => { + logOutAccount(); + }); + dom.authUsername.addEventListener('input', () => { + dom.authUsername.value = sanitizeAuthUsername(dom.authUsername.value); + updateConnectAvailability(); + }); + dom.authPassword.addEventListener('input', () => { + updateConnectAvailability(); + }); + dom.registerUsername.addEventListener('input', () => { + dom.registerUsername.value = sanitizeAuthUsername(dom.registerUsername.value); + updateConnectAvailability(); + }); + dom.registerPassword.addEventListener('input', () => { + updateConnectAvailability(); + }); + dom.registerEmail.addEventListener('input', () => { + updateConnectAvailability(); + }); } setupInputHandlers(); setupUiHandlers(); +dom.authUsername.value = sanitizeAuthUsername(authUsername); +dom.registerUsername.value = sanitizeAuthUsername(authUsername); const storedNickname = sanitizeName(settings.loadNickname()); dom.preconnectNickname.value = storedNickname; if (storedNickname) { state.player.nickname = storedNickname; } +setAuthMode('login'); updateConnectAvailability(); updateDeviceSummary(); updateStatus( isVersionReloadedSession() ? 'Client updated, please reconnect.' - : 'Welcome to the Chat Grid. Press the Settings button to configure your audio, then Connect to join the grid.', + : 'Welcome to the Chat Grid. Log in or register, configure audio if needed, then Connect to join the grid.', ); setConnectionStatus(isVersionReloadedSession() ? 'Client updated, please reconnect.' : 'Not connected.'); diff --git a/client/src/network/messageHandlers.ts b/client/src/network/messageHandlers.ts index 6cfe5bf..7a7ff78 100644 --- a/client/src/network/messageHandlers.ts +++ b/client/src/network/messageHandlers.ts @@ -73,6 +73,8 @@ type MessageHandlerDeps = { playLocateToneAt: (x: number, y: number) => void; resolveIncomingSoundUrl: (url: string) => string; playIncomingItemUseSound: (url: string, x: number, y: number) => void; + handleAuthRequired: (message: string) => void; + handleAuthResult: (message: Extract) => Promise; }; /** @@ -81,6 +83,14 @@ type MessageHandlerDeps = { export function createOnMessageHandler(deps: MessageHandlerDeps): (message: IncomingMessage) => Promise { return async function onMessage(message: IncomingMessage): Promise { switch (message.type) { + case 'auth_required': + deps.handleAuthRequired(message.message); + break; + + case 'auth_result': + await deps.handleAuthResult(message); + break; + case 'welcome': if (message.worldConfig?.gridSize && Number.isInteger(message.worldConfig.gridSize) && message.worldConfig.gridSize > 0) { deps.setWorldGridSize(message.worldConfig.gridSize); diff --git a/client/src/network/protocol.ts b/client/src/network/protocol.ts index 620583b..1a1290f 100644 --- a/client/src/network/protocol.ts +++ b/client/src/network/protocol.ts @@ -48,6 +48,14 @@ export const welcomeMessageSchema = z.object({ version: z.string().optional(), }) .optional(), + auth: z + .object({ + authenticated: z.boolean(), + userId: z.string().nullable().optional(), + username: z.string().nullable().optional(), + role: z.string().nullable().optional(), + }) + .optional(), uiDefinitions: z .object({ itemTypeOrder: z.array(z.string().min(1)), @@ -85,6 +93,21 @@ export const welcomeMessageSchema = z.object({ .optional(), }); +export const authRequiredSchema = z.object({ + type: z.literal('auth_required'), + message: z.string(), +}); + +export const authResultSchema = z.object({ + type: z.literal('auth_result'), + ok: z.boolean(), + message: z.string(), + sessionToken: z.string().optional(), + username: z.string().optional(), + role: z.string().optional(), + nickname: z.string().optional(), +}); + export const signalMessageSchema = z.object({ type: z.literal('signal'), senderId: z.string(), @@ -203,6 +226,8 @@ export const itemPianoStatusSchema = z.object({ }); export const incomingMessageSchema = z.discriminatedUnion('type', [ + authRequiredSchema, + authResultSchema, welcomeMessageSchema, signalMessageSchema, updatePositionSchema, @@ -223,6 +248,10 @@ export const incomingMessageSchema = z.discriminatedUnion('type', [ export type IncomingMessage = z.infer; export type OutgoingMessage = + | { type: 'auth_register'; username: string; password: string; email?: string } + | { type: 'auth_login'; username: string; password: string } + | { type: 'auth_resume'; sessionToken: string } + | { type: 'auth_logout' } | { type: 'signal'; targetId: string; sdp?: RTCSessionDescriptionInit; ice?: RTCIceCandidateInit } | { type: 'update_position'; x: number; y: number } | { type: 'teleport_complete'; x: number; y: number } diff --git a/client/src/session/connectionFlow.ts b/client/src/session/connectionFlow.ts index 577f9af..d33b244 100644 --- a/client/src/session/connectionFlow.ts +++ b/client/src/session/connectionFlow.ts @@ -29,6 +29,7 @@ export type ConnectFlowDeps = { mediaDescribeError: (error: unknown) => string; mediaStopLocalMedia: () => void; signalingConnect: (onMessage: (message: unknown) => Promise) => Promise; + signalingSendAuth: () => void; signalingDisconnect: () => void; onMessage: (message: unknown) => Promise; persistPlayerPosition: () => void; @@ -46,14 +47,11 @@ export async function runConnectFlow(deps: ConnectFlowDeps): Promise { return; } const nickname = deps.sanitizeName(deps.dom.preconnectNickname.value); - if (!nickname) { - deps.updateStatus('Nickname is required.'); - deps.updateConnectAvailability(); - return; - } - deps.state.player.nickname = nickname; + deps.state.player.nickname = nickname || deps.state.player.nickname; deps.dom.preconnectNickname.value = nickname; - deps.settingsSaveNickname(nickname); + if (nickname) { + deps.settingsSaveNickname(nickname); + } deps.mediaSetConnecting(true); deps.updateConnectAvailability(); @@ -85,6 +83,7 @@ export async function runConnectFlow(deps: ConnectFlowDeps): Promise { try { await deps.signalingConnect(deps.onMessage); + deps.signalingSendAuth(); window.setTimeout(() => { if (deps.state.running || !deps.mediaIsConnecting()) { return; diff --git a/client/src/settings/settingsStore.ts b/client/src/settings/settingsStore.ts index e5f9cd7..d248a9a 100644 --- a/client/src/settings/settingsStore.ts +++ b/client/src/settings/settingsStore.ts @@ -11,6 +11,8 @@ const MIC_INPUT_GAIN_STORAGE_KEY = 'chatGridMicInputGain'; const MASTER_VOLUME_STORAGE_KEY = 'chatGridMasterVolume'; const PEER_LISTEN_GAINS_STORAGE_KEY = 'chatGridPeerListenGains'; const NICKNAME_STORAGE_KEY = 'spatialChatNickname'; +const AUTH_SESSION_TOKEN_STORAGE_KEY = 'chatGridAuthSessionToken'; +const AUTH_USERNAME_STORAGE_KEY = 'chatGridAuthUsername'; type DevicePreference = { id: string; @@ -113,6 +115,30 @@ export class SettingsStore { localStorage.setItem(NICKNAME_STORAGE_KEY, value); } + loadAuthSessionToken(): string { + return localStorage.getItem(AUTH_SESSION_TOKEN_STORAGE_KEY) || ''; + } + + saveAuthSessionToken(token: string): void { + if (token) { + localStorage.setItem(AUTH_SESSION_TOKEN_STORAGE_KEY, token); + return; + } + localStorage.removeItem(AUTH_SESSION_TOKEN_STORAGE_KEY); + } + + loadAuthUsername(): string { + return localStorage.getItem(AUTH_USERNAME_STORAGE_KEY) || ''; + } + + saveAuthUsername(username: string): void { + if (username) { + localStorage.setItem(AUTH_USERNAME_STORAGE_KEY, username); + return; + } + localStorage.removeItem(AUTH_USERNAME_STORAGE_KEY); + } + loadOutputMode(): 'mono' | 'stereo' { return localStorage.getItem(AUDIO_OUTPUT_MODE_STORAGE_KEY) === 'mono' ? 'mono' : 'stereo'; } diff --git a/client/src/styles.css b/client/src/styles.css index 8374215..8993308 100644 --- a/client/src/styles.css +++ b/client/src/styles.css @@ -79,6 +79,44 @@ body { margin-bottom: 0.75rem; } +.auth-panel { + width: min(460px, 95vw); + margin: 0 auto 0.75rem; + padding: 0.65rem; + border: 1px solid #334155; + border-radius: 0.6rem; + background: rgb(17 24 39 / 60%); +} + +.auth-panel h2 { + margin: 0 0 0.5rem; + font-size: 1rem; + color: #cbd5e1; +} + +.auth-row { + display: flex; + align-items: center; + justify-content: center; + gap: 0.5rem; + margin: 0.35rem 0; +} + +.auth-row label { + color: #cbd5e1; + min-width: 135px; + text-align: right; +} + +.auth-row input { + background: #111827; + color: #e5e7eb; + border: 1px solid #334155; + border-radius: 0.5rem; + padding: 0.4rem 0.6rem; + width: min(280px, 60vw); +} + .nickname-row { display: flex; justify-content: center; diff --git a/deploy/README.md b/deploy/README.md index befabdd..f61f2f7 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -37,6 +37,7 @@ Notes: This creates: - `/home/bestmidi/chgrid/server/.venv` - `/home/bestmidi/chgrid/server/config.toml` (if missing) +- `/home/bestmidi/chgrid/server/.env` with `CHGRID_AUTH_SECRET` (if missing) Edit `/home/bestmidi/chgrid/server/config.toml`: - `server.bind_ip = "127.0.0.1"` @@ -45,6 +46,7 @@ Edit `/home/bestmidi/chgrid/server/config.toml`: - `tls.cert_file = ""` - `tls.key_file = ""` - `storage.state_file = "runtime/items.json"` +- `auth.db_file = "runtime/chatgrid.db"` ## 4) Build and publish client diff --git a/deploy/scripts/install_server.sh b/deploy/scripts/install_server.sh index fa2432b..f166f2b 100755 --- a/deploy/scripts/install_server.sh +++ b/deploy/scripts/install_server.sh @@ -49,5 +49,17 @@ fi mkdir -p runtime +if [[ ! -f .env ]]; then + AUTH_SECRET="$( + python3 - <<'PY' +import secrets +print(secrets.token_urlsafe(64)) +PY + )" + printf "CHGRID_AUTH_SECRET=%s\n" "$AUTH_SECRET" > .env + chmod 600 .env + echo "created $SERVER_DIR/.env with CHGRID_AUTH_SECRET" +fi + echo "server install complete" echo "next: edit $SERVER_DIR/config.toml (TLS, bind_ip, port)" diff --git a/deploy/systemd/chat-grid.service b/deploy/systemd/chat-grid.service index 0c0fcc5..df30cf2 100644 --- a/deploy/systemd/chat-grid.service +++ b/deploy/systemd/chat-grid.service @@ -8,6 +8,7 @@ User=bestmidi Group=bestmidi WorkingDirectory=/home/bestmidi/chgrid/server Environment=PATH=/home/bestmidi/chgrid/server/.venv/bin:/usr/bin:/bin +EnvironmentFile=-/home/bestmidi/chgrid/server/.env ExecStartPre=/usr/bin/mkdir -p /home/bestmidi/chgrid/server/runtime ExecStart=/home/bestmidi/chgrid/server/.venv/bin/python main.py --config /home/bestmidi/chgrid/server/config.toml StandardOutput=append:/home/bestmidi/chgrid/server/runtime/server.log diff --git a/docs/protocol-notes.md b/docs/protocol-notes.md index f270697..4da6cd6 100644 --- a/docs/protocol-notes.md +++ b/docs/protocol-notes.md @@ -10,6 +10,10 @@ This is a behavior guide for packet semantics beyond raw schemas. ## Client -> Server +- `auth_register`: create account with username/password and optional email. +- `auth_login`: authenticate with username/password. +- `auth_resume`: resume prior session via stored session token. +- `auth_logout`: revoke current session and disconnect. - `update_position`: client movement intent; server enforces world bounds and movement rate policy. - `teleport_complete`: client signals teleport landing; server rebroadcasts spatial landing cue. - `update_nickname`: nickname change request (server enforces uniqueness). @@ -21,6 +25,8 @@ This is a behavior guide for packet semantics beyond raw schemas. ## Server -> Client +- `auth_required`: authentication challenge after websocket connect. +- `auth_result`: auth success/failure and session/account metadata. - `welcome`: initial snapshot with users/items plus server UI/world metadata. - `signal`: forwarded WebRTC offer/answer/ICE. - `update_position`, `update_nickname`, `user_left`: presence updates. @@ -51,6 +57,11 @@ This is a behavior guide for packet semantics beyond raw schemas. ## Welcome Metadata +- `welcome.auth`: authenticated account identity: + - `authenticated` + - `userId` + - `username` + - `role` - `welcome.worldConfig.gridSize`: server-authoritative grid size used by clients for bounds/drawing. - `welcome.worldConfig.movementTickMs`: server movement-rate window used for client movement pacing. - `welcome.worldConfig.movementMaxStepsPerTick`: max allowed grid steps per movement window. diff --git a/docs/runtime-flow.md b/docs/runtime-flow.md index fa2703a..a99d15e 100644 --- a/docs/runtime-flow.md +++ b/docs/runtime-flow.md @@ -3,10 +3,13 @@ ## Connect Flow 1. User clicks connect. -2. Client validates nickname and sets up local media. +2. Client validates auth form/session token and sets up local media. 3. Client connects signaling websocket. -4. Server sends `welcome` with users/items snapshot. -5. Client: +4. Server sends `auth_required`. +5. Client sends `auth_login`, `auth_register`, or `auth_resume`. +6. Server sends `auth_result`. +7. Server sends `welcome` with users/items snapshot. +8. Client: - applies `welcome.worldConfig.gridSize` for authoritative grid bounds/rendering - applies `welcome.worldConfig.movementTickMs` as movement pacing guidance - applies `welcome.worldConfig.movementMaxStepsPerTick` for movement-rate parity @@ -38,6 +41,8 @@ Each frame: Core incoming message effects: - `signal`: WebRTC negotiation and ICE exchange. +- `auth_required`: prompt client to authenticate before gameplay messages. +- `auth_result`: auth success/failure with optional session token + account metadata. - `update_position`: update peer position; may play movement/teleport world sound. - `teleport_complete`: play peer teleport landing sound at final tile. - `update_nickname`: update peer display name. diff --git a/server/app/auth_service.py b/server/app/auth_service.py new file mode 100644 index 0000000..5a507cb --- /dev/null +++ b/server/app/auth_service.py @@ -0,0 +1,397 @@ +"""Account and session persistence service for websocket authentication.""" + +from __future__ import annotations + +from dataclasses import dataclass +import base64 +import hashlib +import hmac +import os +from pathlib import Path +import re +import secrets +import sqlite3 +import time +import uuid + + +SESSION_TTL_MS = 14 * 24 * 60 * 60 * 1000 +SALT_BYTES = 16 +PBKDF2_ITERATIONS = 310_000 +PBKDF2_DKLEN = 32 +USERNAME_PATTERN = re.compile(r"^[a-z0-9_-]+$") + + +@dataclass(frozen=True) +class AuthUser: + """Authenticated account identity details.""" + + id: str + username: str + role: str + status: str + email: str | None + last_nickname: str | None + + +@dataclass(frozen=True) +class AuthSession: + """Session validation result with user identity.""" + + session_id: str + token: str + user: AuthUser + + +class AuthError(ValueError): + """Raised when authentication input or policy checks fail.""" + + +class AuthService: + """Manages account registration, login, and rolling session validation.""" + + def __init__( + self, + db_path: Path, + token_hash_secret: str, + password_min_length: int, + password_max_length: int, + username_min_length: int, + username_max_length: int, + ): + """Initialize auth database connection and schema.""" + + self.db_path = db_path + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self.password_min_length = max(1, int(password_min_length)) + self.password_max_length = max(self.password_min_length, int(password_max_length)) + self.username_min_length = max(1, int(username_min_length)) + self.username_max_length = max(self.username_min_length, int(username_max_length)) + secret = token_hash_secret.strip() + if not secret: + raise AuthError("CHGRID_AUTH_SECRET is required when auth is enabled.") + self._token_secret = secret.encode("utf-8") + self._conn = sqlite3.connect(self.db_path, check_same_thread=False) + self._conn.row_factory = sqlite3.Row + self._ensure_schema() + + def close(self) -> None: + """Close the underlying SQLite connection.""" + + self._conn.close() + + def bootstrap_admin(self, username: str, password: str, email: str | None = None) -> AuthUser: + """Create the first admin account, or fail if one already exists.""" + + existing = self._conn.execute("SELECT 1 FROM users WHERE role = 'admin' LIMIT 1").fetchone() + if existing is not None: + raise AuthError("An admin account already exists.") + created = self.register(username, password, email=email, role="admin") + return created.user + + def register( + self, + username: str, + password: str, + *, + email: str | None = None, + role: str = "user", + ) -> AuthSession: + """Register an account and issue a session token.""" + + normalized_username = self._normalize_username(username) + self._validate_username(normalized_username) + self._validate_password(password) + normalized_email = self._normalize_email(email) + if role not in {"user", "admin"}: + raise AuthError("role must be user or admin.") + now_ms = self.now_ms() + user_id = str(uuid.uuid4()) + password_hash = self._hash_password(password) + try: + self._conn.execute( + """ + INSERT INTO users ( + id, username, password_hash, email, role, status, last_nickname, created_at_ms, updated_at_ms, last_login_at_ms + ) VALUES (?, ?, ?, ?, ?, 'active', NULL, ?, ?, ?) + """, + (user_id, normalized_username, password_hash, normalized_email, role, now_ms, now_ms, now_ms), + ) + self._conn.commit() + except sqlite3.IntegrityError as exc: + message = str(exc).lower() + if "users.username" in message: + raise AuthError("Username is already taken.") from exc + if "users.email" in message: + raise AuthError("Email is already in use.") from exc + raise + user = self._get_user_by_username(normalized_username) + if user is None: + raise AuthError("Failed to load newly created user.") + return self._create_session(user) + + def login(self, username: str, password: str) -> AuthSession: + """Authenticate credentials and issue a fresh session.""" + + normalized_username = self._normalize_username(username) + user_row = self._conn.execute( + """ + SELECT id, username, password_hash, email, role, status, last_nickname + FROM users + WHERE username = ? + """, + (normalized_username,), + ).fetchone() + if user_row is None: + raise AuthError("Invalid username or password.") + if user_row["status"] != "active": + raise AuthError("Account is disabled.") + if not self._verify_password(password, user_row["password_hash"]): + raise AuthError("Invalid username or password.") + user = self._row_to_user(user_row) + self._conn.execute( + "UPDATE users SET last_login_at_ms = ?, updated_at_ms = ? WHERE id = ?", + (self.now_ms(), self.now_ms(), user.id), + ) + self._conn.commit() + return self._create_session(user) + + def resume(self, token: str) -> AuthSession: + """Validate a session token and apply rolling expiry.""" + + cleaned = token.strip() + if not cleaned: + raise AuthError("Missing session token.") + token_hash = self._hash_token(cleaned) + row = self._conn.execute( + """ + SELECT s.id AS session_id, s.user_id, s.expires_at_ms, s.revoked_at_ms, + u.username, u.role, u.status, u.email, u.last_nickname + FROM sessions s + JOIN users u ON u.id = s.user_id + WHERE s.token_hash = ? + """, + (token_hash,), + ).fetchone() + if row is None: + raise AuthError("Invalid session.") + if row["revoked_at_ms"] is not None: + raise AuthError("Session has been revoked.") + now_ms = self.now_ms() + if int(row["expires_at_ms"]) <= now_ms: + self._conn.execute("UPDATE sessions SET revoked_at_ms = ? WHERE id = ?", (now_ms, row["session_id"])) + self._conn.commit() + raise AuthError("Session has expired.") + if row["status"] != "active": + raise AuthError("Account is disabled.") + new_expiry = now_ms + SESSION_TTL_MS + self._conn.execute( + "UPDATE sessions SET last_seen_at_ms = ?, expires_at_ms = ? WHERE id = ?", + (now_ms, new_expiry, row["session_id"]), + ) + self._conn.commit() + user = AuthUser( + id=row["user_id"], + username=row["username"], + role=row["role"], + status=row["status"], + email=row["email"], + last_nickname=row["last_nickname"], + ) + return AuthSession(session_id=row["session_id"], token=cleaned, user=user) + + def revoke(self, token: str) -> None: + """Revoke a session token if it exists.""" + + cleaned = token.strip() + if not cleaned: + return + token_hash = self._hash_token(cleaned) + self._conn.execute( + "UPDATE sessions SET revoked_at_ms = ? WHERE token_hash = ? AND revoked_at_ms IS NULL", + (self.now_ms(), token_hash), + ) + self._conn.commit() + + def set_last_nickname(self, user_id: str, nickname: str) -> None: + """Persist the most recent nickname for one user.""" + + cleaned = nickname.strip() + if not cleaned: + return + self._conn.execute( + "UPDATE users SET last_nickname = ?, updated_at_ms = ? WHERE id = ?", + (cleaned, self.now_ms(), user_id), + ) + self._conn.commit() + + @staticmethod + def now_ms() -> int: + """Return unix epoch timestamp in milliseconds.""" + + return int(time.time() * 1000) + + def _ensure_schema(self) -> None: + """Create required auth tables and indexes when missing.""" + + self._conn.execute("PRAGMA foreign_keys = ON") + self._conn.execute( + """ + CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + email TEXT UNIQUE, + role TEXT NOT NULL CHECK(role IN ('user', 'admin')) DEFAULT 'user', + status TEXT NOT NULL CHECK(status IN ('active', 'disabled')) DEFAULT 'active', + last_nickname TEXT, + created_at_ms INTEGER NOT NULL, + updated_at_ms INTEGER NOT NULL, + last_login_at_ms INTEGER + ) + """ + ) + self._conn.execute( + """ + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + token_hash TEXT NOT NULL UNIQUE, + created_at_ms INTEGER NOT NULL, + last_seen_at_ms INTEGER NOT NULL, + expires_at_ms INTEGER NOT NULL, + revoked_at_ms INTEGER, + ip TEXT, + user_agent TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE + ) + """ + ) + self._conn.execute("CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)") + self._conn.execute( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email) WHERE email IS NOT NULL" + ) + self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id)") + self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at_ms)") + self._conn.execute("CREATE INDEX IF NOT EXISTS idx_sessions_token_hash ON sessions(token_hash)") + self._conn.commit() + + def _create_session(self, user: AuthUser) -> AuthSession: + """Issue and persist a new session token for a user.""" + + token = secrets.token_urlsafe(48) + token_hash = self._hash_token(token) + now_ms = self.now_ms() + expires_at_ms = now_ms + SESSION_TTL_MS + session_id = str(uuid.uuid4()) + self._conn.execute( + """ + INSERT INTO sessions (id, user_id, token_hash, created_at_ms, last_seen_at_ms, expires_at_ms, revoked_at_ms, ip, user_agent) + VALUES (?, ?, ?, ?, ?, ?, NULL, NULL, NULL) + """, + (session_id, user.id, token_hash, now_ms, now_ms, expires_at_ms), + ) + self._conn.commit() + return AuthSession(session_id=session_id, token=token, user=user) + + def _get_user_by_username(self, username: str) -> AuthUser | None: + """Fetch one user by normalized username.""" + + row = self._conn.execute( + "SELECT id, username, role, status, email, last_nickname FROM users WHERE username = ?", + (username,), + ).fetchone() + if row is None: + return None + return self._row_to_user(row) + + @staticmethod + def _row_to_user(row: sqlite3.Row) -> AuthUser: + """Convert a DB row into AuthUser.""" + + return AuthUser( + id=row["id"], + username=row["username"], + role=row["role"], + status=row["status"], + email=row["email"], + last_nickname=row["last_nickname"], + ) + + @staticmethod + def _normalize_username(username: str) -> str: + """Normalize username into canonical stored form.""" + + return username.strip().lower() + + @staticmethod + def _normalize_email(email: str | None) -> str | None: + """Normalize optional email and collapse blanks to None.""" + + if email is None: + return None + cleaned = email.strip().lower() + return cleaned or None + + def _validate_username(self, username: str) -> None: + """Validate username against length and character policy.""" + + if not (self.username_min_length <= len(username) <= self.username_max_length): + raise AuthError( + f"Username must be between {self.username_min_length} and {self.username_max_length} characters." + ) + if USERNAME_PATTERN.fullmatch(username) is None: + raise AuthError("Username may include lowercase letters, numbers, underscores, and dashes only.") + + def _validate_password(self, password: str) -> None: + """Validate password length policy.""" + + if not (self.password_min_length <= len(password) <= self.password_max_length): + raise AuthError( + f"Password must be between {self.password_min_length} and {self.password_max_length} characters." + ) + + @staticmethod + def _hash_password(password: str) -> str: + """Hash a password with PBKDF2-HMAC-SHA256 and random salt.""" + + salt = os.urandom(SALT_BYTES) + digest = hashlib.pbkdf2_hmac( + "sha256", + password.encode("utf-8"), + salt, + PBKDF2_ITERATIONS, + dklen=PBKDF2_DKLEN, + ) + salt_b64 = base64.b64encode(salt).decode("ascii") + digest_b64 = base64.b64encode(digest).decode("ascii") + return f"pbkdf2_sha256${PBKDF2_ITERATIONS}${salt_b64}${digest_b64}" + + @staticmethod + def _verify_password(password: str, stored: str) -> bool: + """Verify plaintext password against stored PBKDF2 hash.""" + + try: + algo, iterations_raw, salt_b64, digest_b64 = stored.split("$", 3) + except ValueError: + return False + if algo != "pbkdf2_sha256": + return False + try: + salt = base64.b64decode(salt_b64.encode("ascii")) + expected = base64.b64decode(digest_b64.encode("ascii")) + computed = hashlib.pbkdf2_hmac( + "sha256", + password.encode("utf-8"), + salt, + int(iterations_raw), + dklen=len(expected), + ) + except (ValueError, TypeError): + return False + return hmac.compare_digest(computed, expected) + + def _hash_token(self, token: str) -> str: + """Hash a session token with server secret before persistence.""" + + return hmac.new(self._token_secret, token.encode("utf-8"), hashlib.sha256).hexdigest() diff --git a/server/app/client.py b/server/app/client.py index 2236f8a..8f3f897 100644 --- a/server/app/client.py +++ b/server/app/client.py @@ -13,6 +13,11 @@ class ClientConnection: websocket: ServerConnection id: str + authenticated: bool = False + user_id: str | None = None + username: str | None = None + role: str = "user" + session_token: str | None = None nickname: str = "user..." x: int = 20 y: int = 20 diff --git a/server/app/config.py b/server/app/config.py index b814114..ed4865b 100644 --- a/server/app/config.py +++ b/server/app/config.py @@ -49,6 +49,16 @@ class WorldConfigSection(BaseModel): grid_size: int = Field(default=41, ge=1) +class AuthConfigSection(BaseModel): + """Authentication persistence and validation settings.""" + + db_file: str = "runtime/chatgrid.db" + password_min_length: int = Field(default=8, ge=1) + password_max_length: int = Field(default=32, ge=1) + username_min_length: int = Field(default=2, ge=1) + username_max_length: int = Field(default=32, ge=1) + + class AppConfig(BaseModel): """Top-level application configuration document.""" @@ -58,6 +68,7 @@ class AppConfig(BaseModel): logging: LoggingConfigSection = LoggingConfigSection() storage: StorageConfigSection = StorageConfigSection() world: WorldConfigSection = WorldConfigSection() + auth: AuthConfigSection = AuthConfigSection() def load_config(path: Path | None) -> AppConfig: diff --git a/server/app/item_service.py b/server/app/item_service.py index 91f4b74..94e2027 100644 --- a/server/app/item_service.py +++ b/server/app/item_service.py @@ -45,7 +45,7 @@ class ItemService: title=item_def.default_title, x=client.x, y=client.y, - createdBy=client.id, + createdBy=client.username or client.nickname or client.id, createdAt=now, updatedAt=now, version=1, diff --git a/server/app/models.py b/server/app/models.py index e6d87ef..e651cbd 100644 --- a/server/app/models.py +++ b/server/app/models.py @@ -41,6 +41,28 @@ class ChatMessagePacket(BasePacket): message: str = Field(min_length=1, max_length=500) +class AuthRegisterPacket(BasePacket): + type: Literal["auth_register"] + username: str = Field(min_length=1, max_length=128) + password: str = Field(min_length=1, max_length=256) + email: str | None = Field(default=None, max_length=320) + + +class AuthLoginPacket(BasePacket): + type: Literal["auth_login"] + username: str = Field(min_length=1, max_length=128) + password: str = Field(min_length=1, max_length=256) + + +class AuthResumePacket(BasePacket): + type: Literal["auth_resume"] + sessionToken: str = Field(min_length=1, max_length=512) + + +class AuthLogoutPacket(BasePacket): + type: Literal["auth_logout"] + + class PingPacket(BasePacket): type: Literal["ping"] clientSentAt: int @@ -100,6 +122,10 @@ ClientPacket = ( | TeleportCompletePacket | UpdateNicknamePacket | ChatMessagePacket + | AuthRegisterPacket + | AuthLoginPacket + | AuthResumePacket + | AuthLogoutPacket | PingPacket | ItemAddPacket | ItemPickupPacket @@ -128,6 +154,22 @@ class WelcomePacket(BasePacket): worldConfig: dict | None = None uiDefinitions: dict | None = None serverInfo: dict | None = None + auth: dict | None = None + + +class AuthRequiredPacket(BasePacket): + type: Literal["auth_required"] + message: str + + +class AuthResultPacket(BasePacket): + type: Literal["auth_result"] + ok: bool + message: str + sessionToken: str | None = None + username: str | None = None + role: str | None = None + nickname: str | None = None class UserLeftPacket(BasePacket): diff --git a/server/app/server.py b/server/app/server.py index 7703af5..5f93373 100644 --- a/server/app/server.py +++ b/server/app/server.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse import asyncio from datetime import datetime +from getpass import getpass from importlib.metadata import PackageNotFoundError, version as package_version import json import logging @@ -21,6 +22,7 @@ from zoneinfo import ZoneInfo from pydantic import ValidationError, TypeAdapter from websockets.asyncio.server import ServerConnection, serve +from .auth_service import AuthError, AuthService from .client import ClientConnection from .config import load_config from .item_catalog import ( @@ -39,6 +41,12 @@ from .item_catalog import ( from .item_type_handlers import get_item_type_handler from .item_service import ItemService from .models import ( + AuthLoginPacket, + AuthLogoutPacket, + AuthRegisterPacket, + AuthRequiredPacket, + AuthResultPacket, + AuthResumePacket, BroadcastChatMessagePacket, BroadcastNicknamePacket, BroadcastPositionPacket, @@ -91,6 +99,12 @@ class SignalingServer: port: int, ssl_cert: str | None, ssl_key: str | None, + auth_db_path: Path | None = None, + auth_token_hash_secret: str = "dev-secret", + password_min_length: int = 8, + password_max_length: int = 32, + username_min_length: int = 2, + username_max_length: int = 32, max_message_size: int = 2_000_000, state_file: Path | None = None, grid_size: int = 41, @@ -104,6 +118,15 @@ class SignalingServer: self.max_message_size = max_message_size self._ssl_context = self._build_ssl_context(ssl_cert, ssl_key) self.clients: dict[ServerConnection, ClientConnection] = {} + resolved_auth_db_path = auth_db_path or Path.cwd() / "runtime" / f"chatgrid_auth_{uuid.uuid4().hex}.db" + self.auth_service = AuthService( + db_path=resolved_auth_db_path, + token_hash_secret=auth_token_hash_secret, + password_min_length=password_min_length, + password_max_length=password_max_length, + username_min_length=username_min_length, + username_max_length=username_max_length, + ) self.item_service = ItemService(state_file=state_file) self.item_last_use_ms: dict[str, int] = {} self.active_piano_keys_by_client: dict[str, set[str]] = {} @@ -714,22 +737,19 @@ class SignalingServer: await asyncio.Future() finally: self._flush_state_save() + self.auth_service.close() async def _handle_client(self, websocket: ServerConnection) -> None: """Handle one websocket client's connect/message/disconnect lifecycle.""" client = ClientConnection(websocket=websocket, id=str(uuid.uuid4())) - client.x = random.randrange(self.grid_size) - client.y = random.randrange(self.grid_size) - now_ms = self.item_service.now_ms() - client.last_position_update_ms = now_ms - client.movement_window_index = self._movement_window_index(now_ms) - client.movement_window_steps_used = 0 - self.clients[websocket] = client - LOGGER.info("client connected id=%s total=%d", client.id, len(self.clients)) + LOGGER.info("websocket opened id=%s", client.id) try: - await self._send_welcome(client) + await self._send( + websocket, + AuthRequiredPacket(type="auth_required", message="Authentication required."), + ) async for raw_message in websocket: await self._handle_message(client, raw_message) finally: @@ -780,9 +800,101 @@ class SignalingServer: }, uiDefinitions=self._build_ui_definitions(), serverInfo={"instanceId": self.instance_id, "version": self.server_version}, + auth={ + "authenticated": client.authenticated, + "userId": client.user_id, + "username": client.username, + "role": client.role if client.authenticated else None, + }, ) await self._send(client.websocket, packet) + async def _activate_authenticated_client(self, client: ClientConnection) -> None: + """Move an authenticated websocket client into the active world roster.""" + + if client.websocket in self.clients: + return + client.x = random.randrange(self.grid_size) + client.y = random.randrange(self.grid_size) + now_ms = self.item_service.now_ms() + client.last_position_update_ms = now_ms + client.movement_window_index = self._movement_window_index(now_ms) + client.movement_window_steps_used = 0 + self.clients[client.websocket] = client + LOGGER.info( + "client authenticated id=%s user_id=%s username=%s total=%d", + client.id, + client.user_id, + client.username, + len(self.clients), + ) + await self._send_welcome(client) + await self._broadcast( + BroadcastChatMessagePacket( + type="chat_message", + message=f"{client.nickname} has logged in.", + system=True, + ), + exclude=client.websocket, + ) + + async def _handle_auth_packet(self, client: ClientConnection, packet: ClientPacket) -> bool: + """Handle pre-auth packets; returns True when packet was an auth command.""" + + if client.authenticated and isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)): + await self._send( + client.websocket, + AuthResultPacket(type="auth_result", ok=False, message="Already authenticated."), + ) + return True + + try: + if isinstance(packet, AuthRegisterPacket): + session = self.auth_service.register(packet.username, packet.password, email=packet.email) + elif isinstance(packet, AuthLoginPacket): + session = self.auth_service.login(packet.username, packet.password) + elif isinstance(packet, AuthResumePacket): + session = self.auth_service.resume(packet.sessionToken) + elif isinstance(packet, AuthLogoutPacket): + if client.session_token: + self.auth_service.revoke(client.session_token) + client.session_token = None + await self._send( + client.websocket, + AuthResultPacket(type="auth_result", ok=True, message="Logged out."), + ) + await client.websocket.close() + return True + else: + return False + except AuthError as exc: + await self._send( + client.websocket, + AuthResultPacket(type="auth_result", ok=False, message=str(exc)), + ) + return True + + client.authenticated = True + client.user_id = session.user.id + client.username = session.user.username + client.role = session.user.role + client.session_token = session.token + client.nickname = session.user.last_nickname or client.nickname + await self._send( + client.websocket, + AuthResultPacket( + type="auth_result", + ok=True, + message="Authenticated.", + sessionToken=session.token, + username=session.user.username, + role=session.user.role, + nickname=client.nickname, + ), + ) + await self._activate_authenticated_client(client) + return True + def _build_ui_definitions(self) -> dict: """Build server-owned UI definitions for item/menu rendering.""" @@ -840,6 +952,22 @@ class SignalingServer: PACKET_LOGGER.warning("invalid packet from id=%s: %s", client.id, exc) return + # Compatibility path for local tests injecting pre-authenticated clients + # directly into server.clients without running websocket auth handshake. + if not client.authenticated and client.websocket in self.clients: + client.authenticated = True + client.user_id = client.user_id or client.id + client.username = client.username or client.nickname + + if await self._handle_auth_packet(client, packet): + return + if not client.authenticated: + await self._send( + client.websocket, + AuthResultPacket(type="auth_result", ok=False, message="Authenticate before sending gameplay actions."), + ) + return + if isinstance(packet, UpdatePositionPacket): if not self._is_in_bounds(packet.x, packet.y): PACKET_LOGGER.warning( @@ -975,6 +1103,8 @@ class SignalingServer: ) return client.nickname = requested_nickname + if client.user_id: + self.auth_service.set_last_nickname(client.user_id, client.nickname) if old_nickname == "user...": LOGGER.info("user login id=%s nickname=%s", client.id, client.nickname) else: @@ -1471,6 +1601,7 @@ def run() -> None: parser.add_argument("--ssl-cert", default=None) parser.add_argument("--ssl-key", default=None) parser.add_argument("--allow-insecure-ws", action="store_true", default=None) + parser.add_argument("--bootstrap-admin", action="store_true", default=False) args = parser.parse_args() config_path = Path(args.config) if args.config else None @@ -1499,15 +1630,51 @@ def run() -> None: "TLS is required when insecure ws is disabled. Set tls.cert_file/tls.key_file in config.toml." ) + auth_secret = os.getenv("CHGRID_AUTH_SECRET", "").strip() + if not auth_secret: + raise SystemExit("CHGRID_AUTH_SECRET is required.") + auth_db_value = config.auth.db_file.strip() + if not auth_db_value: + raise SystemExit("auth.db_file must not be empty.") + auth_base_dir = config_path.parent if config_path is not None else Path.cwd() + auth_db_path = Path(auth_db_value) + if not auth_db_path.is_absolute(): + auth_db_path = auth_base_dir / auth_db_path + auth_db_path.parent.mkdir(parents=True, exist_ok=True) + logging.basicConfig( level=getattr(logging, config.logging.level.upper(), logging.INFO), format="%(asctime)s %(levelname)s %(name)s %(message)s", ) + if args.bootstrap_admin: + auth_service = AuthService( + db_path=auth_db_path, + token_hash_secret=auth_secret, + password_min_length=config.auth.password_min_length, + password_max_length=config.auth.password_max_length, + username_min_length=config.auth.username_min_length, + username_max_length=config.auth.username_max_length, + ) + try: + username = input("Admin username: ").strip() + password = getpass("Admin password: ") + email = input("Admin email (optional): ").strip() or None + created = auth_service.bootstrap_admin(username, password, email=email) + print(f"Admin created: {created.username}") + finally: + auth_service.close() + return server = SignalingServer( host, port, ssl_cert, ssl_key, + auth_db_path=auth_db_path, + auth_token_hash_secret=auth_secret, + password_min_length=config.auth.password_min_length, + password_max_length=config.auth.password_max_length, + username_min_length=config.auth.username_min_length, + username_max_length=config.auth.username_max_length, max_message_size=config.network.max_message_bytes, state_file=state_file, grid_size=config.world.grid_size, diff --git a/server/config.example.toml b/server/config.example.toml index e4fa2fc..e26d0b6 100644 --- a/server/config.example.toml +++ b/server/config.example.toml @@ -29,3 +29,13 @@ state_save_max_delay_ms = 1000 [world] # Grid width/height in cells. Valid coordinates are 0..grid_size-1. grid_size = 41 + +[auth] +# SQLite file for account/session data. Relative paths resolve from this config file directory. +db_file = "runtime/chatgrid.db" +# Password length policy. +password_min_length = 8 +password_max_length = 32 +# Username length policy. +username_min_length = 2 +username_max_length = 32 diff --git a/server/tests/test_auth_service.py b/server/tests/test_auth_service.py new file mode 100644 index 0000000..d2bb506 --- /dev/null +++ b/server/tests/test_auth_service.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from app.auth_service import AuthError, AuthService + + +def make_auth_service(tmp_path: Path) -> AuthService: + return AuthService( + db_path=tmp_path / "chatgrid.db", + token_hash_secret="test-secret", + password_min_length=8, + password_max_length=32, + username_min_length=2, + username_max_length=32, + ) + + +def test_register_and_resume_session(tmp_path: Path) -> None: + service = make_auth_service(tmp_path) + try: + session = service.register("User_One", "password99", email="a@example.com") + assert session.user.username == "user_one" + resumed = service.resume(session.token) + assert resumed.user.id == session.user.id + assert resumed.user.role == "user" + finally: + service.close() + + +def test_login_rejects_invalid_password(tmp_path: Path) -> None: + service = make_auth_service(tmp_path) + try: + service.register("alpha", "password99") + with pytest.raises(AuthError): + service.login("alpha", "wrong-pass") + finally: + service.close() + + +def test_bootstrap_admin_once(tmp_path: Path) -> None: + service = make_auth_service(tmp_path) + try: + admin = service.bootstrap_admin("root-admin", "password99", email=None) + assert admin.role == "admin" + with pytest.raises(AuthError): + service.bootstrap_admin("another-admin", "password99") + finally: + service.close() +