From 3027ea04b9535db8fd3abb6736d6a107f3f765e6 Mon Sep 17 00:00:00 2001 From: Jage9 Date: Sat, 21 Feb 2026 17:19:27 -0500 Subject: [PATCH] Fix server bounds validation, cooldown timing, and broadcast fanout --- server/app/config.py | 7 ++ server/app/server.py | 30 +++++-- server/config.example.toml | 4 + server/tests/test_config.py | 1 + server/tests/test_item_use_cooldown.py | 32 ++++++++ server/tests/test_server_message_handling.py | 85 ++++++++++++++++++++ 6 files changed, 154 insertions(+), 5 deletions(-) create mode 100644 server/tests/test_server_message_handling.py diff --git a/server/app/config.py b/server/app/config.py index 4b660f7..018445f 100644 --- a/server/app/config.py +++ b/server/app/config.py @@ -41,6 +41,12 @@ class StorageConfigSection(BaseModel): state_file: str = "runtime/items.json" +class WorldConfigSection(BaseModel): + """Authoritative world geometry options.""" + + grid_size: int = Field(default=41, ge=1) + + class AppConfig(BaseModel): """Top-level application configuration document.""" @@ -49,6 +55,7 @@ class AppConfig(BaseModel): tls: TlsConfigSection = TlsConfigSection() logging: LoggingConfigSection = LoggingConfigSection() storage: StorageConfigSection = StorageConfigSection() + world: WorldConfigSection = WorldConfigSection() def load_config(path: Path | None) -> AppConfig: diff --git a/server/app/server.py b/server/app/server.py index 6d86e0d..7f40ddb 100644 --- a/server/app/server.py +++ b/server/app/server.py @@ -67,6 +67,7 @@ class SignalingServer: ssl_key: str | None, max_message_size: int = 2_000_000, state_file: Path | None = None, + grid_size: int = 41, ): """Initialize runtime state, TLS context, and item service.""" @@ -77,6 +78,7 @@ class SignalingServer: self.clients: dict[ServerConnection, ClientConnection] = {} self.item_service = ItemService(state_file=state_file) self.item_last_use_ms: dict[str, int] = {} + self.grid_size = max(1, grid_size) @property def items(self) -> dict[str, WorldItem]: @@ -106,6 +108,11 @@ class SignalingServer: return "radio" if item.type == "radio_station" else item.type + def _is_in_bounds(self, x: int, y: int) -> bool: + """Check whether a coordinate is inside server-authoritative world bounds.""" + + return 0 <= x < self.grid_size and 0 <= y < self.grid_size + @staticmethod def _normalize_clock_timezone(value: object) -> str: """Normalize timezone input to one of supported clock zones.""" @@ -270,6 +277,15 @@ class SignalingServer: return if isinstance(packet, UpdatePositionPacket): + if not self._is_in_bounds(packet.x, packet.y): + PACKET_LOGGER.warning( + "out-of-bounds position ignored id=%s x=%d y=%d grid_size=%d", + client.id, + packet.x, + packet.y, + self.grid_size, + ) + return client.x = packet.x client.y = packet.y await self._broadcast( @@ -456,6 +472,9 @@ class SignalingServer: if item.carrierId != client.id: await self._send_item_result(client, False, "drop", "You are not carrying that item.", item.id) return + if not self._is_in_bounds(packet.x, packet.y): + await self._send_item_result(client, False, "drop", "Drop position is out of bounds.", item.id) + return item.carrierId = None item.x = packet.x item.y = packet.y @@ -518,7 +537,6 @@ class SignalingServer: item.id, ) return - self.item_last_use_ms[item.id] = now_ms delayed_wheel_self_result: str | None = None delayed_wheel_others_result: str | None = None if item.type == "radio_station": @@ -578,6 +596,7 @@ class SignalingServer: display_time = self._format_clock_display_time(item.params) others_message = f"{client.nickname} checks {item.title}. {item.title} says {display_time}." self_message = f"{item.title} says {display_time}." + self.item_last_use_ms[item.id] = now_ms await self._broadcast( BroadcastChatMessagePacket(type="chat_message", message=others_message, system=True), exclude=client.websocket, @@ -790,10 +809,10 @@ class SignalingServer: async def _broadcast(self, packet: object, exclude: ServerConnection | None = None) -> None: """Broadcast one packet to all clients except an optional websocket.""" - for websocket in list(self.clients.keys()): - if websocket is exclude: - continue - await self._send(websocket, packet) + recipients = [websocket for websocket in self.clients if websocket is not exclude] + if not recipients: + return + await asyncio.gather(*(self._send(websocket, packet) for websocket in recipients)) async def _send(self, websocket: ServerConnection, packet: object) -> None: """Send one packet to one websocket, swallowing per-client send failures.""" @@ -875,5 +894,6 @@ def run() -> None: ssl_key, max_message_size=config.network.max_message_bytes, state_file=state_file, + grid_size=config.world.grid_size, ) asyncio.run(server.start()) diff --git a/server/config.example.toml b/server/config.example.toml index fe3f0e7..b3d4dd9 100644 --- a/server/config.example.toml +++ b/server/config.example.toml @@ -21,3 +21,7 @@ level = "INFO" [storage] # Item persistence file. Relative paths are resolved from this config file directory. state_file = "runtime/items.json" + +[world] +# Grid width/height in cells. Valid coordinates are 0..grid_size-1. +grid_size = 41 diff --git a/server/tests/test_config.py b/server/tests/test_config.py index fcf9611..0f97c70 100644 --- a/server/tests/test_config.py +++ b/server/tests/test_config.py @@ -10,6 +10,7 @@ def test_load_config_defaults_when_path_none() -> None: assert cfg.server.bind_ip == "127.0.0.1" assert cfg.network.allow_insecure_ws is True assert cfg.storage.state_file == "runtime/items.json" + assert cfg.world.grid_size == 41 def test_load_config_requires_tls_when_insecure_disabled(tmp_path: Path) -> None: diff --git a/server/tests/test_item_use_cooldown.py b/server/tests/test_item_use_cooldown.py index d6c89e6..0130a2f 100644 --- a/server/tests/test_item_use_cooldown.py +++ b/server/tests/test_item_use_cooldown.py @@ -182,3 +182,35 @@ async def test_clock_timezone_update_validates(monkeypatch: pytest.MonkeyPatch) ) assert send_payloads[-1].ok is False assert "timezone must be one of" in send_payloads[-1].message.lower() + + +@pytest.mark.asyncio +async def test_failed_wheel_use_does_not_consume_cooldown(monkeypatch: pytest.MonkeyPatch) -> None: + server = SignalingServer("127.0.0.1", 8765, None, None) + ws = _fake_ws() + client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6) + server.clients[ws] = client + item = server.item_service.default_item(client, "wheel") + item.params["spaces"] = ",,," + server.item_service.add_item(item) + + send_payloads: list[object] = [] + now_ms = 40_000 + + async def fake_send(websocket: ServerConnection, packet: object) -> None: + send_payloads.append(packet) + + async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None: + return + + monkeypatch.setattr(server, "_send", fake_send) + monkeypatch.setattr(server, "_broadcast", fake_broadcast) + monkeypatch.setattr(server.item_service, "now_ms", lambda: now_ms) + + await server._handle_message(client, json.dumps({"type": "item_use", "itemId": item.id})) + assert send_payloads[-1].ok is False + assert "spaces" in send_payloads[-1].message.lower() + + item.params["spaces"] = "a,b,c" + await server._handle_message(client, json.dumps({"type": "item_use", "itemId": item.id})) + assert send_payloads[-1].ok is True diff --git a/server/tests/test_server_message_handling.py b/server/tests/test_server_message_handling.py new file mode 100644 index 0000000..652d466 --- /dev/null +++ b/server/tests/test_server_message_handling.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import asyncio +import json +from time import monotonic +from typing import cast + +import pytest +from websockets.asyncio.server import ServerConnection + +from app.client import ClientConnection +from app.server import SignalingServer + + +def _fake_ws() -> ServerConnection: + return cast(ServerConnection, object()) + + +@pytest.mark.asyncio +async def test_update_position_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None: + server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41) + ws = _fake_ws() + client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6) + server.clients[ws] = client + + broadcast_payloads: list[object] = [] + + async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None: + broadcast_payloads.append(packet) + + monkeypatch.setattr(server, "_broadcast", fake_broadcast) + + await server._handle_message(client, json.dumps({"type": "update_position", "x": 200, "y": -5})) + + assert client.x == 5 + assert client.y == 6 + assert broadcast_payloads == [] + + +@pytest.mark.asyncio +async def test_item_drop_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None: + server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41) + ws = _fake_ws() + client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6) + server.clients[ws] = client + item = server.item_service.default_item(client, "dice") + item.carrierId = client.id + server.item_service.add_item(item) + + send_payloads: list[object] = [] + + async def fake_send(websocket: ServerConnection, packet: object) -> None: + send_payloads.append(packet) + + monkeypatch.setattr(server, "_send", fake_send) + + await server._handle_message(client, json.dumps({"type": "item_drop", "itemId": item.id, "x": 999, "y": 999})) + + assert item.carrierId == client.id + assert send_payloads[-1].ok is False + assert "out of bounds" in send_payloads[-1].message.lower() + + +@pytest.mark.asyncio +async def test_broadcast_fanout_is_concurrent(monkeypatch: pytest.MonkeyPatch) -> None: + server = SignalingServer("127.0.0.1", 8765, None, None) + ws1 = _fake_ws() + ws2 = _fake_ws() + server.clients[ws1] = ClientConnection(websocket=ws1, id="u1") + server.clients[ws2] = ClientConnection(websocket=ws2, id="u2") + + send_started_at: dict[ServerConnection, float] = {} + + async def fake_send(websocket: ServerConnection, packet: object) -> None: + send_started_at[websocket] = monotonic() + if websocket is ws1: + await asyncio.sleep(0.05) + + monkeypatch.setattr(server, "_send", fake_send) + + await server._broadcast({"type": "noop"}) + + assert ws1 in send_started_at + assert ws2 in send_started_at + assert abs(send_started_at[ws1] - send_started_at[ws2]) < 0.02