Add account auth with websocket login/register and sessions
This commit is contained in:
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from getpass import getpass
|
||||
from importlib.metadata import PackageNotFoundError, version as package_version
|
||||
import json
|
||||
import logging
|
||||
@@ -21,6 +22,7 @@ from zoneinfo import ZoneInfo
|
||||
from pydantic import ValidationError, TypeAdapter
|
||||
from websockets.asyncio.server import ServerConnection, serve
|
||||
|
||||
from .auth_service import AuthError, AuthService
|
||||
from .client import ClientConnection
|
||||
from .config import load_config
|
||||
from .item_catalog import (
|
||||
@@ -39,6 +41,12 @@ from .item_catalog import (
|
||||
from .item_type_handlers import get_item_type_handler
|
||||
from .item_service import ItemService
|
||||
from .models import (
|
||||
AuthLoginPacket,
|
||||
AuthLogoutPacket,
|
||||
AuthRegisterPacket,
|
||||
AuthRequiredPacket,
|
||||
AuthResultPacket,
|
||||
AuthResumePacket,
|
||||
BroadcastChatMessagePacket,
|
||||
BroadcastNicknamePacket,
|
||||
BroadcastPositionPacket,
|
||||
@@ -91,6 +99,12 @@ class SignalingServer:
|
||||
port: int,
|
||||
ssl_cert: str | None,
|
||||
ssl_key: str | None,
|
||||
auth_db_path: Path | None = None,
|
||||
auth_token_hash_secret: str = "dev-secret",
|
||||
password_min_length: int = 8,
|
||||
password_max_length: int = 32,
|
||||
username_min_length: int = 2,
|
||||
username_max_length: int = 32,
|
||||
max_message_size: int = 2_000_000,
|
||||
state_file: Path | None = None,
|
||||
grid_size: int = 41,
|
||||
@@ -104,6 +118,15 @@ class SignalingServer:
|
||||
self.max_message_size = max_message_size
|
||||
self._ssl_context = self._build_ssl_context(ssl_cert, ssl_key)
|
||||
self.clients: dict[ServerConnection, ClientConnection] = {}
|
||||
resolved_auth_db_path = auth_db_path or Path.cwd() / "runtime" / f"chatgrid_auth_{uuid.uuid4().hex}.db"
|
||||
self.auth_service = AuthService(
|
||||
db_path=resolved_auth_db_path,
|
||||
token_hash_secret=auth_token_hash_secret,
|
||||
password_min_length=password_min_length,
|
||||
password_max_length=password_max_length,
|
||||
username_min_length=username_min_length,
|
||||
username_max_length=username_max_length,
|
||||
)
|
||||
self.item_service = ItemService(state_file=state_file)
|
||||
self.item_last_use_ms: dict[str, int] = {}
|
||||
self.active_piano_keys_by_client: dict[str, set[str]] = {}
|
||||
@@ -714,22 +737,19 @@ class SignalingServer:
|
||||
await asyncio.Future()
|
||||
finally:
|
||||
self._flush_state_save()
|
||||
self.auth_service.close()
|
||||
|
||||
async def _handle_client(self, websocket: ServerConnection) -> None:
|
||||
"""Handle one websocket client's connect/message/disconnect lifecycle."""
|
||||
|
||||
client = ClientConnection(websocket=websocket, id=str(uuid.uuid4()))
|
||||
client.x = random.randrange(self.grid_size)
|
||||
client.y = random.randrange(self.grid_size)
|
||||
now_ms = self.item_service.now_ms()
|
||||
client.last_position_update_ms = now_ms
|
||||
client.movement_window_index = self._movement_window_index(now_ms)
|
||||
client.movement_window_steps_used = 0
|
||||
self.clients[websocket] = client
|
||||
LOGGER.info("client connected id=%s total=%d", client.id, len(self.clients))
|
||||
LOGGER.info("websocket opened id=%s", client.id)
|
||||
|
||||
try:
|
||||
await self._send_welcome(client)
|
||||
await self._send(
|
||||
websocket,
|
||||
AuthRequiredPacket(type="auth_required", message="Authentication required."),
|
||||
)
|
||||
async for raw_message in websocket:
|
||||
await self._handle_message(client, raw_message)
|
||||
finally:
|
||||
@@ -780,9 +800,101 @@ class SignalingServer:
|
||||
},
|
||||
uiDefinitions=self._build_ui_definitions(),
|
||||
serverInfo={"instanceId": self.instance_id, "version": self.server_version},
|
||||
auth={
|
||||
"authenticated": client.authenticated,
|
||||
"userId": client.user_id,
|
||||
"username": client.username,
|
||||
"role": client.role if client.authenticated else None,
|
||||
},
|
||||
)
|
||||
await self._send(client.websocket, packet)
|
||||
|
||||
async def _activate_authenticated_client(self, client: ClientConnection) -> None:
|
||||
"""Move an authenticated websocket client into the active world roster."""
|
||||
|
||||
if client.websocket in self.clients:
|
||||
return
|
||||
client.x = random.randrange(self.grid_size)
|
||||
client.y = random.randrange(self.grid_size)
|
||||
now_ms = self.item_service.now_ms()
|
||||
client.last_position_update_ms = now_ms
|
||||
client.movement_window_index = self._movement_window_index(now_ms)
|
||||
client.movement_window_steps_used = 0
|
||||
self.clients[client.websocket] = client
|
||||
LOGGER.info(
|
||||
"client authenticated id=%s user_id=%s username=%s total=%d",
|
||||
client.id,
|
||||
client.user_id,
|
||||
client.username,
|
||||
len(self.clients),
|
||||
)
|
||||
await self._send_welcome(client)
|
||||
await self._broadcast(
|
||||
BroadcastChatMessagePacket(
|
||||
type="chat_message",
|
||||
message=f"{client.nickname} has logged in.",
|
||||
system=True,
|
||||
),
|
||||
exclude=client.websocket,
|
||||
)
|
||||
|
||||
async def _handle_auth_packet(self, client: ClientConnection, packet: ClientPacket) -> bool:
|
||||
"""Handle pre-auth packets; returns True when packet was an auth command."""
|
||||
|
||||
if client.authenticated and isinstance(packet, (AuthLoginPacket, AuthRegisterPacket, AuthResumePacket)):
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(type="auth_result", ok=False, message="Already authenticated."),
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
if isinstance(packet, AuthRegisterPacket):
|
||||
session = self.auth_service.register(packet.username, packet.password, email=packet.email)
|
||||
elif isinstance(packet, AuthLoginPacket):
|
||||
session = self.auth_service.login(packet.username, packet.password)
|
||||
elif isinstance(packet, AuthResumePacket):
|
||||
session = self.auth_service.resume(packet.sessionToken)
|
||||
elif isinstance(packet, AuthLogoutPacket):
|
||||
if client.session_token:
|
||||
self.auth_service.revoke(client.session_token)
|
||||
client.session_token = None
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(type="auth_result", ok=True, message="Logged out."),
|
||||
)
|
||||
await client.websocket.close()
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except AuthError as exc:
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(type="auth_result", ok=False, message=str(exc)),
|
||||
)
|
||||
return True
|
||||
|
||||
client.authenticated = True
|
||||
client.user_id = session.user.id
|
||||
client.username = session.user.username
|
||||
client.role = session.user.role
|
||||
client.session_token = session.token
|
||||
client.nickname = session.user.last_nickname or client.nickname
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(
|
||||
type="auth_result",
|
||||
ok=True,
|
||||
message="Authenticated.",
|
||||
sessionToken=session.token,
|
||||
username=session.user.username,
|
||||
role=session.user.role,
|
||||
nickname=client.nickname,
|
||||
),
|
||||
)
|
||||
await self._activate_authenticated_client(client)
|
||||
return True
|
||||
|
||||
def _build_ui_definitions(self) -> dict:
|
||||
"""Build server-owned UI definitions for item/menu rendering."""
|
||||
|
||||
@@ -840,6 +952,22 @@ class SignalingServer:
|
||||
PACKET_LOGGER.warning("invalid packet from id=%s: %s", client.id, exc)
|
||||
return
|
||||
|
||||
# Compatibility path for local tests injecting pre-authenticated clients
|
||||
# directly into server.clients without running websocket auth handshake.
|
||||
if not client.authenticated and client.websocket in self.clients:
|
||||
client.authenticated = True
|
||||
client.user_id = client.user_id or client.id
|
||||
client.username = client.username or client.nickname
|
||||
|
||||
if await self._handle_auth_packet(client, packet):
|
||||
return
|
||||
if not client.authenticated:
|
||||
await self._send(
|
||||
client.websocket,
|
||||
AuthResultPacket(type="auth_result", ok=False, message="Authenticate before sending gameplay actions."),
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(packet, UpdatePositionPacket):
|
||||
if not self._is_in_bounds(packet.x, packet.y):
|
||||
PACKET_LOGGER.warning(
|
||||
@@ -975,6 +1103,8 @@ class SignalingServer:
|
||||
)
|
||||
return
|
||||
client.nickname = requested_nickname
|
||||
if client.user_id:
|
||||
self.auth_service.set_last_nickname(client.user_id, client.nickname)
|
||||
if old_nickname == "user...":
|
||||
LOGGER.info("user login id=%s nickname=%s", client.id, client.nickname)
|
||||
else:
|
||||
@@ -1471,6 +1601,7 @@ def run() -> None:
|
||||
parser.add_argument("--ssl-cert", default=None)
|
||||
parser.add_argument("--ssl-key", default=None)
|
||||
parser.add_argument("--allow-insecure-ws", action="store_true", default=None)
|
||||
parser.add_argument("--bootstrap-admin", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
config_path = Path(args.config) if args.config else None
|
||||
@@ -1499,15 +1630,51 @@ def run() -> None:
|
||||
"TLS is required when insecure ws is disabled. Set tls.cert_file/tls.key_file in config.toml."
|
||||
)
|
||||
|
||||
auth_secret = os.getenv("CHGRID_AUTH_SECRET", "").strip()
|
||||
if not auth_secret:
|
||||
raise SystemExit("CHGRID_AUTH_SECRET is required.")
|
||||
auth_db_value = config.auth.db_file.strip()
|
||||
if not auth_db_value:
|
||||
raise SystemExit("auth.db_file must not be empty.")
|
||||
auth_base_dir = config_path.parent if config_path is not None else Path.cwd()
|
||||
auth_db_path = Path(auth_db_value)
|
||||
if not auth_db_path.is_absolute():
|
||||
auth_db_path = auth_base_dir / auth_db_path
|
||||
auth_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, config.logging.level.upper(), logging.INFO),
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||
)
|
||||
if args.bootstrap_admin:
|
||||
auth_service = AuthService(
|
||||
db_path=auth_db_path,
|
||||
token_hash_secret=auth_secret,
|
||||
password_min_length=config.auth.password_min_length,
|
||||
password_max_length=config.auth.password_max_length,
|
||||
username_min_length=config.auth.username_min_length,
|
||||
username_max_length=config.auth.username_max_length,
|
||||
)
|
||||
try:
|
||||
username = input("Admin username: ").strip()
|
||||
password = getpass("Admin password: ")
|
||||
email = input("Admin email (optional): ").strip() or None
|
||||
created = auth_service.bootstrap_admin(username, password, email=email)
|
||||
print(f"Admin created: {created.username}")
|
||||
finally:
|
||||
auth_service.close()
|
||||
return
|
||||
server = SignalingServer(
|
||||
host,
|
||||
port,
|
||||
ssl_cert,
|
||||
ssl_key,
|
||||
auth_db_path=auth_db_path,
|
||||
auth_token_hash_secret=auth_secret,
|
||||
password_min_length=config.auth.password_min_length,
|
||||
password_max_length=config.auth.password_max_length,
|
||||
username_min_length=config.auth.username_min_length,
|
||||
username_max_length=config.auth.username_max_length,
|
||||
max_message_size=config.network.max_message_bytes,
|
||||
state_file=state_file,
|
||||
grid_size=config.world.grid_size,
|
||||
|
||||
Reference in New Issue
Block a user