Files
chat_grid/server/app/network_security.py
2026-03-08 20:51:50 -04:00

166 lines
5.8 KiB
Python

"""Helpers for browser-origin policy and SSRF-safe outbound URL handling."""
from __future__ import annotations
import ipaddress
import socket
from typing import Iterable
from urllib.error import HTTPError
from urllib.parse import urljoin, urlsplit, urlunsplit
from urllib.request import HTTPRedirectHandler, Request, build_opener
IpAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
class _NoRedirectHandler(HTTPRedirectHandler):
"""Disable automatic redirects so each hop can be revalidated."""
def redirect_request(self, req, fp, code, msg, headers, newurl): # type: ignore[override]
"""Return None so urllib surfaces redirects as HTTPError objects."""
return None
_NO_REDIRECT_OPENER = build_opener(_NoRedirectHandler)
def _format_host(host: str) -> str:
"""Return one hostname/IP suitable for URL netloc reconstruction."""
if ":" in host and not host.startswith("["):
return f"[{host}]"
return host
def _normalize_netloc(parts) -> str:
"""Rebuild one normalized netloc from parsed URL parts."""
if not parts.hostname:
raise ValueError("host is required")
netloc = _format_host(parts.hostname.lower())
if parts.port is not None:
netloc = f"{netloc}:{parts.port}"
return netloc
def normalize_origin(value: str, *, field_name: str = "origin") -> str:
"""Validate and normalize one browser origin string."""
text = value.strip()
if not text:
raise ValueError(f"{field_name} must not be empty.")
try:
parts = urlsplit(text)
netloc = _normalize_netloc(parts)
except ValueError as exc:
raise ValueError(f"{field_name} must be a valid http/https origin.") from exc
scheme = parts.scheme.lower()
if scheme not in {"http", "https"}:
raise ValueError(f"{field_name} must use http or https.")
if parts.username is not None or parts.password is not None:
raise ValueError(f"{field_name} must not include credentials.")
if parts.path not in {"", "/"} or parts.query or parts.fragment:
raise ValueError(f"{field_name} must not include path, query, or fragment.")
return urlunsplit((scheme, netloc, "", "", ""))
def _resolve_host_ips(host: str) -> set[IpAddress]:
"""Resolve one hostname or IP literal to concrete IP addresses."""
try:
return {ipaddress.ip_address(host)}
except ValueError:
pass
resolved: set[IpAddress] = set()
try:
infos = socket.getaddrinfo(host, None, type=socket.SOCK_STREAM)
except socket.gaierror as exc:
raise ValueError("DNS resolution failed.") from exc
for family, _type, _proto, _canonname, sockaddr in infos:
if family == socket.AF_INET:
resolved.add(ipaddress.ip_address(sockaddr[0]))
elif family == socket.AF_INET6:
resolved.add(ipaddress.ip_address(sockaddr[0]))
if not resolved:
raise ValueError("DNS resolution failed.")
return resolved
def _ensure_public_ips(addresses: Iterable[IpAddress], *, field_name: str) -> None:
"""Reject non-public IP addresses for SSRF-sensitive outbound requests."""
for address in addresses:
if not address.is_global:
raise ValueError(f"{field_name} must resolve to a public IP address.")
def validate_public_media_url(value: str, *, field_name: str = "url") -> str:
"""Validate and normalize one public http/https media URL."""
text = value.strip()
if not text:
return ""
try:
parts = urlsplit(text)
netloc = _normalize_netloc(parts)
except ValueError as exc:
raise ValueError(f"{field_name} must be a valid http/https URL.") from exc
scheme = parts.scheme.lower()
if scheme not in {"http", "https"}:
raise ValueError(f"{field_name} must use http or https.")
if parts.username is not None or parts.password is not None:
raise ValueError(f"{field_name} must not include credentials.")
_ensure_public_ips(_resolve_host_ips(parts.hostname or ""), field_name=field_name)
return urlunsplit((scheme, netloc, parts.path, parts.query, parts.fragment))
def validate_media_reference(value: str, *, field_name: str = "url") -> str:
"""Validate one media reference as either a public URL or a site-relative path."""
text = value.strip()
if not text:
return ""
parts = urlsplit(text)
if parts.scheme:
return validate_public_media_url(text, field_name=field_name)
if parts.netloc:
raise ValueError(f"{field_name} must use http or https when specifying a host.")
if not text.startswith("/"):
raise ValueError(f"{field_name} must be an absolute http/https URL or site-relative path.")
return text
def open_validated_public_url(
url: str,
*,
headers: dict[str, str] | None = None,
timeout: float = 6.0,
max_redirects: int = 5,
):
"""Open one public media URL while revalidating each redirect target."""
current_url = validate_public_media_url(url)
request_headers = headers or {}
for redirect_count in range(max_redirects + 1):
request = Request(current_url, headers=request_headers)
try:
return _NO_REDIRECT_OPENER.open(request, timeout=timeout)
except HTTPError as exc:
try:
if 300 <= exc.code < 400:
if redirect_count >= max_redirects:
raise ValueError("Too many redirects.")
location = str(exc.headers.get("Location") or "").strip()
if not location:
raise ValueError("Redirect location missing or invalid.")
current_url = validate_public_media_url(urljoin(current_url, location))
continue
raise
finally:
exc.close()
raise ValueError("Too many redirects.")