Files
chat_grid/server/app/server.py

642 lines
26 KiB
Python
Raw Normal View History

2026-02-20 08:16:43 -05:00
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import random
import ssl
import uuid
from pathlib import Path
from typing import Literal
from pydantic import ValidationError, TypeAdapter
from websockets.asyncio.server import ServerConnection, serve
from .client import ClientConnection
from .config import load_config
from .item_catalog import get_item_use_cooldown_ms
2026-02-20 08:16:43 -05:00
from .item_service import ItemService
from .models import (
BroadcastChatMessagePacket,
BroadcastNicknamePacket,
BroadcastPositionPacket,
ChatMessagePacket,
ClientPacket,
ForwardSignalPacket,
ItemActionResultPacket,
ItemAddPacket,
ItemDeletePacket,
ItemDropPacket,
ItemPickupPacket,
ItemRemovePacket,
ItemUpdatePacket,
ItemUpsertPacket,
ItemUsePacket,
ItemUseSoundPacket,
NicknameResultPacket,
PingPacket,
PongPacket,
RemoteUser,
UpdateNicknamePacket,
UpdatePositionPacket,
UserLeftPacket,
WelcomePacket,
WorldItem,
)
LOGGER = logging.getLogger("chgrid.server")
PACKET_LOGGER = logging.getLogger("chgrid.server.packet")
CLIENT_PACKET_ADAPTER = TypeAdapter(ClientPacket)
RADIO_EFFECT_IDS = {"reverb", "echo", "flanger", "high_pass", "low_pass", "off"}
2026-02-20 08:16:43 -05:00
class SignalingServer:
def __init__(
self,
host: str,
port: int,
ssl_cert: str | None,
ssl_key: str | None,
max_message_size: int = 2_000_000,
state_file: Path | None = None,
):
self.host = host
self.port = port
self.max_message_size = max_message_size
self._ssl_context = self._build_ssl_context(ssl_cert, ssl_key)
self.clients: dict[ServerConnection, ClientConnection] = {}
self.item_service = ItemService(state_file=state_file)
self.item_last_use_ms: dict[str, int] = {}
2026-02-20 08:16:43 -05:00
@property
def items(self) -> dict[str, WorldItem]:
return self.item_service.items
def _nickname_key(self, nickname: str) -> str:
return nickname.casefold()
def _is_nickname_taken(self, nickname: str, exclude_client_id: str | None = None) -> bool:
wanted = self._nickname_key(nickname)
for other in self.clients.values():
if exclude_client_id is not None and other.id == exclude_client_id:
continue
if self._nickname_key(other.nickname) == wanted:
return True
return False
@staticmethod
def _item_type_label(item: WorldItem) -> str:
return "radio" if item.type == "radio_station" else item.type
async def _send_item_result(
self,
client: ClientConnection,
ok: bool,
action: Literal["add", "pickup", "drop", "delete", "use", "update"],
message: str,
item_id: str | None = None,
) -> None:
await self._send(
client.websocket,
ItemActionResultPacket(
type="item_action_result",
ok=ok,
action=action,
message=message,
itemId=item_id,
),
)
async def _broadcast_item(self, item: WorldItem) -> None:
await self._broadcast(ItemUpsertPacket(type="item_upsert", item=item))
async def start(self) -> None:
protocol = "wss" if self._ssl_context else "ws"
LOGGER.info("starting signaling server on %s://%s:%d", protocol, self.host, self.port)
async with serve(
self._handle_client,
self.host,
self.port,
ssl=self._ssl_context,
max_size=self.max_message_size,
):
await asyncio.Future()
async def _handle_client(self, websocket: ServerConnection) -> None:
client = ClientConnection(websocket=websocket, id=str(uuid.uuid4()))
self.clients[websocket] = client
LOGGER.info("client connected id=%s total=%d", client.id, len(self.clients))
try:
await self._send_welcome(client)
async for raw_message in websocket:
await self._handle_message(client, raw_message)
finally:
if websocket in self.clients:
disconnected = self.clients.pop(websocket)
for item in self.item_service.drop_carried_items_for_disconnect(disconnected):
await self._broadcast_item(item)
self.item_service.save_state()
LOGGER.info("client disconnected id=%s total=%d", disconnected.id, len(self.clients))
await self._broadcast(UserLeftPacket(type="user_left", id=disconnected.id), exclude=websocket)
await self._broadcast(
BroadcastChatMessagePacket(
type="chat_message",
message=f"{disconnected.nickname} has logged out.",
system=True,
),
exclude=websocket,
)
async def _send_welcome(self, client: ClientConnection) -> None:
users = [
RemoteUser(id=other.id, nickname=other.nickname, x=other.x, y=other.y)
for ws, other in self.clients.items()
if ws is not client.websocket
]
packet = WelcomePacket(
type="welcome",
id=client.id,
users=users,
items=[item.model_dump(exclude_none=True) for item in self.items.values()],
)
await self._send(client.websocket, packet)
async def _handle_message(self, client: ClientConnection, raw_message: str) -> None:
try:
payload = json.loads(raw_message)
except json.JSONDecodeError:
PACKET_LOGGER.warning("non-json packet from id=%s", client.id)
return
try:
packet = CLIENT_PACKET_ADAPTER.validate_python(payload)
except ValidationError as exc:
PACKET_LOGGER.warning("invalid packet from id=%s: %s", client.id, exc)
return
if isinstance(packet, UpdatePositionPacket):
client.x = packet.x
client.y = packet.y
await self._broadcast(
BroadcastPositionPacket(type="update_position", id=client.id, x=client.x, y=client.y),
exclude=client.websocket,
)
carried = self.item_service.find_carried_item(client.id)
if carried:
carried.x = client.x
carried.y = client.y
carried.updatedAt = self.item_service.now_ms()
await self._broadcast_item(carried)
return
if isinstance(packet, UpdateNicknamePacket):
requested_nickname = packet.nickname.strip()
if not requested_nickname:
await self._send(
client.websocket,
NicknameResultPacket(
type="nickname_result",
accepted=False,
requestedNickname=packet.nickname,
effectiveNickname=client.nickname,
reason="Nickname is required.",
),
)
return
old_nickname = client.nickname
if self._is_nickname_taken(requested_nickname, exclude_client_id=client.id):
await self._send(
client.websocket,
NicknameResultPacket(
type="nickname_result",
accepted=False,
requestedNickname=requested_nickname,
effectiveNickname=client.nickname,
reason="Nickname already in use.",
),
)
return
if requested_nickname == old_nickname:
await self._send(
client.websocket,
NicknameResultPacket(
type="nickname_result",
accepted=True,
requestedNickname=requested_nickname,
effectiveNickname=client.nickname,
),
)
return
client.nickname = requested_nickname
await self._send(
client.websocket,
NicknameResultPacket(
type="nickname_result",
accepted=True,
requestedNickname=requested_nickname,
effectiveNickname=client.nickname,
),
)
await self._broadcast(
BroadcastNicknamePacket(type="update_nickname", id=client.id, nickname=client.nickname),
exclude=client.websocket,
)
if old_nickname == "user...":
await self._broadcast(
BroadcastChatMessagePacket(
type="chat_message",
message=f"{client.nickname} has logged in.",
system=True,
),
exclude=client.websocket,
)
else:
await self._broadcast(
BroadcastChatMessagePacket(
type="chat_message",
message=f"{old_nickname} is now known as {client.nickname}.",
system=True,
),
exclude=client.websocket,
)
self_message = (
f"Welcome. Logged in as {client.nickname}."
if old_nickname == "user..."
else f"You are now known as {client.nickname}."
)
await self._send(
client.websocket,
BroadcastChatMessagePacket(
type="chat_message",
message=self_message,
system=True,
),
)
return
if isinstance(packet, ChatMessagePacket):
await self._broadcast(
BroadcastChatMessagePacket(
type="chat_message",
message=packet.message,
senderId=client.id,
senderNickname=client.nickname,
system=False,
)
)
return
if isinstance(packet, PingPacket):
await self._send(
client.websocket,
PongPacket(type="pong", clientSentAt=packet.clientSentAt),
)
return
if isinstance(packet, ItemAddPacket):
item = self.item_service.default_item(client, packet.itemType)
self.item_service.add_item(item)
await self._broadcast_item(item)
self.item_service.save_state()
item_text = f"{item.title} ({self._item_type_label(item)})"
await self._broadcast(
BroadcastChatMessagePacket(
type="chat_message",
message=f"{client.nickname} placed {item_text} at {item.x}, {item.y}.",
system=True,
),
exclude=client.websocket,
)
await self._send_item_result(
client,
True,
"add",
f"You placed {item_text} at {item.x}, {item.y}.",
item.id,
)
return
if isinstance(packet, ItemPickupPacket):
item = self.items.get(packet.itemId)
if not item:
await self._send_item_result(client, False, "pickup", "Item not found.")
return
if item.carrierId and item.carrierId != client.id:
await self._send_item_result(client, False, "pickup", "Item is already being carried.", item.id)
return
carried = self.item_service.find_carried_item(client.id)
if carried and carried.id != item.id:
await self._send_item_result(client, False, "pickup", "You are already carrying an item.", item.id)
return
if item.carrierId is None and (item.x != client.x or item.y != client.y):
await self._send_item_result(client, False, "pickup", "Item is not on your square.", item.id)
return
item.carrierId = client.id
item.x = client.x
item.y = client.y
item.updatedAt = self.item_service.now_ms()
await self._broadcast_item(item)
self.item_service.save_state()
await self._send_item_result(client, True, "pickup", f"Picked up {item.title}.", item.id)
return
if isinstance(packet, ItemDropPacket):
item = self.items.get(packet.itemId)
if not item:
await self._send_item_result(client, False, "drop", "Item not found.")
return
if item.carrierId != client.id:
await self._send_item_result(client, False, "drop", "You are not carrying that item.", item.id)
return
item.carrierId = None
item.x = packet.x
item.y = packet.y
item.updatedAt = self.item_service.now_ms()
await self._broadcast_item(item)
self.item_service.save_state()
await self._send_item_result(client, True, "drop", f"Dropped {item.title}.", item.id)
return
if isinstance(packet, ItemDeletePacket):
item = self.items.get(packet.itemId)
if not item:
await self._send_item_result(client, False, "delete", "Item not found.")
return
if item.carrierId and item.carrierId != client.id:
await self._send_item_result(client, False, "delete", "Item is being carried by another user.", item.id)
return
if item.carrierId is None and (item.x != client.x or item.y != client.y):
await self._send_item_result(client, False, "delete", "Item is not on your square.", item.id)
return
self.item_service.remove_item(item.id)
self.item_last_use_ms.pop(item.id, None)
2026-02-20 08:16:43 -05:00
await self._broadcast(ItemRemovePacket(type="item_remove", itemId=item.id))
self.item_service.save_state()
await self._send_item_result(client, True, "delete", f"Deleted {item.title}.", item.id)
return
if isinstance(packet, ItemUsePacket):
item = self.items.get(packet.itemId)
if not item:
await self._send_item_result(client, False, "use", "Item not found.")
return
if item.carrierId not in (None, client.id):
await self._send_item_result(client, False, "use", "Item is not available.", item.id)
return
if item.carrierId is None and (item.x != client.x or item.y != client.y):
await self._send_item_result(client, False, "use", "Item is not on your square.", item.id)
return
if item.type != "dice":
await self._send_item_result(client, False, "use", "This item cannot be used yet.", item.id)
return
now_ms = self.item_service.now_ms()
cooldown_ms = get_item_use_cooldown_ms(item.type)
last_use_ms = self.item_last_use_ms.get(item.id)
if last_use_ms is not None and now_ms - last_use_ms < cooldown_ms:
remaining_ms = cooldown_ms - (now_ms - last_use_ms)
await self._send_item_result(
client,
False,
"use",
f"{item.title} is on cooldown for {max(1, remaining_ms)} ms.",
item.id,
)
return
self.item_last_use_ms[item.id] = now_ms
2026-02-20 08:16:43 -05:00
try:
sides = max(1, min(100, int(item.params.get("sides", 6))))
number = max(1, min(100, int(item.params.get("number", 2))))
except (TypeError, ValueError):
sides = 6
number = 2
rolls = [random.randint(1, sides) for _ in range(number)]
total = sum(rolls)
others_message = (
f"{client.nickname} rolled {item.title}: {', '.join(str(value) for value in rolls)} (total {total})."
)
self_message = f"You rolled {item.title}: {', '.join(str(value) for value in rolls)} (total {total})."
await self._broadcast(
BroadcastChatMessagePacket(type="chat_message", message=others_message, system=True),
exclude=client.websocket,
)
if item.useSound:
await self._broadcast(
ItemUseSoundPacket(
type="item_use_sound",
itemId=item.id,
sound=item.useSound,
x=item.x,
y=item.y,
)
)
await self._send_item_result(client, True, "use", self_message, item.id)
return
if isinstance(packet, ItemUpdatePacket):
item = self.items.get(packet.itemId)
if not item:
await self._send_item_result(client, False, "update", "Item not found.")
return
if item.carrierId not in (None, client.id):
await self._send_item_result(client, False, "update", "Item is not available for editing.", item.id)
return
if item.carrierId is None and (item.x != client.x or item.y != client.y):
await self._send_item_result(client, False, "update", "Item is not on your square.", item.id)
return
if packet.title is not None:
title = packet.title.strip()
if not title:
await self._send_item_result(client, False, "update", "Title cannot be empty.", item.id)
return
item.title = title[:80]
if packet.params:
next_params = {**item.params, **packet.params}
if item.type == "dice":
try:
sides = int(next_params.get("sides", 6))
number = int(next_params.get("number", 2))
except (TypeError, ValueError):
await self._send_item_result(client, False, "update", "Dice values must be numbers.", item.id)
return
if not (1 <= sides <= 100 and 1 <= number <= 100):
await self._send_item_result(
client, False, "update", "Dice sides and number must be between 1 and 100.", item.id
)
return
next_params["sides"] = sides
next_params["number"] = number
if item.type == "radio_station":
stream_url = str(next_params.get("streamUrl", "")).strip()
previous_stream_url = str(item.params.get("streamUrl", "")).strip()
next_params["streamUrl"] = stream_url
enabled_value = next_params.get("enabled", True)
if isinstance(enabled_value, bool):
enabled = enabled_value
elif isinstance(enabled_value, (int, float)):
enabled = bool(enabled_value)
elif isinstance(enabled_value, str):
token = enabled_value.strip().lower()
if token in {"on", "true", "1", "yes"}:
enabled = True
elif token in {"off", "false", "0", "no"}:
enabled = False
else:
await self._send_item_result(
client, False, "update", "enabled must be true/false or on/off.", item.id
)
return
else:
await self._send_item_result(
client, False, "update", "enabled must be true/false or on/off.", item.id
)
return
if stream_url and stream_url != previous_stream_url:
enabled = True
if not stream_url:
enabled = False
next_params["enabled"] = enabled
try:
volume = int(next_params.get("volume", 50))
except (TypeError, ValueError):
await self._send_item_result(client, False, "update", "volume must be a number.", item.id)
return
if not (0 <= volume <= 100):
await self._send_item_result(
client, False, "update", "volume must be between 0 and 100.", item.id
)
return
next_params["volume"] = volume
effect = str(next_params.get("effect", "off")).strip().lower()
if effect not in RADIO_EFFECT_IDS:
await self._send_item_result(
client,
False,
"update",
"effect must be one of reverb, echo, flanger, high_pass, low_pass, off.",
item.id,
)
return
next_params["effect"] = effect
try:
effect_value = int(next_params.get("effectValue", 50))
except (TypeError, ValueError):
await self._send_item_result(client, False, "update", "effectValue must be a number.", item.id)
return
if not (0 <= effect_value <= 100):
await self._send_item_result(
client, False, "update", "effectValue must be between 0 and 100.", item.id
)
return
next_params["effectValue"] = round(effect_value / 5) * 5
2026-02-20 08:16:43 -05:00
item.params = next_params
item.updatedAt = self.item_service.now_ms()
item.version += 1
await self._broadcast_item(item)
self.item_service.save_state()
await self._send_item_result(client, True, "update", f"Updated {item.title}.", item.id)
return
target = self._find_by_id(packet.targetId)
if not target:
PACKET_LOGGER.info("signal target not found sender=%s target=%s", client.id, packet.targetId)
return
await self._send(
target.websocket,
ForwardSignalPacket(
type="signal",
senderId=client.id,
senderNickname=client.nickname,
x=client.x,
y=client.y,
sdp=packet.sdp,
ice=packet.ice,
),
)
async def _broadcast(self, packet: object, exclude: ServerConnection | None = None) -> None:
for websocket in list(self.clients.keys()):
if websocket is exclude:
continue
await self._send(websocket, packet)
async def _send(self, websocket: ServerConnection, packet: object) -> None:
try:
if hasattr(packet, "model_dump"):
data = packet.model_dump(exclude_none=True)
else:
data = packet
await websocket.send(json.dumps(data))
except Exception as exc: # intentionally broad to keep server alive per client error
LOGGER.debug("send failure: %s", exc)
def _find_by_id(self, client_id: str) -> ClientConnection | None:
for client in self.clients.values():
if client.id == client_id:
return client
return None
@staticmethod
def _build_ssl_context(cert: str | None, key: str | None) -> ssl.SSLContext | None:
if not cert or not key:
return None
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(certfile=Path(cert), keyfile=Path(key))
return context
def run() -> None:
parser = argparse.ArgumentParser(description="chgrid signaling server")
parser.add_argument("--config", default="config.toml")
parser.add_argument("--host", default=None)
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--ssl-cert", default=None)
parser.add_argument("--ssl-key", default=None)
parser.add_argument("--allow-insecure-ws", action="store_true", default=None)
args = parser.parse_args()
config_path = Path(args.config) if args.config else None
if config_path and not config_path.exists() and args.config == "config.toml":
config_path = None
config = load_config(config_path)
host = args.host or config.server.bind_ip
port = args.port or config.server.port
allow_insecure_ws = config.network.allow_insecure_ws
if args.allow_insecure_ws is True:
allow_insecure_ws = True
ssl_cert = args.ssl_cert if args.ssl_cert is not None else config.tls.cert_file or None
ssl_key = args.ssl_key if args.ssl_key is not None else config.tls.key_file or None
state_file_value = config.storage.state_file.strip()
state_file: Path | None = None
if state_file_value:
base_dir = config_path.parent if config_path is not None else Path.cwd()
state_file = Path(state_file_value)
if not state_file.is_absolute():
state_file = base_dir / state_file
if not allow_insecure_ws and (not ssl_cert or not ssl_key):
raise SystemExit(
"TLS is required when insecure ws is disabled. Set tls.cert_file/tls.key_file in config.toml."
)
logging.basicConfig(
level=getattr(logging, config.logging.level.upper(), logging.INFO),
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
server = SignalingServer(
host,
port,
ssl_cert,
ssl_key,
max_message_size=config.network.max_message_bytes,
state_file=state_file,
)
asyncio.run(server.start())