Harden origin and media URL security
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ....network_security import validate_media_reference
|
||||
from ....models import WorldItem
|
||||
from ...sound_policy import enforce_max_length, normalize_media_reference
|
||||
from ...helpers import keep_only_known_params
|
||||
@@ -16,6 +17,7 @@ def validate_update(item: WorldItem, next_params: dict) -> dict:
|
||||
max_length=2048,
|
||||
field_name="streamUrl",
|
||||
)
|
||||
next_params["streamUrl"] = validate_media_reference(next_params["streamUrl"], field_name="streamUrl")
|
||||
|
||||
enabled_value = next_params.get("enabled", True)
|
||||
if isinstance(enabled_value, bool):
|
||||
|
||||
165
server/app/network_security.py
Normal file
165
server/app/network_security.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Helpers for browser-origin policy and SSRF-safe outbound URL handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from typing import Iterable
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import urljoin, urlsplit, urlunsplit
|
||||
from urllib.request import HTTPRedirectHandler, Request, build_opener
|
||||
|
||||
IpAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
|
||||
|
||||
|
||||
class _NoRedirectHandler(HTTPRedirectHandler):
|
||||
"""Disable automatic redirects so each hop can be revalidated."""
|
||||
|
||||
def redirect_request(self, req, fp, code, msg, headers, newurl): # type: ignore[override]
|
||||
"""Return None so urllib surfaces redirects as HTTPError objects."""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
_NO_REDIRECT_OPENER = build_opener(_NoRedirectHandler)
|
||||
|
||||
|
||||
def _format_host(host: str) -> str:
|
||||
"""Return one hostname/IP suitable for URL netloc reconstruction."""
|
||||
|
||||
if ":" in host and not host.startswith("["):
|
||||
return f"[{host}]"
|
||||
return host
|
||||
|
||||
|
||||
def _normalize_netloc(parts) -> str:
|
||||
"""Rebuild one normalized netloc from parsed URL parts."""
|
||||
|
||||
if not parts.hostname:
|
||||
raise ValueError("host is required")
|
||||
netloc = _format_host(parts.hostname.lower())
|
||||
if parts.port is not None:
|
||||
netloc = f"{netloc}:{parts.port}"
|
||||
return netloc
|
||||
|
||||
|
||||
def normalize_origin(value: str, *, field_name: str = "origin") -> str:
|
||||
"""Validate and normalize one browser origin string."""
|
||||
|
||||
text = value.strip()
|
||||
if not text:
|
||||
raise ValueError(f"{field_name} must not be empty.")
|
||||
try:
|
||||
parts = urlsplit(text)
|
||||
netloc = _normalize_netloc(parts)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"{field_name} must be a valid http/https origin.") from exc
|
||||
|
||||
scheme = parts.scheme.lower()
|
||||
if scheme not in {"http", "https"}:
|
||||
raise ValueError(f"{field_name} must use http or https.")
|
||||
if parts.username is not None or parts.password is not None:
|
||||
raise ValueError(f"{field_name} must not include credentials.")
|
||||
if parts.path not in {"", "/"} or parts.query or parts.fragment:
|
||||
raise ValueError(f"{field_name} must not include path, query, or fragment.")
|
||||
return urlunsplit((scheme, netloc, "", "", ""))
|
||||
|
||||
|
||||
def _resolve_host_ips(host: str) -> set[IpAddress]:
|
||||
"""Resolve one hostname or IP literal to concrete IP addresses."""
|
||||
|
||||
try:
|
||||
return {ipaddress.ip_address(host)}
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
resolved: set[IpAddress] = set()
|
||||
try:
|
||||
infos = socket.getaddrinfo(host, None, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror as exc:
|
||||
raise ValueError("DNS resolution failed.") from exc
|
||||
for family, _type, _proto, _canonname, sockaddr in infos:
|
||||
if family == socket.AF_INET:
|
||||
resolved.add(ipaddress.ip_address(sockaddr[0]))
|
||||
elif family == socket.AF_INET6:
|
||||
resolved.add(ipaddress.ip_address(sockaddr[0]))
|
||||
if not resolved:
|
||||
raise ValueError("DNS resolution failed.")
|
||||
return resolved
|
||||
|
||||
|
||||
def _ensure_public_ips(addresses: Iterable[IpAddress], *, field_name: str) -> None:
|
||||
"""Reject non-public IP addresses for SSRF-sensitive outbound requests."""
|
||||
|
||||
for address in addresses:
|
||||
if not address.is_global:
|
||||
raise ValueError(f"{field_name} must resolve to a public IP address.")
|
||||
|
||||
|
||||
def validate_public_media_url(value: str, *, field_name: str = "url") -> str:
|
||||
"""Validate and normalize one public http/https media URL."""
|
||||
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
try:
|
||||
parts = urlsplit(text)
|
||||
netloc = _normalize_netloc(parts)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"{field_name} must be a valid http/https URL.") from exc
|
||||
|
||||
scheme = parts.scheme.lower()
|
||||
if scheme not in {"http", "https"}:
|
||||
raise ValueError(f"{field_name} must use http or https.")
|
||||
if parts.username is not None or parts.password is not None:
|
||||
raise ValueError(f"{field_name} must not include credentials.")
|
||||
_ensure_public_ips(_resolve_host_ips(parts.hostname or ""), field_name=field_name)
|
||||
return urlunsplit((scheme, netloc, parts.path, parts.query, parts.fragment))
|
||||
|
||||
|
||||
def validate_media_reference(value: str, *, field_name: str = "url") -> str:
|
||||
"""Validate one media reference as either a public URL or a site-relative path."""
|
||||
|
||||
text = value.strip()
|
||||
if not text:
|
||||
return ""
|
||||
parts = urlsplit(text)
|
||||
if parts.scheme:
|
||||
return validate_public_media_url(text, field_name=field_name)
|
||||
if parts.netloc:
|
||||
raise ValueError(f"{field_name} must use http or https when specifying a host.")
|
||||
if not text.startswith("/"):
|
||||
raise ValueError(f"{field_name} must be an absolute http/https URL or site-relative path.")
|
||||
return text
|
||||
|
||||
|
||||
def open_validated_public_url(
|
||||
url: str,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float = 6.0,
|
||||
max_redirects: int = 5,
|
||||
):
|
||||
"""Open one public media URL while revalidating each redirect target."""
|
||||
|
||||
current_url = validate_public_media_url(url)
|
||||
request_headers = headers or {}
|
||||
for redirect_count in range(max_redirects + 1):
|
||||
request = Request(current_url, headers=request_headers)
|
||||
try:
|
||||
return _NO_REDIRECT_OPENER.open(request, timeout=timeout)
|
||||
except HTTPError as exc:
|
||||
try:
|
||||
if 300 <= exc.code < 400:
|
||||
if redirect_count >= max_redirects:
|
||||
raise ValueError("Too many redirects.")
|
||||
location = str(exc.headers.get("Location") or "").strip()
|
||||
if not location:
|
||||
raise ValueError("Redirect location missing or invalid.")
|
||||
current_url = validate_public_media_url(urljoin(current_url, location))
|
||||
continue
|
||||
raise
|
||||
finally:
|
||||
exc.close()
|
||||
raise ValueError("Too many redirects.")
|
||||
@@ -22,7 +22,6 @@ import uuid
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from pydantic import ValidationError, TypeAdapter
|
||||
@@ -107,6 +106,7 @@ from .models import (
|
||||
WelcomePacket,
|
||||
WorldItem,
|
||||
)
|
||||
from .network_security import normalize_origin, open_validated_public_url
|
||||
from .ui_metadata import (
|
||||
ADMIN_MENU_ACTION_DEFINITIONS,
|
||||
ITEM_MANAGEMENT_ACTION_DEFINITIONS,
|
||||
@@ -160,6 +160,7 @@ class SignalingServer:
|
||||
grid_size: int = 41,
|
||||
state_save_debounce_ms: int = 200,
|
||||
state_save_max_delay_ms: int = 1000,
|
||||
host_origin: str | None = None,
|
||||
):
|
||||
"""Initialize runtime state, TLS context, and item service."""
|
||||
|
||||
@@ -187,6 +188,7 @@ class SignalingServer:
|
||||
self.movement_max_steps_per_tick = MOVEMENT_MAX_STEPS_PER_TICK
|
||||
self.instance_id = str(uuid.uuid4())
|
||||
self.server_version = self._resolve_server_version()
|
||||
self.host_origin = normalize_origin(host_origin, field_name="host origin") if host_origin else None
|
||||
self.state_save_debounce_ms = max(1, int(state_save_debounce_ms))
|
||||
self.state_save_max_delay_ms = max(self.state_save_debounce_ms, int(state_save_max_delay_ms))
|
||||
self._pending_state_save_handle: asyncio.TimerHandle | None = None
|
||||
@@ -676,11 +678,11 @@ class SignalingServer:
|
||||
if not stream_url:
|
||||
return "", ""
|
||||
try:
|
||||
request = Request(
|
||||
with open_validated_public_url(
|
||||
stream_url,
|
||||
headers={"Icy-MetaData": "1", "User-Agent": "ChatGrid"},
|
||||
)
|
||||
with urlopen(request, timeout=RADIO_METADATA_TIMEOUT_S) as response:
|
||||
timeout=RADIO_METADATA_TIMEOUT_S,
|
||||
) as response:
|
||||
station = str(response.headers.get("icy-name") or response.headers.get("ice-name") or "").strip()
|
||||
title = ""
|
||||
metaint_raw = response.headers.get("icy-metaint")
|
||||
@@ -1355,6 +1357,7 @@ class SignalingServer:
|
||||
self.port,
|
||||
ssl=self._ssl_context,
|
||||
max_size=self.max_message_size,
|
||||
origins=[self.host_origin] if self.host_origin else None,
|
||||
process_request=self._process_http_request,
|
||||
):
|
||||
await asyncio.Future()
|
||||
@@ -3078,6 +3081,13 @@ def run() -> None:
|
||||
auth_secret = os.getenv("CHGRID_AUTH_SECRET", "").strip()
|
||||
if not auth_secret:
|
||||
raise SystemExit("CHGRID_AUTH_SECRET is required.")
|
||||
host_origin = os.getenv("CHGRID_HOST_ORIGIN", "").strip()
|
||||
if not host_origin:
|
||||
raise SystemExit("CHGRID_HOST_ORIGIN is required.")
|
||||
try:
|
||||
host_origin = normalize_origin(host_origin, field_name="CHGRID_HOST_ORIGIN")
|
||||
except ValueError as exc:
|
||||
raise SystemExit(str(exc)) from exc
|
||||
auth_db_value = config.auth.db_file.strip()
|
||||
if not auth_db_value:
|
||||
raise SystemExit("auth.db_file must not be empty.")
|
||||
@@ -3206,6 +3216,6 @@ def run() -> None:
|
||||
grid_size=config.world.grid_size,
|
||||
state_save_debounce_ms=config.storage.state_save_debounce_ms,
|
||||
state_save_max_delay_ms=config.storage.state_save_max_delay_ms,
|
||||
host_origin=host_origin,
|
||||
)
|
||||
asyncio.run(server.start())
|
||||
ItemClockAnnouncePacket,
|
||||
|
||||
46
server/tests/test_network_security.py
Normal file
46
server/tests/test_network_security.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
|
||||
from app.network_security import normalize_origin, validate_media_reference, validate_public_media_url
|
||||
|
||||
|
||||
def test_normalize_origin_rejects_paths() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
normalize_origin("https://example.com/chgrid")
|
||||
|
||||
|
||||
def test_normalize_origin_normalizes_case_and_trailing_slash() -> None:
|
||||
assert normalize_origin("HTTPS://Example.COM:443/") == "https://example.com:443"
|
||||
|
||||
|
||||
def test_validate_public_media_url_rejects_private_ip() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
validate_public_media_url("http://127.0.0.1/audio")
|
||||
|
||||
|
||||
def test_validate_public_media_url_resolves_hostname(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_getaddrinfo(host: str, port, type: int = 0):
|
||||
assert host == "radio.example.com"
|
||||
return [(socket.AF_INET, type, 6, "", ("93.184.216.34", 0))]
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
assert validate_public_media_url("https://Radio.Example.com/live") == "https://radio.example.com/live"
|
||||
|
||||
|
||||
def test_validate_public_media_url_rejects_private_resolution(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_getaddrinfo(host: str, port, type: int = 0):
|
||||
assert host == "radio.example.com"
|
||||
return [(socket.AF_INET, type, 6, "", ("10.0.0.5", 0))]
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
validate_public_media_url("https://radio.example.com/live")
|
||||
|
||||
|
||||
def test_validate_media_reference_allows_site_relative_path() -> None:
|
||||
assert validate_media_reference("/chgrid/media_proxy.php?url=test") == "/chgrid/media_proxy.php?url=test"
|
||||
Reference in New Issue
Block a user