diff --git a/runtime/hub/core/git_validation.py b/runtime/hub/core/git_validation.py new file mode 100644 index 0000000..cdeb237 --- /dev/null +++ b/runtime/hub/core/git_validation.py @@ -0,0 +1,38 @@ +import re +from urllib.parse import urlparse, urlunparse + + +def validate_and_sanitize_repo_url(url: str, allowed_providers: list[str]) -> tuple[bool, str, str]: + if not url or not str(url).strip(): + return True, "", "" + + url = str(url).strip() + if "://" not in url: + url = "https://" + url + + try: + parsed = urlparse(url) + if parsed.scheme not in ["http", "https"]: + return False, "Only HTTP/HTTPS URLs supported", "" + if not parsed.netloc: + return False, "Invalid URL format", "" + + path = parsed.path + tree_match = re.match(r"^(/[^/]+/[^/]+)/tree/.+$", path) + if tree_match: + path = tree_match.group(1) + if path.endswith(".git"): + path = path[:-4] + + sanitized = urlunparse((parsed.scheme, parsed.netloc.lower(), path, "", "", "")) + hostname = parsed.netloc.lower() + if not any(hostname == provider or hostname.endswith("." + provider) for provider in allowed_providers): + return False, f"Repository host '{hostname}' not authorized", "" + except Exception as e: + return False, f"URL parsing error: {e}", "" + + dangerous_patterns = [";", "||", "&&", "$(", "`", "\n", "\r"] + if any(pat in sanitized for pat in dangerous_patterns): + return False, "URL contains suspicious characters", "" + + return True, "", sanitized diff --git a/runtime/hub/core/handlers.py b/runtime/hub/core/handlers.py index bc6fd5c..a5ba061 100644 --- a/runtime/hub/core/handlers.py +++ b/runtime/hub/core/handlers.py @@ -41,6 +41,7 @@ from tornado import web from core.authenticators import CustomFirstUseAuthenticator +from core.git_validation import validate_and_sanitize_repo_url from core.quota import ( BatchQuotaRequest, QuotaAction, @@ -921,17 +922,11 @@ async def get(self, repo_path: str): config = HubConfig.get() allowed_providers = list(config.git_clone.allowedProviders) - repo_url = f"https://{repo_path.rstrip('/')}" - - try: - parsed = urlparse(repo_url) - hostname = parsed.netloc.lower() - except Exception as e: - raise web.HTTPError(400, "Invalid repository URL") from e - - is_allowed = any(hostname == p or hostname.endswith("." + p) for p in allowed_providers) - if not is_allowed: - raise web.HTTPError(403, f"Repository host '{hostname}' is not allowed") + is_valid, error, repo_url = validate_and_sanitize_repo_url(repo_path.rstrip("/"), allowed_providers) + if not is_valid: + if "not authorized" in error: + raise web.HTTPError(403, error) + raise web.HTTPError(400, error) params: list[tuple[str, str]] = [("repo_url", repo_url)] if self.get_argument("autostart", ""): @@ -1064,7 +1059,14 @@ async def _validate(self, url: str, branch: str, token: str) -> dict: @web.authenticated async def post(self): - body = json.loads(self.request.body) + try: + body = json.loads(self.request.body.decode("utf-8")) + except json.JSONDecodeError: + self.set_status(400) + self.set_header("Content-Type", "application/json") + self.finish(json.dumps({"error": "Invalid JSON"})) + return + url = (body.get("url") or "").strip() branch = (body.get("branch") or "").strip() @@ -1086,7 +1088,13 @@ async def post(self): result = {"valid": False, "error": "URL is required"} if url: - result = await self._validate(url, branch, access_token) + is_valid, error, sanitized_url = validate_and_sanitize_repo_url( + url, list(config.git_clone.allowedProviders) + ) + if not is_valid: + result = {"valid": False, "error": error} + else: + result = await self._validate(sanitized_url, branch, access_token) self.set_header("Content-Type", "application/json") self.finish(json.dumps(result)) diff --git a/runtime/hub/tests/test_validate_repo_handler.py b/runtime/hub/tests/test_validate_repo_handler.py new file mode 100644 index 0000000..5f1797f --- /dev/null +++ b/runtime/hub/tests/test_validate_repo_handler.py @@ -0,0 +1,173 @@ +import importlib.util +import json +import sys +import types +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +CORE = ROOT / "core" + +if "jupyterhub.apihandlers" not in sys.modules: + jupyterhub_module = types.ModuleType("jupyterhub") + apihandlers_module = types.ModuleType("jupyterhub.apihandlers") + handlers_module = types.ModuleType("jupyterhub.handlers") + apihandlers_module.APIHandler = type("APIHandler", (), {}) + handlers_module.BaseHandler = type("BaseHandler", (), {}) + sys.modules["jupyterhub"] = jupyterhub_module + sys.modules["jupyterhub.apihandlers"] = apihandlers_module + sys.modules["jupyterhub.handlers"] = handlers_module + +if "multiauthenticator" not in sys.modules: + multiauthenticator_module = types.ModuleType("multiauthenticator") + multiauthenticator_module.MultiAuthenticator = type("MultiAuthenticator", (), {}) + sys.modules["multiauthenticator"] = multiauthenticator_module + +if "core" not in sys.modules: + core_module = types.ModuleType("core") + core_module.__path__ = [str(CORE)] + sys.modules["core"] = core_module + +if "core.authenticators" not in sys.modules: + auth_module = types.ModuleType("core.authenticators") + auth_module.CustomFirstUseAuthenticator = type("CustomFirstUseAuthenticator", (), {}) + sys.modules["core.authenticators"] = auth_module + +if "core.quota" not in sys.modules: + quota_module = types.ModuleType("core.quota") + quota_module.BatchQuotaRequest = type("BatchQuotaRequest", (), {}) + quota_module.QuotaAction = type("QuotaAction", (), {}) + quota_module.QuotaModifyRequest = type("QuotaModifyRequest", (), {}) + quota_module.QuotaRefreshRequest = type("QuotaRefreshRequest", (), {}) + quota_module.get_quota_manager = lambda: None + sys.modules["core.quota"] = quota_module + +if "core.stats_handlers" not in sys.modules: + stats_module = types.ModuleType("core.stats_handlers") + for name in [ + "StatsActiveSSEHandler", + "StatsDistributionHandler", + "StatsHourlyHandler", + "StatsMyUsageHandler", + "StatsOverviewHandler", + "StatsUsageHandler", + "StatsUserHandler", + ]: + setattr(stats_module, name, type(name, (), {})) + sys.modules["core.stats_handlers"] = stats_module + + +def load_module(name: str, path: Path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +git_validation = load_module("core.git_validation", CORE / "git_validation.py") +handlers = load_module("core.handlers", CORE / "handlers.py") +validate_and_sanitize_repo_url = git_validation.validate_and_sanitize_repo_url +ValidateRepoHandler = handlers.ValidateRepoHandler + + +class DummyUser: + def __init__(self, auth_state=None): + self._auth_state = auth_state or {} + + async def get_auth_state(self): + return self._auth_state + + +class DummyGitClone: + def __init__(self, allowed_providers, github_app_name="", default_access_token=""): + self.allowedProviders = allowed_providers + self.githubAppName = github_app_name + self.defaultAccessToken = default_access_token + + +class DummyConfig: + def __init__(self, allowed_providers, github_app_name="", default_access_token=""): + self.git_clone = DummyGitClone(allowed_providers, github_app_name, default_access_token) + + +def test_validate_repo_url_adds_https_and_strips_tree_and_dot_git(): + ok, error, sanitized = validate_and_sanitize_repo_url( + "github.com/example/project.git/tree/main", + ["github.com"], + ) + assert ok is True + assert error == "" + assert sanitized == "https://github.com/example/project" + + +def test_validate_repo_url_rejects_disallowed_host(): + ok, error, sanitized = validate_and_sanitize_repo_url( + "https://evil.example.com/org/repo", + ["github.com", "gitlab.com"], + ) + assert ok is False + assert "not authorized" in error + assert sanitized == "" + + +def test_validate_repo_post_returns_400_for_invalid_json(monkeypatch): + monkeypatch.setitem( + sys.modules, + "core.config", + types.SimpleNamespace( + HubConfig=type("HubConfig", (), {"get": staticmethod(lambda: DummyConfig(["github.com"]))}) + ), + ) + + handler = object.__new__(ValidateRepoHandler) + handler.request = types.SimpleNamespace(body=b"{not-json") + handler.current_user = DummyUser() + + captured = {} + handler.set_status = lambda status: captured.setdefault("status", status) + handler.set_header = lambda key, value: captured.setdefault("headers", {}).__setitem__(key, value) + handler.finish = lambda payload: captured.setdefault("body", payload) + + import asyncio + + asyncio.run(handler.post()) + + assert captured["status"] == 400 + assert captured["headers"]["Content-Type"] == "application/json" + assert json.loads(captured["body"]) == {"error": "Invalid JSON"} + + +def test_validate_repo_post_rejects_disallowed_provider_before_remote_call(monkeypatch): + monkeypatch.setitem( + sys.modules, + "core.config", + types.SimpleNamespace( + HubConfig=type("HubConfig", (), {"get": staticmethod(lambda: DummyConfig(["github.com"]))}) + ), + ) + + handler = object.__new__(ValidateRepoHandler) + handler.request = types.SimpleNamespace( + body=json.dumps({"url": "https://evil.example.com/org/repo", "branch": "main"}).encode("utf-8") + ) + handler.current_user = DummyUser() + + called = {"value": False} + + async def fake_validate(url, branch, token): + called["value"] = True + return {"valid": True, "error": ""} + + handler._validate = fake_validate + handler.set_header = lambda key, value: None + result = {} + handler.finish = lambda payload: result.setdefault("payload", payload) + + import asyncio + + asyncio.run(handler.post()) + + assert called["value"] is False + assert json.loads(result["payload"])["valid"] is False + assert "not authorized" in json.loads(result["payload"])["error"]