Add account auth with websocket login/register and sessions

This commit is contained in:
Jage9
2026-02-24 22:03:10 -05:00
parent 1938f239e6
commit bf3bc90f2a
21 changed files with 1053 additions and 24 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import argparse
import asyncio
from datetime import datetime
from getpass import getpass
from importlib.metadata import PackageNotFoundError, version as package_version
import json
import logging
@@ -21,6 +22,7 @@ from zoneinfo import ZoneInfo
from pydantic import ValidationError, TypeAdapter
from websockets.asyncio.server import ServerConnection, serve
from .auth_service import AuthError, AuthService
from .client import ClientConnection
from .config import load_config
from .item_catalog import (
@@ -39,6 +41,12 @@ from .item_catalog import (
from .item_type_handlers import get_item_type_handler
from .item_service import ItemService
from .models import (
AuthLoginPacket,
AuthLogoutPacket,
AuthRegisterPacket,
AuthRequiredPacket,
AuthResultPacket,
AuthResumePacket,
BroadcastChatMessagePacket,
BroadcastNicknamePacket,
BroadcastPositionPacket,
@@ -91,6 +99,12 @@ class SignalingServer:
port: int,
ssl_cert: str | None,
ssl_key: str | None,
auth_db_path: Path | None = None,
auth_token_hash_secret: str = "dev-secret",
password_min_length: int = 8,
password_max_length: int = 32,
username_min_length: int = 2,
username_max_length: int = 32,
max_message_size: int = 2_000_000,
state_file: Path | None = None,
grid_size: int = 41,
@@ -104,6 +118,15 @@ class SignalingServer:
self.max_message_size = max_message_size
self._ssl_context = self._build_ssl_context(ssl_cert, ssl_key)
self.clients: dict[ServerConnection, ClientConnection] = {}
resolved_auth_db_path = auth_db_path or Path.cwd() / "runtime" / f"chatgrid_auth_{uuid.uuid4().hex}.db"
self.auth_service = AuthService(
db_path=resolved_auth_db_path,
token_hash_secret=auth_token_hash_secret,
password_min_length=password_min_length,
password_max_length=password_max_length,
username_min_length=username_min_length,
username_max_length=username_max_length,
)
self.item_service = ItemService(state_file=state_file)
self.item_last_use_ms: dict[str, int] = {}
self.active_piano_keys_by_client: dict[str, set[str]] = {}
@@ -714,22 +737,19 @@ class SignalingServer:
await asyncio.Future()
finally:
self._flush_state_save()
self.auth_service.close()
async def _handle_client(self, websocket: ServerConnection) -> None:
"""Handle one websocket client's connect/message/disconnect lifecycle."""
client = ClientConnection(websocket=websocket, id=str(uuid.uuid4()))
client.x = random.randrange(self.grid_size)
client.y = random.randrange(self.grid_size)
now_ms = self.item_service.now_ms()
client.last_position_update_ms = now_ms
client.movement_window_index = self._movement_window_index(now_ms)
client.movement_window_steps_used = 0
self.clients[websocket] = client
LOGGER.info("client connected id=%s total=%d", client.id, len(self.clients))
LOGGER.info("websocket opened id=%s", client.id)
try:
await self._send_welcome(client)
await self._send(
websocket,
AuthRequiredPacket(type="auth_required", message="Authentication required."),
)
async for raw_message in websocket:
await self._handle_message(client, raw_message)
finally:
@@ -780,9 +800,101 @@ class SignalingServer:
},
uiDefinitions=self._build_ui_definitions(),
serverInfo={"instanceId": self.instance_id, "version": self.server_version},
auth={
"authenticated": client.authenticated,
"userId": client.user_id,
"username": client.username,
"role": client.role if client.authenticated else None,
},
)
await self._send(client.websocket, packet)
async def _activate_authenticated_client(self, client: ClientConnection) -> None:
"""Move an authenticated websocket client into the active world roster."""
if client.websocket in self.clients:
return
client.x = random.randrange(self.grid_size)
client.y = random.randrange(self.grid_size)
now_ms = self.item_service.now_ms()
client.last_position_update_ms = now_ms
client.movement_window_index = self._movement_window_index(now_ms)
client.movement_window_steps_used = 0
self.clients[client.websocket] = client
LOGGER.info(
"client authenticated id=%s user_id=%s username=%s total=%d",
client.id,
client.user_id,
client.username,
len(self.clients),
)
await self._send_welcome(client)
await self._broadcast(
BroadcastChatMessagePacket(
type="chat_message",
message=f"{client.nickname} has logged in.",
system=True,
),
exclude=client.websocket,
)
async def _handle_auth_packet(self, client: ClientConnection, packet: ClientPacket) -> bool:
"""Handle pre-auth packets; returns True when packet was an auth command."""
if client.authenticated and isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)):
await self._send(
client.websocket,
AuthResultPacket(type="auth_result", ok=False, message="Already authenticated."),
)
return True
try:
if isinstance(packet, AuthRegisterPacket):
session = self.auth_service.register(packet.username, packet.password, email=packet.email)
elif isinstance(packet, AuthLoginPacket):
session = self.auth_service.login(packet.username, packet.password)
elif isinstance(packet, AuthResumePacket):
session = self.auth_service.resume(packet.sessionToken)
elif isinstance(packet, AuthLogoutPacket):
if client.session_token:
self.auth_service.revoke(client.session_token)
client.session_token = None
await self._send(
client.websocket,
AuthResultPacket(type="auth_result", ok=True, message="Logged out."),
)
await client.websocket.close()
return True
else:
return False
except AuthError as exc:
await self._send(
client.websocket,
AuthResultPacket(type="auth_result", ok=False, message=str(exc)),
)
return True
client.authenticated = True
client.user_id = session.user.id
client.username = session.user.username
client.role = session.user.role
client.session_token = session.token
client.nickname = session.user.last_nickname or client.nickname
await self._send(
client.websocket,
AuthResultPacket(
type="auth_result",
ok=True,
message="Authenticated.",
sessionToken=session.token,
username=session.user.username,
role=session.user.role,
nickname=client.nickname,
),
)
await self._activate_authenticated_client(client)
return True
def _build_ui_definitions(self) -> dict:
"""Build server-owned UI definitions for item/menu rendering."""
@@ -840,6 +952,22 @@ class SignalingServer:
PACKET_LOGGER.warning("invalid packet from id=%s: %s", client.id, exc)
return
# Compatibility path for local tests injecting pre-authenticated clients
# directly into server.clients without running websocket auth handshake.
if not client.authenticated and client.websocket in self.clients:
client.authenticated = True
client.user_id = client.user_id or client.id
client.username = client.username or client.nickname
if await self._handle_auth_packet(client, packet):
return
if not client.authenticated:
await self._send(
client.websocket,
AuthResultPacket(type="auth_result", ok=False, message="Authenticate before sending gameplay actions."),
)
return
if isinstance(packet, UpdatePositionPacket):
if not self._is_in_bounds(packet.x, packet.y):
PACKET_LOGGER.warning(
@@ -975,6 +1103,8 @@ class SignalingServer:
)
return
client.nickname = requested_nickname
if client.user_id:
self.auth_service.set_last_nickname(client.user_id, client.nickname)
if old_nickname == "user...":
LOGGER.info("user login id=%s nickname=%s", client.id, client.nickname)
else:
@@ -1471,6 +1601,7 @@ def run() -> None:
parser.add_argument("--ssl-cert", default=None)
parser.add_argument("--ssl-key", default=None)
parser.add_argument("--allow-insecure-ws", action="store_true", default=None)
parser.add_argument("--bootstrap-admin", action="store_true", default=False)
args = parser.parse_args()
config_path = Path(args.config) if args.config else None
@@ -1499,15 +1630,51 @@ def run() -> None:
"TLS is required when insecure ws is disabled. Set tls.cert_file/tls.key_file in config.toml."
)
auth_secret = os.getenv("CHGRID_AUTH_SECRET", "").strip()
if not auth_secret:
raise SystemExit("CHGRID_AUTH_SECRET is required.")
auth_db_value = config.auth.db_file.strip()
if not auth_db_value:
raise SystemExit("auth.db_file must not be empty.")
auth_base_dir = config_path.parent if config_path is not None else Path.cwd()
auth_db_path = Path(auth_db_value)
if not auth_db_path.is_absolute():
auth_db_path = auth_base_dir / auth_db_path
auth_db_path.parent.mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=getattr(logging, config.logging.level.upper(), logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
if args.bootstrap_admin:
auth_service = AuthService(
db_path=auth_db_path,
token_hash_secret=auth_secret,
password_min_length=config.auth.password_min_length,
password_max_length=config.auth.password_max_length,
username_min_length=config.auth.username_min_length,
username_max_length=config.auth.username_max_length,
)
try:
username = input("Admin username: ").strip()
password = getpass("Admin password: ")
email = input("Admin email (optional): ").strip() or None
created = auth_service.bootstrap_admin(username, password, email=email)
print(f"Admin created: {created.username}")
finally:
auth_service.close()
return
server = SignalingServer(
host,
port,
ssl_cert,
ssl_key,
auth_db_path=auth_db_path,
auth_token_hash_secret=auth_secret,
password_min_length=config.auth.password_min_length,
password_max_length=config.auth.password_max_length,
username_min_length=config.auth.username_min_length,
username_max_length=config.auth.username_max_length,
max_message_size=config.network.max_message_bytes,
state_file=state_file,
grid_size=config.world.grid_size,