Fix server bounds validation, cooldown timing, and broadcast fanout
This commit is contained in:
@@ -41,6 +41,12 @@ class StorageConfigSection(BaseModel):
|
|||||||
state_file: str = "runtime/items.json"
|
state_file: str = "runtime/items.json"
|
||||||
|
|
||||||
|
|
||||||
|
class WorldConfigSection(BaseModel):
|
||||||
|
"""Authoritative world geometry options."""
|
||||||
|
|
||||||
|
grid_size: int = Field(default=41, ge=1)
|
||||||
|
|
||||||
|
|
||||||
class AppConfig(BaseModel):
|
class AppConfig(BaseModel):
|
||||||
"""Top-level application configuration document."""
|
"""Top-level application configuration document."""
|
||||||
|
|
||||||
@@ -49,6 +55,7 @@ class AppConfig(BaseModel):
|
|||||||
tls: TlsConfigSection = TlsConfigSection()
|
tls: TlsConfigSection = TlsConfigSection()
|
||||||
logging: LoggingConfigSection = LoggingConfigSection()
|
logging: LoggingConfigSection = LoggingConfigSection()
|
||||||
storage: StorageConfigSection = StorageConfigSection()
|
storage: StorageConfigSection = StorageConfigSection()
|
||||||
|
world: WorldConfigSection = WorldConfigSection()
|
||||||
|
|
||||||
|
|
||||||
def load_config(path: Path | None) -> AppConfig:
|
def load_config(path: Path | None) -> AppConfig:
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class SignalingServer:
|
|||||||
ssl_key: str | None,
|
ssl_key: str | None,
|
||||||
max_message_size: int = 2_000_000,
|
max_message_size: int = 2_000_000,
|
||||||
state_file: Path | None = None,
|
state_file: Path | None = None,
|
||||||
|
grid_size: int = 41,
|
||||||
):
|
):
|
||||||
"""Initialize runtime state, TLS context, and item service."""
|
"""Initialize runtime state, TLS context, and item service."""
|
||||||
|
|
||||||
@@ -77,6 +78,7 @@ class SignalingServer:
|
|||||||
self.clients: dict[ServerConnection, ClientConnection] = {}
|
self.clients: dict[ServerConnection, ClientConnection] = {}
|
||||||
self.item_service = ItemService(state_file=state_file)
|
self.item_service = ItemService(state_file=state_file)
|
||||||
self.item_last_use_ms: dict[str, int] = {}
|
self.item_last_use_ms: dict[str, int] = {}
|
||||||
|
self.grid_size = max(1, grid_size)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def items(self) -> dict[str, WorldItem]:
|
def items(self) -> dict[str, WorldItem]:
|
||||||
@@ -106,6 +108,11 @@ class SignalingServer:
|
|||||||
|
|
||||||
return "radio" if item.type == "radio_station" else item.type
|
return "radio" if item.type == "radio_station" else item.type
|
||||||
|
|
||||||
|
def _is_in_bounds(self, x: int, y: int) -> bool:
|
||||||
|
"""Check whether a coordinate is inside server-authoritative world bounds."""
|
||||||
|
|
||||||
|
return 0 <= x < self.grid_size and 0 <= y < self.grid_size
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_clock_timezone(value: object) -> str:
|
def _normalize_clock_timezone(value: object) -> str:
|
||||||
"""Normalize timezone input to one of supported clock zones."""
|
"""Normalize timezone input to one of supported clock zones."""
|
||||||
@@ -270,6 +277,15 @@ class SignalingServer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(packet, UpdatePositionPacket):
|
if isinstance(packet, UpdatePositionPacket):
|
||||||
|
if not self._is_in_bounds(packet.x, packet.y):
|
||||||
|
PACKET_LOGGER.warning(
|
||||||
|
"out-of-bounds position ignored id=%s x=%d y=%d grid_size=%d",
|
||||||
|
client.id,
|
||||||
|
packet.x,
|
||||||
|
packet.y,
|
||||||
|
self.grid_size,
|
||||||
|
)
|
||||||
|
return
|
||||||
client.x = packet.x
|
client.x = packet.x
|
||||||
client.y = packet.y
|
client.y = packet.y
|
||||||
await self._broadcast(
|
await self._broadcast(
|
||||||
@@ -456,6 +472,9 @@ class SignalingServer:
|
|||||||
if item.carrierId != client.id:
|
if item.carrierId != client.id:
|
||||||
await self._send_item_result(client, False, "drop", "You are not carrying that item.", item.id)
|
await self._send_item_result(client, False, "drop", "You are not carrying that item.", item.id)
|
||||||
return
|
return
|
||||||
|
if not self._is_in_bounds(packet.x, packet.y):
|
||||||
|
await self._send_item_result(client, False, "drop", "Drop position is out of bounds.", item.id)
|
||||||
|
return
|
||||||
item.carrierId = None
|
item.carrierId = None
|
||||||
item.x = packet.x
|
item.x = packet.x
|
||||||
item.y = packet.y
|
item.y = packet.y
|
||||||
@@ -518,7 +537,6 @@ class SignalingServer:
|
|||||||
item.id,
|
item.id,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
self.item_last_use_ms[item.id] = now_ms
|
|
||||||
delayed_wheel_self_result: str | None = None
|
delayed_wheel_self_result: str | None = None
|
||||||
delayed_wheel_others_result: str | None = None
|
delayed_wheel_others_result: str | None = None
|
||||||
if item.type == "radio_station":
|
if item.type == "radio_station":
|
||||||
@@ -578,6 +596,7 @@ class SignalingServer:
|
|||||||
display_time = self._format_clock_display_time(item.params)
|
display_time = self._format_clock_display_time(item.params)
|
||||||
others_message = f"{client.nickname} checks {item.title}. {item.title} says {display_time}."
|
others_message = f"{client.nickname} checks {item.title}. {item.title} says {display_time}."
|
||||||
self_message = f"{item.title} says {display_time}."
|
self_message = f"{item.title} says {display_time}."
|
||||||
|
self.item_last_use_ms[item.id] = now_ms
|
||||||
await self._broadcast(
|
await self._broadcast(
|
||||||
BroadcastChatMessagePacket(type="chat_message", message=others_message, system=True),
|
BroadcastChatMessagePacket(type="chat_message", message=others_message, system=True),
|
||||||
exclude=client.websocket,
|
exclude=client.websocket,
|
||||||
@@ -790,10 +809,10 @@ class SignalingServer:
|
|||||||
async def _broadcast(self, packet: object, exclude: ServerConnection | None = None) -> None:
|
async def _broadcast(self, packet: object, exclude: ServerConnection | None = None) -> None:
|
||||||
"""Broadcast one packet to all clients except an optional websocket."""
|
"""Broadcast one packet to all clients except an optional websocket."""
|
||||||
|
|
||||||
for websocket in list(self.clients.keys()):
|
recipients = [websocket for websocket in self.clients if websocket is not exclude]
|
||||||
if websocket is exclude:
|
if not recipients:
|
||||||
continue
|
return
|
||||||
await self._send(websocket, packet)
|
await asyncio.gather(*(self._send(websocket, packet) for websocket in recipients))
|
||||||
|
|
||||||
async def _send(self, websocket: ServerConnection, packet: object) -> None:
|
async def _send(self, websocket: ServerConnection, packet: object) -> None:
|
||||||
"""Send one packet to one websocket, swallowing per-client send failures."""
|
"""Send one packet to one websocket, swallowing per-client send failures."""
|
||||||
@@ -875,5 +894,6 @@ def run() -> None:
|
|||||||
ssl_key,
|
ssl_key,
|
||||||
max_message_size=config.network.max_message_bytes,
|
max_message_size=config.network.max_message_bytes,
|
||||||
state_file=state_file,
|
state_file=state_file,
|
||||||
|
grid_size=config.world.grid_size,
|
||||||
)
|
)
|
||||||
asyncio.run(server.start())
|
asyncio.run(server.start())
|
||||||
|
|||||||
@@ -21,3 +21,7 @@ level = "INFO"
|
|||||||
[storage]
|
[storage]
|
||||||
# Item persistence file. Relative paths are resolved from this config file directory.
|
# Item persistence file. Relative paths are resolved from this config file directory.
|
||||||
state_file = "runtime/items.json"
|
state_file = "runtime/items.json"
|
||||||
|
|
||||||
|
[world]
|
||||||
|
# Grid width/height in cells. Valid coordinates are 0..grid_size-1.
|
||||||
|
grid_size = 41
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ def test_load_config_defaults_when_path_none() -> None:
|
|||||||
assert cfg.server.bind_ip == "127.0.0.1"
|
assert cfg.server.bind_ip == "127.0.0.1"
|
||||||
assert cfg.network.allow_insecure_ws is True
|
assert cfg.network.allow_insecure_ws is True
|
||||||
assert cfg.storage.state_file == "runtime/items.json"
|
assert cfg.storage.state_file == "runtime/items.json"
|
||||||
|
assert cfg.world.grid_size == 41
|
||||||
|
|
||||||
|
|
||||||
def test_load_config_requires_tls_when_insecure_disabled(tmp_path: Path) -> None:
|
def test_load_config_requires_tls_when_insecure_disabled(tmp_path: Path) -> None:
|
||||||
|
|||||||
@@ -182,3 +182,35 @@ async def test_clock_timezone_update_validates(monkeypatch: pytest.MonkeyPatch)
|
|||||||
)
|
)
|
||||||
assert send_payloads[-1].ok is False
|
assert send_payloads[-1].ok is False
|
||||||
assert "timezone must be one of" in send_payloads[-1].message.lower()
|
assert "timezone must be one of" in send_payloads[-1].message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failed_wheel_use_does_not_consume_cooldown(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
server = SignalingServer("127.0.0.1", 8765, None, None)
|
||||||
|
ws = _fake_ws()
|
||||||
|
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
|
||||||
|
server.clients[ws] = client
|
||||||
|
item = server.item_service.default_item(client, "wheel")
|
||||||
|
item.params["spaces"] = ",,,"
|
||||||
|
server.item_service.add_item(item)
|
||||||
|
|
||||||
|
send_payloads: list[object] = []
|
||||||
|
now_ms = 40_000
|
||||||
|
|
||||||
|
async def fake_send(websocket: ServerConnection, packet: object) -> None:
|
||||||
|
send_payloads.append(packet)
|
||||||
|
|
||||||
|
async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
monkeypatch.setattr(server, "_send", fake_send)
|
||||||
|
monkeypatch.setattr(server, "_broadcast", fake_broadcast)
|
||||||
|
monkeypatch.setattr(server.item_service, "now_ms", lambda: now_ms)
|
||||||
|
|
||||||
|
await server._handle_message(client, json.dumps({"type": "item_use", "itemId": item.id}))
|
||||||
|
assert send_payloads[-1].ok is False
|
||||||
|
assert "spaces" in send_payloads[-1].message.lower()
|
||||||
|
|
||||||
|
item.params["spaces"] = "a,b,c"
|
||||||
|
await server._handle_message(client, json.dumps({"type": "item_use", "itemId": item.id}))
|
||||||
|
assert send_payloads[-1].ok is True
|
||||||
|
|||||||
85
server/tests/test_server_message_handling.py
Normal file
85
server/tests/test_server_message_handling.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from time import monotonic
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from websockets.asyncio.server import ServerConnection
|
||||||
|
|
||||||
|
from app.client import ClientConnection
|
||||||
|
from app.server import SignalingServer
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_ws() -> ServerConnection:
|
||||||
|
return cast(ServerConnection, object())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_position_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41)
|
||||||
|
ws = _fake_ws()
|
||||||
|
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
|
||||||
|
server.clients[ws] = client
|
||||||
|
|
||||||
|
broadcast_payloads: list[object] = []
|
||||||
|
|
||||||
|
async def fake_broadcast(packet: object, exclude: ServerConnection | None = None) -> None:
|
||||||
|
broadcast_payloads.append(packet)
|
||||||
|
|
||||||
|
monkeypatch.setattr(server, "_broadcast", fake_broadcast)
|
||||||
|
|
||||||
|
await server._handle_message(client, json.dumps({"type": "update_position", "x": 200, "y": -5}))
|
||||||
|
|
||||||
|
assert client.x == 5
|
||||||
|
assert client.y == 6
|
||||||
|
assert broadcast_payloads == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_item_drop_rejects_out_of_bounds(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
server = SignalingServer("127.0.0.1", 8765, None, None, grid_size=41)
|
||||||
|
ws = _fake_ws()
|
||||||
|
client = ClientConnection(websocket=ws, id="u1", nickname="tester", x=5, y=6)
|
||||||
|
server.clients[ws] = client
|
||||||
|
item = server.item_service.default_item(client, "dice")
|
||||||
|
item.carrierId = client.id
|
||||||
|
server.item_service.add_item(item)
|
||||||
|
|
||||||
|
send_payloads: list[object] = []
|
||||||
|
|
||||||
|
async def fake_send(websocket: ServerConnection, packet: object) -> None:
|
||||||
|
send_payloads.append(packet)
|
||||||
|
|
||||||
|
monkeypatch.setattr(server, "_send", fake_send)
|
||||||
|
|
||||||
|
await server._handle_message(client, json.dumps({"type": "item_drop", "itemId": item.id, "x": 999, "y": 999}))
|
||||||
|
|
||||||
|
assert item.carrierId == client.id
|
||||||
|
assert send_payloads[-1].ok is False
|
||||||
|
assert "out of bounds" in send_payloads[-1].message.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_fanout_is_concurrent(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
server = SignalingServer("127.0.0.1", 8765, None, None)
|
||||||
|
ws1 = _fake_ws()
|
||||||
|
ws2 = _fake_ws()
|
||||||
|
server.clients[ws1] = ClientConnection(websocket=ws1, id="u1")
|
||||||
|
server.clients[ws2] = ClientConnection(websocket=ws2, id="u2")
|
||||||
|
|
||||||
|
send_started_at: dict[ServerConnection, float] = {}
|
||||||
|
|
||||||
|
async def fake_send(websocket: ServerConnection, packet: object) -> None:
|
||||||
|
send_started_at[websocket] = monotonic()
|
||||||
|
if websocket is ws1:
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
monkeypatch.setattr(server, "_send", fake_send)
|
||||||
|
|
||||||
|
await server._broadcast({"type": "noop"})
|
||||||
|
|
||||||
|
assert ws1 in send_started_at
|
||||||
|
assert ws2 in send_started_at
|
||||||
|
assert abs(send_started_at[ws1] - send_started_at[ws2]) < 0.02
|
||||||
Reference in New Issue
Block a user