Scope server routes by base path

This commit is contained in:
Jage9
2026-03-08 22:24:32 -04:00
parent bd0ec1b01e
commit 54a7a3085b
14 changed files with 113 additions and 47 deletions

View File

@@ -13,6 +13,7 @@ class ServerConfigSection(BaseModel):
bind_ip: str = "127.0.0.1"
port: int = 8765
base_path: str = "/"
class NetworkConfigSection(BaseModel):

View File

@@ -133,9 +133,10 @@ RADIO_METADATA_TIMEOUT_S = 6.0
CLOCK_ANNOUNCE_POLL_INTERVAL_S = 1.0
AUTH_SESSION_COOKIE_NAME = "chgrid_session_token"
AUTH_SESSION_COOKIE_MAX_AGE_SECONDS = 14 * 24 * 60 * 60
AUTH_SESSION_COOKIE_SET_PATH = "/auth/session/set"
AUTH_SESSION_COOKIE_CLEAR_PATH = "/auth/session/clear"
AUTH_SESSION_COOKIE_CHECK_PATH = "/auth/session/check"
AUTH_SESSION_COOKIE_SET_PATH = "auth/session/set"
AUTH_SESSION_COOKIE_CLEAR_PATH = "auth/session/clear"
AUTH_SESSION_COOKIE_CHECK_PATH = "auth/session/check"
WEBSOCKET_PATH = "ws"
AUTH_SESSION_COOKIE_CLIENT_HEADER = "X-Chgrid-Auth-Client"
AUTH_LOGIN_FAILURE_MESSAGE = "We couldn't log you in. Check your details and try again."
AUTH_RESUME_FAILURE_MESSAGE = "We couldn't restore your session. Please log in again."
@@ -162,6 +163,7 @@ class SignalingServer:
state_save_debounce_ms: int = 200,
state_save_max_delay_ms: int = 1000,
host_origin: str | None = None,
base_path: str = "/",
):
"""Initialize runtime state, TLS context, and item service."""
@@ -190,6 +192,11 @@ class SignalingServer:
self.instance_id = str(uuid.uuid4())
self.server_version = self._resolve_server_version()
self.host_origin = normalize_origin(host_origin, field_name="host origin") if host_origin else None
self.base_path = self._normalize_base_path(base_path)
self.websocket_path = self._base_path_join(WEBSOCKET_PATH)
self.auth_session_cookie_set_path = self._base_path_join(AUTH_SESSION_COOKIE_SET_PATH)
self.auth_session_cookie_clear_path = self._base_path_join(AUTH_SESSION_COOKIE_CLEAR_PATH)
self.auth_session_cookie_check_path = self._base_path_join(AUTH_SESSION_COOKIE_CHECK_PATH)
self.state_save_debounce_ms = max(1, int(state_save_debounce_ms))
self.state_save_max_delay_ms = max(self.state_save_debounce_ms, int(state_save_max_delay_ms))
self._pending_state_save_handle: asyncio.TimerHandle | None = None
@@ -263,6 +270,23 @@ class SignalingServer:
"passwordMaxLength": self.auth_service.password_max_length,
}
@staticmethod
def _normalize_base_path(value: str) -> str:
"""Normalize one instance base path to leading/trailing slash form."""
text = str(value).strip()
if not text or text == "/":
return "/"
return f"/{text.strip('/')}/"
def _base_path_join(self, suffix: str) -> str:
"""Join one instance-relative route suffix to the configured base path."""
token = suffix.lstrip("/")
if self.base_path == "/":
return f"/{token}"
return f"{self.base_path}{token}"
def _session_cookie_secure(self, request: HttpRequest | None = None) -> bool:
"""Return True when session cookies should be marked Secure."""
@@ -278,7 +302,7 @@ class SignalingServer:
secure = "; Secure" if self._session_cookie_secure(request) else ""
return (
f"{AUTH_SESSION_COOKIE_NAME}={token}; Path=/; HttpOnly; SameSite=Lax; "
f"{AUTH_SESSION_COOKIE_NAME}={token}; Path={self.base_path}; HttpOnly; SameSite=Lax; "
f"Max-Age={AUTH_SESSION_COOKIE_MAX_AGE_SECONDS}{secure}"
)
@@ -286,7 +310,7 @@ class SignalingServer:
"""Build Set-Cookie header value that expires the session cookie."""
secure = "; Secure" if self._session_cookie_secure(request) else ""
return f"{AUTH_SESSION_COOKIE_NAME}=; Path=/; HttpOnly; SameSite=Lax; Max-Age=0{secure}"
return f"{AUTH_SESSION_COOKIE_NAME}=; Path={self.base_path}; HttpOnly; SameSite=Lax; Max-Age=0{secure}"
def _origin_allowed(self, request: HttpRequest) -> bool:
"""Return whether one auth helper HTTP request comes from the configured app origin."""
@@ -316,8 +340,18 @@ class SignalingServer:
"""Handle lightweight same-origin auth cookie set/clear HTTP endpoints."""
path = request.path.split("?", 1)[0]
if path not in {AUTH_SESSION_COOKIE_SET_PATH, AUTH_SESSION_COOKIE_CLEAR_PATH, AUTH_SESSION_COOKIE_CHECK_PATH}:
auth_paths = {
self.auth_session_cookie_set_path,
self.auth_session_cookie_clear_path,
self.auth_session_cookie_check_path,
}
if path == self.websocket_path:
return None
if path not in auth_paths:
headers = Headers()
headers["Content-Type"] = "text/plain; charset=utf-8"
headers["Cache-Control"] = "no-store"
return HttpResponse(404, "Not Found", headers, b"not found")
headers = Headers()
headers["Content-Type"] = "text/plain; charset=utf-8"
@@ -328,7 +362,7 @@ class SignalingServer:
if not self._origin_allowed(request):
return HttpResponse(403, "Forbidden", headers, b"origin not allowed")
if path == AUTH_SESSION_COOKIE_CHECK_PATH:
if path == self.auth_session_cookie_check_path:
cookie_header = str(request.headers.get("Cookie", "")).strip()
token = self._cookie_value(cookie_header, AUTH_SESSION_COOKIE_NAME)
if not token:
@@ -339,7 +373,7 @@ class SignalingServer:
return HttpResponse(401, "Unauthorized", headers, b"invalid session")
return HttpResponse(204, "No Content", headers, b"")
if path == AUTH_SESSION_COOKIE_CLEAR_PATH:
if path == self.auth_session_cookie_clear_path:
headers["Set-Cookie"] = self._clear_session_cookie_header(request=request)
return HttpResponse(200, "OK", headers, b"cleared")
@@ -1415,6 +1449,11 @@ class SignalingServer:
LOGGER.info("websocket opened id=%s", client.id)
try:
request = getattr(websocket, "request", None)
request_path = str(getattr(request, "path", "")).split("?", 1)[0]
if request_path != self.websocket_path:
await websocket.close()
return
cookie_token = self._session_token_from_websocket_cookie(websocket)
if cookie_token:
await self._handle_auth_packet(
@@ -3245,5 +3284,6 @@ def run() -> None:
state_save_debounce_ms=config.storage.state_save_debounce_ms,
state_save_max_delay_ms=config.storage.state_save_max_delay_ms,
host_origin=host_origin,
base_path=config.server.base_path,
)
asyncio.run(server.start())

View File

@@ -3,6 +3,8 @@
bind_ip = "127.0.0.1"
# Listen port for signaling websocket server.
port = 8765
# Public base path for this grid instance. Examples: "/", "/chgrid/", "/ttgrid/".
base_path = "/"
[network]
# Maximum inbound websocket message size in bytes.

View File

@@ -8,6 +8,7 @@ from app.config import load_config
def test_load_config_defaults_when_path_none() -> None:
cfg = load_config(None)
assert cfg.server.bind_ip == "127.0.0.1"
assert cfg.server.base_path == "/"
assert cfg.network.allow_insecure_ws is False
assert cfg.storage.state_file == "runtime/items.json"
assert cfg.storage.state_save_debounce_ms == 200
@@ -43,3 +44,18 @@ state_save_max_delay_ms = 900
cfg = load_config(config_path)
assert cfg.storage.state_save_debounce_ms == 150
assert cfg.storage.state_save_max_delay_ms == 900
def test_load_config_reads_server_base_path(tmp_path: Path) -> None:
config_path = tmp_path / "config.toml"
config_path.write_text(
"""
[network]
allow_insecure_ws = true
[server]
base_path = "/ttgrid/"
""".strip()
)
cfg = load_config(config_path)
assert cfg.server.base_path == "/ttgrid/"

View File

@@ -24,13 +24,17 @@ def _request(path: str, headers: dict[str, str] | None = None) -> Request:
return Request(path=path, headers=values)
def _server() -> SignalingServer:
return SignalingServer("127.0.0.1", 8765, None, None, host_origin="https://example.com", base_path="/chgrid/")
@pytest.mark.asyncio
async def test_session_cookie_set_endpoint_sets_httponly_cookie() -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, host_origin="https://example.com")
server = _server()
username = f"user_{uuid.uuid4().hex[:8]}"
session = server.auth_service.register(username, "password99")
request = _request(
AUTH_SESSION_COOKIE_SET_PATH,
server.auth_session_cookie_set_path,
headers={
AUTH_SESSION_COOKIE_CLIENT_HEADER: "1",
"Authorization": f"Bearer {session.token}",
@@ -44,15 +48,16 @@ async def test_session_cookie_set_endpoint_sets_httponly_cookie() -> None:
assert response.status_code == 200
set_cookie = response.headers.get("Set-Cookie", "")
assert f"{AUTH_SESSION_COOKIE_NAME}=" in set_cookie
assert "Path=/chgrid/" in set_cookie
assert "HttpOnly" in set_cookie
assert "SameSite=Lax" in set_cookie
@pytest.mark.asyncio
async def test_session_cookie_clear_endpoint_expires_cookie() -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, host_origin="https://example.com")
server = _server()
request = _request(
AUTH_SESSION_COOKIE_CLEAR_PATH,
server.auth_session_cookie_clear_path,
headers={AUTH_SESSION_COOKIE_CLIENT_HEADER: "1", "Origin": "https://example.com"},
)
@@ -68,11 +73,11 @@ async def test_session_cookie_clear_endpoint_expires_cookie() -> None:
@pytest.mark.asyncio
async def test_session_cookie_check_endpoint_accepts_valid_cookie() -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, host_origin="https://example.com")
server = _server()
username = f"user_{uuid.uuid4().hex[:8]}"
session = server.auth_service.register(username, "password99")
request = _request(
AUTH_SESSION_COOKIE_CHECK_PATH,
server.auth_session_cookie_check_path,
headers={
AUTH_SESSION_COOKIE_CLIENT_HEADER: "1",
"Cookie": f"{AUTH_SESSION_COOKIE_NAME}={session.token}",
@@ -88,9 +93,9 @@ async def test_session_cookie_check_endpoint_accepts_valid_cookie() -> None:
@pytest.mark.asyncio
async def test_session_cookie_check_endpoint_rejects_missing_cookie() -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, host_origin="https://example.com")
server = _server()
request = _request(
AUTH_SESSION_COOKIE_CHECK_PATH,
server.auth_session_cookie_check_path,
headers={AUTH_SESSION_COOKIE_CLIENT_HEADER: "1", "Origin": "https://example.com"},
)
@@ -102,9 +107,9 @@ async def test_session_cookie_check_endpoint_rejects_missing_cookie() -> None:
@pytest.mark.asyncio
async def test_session_cookie_helpers_reject_wrong_origin() -> None:
server = SignalingServer("127.0.0.1", 8765, None, None, host_origin="https://example.com")
server = _server()
request = _request(
AUTH_SESSION_COOKIE_CLEAR_PATH,
server.auth_session_cookie_clear_path,
headers={AUTH_SESSION_COOKIE_CLIENT_HEADER: "1", "Origin": "https://evil.example.com"},
)