Fix server bounds validation, cooldown timing, and broadcast fanout
This commit is contained in:
@@ -41,6 +41,12 @@ class StorageConfigSection(BaseModel):
|
||||
state_file: str = "runtime/items.json"
|
||||
|
||||
|
||||
class WorldConfigSection(BaseModel):
|
||||
"""Authoritative world geometry options."""
|
||||
|
||||
grid_size: int = Field(default=41, ge=1)
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Top-level application configuration document."""
|
||||
|
||||
@@ -49,6 +55,7 @@ class AppConfig(BaseModel):
|
||||
tls: TlsConfigSection = TlsConfigSection()
|
||||
logging: LoggingConfigSection = LoggingConfigSection()
|
||||
storage: StorageConfigSection = StorageConfigSection()
|
||||
world: WorldConfigSection = WorldConfigSection()
|
||||
|
||||
|
||||
def load_config(path: Path | None) -> AppConfig:
|
||||
|
||||
@@ -67,6 +67,7 @@ class SignalingServer:
|
||||
ssl_key: str | None,
|
||||
max_message_size: int = 2_000_000,
|
||||
state_file: Path | None = None,
|
||||
grid_size: int = 41,
|
||||
):
|
||||
"""Initialize runtime state, TLS context, and item service."""
|
||||
|
||||
@@ -77,6 +78,7 @@ class SignalingServer:
|
||||
self.clients: dict[ServerConnection, ClientConnection] = {}
|
||||
self.item_service = ItemService(state_file=state_file)
|
||||
self.item_last_use_ms: dict[str, int] = {}
|
||||
self.grid_size = max(1, grid_size)
|
||||
|
||||
@property
|
||||
def items(self) -> dict[str, WorldItem]:
|
||||
@@ -106,6 +108,11 @@ class SignalingServer:
|
||||
|
||||
return "radio" if item.type == "radio_station" else item.type
|
||||
|
||||
def _is_in_bounds(self, x: int, y: int) -> bool:
|
||||
"""Check whether a coordinate is inside server-authoritative world bounds."""
|
||||
|
||||
return 0 <= x < self.grid_size and 0 <= y < self.grid_size
|
||||
|
||||
@staticmethod
|
||||
def _normalize_clock_timezone(value: object) -> str:
|
||||
"""Normalize timezone input to one of supported clock zones."""
|
||||
@@ -270,6 +277,15 @@ class SignalingServer:
|
||||
return
|
||||
|
||||
if isinstance(packet, UpdatePositionPacket):
|
||||
if not self._is_in_bounds(packet.x, packet.y):
|
||||
PACKET_LOGGER.warning(
|
||||
"out-of-bounds position ignored id=%s x=%d y=%d grid_size=%d",
|
||||
client.id,
|
||||
packet.x,
|
||||
packet.y,
|
||||
self.grid_size,
|
||||
)
|
||||
return
|
||||
client.x = packet.x
|
||||
client.y = packet.y
|
||||
await self._broadcast(
|
||||
@@ -456,6 +472,9 @@ class SignalingServer:
|
||||
if item.carrierId != client.id:
|
||||
await self._send_item_result(client, False, "drop", "You are not carrying that item.", item.id)
|
||||
return
|
||||
if not self._is_in_bounds(packet.x, packet.y):
|
||||
await self._send_item_result(client, False, "drop", "Drop position is out of bounds.", item.id)
|
||||
return
|
||||
item.carrierId = None
|
||||
item.x = packet.x
|
||||
item.y = packet.y
|
||||
@@ -518,7 +537,6 @@ class SignalingServer:
|
||||
item.id,
|
||||
)
|
||||
return
|
||||
self.item_last_use_ms[item.id] = now_ms
|
||||
delayed_wheel_self_result: str | None = None
|
||||
delayed_wheel_others_result: str | None = None
|
||||
if item.type == "radio_station":
|
||||
@@ -578,6 +596,7 @@ class SignalingServer:
|
||||
display_time = self._format_clock_display_time(item.params)
|
||||
others_message = f"{client.nickname} checks {item.title}. {item.title} says {display_time}."
|
||||
self_message = f"{item.title} says {display_time}."
|
||||
self.item_last_use_ms[item.id] = now_ms
|
||||
await self._broadcast(
|
||||
BroadcastChatMessagePacket(type="chat_message", message=others_message, system=True),
|
||||
exclude=client.websocket,
|
||||
@@ -790,10 +809,10 @@ class SignalingServer:
|
||||
async def _broadcast(self, packet: object, exclude: ServerConnection | None = None) -> None:
|
||||
"""Broadcast one packet to all clients except an optional websocket."""
|
||||
|
||||
for websocket in list(self.clients.keys()):
|
||||
if websocket is exclude:
|
||||
continue
|
||||
await self._send(websocket, packet)
|
||||
recipients = [websocket for websocket in self.clients if websocket is not exclude]
|
||||
if not recipients:
|
||||
return
|
||||
await asyncio.gather(*(self._send(websocket, packet) for websocket in recipients))
|
||||
|
||||
async def _send(self, websocket: ServerConnection, packet: object) -> None:
|
||||
"""Send one packet to one websocket, swallowing per-client send failures."""
|
||||
@@ -875,5 +894,6 @@ def run() -> None:
|
||||
ssl_key,
|
||||
max_message_size=config.network.max_message_bytes,
|
||||
state_file=state_file,
|
||||
grid_size=config.world.grid_size,
|
||||
)
|
||||
asyncio.run(server.start())
|
||||
|
||||
@@ -21,3 +21,7 @@ level = "INFO"
|
||||
[storage]
|
||||
# Item persistence file. Relative paths are resolved from this config file directory.
|
||||
state_file = "runtime/items.json"
|
||||
|
||||
[world]
|
||||
# Grid width/height in cells. Valid coordinates are 0..grid_size-1.
|
||||
grid_size = 41
|
||||
|
||||
@@ -10,6 +10,7 @@ def test_load_config_defaults_when_path_none() -> None:
|
||||
assert cfg.server.bind_ip == "127.0.0.1"
|
||||
assert cfg.network.allow_insecure_ws is True
|
||||
assert cfg.storage.state_file == "runtime/items.json"
|
||||
assert cfg.world.grid_size == 41
|
||||
|
||||
|
||||
def test_load_config_requires_tls_when_insecure_disabled(tmp_path: Path) -> None:
|
||||
|
||||
@@ -182,3 +182,35 @@ async def test_clock_timezone_update_validates(monkeypatch: pytest.MonkeyPatch)
|
||||
)
|
||||
assert send_payloads[-1].ok is False
|
||||
assert "timezone must be one of" in send_payloads[-1].message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_wheel_use_does_not_consume_cooldown(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
server = SignalingServer("127.0.0.1", 8765, None, None)
|
||||
ws = _fake_ws()
|
||||
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
|
||||
server.clients[ws] = client
|
||||
item = server.item_service.default_item(client, "wheel")
|
||||
item.params["spaces"] = ",,,"
|
||||
server.item_service.add_item(item)
|
||||
|
||||
send_payloads: list[object] = []
|
||||
now_ms = 40_000
|
||||
|
||||
async def fake_send(websocket: ServerConnection, packet: object) -> None:
|
||||
send_payloads.append(packet)
|
||||
|
||||
async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None:
|
||||
return
|
||||
|
||||
monkeypatch.setattr(server, "_send", fake_send)
|
||||
monkeypatch.setattr(server, "_broadcast", fake_broadcast)
|
||||
monkeypatch.setattr(server.item_service, "now_ms", lambda: now_ms)
|
||||
|
||||
await server._handle_message(client, json.dumps({"type": "item_use", "itemId": item.id}))
|
||||
assert send_payloads[-1].ok is False
|
||||
assert "spaces" in send_payloads[-1].message.lower()
|
||||
|
||||
item.params["spaces"] = "a,b,c"
|
||||
await server._handle_message(client, json.dumps({"type": "item_use", "itemId": item.id}))
|
||||
assert send_payloads[-1].ok is True
|
||||
|
||||
85
server/tests/test_server_message_handling.py
Normal file
85
server/tests/test_server_message_handling.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from time import monotonic
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from websockets.asyncio.server import ServerConnection
|
||||
|
||||
from app.client import ClientConnection
|
||||
from app.server import SignalingServer
|
||||
|
||||
|
||||
def _fake_ws() -> ServerConnection:
|
||||
return cast(ServerConnection, object())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_position_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41)
|
||||
ws = _fake_ws()
|
||||
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
|
||||
server.clients[ws] = client
|
||||
|
||||
broadcast_payloads: list[object] = []
|
||||
|
||||
async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None:
|
||||
broadcast_payloads.append(packet)
|
||||
|
||||
monkeypatch.setattr(server, "_broadcast", fake_broadcast)
|
||||
|
||||
await server._handle_message(client, json.dumps({"type": "update_position", "x": 200, "y": -5}))
|
||||
|
||||
assert client.x == 5
|
||||
assert client.y == 6
|
||||
assert broadcast_payloads == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_item_drop_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41)
|
||||
ws = _fake_ws()
|
||||
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
|
||||
server.clients[ws] = client
|
||||
item = server.item_service.default_item(client, "dice")
|
||||
item.carrierId = client.id
|
||||
server.item_service.add_item(item)
|
||||
|
||||
send_payloads: list[object] = []
|
||||
|
||||
async def fake_send(websocket: ServerConnection, packet: object) -> None:
|
||||
send_payloads.append(packet)
|
||||
|
||||
monkeypatch.setattr(server, "_send", fake_send)
|
||||
|
||||
await server._handle_message(client, json.dumps({"type": "item_drop", "itemId": item.id, "x": 999, "y": 999}))
|
||||
|
||||
assert item.carrierId == client.id
|
||||
assert send_payloads[-1].ok is False
|
||||
assert "out of bounds" in send_payloads[-1].message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_fanout_is_concurrent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
server = SignalingServer("127.0.0.1", 8765, None, None)
|
||||
ws1 = _fake_ws()
|
||||
ws2 = _fake_ws()
|
||||
server.clients[ws1] = ClientConnection(websocket=ws1, id="u1")
|
||||
server.clients[ws2] = ClientConnection(websocket=ws2, id="u2")
|
||||
|
||||
send_started_at: dict[ServerConnection, float] = {}
|
||||
|
||||
async def fake_send(websocket: ServerConnection, packet: object) -> None:
|
||||
send_started_at[websocket] = monotonic()
|
||||
if websocket is ws1:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
monkeypatch.setattr(server, "_send", fake_send)
|
||||
|
||||
await server._broadcast({"type": "noop"})
|
||||
|
||||
assert ws1 in send_started_at
|
||||
assert ws2 in send_started_at
|
||||
assert abs(send_started_at[ws1] - send_started_at[ws2]) < 0.02
|
||||
Reference in New Issue
Block a user