Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions runtime/hub/core/git_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import re
from urllib.parse import urlparse, urlunparse


def validate_and_sanitize_repo_url(url: str, allowed_providers: list[str]) -> tuple[bool, str, str]:
if not url or not str(url).strip():
return True, "", ""

url = str(url).strip()
if "://" not in url:
url = "https://" + url

try:
parsed = urlparse(url)
if parsed.scheme not in ["http", "https"]:
return False, "Only HTTP/HTTPS URLs supported", ""
if not parsed.netloc:
return False, "Invalid URL format", ""

path = parsed.path
tree_match = re.match(r"^(/[^/]+/[^/]+)/tree/.+$", path)
if tree_match:
path = tree_match.group(1)
if path.endswith(".git"):
path = path[:-4]

sanitized = urlunparse((parsed.scheme, parsed.netloc.lower(), path, "", "", ""))
hostname = parsed.netloc.lower()
if not any(hostname == provider or hostname.endswith("." + provider) for provider in allowed_providers):
return False, f"Repository host '{hostname}' not authorized", ""
except Exception as e:
return False, f"URL parsing error: {e}", ""

dangerous_patterns = [";", "||", "&&", "$(", "`", "\n", "\r"]
if any(pat in sanitized for pat in dangerous_patterns):
return False, "URL contains suspicious characters", ""

return True, "", sanitized
34 changes: 21 additions & 13 deletions runtime/hub/core/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from tornado import web

from core.authenticators import CustomFirstUseAuthenticator
from core.git_validation import validate_and_sanitize_repo_url
from core.quota import (
BatchQuotaRequest,
QuotaAction,
Expand Down Expand Up @@ -921,17 +922,11 @@ async def get(self, repo_path: str):
config = HubConfig.get()
allowed_providers = list(config.git_clone.allowedProviders)

repo_url = f"https://{repo_path.rstrip('/')}"

try:
parsed = urlparse(repo_url)
hostname = parsed.netloc.lower()
except Exception as e:
raise web.HTTPError(400, "Invalid repository URL") from e

is_allowed = any(hostname == p or hostname.endswith("." + p) for p in allowed_providers)
if not is_allowed:
raise web.HTTPError(403, f"Repository host '{hostname}' is not allowed")
is_valid, error, repo_url = validate_and_sanitize_repo_url(repo_path.rstrip("/"), allowed_providers)
if not is_valid:
if "not authorized" in error:
raise web.HTTPError(403, error)
raise web.HTTPError(400, error)

params: list[tuple[str, str]] = [("repo_url", repo_url)]
if self.get_argument("autostart", ""):
Expand Down Expand Up @@ -1064,7 +1059,14 @@ async def _validate(self, url: str, branch: str, token: str) -> dict:

@web.authenticated
async def post(self):
body = json.loads(self.request.body)
try:
body = json.loads(self.request.body.decode("utf-8"))
except json.JSONDecodeError:
self.set_status(400)
self.set_header("Content-Type", "application/json")
self.finish(json.dumps({"error": "Invalid JSON"}))
return

url = (body.get("url") or "").strip()
branch = (body.get("branch") or "").strip()

Expand All @@ -1086,7 +1088,13 @@ async def post(self):

result = {"valid": False, "error": "URL is required"}
if url:
result = await self._validate(url, branch, access_token)
is_valid, error, sanitized_url = validate_and_sanitize_repo_url(
url, list(config.git_clone.allowedProviders)
)
if not is_valid:
result = {"valid": False, "error": error}
else:
result = await self._validate(sanitized_url, branch, access_token)
self.set_header("Content-Type", "application/json")
self.finish(json.dumps(result))

Expand Down
173 changes: 173 additions & 0 deletions runtime/hub/tests/test_validate_repo_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import importlib.util
import json
import sys
import types
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
CORE = ROOT / "core"

if "jupyterhub.apihandlers" not in sys.modules:
jupyterhub_module = types.ModuleType("jupyterhub")
apihandlers_module = types.ModuleType("jupyterhub.apihandlers")
handlers_module = types.ModuleType("jupyterhub.handlers")
apihandlers_module.APIHandler = type("APIHandler", (), {})
handlers_module.BaseHandler = type("BaseHandler", (), {})
sys.modules["jupyterhub"] = jupyterhub_module
sys.modules["jupyterhub.apihandlers"] = apihandlers_module
sys.modules["jupyterhub.handlers"] = handlers_module

if "multiauthenticator" not in sys.modules:
multiauthenticator_module = types.ModuleType("multiauthenticator")
multiauthenticator_module.MultiAuthenticator = type("MultiAuthenticator", (), {})
sys.modules["multiauthenticator"] = multiauthenticator_module

if "core" not in sys.modules:
core_module = types.ModuleType("core")
core_module.__path__ = [str(CORE)]
sys.modules["core"] = core_module

if "core.authenticators" not in sys.modules:
auth_module = types.ModuleType("core.authenticators")
auth_module.CustomFirstUseAuthenticator = type("CustomFirstUseAuthenticator", (), {})
sys.modules["core.authenticators"] = auth_module

if "core.quota" not in sys.modules:
quota_module = types.ModuleType("core.quota")
quota_module.BatchQuotaRequest = type("BatchQuotaRequest", (), {})
quota_module.QuotaAction = type("QuotaAction", (), {})
quota_module.QuotaModifyRequest = type("QuotaModifyRequest", (), {})
quota_module.QuotaRefreshRequest = type("QuotaRefreshRequest", (), {})
quota_module.get_quota_manager = lambda: None
sys.modules["core.quota"] = quota_module

if "core.stats_handlers" not in sys.modules:
stats_module = types.ModuleType("core.stats_handlers")
for name in [
"StatsActiveSSEHandler",
"StatsDistributionHandler",
"StatsHourlyHandler",
"StatsMyUsageHandler",
"StatsOverviewHandler",
"StatsUsageHandler",
"StatsUserHandler",
]:
setattr(stats_module, name, type(name, (), {}))
sys.modules["core.stats_handlers"] = stats_module


def load_module(name: str, path: Path):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
assert spec.loader is not None
spec.loader.exec_module(module)
return module


git_validation = load_module("core.git_validation", CORE / "git_validation.py")
handlers = load_module("core.handlers", CORE / "handlers.py")
validate_and_sanitize_repo_url = git_validation.validate_and_sanitize_repo_url
ValidateRepoHandler = handlers.ValidateRepoHandler


class DummyUser:
def __init__(self, auth_state=None):
self._auth_state = auth_state or {}

async def get_auth_state(self):
return self._auth_state


class DummyGitClone:
def __init__(self, allowed_providers, github_app_name="", default_access_token=""):
self.allowedProviders = allowed_providers
self.githubAppName = github_app_name
self.defaultAccessToken = default_access_token


class DummyConfig:
def __init__(self, allowed_providers, github_app_name="", default_access_token=""):
self.git_clone = DummyGitClone(allowed_providers, github_app_name, default_access_token)


def test_validate_repo_url_adds_https_and_strips_tree_and_dot_git():
ok, error, sanitized = validate_and_sanitize_repo_url(
"github.com/example/project.git/tree/main",
["github.com"],
)
assert ok is True
assert error == ""
assert sanitized == "https://github.com/example/project"


def test_validate_repo_url_rejects_disallowed_host():
ok, error, sanitized = validate_and_sanitize_repo_url(
"https://evil.example.com/org/repo",
["github.com", "gitlab.com"],
)
assert ok is False
assert "not authorized" in error
assert sanitized == ""


def test_validate_repo_post_returns_400_for_invalid_json(monkeypatch):
monkeypatch.setitem(
sys.modules,
"core.config",
types.SimpleNamespace(
HubConfig=type("HubConfig", (), {"get": staticmethod(lambda: DummyConfig(["github.com"]))})
),
)

handler = object.__new__(ValidateRepoHandler)
handler.request = types.SimpleNamespace(body=b"{not-json")
handler.current_user = DummyUser()

captured = {}
handler.set_status = lambda status: captured.setdefault("status", status)
handler.set_header = lambda key, value: captured.setdefault("headers", {}).__setitem__(key, value)
handler.finish = lambda payload: captured.setdefault("body", payload)

import asyncio

asyncio.run(handler.post())

assert captured["status"] == 400
assert captured["headers"]["Content-Type"] == "application/json"
assert json.loads(captured["body"]) == {"error": "Invalid JSON"}


def test_validate_repo_post_rejects_disallowed_provider_before_remote_call(monkeypatch):
monkeypatch.setitem(
sys.modules,
"core.config",
types.SimpleNamespace(
HubConfig=type("HubConfig", (), {"get": staticmethod(lambda: DummyConfig(["github.com"]))})
),
)

handler = object.__new__(ValidateRepoHandler)
handler.request = types.SimpleNamespace(
body=json.dumps({"url": "https://evil.example.com/org/repo", "branch": "main"}).encode("utf-8")
)
handler.current_user = DummyUser()

called = {"value": False}

async def fake_validate(url, branch, token):
called["value"] = True
return {"valid": True, "error": ""}

handler._validate = fake_validate
handler.set_header = lambda key, value: None
result = {}
handler.finish = lambda payload: result.setdefault("payload", payload)

import asyncio

asyncio.run(handler.post())

assert called["value"] is False
assert json.loads(result["payload"])["valid"] is False
assert "not authorized" in json.loads(result["payload"])["error"]
Loading