diff --git a/pkg/router/handlers.go b/pkg/router/handlers.go index c95c221c..d86506c6 100644 --- a/pkg/router/handlers.go +++ b/pkg/router/handlers.go @@ -17,10 +17,12 @@ limitations under the License. package router import ( + "errors" "fmt" "net/http" "net/http/httputil" "net/url" + "strconv" "strings" "time" @@ -103,27 +105,54 @@ func (s *Server) handleGetSandboxError(c *gin.Context, err error) { c.JSON(http.StatusInternalServerError, gin.H{"error": "internal server error"}) } +// errNoEntryPoint is returned when a sandbox has no entry points configured. +var errNoEntryPoint = errors.New("no entry point found for sandbox") + func determineUpstreamURL(sandbox *types.SandboxInfo, path string) (*url.URL, error) { // prefer matched entrypoint by path for _, ep := range sandbox.EntryPoints { if strings.HasPrefix(path, ep.Path) { - return buildURL(ep.Protocol, ep.Endpoint), nil + return buildURL(ep.Protocol, ep.Endpoint) } } // fallback to first entrypoint if len(sandbox.EntryPoints) == 0 { - return nil, fmt.Errorf("no entry point found for sandbox") + return nil, errNoEntryPoint } ep := sandbox.EntryPoints[0] - return buildURL(ep.Protocol, ep.Endpoint), nil + return buildURL(ep.Protocol, ep.Endpoint) +} + +// validProxySchemes are the URL schemes accepted for reverse-proxy targets. +// ws/wss are included to support WebSocket upgrades. +var validProxySchemes = map[string]bool{ + "http": true, "https": true, "ws": true, "wss": true, } -func buildURL(protocol, endpoint string) *url.URL { +func buildURL(protocol, endpoint string) (*url.URL, error) { if protocol != "" && !strings.Contains(endpoint, "://") { - endpoint = (strings.ToLower(protocol) + "://" + endpoint) + endpoint = strings.ToLower(protocol) + "://" + endpoint + } + u, err := url.Parse(endpoint) + if err != nil { + return nil, fmt.Errorf("invalid endpoint URL %q: %w", endpoint, err) + } + if u.Scheme == "" { + return nil, fmt.Errorf("invalid endpoint URL %q: missing scheme", endpoint) } - url, _ := url.Parse(endpoint) - return url + if !validProxySchemes[u.Scheme] { + return nil, fmt.Errorf("invalid endpoint URL %q: unsupported scheme %q, must be http, https, ws, or wss", endpoint, u.Scheme) + } + if u.Host == "" { + return nil, fmt.Errorf("invalid endpoint URL %q: missing host", endpoint) + } + if portStr := u.Port(); portStr != "" { + port, err := strconv.Atoi(portStr) + if err != nil || port < 1 || port > 65535 { + return nil, fmt.Errorf("invalid endpoint URL %q: port %q out of range (1-65535)", endpoint, portStr) + } + } + return u, nil } // handleAgentInvoke handles agent invocation requests @@ -148,9 +177,11 @@ func (s *Server) forwardToSandbox(c *gin.Context, sandbox *types.SandboxInfo, pa targetURL, err := determineUpstreamURL(sandbox, path) if err != nil { klog.Errorf("Failed to get sandbox access address %s: %v", sandbox.SandboxID, err) - c.JSON(http.StatusNotFound, gin.H{ - "error": err.Error(), - }) + if errors.Is(err, errNoEntryPoint) { + c.JSON(http.StatusNotFound, gin.H{"error": err.Error()}) + } else { + c.JSON(http.StatusInternalServerError, gin.H{"error": "invalid sandbox endpoint configuration"}) + } return } diff --git a/pkg/router/handlers_test.go b/pkg/router/handlers_test.go index 84aa06a8..0d60087a 100644 --- a/pkg/router/handlers_test.go +++ b/pkg/router/handlers_test.go @@ -393,8 +393,63 @@ func TestHandleCodeInterpreterInvoke(t *testing.T) { } func TestForwardToSandbox_InvalidEndpoint(t *testing.T) { - setupEnv() - defer teardownEnv() + cases := []struct { + name string + endpoint string + }{ + {"malformed scheme", "://invalid-url"}, + {"empty endpoint", ""}, + {"scheme only no host", "http://"}, + {"no scheme", "localhost:8080"}, + {"unsupported scheme ftp", "ftp://host:21"}, + {"unsupported scheme file", "file:///etc/passwd"}, + {"port out of range", "http://host:99999"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv("REDIS_ADDR", "localhost:6379") + t.Setenv("REDIS_PASSWORD", "test-password") + t.Setenv("WORKLOAD_MANAGER_URL", "http://localhost:8080") + + config := &Config{Port: "8080"} + server, err := NewServer(config) + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + + server.sessionManager = &mockSessionManager{ + sandbox: &types.SandboxInfo{ + SandboxID: "test-sandbox", + SessionID: "test-session", + Name: "test-sandbox", + EntryPoints: []types.SandboxEntryPoint{ + {Endpoint: tc.endpoint, Path: "/test"}, + }, + }, + } + + routerServer := httptest.NewServer(server.engine) + defer routerServer.Close() + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Post(routerServer.URL+"/v1/namespaces/default/agent-runtimes/test-agent/invocations/test", "application/json", nil) + if err != nil { + t.Fatalf("Failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("[%s] expected %d, got %d", tc.name, http.StatusInternalServerError, resp.StatusCode) + } + }) + } +} + +func TestForwardToSandbox_NoEntryPoints(t *testing.T) { + t.Setenv("REDIS_ADDR", "localhost:6379") + t.Setenv("REDIS_PASSWORD", "test-password") + t.Setenv("WORKLOAD_MANAGER_URL", "http://localhost:8080") config := &Config{Port: "8080"} server, err := NewServer(config) @@ -404,16 +459,13 @@ func TestForwardToSandbox_InvalidEndpoint(t *testing.T) { server.sessionManager = &mockSessionManager{ sandbox: &types.SandboxInfo{ - SandboxID: "test-sandbox", - SessionID: "test-session", - Name: "test-sandbox", - EntryPoints: []types.SandboxEntryPoint{ - {Endpoint: "://invalid-url", Path: "/test"}, - }, + SandboxID: "test-sandbox", + SessionID: "test-session", + Name: "test-sandbox", + EntryPoints: []types.SandboxEntryPoint{}, }, } - // run via real server to avoid CloseNotifier panic routerServer := httptest.NewServer(server.engine) defer routerServer.Close() @@ -424,8 +476,8 @@ func TestForwardToSandbox_InvalidEndpoint(t *testing.T) { } defer resp.Body.Close() - if resp.StatusCode != http.StatusInternalServerError { - t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, resp.StatusCode) + if resp.StatusCode != http.StatusNotFound { + t.Errorf("Expected status code %d, got %d", http.StatusNotFound, resp.StatusCode) } } diff --git a/pkg/router/jwt_test.go b/pkg/router/jwt_test.go index aef69981..fdbe1f39 100644 --- a/pkg/router/jwt_test.go +++ b/pkg/router/jwt_test.go @@ -183,7 +183,16 @@ func TestGetPrivateKeyPEM(t *testing.T) { privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.NoError(t, err) assert.NotNil(t, privateKey) - assert.Equal(t, manager.privateKey, privateKey) + + // Compare key components instead of whole struct — ParsePKCS1PrivateKey may + // not precompute Dp/Dq/Qinv identically to the original key. + assert.Equal(t, 0, manager.privateKey.PublicKey.N.Cmp(privateKey.PublicKey.N), "Public key N should match") + assert.Equal(t, manager.privateKey.PublicKey.E, privateKey.PublicKey.E, "Public key E should match") + assert.Equal(t, 0, manager.privateKey.D.Cmp(privateKey.D), "Private exponent D should match") + assert.Equal(t, len(manager.privateKey.Primes), len(privateKey.Primes), "Number of primes should match") + for i := range manager.privateKey.Primes { + assert.Equal(t, 0, manager.privateKey.Primes[i].Cmp(privateKey.Primes[i]), "Prime %d should match", i) + } } func TestLoadPrivateKeyPEM(t *testing.T) {