diff --git a/internal/acpwrap/convert.go b/internal/acpwrap/convert.go index 6a94c70..9b4cfa5 100644 --- a/internal/acpwrap/convert.go +++ b/internal/acpwrap/convert.go @@ -6,9 +6,9 @@ import ( "encoding/json" "fmt" "strings" - "unicode" "github.com/BakeLens/crust/internal/rules" + "mvdan.cc/sh/v3/syntax" ) // ACP parameter types @@ -32,20 +32,14 @@ type terminalCreateParams struct { Cwd string `json:"cwd,omitempty"` } -// shellSafe is the set of characters that don't need quoting in shell arguments. -const shellSafe = "-_./:=+," - -// shellQuote quotes a shell argument if it contains special characters. +// shellQuote quotes a shell argument using the shell parser's own Quote function. +// Falls back to single-quoting on error (e.g., null bytes). func shellQuote(s string) string { - if s == "" { - return "''" - } - if strings.ContainsFunc(s, func(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsDigit(c) && !strings.ContainsRune(shellSafe, c) - }) { + q, err := syntax.Quote(s, syntax.LangBash) + if err != nil { return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" } - return s + return q } // ACPMethodToToolCall converts an ACP JSON-RPC method + params into a rules.ToolCall. diff --git a/internal/acpwrap/convert_test.go b/internal/acpwrap/convert_test.go index a99d2d1..b548612 100644 --- a/internal/acpwrap/convert_test.go +++ b/internal/acpwrap/convert_test.go @@ -77,58 +77,3 @@ func TestAcpMethodToToolCall_MalformedParams(t *testing.T) { } } } - -// --- shellQuote --- - -func TestShellQuote(t *testing.T) { - tests := []struct { - input, want string - }{ - {"", "''"}, - {"simple", "simple"}, - {"-rf", "-rf"}, - {"/etc/passwd", "/etc/passwd"}, - {"file with spaces", "'file with spaces'"}, - {"; rm -rf /", "'; rm -rf /'"}, - {"$(whoami)", "'$(whoami)'"}, - {"`id`", "'`id`'"}, - {"it's", "'it'\"'\"'s'"}, - {"a&b", "'a&b'"}, - {"a|b", "'a|b'"}, - {"a>b", "'a>b'"}, - {"aAgent"}, - Outbound: jsonrpc.PipeConfig{Label: "Agent->IDE", Protocol: "ACP", Convert: ACPMethodToToolCall}, - }) - }() - - select { - case code := <-done: - if code != 0 { - t.Errorf("exit code = %d, want 0", code) - } - case <-time.After(5 * time.Second): - t.Fatal("RunProxy hung — IDE stdin not closed after agent exit") - } - }) - - t.Run("propagates_exit_code", func(t *testing.T) { - if _, err := exec.LookPath("false"); err != nil { - t.Skip("'false' not found in PATH") - } - engine := newTestEngine(t) - stdinR, stdinW := io.Pipe() - defer stdinW.Close() - - done := make(chan int, 1) - go func() { - done <- jsonrpc.RunProxy(engine, []string{"false"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ - Log: testLog, - ProcessLabel: "Agent", - Inbound: jsonrpc.PipeConfig{Label: "IDE->Agent"}, - Outbound: jsonrpc.PipeConfig{Label: "Agent->IDE", Protocol: "ACP", Convert: ACPMethodToToolCall}, - }) - }() - - select { - case code := <-done: - if code == 0 { - t.Error("expected non-zero exit code from 'false'") - } - case <-time.After(5 * time.Second): - t.Fatal("RunProxy hung") - } - }) -} - -// --- PipeInspect + ACPMethodToToolCall integration --- +// --- ACPMethodToToolCall converter edge cases --- +// Path-based blocking (.env, .ssh) and passthrough are covered by +// jsonrpc/proxy_test.go (unit) and mcpgateway/e2e_test.go (E2E). +// These tests verify ACP-specific converter error handling. -func TestPipeAgentToIDE_Blocks(t *testing.T) { +func TestPipeAgentToIDE_BlocksConverterEdgeCases(t *testing.T) { tests := []struct { name string msg string }{ - {"env_read", `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`}, - {"ssh_key_read", `{"jsonrpc":"2.0","id":2,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/home/user/.ssh/id_rsa"}}`}, - {"env_write", `{"jsonrpc":"2.0","id":3,"method":"fs/write_text_file","params":{"sessionId":"s1","path":"/app/.env","content":"API_KEY=secret"}}`}, {"malformed_read_params", `{"jsonrpc":"2.0","id":4,"method":"fs/read_text_file","params":"not-an-object"}`}, {"malformed_terminal_params", `{"jsonrpc":"2.0","id":5,"method":"terminal/create","params":42}`}, {"null_params", `{"jsonrpc":"2.0","id":6,"method":"fs/read_text_file","params":null}`}, @@ -127,69 +51,3 @@ func TestPipeAgentToIDE_Blocks(t *testing.T) { }) } } - -func TestPipeAgentToIDE_BlocksEnvRead_ErrorShape(t *testing.T) { - fwd, errOut := runPipe(t, `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`+"\n") - if fwd != "" { - t.Errorf("IDE should not receive blocked request, got: %s", fwd) - } - var resp jsonrpc.ErrorResponse - if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { - t.Fatalf("expected JSON-RPC error, got: %s", errOut) - } - if resp.Error.Code != jsonrpc.BlockedError { - t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) - } - if !strings.Contains(resp.Error.Message, "[Crust]") { - t.Errorf("error message missing [Crust]: %s", resp.Error.Message) - } -} - -func TestPipeAgentToIDE_Passes(t *testing.T) { - tests := []struct { - name string - msg string - }{ - {"normal_read", `{"jsonrpc":"2.0","id":10,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/src/main.go"}}`}, - {"non_security_method", `{"jsonrpc":"2.0","id":20,"method":"session/prompt","params":{"text":"hello"}}`}, - {"notification", `{"jsonrpc":"2.0","method":"session/update","params":{"status":"working"}}`}, - {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"file data"}}`}, - {"invalid_json", `not valid json at all`}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fwd, errOut := runPipe(t, tt.msg+"\n") - if fwd != tt.msg+"\n" { - t.Errorf("message should pass through unchanged\ngot: %q\nwant: %q", fwd, tt.msg+"\n") - } - if errOut != "" { - t.Errorf("agent should not receive errors, got: %s", errOut) - } - }) - } -} - -func TestPipeAgentToIDE_EmptyLine(t *testing.T) { - fwd, _ := runPipe(t, "\n") - if fwd != "\n" { - t.Errorf("empty line should pass through, got: %q", fwd) - } -} - -func TestPipeAgentToIDE_MultipleMessages(t *testing.T) { - msgs := strings.Join([]string{ - `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`, - `{"jsonrpc":"2.0","id":2,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/main.go"}}`, - `{"jsonrpc":"2.0","id":3,"method":"session/prompt","params":{"text":"hi"}}`, - }, "\n") + "\n" - - fwd, errOut := runPipe(t, msgs) - - fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") - if len(fwdLines) != 2 { - t.Errorf("expected 2 IDE messages (main.go + prompt), got %d: %v", len(fwdLines), fwdLines) - } - if errOut == "" { - t.Error("agent should receive error for .env read") - } -} diff --git a/internal/autowrap/run_test.go b/internal/autowrap/run_test.go index cc69961..0c7e483 100644 --- a/internal/autowrap/run_test.go +++ b/internal/autowrap/run_test.go @@ -2,37 +2,21 @@ package autowrap import ( "bytes" - "encoding/json" - "io" - "os/exec" "strings" "testing" - "time" "github.com/BakeLens/crust/internal/jsonrpc" "github.com/BakeLens/crust/internal/logger" "github.com/BakeLens/crust/internal/mcpgateway" - "github.com/BakeLens/crust/internal/rules" + "github.com/BakeLens/crust/internal/testutil" ) var testLog = logger.New("wrap-test") -func newTestEngine(t *testing.T) *rules.Engine { - t.Helper() - engine, err := rules.NewEngine(rules.EngineConfig{ - UserRulesDir: t.TempDir(), - DisableBuiltin: false, - }) - if err != nil { - t.Fatalf("Failed to create engine: %v", err) - } - return engine -} - // runInboundPipe runs PipeInspect with MCPMethodToToolCall (inbound direction). func runInboundPipe(t *testing.T, input string) (fwd, errOut string) { t.Helper() - engine := newTestEngine(t) + engine := testutil.NewEngine(t) var fwdBuf, errBuf bytes.Buffer fwdWriter := jsonrpc.NewLockedWriter(&fwdBuf) errWriter := jsonrpc.NewLockedWriter(&errBuf) @@ -44,7 +28,7 @@ func runInboundPipe(t *testing.T, input string) (fwd, errOut string) { // runOutboundPipe runs PipeInspect with BothMethodToToolCall (outbound direction). func runOutboundPipe(t *testing.T, input string) (fwd, errOut string) { t.Helper() - engine := newTestEngine(t) + engine := testutil.NewEngine(t) var fwdBuf, errBuf bytes.Buffer fwdWriter := jsonrpc.NewLockedWriter(&fwdBuf) errWriter := jsonrpc.NewLockedWriter(&errBuf) @@ -53,134 +37,9 @@ func runOutboundPipe(t *testing.T, input string) (fwd, errOut string) { return fwdBuf.String(), errBuf.String() } -// --- Inbound (MCP) direction --- - -func TestPipeInbound_BlocksMCP(t *testing.T) { - tests := []struct { - name string - msg string - }{ - {"env_read", `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`}, - {"ssh_key_read", `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/home/user/.ssh/id_rsa"}}}`}, - {"resource_env_read", `{"jsonrpc":"2.0","id":3,"method":"resources/read","params":{"uri":"file:///app/.env"}}`}, - {"malformed_params", `{"jsonrpc":"2.0","id":4,"method":"tools/call","params":"not-an-object"}`}, - {"null_params", `{"jsonrpc":"2.0","id":5,"method":"tools/call","params":null}`}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fwd, errOut := runInboundPipe(t, tt.msg+"\n") - if fwd != "" { - t.Errorf("subprocess should not receive blocked request, got: %s", fwd) - } - if errOut == "" { - t.Error("client should receive an error response") - } - }) - } -} - -func TestPipeInbound_PassesMCP(t *testing.T) { - tests := []struct { - name string - msg string - }{ - {"normal_read", `{"jsonrpc":"2.0","id":10,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/src/main.go"}}}`}, - {"non_security_method", `{"jsonrpc":"2.0","id":20,"method":"initialize","params":{"capabilities":{}}}`}, - {"tools_list", `{"jsonrpc":"2.0","id":30,"method":"tools/list","params":{}}`}, - {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"data"}}`}, - {"invalid_json", `not valid json at all`}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fwd, errOut := runInboundPipe(t, tt.msg+"\n") - if fwd != tt.msg+"\n" { - t.Errorf("message should pass through unchanged\ngot: %q\nwant: %q", fwd, tt.msg+"\n") - } - if errOut != "" { - t.Errorf("client should not receive errors, got: %s", errOut) - } - }) - } -} - -func TestPipeInbound_ErrorShape(t *testing.T) { - _, errOut := runInboundPipe(t, `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`+"\n") - var resp jsonrpc.ErrorResponse - if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { - t.Fatalf("expected JSON-RPC error, got: %s", errOut) - } - if resp.Error.Code != jsonrpc.BlockedError { - t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) - } - if !strings.Contains(resp.Error.Message, "[Crust]") { - t.Errorf("error message missing [Crust]: %s", resp.Error.Message) - } -} - -// --- Outbound (ACP) direction --- - -func TestPipeOutbound_BlocksACP(t *testing.T) { - tests := []struct { - name string - msg string - }{ - {"env_read", `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`}, - {"ssh_key_read", `{"jsonrpc":"2.0","id":2,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/home/user/.ssh/id_rsa"}}`}, - {"env_write", `{"jsonrpc":"2.0","id":3,"method":"fs/write_text_file","params":{"sessionId":"s1","path":"/app/.env","content":"SECRET=abc"}}`}, - {"malformed_params", `{"jsonrpc":"2.0","id":4,"method":"fs/read_text_file","params":"not-an-object"}`}, - {"null_params", `{"jsonrpc":"2.0","id":5,"method":"fs/read_text_file","params":null}`}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fwd, errOut := runOutboundPipe(t, tt.msg+"\n") - if fwd != "" { - t.Errorf("client should not receive blocked request, got: %s", fwd) - } - if errOut == "" { - t.Error("subprocess should receive an error response") - } - }) - } -} - -func TestPipeOutbound_PassesACP(t *testing.T) { - tests := []struct { - name string - msg string - }{ - {"normal_read", `{"jsonrpc":"2.0","id":10,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/src/main.go"}}`}, - {"non_security_method", `{"jsonrpc":"2.0","id":20,"method":"session/prompt","params":{"text":"hello"}}`}, - {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"data"}}`}, - {"invalid_json", `not valid json at all`}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fwd, errOut := runOutboundPipe(t, tt.msg+"\n") - if fwd != tt.msg+"\n" { - t.Errorf("message should pass through unchanged\ngot: %q\nwant: %q", fwd, tt.msg+"\n") - } - if errOut != "" { - t.Errorf("subprocess should not receive errors, got: %s", errOut) - } - }) - } -} - -func TestPipeOutbound_ErrorShape(t *testing.T) { - _, errOut := runOutboundPipe(t, `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`+"\n") - var resp jsonrpc.ErrorResponse - if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { - t.Fatalf("expected JSON-RPC error, got: %s", errOut) - } - if resp.Error.Code != jsonrpc.BlockedError { - t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) - } - if !strings.Contains(resp.Error.Message, "[Crust]") { - t.Errorf("error message missing [Crust]: %s", resp.Error.Message) - } -} - -// --- Cross-protocol --- +// --- Cross-protocol tests --- +// These are unique to autowrap: they verify that the inbound MCP converter +// ignores ACP methods, and the outbound BothMethodToToolCall catches both. func TestPipeInbound_IgnoresACPMethods(t *testing.T) { // Inbound uses MCPMethodToToolCall only — ACP methods pass through unexamined. @@ -205,197 +64,3 @@ func TestPipeOutbound_BlocksMCPMethods(t *testing.T) { t.Error("subprocess should receive an error response for blocked MCP method") } } - -// --- Response DLP --- - -func TestPipeOutbound_ResponseDLP_BlocksSecrets(t *testing.T) { - tests := []struct { - name string - msg string - }{ - {"aws_key", `{"jsonrpc":"2.0","id":1,"result":{"content":"key=AKIAIOSFODNN7EXAMPLE"}}`}, - {"github_token", `{"jsonrpc":"2.0","id":2,"result":{"text":"ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklm"}}`}, - {"openai_key", `{"jsonrpc":"2.0","id":3,"result":{"config":"sk-proj-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}}`}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fwd, _ := runOutboundPipe(t, tt.msg+"\n") - // Response should NOT be forwarded (DLP blocks it) - if strings.Contains(fwd, "AKIA") || strings.Contains(fwd, "ghp_") || strings.Contains(fwd, "sk-proj-") { - t.Errorf("response with secret should be blocked by DLP, got forwarded: %s", fwd) - } - }) - } -} - -func TestPipeOutbound_ResponseDLP_PassesClean(t *testing.T) { - msg := `{"jsonrpc":"2.0","id":1,"result":{"content":"safe data, no secrets here"}}` - fwd, _ := runOutboundPipe(t, msg+"\n") - if fwd != msg+"\n" { - t.Errorf("clean response should pass through\ngot: %q\nwant: %q", fwd, msg+"\n") - } -} - -func TestPipeOutbound_ResponseDLP_BlocksErrorFieldSecrets(t *testing.T) { - tests := []struct { - name string - msg string - }{ - {"aws_key_in_error", `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"failed to read config: AKIAIOSFODNN7EXAMPLE"}}`}, - {"github_token_in_error", `{"jsonrpc":"2.0","id":2,"error":{"code":-32000,"message":"auth failed: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklm"}}`}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fwd, _ := runOutboundPipe(t, tt.msg+"\n") - if strings.Contains(fwd, "AKIA") || strings.Contains(fwd, "ghp_") { - t.Errorf("error response with secret should be blocked by DLP, got forwarded: %s", fwd) - } - // fwd should contain a replacement JSON-RPC error from Crust - var resp jsonrpc.ErrorResponse - if err := json.Unmarshal(bytes.TrimSpace([]byte(fwd)), &resp); err != nil { - t.Fatalf("expected JSON-RPC error in fwd, got: %q", fwd) - } - if resp.Error.Code != jsonrpc.BlockedError { - t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) - } - }) - } -} - -func TestPipeOutbound_ResponseDLP_PassesCleanError(t *testing.T) { - msg := `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"file not found"}}` - fwd, _ := runOutboundPipe(t, msg+"\n") - if fwd != msg+"\n" { - t.Errorf("clean error response should pass through\ngot: %q\nwant: %q", fwd, msg+"\n") - } -} - -func TestPipeOutbound_ResponseDLP_ErrorResponseShape(t *testing.T) { - msg := `{"jsonrpc":"2.0","id":1,"result":{"content":"key=AKIAIOSFODNN7EXAMPLE"}}` - fwd, _ := runOutboundPipe(t, msg+"\n") - // fwd should contain a JSON-RPC error (sent to client via fwdWriter) - var resp jsonrpc.ErrorResponse - if err := json.Unmarshal(bytes.TrimSpace([]byte(fwd)), &resp); err != nil { - t.Fatalf("expected JSON-RPC error in fwd, got: %q", fwd) - } - if resp.Error.Code != jsonrpc.BlockedError { - t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) - } - if !strings.Contains(resp.Error.Message, "[Crust]") { - t.Errorf("error message missing [Crust]: %s", resp.Error.Message) - } -} - -// --- Empty lines --- - -func TestPipeInbound_EmptyLine(t *testing.T) { - fwd, _ := runInboundPipe(t, "\n") - if fwd != "\n" { - t.Errorf("empty line should pass through, got: %q", fwd) - } -} - -func TestPipeOutbound_EmptyLine(t *testing.T) { - fwd, _ := runOutboundPipe(t, "\n") - if fwd != "\n" { - t.Errorf("empty line should pass through, got: %q", fwd) - } -} - -// --- RunProxy (hang / exit code) --- - -func TestRunProxy(t *testing.T) { - t.Run("no_hang_on_exit", func(t *testing.T) { - if _, err := exec.LookPath("true"); err != nil { - t.Skip("'true' not found in PATH") - } - engine := newTestEngine(t) - stdinR, stdinW := io.Pipe() - defer stdinW.Close() - - done := make(chan int, 1) - go func() { - done <- jsonrpc.RunProxy(engine, []string{"true"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ - Log: testLog, - ProcessLabel: "Subprocess", - Inbound: jsonrpc.PipeConfig{Label: "Inbound", Protocol: "MCP", Convert: mcpgateway.MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "Stdio", Convert: BothMethodToToolCall}, - }) - }() - - select { - case code := <-done: - if code != 0 { - t.Errorf("exit code = %d, want 0", code) - } - case <-time.After(5 * time.Second): - t.Fatal("RunProxy hung — stdin not closed after subprocess exit") - } - }) - - t.Run("propagates_exit_code", func(t *testing.T) { - if _, err := exec.LookPath("false"); err != nil { - t.Skip("'false' not found in PATH") - } - engine := newTestEngine(t) - stdinR, stdinW := io.Pipe() - defer stdinW.Close() - - done := make(chan int, 1) - go func() { - done <- jsonrpc.RunProxy(engine, []string{"false"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ - Log: testLog, - ProcessLabel: "Subprocess", - Inbound: jsonrpc.PipeConfig{Label: "Inbound", Protocol: "MCP", Convert: mcpgateway.MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "Stdio", Convert: BothMethodToToolCall}, - }) - }() - - select { - case code := <-done: - if code == 0 { - t.Error("expected non-zero exit code from 'false'") - } - case <-time.After(5 * time.Second): - t.Fatal("RunProxy hung") - } - }) -} - -// --- Multiple messages --- - -func TestPipeInbound_MultipleMessages(t *testing.T) { - msgs := strings.Join([]string{ - `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`, - `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/main.go"}}}`, - `{"jsonrpc":"2.0","id":3,"method":"initialize","params":{"capabilities":{}}}`, - }, "\n") + "\n" - - fwd, errOut := runInboundPipe(t, msgs) - - fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") - if len(fwdLines) != 2 { - t.Errorf("expected 2 subprocess messages (main.go + initialize), got %d: %v", len(fwdLines), fwdLines) - } - if errOut == "" { - t.Error("client should receive error for .env read") - } -} - -func TestPipeOutbound_MultipleMessages(t *testing.T) { - msgs := strings.Join([]string{ - `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`, - `{"jsonrpc":"2.0","id":2,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/main.go"}}`, - `{"jsonrpc":"2.0","id":3,"method":"session/prompt","params":{"text":"hi"}}`, - }, "\n") + "\n" - - fwd, errOut := runOutboundPipe(t, msgs) - - fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") - if len(fwdLines) != 2 { - t.Errorf("expected 2 client messages (main.go + prompt), got %d: %v", len(fwdLines), fwdLines) - } - if errOut == "" { - t.Error("subprocess should receive error for .env read") - } -} diff --git a/internal/jsonrpc/pipe.go b/internal/jsonrpc/pipe.go index f40ebdf..edee225 100644 --- a/internal/jsonrpc/pipe.go +++ b/internal/jsonrpc/pipe.go @@ -42,6 +42,139 @@ func PipePassthrough(log *logger.Logger, src io.Reader, dst *LockedWriter, label } } +// processResult indicates what happened when processing a single message. +type processResult int + +const ( + resultForwarded processResult = iota // message was forwarded + resultBlocked // message was blocked + resultWriteErr // write error; caller should abort +) + +// scanDLP checks a JSON-RPC field for leaked secrets. Returns true if blocked. +// If id is non-empty, sends a JSON-RPC error response to errWriter. +// For notifications (no ID), the message is dropped silently. +func scanDLP(log *logger.Logger, engine *rules.Engine, data json.RawMessage, + id json.RawMessage, errWriter *LockedWriter, protocol, logLabel string) bool { + if len(data) == 0 { + return false + } + dlpResult := engine.ScanDLP(string(data)) + if dlpResult == nil { + return false + } + log.Warn("Blocked %s %s (DLP): rule=%s message=%s", + protocol, logLabel, dlpResult.RuleName, dlpResult.Message) + if len(id) > 0 { + SendBlockError(log, errWriter, id, + fmt.Sprintf("[Crust] Blocked by rule %q: %s", dlpResult.RuleName, dlpResult.Message)) + } + return true +} + +// forwardLine writes data to w and returns resultForwarded, or resultWriteErr on failure. +func forwardLine(log *logger.Logger, w *LockedWriter, data []byte, label string) processResult { + if err := w.WriteLine(data); err != nil { + log.Debug("%s write error: %v", label, err) + return resultWriteErr + } + return resultForwarded +} + +// processMessage inspects a single JSON-RPC message and either forwards or blocks it. +// This is the core inspection logic, reused by both the main loop and batch handler. +func processMessage(log *logger.Logger, engine *rules.Engine, line []byte, msg *Message, + fwdWriter, errWriter *LockedWriter, convert MethodConverter, protocol, label string) processResult { + + // Response (no method): DLP-scan only. + if !msg.IsRequest() && !msg.IsNotification() { + if scanDLP(log, engine, msg.Result, msg.ID, fwdWriter, protocol, "response") || + scanDLP(log, engine, msg.Error, msg.ID, fwdWriter, protocol, "error response") { + return resultBlocked + } + return forwardLine(log, fwdWriter, line, label) + } + + // Notification: DLP-scan params for leaked secrets, then fall through + // to converter + rule evaluation (notifications with security-relevant + // methods like tools/call must still be inspected). + if msg.IsNotification() { + if scanDLP(log, engine, msg.Params, nil, fwdWriter, protocol, + "notification method="+msg.Method) { + return resultBlocked + } + } + + // Request or notification with a method: convert and evaluate rules. + toolCall, err := convert(msg.Method, msg.Params) + if toolCall == nil && err == nil { + return forwardLine(log, fwdWriter, line, label) + } + if err != nil { + log.Warn("Blocked %s %s: %v", protocol, msg.Method, err) + if msg.IsRequest() { + SendBlockError(log, errWriter, msg.ID, "[Crust] Blocked: malformed params for "+msg.Method) + } + return resultBlocked + } + + result := engine.Evaluate(*toolCall) + + if result.Matched && result.Action == rules.ActionBlock { + log.Warn("Blocked %s %s (tool=%s): rule=%s message=%s", + protocol, msg.Method, toolCall.Name, result.RuleName, result.Message) + if msg.IsRequest() { + SendBlockError(log, errWriter, msg.ID, + fmt.Sprintf("[Crust] Blocked by rule %q: %s", result.RuleName, result.Message)) + } + return resultBlocked + } + + if result.Matched && result.Action == rules.ActionLog { + log.Info("Logged %s %s (tool=%s): rule=%s", + protocol, msg.Method, toolCall.Name, result.RuleName) + } + + return forwardLine(log, fwdWriter, line, label) +} + +// processBatch handles a JSON-RPC batch array by inspecting each element +// individually. Each allowed element is forwarded as an individual JSONL line. +// MCP stdio transport is JSONL, so splitting batches is correct behavior. +func processBatch(log *logger.Logger, engine *rules.Engine, line []byte, + fwdWriter, errWriter *LockedWriter, convert MethodConverter, protocol, label string) processResult { + + var batch []json.RawMessage + if err := json.Unmarshal(line, &batch); err != nil { + // Not a valid JSON array — forward as-is (same as invalid JSON) + log.Debug("%s batch parse error: %v", label, err) + return forwardLine(log, fwdWriter, line, label) + } + + if len(batch) == 0 { + return forwardLine(log, fwdWriter, line, label) + } + + log.Debug("%s processing batch of %d messages", label, len(batch)) + + for _, raw := range batch { + var msg Message + if err := json.Unmarshal(raw, &msg); err != nil { + // Element is not valid JSON-RPC — forward individually + if forwardLine(log, fwdWriter, raw, label) == resultWriteErr { + return resultWriteErr + } + continue + } + + if processMessage(log, engine, raw, &msg, fwdWriter, errWriter, convert, protocol, label) == resultWriteErr { + return resultWriteErr + } + } + + return resultForwarded +} + // PipeInspect reads JSONL from src, runs security-relevant messages through // the converter and rule engine, and either forwards or blocks them. // @@ -67,75 +200,26 @@ func PipeInspect(log *logger.Logger, engine *rules.Engine, src io.Reader, continue } - var msg Message - if err := json.Unmarshal(line, &msg); err != nil { - if err := fwdWriter.WriteLine(line); err != nil { - log.Debug("%s write error: %v", label, err) - return - } - continue - } - - if !msg.IsRequest() { - // Response DLP: scan responses for leaked secrets. - // Errors go to fwdWriter (client) because the client is waiting - // for this response — the server doesn't need to know. - if len(msg.Result) > 0 { - if dlpResult := engine.ScanDLP(string(msg.Result)); dlpResult != nil { - log.Warn("Blocked %s response (DLP): rule=%s message=%s", - protocol, dlpResult.RuleName, dlpResult.Message) - SendBlockError(log, fwdWriter, msg.ID, - fmt.Sprintf("[Crust] Blocked by rule %q: %s", dlpResult.RuleName, dlpResult.Message)) - continue - } - } - if len(msg.Error) > 0 { - if dlpResult := engine.ScanDLP(string(msg.Error)); dlpResult != nil { - log.Warn("Blocked %s error response (DLP): rule=%s message=%s", - protocol, dlpResult.RuleName, dlpResult.Message) - SendBlockError(log, fwdWriter, msg.ID, - fmt.Sprintf("[Crust] Blocked by rule %q: %s", dlpResult.RuleName, dlpResult.Message)) - continue - } - } - if err := fwdWriter.WriteLine(line); err != nil { - log.Debug("%s write error: %v", label, err) + // Detect JSON-RPC batch arrays. Per JSON-RPC 2.0 spec, batch requests + // are JSON arrays. Without this check, arrays fail to unmarshal into + // the Message struct and are forwarded unexamined (security bypass). + if line[0] == '[' { + if processBatch(log, engine, line, fwdWriter, errWriter, convert, protocol, label) == resultWriteErr { return } continue } - toolCall, err := convert(msg.Method, msg.Params) - if toolCall == nil && err == nil { + var msg Message + if err := json.Unmarshal(line, &msg); err != nil { if err := fwdWriter.WriteLine(line); err != nil { log.Debug("%s write error: %v", label, err) return } continue } - if err != nil { - log.Warn("Blocked %s %s: %v", protocol, msg.Method, err) - SendBlockError(log, errWriter, msg.ID, "[Crust] Blocked: malformed params for "+msg.Method) - continue - } - result := engine.Evaluate(*toolCall) - - if result.Matched && result.Action == rules.ActionBlock { - log.Warn("Blocked %s %s (tool=%s): rule=%s message=%s", - protocol, msg.Method, toolCall.Name, result.RuleName, result.Message) - SendBlockError(log, errWriter, msg.ID, - fmt.Sprintf("[Crust] Blocked by rule %q: %s", result.RuleName, result.Message)) - continue - } - - if result.Matched && result.Action == rules.ActionLog { - log.Info("Logged %s %s (tool=%s): rule=%s", - protocol, msg.Method, toolCall.Name, result.RuleName) - } - - if err := fwdWriter.WriteLine(line); err != nil { - log.Debug("%s write error: %v", label, err) + if processMessage(log, engine, line, &msg, fwdWriter, errWriter, convert, protocol, label) == resultWriteErr { return } } diff --git a/internal/jsonrpc/proxy_test.go b/internal/jsonrpc/proxy_test.go index 12b62c2..957f12d 100644 --- a/internal/jsonrpc/proxy_test.go +++ b/internal/jsonrpc/proxy_test.go @@ -12,22 +12,11 @@ import ( "github.com/BakeLens/crust/internal/logger" "github.com/BakeLens/crust/internal/rules" + "github.com/BakeLens/crust/internal/testutil" ) var testLog = logger.New("test") -func newTestEngine(t *testing.T) *rules.Engine { - t.Helper() - engine, err := rules.NewEngine(rules.EngineConfig{ - UserRulesDir: t.TempDir(), - DisableBuiltin: false, - }) - if err != nil { - t.Fatalf("Failed to create engine: %v", err) - } - return engine -} - // blockAllConverter is a test converter that treats "security/call" as security-relevant // and maps it to a tool call that will be blocked by the built-in rules. func blockAllConverter(method string, params json.RawMessage) (*rules.ToolCall, error) { @@ -132,7 +121,7 @@ func TestPipePassthrough_EmptyInput(t *testing.T) { func runInspect(t *testing.T, input string, convert MethodConverter) (fwd, errOut string) { t.Helper() - engine := newTestEngine(t) + engine := testutil.NewEngine(t) var fwdBuf, errBuf bytes.Buffer fwdWriter := NewLockedWriter(&fwdBuf) errWriter := NewLockedWriter(&errBuf) @@ -250,7 +239,7 @@ func TestRunProxy_NoHang(t *testing.T) { if _, err := exec.LookPath("true"); err != nil { t.Skip("'true' not found in PATH") } - engine := newTestEngine(t) + engine := testutil.NewEngine(t) stdinR, stdinW := io.Pipe() defer stdinW.Close() // keep write end OPEN to expose the hang bug @@ -278,7 +267,7 @@ func TestRunProxy_ExitCode(t *testing.T) { if _, err := exec.LookPath("false"); err != nil { t.Skip("'false' not found in PATH") } - engine := newTestEngine(t) + engine := testutil.NewEngine(t) stdinR, stdinW := io.Pipe() defer stdinW.Close() @@ -306,7 +295,7 @@ func TestRunProxy_WithInspect(t *testing.T) { if _, err := exec.LookPath("cat"); err != nil { t.Skip("'cat' not found in PATH") } - engine := newTestEngine(t) + engine := testutil.NewEngine(t) // cat echoes stdin to stdout — so we send a message and check it comes through input := `{"jsonrpc":"2.0","id":1,"method":"other/call","params":{}}` + "\n" @@ -353,7 +342,148 @@ func TestForwardSignals_StopSignals(t *testing.T) { } } -// --- IsRequest --- +// --- Batch handling --- + +func TestPipeInspect_BatchBlocksSecurityMessages(t *testing.T) { + // Before the fix, arrays fail to unmarshal into Message and are forwarded unexamined. + batch := `[{"jsonrpc":"2.0","id":1,"method":"security/call","params":{}}]` + "\n" + fwd, errOut := runInspect(t, batch, blockAllConverter) + if strings.Contains(fwd, "security/call") { + t.Errorf("batch element should be blocked, but was forwarded: %s", fwd) + } + if errOut == "" { + t.Error("expected error response for blocked batch element") + } +} + +func TestPipeInspect_BatchMixed(t *testing.T) { + batch := `[{"jsonrpc":"2.0","id":1,"method":"security/call","params":{}},{"jsonrpc":"2.0","id":2,"method":"other/call","params":{}}]` + "\n" + fwd, errOut := runInspect(t, batch, blockAllConverter) + if !strings.Contains(fwd, "other/call") { + t.Errorf("allowed batch element should be forwarded, got: %s", fwd) + } + if strings.Contains(fwd, "security/call") { + t.Errorf("blocked batch element should not be forwarded") + } + if errOut == "" { + t.Error("expected error response for blocked batch element") + } +} + +func TestPipeInspect_BatchInvalidFallthrough(t *testing.T) { + msg := `[broken json` + "\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != msg { + t.Errorf("invalid batch should pass through, got: %q", fwd) + } + if errOut != "" { + t.Errorf("unexpected error: %s", errOut) + } +} + +func TestPipeInspect_BatchResponseDLP(t *testing.T) { + // Batch containing a response with DLP-triggering content + batch := `[{"jsonrpc":"2.0","id":1,"result":{"key":"AKIAIOSFODNN7EXAMPLE"}}]` + "\n" + fwd, _ := runInspect(t, batch, blockAllConverter) + if strings.Contains(fwd, "AKIA") { + t.Errorf("response with AWS key in batch should be blocked by DLP, got: %s", fwd) + } +} + +// --- Notification params DLP --- + +func TestPipeInspect_NotificationParamsDLP(t *testing.T) { + msg := `{"jsonrpc":"2.0","method":"update","params":{"token":"ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklm"}}` + "\n" + fwd, _ := runInspect(t, msg, blockAllConverter) + if strings.Contains(fwd, "ghp_") { + t.Errorf("notification with GitHub token in params should be blocked by DLP, got: %s", fwd) + } +} + +func TestPipeInspect_NotificationCleanParamsPassthrough(t *testing.T) { + msg := `{"jsonrpc":"2.0","method":"update","params":{"progress":50}}` + "\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != msg { + t.Errorf("clean notification should pass through, got: %q", fwd) + } + if errOut != "" { + t.Errorf("unexpected error: %s", errOut) + } +} + +// --- Notification security bypass --- +// Before the fix, notifications (method + no id) bypassed the converter +// and rule engine entirely — only DLP scanning ran. + +func TestPipeInspect_NotificationBlocksSecurityMethod(t *testing.T) { + // "security/call" as a notification (no id) must still be blocked. + msg := `{"jsonrpc":"2.0","method":"security/call","params":{}}` + "\n" + fwd, _ := runInspect(t, msg, blockAllConverter) + if strings.Contains(fwd, "security/call") { + t.Errorf("notification with security method should be blocked, got forwarded: %s", fwd) + } +} + +func TestPipeInspect_NotificationPassesNonSecurityMethod(t *testing.T) { + // Non-security notifications must still pass through. + msg := `{"jsonrpc":"2.0","method":"other/call","params":{}}` + "\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != msg { + t.Errorf("non-security notification should pass through, got %q", fwd) + } + if errOut != "" { + t.Errorf("unexpected error: %s", errOut) + } +} + +func TestPipeInspect_NotificationMalformedParamsBlocked(t *testing.T) { + // "malformed/call" as a notification — converter returns error, must be blocked. + msg := `{"jsonrpc":"2.0","method":"malformed/call","params":{}}` + "\n" + fwd, _ := runInspect(t, msg, blockAllConverter) + if fwd != "" { + t.Errorf("notification with malformed params should be blocked, got: %s", fwd) + } +} + +func TestPipeInspect_NotificationNoErrorResponseSent(t *testing.T) { + // Blocked notifications should NOT generate error responses (no id to reply to). + msg := `{"jsonrpc":"2.0","method":"security/call","params":{}}` + "\n" + _, errOut := runInspect(t, msg, blockAllConverter) + if errOut != "" { + t.Errorf("blocked notification should not generate error response, got: %s", errOut) + } +} + +func TestPipeInspect_BatchNotificationBlocked(t *testing.T) { + // Batch element as notification must also be blocked. + batch := `[{"jsonrpc":"2.0","method":"security/call","params":{}}]` + "\n" + fwd, _ := runInspect(t, batch, blockAllConverter) + if strings.Contains(fwd, "security/call") { + t.Errorf("batch notification should be blocked, got: %s", fwd) + } +} + +// --- IsRequest / IsNotification --- + +func TestMessage_IsNotification(t *testing.T) { + tests := []struct { + name string + msg Message + want bool + }{ + {"notification", Message{Method: "foo"}, true}, + {"request", Message{Method: "foo", ID: json.RawMessage(`1`)}, false}, + {"response", Message{ID: json.RawMessage(`1`), Result: json.RawMessage(`{}`)}, false}, + {"empty", Message{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.msg.IsNotification(); got != tt.want { + t.Errorf("IsNotification() = %v, want %v", got, tt.want) + } + }) + } +} func TestMessage_IsRequest(t *testing.T) { tests := []struct { diff --git a/internal/jsonrpc/types.go b/internal/jsonrpc/types.go index d778cdb..6c91990 100644 --- a/internal/jsonrpc/types.go +++ b/internal/jsonrpc/types.go @@ -31,6 +31,11 @@ func (m *Message) IsRequest() bool { return m.Method != "" && len(m.ID) > 0 } +// IsNotification returns true if this is a JSON-RPC notification (has method but no id). +func (m *Message) IsNotification() bool { + return m.Method != "" && len(m.ID) == 0 +} + // ErrorResponse is a JSON-RPC 2.0 error response. type ErrorResponse struct { JSONRPC string `json:"jsonrpc"` diff --git a/internal/mcpgateway/convert.go b/internal/mcpgateway/convert.go index fb58d6f..5e853b9 100644 --- a/internal/mcpgateway/convert.go +++ b/internal/mcpgateway/convert.go @@ -26,6 +26,19 @@ type resourcesReadParams struct { URI string `json:"uri"` } +// samplingCreateMessageParams represents the params of a MCP sampling/createMessage request. +// The server asks the client's LLM to process messages (potentially instructing tool use). +type samplingCreateMessageParams struct { + Messages json.RawMessage `json:"messages"` + MaxTokens int `json:"maxTokens"` +} + +// elicitationCreateParams represents the params of a MCP elicitation/create request. +// The server presents a form or URL to the user (phishing vector). +type elicitationCreateParams struct { + Message string `json:"message"` +} + // MCPMethodToToolCall converts an MCP JSON-RPC method + params into a rules.ToolCall. // // Returns: @@ -36,7 +49,7 @@ func MCPMethodToToolCall(method string, params json.RawMessage) (*rules.ToolCall // Reject nil/null params on security-relevant methods (json.Unmarshal silently // zero-initializes the struct, which would produce an empty name and bypass rules). switch method { - case "tools/call", "resources/read": + case "tools/call", "resources/read", "sampling/createMessage", "elicitation/create": if len(params) == 0 || string(params) == "null" { return nil, fmt.Errorf("nil params for security method %s", method) } @@ -95,6 +108,26 @@ func MCPMethodToToolCall(method string, params json.RawMessage) (*rules.ToolCall return &rules.ToolCall{Name: "mcp_resource_read", Arguments: args}, nil } + case "sampling/createMessage": + var p samplingCreateMessageParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, fmt.Errorf("malformed %s params: %w", method, err) + } + // Pass full params as arguments so DLP and content-based rules + // can inspect the embedded messages and system prompt. + return &rules.ToolCall{Name: "mcp_sampling", Arguments: params}, nil + + case "elicitation/create": + var p elicitationCreateParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, fmt.Errorf("malformed %s params: %w", method, err) + } + if p.Message == "" { + return nil, fmt.Errorf("empty message in %s", method) + } + // Pass full params so rule engine can inspect the message and schema. + return &rules.ToolCall{Name: "mcp_elicitation", Arguments: params}, nil + default: return nil, nil } diff --git a/internal/mcpgateway/convert_test.go b/internal/mcpgateway/convert_test.go index 74c112a..3589be0 100644 --- a/internal/mcpgateway/convert_test.go +++ b/internal/mcpgateway/convert_test.go @@ -104,8 +104,65 @@ func TestMcpMethodToToolCall_Unknown(t *testing.T) { } } +func TestMcpMethodToToolCall_SamplingCreateMessage(t *testing.T) { + params := `{"messages":[{"role":"user","content":{"type":"text","text":"hello"}}],"maxTokens":100}` + tc, err := MCPMethodToToolCall("sampling/createMessage", json.RawMessage(params)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc == nil { + t.Fatal("expected non-nil ToolCall") + } + if tc.Name != "mcp_sampling" { + t.Errorf("name = %s, want mcp_sampling", tc.Name) + } + // Full params should be passed as Arguments for DLP inspection. + if string(tc.Arguments) != params { + t.Errorf("arguments = %s, want %s", string(tc.Arguments), params) + } +} + +func TestMcpMethodToToolCall_ElicitationCreate(t *testing.T) { + params := `{"message":"Please enter your API key","schema":{"type":"object"}}` + tc, err := MCPMethodToToolCall("elicitation/create", json.RawMessage(params)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc == nil { + t.Fatal("expected non-nil ToolCall") + } + if tc.Name != "mcp_elicitation" { + t.Errorf("name = %s, want mcp_elicitation", tc.Name) + } + if string(tc.Arguments) != params { + t.Errorf("arguments = %s, want %s", string(tc.Arguments), params) + } +} + +func TestMcpMethodToToolCall_SamplingNilParams(t *testing.T) { + for _, input := range []json.RawMessage{nil, json.RawMessage(`null`), json.RawMessage(``)} { + tc, err := MCPMethodToToolCall("sampling/createMessage", input) + if tc != nil { + t.Errorf("expected nil ToolCall for nil/null params, got %+v", tc) + } + if err == nil { + t.Error("expected error for nil/null params on security method") + } + } +} + +func TestMcpMethodToToolCall_ElicitationEmptyMessage(t *testing.T) { + tc, err := MCPMethodToToolCall("elicitation/create", json.RawMessage(`{"message":""}`)) + if tc != nil { + t.Error("expected nil ToolCall for empty message") + } + if err == nil { + t.Error("expected error for empty message in elicitation/create") + } +} + func TestMcpMethodToToolCall_MalformedParams(t *testing.T) { - methods := []string{"tools/call", "resources/read"} + methods := []string{"tools/call", "resources/read", "sampling/createMessage", "elicitation/create"} badInputs := []json.RawMessage{ json.RawMessage(`{broken`), json.RawMessage(`"just a string"`), diff --git a/internal/mcpgateway/e2e_test.go b/internal/mcpgateway/e2e_test.go index acf5610..314f98f 100644 --- a/internal/mcpgateway/e2e_test.go +++ b/internal/mcpgateway/e2e_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/BakeLens/crust/internal/jsonrpc" + "github.com/BakeLens/crust/internal/testutil" ) // skipE2E skips if -short or npx not available. @@ -73,7 +74,7 @@ type e2eResponse struct { // all JSON-RPC responses received by the client. func runMCPE2E(t *testing.T, dir string, messages []string) []e2eResponse { t.Helper() - engine := newTestEngine(t) + engine := testutil.NewEngine(t) input := strings.Join(messages, "\n") + "\n" stdinR := io.NopCloser(strings.NewReader(input)) var stdout strings.Builder diff --git a/internal/mcpgateway/run_test.go b/internal/mcpgateway/run_test.go index 7f8b27a..1a49452 100644 --- a/internal/mcpgateway/run_test.go +++ b/internal/mcpgateway/run_test.go @@ -2,34 +2,21 @@ package mcpgateway import ( "bytes" - "encoding/json" "strings" "testing" "github.com/BakeLens/crust/internal/jsonrpc" "github.com/BakeLens/crust/internal/logger" - "github.com/BakeLens/crust/internal/rules" + "github.com/BakeLens/crust/internal/testutil" ) var testLog = logger.New("mcp-test") -func newTestEngine(t *testing.T) *rules.Engine { - t.Helper() - engine, err := rules.NewEngine(rules.EngineConfig{ - UserRulesDir: t.TempDir(), - DisableBuiltin: false, - }) - if err != nil { - t.Fatalf("Failed to create engine: %v", err) - } - return engine -} - // runPipe runs PipeInspect with MCPMethodToToolCall and returns what was // forwarded and what error responses were generated. func runPipe(t *testing.T, input string) (fwd, errOut string) { t.Helper() - engine := newTestEngine(t) + engine := testutil.NewEngine(t) var fwdBuf, errBuf bytes.Buffer fwdWriter := jsonrpc.NewLockedWriter(&fwdBuf) errWriter := jsonrpc.NewLockedWriter(&errBuf) @@ -39,7 +26,9 @@ func runPipe(t *testing.T, input string) (fwd, errOut string) { } // --- Edge-case blocking (malformed inputs, resources/read) --- -// Security blocking of .env, .ssh, etc. is covered by E2E tests (e2e_test.go). +// Path-based blocking, passthrough, batch handling, error shapes, and DLP are +// covered by jsonrpc/proxy_test.go (unit) and e2e_test.go (real MCP server). +// These tests verify MCP-specific converter edge cases in the full pipeline. func TestPipeClientToServer_BlocksEdgeCases(t *testing.T) { tests := []struct { @@ -63,47 +52,3 @@ func TestPipeClientToServer_BlocksEdgeCases(t *testing.T) { }) } } - -// --- Passthrough edge cases --- -// Passthrough of initialize, tools/list, and allowed tool calls is covered by E2E tests. - -func TestPipeClientToServer_PassesEdgeCases(t *testing.T) { - tests := []struct { - name string - msg string - }{ - {"notification", `{"jsonrpc":"2.0","method":"notifications/cancelled","params":{"requestId":1}}`}, //nolint:misspell // MCP protocol uses "cancelled" - {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"file data"}}`}, - {"invalid_json", `not valid json at all`}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fwd, errOut := runPipe(t, tt.msg+"\n") - if fwd != tt.msg+"\n" { - t.Errorf("message should pass through unchanged\ngot: %q\nwant: %q", fwd, tt.msg+"\n") - } - if errOut != "" { - t.Errorf("client should not receive errors, got: %s", errOut) - } - }) - } -} - -// --- resources/read error response shape --- - -func TestPipeClientToServer_ResourceReadErrorShape(t *testing.T) { - fwd, errOut := runPipe(t, `{"jsonrpc":"2.0","id":1,"method":"resources/read","params":{"uri":"file:///app/.env"}}`+"\n") - if fwd != "" { - t.Errorf("server should not receive blocked request, got: %s", fwd) - } - var resp jsonrpc.ErrorResponse - if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { - t.Fatalf("expected JSON-RPC error, got: %s", errOut) - } - if resp.Error.Code != jsonrpc.BlockedError { - t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) - } - if !strings.Contains(resp.Error.Message, "[Crust]") { - t.Errorf("error message missing [Crust]: %s", resp.Error.Message) - } -} diff --git a/internal/proxy/sse_buffer.go b/internal/proxy/sse_buffer.go index 35945d7..01f0f67 100644 --- a/internal/proxy/sse_buffer.go +++ b/internal/proxy/sse_buffer.go @@ -14,6 +14,7 @@ import ( "github.com/BakeLens/crust/internal/security" "github.com/BakeLens/crust/internal/telemetry" "github.com/BakeLens/crust/internal/types" + "mvdan.cc/sh/v3/syntax" ) const blockedToolSuffix = " Please inform the user and try a different approach." @@ -92,9 +93,13 @@ func NewBufferedSSEWriter(w http.ResponseWriter, maxSize int, timeout time.Durat // shellToolNames lists tool names that can execute shell commands (in priority order) var shellToolNames = []string{"Bash", "bash", "Shell", "shell", "Execute", "execute", "Exec", "exec", "RunCommand", "run_command", "Terminal", "terminal", "Cmd", "cmd"} -// shellQuote wraps s in single quotes with proper escaping for shell. +// shellQuote quotes a string for shell using the shell parser's own Quote function. func shellQuote(s string) string { - return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" + q, err := syntax.Quote(s, syntax.LangBash) + if err != nil { + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" + } + return q } // buildBlockedReplacement constructs the replacement command input for a blocked tool call. diff --git a/internal/proxy/sse_buffer_test.go b/internal/proxy/sse_buffer_test.go index 3fb5a78..2797aec 100644 --- a/internal/proxy/sse_buffer_test.go +++ b/internal/proxy/sse_buffer_test.go @@ -185,18 +185,6 @@ func TestBuildBlockedReplacement_WithoutMessage(t *testing.T) { } } -func TestBuildBlockedReplacement_ShellQuoting(t *testing.T) { - result := buildBlockedReplacement("Bash", rules.MatchResult{ - Message: "Can't do that", - }) - - cmd := result["command"] - // Single quote in "Can't" should be escaped for shell safety - if !strings.Contains(cmd, `'\''`) { - t.Errorf("single quote not escaped in command: %q", cmd) - } -} - func TestBufferedSSEWriter_ReplaceModeNoShellTool_FallsBackToRemove(t *testing.T) { w := httptest.NewRecorder() // Create buffer with NO shell tool (only non-shell tools) diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..d6967c6 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,22 @@ +// Package testutil provides shared test helpers for Crust's proxy packages. +package testutil + +import ( + "testing" + + "github.com/BakeLens/crust/internal/rules" +) + +// NewEngine creates a rules.Engine with built-in rules and a temp user rules dir. +// Accepts *testing.T, *testing.F, or *testing.B. +func NewEngine(tb testing.TB) *rules.Engine { + tb.Helper() + engine, err := rules.NewEngine(rules.EngineConfig{ + UserRulesDir: tb.TempDir(), + DisableBuiltin: false, + }) + if err != nil { + tb.Fatalf("Failed to create engine: %v", err) + } + return engine +}