166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
|
|
"""Helpers for browser-origin policy and SSRF-safe outbound URL handling."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import ipaddress
|
||
|
|
import socket
|
||
|
|
from typing import Iterable
|
||
|
|
from urllib.error import HTTPError
|
||
|
|
from urllib.parse import urljoin, urlsplit, urlunsplit
|
||
|
|
from urllib.request import HTTPRedirectHandler, Request, build_opener
|
||
|
|
|
||
|
|
IpAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
|
||
|
|
|
||
|
|
|
||
|
|
class _NoRedirectHandler(HTTPRedirectHandler):
|
||
|
|
"""Disable automatic redirects so each hop can be revalidated."""
|
||
|
|
|
||
|
|
def redirect_request(self, req, fp, code, msg, headers, newurl): # type: ignore[override]
|
||
|
|
"""Return None so urllib surfaces redirects as HTTPError objects."""
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
_NO_REDIRECT_OPENER = build_opener(_NoRedirectHandler)
|
||
|
|
|
||
|
|
|
||
|
|
def _format_host(host: str) -> str:
|
||
|
|
"""Return one hostname/IP suitable for URL netloc reconstruction."""
|
||
|
|
|
||
|
|
if ":" in host and not host.startswith("["):
|
||
|
|
return f"[{host}]"
|
||
|
|
return host
|
||
|
|
|
||
|
|
|
||
|
|
def _normalize_netloc(parts) -> str:
|
||
|
|
"""Rebuild one normalized netloc from parsed URL parts."""
|
||
|
|
|
||
|
|
if not parts.hostname:
|
||
|
|
raise ValueError("host is required")
|
||
|
|
netloc = _format_host(parts.hostname.lower())
|
||
|
|
if parts.port is not None:
|
||
|
|
netloc = f"{netloc}:{parts.port}"
|
||
|
|
return netloc
|
||
|
|
|
||
|
|
|
||
|
|
def normalize_origin(value: str, *, field_name: str = "origin") -> str:
|
||
|
|
"""Validate and normalize one browser origin string."""
|
||
|
|
|
||
|
|
text = value.strip()
|
||
|
|
if not text:
|
||
|
|
raise ValueError(f"{field_name} must not be empty.")
|
||
|
|
try:
|
||
|
|
parts = urlsplit(text)
|
||
|
|
netloc = _normalize_netloc(parts)
|
||
|
|
except ValueError as exc:
|
||
|
|
raise ValueError(f"{field_name} must be a valid http/https origin.") from exc
|
||
|
|
|
||
|
|
scheme = parts.scheme.lower()
|
||
|
|
if scheme not in {"http", "https"}:
|
||
|
|
raise ValueError(f"{field_name} must use http or https.")
|
||
|
|
if parts.username is not None or parts.password is not None:
|
||
|
|
raise ValueError(f"{field_name} must not include credentials.")
|
||
|
|
if parts.path not in {"", "/"} or parts.query or parts.fragment:
|
||
|
|
raise ValueError(f"{field_name} must not include path, query, or fragment.")
|
||
|
|
return urlunsplit((scheme, netloc, "", "", ""))
|
||
|
|
|
||
|
|
|
||
|
|
def _resolve_host_ips(host: str) -> set[IpAddress]:
|
||
|
|
"""Resolve one hostname or IP literal to concrete IP addresses."""
|
||
|
|
|
||
|
|
try:
|
||
|
|
return {ipaddress.ip_address(host)}
|
||
|
|
except ValueError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
resolved: set[IpAddress] = set()
|
||
|
|
try:
|
||
|
|
infos = socket.getaddrinfo(host, None, type=socket.SOCK_STREAM)
|
||
|
|
except socket.gaierror as exc:
|
||
|
|
raise ValueError("DNS resolution failed.") from exc
|
||
|
|
for family, _type, _proto, _canonname, sockaddr in infos:
|
||
|
|
if family == socket.AF_INET:
|
||
|
|
resolved.add(ipaddress.ip_address(sockaddr[0]))
|
||
|
|
elif family == socket.AF_INET6:
|
||
|
|
resolved.add(ipaddress.ip_address(sockaddr[0]))
|
||
|
|
if not resolved:
|
||
|
|
raise ValueError("DNS resolution failed.")
|
||
|
|
return resolved
|
||
|
|
|
||
|
|
|
||
|
|
def _ensure_public_ips(addresses: Iterable[IpAddress], *, field_name: str) -> None:
|
||
|
|
"""Reject non-public IP addresses for SSRF-sensitive outbound requests."""
|
||
|
|
|
||
|
|
for address in addresses:
|
||
|
|
if not address.is_global:
|
||
|
|
raise ValueError(f"{field_name} must resolve to a public IP address.")
|
||
|
|
|
||
|
|
|
||
|
|
def validate_public_media_url(value: str, *, field_name: str = "url") -> str:
|
||
|
|
"""Validate and normalize one public http/https media URL."""
|
||
|
|
|
||
|
|
text = value.strip()
|
||
|
|
if not text:
|
||
|
|
return ""
|
||
|
|
|
||
|
|
try:
|
||
|
|
parts = urlsplit(text)
|
||
|
|
netloc = _normalize_netloc(parts)
|
||
|
|
except ValueError as exc:
|
||
|
|
raise ValueError(f"{field_name} must be a valid http/https URL.") from exc
|
||
|
|
|
||
|
|
scheme = parts.scheme.lower()
|
||
|
|
if scheme not in {"http", "https"}:
|
||
|
|
raise ValueError(f"{field_name} must use http or https.")
|
||
|
|
if parts.username is not None or parts.password is not None:
|
||
|
|
raise ValueError(f"{field_name} must not include credentials.")
|
||
|
|
_ensure_public_ips(_resolve_host_ips(parts.hostname or ""), field_name=field_name)
|
||
|
|
return urlunsplit((scheme, netloc, parts.path, parts.query, parts.fragment))
|
||
|
|
|
||
|
|
|
||
|
|
def validate_media_reference(value: str, *, field_name: str = "url") -> str:
|
||
|
|
"""Validate one media reference as either a public URL or a site-relative path."""
|
||
|
|
|
||
|
|
text = value.strip()
|
||
|
|
if not text:
|
||
|
|
return ""
|
||
|
|
parts = urlsplit(text)
|
||
|
|
if parts.scheme:
|
||
|
|
return validate_public_media_url(text, field_name=field_name)
|
||
|
|
if parts.netloc:
|
||
|
|
raise ValueError(f"{field_name} must use http or https when specifying a host.")
|
||
|
|
if not text.startswith("/"):
|
||
|
|
raise ValueError(f"{field_name} must be an absolute http/https URL or site-relative path.")
|
||
|
|
return text
|
||
|
|
|
||
|
|
|
||
|
|
def open_validated_public_url(
|
||
|
|
url: str,
|
||
|
|
*,
|
||
|
|
headers: dict[str, str] | None = None,
|
||
|
|
timeout: float = 6.0,
|
||
|
|
max_redirects: int = 5,
|
||
|
|
):
|
||
|
|
"""Open one public media URL while revalidating each redirect target."""
|
||
|
|
|
||
|
|
current_url = validate_public_media_url(url)
|
||
|
|
request_headers = headers or {}
|
||
|
|
for redirect_count in range(max_redirects + 1):
|
||
|
|
request = Request(current_url, headers=request_headers)
|
||
|
|
try:
|
||
|
|
return _NO_REDIRECT_OPENER.open(request, timeout=timeout)
|
||
|
|
except HTTPError as exc:
|
||
|
|
try:
|
||
|
|
if 300 <= exc.code < 400:
|
||
|
|
if redirect_count >= max_redirects:
|
||
|
|
raise ValueError("Too many redirects.")
|
||
|
|
location = str(exc.headers.get("Location") or "").strip()
|
||
|
|
if not location:
|
||
|
|
raise ValueError("Redirect location missing or invalid.")
|
||
|
|
current_url = validate_public_media_url(urljoin(current_url, location))
|
||
|
|
continue
|
||
|
|
raise
|
||
|
|
finally:
|
||
|
|
exc.close()
|
||
|
|
raise ValueError("Too many redirects.")
|