Harden origin and media URL security

This commit is contained in:
Jage9
2026-03-08 20:51:50 -04:00
parent 3d69bbcea2
commit 78bc931cce
12 changed files with 378 additions and 14 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from ....network_security import validate_media_reference
from ....models import WorldItem
from ...sound_policy import enforce_max_length, normalize_media_reference
from ...helpers import keep_only_known_params
@@ -16,6 +17,7 @@ def validate_update(item: WorldItem, next_params: dict) -> dict:
max_length=2048,
field_name="streamUrl",
)
next_params["streamUrl"] = validate_media_reference(next_params["streamUrl"], field_name="streamUrl")
enabled_value = next_params.get("enabled", True)
if isinstance(enabled_value, bool):

View File

@@ -0,0 +1,165 @@
"""Helpers for browser-origin policy and SSRF-safe outbound URL handling."""
from __future__ import annotations
import ipaddress
import socket
from typing import Iterable
from urllib.error import HTTPError
from urllib.parse import urljoin, urlsplit, urlunsplit
from urllib.request import HTTPRedirectHandler, Request, build_opener
IpAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
class _NoRedirectHandler(HTTPRedirectHandler):
"""Disable automatic redirects so each hop can be revalidated."""
def redirect_request(self, req, fp, code, msg, headers, newurl): # type: ignore[override]
"""Return None so urllib surfaces redirects as HTTPError objects."""
return None
_NO_REDIRECT_OPENER = build_opener(_NoRedirectHandler)
def _format_host(host: str) -> str:
"""Return one hostname/IP suitable for URL netloc reconstruction."""
if ":" in host and not host.startswith("["):
return f"[{host}]"
return host
def _normalize_netloc(parts) -> str:
"""Rebuild one normalized netloc from parsed URL parts."""
if not parts.hostname:
raise ValueError("host is required")
netloc = _format_host(parts.hostname.lower())
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
return netloc
def normalize_origin(value: str, *, field_name: str = "origin") -> str:
"""Validate and normalize one browser origin string."""
text = value.strip()
if not text:
raise ValueError(f"{field_name} must not be empty.")
try:
parts = urlsplit(text)
netloc = _normalize_netloc(parts)
except ValueError as exc:
raise ValueError(f"{field_name} must be a valid http/https origin.") from exc
scheme = parts.scheme.lower()
if scheme not in {"http", "https"}:
raise ValueError(f"{field_name} must use http or https.")
if parts.username is not None or parts.password is not None:
raise ValueError(f"{field_name} must not include credentials.")
if parts.path not in {"", "/"} or parts.query or parts.fragment:
raise ValueError(f"{field_name} must not include path, query, or fragment.")
return urlunsplit((scheme, netloc, "", "", ""))
def _resolve_host_ips(host: str) -> set[IpAddress]:
"""Resolve one hostname or IP literal to concrete IP addresses."""
try:
return {ipaddress.ip_address(host)}
except ValueError:
pass
resolved: set[IpAddress] = set()
try:
infos = socket.getaddrinfo(host, None, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise ValueError("DNS resolution failed.") from exc
for family, _type, _proto, _canonname, sockaddr in infos:
if family == socket.AF_INET:
resolved.add(ipaddress.ip_address(sockaddr[0]))
elif family == socket.AF_INET6:
resolved.add(ipaddress.ip_address(sockaddr[0]))
if not resolved:
raise ValueError("DNS resolution failed.")
return resolved
def _ensure_public_ips(addresses: Iterable[IpAddress], *, field_name: str) -> None:
"""Reject non-public IP addresses for SSRF-sensitive outbound requests."""
for address in addresses:
if not address.is_global:
raise ValueError(f"{field_name} must resolve to a public IP address.")
def validate_public_media_url(value: str, *, field_name: str = "url") -> str:
"""Validate and normalize one public http/https media URL."""
text = value.strip()
if not text:
return ""
try:
parts = urlsplit(text)
netloc = _normalize_netloc(parts)
except ValueError as exc:
raise ValueError(f"{field_name} must be a valid http/https URL.") from exc
scheme = parts.scheme.lower()
if scheme not in {"http", "https"}:
raise ValueError(f"{field_name} must use http or https.")
if parts.username is not None or parts.password is not None:
raise ValueError(f"{field_name} must not include credentials.")
_ensure_public_ips(_resolve_host_ips(parts.hostname or ""), field_name=field_name)
return urlunsplit((scheme, netloc, parts.path, parts.query, parts.fragment))
def validate_media_reference(value: str, *, field_name: str = "url") -> str:
"""Validate one media reference as either a public URL or a site-relative path."""
text = value.strip()
if not text:
return ""
parts = urlsplit(text)
if parts.scheme:
return validate_public_media_url(text, field_name=field_name)
if parts.netloc:
raise ValueError(f"{field_name} must use http or https when specifying a host.")
if not text.startswith("/"):
raise ValueError(f"{field_name} must be an absolute http/https URL or site-relative path.")
return text
def open_validated_public_url(
url: str,
*,
headers: dict[str, str] | None = None,
timeout: float = 6.0,
max_redirects: int = 5,
):
"""Open one public media URL while revalidating each redirect target."""
current_url = validate_public_media_url(url)
request_headers = headers or {}
for redirect_count in range(max_redirects + 1):
request = Request(current_url, headers=request_headers)
try:
return _NO_REDIRECT_OPENER.open(request, timeout=timeout)
except HTTPError as exc:
try:
if 300 <= exc.code < 400:
if redirect_count >= max_redirects:
raise ValueError("Too many redirects.")
location = str(exc.headers.get("Location") or "").strip()
if not location:
raise ValueError("Redirect location missing or invalid.")
current_url = validate_public_media_url(urljoin(current_url, location))
continue
raise
finally:
exc.close()
raise ValueError("Too many redirects.")

View File

@@ -22,7 +22,6 @@ import uuid
from pathlib import Path
from typing import Literal
from urllib.error import URLError
from urllib.request import Request, urlopen
from zoneinfo import ZoneInfo
from pydantic import ValidationError, TypeAdapter
@@ -107,6 +106,7 @@ from .models import (
WelcomePacket,
WorldItem,
)
from .network_security import normalize_origin, open_validated_public_url
from .ui_metadata import (
ADMIN_MENU_ACTION_DEFINITIONS,
ITEM_MANAGEMENT_ACTION_DEFINITIONS,
@@ -160,6 +160,7 @@ class SignalingServer:
grid_size: int = 41,
state_save_debounce_ms: int = 200,
state_save_max_delay_ms: int = 1000,
host_origin: str | None = None,
):
"""Initialize runtime state, TLS context, and item service."""
@@ -187,6 +188,7 @@ class SignalingServer:
self.movement_max_steps_per_tick = MOVEMENT_MAX_STEPS_PER_TICK
self.instance_id = str(uuid.uuid4())
self.server_version = self._resolve_server_version()
self.host_origin = normalize_origin(host_origin, field_name="host origin") if host_origin else None
self.state_save_debounce_ms = max(1, int(state_save_debounce_ms))
self.state_save_max_delay_ms = max(self.state_save_debounce_ms, int(state_save_max_delay_ms))
self._pending_state_save_handle: asyncio.TimerHandle | None = None
@@ -676,11 +678,11 @@ class SignalingServer:
if not stream_url:
return "", ""
try:
request = Request(
with open_validated_public_url(
stream_url,
headers={"Icy-MetaData": "1", "User-Agent": "ChatGrid"},
)
with urlopen(request, timeout=RADIO_METADATA_TIMEOUT_S) as response:
timeout=RADIO_METADATA_TIMEOUT_S,
) as response:
station = str(response.headers.get("icy-name") or response.headers.get("ice-name") or "").strip()
title = ""
metaint_raw = response.headers.get("icy-metaint")
@@ -1355,6 +1357,7 @@ class SignalingServer:
self.port,
ssl=self._ssl_context,
max_size=self.max_message_size,
origins=[self.host_origin] if self.host_origin else None,
process_request=self._process_http_request,
):
await asyncio.Future()
@@ -3078,6 +3081,13 @@ def run() -> None:
auth_secret = os.getenv("CHGRID_AUTH_SECRET", "").strip()
if not auth_secret:
raise SystemExit("CHGRID_AUTH_SECRET is required.")
host_origin = os.getenv("CHGRID_HOST_ORIGIN", "").strip()
if not host_origin:
raise SystemExit("CHGRID_HOST_ORIGIN is required.")
try:
host_origin = normalize_origin(host_origin, field_name="CHGRID_HOST_ORIGIN")
except ValueError as exc:
raise SystemExit(str(exc)) from exc
auth_db_value = config.auth.db_file.strip()
if not auth_db_value:
raise SystemExit("auth.db_file must not be empty.")
@@ -3206,6 +3216,6 @@ def run() -> None:
grid_size=config.world.grid_size,
state_save_debounce_ms=config.storage.state_save_debounce_ms,
state_save_max_delay_ms=config.storage.state_save_max_delay_ms,
host_origin=host_origin,
)
asyncio.run(server.start())
ItemClockAnnouncePacket,

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import socket
import pytest
from app.network_security import normalize_origin, validate_media_reference, validate_public_media_url
def test_normalize_origin_rejects_paths() -> None:
with pytest.raises(ValueError):
normalize_origin("https://example.com/chgrid")
def test_normalize_origin_normalizes_case_and_trailing_slash() -> None:
assert normalize_origin("HTTPS://Example.COM:443/") == "https://example.com:443"
def test_validate_public_media_url_rejects_private_ip() -> None:
with pytest.raises(ValueError):
validate_public_media_url("http://127.0.0.1/audio")
def test_validate_public_media_url_resolves_hostname(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_getaddrinfo(host: str, port, type: int = 0):
assert host == "radio.example.com"
return [(socket.AF_INET, type, 6, "", ("93.184.216.34", 0))]
monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo)
assert validate_public_media_url("https://Radio.Example.com/live") == "https://radio.example.com/live"
def test_validate_public_media_url_rejects_private_resolution(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_getaddrinfo(host: str, port, type: int = 0):
assert host == "radio.example.com"
return [(socket.AF_INET, type, 6, "", ("10.0.0.5", 0))]
monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo)
with pytest.raises(ValueError):
validate_public_media_url("https://radio.example.com/live")
def test_validate_media_reference_allows_site_relative_path() -> None:
assert validate_media_reference("/chgrid/media_proxy.php?url=test") == "/chgrid/media_proxy.php?url=test"