Enforce websocket origin allowlist with secure-mode config

This commit is contained in:
Jage9
2026-02-28 04:47:07 -05:00
parent 9f3cd1fbdc
commit cf30229b37
9 changed files with 87 additions and 1 deletions

View File

@@ -9,6 +9,7 @@ def test_load_config_defaults_when_path_none() -> None:
cfg = load_config(None)
assert cfg.server.bind_ip == "127.0.0.1"
assert cfg.network.allow_insecure_ws is False
assert cfg.network.allowed_origins == []
assert cfg.storage.state_file == "runtime/items.json"
assert cfg.storage.state_save_debounce_ms == 200
assert cfg.storage.state_save_max_delay_ms == 1000
@@ -43,3 +44,16 @@ state_save_max_delay_ms = 900
cfg = load_config(config_path)
assert cfg.storage.state_save_debounce_ms == 150
assert cfg.storage.state_save_max_delay_ms == 900
def test_load_config_reads_allowed_origins(tmp_path: Path) -> None:
config_path = tmp_path / "config.toml"
config_path.write_text(
"""
[network]
allow_insecure_ws = true
allowed_origins = ["https://bestmidi.com", "https://www.bestmidi.com"]
""".strip()
)
cfg = load_config(config_path)
assert cfg.network.allowed_origins == ["https://bestmidi.com", "https://www.bestmidi.com"]

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import pytest
from app.server import LOCAL_DEV_ALLOWED_ORIGINS, _resolve_allowed_origins
def test_resolve_allowed_origins_defaults_localhost_for_insecure_mode() -> None:
origins = _resolve_allowed_origins([], allow_insecure_ws=True)
assert origins == LOCAL_DEV_ALLOWED_ORIGINS
def test_resolve_allowed_origins_requires_values_for_secure_mode() -> None:
with pytest.raises(ValueError):
_resolve_allowed_origins([], allow_insecure_ws=False)
def test_resolve_allowed_origins_requires_https_in_secure_mode() -> None:
with pytest.raises(ValueError):
_resolve_allowed_origins(["http://localhost:5173"], allow_insecure_ws=False)
def test_resolve_allowed_origins_normalizes_and_deduplicates() -> None:
origins = _resolve_allowed_origins(
[" https://bestmidi.com ", "https://bestmidi.com", "https://www.bestmidi.com"],
allow_insecure_ws=False,
)
assert origins == ("https://bestmidi.com", "https://www.bestmidi.com")