Enforce websocket origin allowlist with secure-mode config
This commit is contained in:
@@ -13,10 +13,13 @@ Key options:
|
||||
- `server.bind_ip`, `server.port`
|
||||
- `network.max_message_bytes`
|
||||
- `network.allow_insecure_ws`
|
||||
- `network.allowed_origins`
|
||||
- `tls.cert_file`, `tls.key_file`
|
||||
|
||||
If `network.allow_insecure_ws = false`, TLS cert/key are required and server runs as `wss://`.
|
||||
For local/dev without TLS, either set `network.allow_insecure_ws = true` or pass `--allow-insecure-ws`.
|
||||
When insecure ws is disabled, `network.allowed_origins` must list your deployed `https://` origins.
|
||||
When insecure ws is enabled and `network.allowed_origins` is empty, localhost dev origins are allowed automatically.
|
||||
|
||||
## Run
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ class NetworkConfigSection(BaseModel):
|
||||
|
||||
max_message_bytes: int = Field(default=2_000_000, gt=0)
|
||||
allow_insecure_ws: bool = False
|
||||
allowed_origins: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TlsConfigSection(BaseModel):
|
||||
|
||||
@@ -127,6 +127,7 @@ AUTH_SESSION_COOKIE_CLEAR_PATH = "/auth/session/clear"
|
||||
AUTH_SESSION_COOKIE_CLIENT_HEADER = "X-Chgrid-Auth-Client"
|
||||
AUTH_LOGIN_FAILURE_MESSAGE = "We couldn't log you in. Check your details and try again."
|
||||
AUTH_RESUME_FAILURE_MESSAGE = "We couldn't restore your session. Please log in again."
|
||||
LOCAL_DEV_ALLOWED_ORIGINS: tuple[str, ...] = ("http://localhost:5173", "http://127.0.0.1:5173")
|
||||
ADMIN_MENU_ACTION_DEFINITIONS: tuple[dict[str, str], ...] = (
|
||||
{"id": "manage_roles", "label": "Role management", "permission": "role.manage"},
|
||||
{"id": "change_user_role", "label": "Change user role", "permission": "user.change_role"},
|
||||
@@ -144,6 +145,7 @@ class SignalingServer:
|
||||
port: int,
|
||||
ssl_cert: str | None,
|
||||
ssl_key: str | None,
|
||||
allowed_origins: tuple[str, ...] | list[str] | None = None,
|
||||
auth_db_path: Path | None = None,
|
||||
auth_token_hash_secret: str = "dev-secret",
|
||||
password_min_length: int = 8,
|
||||
@@ -161,6 +163,7 @@ class SignalingServer:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.max_message_size = max_message_size
|
||||
self.allowed_origins = tuple(allowed_origins or ())
|
||||
self._ssl_context = self._build_ssl_context(ssl_cert, ssl_key)
|
||||
self.clients: dict[ServerConnection, ClientConnection] = {}
|
||||
resolved_auth_db_path = auth_db_path or Path.cwd() / "runtime" / "chatgrid.db"
|
||||
@@ -1302,6 +1305,7 @@ class SignalingServer:
|
||||
self._handle_client,
|
||||
self.host,
|
||||
self.port,
|
||||
origins=self.allowed_origins if self.allowed_origins else None,
|
||||
ssl=self._ssl_context,
|
||||
max_size=self.max_message_size,
|
||||
process_request=self._process_http_request,
|
||||
@@ -2860,6 +2864,10 @@ def run() -> None:
|
||||
raise SystemExit(
|
||||
"TLS is required when insecure ws is disabled. Set tls.cert_file/tls.key_file in config.toml."
|
||||
)
|
||||
try:
|
||||
allowed_origins = _resolve_allowed_origins(config.network.allowed_origins, allow_insecure_ws=allow_insecure_ws)
|
||||
except ValueError as exc:
|
||||
raise SystemExit(str(exc)) from exc
|
||||
|
||||
auth_secret = os.getenv("CHGRID_AUTH_SECRET", "").strip()
|
||||
if not auth_secret:
|
||||
@@ -2992,6 +3000,31 @@ def run() -> None:
|
||||
grid_size=config.world.grid_size,
|
||||
state_save_debounce_ms=config.storage.state_save_debounce_ms,
|
||||
state_save_max_delay_ms=config.storage.state_save_max_delay_ms,
|
||||
allowed_origins=allowed_origins,
|
||||
)
|
||||
asyncio.run(server.start())
|
||||
ItemClockAnnouncePacket,
|
||||
|
||||
|
||||
def _resolve_allowed_origins(raw_origins: list[str], *, allow_insecure_ws: bool) -> tuple[str, ...]:
|
||||
"""Resolve websocket Origin allowlist from config and transport mode."""
|
||||
|
||||
normalized: list[str] = []
|
||||
for origin in raw_origins:
|
||||
candidate = str(origin or "").strip()
|
||||
if not candidate or candidate in normalized:
|
||||
continue
|
||||
normalized.append(candidate)
|
||||
|
||||
if allow_insecure_ws:
|
||||
if normalized:
|
||||
return tuple(normalized)
|
||||
return LOCAL_DEV_ALLOWED_ORIGINS
|
||||
|
||||
if not normalized:
|
||||
raise ValueError(
|
||||
"network.allowed_origins must list your https web origin(s) when insecure ws is disabled."
|
||||
)
|
||||
non_https = [origin for origin in normalized if not origin.lower().startswith("https://")]
|
||||
if non_https:
|
||||
raise ValueError("network.allowed_origins must use https origins when insecure ws is disabled.")
|
||||
return tuple(normalized)
|
||||
|
||||
@@ -9,6 +9,10 @@ port = 8765
|
||||
max_message_bytes = 2000000
|
||||
# Secure-by-default: TLS is required unless you explicitly set this to true for local/dev.
|
||||
allow_insecure_ws = false
|
||||
# Allowed websocket request Origin values.
|
||||
# Production: list your deployed https web origins explicitly.
|
||||
# Local/dev: when allow_insecure_ws=true and this list is empty, localhost defaults are used.
|
||||
allowed_origins = ["https://bestmidi.com", "https://www.bestmidi.com"]
|
||||
|
||||
[tls]
|
||||
# Required when allow_insecure_ws = false.
|
||||
|
||||
@@ -9,6 +9,7 @@ def test_load_config_defaults_when_path_none() -> None:
|
||||
cfg = load_config(None)
|
||||
assert cfg.server.bind_ip == "127.0.0.1"
|
||||
assert cfg.network.allow_insecure_ws is False
|
||||
assert cfg.network.allowed_origins == []
|
||||
assert cfg.storage.state_file == "runtime/items.json"
|
||||
assert cfg.storage.state_save_debounce_ms == 200
|
||||
assert cfg.storage.state_save_max_delay_ms == 1000
|
||||
@@ -43,3 +44,16 @@ state_save_max_delay_ms = 900
|
||||
cfg = load_config(config_path)
|
||||
assert cfg.storage.state_save_debounce_ms == 150
|
||||
assert cfg.storage.state_save_max_delay_ms == 900
|
||||
|
||||
|
||||
def test_load_config_reads_allowed_origins(tmp_path: Path) -> None:
|
||||
config_path = tmp_path / "config.toml"
|
||||
config_path.write_text(
|
||||
"""
|
||||
[network]
|
||||
allow_insecure_ws = true
|
||||
allowed_origins = ["https://bestmidi.com", "https://www.bestmidi.com"]
|
||||
""".strip()
|
||||
)
|
||||
cfg = load_config(config_path)
|
||||
assert cfg.network.allowed_origins == ["https://bestmidi.com", "https://www.bestmidi.com"]
|
||||
|
||||
28
server/tests/test_origin_policy.py
Normal file
28
server/tests/test_origin_policy.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.server import LOCAL_DEV_ALLOWED_ORIGINS, _resolve_allowed_origins
|
||||
|
||||
|
||||
def test_resolve_allowed_origins_defaults_localhost_for_insecure_mode() -> None:
|
||||
origins = _resolve_allowed_origins([], allow_insecure_ws=True)
|
||||
assert origins == LOCAL_DEV_ALLOWED_ORIGINS
|
||||
|
||||
|
||||
def test_resolve_allowed_origins_requires_values_for_secure_mode() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_resolve_allowed_origins([], allow_insecure_ws=False)
|
||||
|
||||
|
||||
def test_resolve_allowed_origins_requires_https_in_secure_mode() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_resolve_allowed_origins(["http://localhost:5173"], allow_insecure_ws=False)
|
||||
|
||||
|
||||
def test_resolve_allowed_origins_normalizes_and_deduplicates() -> None:
|
||||
origins = _resolve_allowed_origins(
|
||||
[" https://bestmidi.com ", "https://bestmidi.com", "https://www.bestmidi.com"],
|
||||
allow_insecure_ws=False,
|
||||
)
|
||||
assert origins == ("https://bestmidi.com", "https://www.bestmidi.com")
|
||||
Reference in New Issue
Block a user