Fix server bounds validation, cooldown timing, and broadcast fanout

This commit is contained in:
Jage9
2026-02-21 17:19:27 -05:00
parent fe32cd28f2
commit 3027ea04b9
6 changed files with 154 additions and 5 deletions

View File

@@ -41,6 +41,12 @@ class StorageConfigSection(BaseModel):
state_file: str = "runtime/items.json"
class WorldConfigSection(BaseModel):
"""Authoritative world geometry options."""
grid_size: int = Field(default=41, ge=1)
class AppConfig(BaseModel):
"""Top-level application configuration document."""
@@ -49,6 +55,7 @@ class AppConfig(BaseModel):
tls: TlsConfigSection = TlsConfigSection()
logging: LoggingConfigSection = LoggingConfigSection()
storage: StorageConfigSection = StorageConfigSection()
world: WorldConfigSection = WorldConfigSection()
def load_config(path: Path | None) -> AppConfig:

View File

@@ -67,6 +67,7 @@ class SignalingServer:
ssl_key: str | None,
max_message_size: int = 2_000_000,
state_file: Path | None = None,
grid_size: int = 41,
):
"""Initialize runtime state, TLS context, and item service."""
@@ -77,6 +78,7 @@ class SignalingServer:
self.clients: dict[ServerConnection, ClientConnection] = {}
self.item_service = ItemService(state_file=state_file)
self.item_last_use_ms: dict[str, int] = {}
self.grid_size = max(1, grid_size)
@property
def items(self) -> dict[str, WorldItem]:
@@ -106,6 +108,11 @@ class SignalingServer:
return "radio" if item.type == "radio_station" else item.type
def _is_in_bounds(self, x: int, y: int) -> bool:
"""Check whether a coordinate is inside server-authoritative world bounds."""
return 0 <= x < self.grid_size and 0 <= y < self.grid_size
@staticmethod
def _normalize_clock_timezone(value: object) -> str:
"""Normalize timezone input to one of supported clock zones."""
@@ -270,6 +277,15 @@ class SignalingServer:
return
if isinstance(packet, UpdatePositionPacket):
if not self._is_in_bounds(packet.x, packet.y):
PACKET_LOGGER.warning(
"out-of-bounds position ignored id=%s x=%d y=%d grid_size=%d",
client.id,
packet.x,
packet.y,
self.grid_size,
)
return
client.x = packet.x
client.y = packet.y
await self._broadcast(
@@ -456,6 +472,9 @@ class SignalingServer:
if item.carrierId != client.id:
await self._send_item_result(client, False, "drop", "You are not carrying that item.", item.id)
return
if not self._is_in_bounds(packet.x, packet.y):
await self._send_item_result(client, False, "drop", "Drop position is out of bounds.", item.id)
return
item.carrierId = None
item.x = packet.x
item.y = packet.y
@@ -518,7 +537,6 @@ class SignalingServer:
item.id,
)
return
self.item_last_use_ms[item.id] = now_ms
delayed_wheel_self_result: str | None = None
delayed_wheel_others_result: str | None = None
if item.type == "radio_station":
@@ -578,6 +596,7 @@ class SignalingServer:
display_time = self._format_clock_display_time(item.params)
others_message = f"{client.nickname} checks {item.title}. {item.title} says {display_time}."
self_message = f"{item.title} says {display_time}."
self.item_last_use_ms[item.id] = now_ms
await self._broadcast(
BroadcastChatMessagePacket(type="chat_message", message=others_message, system=True),
exclude=client.websocket,
@@ -790,10 +809,10 @@ class SignalingServer:
async def _broadcast(self, packet: object, exclude: ServerConnection | None = None) -> None:
"""Broadcast one packet to all clients except an optional websocket."""
for websocket in list(self.clients.keys()):
if websocket is exclude:
continue
await self._send(websocket, packet)
recipients = [websocket for websocket in self.clients if websocket is not exclude]
if not recipients:
return
await asyncio.gather(*(self._send(websocket, packet) for websocket in recipients))
async def _send(self, websocket: ServerConnection, packet: object) -> None:
"""Send one packet to one websocket, swallowing per-client send failures."""
@@ -875,5 +894,6 @@ def run() -> None:
ssl_key,
max_message_size=config.network.max_message_bytes,
state_file=state_file,
grid_size=config.world.grid_size,
)
asyncio.run(server.start())

View File

@@ -21,3 +21,7 @@ level = "INFO"
[storage]
# Item persistence file. Relative paths are resolved from this config file directory.
state_file = "runtime/items.json"
[world]
# Grid width/height in cells. Valid coordinates are 0..grid_size-1.
grid_size = 41

View File

@@ -10,6 +10,7 @@ def test_load_config_defaults_when_path_none() -> None:
assert cfg.server.bind_ip == "127.0.0.1"
assert cfg.network.allow_insecure_ws is True
assert cfg.storage.state_file == "runtime/items.json"
assert cfg.world.grid_size == 41
def test_load_config_requires_tls_when_insecure_disabled(tmp_path: Path) -> None:

View File

@@ -182,3 +182,35 @@ async def test_clock_timezone_update_validates(monkeypatch: pytest.MonkeyPatch)
)
assert send_payloads[-1].ok is False
assert "timezone must be one of" in send_payloads[-1].message.lower()
@pytest.mark.asyncio
async def test_failed_wheel_use_does_not_consume_cooldown(monkeypatch: pytest.MonkeyPatch) -> None:
server = SignalingServer("127.0.0.1", 8765, None, None)
ws = _fake_ws()
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
server.clients[ws] = client
item = server.item_service.default_item(client, "wheel")
item.params["spaces"] = ",,,"
server.item_service.add_item(item)
send_payloads: list[object] = []
now_ms = 40_000
async def fake_send(websocket: ServerConnection, packet: object) -> None:
send_payloads.append(packet)
async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None:
return
monkeypatch.setattr(server, "_send", fake_send)
monkeypatch.setattr(server, "_broadcast", fake_broadcast)
monkeypatch.setattr(server.item_service, "now_ms", lambda: now_ms)
await server._handle_message(client, json.dumps({"type": "item_use", "itemId": item.id}))
assert send_payloads[-1].ok is False
assert "spaces" in send_payloads[-1].message.lower()
item.params["spaces"] = "a,b,c"
await server._handle_message(client, json.dumps({"type": "item_use", "itemId": item.id}))
assert send_payloads[-1].ok is True

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
import asyncio
import json
from time import monotonic
from typing import cast
import pytest
from websockets.asyncio.server import ServerConnection
from app.client import ClientConnection
from app.server import SignalingServer
def _fake_ws() -> ServerConnection:
return cast(ServerConnection, object())
@pytest.mark.asyncio
async def test_update_position_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41)
ws = _fake_ws()
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
server.clients[ws] = client
broadcast_payloads: list[object] = []
async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None:
broadcast_payloads.append(packet)
monkeypatch.setattr(server, "_broadcast", fake_broadcast)
await server._handle_message(client, json.dumps({"type": "update_position", "x": 200, "y": -5}))
assert client.x == 5
assert client.y == 6
assert broadcast_payloads == []
@pytest.mark.asyncio
async def test_item_drop_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41)
ws = _fake_ws()
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
server.clients[ws] = client
item = server.item_service.default_item(client, "dice")
item.carrierId = client.id
server.item_service.add_item(item)
send_payloads: list[object] = []
async def fake_send(websocket: ServerConnection, packet: object) -> None:
send_payloads.append(packet)
monkeypatch.setattr(server, "_send", fake_send)
await server._handle_message(client, json.dumps({"type": "item_drop", "itemId": item.id, "x": 999, "y": 999}))
assert item.carrierId == client.id
assert send_payloads[-1].ok is False
assert "out of bounds" in send_payloads[-1].message.lower()
@pytest.mark.asyncio
async def test_broadcast_fanout_is_concurrent(monkeypatch: pytest.MonkeyPatch) -> None:
server = SignalingServer("127.0.0.1", 8765, None, None)
ws1 = _fake_ws()
ws2 = _fake_ws()
server.clients[ws1] = ClientConnection(websocket=ws1, id="u1")
server.clients[ws2] = ClientConnection(websocket=ws2, id="u2")
send_started_at: dict[ServerConnection, float] = {}
async def fake_send(websocket: ServerConnection, packet: object) -> None:
send_started_at[websocket] = monotonic()
if websocket is ws1:
await asyncio.sleep(0.05)
monkeypatch.setattr(server, "_send", fake_send)
await server._broadcast({"type": "noop"})
assert ws1 in send_started_at
assert ws2 in send_started_at
assert abs(send_started_at[ws1] - send_started_at[ws2]) < 0.02