From 3e0d5030178a24671ffeca2796d12c6af9f05f74 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 27 Feb 2026 10:23:48 +0800 Subject: [PATCH 1/6] chore: replace old demo scripts with unified TUI demo MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove demo.tape, demo-full.tape, and demo-agent.py (obsoleted by the new VHS-based TUI demo). Add demo-tui.tape, demo-acp.sh, and simplify demo-attack.sh (12 requests) and demo-mock.py (2 canned responses). All demo files use fake paths, fake credentials, and environment variables for API keys — no privacy-sensitive data. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 8e6c689..990efe9 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ Thumbs.db # VHS recordings (keep only GIF) docs/*.webm +docs/*.mp4 # Agent test local config (contains API keys) tests/agents/config.local.yaml From 79de81713a842fead6a2d59a655a5aed106f3579 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 27 Feb 2026 18:38:59 +0800 Subject: [PATCH 2/6] refactor: unify proxy code into internal/jsonrpc shared package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract ~750 lines of duplicated proxy infrastructure from acpwrap, mcpgateway, and autowrap into a shared internal/jsonrpc package with MethodConverter abstraction. Each protocol package is now a thin wrapper (~25 lines) plus its protocol-specific converter function. New internal/jsonrpc/ package: - types.go: Message, ErrorResponse, LockedWriter, MethodConverter - pipe.go: PipeInspect, PipePassthrough, SendBlockError - proxy.go: RunProxy (unified subprocess lifecycle) - signal_{unix,windows}.go: cross-platform signal forwarding Bug fixes: - Scanner buffer overflow now logged at WARN (was DEBUG) - Signal-forwarding goroutine tracked in WaitGroup - Tool name included in all block log messages - Stdin pipe closed if stdout pipe creation fails - Child stdout closed on Start() failure Also adds mock MCP server (cmd/mock-mcp-server) and golangci-lint exclusion for jsonrpc package name conflict with stdlib. Net: -747 lines deleted, +2266 added → ~43% reduction in proxy code. --- .golangci.yml | 5 + cmd/mock-mcp-server/main.go | 156 ++++++++++++ internal/acpwrap/convert.go | 117 +++++++++ internal/acpwrap/convert_test.go | 134 ++++++++++ internal/acpwrap/proxy.go | 352 -------------------------- internal/acpwrap/proxy_test.go | 340 ------------------------- internal/acpwrap/run.go | 22 ++ internal/acpwrap/run_test.go | 195 +++++++++++++++ internal/acpwrap/signal_unix.go | 20 -- internal/acpwrap/signal_windows.go | 19 -- internal/autowrap/run.go | 35 +++ internal/autowrap/run_test.go | 320 +++++++++++++++++++++++ internal/jsonrpc/pipe.go | 125 +++++++++ internal/jsonrpc/proxy.go | 139 ++++++++++ internal/jsonrpc/proxy_test.go | 376 ++++++++++++++++++++++++++++ internal/jsonrpc/signal_unix.go | 22 ++ internal/jsonrpc/signal_windows.go | 21 ++ internal/jsonrpc/types.go | 76 ++++++ internal/mcpgateway/convert.go | 101 ++++++++ internal/mcpgateway/convert_test.go | 147 +++++++++++ internal/mcpgateway/run.go | 22 ++ internal/mcpgateway/run_test.go | 197 +++++++++++++++ main.go | 72 ++++-- 23 files changed, 2266 insertions(+), 747 deletions(-) create mode 100644 cmd/mock-mcp-server/main.go create mode 100644 internal/acpwrap/convert.go create mode 100644 internal/acpwrap/convert_test.go delete mode 100644 internal/acpwrap/proxy.go delete mode 100644 internal/acpwrap/proxy_test.go create mode 100644 internal/acpwrap/run.go create mode 100644 internal/acpwrap/run_test.go delete mode 100644 internal/acpwrap/signal_unix.go delete mode 100644 internal/acpwrap/signal_windows.go create mode 100644 internal/autowrap/run.go create mode 100644 internal/autowrap/run_test.go create mode 100644 internal/jsonrpc/pipe.go create mode 100644 internal/jsonrpc/proxy.go create mode 100644 internal/jsonrpc/proxy_test.go create mode 100644 internal/jsonrpc/signal_unix.go create mode 100644 internal/jsonrpc/signal_windows.go create mode 100644 internal/jsonrpc/types.go create mode 100644 internal/mcpgateway/convert.go create mode 100644 internal/mcpgateway/convert_test.go create mode 100644 internal/mcpgateway/run.go create mode 100644 internal/mcpgateway/run_test.go diff --git a/.golangci.yml b/.golangci.yml index 217b542..327880e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -164,6 +164,11 @@ linters: linters: - revive text: "var-naming: avoid meaningless package names" + # "jsonrpc" conflicts with net/rpc/jsonrpc but is the clearest name for this package + - path: internal/jsonrpc/ + linters: + - revive + text: "var-naming: avoid package names that conflict" issues: max-issues-per-linter: 50 diff --git a/cmd/mock-mcp-server/main.go b/cmd/mock-mcp-server/main.go new file mode 100644 index 0000000..5bbc05d --- /dev/null +++ b/cmd/mock-mcp-server/main.go @@ -0,0 +1,156 @@ +// Package main implements a minimal mock MCP server for testing the Crust MCP gateway. +// It reads JSON-RPC 2.0 messages from stdin and writes responses to stdout. +package main + +import ( + "bufio" + "encoding/json" + "fmt" + "os" +) + +type jsonRPCMessage struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Result any `json:"result"` +} + +type jsonRPCError struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Error struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` +} + +type toolsCallParams struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +func main() { + scanner := bufio.NewScanner(os.Stdin) + scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024) + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + var msg jsonRPCMessage + if err := json.Unmarshal(line, &msg); err != nil { + continue + } + + if msg.Method == "" || len(msg.ID) == 0 { + continue // not a request + } + + switch msg.Method { + case "initialize": + respond(msg.ID, map[string]any{ + "protocolVersion": "2024-11-05", + "capabilities": map[string]any{ + "tools": map[string]any{}, + "resources": map[string]any{}, + }, + "serverInfo": map[string]any{ + "name": "mock-mcp-server", + "version": "1.0.0", + }, + }) + + case "tools/list": + respond(msg.ID, map[string]any{ + "tools": []any{ + map[string]any{ + "name": "read_file", + "description": "Read a file", + "inputSchema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + "required": []string{"path"}, + }, + }, + map[string]any{ + "name": "write_file", + "description": "Write a file", + "inputSchema": map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + "content": map[string]any{"type": "string"}, + }, + "required": []string{"path", "content"}, + }, + }, + }, + }) + + case "tools/call": + var p toolsCallParams + if err := json.Unmarshal(msg.Params, &p); err != nil { + respondError(msg.ID, -32602, "invalid params") + continue + } + respond(msg.ID, map[string]any{ + "content": []any{ + map[string]any{ + "type": "text", + "text": fmt.Sprintf("[mock] tool=%s executed successfully", p.Name), + }, + }, + }) + + case "resources/read": + respond(msg.ID, map[string]any{ + "contents": []any{ + map[string]any{ + "uri": "file:///mock", + "mimeType": "text/plain", + "text": "[mock] resource content", + }, + }, + }) + + default: + respondError(msg.ID, -32601, "method not found: "+msg.Method) + } + } +} + +func respond(id json.RawMessage, result any) { + resp, err := json.Marshal(jsonRPCResponse{ + JSONRPC: "2.0", + ID: id, + Result: result, + }) + if err != nil { + fmt.Fprintf(os.Stderr, "marshal error: %v\n", err) + return + } + fmt.Fprintf(os.Stdout, "%s\n", resp) +} + +func respondError(id json.RawMessage, code int, msg string) { + resp := jsonRPCError{JSONRPC: "2.0", ID: id} + resp.Error.Code = code + resp.Error.Message = msg + data, err := json.Marshal(resp) + if err != nil { + fmt.Fprintf(os.Stderr, "marshal error: %v\n", err) + return + } + fmt.Fprintf(os.Stdout, "%s\n", data) +} diff --git a/internal/acpwrap/convert.go b/internal/acpwrap/convert.go new file mode 100644 index 0000000..6a94c70 --- /dev/null +++ b/internal/acpwrap/convert.go @@ -0,0 +1,117 @@ +// Package acpwrap implements a transparent stdio proxy for ACP (Agent Client Protocol) +// agents, intercepting security-relevant JSON-RPC messages using Crust's rule engine. +package acpwrap + +import ( + "encoding/json" + "fmt" + "strings" + "unicode" + + "github.com/BakeLens/crust/internal/rules" +) + +// ACP parameter types + +type fsReadParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` +} + +type fsWriteParams struct { + SessionID string `json:"sessionId"` + Path string `json:"path"` + Content string `json:"content"` +} + +type terminalCreateParams struct { + SessionID string `json:"sessionId"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` + Cwd string `json:"cwd,omitempty"` +} + +// shellSafe is the set of characters that don't need quoting in shell arguments. +const shellSafe = "-_./:=+," + +// shellQuote quotes a shell argument if it contains special characters. +func shellQuote(s string) string { + if s == "" { + return "''" + } + if strings.ContainsFunc(s, func(c rune) bool { + return !unicode.IsLetter(c) && !unicode.IsDigit(c) && !strings.ContainsRune(shellSafe, c) + }) { + return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" + } + return s +} + +// ACPMethodToToolCall converts an ACP JSON-RPC method + params into a rules.ToolCall. +// +// Returns: +// - (*ToolCall, nil) for successfully parsed security-relevant methods +// - (nil, nil) for non-security methods (caller should pass through) +// - (nil, error) for security-relevant methods with malformed params (caller should block) +func ACPMethodToToolCall(method string, params json.RawMessage) (*rules.ToolCall, error) { + // Reject nil/null params on security-relevant methods (json.Unmarshal silently + // zero-initializes the struct, which would produce an empty path and bypass rules). + switch method { + case "fs/read_text_file", "fs/write_text_file", "terminal/create": + if len(params) == 0 || string(params) == "null" { + return nil, fmt.Errorf("nil params for security method %s", method) + } + default: + return nil, nil // not security-relevant + } + + switch method { + case "fs/read_text_file": + var p fsReadParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, fmt.Errorf("malformed %s params: %w", method, err) + } + args, err := json.Marshal(map[string]string{"path": p.Path}) + if err != nil { + return nil, fmt.Errorf("marshal error: %w", err) + } + return &rules.ToolCall{Name: "read_file", Arguments: args}, nil + + case "fs/write_text_file": + var p fsWriteParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, fmt.Errorf("malformed %s params: %w", method, err) + } + args, err := json.Marshal(map[string]any{ + "path": p.Path, + "content": p.Content, + }) + if err != nil { + return nil, fmt.Errorf("marshal error: %w", err) + } + return &rules.ToolCall{Name: "write_file", Arguments: args}, nil + + case "terminal/create": + var p terminalCreateParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, fmt.Errorf("malformed %s params: %w", method, err) + } + fullCmd := p.Command + if len(p.Args) > 0 { + quoted := make([]string, len(p.Args)) + for i, a := range p.Args { + quoted[i] = shellQuote(a) + } + fullCmd += " " + strings.Join(quoted, " ") + } + args, err := json.Marshal(map[string]string{"command": fullCmd}) + if err != nil { + return nil, fmt.Errorf("marshal error: %w", err) + } + return &rules.ToolCall{Name: "bash", Arguments: args}, nil + + default: + return nil, nil + } +} diff --git a/internal/acpwrap/convert_test.go b/internal/acpwrap/convert_test.go new file mode 100644 index 0000000..a99d2d1 --- /dev/null +++ b/internal/acpwrap/convert_test.go @@ -0,0 +1,134 @@ +package acpwrap + +import ( + "encoding/json" + "testing" +) + +// --- ACPMethodToToolCall --- + +func TestAcpMethodToToolCall(t *testing.T) { + tests := []struct { + name string + method string + params string + wantName string + wantKey string + wantVal string + }{ + {"fs_read", "fs/read_text_file", `{"sessionId":"s1","path":"/etc/passwd"}`, "read_file", "path", "/etc/passwd"}, + {"fs_write", "fs/write_text_file", `{"sessionId":"s1","path":"/home/user/.env","content":"SECRET=abc"}`, "write_file", "path", "/home/user/.env"}, + {"terminal", "terminal/create", `{"sessionId":"s1","command":"rm","args":["-rf","/"]}`, "bash", "command", "rm -rf /"}, + {"terminal_no_args", "terminal/create", `{"sessionId":"s1","command":"ls"}`, "bash", "command", "ls"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc, err := ACPMethodToToolCall(tt.method, json.RawMessage(tt.params)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc == nil { + t.Fatal("expected non-nil ToolCall") + } + if tc.Name != tt.wantName { + t.Errorf("name = %s, want %s", tc.Name, tt.wantName) + } + var args map[string]any + if err := json.Unmarshal(tc.Arguments, &args); err != nil { + t.Fatal(err) + } + if got := args[tt.wantKey]; got != tt.wantVal { + t.Errorf("%s = %v, want %s", tt.wantKey, got, tt.wantVal) + } + }) + } +} + +func TestAcpMethodToToolCall_Unknown(t *testing.T) { + for _, method := range []string{"session/prompt", "initialize", "fs/delete"} { + tc, err := ACPMethodToToolCall(method, nil) + if err != nil { + t.Fatalf("%s: unexpected error: %v", method, err) + } + if tc != nil { + t.Errorf("%s should not be security-relevant", method) + } + } +} + +func TestAcpMethodToToolCall_MalformedParams(t *testing.T) { + methods := []string{"fs/read_text_file", "fs/write_text_file", "terminal/create"} + badInputs := []json.RawMessage{ + json.RawMessage(`{broken`), + json.RawMessage(`"just a string"`), + json.RawMessage(`null`), + json.RawMessage(`42`), + json.RawMessage(``), + } + for _, method := range methods { + for _, input := range badInputs { + tc, err := ACPMethodToToolCall(method, input) + if tc != nil { + t.Errorf("%s with %q: expected nil ToolCall for malformed params", method, input) + } + if err == nil { + t.Errorf("%s with %q: expected error for malformed params", method, input) + } + } + } +} + +// --- shellQuote --- + +func TestShellQuote(t *testing.T) { + tests := []struct { + input, want string + }{ + {"", "''"}, + {"simple", "simple"}, + {"-rf", "-rf"}, + {"/etc/passwd", "/etc/passwd"}, + {"file with spaces", "'file with spaces'"}, + {"; rm -rf /", "'; rm -rf /'"}, + {"$(whoami)", "'$(whoami)'"}, + {"`id`", "'`id`'"}, + {"it's", "'it'\"'\"'s'"}, + {"a&b", "'a&b'"}, + {"a|b", "'a|b'"}, + {"a>b", "'a>b'"}, + {"a 0 -} - -// jsonRPCError is a JSON-RPC 2.0 error response. -type jsonRPCError struct { - JSONRPC string `json:"jsonrpc"` - ID json.RawMessage `json:"id"` - Error jsonRPCErrorObj `json:"error"` -} - -type jsonRPCErrorObj struct { - Code int `json:"code"` - Message string `json:"message"` -} - -// lockedWriter is a mutex-protected writer for agent stdin. -// Both the IDE→Agent goroutine and the blocking logic write to it. -type lockedWriter struct { - mu sync.Mutex - w io.Writer -} - -func (lw *lockedWriter) writeLine(data []byte) error { - lw.mu.Lock() - defer lw.mu.Unlock() - if _, err := lw.w.Write(data); err != nil { - return err - } - _, err := lw.w.Write([]byte{'\n'}) - return err -} - -// ACP parameter types - -type fsReadParams struct { - SessionID string `json:"sessionId"` - Path string `json:"path"` -} - -type fsWriteParams struct { - SessionID string `json:"sessionId"` - Path string `json:"path"` - Content string `json:"content"` -} - -type terminalCreateParams struct { - SessionID string `json:"sessionId"` - Command string `json:"command"` - Args []string `json:"args,omitempty"` - Env map[string]string `json:"env,omitempty"` - Cwd string `json:"cwd,omitempty"` -} - -// shellSafe is the set of characters that don't need quoting in shell arguments. -const shellSafe = "-_./:=+," - -// shellQuote quotes a shell argument if it contains special characters. -func shellQuote(s string) string { - if s == "" { - return "''" - } - if strings.ContainsFunc(s, func(c rune) bool { - return !unicode.IsLetter(c) && !unicode.IsDigit(c) && !strings.ContainsRune(shellSafe, c) - }) { - return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" - } - return s -} - -// acpMethodToToolCall converts an ACP JSON-RPC method + params into a rules.ToolCall -// that the existing rule engine can evaluate. -// Returns (toolCall, true, nil) for successfully parsed security-relevant methods, -// (_, false, nil) for non-security methods, and (_, true, err) when a security-relevant -// method has malformed params (caller should block). -func acpMethodToToolCall(method string, params json.RawMessage) (rules.ToolCall, bool, error) { - // Reject nil/null params on security-relevant methods (json.Unmarshal silently - // zero-initializes the struct, which would produce an empty path and bypass rules). - switch method { - case "fs/read_text_file", "fs/write_text_file", "terminal/create": - if len(params) == 0 || string(params) == "null" { - return rules.ToolCall{}, true, fmt.Errorf("nil params for security method %s", method) - } - } - - switch method { - case "fs/read_text_file": - var p fsReadParams - if err := json.Unmarshal(params, &p); err != nil { - return rules.ToolCall{}, true, fmt.Errorf("malformed %s params: %w", method, err) - } - args, err := json.Marshal(map[string]string{"path": p.Path}) - if err != nil { - return rules.ToolCall{}, true, fmt.Errorf("marshal error: %w", err) - } - return rules.ToolCall{Name: "read_file", Arguments: args}, true, nil - - case "fs/write_text_file": - var p fsWriteParams - if err := json.Unmarshal(params, &p); err != nil { - return rules.ToolCall{}, true, fmt.Errorf("malformed %s params: %w", method, err) - } - args, err := json.Marshal(map[string]any{ - "path": p.Path, - "content": p.Content, - }) - if err != nil { - return rules.ToolCall{}, true, fmt.Errorf("marshal error: %w", err) - } - return rules.ToolCall{Name: "write_file", Arguments: args}, true, nil - - case "terminal/create": - var p terminalCreateParams - if err := json.Unmarshal(params, &p); err != nil { - return rules.ToolCall{}, true, fmt.Errorf("malformed %s params: %w", method, err) - } - fullCmd := p.Command - if len(p.Args) > 0 { - quoted := make([]string, len(p.Args)) - for i, a := range p.Args { - quoted[i] = shellQuote(a) - } - fullCmd += " " + strings.Join(quoted, " ") - } - args, err := json.Marshal(map[string]string{"command": fullCmd}) - if err != nil { - return rules.ToolCall{}, true, fmt.Errorf("marshal error: %w", err) - } - return rules.ToolCall{Name: "bash", Arguments: args}, true, nil - - default: - return rules.ToolCall{}, false, nil - } -} - -// Run starts the ACP proxy. It spawns the agent subprocess, wires up stdio, -// and evaluates security-relevant messages. Returns the agent's exit code. -func Run(engine *rules.Engine, agentCmd []string) int { - return runProxy(engine, agentCmd, os.Stdin, os.Stdout) -} - -// runProxy is the internal implementation of Run, accepting explicit IO handles -// so that ideStdin can be closed after the agent exits (unblocking the scanner). -func runProxy(engine *rules.Engine, agentCmd []string, ideStdin io.ReadCloser, ideStdout io.Writer) int { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cmd := exec.CommandContext(ctx, agentCmd[0], agentCmd[1:]...) //nolint:gosec // agentCmd is user-specified by design - agentStdin, err := cmd.StdinPipe() - if err != nil { - log.Error("Failed to create agent stdin pipe: %v", err) - return 1 - } - agentStdout, err := cmd.StdoutPipe() - if err != nil { - log.Error("Failed to create agent stdout pipe: %v", err) - return 1 - } - cmd.Stderr = os.Stderr - - if err := cmd.Start(); err != nil { - log.Error("Failed to start agent: %v", err) - return 1 - } - - log.Info("Agent started: PID %d, command: %v", cmd.Process.Pid, agentCmd) - log.Info("Rule engine: %d rules loaded", engine.RuleCount()) - - writer := &lockedWriter{w: agentStdin} - - var wg sync.WaitGroup - - // Goroutine 1: IDE → Agent (pass-through, line-by-line with mutex) - wg.Add(1) - go func() { - defer wg.Done() - defer agentStdin.Close() - pipeIDEToAgent(ideStdin, writer) - }() - - // Goroutine 2: Agent → IDE (inspect security-relevant messages) - wg.Add(1) - go func() { - defer wg.Done() - pipeAgentToIDE(engine, agentStdout, ideStdout, writer) - }() - - // Goroutine 3: Forward signals to child - sigCh := forwardSignals() - go func() { - for sig := range sigCh { - if cmd.Process != nil { - _ = cmd.Process.Signal(sig) - } - } - }() - - // Wait for agent to exit - waitErr := cmd.Wait() - stopSignals(sigCh) - cancel() - - // Close IDE stdin to unblock the pipeIDEToAgent goroutine's scanner. - // In production this is os.Stdin; safe because crust exits immediately after. - if ideStdin != nil { - ideStdin.Close() - } - - // Wait for pipe goroutines to finish draining - wg.Wait() - - if waitErr != nil { - var exitErr *exec.ExitError - if errors.As(waitErr, &exitErr) { - return exitErr.ExitCode() - } - return 1 - } - return 0 -} - -// pipeIDEToAgent reads JSONL from the IDE (our stdin) and forwards each line -// to the agent's stdin through the lockedWriter. -func pipeIDEToAgent(ideStdin io.Reader, writer *lockedWriter) { - scanner := bufio.NewScanner(ideStdin) - scanner.Buffer(make([]byte, 0, 64*1024), maxScannerBuf) - - for scanner.Scan() { - if err := writer.writeLine(scanner.Bytes()); err != nil { - log.Debug("IDE→Agent write error: %v", err) - return - } - } - if err := scanner.Err(); err != nil { - log.Debug("IDE stdin scanner error: %v", err) - } -} - -// sendBlockError sends a JSON-RPC error response back to the agent's stdin. -func sendBlockError(writer *lockedWriter, id json.RawMessage, msg string) { - resp, err := json.Marshal(jsonRPCError{ - JSONRPC: "2.0", - ID: id, - Error: jsonRPCErrorObj{Code: jsonRPCBlockedError, Message: msg}, - }) - if err != nil { - log.Debug("Failed to marshal block response: %v", err) - return - } - if err := writer.writeLine(resp); err != nil { - log.Debug("Failed to send block response to agent: %v", err) - } -} - -// pipeAgentToIDE reads JSONL from the agent's stdout, inspects security-relevant -// messages, and either forwards them to the IDE or blocks them. -func pipeAgentToIDE(engine *rules.Engine, agentStdout io.Reader, ideStdout io.Writer, agentWriter *lockedWriter) { - scanner := bufio.NewScanner(agentStdout) - scanner.Buffer(make([]byte, 0, 64*1024), maxScannerBuf) - - for scanner.Scan() { - line := scanner.Bytes() - if len(line) == 0 { - fmt.Fprintln(ideStdout) - continue - } - - var msg jsonRPCMessage - if err := json.Unmarshal(line, &msg); err != nil { - fmt.Fprintf(ideStdout, "%s\n", line) //nolint:gosec // not valid JSON — forward as-is - continue - } - - if !msg.isRequest() { - fmt.Fprintf(ideStdout, "%s\n", line) //nolint:gosec // stdio pipe, not HTTP - continue - } - - toolCall, isSecurityRelevant, parseErr := acpMethodToToolCall(msg.Method, msg.Params) - if !isSecurityRelevant { - fmt.Fprintf(ideStdout, "%s\n", line) //nolint:gosec // stdio pipe, not HTTP - continue - } - - if parseErr != nil { - log.Warn("Blocked ACP %s: %v", msg.Method, parseErr) - sendBlockError(agentWriter, msg.ID, "[Crust] Blocked: malformed params for "+msg.Method) - continue - } - - result := engine.Evaluate(toolCall) - - if result.Matched && result.Action == rules.ActionBlock { - log.Warn("Blocked ACP %s: rule=%s message=%s", msg.Method, result.RuleName, result.Message) - sendBlockError(agentWriter, msg.ID, fmt.Sprintf("[Crust] Blocked by rule %q: %s", result.RuleName, result.Message)) - continue - } - - if result.Matched && result.Action == rules.ActionLog { - log.Info("Logged ACP %s: rule=%s", msg.Method, result.RuleName) - } - - fmt.Fprintf(ideStdout, "%s\n", line) //nolint:gosec // stdio pipe, not HTTP - } - - if err := scanner.Err(); err != nil { - log.Debug("Agent stdout scanner error: %v", err) - } -} diff --git a/internal/acpwrap/proxy_test.go b/internal/acpwrap/proxy_test.go deleted file mode 100644 index bcbe4d5..0000000 --- a/internal/acpwrap/proxy_test.go +++ /dev/null @@ -1,340 +0,0 @@ -package acpwrap - -import ( - "bytes" - "encoding/json" - "io" - "os/exec" - "strings" - "testing" - "time" - - "github.com/BakeLens/crust/internal/rules" -) - -func newTestEngine(t *testing.T) *rules.Engine { - t.Helper() - engine, err := rules.NewEngine(rules.EngineConfig{ - UserRulesDir: t.TempDir(), - DisableBuiltin: false, - }) - if err != nil { - t.Fatalf("Failed to create engine: %v", err) - } - return engine -} - -// runPipe runs pipeAgentToIDE with the given input and returns what the IDE -// received and what error responses were sent back to the agent. -func runPipe(t *testing.T, input string) (ideOut, agentErr string) { - t.Helper() - engine := newTestEngine(t) - var ideStdout, agentStdinBuf bytes.Buffer - pipeAgentToIDE(engine, strings.NewReader(input), &ideStdout, &lockedWriter{w: &agentStdinBuf}) - return ideStdout.String(), agentStdinBuf.String() -} - -// --- acpMethodToToolCall --- - -func TestAcpMethodToToolCall(t *testing.T) { - tests := []struct { - name string - method string - params string - wantName string - wantKey string - wantVal string - }{ - {"fs_read", "fs/read_text_file", `{"sessionId":"s1","path":"/etc/passwd"}`, "read_file", "path", "/etc/passwd"}, - {"fs_write", "fs/write_text_file", `{"sessionId":"s1","path":"/home/user/.env","content":"SECRET=abc"}`, "write_file", "path", "/home/user/.env"}, - {"terminal", "terminal/create", `{"sessionId":"s1","command":"rm","args":["-rf","/"]}`, "bash", "command", "rm -rf /"}, - {"terminal_no_args", "terminal/create", `{"sessionId":"s1","command":"ls"}`, "bash", "command", "ls"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tc, ok, err := acpMethodToToolCall(tt.method, json.RawMessage(tt.params)) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !ok { - t.Fatal("expected security-relevant") - } - if tc.Name != tt.wantName { - t.Errorf("name = %s, want %s", tc.Name, tt.wantName) - } - var args map[string]any - if err := json.Unmarshal(tc.Arguments, &args); err != nil { - t.Fatal(err) - } - if got := args[tt.wantKey]; got != tt.wantVal { - t.Errorf("%s = %v, want %s", tt.wantKey, got, tt.wantVal) - } - }) - } -} - -func TestAcpMethodToToolCall_Unknown(t *testing.T) { - for _, method := range []string{"session/prompt", "initialize", "fs/delete"} { - _, ok, err := acpMethodToToolCall(method, nil) - if err != nil { - t.Fatalf("%s: unexpected error: %v", method, err) - } - if ok { - t.Errorf("%s should not be security-relevant", method) - } - } -} - -func TestAcpMethodToToolCall_MalformedParams(t *testing.T) { - methods := []string{"fs/read_text_file", "fs/write_text_file", "terminal/create"} - badInputs := []json.RawMessage{ - json.RawMessage(`{broken`), - json.RawMessage(`"just a string"`), - json.RawMessage(`null`), - json.RawMessage(`42`), - json.RawMessage(``), - } - for _, method := range methods { - for _, input := range badInputs { - _, ok, err := acpMethodToToolCall(method, input) - if !ok { - t.Errorf("%s with %q: expected ok=true (security-relevant)", method, input) - } - if err == nil { - t.Errorf("%s with %q: expected error for malformed params", method, input) - } - } - } -} - -// --- shellQuote --- - -func TestShellQuote(t *testing.T) { - tests := []struct { - input, want string - }{ - {"", "''"}, - {"simple", "simple"}, - {"-rf", "-rf"}, - {"/etc/passwd", "/etc/passwd"}, - {"file with spaces", "'file with spaces'"}, - {"; rm -rf /", "'; rm -rf /'"}, - {"$(whoami)", "'$(whoami)'"}, - {"`id`", "'`id`'"}, - {"it's", "'it'\"'\"'s'"}, - {"a&b", "'a&b'"}, - {"a|b", "'a|b'"}, - {"a>b", "'a>b'"}, - {"aAgent"}, + Outbound: jsonrpc.PipeConfig{Label: "Agent->IDE", Protocol: "ACP", Convert: ACPMethodToToolCall}, + }) +} diff --git a/internal/acpwrap/run_test.go b/internal/acpwrap/run_test.go new file mode 100644 index 0000000..dd0207a --- /dev/null +++ b/internal/acpwrap/run_test.go @@ -0,0 +1,195 @@ +package acpwrap + +import ( + "bytes" + "encoding/json" + "io" + "os/exec" + "strings" + "testing" + "time" + + "github.com/BakeLens/crust/internal/jsonrpc" + "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/rules" +) + +var testLog = logger.New("acp-test") + +func newTestEngine(t *testing.T) *rules.Engine { + t.Helper() + engine, err := rules.NewEngine(rules.EngineConfig{ + UserRulesDir: t.TempDir(), + DisableBuiltin: false, + }) + if err != nil { + t.Fatalf("Failed to create engine: %v", err) + } + return engine +} + +// runPipe runs PipeInspect with ACPMethodToToolCall and returns what was +// forwarded and what error responses were generated. +func runPipe(t *testing.T, input string) (fwd, errOut string) { + t.Helper() + engine := newTestEngine(t) + var fwdBuf, errBuf bytes.Buffer + fwdWriter := jsonrpc.NewLockedWriter(&fwdBuf) + errWriter := jsonrpc.NewLockedWriter(&errBuf) + jsonrpc.PipeInspect(testLog, engine, strings.NewReader(input), + fwdWriter, errWriter, ACPMethodToToolCall, "ACP", "Agent->IDE") + return fwdBuf.String(), errBuf.String() +} + +// --- RunProxy (hang / exit code) --- + +func TestRunProxy(t *testing.T) { + t.Run("no_hang_on_agent_exit", func(t *testing.T) { + if _, err := exec.LookPath("true"); err != nil { + t.Skip("'true' not found in PATH") + } + engine := newTestEngine(t) + stdinR, stdinW := io.Pipe() + defer stdinW.Close() // keep write end OPEN to expose the hang bug + + done := make(chan int, 1) + go func() { + done <- jsonrpc.RunProxy(engine, []string{"true"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ + Log: testLog, + ProcessLabel: "Agent", + Inbound: jsonrpc.PipeConfig{Label: "IDE->Agent"}, + Outbound: jsonrpc.PipeConfig{Label: "Agent->IDE", Protocol: "ACP", Convert: ACPMethodToToolCall}, + }) + }() + + select { + case code := <-done: + if code != 0 { + t.Errorf("exit code = %d, want 0", code) + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung — IDE stdin not closed after agent exit") + } + }) + + t.Run("propagates_exit_code", func(t *testing.T) { + if _, err := exec.LookPath("false"); err != nil { + t.Skip("'false' not found in PATH") + } + engine := newTestEngine(t) + stdinR, stdinW := io.Pipe() + defer stdinW.Close() + + done := make(chan int, 1) + go func() { + done <- jsonrpc.RunProxy(engine, []string{"false"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ + Log: testLog, + ProcessLabel: "Agent", + Inbound: jsonrpc.PipeConfig{Label: "IDE->Agent"}, + Outbound: jsonrpc.PipeConfig{Label: "Agent->IDE", Protocol: "ACP", Convert: ACPMethodToToolCall}, + }) + }() + + select { + case code := <-done: + if code == 0 { + t.Error("expected non-zero exit code from 'false'") + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung") + } + }) +} + +// --- PipeInspect + ACPMethodToToolCall integration --- + +func TestPipeAgentToIDE_Blocks(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"env_read", `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`}, + {"ssh_key_read", `{"jsonrpc":"2.0","id":2,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/home/user/.ssh/id_rsa"}}`}, + {"env_write", `{"jsonrpc":"2.0","id":3,"method":"fs/write_text_file","params":{"sessionId":"s1","path":"/app/.env","content":"API_KEY=secret"}}`}, + {"malformed_read_params", `{"jsonrpc":"2.0","id":4,"method":"fs/read_text_file","params":"not-an-object"}`}, + {"malformed_terminal_params", `{"jsonrpc":"2.0","id":5,"method":"terminal/create","params":42}`}, + {"null_params", `{"jsonrpc":"2.0","id":6,"method":"fs/read_text_file","params":null}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, errOut := runPipe(t, tt.msg+"\n") + if fwd != "" { + t.Errorf("IDE should not receive blocked request, got: %s", fwd) + } + if errOut == "" { + t.Error("agent should receive an error response") + } + }) + } +} + +func TestPipeAgentToIDE_BlocksEnvRead_ErrorShape(t *testing.T) { + fwd, errOut := runPipe(t, `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`+"\n") + if fwd != "" { + t.Errorf("IDE should not receive blocked request, got: %s", fwd) + } + var resp jsonrpc.ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { + t.Fatalf("expected JSON-RPC error, got: %s", errOut) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message missing [Crust]: %s", resp.Error.Message) + } +} + +func TestPipeAgentToIDE_Passes(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"normal_read", `{"jsonrpc":"2.0","id":10,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/src/main.go"}}`}, + {"non_security_method", `{"jsonrpc":"2.0","id":20,"method":"session/prompt","params":{"text":"hello"}}`}, + {"notification", `{"jsonrpc":"2.0","method":"session/update","params":{"status":"working"}}`}, + {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"file data"}}`}, + {"invalid_json", `not valid json at all`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, errOut := runPipe(t, tt.msg+"\n") + if fwd != tt.msg+"\n" { + t.Errorf("message should pass through unchanged\ngot: %q\nwant: %q", fwd, tt.msg+"\n") + } + if errOut != "" { + t.Errorf("agent should not receive errors, got: %s", errOut) + } + }) + } +} + +func TestPipeAgentToIDE_EmptyLine(t *testing.T) { + fwd, _ := runPipe(t, "\n") + if fwd != "\n" { + t.Errorf("empty line should pass through, got: %q", fwd) + } +} + +func TestPipeAgentToIDE_MultipleMessages(t *testing.T) { + msgs := strings.Join([]string{ + `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`, + `{"jsonrpc":"2.0","id":2,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/main.go"}}`, + `{"jsonrpc":"2.0","id":3,"method":"session/prompt","params":{"text":"hi"}}`, + }, "\n") + "\n" + + fwd, errOut := runPipe(t, msgs) + + fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") + if len(fwdLines) != 2 { + t.Errorf("expected 2 IDE messages (main.go + prompt), got %d: %v", len(fwdLines), fwdLines) + } + if errOut == "" { + t.Error("agent should receive error for .env read") + } +} diff --git a/internal/acpwrap/signal_unix.go b/internal/acpwrap/signal_unix.go deleted file mode 100644 index 4b8cb92..0000000 --- a/internal/acpwrap/signal_unix.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build !windows - -package acpwrap - -import ( - "os" - "os/signal" - "syscall" -) - -func forwardSignals() chan os.Signal { - ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - return ch -} - -func stopSignals(ch chan os.Signal) { - signal.Stop(ch) - close(ch) -} diff --git a/internal/acpwrap/signal_windows.go b/internal/acpwrap/signal_windows.go deleted file mode 100644 index fb40905..0000000 --- a/internal/acpwrap/signal_windows.go +++ /dev/null @@ -1,19 +0,0 @@ -//go:build windows - -package acpwrap - -import ( - "os" - "os/signal" -) - -func forwardSignals() chan os.Signal { - ch := make(chan os.Signal, 1) - signal.Notify(ch, os.Interrupt) - return ch -} - -func stopSignals(ch chan os.Signal) { - signal.Stop(ch) - close(ch) -} diff --git a/internal/autowrap/run.go b/internal/autowrap/run.go new file mode 100644 index 0000000..3ef80f3 --- /dev/null +++ b/internal/autowrap/run.go @@ -0,0 +1,35 @@ +// Package autowrap implements a transparent stdio proxy that auto-detects +// whether the wrapped subprocess speaks ACP or MCP protocol. +// +// It inspects both directions simultaneously: +// - Inbound (client/IDE -> subprocess): checks for MCP security methods +// - Outbound (subprocess -> client/IDE): checks for ACP security methods +// +// Method names between ACP and MCP are disjoint, so there is no ambiguity. +package autowrap + +import ( + "os" + + "github.com/BakeLens/crust/internal/acpwrap" + "github.com/BakeLens/crust/internal/jsonrpc" + "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/mcpgateway" + "github.com/BakeLens/crust/internal/rules" +) + +var log = logger.New("wrap") + +// Run starts the auto-detecting proxy. It spawns the subprocess, wires up stdio, +// and evaluates security-relevant messages from both ACP and MCP protocols. +func Run(engine *rules.Engine, cmd []string) int { + return jsonrpc.RunProxy(engine, cmd, os.Stdin, os.Stdout, jsonrpc.ProxyConfig{ + Log: log, + ProcessLabel: "Subprocess", + Inbound: jsonrpc.PipeConfig{Label: "Inbound", Protocol: "MCP", Convert: mcpgateway.MCPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "ACP", Convert: acpwrap.ACPMethodToToolCall}, + ExtraLogLines: []string{ + "Auto-detect mode: inspecting both ACP and MCP methods", + }, + }) +} diff --git a/internal/autowrap/run_test.go b/internal/autowrap/run_test.go new file mode 100644 index 0000000..5e5e099 --- /dev/null +++ b/internal/autowrap/run_test.go @@ -0,0 +1,320 @@ +package autowrap + +import ( + "bytes" + "encoding/json" + "io" + "os/exec" + "strings" + "testing" + "time" + + "github.com/BakeLens/crust/internal/acpwrap" + "github.com/BakeLens/crust/internal/jsonrpc" + "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/mcpgateway" + "github.com/BakeLens/crust/internal/rules" +) + +var testLog = logger.New("wrap-test") + +func newTestEngine(t *testing.T) *rules.Engine { + t.Helper() + engine, err := rules.NewEngine(rules.EngineConfig{ + UserRulesDir: t.TempDir(), + DisableBuiltin: false, + }) + if err != nil { + t.Fatalf("Failed to create engine: %v", err) + } + return engine +} + +// runInboundPipe runs PipeInspect with MCPMethodToToolCall (inbound direction). +func runInboundPipe(t *testing.T, input string) (fwd, errOut string) { + t.Helper() + engine := newTestEngine(t) + var fwdBuf, errBuf bytes.Buffer + fwdWriter := jsonrpc.NewLockedWriter(&fwdBuf) + errWriter := jsonrpc.NewLockedWriter(&errBuf) + jsonrpc.PipeInspect(testLog, engine, strings.NewReader(input), + fwdWriter, errWriter, mcpgateway.MCPMethodToToolCall, "MCP", "Inbound") + return fwdBuf.String(), errBuf.String() +} + +// runOutboundPipe runs PipeInspect with ACPMethodToToolCall (outbound direction). +func runOutboundPipe(t *testing.T, input string) (fwd, errOut string) { + t.Helper() + engine := newTestEngine(t) + var fwdBuf, errBuf bytes.Buffer + fwdWriter := jsonrpc.NewLockedWriter(&fwdBuf) + errWriter := jsonrpc.NewLockedWriter(&errBuf) + jsonrpc.PipeInspect(testLog, engine, strings.NewReader(input), + fwdWriter, errWriter, acpwrap.ACPMethodToToolCall, "ACP", "Outbound") + return fwdBuf.String(), errBuf.String() +} + +// --- Inbound (MCP) direction --- + +func TestPipeInbound_BlocksMCP(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"env_read", `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`}, + {"ssh_key_read", `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/home/user/.ssh/id_rsa"}}}`}, + {"resource_env_read", `{"jsonrpc":"2.0","id":3,"method":"resources/read","params":{"uri":"file:///app/.env"}}`}, + {"malformed_params", `{"jsonrpc":"2.0","id":4,"method":"tools/call","params":"not-an-object"}`}, + {"null_params", `{"jsonrpc":"2.0","id":5,"method":"tools/call","params":null}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, errOut := runInboundPipe(t, tt.msg+"\n") + if fwd != "" { + t.Errorf("subprocess should not receive blocked request, got: %s", fwd) + } + if errOut == "" { + t.Error("client should receive an error response") + } + }) + } +} + +func TestPipeInbound_PassesMCP(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"normal_read", `{"jsonrpc":"2.0","id":10,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/src/main.go"}}}`}, + {"non_security_method", `{"jsonrpc":"2.0","id":20,"method":"initialize","params":{"capabilities":{}}}`}, + {"tools_list", `{"jsonrpc":"2.0","id":30,"method":"tools/list","params":{}}`}, + {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"data"}}`}, + {"invalid_json", `not valid json at all`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, errOut := runInboundPipe(t, tt.msg+"\n") + if fwd != tt.msg+"\n" { + t.Errorf("message should pass through unchanged\ngot: %q\nwant: %q", fwd, tt.msg+"\n") + } + if errOut != "" { + t.Errorf("client should not receive errors, got: %s", errOut) + } + }) + } +} + +func TestPipeInbound_ErrorShape(t *testing.T) { + _, errOut := runInboundPipe(t, `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`+"\n") + var resp jsonrpc.ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { + t.Fatalf("expected JSON-RPC error, got: %s", errOut) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message missing [Crust]: %s", resp.Error.Message) + } +} + +// --- Outbound (ACP) direction --- + +func TestPipeOutbound_BlocksACP(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"env_read", `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`}, + {"ssh_key_read", `{"jsonrpc":"2.0","id":2,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/home/user/.ssh/id_rsa"}}`}, + {"env_write", `{"jsonrpc":"2.0","id":3,"method":"fs/write_text_file","params":{"sessionId":"s1","path":"/app/.env","content":"SECRET=abc"}}`}, + {"malformed_params", `{"jsonrpc":"2.0","id":4,"method":"fs/read_text_file","params":"not-an-object"}`}, + {"null_params", `{"jsonrpc":"2.0","id":5,"method":"fs/read_text_file","params":null}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, errOut := runOutboundPipe(t, tt.msg+"\n") + if fwd != "" { + t.Errorf("client should not receive blocked request, got: %s", fwd) + } + if errOut == "" { + t.Error("subprocess should receive an error response") + } + }) + } +} + +func TestPipeOutbound_PassesACP(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"normal_read", `{"jsonrpc":"2.0","id":10,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/src/main.go"}}`}, + {"non_security_method", `{"jsonrpc":"2.0","id":20,"method":"session/prompt","params":{"text":"hello"}}`}, + {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"data"}}`}, + {"invalid_json", `not valid json at all`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, errOut := runOutboundPipe(t, tt.msg+"\n") + if fwd != tt.msg+"\n" { + t.Errorf("message should pass through unchanged\ngot: %q\nwant: %q", fwd, tt.msg+"\n") + } + if errOut != "" { + t.Errorf("subprocess should not receive errors, got: %s", errOut) + } + }) + } +} + +func TestPipeOutbound_ErrorShape(t *testing.T) { + _, errOut := runOutboundPipe(t, `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`+"\n") + var resp jsonrpc.ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { + t.Fatalf("expected JSON-RPC error, got: %s", errOut) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message missing [Crust]: %s", resp.Error.Message) + } +} + +// --- Cross-protocol: MCP is not blocked on outbound, ACP is not blocked on inbound --- + +func TestPipeInbound_IgnoresACPMethods(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}` + fwd, errOut := runInboundPipe(t, msg+"\n") + if fwd != msg+"\n" { + t.Errorf("ACP methods should pass through in inbound direction\ngot: %q\nwant: %q", fwd, msg+"\n") + } + if errOut != "" { + t.Errorf("should not generate errors for ACP methods in inbound direction, got: %s", errOut) + } +} + +func TestPipeOutbound_IgnoresMCPMethods(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}` + fwd, errOut := runOutboundPipe(t, msg+"\n") + if fwd != msg+"\n" { + t.Errorf("MCP methods should pass through in outbound direction\ngot: %q\nwant: %q", fwd, msg+"\n") + } + if errOut != "" { + t.Errorf("should not generate errors for MCP methods in outbound direction, got: %s", errOut) + } +} + +// --- Empty lines --- + +func TestPipeInbound_EmptyLine(t *testing.T) { + fwd, _ := runInboundPipe(t, "\n") + if fwd != "\n" { + t.Errorf("empty line should pass through, got: %q", fwd) + } +} + +func TestPipeOutbound_EmptyLine(t *testing.T) { + fwd, _ := runOutboundPipe(t, "\n") + if fwd != "\n" { + t.Errorf("empty line should pass through, got: %q", fwd) + } +} + +// --- RunProxy (hang / exit code) --- + +func TestRunProxy(t *testing.T) { + t.Run("no_hang_on_exit", func(t *testing.T) { + if _, err := exec.LookPath("true"); err != nil { + t.Skip("'true' not found in PATH") + } + engine := newTestEngine(t) + stdinR, stdinW := io.Pipe() + defer stdinW.Close() + + done := make(chan int, 1) + go func() { + done <- jsonrpc.RunProxy(engine, []string{"true"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ + Log: testLog, + ProcessLabel: "Subprocess", + Inbound: jsonrpc.PipeConfig{Label: "Inbound", Protocol: "MCP", Convert: mcpgateway.MCPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "ACP", Convert: acpwrap.ACPMethodToToolCall}, + }) + }() + + select { + case code := <-done: + if code != 0 { + t.Errorf("exit code = %d, want 0", code) + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung — stdin not closed after subprocess exit") + } + }) + + t.Run("propagates_exit_code", func(t *testing.T) { + if _, err := exec.LookPath("false"); err != nil { + t.Skip("'false' not found in PATH") + } + engine := newTestEngine(t) + stdinR, stdinW := io.Pipe() + defer stdinW.Close() + + done := make(chan int, 1) + go func() { + done <- jsonrpc.RunProxy(engine, []string{"false"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ + Log: testLog, + ProcessLabel: "Subprocess", + Inbound: jsonrpc.PipeConfig{Label: "Inbound", Protocol: "MCP", Convert: mcpgateway.MCPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "ACP", Convert: acpwrap.ACPMethodToToolCall}, + }) + }() + + select { + case code := <-done: + if code == 0 { + t.Error("expected non-zero exit code from 'false'") + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung") + } + }) +} + +// --- Multiple messages --- + +func TestPipeInbound_MultipleMessages(t *testing.T) { + msgs := strings.Join([]string{ + `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`, + `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/main.go"}}}`, + `{"jsonrpc":"2.0","id":3,"method":"initialize","params":{"capabilities":{}}}`, + }, "\n") + "\n" + + fwd, errOut := runInboundPipe(t, msgs) + + fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") + if len(fwdLines) != 2 { + t.Errorf("expected 2 subprocess messages (main.go + initialize), got %d: %v", len(fwdLines), fwdLines) + } + if errOut == "" { + t.Error("client should receive error for .env read") + } +} + +func TestPipeOutbound_MultipleMessages(t *testing.T) { + msgs := strings.Join([]string{ + `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}`, + `{"jsonrpc":"2.0","id":2,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/main.go"}}`, + `{"jsonrpc":"2.0","id":3,"method":"session/prompt","params":{"text":"hi"}}`, + }, "\n") + "\n" + + fwd, errOut := runOutboundPipe(t, msgs) + + fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") + if len(fwdLines) != 2 { + t.Errorf("expected 2 client messages (main.go + prompt), got %d: %v", len(fwdLines), fwdLines) + } + if errOut == "" { + t.Error("subprocess should receive error for .env read") + } +} diff --git a/internal/jsonrpc/pipe.go b/internal/jsonrpc/pipe.go new file mode 100644 index 0000000..1f5400f --- /dev/null +++ b/internal/jsonrpc/pipe.go @@ -0,0 +1,125 @@ +package jsonrpc + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + + "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/rules" +) + +// SendBlockError sends a JSON-RPC error response back through the writer. +func SendBlockError(log *logger.Logger, writer *LockedWriter, id json.RawMessage, msg string) { + resp, err := json.Marshal(ErrorResponse{ + JSONRPC: "2.0", + ID: id, + Error: ErrorObj{Code: BlockedError, Message: msg}, + }) + if err != nil { + log.Debug("Failed to marshal block response: %v", err) + return + } + if err := writer.WriteLine(resp); err != nil { + log.Debug("Failed to send block response: %v", err) + } +} + +// PipePassthrough reads JSONL from src and forwards each line to dst. +func PipePassthrough(log *logger.Logger, src io.Reader, dst *LockedWriter, label string) { + scanner := bufio.NewScanner(src) + scanner.Buffer(make([]byte, 0, 64*1024), MaxScannerBuf) + + for scanner.Scan() { + if err := dst.WriteLine(scanner.Bytes()); err != nil { + log.Debug("%s write error: %v", label, err) + return + } + } + if err := scanner.Err(); err != nil { + log.Warn("%s scanner error: %v", label, err) + } +} + +// PipeInspect reads JSONL from src, runs security-relevant messages through +// the converter and rule engine, and either forwards or blocks them. +// +// Parameters: +// - fwdWriter: where allowed messages are forwarded +// - errWriter: where JSON-RPC error responses for blocked messages are sent +// - convert: the protocol-specific method converter +// - protocol: "ACP" or "MCP" (for log messages) +// - label: direction label for debug logs (e.g., "Agent->IDE") +func PipeInspect(log *logger.Logger, engine *rules.Engine, src io.Reader, + fwdWriter, errWriter *LockedWriter, convert MethodConverter, protocol, label string) { + + scanner := bufio.NewScanner(src) + scanner.Buffer(make([]byte, 0, 64*1024), MaxScannerBuf) + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + if err := fwdWriter.WriteLine(line); err != nil { + log.Debug("%s write error: %v", label, err) + return + } + continue + } + + var msg Message + if err := json.Unmarshal(line, &msg); err != nil { + if err := fwdWriter.WriteLine(line); err != nil { + log.Debug("%s write error: %v", label, err) + return + } + continue + } + + if !msg.IsRequest() { + if err := fwdWriter.WriteLine(line); err != nil { + log.Debug("%s write error: %v", label, err) + return + } + continue + } + + toolCall, err := convert(msg.Method, msg.Params) + if toolCall == nil && err == nil { + if err := fwdWriter.WriteLine(line); err != nil { + log.Debug("%s write error: %v", label, err) + return + } + continue + } + if err != nil { + log.Warn("Blocked %s %s: %v", protocol, msg.Method, err) + SendBlockError(log, errWriter, msg.ID, "[Crust] Blocked: malformed params for "+msg.Method) + continue + } + + result := engine.Evaluate(*toolCall) + + if result.Matched && result.Action == rules.ActionBlock { + log.Warn("Blocked %s %s (tool=%s): rule=%s message=%s", + protocol, msg.Method, toolCall.Name, result.RuleName, result.Message) + SendBlockError(log, errWriter, msg.ID, + fmt.Sprintf("[Crust] Blocked by rule %q: %s", result.RuleName, result.Message)) + continue + } + + if result.Matched && result.Action == rules.ActionLog { + log.Info("Logged %s %s (tool=%s): rule=%s", + protocol, msg.Method, toolCall.Name, result.RuleName) + } + + if err := fwdWriter.WriteLine(line); err != nil { + log.Debug("%s write error: %v", label, err) + return + } + } + + if err := scanner.Err(); err != nil { + log.Warn("%s scanner error: %v", label, err) + } +} diff --git a/internal/jsonrpc/proxy.go b/internal/jsonrpc/proxy.go new file mode 100644 index 0000000..fa597da --- /dev/null +++ b/internal/jsonrpc/proxy.go @@ -0,0 +1,139 @@ +package jsonrpc + +import ( + "context" + "errors" + "io" + "os" + "os/exec" + "sync" + + "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/rules" +) + +// PipeConfig describes one direction of the proxy pipeline. +type PipeConfig struct { + // Label for log messages (e.g., "IDE->Agent", "Client->Server"). + Label string + // Protocol name for block/log messages ("ACP" or "MCP"). + Protocol string + // Convert is the method converter. If nil, this direction is passthrough-only. + Convert MethodConverter +} + +// ProxyConfig describes how to run the stdio proxy. +type ProxyConfig struct { + // Log is the logger to use. Each caller passes its own prefixed logger. + Log *logger.Logger + // ProcessLabel is the human name for the child process (e.g., "Agent", "MCP server"). + ProcessLabel string + // Inbound describes client/IDE -> subprocess direction. + Inbound PipeConfig + // Outbound describes subprocess -> client/IDE direction. + Outbound PipeConfig + // ExtraLogLines are additional log lines to emit at startup. + ExtraLogLines []string +} + +// RunProxy starts the stdio proxy. It spawns cmd, wires up stdio pipes, +// runs the configured inspection/passthrough pipes, and returns the child's exit code. +func RunProxy(engine *rules.Engine, cmd []string, stdin io.ReadCloser, stdout io.Writer, cfg ProxyConfig) int { + log := cfg.Log + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + child := exec.CommandContext(ctx, cmd[0], cmd[1:]...) //nolint:gosec // user-specified by design + childStdin, err := child.StdinPipe() + if err != nil { + log.Error("Failed to create %s stdin pipe: %v", cfg.ProcessLabel, err) + return 1 + } + childStdout, err := child.StdoutPipe() + if err != nil { + childStdin.Close() // Bug fix #4: close already-created pipe + log.Error("Failed to create %s stdout pipe: %v", cfg.ProcessLabel, err) + return 1 + } + child.Stderr = os.Stderr + + if err := child.Start(); err != nil { + childStdin.Close() + childStdout.Close() + log.Error("Failed to start %s: %v", cfg.ProcessLabel, err) + return 1 + } + + log.Info("%s started: PID %d, command: %v", cfg.ProcessLabel, child.Process.Pid, cmd) + log.Info("Rule engine: %d rules loaded", engine.RuleCount()) + for _, line := range cfg.ExtraLogLines { + log.Info("%s", line) + } + + clientWriter := NewLockedWriter(stdout) + childWriter := NewLockedWriter(childStdin) + + var wg sync.WaitGroup + + // Goroutine 1: Inbound (client/IDE -> child subprocess) + wg.Add(1) + go func() { + defer wg.Done() + defer childStdin.Close() + if cfg.Inbound.Convert != nil { + PipeInspect(log, engine, stdin, childWriter, clientWriter, + cfg.Inbound.Convert, cfg.Inbound.Protocol, cfg.Inbound.Label) + } else { + PipePassthrough(log, stdin, childWriter, cfg.Inbound.Label) + } + }() + + // Goroutine 2: Outbound (child subprocess -> client/IDE) + wg.Add(1) + go func() { + defer wg.Done() + if cfg.Outbound.Convert != nil { + PipeInspect(log, engine, childStdout, clientWriter, childWriter, + cfg.Outbound.Convert, cfg.Outbound.Protocol, cfg.Outbound.Label) + } else { + PipePassthrough(log, childStdout, clientWriter, cfg.Outbound.Label) + } + }() + + // Goroutine 3: Forward signals to child + // Bug fix #2: signal goroutine tracked in WaitGroup + sigCh := ForwardSignals() + wg.Add(1) + go func() { + defer wg.Done() + for sig := range sigCh { + if child.Process != nil { + if err := child.Process.Signal(sig); err != nil { + log.Debug("Signal %v to %s: %v", sig, cfg.ProcessLabel, err) + } + } + } + }() + + waitErr := child.Wait() + StopSignals(sigCh) + cancel() + + // Close client stdin to unblock the inbound goroutine's scanner. + if stdin != nil { + if err := stdin.Close(); err != nil { + log.Debug("Close stdin: %v", err) + } + } + + wg.Wait() + + if waitErr != nil { + var exitErr *exec.ExitError + if errors.As(waitErr, &exitErr) { + return exitErr.ExitCode() + } + return 1 + } + return 0 +} diff --git a/internal/jsonrpc/proxy_test.go b/internal/jsonrpc/proxy_test.go new file mode 100644 index 0000000..12b62c2 --- /dev/null +++ b/internal/jsonrpc/proxy_test.go @@ -0,0 +1,376 @@ +package jsonrpc + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "os/exec" + "strings" + "testing" + "time" + + "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/rules" +) + +var testLog = logger.New("test") + +func newTestEngine(t *testing.T) *rules.Engine { + t.Helper() + engine, err := rules.NewEngine(rules.EngineConfig{ + UserRulesDir: t.TempDir(), + DisableBuiltin: false, + }) + if err != nil { + t.Fatalf("Failed to create engine: %v", err) + } + return engine +} + +// blockAllConverter is a test converter that treats "security/call" as security-relevant +// and maps it to a tool call that will be blocked by the built-in rules. +func blockAllConverter(method string, params json.RawMessage) (*rules.ToolCall, error) { + if method == "security/call" { + args, _ := json.Marshal(map[string]string{"path": "/etc/shadow"}) + return &rules.ToolCall{Name: "read_file", Arguments: args}, nil + } + if method == "malformed/call" { + return nil, fmt.Errorf("malformed params") + } + return nil, nil // not security-relevant +} + +// passthroughConverter returns nil, nil for everything (all passthrough). +func passthroughConverter(method string, params json.RawMessage) (*rules.ToolCall, error) { + return nil, nil +} + +// --- WriteLine --- + +func TestWriteLine(t *testing.T) { + t.Run("does_not_mutate_input", func(t *testing.T) { + var buf bytes.Buffer + lw := NewLockedWriter(&buf) + + backing := make([]byte, 6, 100) + copy(backing, "helloX") + data := backing[:5] + + if err := lw.WriteLine(data); err != nil { + t.Fatal(err) + } + if backing[5] != 'X' { + t.Errorf("WriteLine mutated caller's buffer: byte after data is %q, want 'X'", backing[5]) + } + if buf.String() != "hello\n" { + t.Errorf("output = %q, want %q", buf.String(), "hello\n") + } + }) + + t.Run("consecutive_writes", func(t *testing.T) { + var buf bytes.Buffer + lw := NewLockedWriter(&buf) + + for _, line := range []string{"first", "second", "third"} { + if err := lw.WriteLine([]byte(line)); err != nil { + t.Fatal(err) + } + } + if got, want := buf.String(), "first\nsecond\nthird\n"; got != want { + t.Errorf("got %q, want %q", got, want) + } + }) +} + +// --- SendBlockError --- + +func TestSendBlockError(t *testing.T) { + var buf bytes.Buffer + lw := NewLockedWriter(&buf) + SendBlockError(testLog, lw, json.RawMessage(`1`), "test error") + + var resp ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &resp); err != nil { + t.Fatalf("failed to parse error response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Errorf("jsonrpc = %q, want %q", resp.JSONRPC, "2.0") + } + if resp.Error.Code != BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, BlockedError) + } + if resp.Error.Message != "test error" { + t.Errorf("error message = %q, want %q", resp.Error.Message, "test error") + } +} + +// --- PipePassthrough --- + +func TestPipePassthrough(t *testing.T) { + input := "line1\nline2\nline3\n" + var buf bytes.Buffer + dst := NewLockedWriter(&buf) + PipePassthrough(testLog, strings.NewReader(input), dst, "test") + + if got := buf.String(); got != input { + t.Errorf("got %q, want %q", got, input) + } +} + +func TestPipePassthrough_EmptyInput(t *testing.T) { + var buf bytes.Buffer + dst := NewLockedWriter(&buf) + PipePassthrough(testLog, strings.NewReader(""), dst, "test") + + if buf.Len() != 0 { + t.Errorf("expected empty output, got %q", buf.String()) + } +} + +// --- PipeInspect --- + +func runInspect(t *testing.T, input string, convert MethodConverter) (fwd, errOut string) { + t.Helper() + engine := newTestEngine(t) + var fwdBuf, errBuf bytes.Buffer + fwdWriter := NewLockedWriter(&fwdBuf) + errWriter := NewLockedWriter(&errBuf) + PipeInspect(testLog, engine, strings.NewReader(input), + fwdWriter, errWriter, convert, "TEST", "test-label") + return fwdBuf.String(), errBuf.String() +} + +func TestPipeInspect_PassesNonSecurityRequest(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"method":"other/call","params":{}}` + "\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != msg { + t.Errorf("expected passthrough, got %q", fwd) + } + if errOut != "" { + t.Errorf("unexpected error response: %s", errOut) + } +} + +func TestPipeInspect_PassesNotification(t *testing.T) { + msg := `{"jsonrpc":"2.0","method":"update","params":{}}` + "\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != msg { + t.Errorf("expected passthrough, got %q", fwd) + } + if errOut != "" { + t.Errorf("unexpected error response: %s", errOut) + } +} + +func TestPipeInspect_PassesResponse(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"result":{"data":"ok"}}` + "\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != msg { + t.Errorf("expected passthrough, got %q", fwd) + } + if errOut != "" { + t.Errorf("unexpected error response: %s", errOut) + } +} + +func TestPipeInspect_PassesInvalidJSON(t *testing.T) { + msg := "not valid json\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != msg { + t.Errorf("expected passthrough, got %q", fwd) + } + if errOut != "" { + t.Errorf("unexpected error response: %s", errOut) + } +} + +func TestPipeInspect_PassesEmptyLine(t *testing.T) { + fwd, _ := runInspect(t, "\n", blockAllConverter) + if fwd != "\n" { + t.Errorf("empty line should pass through, got %q", fwd) + } +} + +func TestPipeInspect_BlocksSecurityRequest(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"method":"security/call","params":{}}` + "\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != "" { + t.Errorf("blocked request should not be forwarded, got %q", fwd) + } + if errOut == "" { + t.Error("expected error response for blocked request") + } + var resp ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { + t.Fatalf("failed to parse error: %v", err) + } + if resp.Error.Code != BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message should contain [Crust]: %s", resp.Error.Message) + } +} + +func TestPipeInspect_BlocksMalformedParams(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"method":"malformed/call","params":{}}` + "\n" + fwd, errOut := runInspect(t, msg, blockAllConverter) + if fwd != "" { + t.Errorf("malformed request should not be forwarded, got %q", fwd) + } + if errOut == "" { + t.Error("expected error response for malformed params") + } +} + +func TestPipeInspect_MultipleMessages(t *testing.T) { + msgs := strings.Join([]string{ + `{"jsonrpc":"2.0","id":1,"method":"security/call","params":{}}`, + `{"jsonrpc":"2.0","id":2,"method":"other/call","params":{}}`, + `{"jsonrpc":"2.0","id":3,"method":"malformed/call","params":{}}`, + }, "\n") + "\n" + + fwd, errOut := runInspect(t, msgs, blockAllConverter) + + fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") + if len(fwdLines) != 1 { + t.Errorf("expected 1 forwarded message, got %d: %v", len(fwdLines), fwdLines) + } + + errLines := strings.Split(strings.TrimRight(errOut, "\n"), "\n") + if len(errLines) != 2 { + t.Errorf("expected 2 error responses (blocked + malformed), got %d: %v", len(errLines), errLines) + } +} + +// --- RunProxy --- + +func TestRunProxy_NoHang(t *testing.T) { + if _, err := exec.LookPath("true"); err != nil { + t.Skip("'true' not found in PATH") + } + engine := newTestEngine(t) + stdinR, stdinW := io.Pipe() + defer stdinW.Close() // keep write end OPEN to expose the hang bug + + done := make(chan int, 1) + go func() { + done <- RunProxy(engine, []string{"true"}, stdinR, &bytes.Buffer{}, ProxyConfig{ + Log: testLog, + ProcessLabel: "test-true", + Inbound: PipeConfig{Label: "in"}, + Outbound: PipeConfig{Label: "out"}, + }) + }() + + select { + case code := <-done: + if code != 0 { + t.Errorf("exit code = %d, want 0", code) + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung — stdin not closed after child exit") + } +} + +func TestRunProxy_ExitCode(t *testing.T) { + if _, err := exec.LookPath("false"); err != nil { + t.Skip("'false' not found in PATH") + } + engine := newTestEngine(t) + stdinR, stdinW := io.Pipe() + defer stdinW.Close() + + done := make(chan int, 1) + go func() { + done <- RunProxy(engine, []string{"false"}, stdinR, &bytes.Buffer{}, ProxyConfig{ + Log: testLog, + ProcessLabel: "test-false", + Inbound: PipeConfig{Label: "in"}, + Outbound: PipeConfig{Label: "out"}, + }) + }() + + select { + case code := <-done: + if code == 0 { + t.Error("expected non-zero exit code from 'false'") + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung") + } +} + +func TestRunProxy_WithInspect(t *testing.T) { + if _, err := exec.LookPath("cat"); err != nil { + t.Skip("'cat' not found in PATH") + } + engine := newTestEngine(t) + + // cat echoes stdin to stdout — so we send a message and check it comes through + input := `{"jsonrpc":"2.0","id":1,"method":"other/call","params":{}}` + "\n" + stdinR := io.NopCloser(strings.NewReader(input)) + var stdout bytes.Buffer + + done := make(chan int, 1) + go func() { + done <- RunProxy(engine, []string{"cat"}, stdinR, &stdout, ProxyConfig{ + Log: testLog, + ProcessLabel: "test-cat", + Inbound: PipeConfig{Label: "in", Protocol: "TEST", Convert: passthroughConverter}, + Outbound: PipeConfig{Label: "out", Protocol: "TEST", Convert: passthroughConverter}, + }) + }() + + select { + case code := <-done: + if code != 0 { + t.Errorf("exit code = %d, want 0", code) + } + if !strings.Contains(stdout.String(), "other/call") { + t.Errorf("expected message to pass through cat, got %q", stdout.String()) + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung") + } +} + +// --- Signal helpers --- + +func TestForwardSignals_StopSignals(t *testing.T) { + ch := ForwardSignals() + if ch == nil { + t.Fatal("ForwardSignals returned nil channel") + } + // StopSignals should close without panic + StopSignals(ch) + + // Verify channel is closed + _, ok := <-ch + if ok { + t.Error("channel should be closed after StopSignals") + } +} + +// --- IsRequest --- + +func TestMessage_IsRequest(t *testing.T) { + tests := []struct { + name string + msg Message + want bool + }{ + {"request", Message{Method: "foo", ID: json.RawMessage(`1`)}, true}, + {"notification", Message{Method: "foo"}, false}, + {"response", Message{ID: json.RawMessage(`1`), Result: json.RawMessage(`{}`)}, false}, + {"empty", Message{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.msg.IsRequest(); got != tt.want { + t.Errorf("IsRequest() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/jsonrpc/signal_unix.go b/internal/jsonrpc/signal_unix.go new file mode 100644 index 0000000..1fd6207 --- /dev/null +++ b/internal/jsonrpc/signal_unix.go @@ -0,0 +1,22 @@ +//go:build !windows + +package jsonrpc + +import ( + "os" + "os/signal" + "syscall" +) + +// ForwardSignals registers for SIGINT, SIGTERM, and SIGHUP and returns the channel. +func ForwardSignals() chan os.Signal { + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + return ch +} + +// StopSignals deregisters and closes the signal channel. +func StopSignals(ch chan os.Signal) { + signal.Stop(ch) + close(ch) +} diff --git a/internal/jsonrpc/signal_windows.go b/internal/jsonrpc/signal_windows.go new file mode 100644 index 0000000..eb4edfb --- /dev/null +++ b/internal/jsonrpc/signal_windows.go @@ -0,0 +1,21 @@ +//go:build windows + +package jsonrpc + +import ( + "os" + "os/signal" +) + +// ForwardSignals registers for os.Interrupt and returns the channel. +func ForwardSignals() chan os.Signal { + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + return ch +} + +// StopSignals deregisters and closes the signal channel. +func StopSignals(ch chan os.Signal) { + signal.Stop(ch) + close(ch) +} diff --git a/internal/jsonrpc/types.go b/internal/jsonrpc/types.go new file mode 100644 index 0000000..d778cdb --- /dev/null +++ b/internal/jsonrpc/types.go @@ -0,0 +1,76 @@ +// Package jsonrpc provides shared types and utilities for Crust's stdio proxy +// implementations (ACP wrap, MCP gateway, auto-detect wrap). +package jsonrpc + +import ( + "encoding/json" + "io" + "sync" + + "github.com/BakeLens/crust/internal/rules" +) + +// BlockedError is the JSON-RPC error code for requests blocked by a security rule. +const BlockedError = -32001 + +// MaxScannerBuf is the maximum size of a single JSONL message (10MB). +const MaxScannerBuf = 10 * 1024 * 1024 + +// Message represents a minimal JSON-RPC 2.0 message. +type Message struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` +} + +// IsRequest returns true if this is a JSON-RPC request (has method + id). +func (m *Message) IsRequest() bool { + return m.Method != "" && len(m.ID) > 0 +} + +// ErrorResponse is a JSON-RPC 2.0 error response. +type ErrorResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Error ErrorObj `json:"error"` +} + +// ErrorObj is the error object within a JSON-RPC error response. +type ErrorObj struct { + Code int `json:"code"` + Message string `json:"message"` +} + +// LockedWriter is a mutex-protected writer for safe concurrent line writes. +// Both the pass-through goroutine and the blocking logic write through it. +type LockedWriter struct { + mu sync.Mutex + w io.Writer +} + +// NewLockedWriter creates a new LockedWriter wrapping w. +func NewLockedWriter(w io.Writer) *LockedWriter { + return &LockedWriter{w: w} +} + +// WriteLine writes data followed by a newline, under the mutex. +func (lw *LockedWriter) WriteLine(data []byte) error { + lw.mu.Lock() + defer lw.mu.Unlock() + if _, err := lw.w.Write(data); err != nil { + return err + } + _, err := lw.w.Write([]byte{'\n'}) + return err +} + +// MethodConverter converts a JSON-RPC method + params into a rules.ToolCall. +// +// Returns: +// - (*ToolCall, nil) for successfully parsed security-relevant methods +// - (nil, nil) for non-security methods (caller should pass through) +// - (nil, error) for security-relevant methods with malformed params (caller should block) +type MethodConverter func(method string, params json.RawMessage) (*rules.ToolCall, error) diff --git a/internal/mcpgateway/convert.go b/internal/mcpgateway/convert.go new file mode 100644 index 0000000..fb58d6f --- /dev/null +++ b/internal/mcpgateway/convert.go @@ -0,0 +1,101 @@ +// Package mcpgateway implements a transparent stdio proxy for MCP (Model Context Protocol) +// servers, intercepting security-relevant JSON-RPC messages using Crust's rule engine. +// +// Unlike ACP wrap (which inspects agent->IDE direction), the MCP gateway inspects the +// client->server direction because MCP clients send tool calls TO the server. +package mcpgateway + +import ( + "encoding/json" + "fmt" + "net/url" + + "github.com/BakeLens/crust/internal/rules" +) + +// MCP parameter types + +// toolsCallParams represents the params of a MCP tools/call request. +type toolsCallParams struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` +} + +// resourcesReadParams represents the params of a MCP resources/read request. +type resourcesReadParams struct { + URI string `json:"uri"` +} + +// MCPMethodToToolCall converts an MCP JSON-RPC method + params into a rules.ToolCall. +// +// Returns: +// - (*ToolCall, nil) for successfully parsed security-relevant methods +// - (nil, nil) for non-security methods (caller should pass through) +// - (nil, error) for security-relevant methods with malformed params (caller should block) +func MCPMethodToToolCall(method string, params json.RawMessage) (*rules.ToolCall, error) { + // Reject nil/null params on security-relevant methods (json.Unmarshal silently + // zero-initializes the struct, which would produce an empty name and bypass rules). + switch method { + case "tools/call", "resources/read": + if len(params) == 0 || string(params) == "null" { + return nil, fmt.Errorf("nil params for security method %s", method) + } + default: + return nil, nil // not security-relevant + } + + switch method { + case "tools/call": + var p toolsCallParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, fmt.Errorf("malformed %s params: %w", method, err) + } + if p.Name == "" { + return nil, fmt.Errorf("empty tool name in %s", method) + } + // MCP tool names are dynamic — pass through directly. + // The rule engine's shape-based extraction handles argument parsing. + args := p.Arguments + if len(args) == 0 { + args = json.RawMessage("{}") + } + return &rules.ToolCall{Name: p.Name, Arguments: args}, nil + + case "resources/read": + var p resourcesReadParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, fmt.Errorf("malformed %s params: %w", method, err) + } + if p.URI == "" { + return nil, fmt.Errorf("empty URI in %s", method) + } + + parsed, err := url.Parse(p.URI) + if err != nil { + return nil, fmt.Errorf("invalid URI in %s: %w", method, err) + } + + switch parsed.Scheme { + case "file", "": + path := parsed.Path + if path == "" { + path = p.URI + } + args, err := json.Marshal(map[string]string{"path": path}) + if err != nil { + return nil, fmt.Errorf("marshal error: %w", err) + } + return &rules.ToolCall{Name: "read_file", Arguments: args}, nil + + default: + args, err := json.Marshal(map[string]string{"url": p.URI}) + if err != nil { + return nil, fmt.Errorf("marshal error: %w", err) + } + return &rules.ToolCall{Name: "mcp_resource_read", Arguments: args}, nil + } + + default: + return nil, nil + } +} diff --git a/internal/mcpgateway/convert_test.go b/internal/mcpgateway/convert_test.go new file mode 100644 index 0000000..74c112a --- /dev/null +++ b/internal/mcpgateway/convert_test.go @@ -0,0 +1,147 @@ +package mcpgateway + +import ( + "encoding/json" + "testing" +) + +// --- MCPMethodToToolCall --- + +func TestMcpMethodToToolCall(t *testing.T) { + tests := []struct { + name string + method string + params string + wantName string + wantKey string + wantVal string + }{ + { + "tools_call_read_file", + "tools/call", + `{"name":"read_file","arguments":{"path":"/etc/passwd"}}`, + "read_file", "path", "/etc/passwd", + }, + { + "tools_call_write_file", + "tools/call", + `{"name":"write_file","arguments":{"path":"/tmp/out.txt","content":"hello"}}`, + "write_file", "path", "/tmp/out.txt", + }, + { + "tools_call_bash", + "tools/call", + `{"name":"bash","arguments":{"command":"ls -la"}}`, + "bash", "command", "ls -la", + }, + { + "tools_call_custom_tool", + "tools/call", + `{"name":"my_custom_tool","arguments":{"query":"SELECT * FROM users"}}`, + "my_custom_tool", "query", "SELECT * FROM users", + }, + { + "resources_read_file", + "resources/read", + `{"uri":"file:///etc/passwd"}`, + "read_file", "path", "/etc/passwd", + }, + { + "resources_read_http", + "resources/read", + `{"uri":"https://evil.com/data"}`, + "mcp_resource_read", "url", "https://evil.com/data", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc, err := MCPMethodToToolCall(tt.method, json.RawMessage(tt.params)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc == nil { + t.Fatal("expected non-nil ToolCall") + } + if tc.Name != tt.wantName { + t.Errorf("name = %s, want %s", tc.Name, tt.wantName) + } + var args map[string]any + if err := json.Unmarshal(tc.Arguments, &args); err != nil { + t.Fatal(err) + } + if got := args[tt.wantKey]; got != tt.wantVal { + t.Errorf("%s = %v, want %s", tt.wantKey, got, tt.wantVal) + } + }) + } +} + +func TestMcpMethodToToolCall_EmptyArguments(t *testing.T) { + tc, err := MCPMethodToToolCall("tools/call", json.RawMessage(`{"name":"ping"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tc == nil { + t.Fatal("expected non-nil ToolCall") + } + if tc.Name != "ping" { + t.Errorf("name = %s, want ping", tc.Name) + } + if string(tc.Arguments) != "{}" { + t.Errorf("arguments = %s, want {}", string(tc.Arguments)) + } +} + +func TestMcpMethodToToolCall_Unknown(t *testing.T) { + for _, method := range []string{"initialize", "tools/list", "prompts/get", "notifications/cancelled"} { //nolint:misspell // MCP protocol uses "cancelled" + tc, err := MCPMethodToToolCall(method, nil) + if err != nil { + t.Fatalf("%s: unexpected error: %v", method, err) + } + if tc != nil { + t.Errorf("%s should not be security-relevant", method) + } + } +} + +func TestMcpMethodToToolCall_MalformedParams(t *testing.T) { + methods := []string{"tools/call", "resources/read"} + badInputs := []json.RawMessage{ + json.RawMessage(`{broken`), + json.RawMessage(`"just a string"`), + json.RawMessage(`null`), + json.RawMessage(`42`), + json.RawMessage(``), + } + for _, method := range methods { + for _, input := range badInputs { + tc, err := MCPMethodToToolCall(method, input) + if tc != nil { + t.Errorf("%s with %q: expected nil ToolCall for malformed params", method, input) + } + if err == nil { + t.Errorf("%s with %q: expected error for malformed params", method, input) + } + } + } +} + +func TestMcpMethodToToolCall_EmptyName(t *testing.T) { + tc, err := MCPMethodToToolCall("tools/call", json.RawMessage(`{"name":"","arguments":{}}`)) + if tc != nil { + t.Error("expected nil ToolCall for empty name") + } + if err == nil { + t.Error("expected error for empty tool name") + } +} + +func TestMcpMethodToToolCall_EmptyURI(t *testing.T) { + tc, err := MCPMethodToToolCall("resources/read", json.RawMessage(`{"uri":""}`)) + if tc != nil { + t.Error("expected nil ToolCall for empty URI") + } + if err == nil { + t.Error("expected error for empty URI") + } +} diff --git a/internal/mcpgateway/run.go b/internal/mcpgateway/run.go new file mode 100644 index 0000000..c688bf9 --- /dev/null +++ b/internal/mcpgateway/run.go @@ -0,0 +1,22 @@ +package mcpgateway + +import ( + "os" + + "github.com/BakeLens/crust/internal/jsonrpc" + "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/rules" +) + +var log = logger.New("mcp") + +// Run starts the MCP gateway proxy. It spawns the MCP server subprocess, wires up +// stdio, and evaluates security-relevant messages. Returns the server's exit code. +func Run(engine *rules.Engine, serverCmd []string) int { + return jsonrpc.RunProxy(engine, serverCmd, os.Stdin, os.Stdout, jsonrpc.ProxyConfig{ + Log: log, + ProcessLabel: "MCP server", + Inbound: jsonrpc.PipeConfig{Label: "Client->Server", Protocol: "MCP", Convert: MCPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Server->Client"}, + }) +} diff --git a/internal/mcpgateway/run_test.go b/internal/mcpgateway/run_test.go new file mode 100644 index 0000000..5ef5d7b --- /dev/null +++ b/internal/mcpgateway/run_test.go @@ -0,0 +1,197 @@ +package mcpgateway + +import ( + "bytes" + "encoding/json" + "io" + "os/exec" + "strings" + "testing" + "time" + + "github.com/BakeLens/crust/internal/jsonrpc" + "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/rules" +) + +var testLog = logger.New("mcp-test") + +func newTestEngine(t *testing.T) *rules.Engine { + t.Helper() + engine, err := rules.NewEngine(rules.EngineConfig{ + UserRulesDir: t.TempDir(), + DisableBuiltin: false, + }) + if err != nil { + t.Fatalf("Failed to create engine: %v", err) + } + return engine +} + +// runPipe runs PipeInspect with MCPMethodToToolCall and returns what was +// forwarded and what error responses were generated. +func runPipe(t *testing.T, input string) (fwd, errOut string) { + t.Helper() + engine := newTestEngine(t) + var fwdBuf, errBuf bytes.Buffer + fwdWriter := jsonrpc.NewLockedWriter(&fwdBuf) + errWriter := jsonrpc.NewLockedWriter(&errBuf) + jsonrpc.PipeInspect(testLog, engine, strings.NewReader(input), + fwdWriter, errWriter, MCPMethodToToolCall, "MCP", "Client->Server") + return fwdBuf.String(), errBuf.String() +} + +// --- RunProxy (hang / exit code) --- + +func TestRunProxy(t *testing.T) { + t.Run("no_hang_on_server_exit", func(t *testing.T) { + if _, err := exec.LookPath("true"); err != nil { + t.Skip("'true' not found in PATH") + } + engine := newTestEngine(t) + stdinR, stdinW := io.Pipe() + defer stdinW.Close() + + done := make(chan int, 1) + go func() { + done <- jsonrpc.RunProxy(engine, []string{"true"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ + Log: testLog, + ProcessLabel: "MCP server", + Inbound: jsonrpc.PipeConfig{Label: "Client->Server", Protocol: "MCP", Convert: MCPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Server->Client"}, + }) + }() + + select { + case code := <-done: + if code != 0 { + t.Errorf("exit code = %d, want 0", code) + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung — client stdin not closed after server exit") + } + }) + + t.Run("propagates_exit_code", func(t *testing.T) { + if _, err := exec.LookPath("false"); err != nil { + t.Skip("'false' not found in PATH") + } + engine := newTestEngine(t) + stdinR, stdinW := io.Pipe() + defer stdinW.Close() + + done := make(chan int, 1) + go func() { + done <- jsonrpc.RunProxy(engine, []string{"false"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ + Log: testLog, + ProcessLabel: "MCP server", + Inbound: jsonrpc.PipeConfig{Label: "Client->Server", Protocol: "MCP", Convert: MCPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Server->Client"}, + }) + }() + + select { + case code := <-done: + if code == 0 { + t.Error("expected non-zero exit code from 'false'") + } + case <-time.After(5 * time.Second): + t.Fatal("RunProxy hung") + } + }) +} + +// --- PipeInspect + MCPMethodToToolCall integration --- + +func TestPipeClientToServer_Blocks(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"env_read", `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`}, + {"ssh_key_read", `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/home/user/.ssh/id_rsa"}}}`}, + {"env_write", `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"write_file","arguments":{"path":"/app/.env","content":"API_KEY=secret"}}}`}, + {"resource_env_read", `{"jsonrpc":"2.0","id":4,"method":"resources/read","params":{"uri":"file:///app/.env"}}`}, + {"malformed_tools_call", `{"jsonrpc":"2.0","id":5,"method":"tools/call","params":"not-an-object"}`}, + {"null_params", `{"jsonrpc":"2.0","id":6,"method":"tools/call","params":null}`}, + {"empty_tool_name", `{"jsonrpc":"2.0","id":7,"method":"tools/call","params":{"name":"","arguments":{}}}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, errOut := runPipe(t, tt.msg+"\n") + if fwd != "" { + t.Errorf("server should not receive blocked request, got: %s", fwd) + } + if errOut == "" { + t.Error("client should receive an error response") + } + }) + } +} + +func TestPipeClientToServer_BlocksEnvRead_ErrorShape(t *testing.T) { + fwd, errOut := runPipe(t, `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`+"\n") + if fwd != "" { + t.Errorf("server should not receive blocked request, got: %s", fwd) + } + var resp jsonrpc.ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { + t.Fatalf("expected JSON-RPC error, got: %s", errOut) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message missing [Crust]: %s", resp.Error.Message) + } +} + +func TestPipeClientToServer_Passes(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"normal_read", `{"jsonrpc":"2.0","id":10,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/src/main.go"}}}`}, + {"non_security_method", `{"jsonrpc":"2.0","id":20,"method":"initialize","params":{"capabilities":{}}}`}, + {"tools_list", `{"jsonrpc":"2.0","id":30,"method":"tools/list","params":{}}`}, + {"notification", `{"jsonrpc":"2.0","method":"notifications/cancelled","params":{"requestId":1}}`}, //nolint:misspell // MCP protocol uses "cancelled" + {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"file data"}}`}, + {"invalid_json", `not valid json at all`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, errOut := runPipe(t, tt.msg+"\n") + if fwd != tt.msg+"\n" { + t.Errorf("message should pass through unchanged\ngot: %q\nwant: %q", fwd, tt.msg+"\n") + } + if errOut != "" { + t.Errorf("client should not receive errors, got: %s", errOut) + } + }) + } +} + +func TestPipeClientToServer_EmptyLine(t *testing.T) { + fwd, _ := runPipe(t, "\n") + if fwd != "\n" { + t.Errorf("empty line should pass through, got: %q", fwd) + } +} + +func TestPipeClientToServer_MultipleMessages(t *testing.T) { + msgs := strings.Join([]string{ + `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`, + `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/main.go"}}}`, + `{"jsonrpc":"2.0","id":3,"method":"initialize","params":{"capabilities":{}}}`, + }, "\n") + "\n" + + fwd, errOut := runPipe(t, msgs) + + fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") + if len(fwdLines) != 2 { + t.Errorf("expected 2 server messages (main.go + initialize), got %d: %v", len(fwdLines), fwdLines) + } + if errOut == "" { + t.Error("client should receive error for .env read") + } +} diff --git a/main.go b/main.go index 4e72f79..723887f 100644 --- a/main.go +++ b/main.go @@ -19,11 +19,13 @@ import ( "github.com/BakeLens/crust/internal/earlyinit" // side-effect import: init() runs before bubbletea's via dependency order + lexicographic tie-breaking "github.com/BakeLens/crust/internal/acpwrap" + "github.com/BakeLens/crust/internal/autowrap" "github.com/BakeLens/crust/internal/cli" "github.com/BakeLens/crust/internal/completion" "github.com/BakeLens/crust/internal/config" "github.com/BakeLens/crust/internal/daemon" "github.com/BakeLens/crust/internal/logger" + "github.com/BakeLens/crust/internal/mcpgateway" "github.com/BakeLens/crust/internal/proxy" "github.com/BakeLens/crust/internal/rules" "github.com/BakeLens/crust/internal/security" @@ -96,6 +98,12 @@ func main() { case "acp-wrap": runAcpWrap(os.Args[2:]) return + case "mcp-gateway": + runMcpGateway(os.Args[2:]) + return + case "wrap": + runWrap(os.Args[2:]) + return case "uninstall": runUninstall() return @@ -904,21 +912,30 @@ func runReloadRules(_ []string) { } // runAcpWrap handles the acp-wrap subcommand -func runAcpWrap(args []string) { - wrapFlags := flag.NewFlagSet("acp-wrap", flag.ExitOnError) - configPath := wrapFlags.String("config", config.DefaultConfigPath(), "Path to configuration file") - logLevel := wrapFlags.String("log-level", "warn", "Log level: trace, debug, info, warn, error") - rulesDir := wrapFlags.String("rules-dir", "", "Override rules directory") - disableBuiltin := wrapFlags.Bool("disable-builtin", false, "Disable builtin security rules") - _ = wrapFlags.Parse(args) - - agentCmd := wrapFlags.Args() - if len(agentCmd) == 0 { - fmt.Fprintf(os.Stderr, "Usage: crust acp-wrap [flags] -- [args...]\n") +// proxyRunConfig describes a proxy subcommand entry point. +type proxyRunConfig struct { + name string // subcommand name (e.g., "acp-wrap") + usage string // usage line (e.g., "acp-wrap [flags] -- [args...]") + run func(engine *rules.Engine, cmd []string) int +} + +// runProxyCommand implements the shared flag parsing, config loading, engine +// init, and subprocess launch for all proxy subcommands (acp-wrap, mcp-gateway, wrap). +func runProxyCommand(pcfg proxyRunConfig, args []string) { + fs := flag.NewFlagSet(pcfg.name, flag.ExitOnError) + configPath := fs.String("config", config.DefaultConfigPath(), "Path to configuration file") + logLevel := fs.String("log-level", "warn", "Log level: trace, debug, info, warn, error") + rulesDir := fs.String("rules-dir", "", "Override rules directory") + disableBuiltin := fs.Bool("disable-builtin", false, "Disable builtin security rules") + _ = fs.Parse(args) + + subCmd := fs.Args() + if len(subCmd) == 0 { + fmt.Fprintf(os.Stderr, "Usage: crust %s\n", pcfg.usage) os.Exit(1) } - // Logger to stderr only — stdout is the ACP pipe + // Logger to stderr only — stdout is the JSON-RPC pipe logger.SetColored(false) if *logLevel != "" { logger.SetGlobalLevelFromString(*logLevel) @@ -937,9 +954,8 @@ func runAcpWrap(args []string) { dir = rules.DefaultUserRulesDir() } - // Ensure user rules directory exists so the engine can load if err := os.MkdirAll(dir, 0o700); err != nil { - fmt.Fprintf(os.Stderr, "crust acp-wrap: failed to create rules dir: %v\n", err) + fmt.Fprintf(os.Stderr, "crust %s: failed to create rules dir: %v\n", pcfg.name, err) os.Exit(1) } @@ -949,11 +965,35 @@ func runAcpWrap(args []string) { SubprocessIsolation: true, }) if err != nil { - fmt.Fprintf(os.Stderr, "crust acp-wrap: failed to init rules: %v\n", err) + fmt.Fprintf(os.Stderr, "crust %s: failed to init rules: %v\n", pcfg.name, err) os.Exit(1) } - os.Exit(acpwrap.Run(engine, agentCmd)) + os.Exit(pcfg.run(engine, subCmd)) +} + +func runAcpWrap(args []string) { + runProxyCommand(proxyRunConfig{ + name: "acp-wrap", + usage: "acp-wrap [flags] -- [args...]", + run: acpwrap.Run, + }, args) +} + +func runMcpGateway(args []string) { + runProxyCommand(proxyRunConfig{ + name: "mcp-gateway", + usage: "mcp-gateway [flags] -- [args...]", + run: mcpgateway.Run, + }, args) +} + +func runWrap(args []string) { + runProxyCommand(proxyRunConfig{ + name: "wrap", + usage: "wrap [flags] -- [args...]", + run: autowrap.Run, + }, args) } // runUninstall handles the uninstall subcommand From b3f64cdba50a5be9fdb0f511c9356d08186581be Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 27 Feb 2026 18:57:11 +0800 Subject: [PATCH 3/6] test: add E2E tests against real MCP filesystem server MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add E2E tests that exercise the full proxy stack against @modelcontextprotocol/server-filesystem — a real MCP server doing actual file I/O. Tests verify initialize handshake, tools/list, allowed reads/writes returning real content, and blocked .env/.ssh access returning Crust error responses. - 8 E2E tests in internal/mcpgateway/e2e_test.go - Guarded with -short skip + npx availability check - New CI job (E2E Tests) with Node.js 22 + pre-installed server - Trimmed redundant unit tests now covered by E2E --- .github/workflows/ci.yml | 20 ++ internal/mcpgateway/e2e_test.go | 381 ++++++++++++++++++++++++++++++++ internal/mcpgateway/run_test.go | 128 ++--------- 3 files changed, 421 insertions(+), 108 deletions(-) create mode 100644 internal/mcpgateway/e2e_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6b95aac..c647e72 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,6 +60,26 @@ jobs: - name: Unit tests (with race detector) run: go test -race ./... -short + e2e: + name: E2E Tests (MCP) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: '1.24.13' + + - uses: actions/setup-node@v4 + with: + node-version: '22' + + - name: Pre-install MCP filesystem server + run: npm install -g @modelcontextprotocol/server-filesystem + + - name: E2E tests + run: go test ./internal/mcpgateway/... -v -run E2E -timeout 120s + docker-test: name: Docker runs-on: ubuntu-latest diff --git a/internal/mcpgateway/e2e_test.go b/internal/mcpgateway/e2e_test.go new file mode 100644 index 0000000..623ce24 --- /dev/null +++ b/internal/mcpgateway/e2e_test.go @@ -0,0 +1,381 @@ +package mcpgateway + +import ( + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/BakeLens/crust/internal/jsonrpc" +) + +// skipE2E skips if -short or npx not available. +func skipE2E(t *testing.T) { + t.Helper() + if testing.Short() { + t.Skip("E2E: skipped in -short mode") + } + if _, err := exec.LookPath("npx"); err != nil { + t.Skip("E2E: npx not found in PATH") + } +} + +// setupTestDir creates a temp directory with test files for the filesystem server. +// It resolves symlinks so paths match on macOS (/var → /private/var). +func setupTestDir(t *testing.T) string { + t.Helper() + raw := t.TempDir() + dir, err := filepath.EvalSymlinks(raw) + if err != nil { + t.Fatalf("failed to resolve symlinks for %s: %v", raw, err) + } + + // Safe files + os.WriteFile(filepath.Join(dir, "safe.txt"), []byte("hello world"), 0o644) + os.MkdirAll(filepath.Join(dir, "subdir"), 0o755) + os.WriteFile(filepath.Join(dir, "subdir", "code.go"), []byte("package main"), 0o644) + + // Sensitive files (should be blocked by Crust) + os.WriteFile(filepath.Join(dir, ".env"), []byte("SECRET_KEY=sk-1234"), 0o644) + os.MkdirAll(filepath.Join(dir, ".ssh"), 0o700) + os.WriteFile(filepath.Join(dir, ".ssh", "id_rsa"), []byte("fake-private-key"), 0o600) + + return dir +} + +// e2eResponse represents a parsed JSON-RPC response from the proxy output. +type e2eResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` +} + +// runMCPE2E runs the MCP proxy against the real filesystem server and returns +// all JSON-RPC responses received by the client. +func runMCPE2E(t *testing.T, dir string, messages []string) []e2eResponse { + t.Helper() + engine := newTestEngine(t) + input := strings.Join(messages, "\n") + "\n" + stdinR := io.NopCloser(strings.NewReader(input)) + var stdout strings.Builder + + done := make(chan int, 1) + go func() { + done <- jsonrpc.RunProxy(engine, + []string{"npx", "-y", "@modelcontextprotocol/server-filesystem", dir}, + stdinR, &stdout, jsonrpc.ProxyConfig{ + Log: testLog, + ProcessLabel: "MCP server", + Inbound: jsonrpc.PipeConfig{Label: "Client->Server", Protocol: "MCP", Convert: MCPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Server->Client"}, + }) + }() + + select { + case <-done: + return parseE2EResponses(t, stdout.String()) + case <-time.After(30 * time.Second): + t.Fatal("E2E test timed out (30s)") + return nil + } +} + +// parseE2EResponses parses JSONL output into e2eResponse structs. +func parseE2EResponses(t *testing.T, output string) []e2eResponse { + t.Helper() + var responses []e2eResponse + for line := range strings.SplitSeq(output, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + var resp e2eResponse + if err := json.Unmarshal([]byte(line), &resp); err != nil { + t.Logf("skipping non-JSON line: %s", line) + continue + } + responses = append(responses, resp) + } + return responses +} + +// findByID finds a response with the given integer ID. +func findByID(responses []e2eResponse, id int) *e2eResponse { + target := fmt.Sprintf("%d", id) + for i := range responses { + if string(responses[i].ID) == target { + return &responses[i] + } + } + return nil +} + +// initMessages returns the standard MCP handshake messages (initialize + initialized notification). +func initMessages() []string { + return []string{ + `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"crust-e2e","version":"1.0.0"}}}`, + `{"jsonrpc":"2.0","method":"notifications/initialized","params":{}}`, + } +} + +// --- E2E Tests --- + +func TestE2E_Initialize(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + responses := runMCPE2E(t, dir, initMessages()[:1]) // just initialize, no notification + + resp := findByID(responses, 1) + if resp == nil { + t.Fatal("no response for initialize (id=1)") + } + if resp.Error != nil { + t.Fatalf("initialize returned error: %s", resp.Error.Message) + } + + // Verify response has protocolVersion + var result map[string]any + if err := json.Unmarshal(resp.Result, &result); err != nil { + t.Fatalf("failed to parse init result: %v", err) + } + if _, ok := result["protocolVersion"]; !ok { + t.Error("initialize response missing protocolVersion") + } + if _, ok := result["capabilities"]; !ok { + t.Error("initialize response missing capabilities") + } + if _, ok := result["serverInfo"]; !ok { + t.Error("initialize response missing serverInfo") + } +} + +func TestE2E_ToolsList(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + messages := append(initMessages(), + `{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`, + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 2) + if resp == nil { + t.Fatal("no response for tools/list (id=2)") + } + if resp.Error != nil { + t.Fatalf("tools/list returned error: %s", resp.Error.Message) + } + + var result map[string]any + if err := json.Unmarshal(resp.Result, &result); err != nil { + t.Fatalf("failed to parse tools/list result: %v", err) + } + tools, ok := result["tools"].([]any) + if !ok || len(tools) == 0 { + t.Fatal("tools/list returned no tools") + } + + // Verify known tools exist + toolNames := make(map[string]bool) + for _, tool := range tools { + if m, ok := tool.(map[string]any); ok { + if name, ok := m["name"].(string); ok { + toolNames[name] = true + } + } + } + for _, want := range []string{"read_text_file", "write_file"} { + if !toolNames[want] { + t.Errorf("tools/list missing tool %q, got: %v", want, toolNames) + } + } +} + +func TestE2E_ReadAllowed(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + messages := append(initMessages(), + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/safe.txt"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 3) + if resp == nil { + t.Fatal("no response for read_text_file safe.txt (id=3)") + } + if resp.Error != nil { + t.Fatalf("read_text_file returned error: code=%d msg=%s", resp.Error.Code, resp.Error.Message) + } + + // Verify the response contains the actual file content + if !strings.Contains(string(resp.Result), "hello world") { + t.Errorf("expected file content 'hello world' in response, got: %s", string(resp.Result)) + } +} + +func TestE2E_ReadBlocked_Env(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + messages := append(initMessages(), + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/.env"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 3) + if resp == nil { + t.Fatal("no response for blocked .env read (id=3)") + } + if resp.Error == nil { + t.Fatalf("expected Crust block error for .env read, got success: %s", string(resp.Result)) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message missing [Crust] prefix: %s", resp.Error.Message) + } +} + +func TestE2E_ReadBlocked_SSHKey(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + messages := append(initMessages(), + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/.ssh/id_rsa"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 3) + if resp == nil { + t.Fatal("no response for blocked SSH key read (id=3)") + } + if resp.Error == nil { + t.Fatalf("expected Crust block error for SSH key read, got success: %s", string(resp.Result)) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } +} + +func TestE2E_WriteAllowed(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + outFile := filepath.Join(dir, "output.txt") + + messages := append(initMessages(), + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"write_file","arguments":{"path":"%s","content":"written by e2e test"}}}`, outFile), + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 3) + if resp == nil { + t.Fatal("no response for write_file (id=3)") + } + if resp.Error != nil { + t.Fatalf("write_file returned error: code=%d msg=%s", resp.Error.Code, resp.Error.Message) + } + + // Verify the file was actually written + content, err := os.ReadFile(outFile) + if err != nil { + t.Fatalf("failed to read written file: %v", err) + } + if string(content) != "written by e2e test" { + t.Errorf("file content = %q, want %q", string(content), "written by e2e test") + } +} + +func TestE2E_WriteBlocked_Env(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + envContent, _ := os.ReadFile(filepath.Join(dir, ".env")) + + messages := append(initMessages(), + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"write_file","arguments":{"path":"%s/.env","content":"STOLEN=true"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 3) + if resp == nil { + t.Fatal("no response for blocked .env write (id=3)") + } + if resp.Error == nil { + t.Fatalf("expected Crust block error for .env write, got success: %s", string(resp.Result)) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + + // Verify the .env file was NOT modified + after, _ := os.ReadFile(filepath.Join(dir, ".env")) + if string(after) != string(envContent) { + t.Errorf(".env was modified despite being blocked: %q → %q", envContent, after) + } +} + +func TestE2E_MixedStream(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + messages := append(initMessages(), + // id=2: tools/list (allowed — not tools/call) + `{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`, + // id=3: read .env (BLOCKED) + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/.env"}}}`, dir), + // id=4: read safe.txt (allowed) + fmt.Sprintf(`{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/safe.txt"}}}`, dir), + // id=5: write .env (BLOCKED) + fmt.Sprintf(`{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"write_file","arguments":{"path":"%s/.env","content":"STOLEN"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + // id=1: initialize — should succeed + if r := findByID(responses, 1); r == nil || r.Error != nil { + t.Error("initialize (id=1) should succeed") + } + + // id=2: tools/list — should succeed + if r := findByID(responses, 2); r == nil || r.Error != nil { + t.Error("tools/list (id=2) should succeed") + } + + // id=3: read .env — should be blocked + if r := findByID(responses, 3); r == nil { + t.Error("expected response for blocked .env read (id=3)") + } else if r.Error == nil { + t.Error("read .env (id=3) should be blocked") + } else if r.Error.Code != jsonrpc.BlockedError { + t.Errorf("read .env error code = %d, want %d", r.Error.Code, jsonrpc.BlockedError) + } + + // id=4: read safe.txt — should succeed with content + if r := findByID(responses, 4); r == nil { + t.Error("expected response for safe.txt read (id=4)") + } else if r.Error != nil { + t.Errorf("read safe.txt (id=4) should succeed, got error: %s", r.Error.Message) + } else if !strings.Contains(string(r.Result), "hello world") { + t.Errorf("read safe.txt (id=4) missing content, got: %s", string(r.Result)) + } + + // id=5: write .env — should be blocked + if r := findByID(responses, 5); r == nil { + t.Error("expected response for blocked .env write (id=5)") + } else if r.Error == nil { + t.Error("write .env (id=5) should be blocked") + } else if r.Error.Code != jsonrpc.BlockedError { + t.Errorf("write .env error code = %d, want %d", r.Error.Code, jsonrpc.BlockedError) + } +} diff --git a/internal/mcpgateway/run_test.go b/internal/mcpgateway/run_test.go index 5ef5d7b..7f8b27a 100644 --- a/internal/mcpgateway/run_test.go +++ b/internal/mcpgateway/run_test.go @@ -3,11 +3,8 @@ package mcpgateway import ( "bytes" "encoding/json" - "io" - "os/exec" "strings" "testing" - "time" "github.com/BakeLens/crust/internal/jsonrpc" "github.com/BakeLens/crust/internal/logger" @@ -41,76 +38,14 @@ func runPipe(t *testing.T, input string) (fwd, errOut string) { return fwdBuf.String(), errBuf.String() } -// --- RunProxy (hang / exit code) --- +// --- Edge-case blocking (malformed inputs, resources/read) --- +// Security blocking of .env, .ssh, etc. is covered by E2E tests (e2e_test.go). -func TestRunProxy(t *testing.T) { - t.Run("no_hang_on_server_exit", func(t *testing.T) { - if _, err := exec.LookPath("true"); err != nil { - t.Skip("'true' not found in PATH") - } - engine := newTestEngine(t) - stdinR, stdinW := io.Pipe() - defer stdinW.Close() - - done := make(chan int, 1) - go func() { - done <- jsonrpc.RunProxy(engine, []string{"true"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ - Log: testLog, - ProcessLabel: "MCP server", - Inbound: jsonrpc.PipeConfig{Label: "Client->Server", Protocol: "MCP", Convert: MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Server->Client"}, - }) - }() - - select { - case code := <-done: - if code != 0 { - t.Errorf("exit code = %d, want 0", code) - } - case <-time.After(5 * time.Second): - t.Fatal("RunProxy hung — client stdin not closed after server exit") - } - }) - - t.Run("propagates_exit_code", func(t *testing.T) { - if _, err := exec.LookPath("false"); err != nil { - t.Skip("'false' not found in PATH") - } - engine := newTestEngine(t) - stdinR, stdinW := io.Pipe() - defer stdinW.Close() - - done := make(chan int, 1) - go func() { - done <- jsonrpc.RunProxy(engine, []string{"false"}, stdinR, &bytes.Buffer{}, jsonrpc.ProxyConfig{ - Log: testLog, - ProcessLabel: "MCP server", - Inbound: jsonrpc.PipeConfig{Label: "Client->Server", Protocol: "MCP", Convert: MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Server->Client"}, - }) - }() - - select { - case code := <-done: - if code == 0 { - t.Error("expected non-zero exit code from 'false'") - } - case <-time.After(5 * time.Second): - t.Fatal("RunProxy hung") - } - }) -} - -// --- PipeInspect + MCPMethodToToolCall integration --- - -func TestPipeClientToServer_Blocks(t *testing.T) { +func TestPipeClientToServer_BlocksEdgeCases(t *testing.T) { tests := []struct { name string msg string }{ - {"env_read", `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`}, - {"ssh_key_read", `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/home/user/.ssh/id_rsa"}}}`}, - {"env_write", `{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"write_file","arguments":{"path":"/app/.env","content":"API_KEY=secret"}}}`}, {"resource_env_read", `{"jsonrpc":"2.0","id":4,"method":"resources/read","params":{"uri":"file:///app/.env"}}`}, {"malformed_tools_call", `{"jsonrpc":"2.0","id":5,"method":"tools/call","params":"not-an-object"}`}, {"null_params", `{"jsonrpc":"2.0","id":6,"method":"tools/call","params":null}`}, @@ -129,31 +64,14 @@ func TestPipeClientToServer_Blocks(t *testing.T) { } } -func TestPipeClientToServer_BlocksEnvRead_ErrorShape(t *testing.T) { - fwd, errOut := runPipe(t, `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`+"\n") - if fwd != "" { - t.Errorf("server should not receive blocked request, got: %s", fwd) - } - var resp jsonrpc.ErrorResponse - if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { - t.Fatalf("expected JSON-RPC error, got: %s", errOut) - } - if resp.Error.Code != jsonrpc.BlockedError { - t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) - } - if !strings.Contains(resp.Error.Message, "[Crust]") { - t.Errorf("error message missing [Crust]: %s", resp.Error.Message) - } -} +// --- Passthrough edge cases --- +// Passthrough of initialize, tools/list, and allowed tool calls is covered by E2E tests. -func TestPipeClientToServer_Passes(t *testing.T) { +func TestPipeClientToServer_PassesEdgeCases(t *testing.T) { tests := []struct { name string msg string }{ - {"normal_read", `{"jsonrpc":"2.0","id":10,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/src/main.go"}}}`}, - {"non_security_method", `{"jsonrpc":"2.0","id":20,"method":"initialize","params":{"capabilities":{}}}`}, - {"tools_list", `{"jsonrpc":"2.0","id":30,"method":"tools/list","params":{}}`}, {"notification", `{"jsonrpc":"2.0","method":"notifications/cancelled","params":{"requestId":1}}`}, //nolint:misspell // MCP protocol uses "cancelled" {"response", `{"jsonrpc":"2.0","id":5,"result":{"content":"file data"}}`}, {"invalid_json", `not valid json at all`}, @@ -171,27 +89,21 @@ func TestPipeClientToServer_Passes(t *testing.T) { } } -func TestPipeClientToServer_EmptyLine(t *testing.T) { - fwd, _ := runPipe(t, "\n") - if fwd != "\n" { - t.Errorf("empty line should pass through, got: %q", fwd) - } -} - -func TestPipeClientToServer_MultipleMessages(t *testing.T) { - msgs := strings.Join([]string{ - `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}`, - `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/main.go"}}}`, - `{"jsonrpc":"2.0","id":3,"method":"initialize","params":{"capabilities":{}}}`, - }, "\n") + "\n" +// --- resources/read error response shape --- - fwd, errOut := runPipe(t, msgs) - - fwdLines := strings.Split(strings.TrimRight(fwd, "\n"), "\n") - if len(fwdLines) != 2 { - t.Errorf("expected 2 server messages (main.go + initialize), got %d: %v", len(fwdLines), fwdLines) +func TestPipeClientToServer_ResourceReadErrorShape(t *testing.T) { + fwd, errOut := runPipe(t, `{"jsonrpc":"2.0","id":1,"method":"resources/read","params":{"uri":"file:///app/.env"}}`+"\n") + if fwd != "" { + t.Errorf("server should not receive blocked request, got: %s", fwd) + } + var resp jsonrpc.ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace([]byte(errOut)), &resp); err != nil { + t.Fatalf("expected JSON-RPC error, got: %s", errOut) } - if errOut == "" { - t.Error("client should receive error for .env read") + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message missing [Crust]: %s", resp.Error.Message) } } From 96ca43fd66de99cff35c9dc5bafe1380450d6911 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 27 Feb 2026 19:09:40 +0800 Subject: [PATCH 4/6] docs: add MCP gateway documentation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add docs/mcp.md with detailed MCP gateway guide — architecture diagram, supported servers, Claude Desktop config, auto-detect mode, blocking examples, and CLI flags. Update README with MCP Gateway section, grouped documentation table (Setup / Reference), and MCP column in how-it-works blocking matrix. --- README.md | 35 +++++++++++---- docs/how-it-works.md | 53 ++++++++-------------- docs/mcp.md | 103 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 42 deletions(-) create mode 100644 docs/mcp.md diff --git a/README.md b/README.md index a073b37..c9474da 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@

WebsiteQuick Start • + MCPACPProtectionHow It Works • @@ -104,6 +105,16 @@ crust doctor # Diagnose provider endpoints crust stop # Stop crust ``` +## MCP Gateway + +For [MCP](https://modelcontextprotocol.io) servers, Crust intercepts `tools/call` and `resources/read` requests before they reach the server. + +```bash +crust mcp-gateway -- npx -y @modelcontextprotocol/server-filesystem /path/to/dir +``` + +Works with any MCP server. See the [MCP setup guide](docs/mcp.md) for details and examples. + ## ACP Integration For IDEs that use the [Agent Client Protocol](https://agentclientprotocol.com) (ACP), Crust can wrap any ACP agent as a transparent stdio proxy — intercepting file reads, writes, and terminal commands before the IDE executes them. No changes to the agent or IDE required. @@ -170,23 +181,31 @@ Crust inspects tool calls at multiple layers: 1. **Layer 0 (Request Scan)**: Scans tool calls in conversation history before they reach the LLM — catches agents replaying dangerous actions. 2. **Layer 1 (Response Scan)**: Scans tool calls in the LLM's response before they execute — blocks new dangerous actions in real-time. -3. **ACP Mode**: Wraps ACP agents as a stdio proxy, intercepting JSON-RPC file/terminal requests before the IDE executes them. +3. **Stdio Proxy** ([MCP](docs/mcp.md) / [ACP](docs/acp.md)): Wraps MCP servers or ACP agents as a stdio proxy, intercepting security-relevant JSON-RPC messages. -Layers 0–1 apply a [10-step evaluation pipeline](docs/how-it-works.md) — input sanitization, Unicode normalization, obfuscation detection, DLP secret scanning, path-based rules, and fallback content matching — each step in microseconds. ACP mode reuses the same rule engine. +All modes apply a [10-step evaluation pipeline](docs/how-it-works.md) — input sanitization, Unicode normalization, obfuscation detection, DLP secret scanning, path-based rules, and fallback content matching — each step in microseconds. All activity is logged locally to encrypted storage. ## Documentation +**Setup** + +| Guide | Description | +|-------|-------------| +| [Configuration](docs/configuration.md) | Providers, auto mode, block modes | +| [MCP Gateway](docs/mcp.md) | Stdio proxy for [MCP](https://modelcontextprotocol.io) servers — Claude Desktop, custom servers | +| [ACP Integration](docs/acp.md) | Stdio proxy for [ACP](https://agentclientprotocol.com) agents — JetBrains, VS Code | +| [Docker](docs/docker.md) | Dockerfile, docker-compose, container setup | + +**Reference** + | Guide | Description | |-------|-------------| -| [Configuration](docs/configuration.md) | `config.yaml`, providers, auto mode, block modes | | [CLI Reference](docs/cli.md) | Commands, flags, environment variables | -| [How It Works](docs/how-it-works.md) | Architecture, rule schema, protection categories | -| [Docker](docs/docker.md) | Dockerfile, docker-compose, TUI in containers | -| [Shell Parsing](docs/shell-parsing.md) | How Bash commands are parsed for path and command extraction | -| [Migration](docs/migration.md) | Upgrade guides for breaking changes between versions | -| [TUI Design](docs/tui.md) | Terminal UI internals, plain mode, Docker behavior | +| [How It Works](docs/how-it-works.md) | Architecture, rule engine, evaluation pipeline | +| [Shell Parsing](docs/shell-parsing.md) | Bash command parsing for path/command extraction | +| [Migration](docs/migration.md) | Upgrade guides for breaking changes | ## Build from Source diff --git a/docs/how-it-works.md b/docs/how-it-works.md index 27d8788..fa4d73f 100644 --- a/docs/how-it-works.md +++ b/docs/how-it-works.md @@ -28,28 +28,11 @@ Layer 1 Rule Evaluation Order: **Layer 1 (Response Rules):** Scans LLM-generated tool_calls in responses. Fast pattern matching with friendly error messages. -**ACP Mode (`crust acp-wrap`):** For IDEs using the [Agent Client Protocol](https://agentclientprotocol.com), Crust wraps the agent as a transparent stdio proxy. Supports JetBrains IDEs and other ACP-compatible editors. Security-relevant JSON-RPC messages (`fs/read_text_file`, `fs/write_text_file`, `terminal/create`) are intercepted and evaluated by the same rule engine. Blocked requests never reach the IDE — the agent receives a JSON-RPC error response instead. See [ACP setup guide](acp.md) for configuration details. +**[MCP Gateway](mcp.md) (`crust mcp-gateway`):** Wraps [MCP](https://modelcontextprotocol.io) servers as a transparent stdio proxy. Intercepts `tools/call` and `resources/read` requests. Works with any MCP server (filesystem, database, custom). -```text -IDE (JetBrains / any ACP-compatible editor) - │ stdin/stdout (JSON-RPC 2.0) - ▼ -┌──────────────────────────────────────┐ -│ crust acp-wrap │ -│ │ -│ Agent→IDE: inspect each request │ -│ ├─ fs/read_text_file → Evaluate │ -│ ├─ fs/write_text_file → Evaluate │ -│ ├─ terminal/create → Evaluate │ -│ └─ everything else → pass │ -│ │ -│ BLOCKED → JSON-RPC error to agent │ -│ ALLOWED → forward to IDE unchanged │ -└──────────────────────────────────────┘ - │ stdin/stdout - ▼ -Real ACP Agent (Goose, Gemini CLI, etc.) -``` +**[ACP Mode](acp.md) (`crust acp-wrap`):** Wraps [ACP](https://agentclientprotocol.com) agents as a transparent stdio proxy. Intercepts `fs/read_text_file`, `fs/write_text_file`, and `terminal/create` requests. Supports JetBrains IDEs and other ACP-compatible editors. + +**Auto-detect (`crust wrap`):** Inspects both MCP and ACP methods simultaneously. Method names are disjoint — no conflict. --- @@ -91,19 +74,21 @@ Real ACP Agent (Goose, Gemini CLI, etc.) ## When Each Layer Blocks -| Attack | Layer 0 | Layer 1 | ACP Mode | -|--------|---------|---------|----------| -| Bad agent with secrets in history | ✅ Blocked | - | - | -| Poisoned conversation replay | ✅ Blocked | - | - | -| LLM generates `cat .env` | - | ✅ Blocked | - | -| LLM generates `rm -rf /etc` | - | ✅ Blocked | - | -| `$(cat .env)` obfuscation | - | ✅ Blocked | - | -| Symlink bypass | - | ✅ Blocked (composite) | - | -| Leaking real API keys/tokens | - | ✅ Blocked (DLP) | ✅ Blocked (DLP) | -| MCP plugin (e.g. Playwright) | - | ✅ Blocked (content-only) | - | -| ACP agent reads `.env` via IDE | - | - | ✅ Blocked | -| ACP agent reads SSH keys via IDE | - | - | ✅ Blocked | -| ACP agent runs `cat /etc/shadow` | - | - | ✅ Blocked | +| Attack | Layer 0 | Layer 1 | MCP Gateway | ACP Mode | +|--------|---------|---------|-------------|----------| +| Bad agent with secrets in history | ✅ Blocked | - | - | - | +| Poisoned conversation replay | ✅ Blocked | - | - | - | +| LLM generates `cat .env` | - | ✅ Blocked | - | - | +| LLM generates `rm -rf /etc` | - | ✅ Blocked | - | - | +| `$(cat .env)` obfuscation | - | ✅ Blocked | - | - | +| Symlink bypass | - | ✅ Blocked (composite) | - | - | +| Leaking real API keys/tokens | - | ✅ Blocked (DLP) | ✅ Blocked (DLP) | ✅ Blocked (DLP) | +| MCP server reads `.env` | - | - | ✅ Blocked | - | +| MCP server reads SSH keys | - | - | ✅ Blocked | - | +| MCP `resources/read file:///etc/shadow` | - | - | ✅ Blocked | - | +| ACP agent reads `.env` via IDE | - | - | - | ✅ Blocked | +| ACP agent reads SSH keys via IDE | - | - | - | ✅ Blocked | +| ACP agent runs `cat /etc/shadow` | - | - | - | ✅ Blocked | --- diff --git a/docs/mcp.md b/docs/mcp.md new file mode 100644 index 0000000..d0bee53 --- /dev/null +++ b/docs/mcp.md @@ -0,0 +1,103 @@ +# MCP Gateway + +Crust can wrap any [MCP (Model Context Protocol)](https://modelcontextprotocol.io) server as a transparent stdio proxy — intercepting `tools/call` and `resources/read` requests before they reach the server. + +```bash +crust mcp-gateway -- npx -y @modelcontextprotocol/server-filesystem /path/to/dir +``` + +## How It Works + +```text +MCP Client (Claude Desktop, IDE, etc.) + │ stdin/stdout (JSON-RPC 2.0) + ▼ +┌──────────────────────────────────────┐ +│ crust mcp-gateway │ +│ │ +│ Client→Server: inspect each request │ +│ ├─ tools/call → Evaluate │ +│ ├─ resources/read → Evaluate │ +│ └─ everything else → pass │ +│ │ +│ BLOCKED → JSON-RPC error to client │ +│ ALLOWED → forward to server │ +└──────────────────────────────────────┘ + │ stdin/stdout + ▼ +Real MCP Server (filesystem, database, etc.) +``` + +Crust evaluates every `tools/call` and `resources/read` request against the rule engine. Tool arguments (paths, commands, content) are extracted using **shape-based detection** — any tool with a `path` field is treated as file access, regardless of the tool name. This means Crust protects against novel tools without needing explicit configuration. + +Allowed requests pass through byte-for-byte unchanged. Blocked requests receive a JSON-RPC error response with code `-32001` and a `[Crust]`-prefixed message explaining the block. + +## Prerequisites + +1. **Crust** installed and on your `PATH` +2. **An MCP server** — any server that speaks [MCP](https://modelcontextprotocol.io) over stdio + +## Supported MCP Servers + +Any MCP server works. Common examples: + +| Server | Install | Command | +|--------|---------|---------| +| [Filesystem](https://github.com/modelcontextprotocol/servers/tree/main/src/filesystem) | `npm i -g @modelcontextprotocol/server-filesystem` | `npx @modelcontextprotocol/server-filesystem /path` | +| [Everything](https://github.com/modelcontextprotocol/servers/tree/main/src/everything) | `npm i -g @modelcontextprotocol/server-everything` | `npx @modelcontextprotocol/server-everything` | +| [PostgreSQL](https://github.com/modelcontextprotocol/servers/tree/main/src/postgres) | `npm i -g @modelcontextprotocol/server-postgres` | `npx @modelcontextprotocol/server-postgres $DATABASE_URL` | +| Custom server | — | Any command that speaks MCP over stdio | + +## Claude Desktop + +Add Crust as a wrapper in your Claude Desktop MCP config (`~/Library/Application Support/Claude/claude_desktop_config.json` on macOS): + +```json +{ + "mcpServers": { + "filesystem": { + "command": "crust", + "args": ["mcp-gateway", "--", "npx", "-y", "@modelcontextprotocol/server-filesystem", "/Users/you/projects"] + } + } +} +``` + +## Auto-detect Mode + +If you don't know whether a subprocess speaks MCP or ACP, use `crust wrap`: + +```bash +crust wrap -- npx -y @modelcontextprotocol/server-filesystem /path/to/dir +``` + +This inspects both MCP (inbound) and ACP (outbound) methods simultaneously. Since the method names are disjoint, there is no conflict. + +## What Gets Blocked + +The same rules apply as the HTTP gateway and ACP modes. Security-relevant tool calls are evaluated against path rules, DLP patterns, and content matching: + +| Scenario | Result | +|----------|--------| +| `tools/call read_text_file /app/main.go` | Allowed | +| `tools/call read_text_file /app/.env` | Blocked — `.env` files contain secrets | +| `tools/call write_file /app/.env` | Blocked — cannot write to `.env` | +| `tools/call read_text_file ~/.ssh/id_rsa` | Blocked — SSH private keys | +| `resources/read file:///etc/shadow` | Blocked — system auth files | +| `tools/call list_directory /app/src` | Allowed | +| `initialize`, `tools/list`, notifications | Passed through unchanged | + +## CLI Reference + +```bash +crust mcp-gateway [flags] -- [args...] +``` + +| Flag | Default | Description | +|------|---------|-------------| +| `--config` | `~/.crust/config.yaml` | Path to configuration file | +| `--rules-dir` | `~/.crust/rules/` | Directory for custom rules | +| `--log-level` | `info` | Log level (`debug`, `info`, `warn`, `error`) | +| `--disable-builtin` | `false` | Disable built-in security rules | + +Logs go to stderr so they don't interfere with the JSON-RPC stdio stream. From 53a1c3177df71b34d813267b3c1896360d35cbb3 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 27 Feb 2026 19:36:35 +0800 Subject: [PATCH 5/6] feat: add response DLP scanning and bidirectional stdio proxy inspection Close security gaps where server responses could leak secrets undetected: - Add Engine.ScanDLP() for standalone DLP scanning of arbitrary content - PipeInspect now scans response Result fields for DLP patterns before forwarding to the client (errors sent to client via fwdWriter) - MCP gateway outbound now uses MCPMethodToToolCall (was passthrough) - New BothMethodToToolCall combined converter for crust wrap outbound inspects both MCP and ACP methods in both directions - E2E tests against real @modelcontextprotocol/server-filesystem verify that files with embedded AWS keys and GitHub tokens are blocked by response DLP even when the file path passes inbound rules --- README.md | 2 +- docs/how-it-works.md | 14 ++- docs/mcp.md | 31 +++++-- internal/autowrap/convert.go | 19 ++++ internal/autowrap/run.go | 3 +- internal/autowrap/run_test.go | 67 ++++++++++++-- internal/jsonrpc/pipe.go | 12 +++ internal/mcpgateway/e2e_test.go | 158 +++++++++++++++++++++++++++++++- internal/mcpgateway/run.go | 2 +- internal/rules/engine.go | 37 ++++++++ 10 files changed, 316 insertions(+), 29 deletions(-) create mode 100644 internal/autowrap/convert.go diff --git a/README.md b/README.md index c9474da..935ad2c 100644 --- a/README.md +++ b/README.md @@ -181,7 +181,7 @@ Crust inspects tool calls at multiple layers: 1. **Layer 0 (Request Scan)**: Scans tool calls in conversation history before they reach the LLM — catches agents replaying dangerous actions. 2. **Layer 1 (Response Scan)**: Scans tool calls in the LLM's response before they execute — blocks new dangerous actions in real-time. -3. **Stdio Proxy** ([MCP](docs/mcp.md) / [ACP](docs/acp.md)): Wraps MCP servers or ACP agents as a stdio proxy, intercepting security-relevant JSON-RPC messages. +3. **Stdio Proxy** ([MCP](docs/mcp.md) / [ACP](docs/acp.md)): Wraps MCP servers or ACP agents as a stdio proxy, intercepting security-relevant JSON-RPC messages in both directions — including DLP scanning of server responses for leaked secrets. All modes apply a [10-step evaluation pipeline](docs/how-it-works.md) — input sanitization, Unicode normalization, obfuscation detection, DLP secret scanning, path-based rules, and fallback content matching — each step in microseconds. diff --git a/docs/how-it-works.md b/docs/how-it-works.md index fa4d73f..5f0e793 100644 --- a/docs/how-it-works.md +++ b/docs/how-it-works.md @@ -28,11 +28,11 @@ Layer 1 Rule Evaluation Order: **Layer 1 (Response Rules):** Scans LLM-generated tool_calls in responses. Fast pattern matching with friendly error messages. -**[MCP Gateway](mcp.md) (`crust mcp-gateway`):** Wraps [MCP](https://modelcontextprotocol.io) servers as a transparent stdio proxy. Intercepts `tools/call` and `resources/read` requests. Works with any MCP server (filesystem, database, custom). +**[MCP Gateway](mcp.md) (`crust mcp-gateway`):** Wraps [MCP](https://modelcontextprotocol.io) servers as a transparent stdio proxy. Inspects both directions — client→server requests (`tools/call`, `resources/read`) and server→client responses (DLP secret scanning). Works with any MCP server (filesystem, database, custom). **[ACP Mode](acp.md) (`crust acp-wrap`):** Wraps [ACP](https://agentclientprotocol.com) agents as a transparent stdio proxy. Intercepts `fs/read_text_file`, `fs/write_text_file`, and `terminal/create` requests. Supports JetBrains IDEs and other ACP-compatible editors. -**Auto-detect (`crust wrap`):** Inspects both MCP and ACP methods simultaneously. Method names are disjoint — no conflict. +**Auto-detect (`crust wrap`):** Inspects both MCP and ACP methods in both directions. Response DLP scans all server responses for leaked secrets. Method names are disjoint — no conflict. --- @@ -83,9 +83,11 @@ Layer 1 Rule Evaluation Order: | `$(cat .env)` obfuscation | - | ✅ Blocked | - | - | | Symlink bypass | - | ✅ Blocked (composite) | - | - | | Leaking real API keys/tokens | - | ✅ Blocked (DLP) | ✅ Blocked (DLP) | ✅ Blocked (DLP) | -| MCP server reads `.env` | - | - | ✅ Blocked | - | -| MCP server reads SSH keys | - | - | ✅ Blocked | - | -| MCP `resources/read file:///etc/shadow` | - | - | ✅ Blocked | - | +| MCP client reads `.env` | - | - | ✅ Blocked (inbound) | - | +| MCP client reads SSH keys | - | - | ✅ Blocked (inbound) | - | +| MCP `resources/read file:///etc/shadow` | - | - | ✅ Blocked (inbound) | - | +| MCP server returns API keys in results | - | - | ✅ Blocked (response DLP) | - | +| MCP server returns tokens in results | - | - | ✅ Blocked (response DLP) | - | | ACP agent reads `.env` via IDE | - | - | - | ✅ Blocked | | ACP agent reads SSH keys via IDE | - | - | - | ✅ Blocked | | ACP agent runs `cat /etc/shadow` | - | - | - | ✅ Blocked | @@ -96,6 +98,8 @@ Layer 1 Rule Evaluation Order: Step 7 of the evaluation pipeline runs hardcoded DLP (Data Loss Prevention) patterns against all operations. These patterns detect real API keys and tokens by their format, regardless of file path or tool name. +In stdio proxy modes (MCP Gateway, ACP Wrap, Auto-detect), DLP also scans **server/agent responses** before they reach the client. This catches secrets leaked by the subprocess — for example, an MCP server returning file content that contains an AWS access key. The response is replaced with a JSON-RPC error so the secret never reaches the client. + | Provider | Pattern | |----------|---------| | AWS | Access key IDs (`AKIA...`, `ASIA...`) | diff --git a/docs/mcp.md b/docs/mcp.md index d0bee53..be19c4a 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -1,6 +1,6 @@ # MCP Gateway -Crust can wrap any [MCP (Model Context Protocol)](https://modelcontextprotocol.io) server as a transparent stdio proxy — intercepting `tools/call` and `resources/read` requests before they reach the server. +Crust can wrap any [MCP (Model Context Protocol)](https://modelcontextprotocol.io) server as a transparent stdio proxy — intercepting requests in both directions and scanning responses for leaked secrets. ```bash crust mcp-gateway -- npx -y @modelcontextprotocol/server-filesystem /path/to/dir @@ -15,22 +15,30 @@ MCP Client (Claude Desktop, IDE, etc.) ┌──────────────────────────────────────┐ │ crust mcp-gateway │ │ │ -│ Client→Server: inspect each request │ +│ Client→Server (inbound): │ │ ├─ tools/call → Evaluate │ │ ├─ resources/read → Evaluate │ │ └─ everything else → pass │ │ │ -│ BLOCKED → JSON-RPC error to client │ -│ ALLOWED → forward to server │ +│ Server→Client (outbound): │ +│ ├─ responses → DLP scan │ +│ ├─ server requests → Evaluate │ +│ └─ everything else → pass │ +│ │ +│ BLOCKED → JSON-RPC error │ +│ ALLOWED → forward unchanged │ └──────────────────────────────────────┘ │ stdin/stdout ▼ Real MCP Server (filesystem, database, etc.) ``` -Crust evaluates every `tools/call` and `resources/read` request against the rule engine. Tool arguments (paths, commands, content) are extracted using **shape-based detection** — any tool with a `path` field is treated as file access, regardless of the tool name. This means Crust protects against novel tools without needing explicit configuration. +Crust inspects both directions: -Allowed requests pass through byte-for-byte unchanged. Blocked requests receive a JSON-RPC error response with code `-32001` and a `[Crust]`-prefixed message explaining the block. +- **Inbound (Client→Server):** Evaluates `tools/call` and `resources/read` requests against path rules, DLP patterns, and content matching. Tool arguments are extracted using **shape-based detection** — any tool with a `path` field is treated as file access, regardless of the tool name. +- **Outbound (Server→Client):** Scans server responses for leaked secrets using DLP patterns. If a server returns file content containing API keys or tokens, the response is blocked before it reaches the client. + +Allowed messages pass through byte-for-byte unchanged. Blocked messages receive a JSON-RPC error response with code `-32001` and a `[Crust]`-prefixed message explaining the block. ## Prerequisites @@ -77,6 +85,8 @@ This inspects both MCP (inbound) and ACP (outbound) methods simultaneously. Sinc The same rules apply as the HTTP gateway and ACP modes. Security-relevant tool calls are evaluated against path rules, DLP patterns, and content matching: +**Inbound (Client→Server):** + | Scenario | Result | |----------|--------| | `tools/call read_text_file /app/main.go` | Allowed | @@ -87,6 +97,15 @@ The same rules apply as the HTTP gateway and ACP modes. Security-relevant tool c | `tools/call list_directory /app/src` | Allowed | | `initialize`, `tools/list`, notifications | Passed through unchanged | +**Outbound (Server→Client) — Response DLP:** + +| Scenario | Result | +|----------|--------| +| Server returns file content with no secrets | Allowed | +| Server returns content with AWS key (`AKIA...`) | Blocked — DLP detects API key | +| Server returns content with GitHub token (`ghp_...`) | Blocked — DLP detects token | +| Server returns content with Stripe key (`sk_live_...`) | Blocked — DLP detects secret | + ## CLI Reference ```bash diff --git a/internal/autowrap/convert.go b/internal/autowrap/convert.go new file mode 100644 index 0000000..5bf3cc7 --- /dev/null +++ b/internal/autowrap/convert.go @@ -0,0 +1,19 @@ +package autowrap + +import ( + "encoding/json" + + "github.com/BakeLens/crust/internal/acpwrap" + "github.com/BakeLens/crust/internal/mcpgateway" + "github.com/BakeLens/crust/internal/rules" +) + +// BothMethodToToolCall tries MCP conversion first, then ACP. This is used for +// the outbound direction in crust wrap — a malicious subprocess could speak +// either protocol, so we check both. Method names are disjoint (no conflict). +func BothMethodToToolCall(method string, params json.RawMessage) (*rules.ToolCall, error) { + if tc, err := mcpgateway.MCPMethodToToolCall(method, params); tc != nil || err != nil { + return tc, err + } + return acpwrap.ACPMethodToToolCall(method, params) +} diff --git a/internal/autowrap/run.go b/internal/autowrap/run.go index 3ef80f3..5450077 100644 --- a/internal/autowrap/run.go +++ b/internal/autowrap/run.go @@ -11,7 +11,6 @@ package autowrap import ( "os" - "github.com/BakeLens/crust/internal/acpwrap" "github.com/BakeLens/crust/internal/jsonrpc" "github.com/BakeLens/crust/internal/logger" "github.com/BakeLens/crust/internal/mcpgateway" @@ -27,7 +26,7 @@ func Run(engine *rules.Engine, cmd []string) int { Log: log, ProcessLabel: "Subprocess", Inbound: jsonrpc.PipeConfig{Label: "Inbound", Protocol: "MCP", Convert: mcpgateway.MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "ACP", Convert: acpwrap.ACPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "Stdio", Convert: BothMethodToToolCall}, ExtraLogLines: []string{ "Auto-detect mode: inspecting both ACP and MCP methods", }, diff --git a/internal/autowrap/run_test.go b/internal/autowrap/run_test.go index 5e5e099..f7c7ec0 100644 --- a/internal/autowrap/run_test.go +++ b/internal/autowrap/run_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/BakeLens/crust/internal/acpwrap" "github.com/BakeLens/crust/internal/jsonrpc" "github.com/BakeLens/crust/internal/logger" "github.com/BakeLens/crust/internal/mcpgateway" @@ -42,7 +41,7 @@ func runInboundPipe(t *testing.T, input string) (fwd, errOut string) { return fwdBuf.String(), errBuf.String() } -// runOutboundPipe runs PipeInspect with ACPMethodToToolCall (outbound direction). +// runOutboundPipe runs PipeInspect with BothMethodToToolCall (outbound direction). func runOutboundPipe(t *testing.T, input string) (fwd, errOut string) { t.Helper() engine := newTestEngine(t) @@ -50,7 +49,7 @@ func runOutboundPipe(t *testing.T, input string) (fwd, errOut string) { fwdWriter := jsonrpc.NewLockedWriter(&fwdBuf) errWriter := jsonrpc.NewLockedWriter(&errBuf) jsonrpc.PipeInspect(testLog, engine, strings.NewReader(input), - fwdWriter, errWriter, acpwrap.ACPMethodToToolCall, "ACP", "Outbound") + fwdWriter, errWriter, BothMethodToToolCall, "Stdio", "Outbound") return fwdBuf.String(), errBuf.String() } @@ -181,9 +180,10 @@ func TestPipeOutbound_ErrorShape(t *testing.T) { } } -// --- Cross-protocol: MCP is not blocked on outbound, ACP is not blocked on inbound --- +// --- Cross-protocol --- func TestPipeInbound_IgnoresACPMethods(t *testing.T) { + // Inbound uses MCPMethodToToolCall only — ACP methods pass through unexamined. msg := `{"jsonrpc":"2.0","id":1,"method":"fs/read_text_file","params":{"sessionId":"s1","path":"/app/.env"}}` fwd, errOut := runInboundPipe(t, msg+"\n") if fwd != msg+"\n" { @@ -194,14 +194,61 @@ func TestPipeInbound_IgnoresACPMethods(t *testing.T) { } } -func TestPipeOutbound_IgnoresMCPMethods(t *testing.T) { +func TestPipeOutbound_BlocksMCPMethods(t *testing.T) { + // Outbound uses BothMethodToToolCall — MCP methods with sensitive paths are blocked. msg := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"read_file","arguments":{"path":"/app/.env"}}}` fwd, errOut := runOutboundPipe(t, msg+"\n") + if fwd != "" { + t.Errorf("MCP methods with .env should be blocked on outbound, got forwarded: %s", fwd) + } + if errOut == "" { + t.Error("subprocess should receive an error response for blocked MCP method") + } +} + +// --- Response DLP --- + +func TestPipeOutbound_ResponseDLP_BlocksSecrets(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"aws_key", `{"jsonrpc":"2.0","id":1,"result":{"content":"key=AKIAIOSFODNN7EXAMPLE"}}`}, + {"github_token", `{"jsonrpc":"2.0","id":2,"result":{"text":"ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklm"}}`}, + {"openai_key", `{"jsonrpc":"2.0","id":3,"result":{"config":"sk-proj-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"}}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, _ := runOutboundPipe(t, tt.msg+"\n") + // Response should NOT be forwarded (DLP blocks it) + if strings.Contains(fwd, "AKIA") || strings.Contains(fwd, "ghp_") || strings.Contains(fwd, "sk-proj-") { + t.Errorf("response with secret should be blocked by DLP, got forwarded: %s", fwd) + } + }) + } +} + +func TestPipeOutbound_ResponseDLP_PassesClean(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"result":{"content":"safe data, no secrets here"}}` + fwd, _ := runOutboundPipe(t, msg+"\n") if fwd != msg+"\n" { - t.Errorf("MCP methods should pass through in outbound direction\ngot: %q\nwant: %q", fwd, msg+"\n") + t.Errorf("clean response should pass through\ngot: %q\nwant: %q", fwd, msg+"\n") } - if errOut != "" { - t.Errorf("should not generate errors for MCP methods in outbound direction, got: %s", errOut) +} + +func TestPipeOutbound_ResponseDLP_ErrorResponseShape(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"result":{"content":"key=AKIAIOSFODNN7EXAMPLE"}}` + fwd, _ := runOutboundPipe(t, msg+"\n") + // fwd should contain a JSON-RPC error (sent to client via fwdWriter) + var resp jsonrpc.ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace([]byte(fwd)), &resp); err != nil { + t.Fatalf("expected JSON-RPC error in fwd, got: %q", fwd) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message missing [Crust]: %s", resp.Error.Message) } } @@ -238,7 +285,7 @@ func TestRunProxy(t *testing.T) { Log: testLog, ProcessLabel: "Subprocess", Inbound: jsonrpc.PipeConfig{Label: "Inbound", Protocol: "MCP", Convert: mcpgateway.MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "ACP", Convert: acpwrap.ACPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "Stdio", Convert: BothMethodToToolCall}, }) }() @@ -266,7 +313,7 @@ func TestRunProxy(t *testing.T) { Log: testLog, ProcessLabel: "Subprocess", Inbound: jsonrpc.PipeConfig{Label: "Inbound", Protocol: "MCP", Convert: mcpgateway.MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "ACP", Convert: acpwrap.ACPMethodToToolCall}, + Outbound: jsonrpc.PipeConfig{Label: "Outbound", Protocol: "Stdio", Convert: BothMethodToToolCall}, }) }() diff --git a/internal/jsonrpc/pipe.go b/internal/jsonrpc/pipe.go index 1f5400f..d0962d4 100644 --- a/internal/jsonrpc/pipe.go +++ b/internal/jsonrpc/pipe.go @@ -77,6 +77,18 @@ func PipeInspect(log *logger.Logger, engine *rules.Engine, src io.Reader, } if !msg.IsRequest() { + // Response DLP: scan successful responses for leaked secrets. + // Errors go to fwdWriter (client) because the client is waiting + // for this response — the server doesn't need to know. + if len(msg.Result) > 0 { + if dlpResult := engine.ScanDLP(string(msg.Result)); dlpResult != nil { + log.Warn("Blocked %s response (DLP): rule=%s message=%s", + protocol, dlpResult.RuleName, dlpResult.Message) + SendBlockError(log, fwdWriter, msg.ID, + fmt.Sprintf("[Crust] Blocked by rule %q: %s", dlpResult.RuleName, dlpResult.Message)) + continue + } + } if err := fwdWriter.WriteLine(line); err != nil { log.Debug("%s write error: %v", label, err) return diff --git a/internal/mcpgateway/e2e_test.go b/internal/mcpgateway/e2e_test.go index 623ce24..acf5610 100644 --- a/internal/mcpgateway/e2e_test.go +++ b/internal/mcpgateway/e2e_test.go @@ -40,11 +40,20 @@ func setupTestDir(t *testing.T) string { os.MkdirAll(filepath.Join(dir, "subdir"), 0o755) os.WriteFile(filepath.Join(dir, "subdir", "code.go"), []byte("package main"), 0o644) - // Sensitive files (should be blocked by Crust) + // Sensitive files (should be blocked by Crust path rules) os.WriteFile(filepath.Join(dir, ".env"), []byte("SECRET_KEY=sk-1234"), 0o644) os.MkdirAll(filepath.Join(dir, ".ssh"), 0o700) os.WriteFile(filepath.Join(dir, ".ssh", "id_rsa"), []byte("fake-private-key"), 0o600) + // Files with embedded secrets (should be blocked by response DLP) + // These files have innocent names but contain real API key patterns. + os.WriteFile(filepath.Join(dir, "config.txt"), + []byte("# App config\nAWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE\nREGION=us-east-1"), 0o644) + os.WriteFile(filepath.Join(dir, "tokens.txt"), + []byte("github_token=ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklm\n"), 0o644) + os.WriteFile(filepath.Join(dir, "notes.txt"), + []byte("TODO: refactor auth module\nno secrets here"), 0o644) + return dir } @@ -77,7 +86,7 @@ func runMCPE2E(t *testing.T, dir string, messages []string) []e2eResponse { Log: testLog, ProcessLabel: "MCP server", Inbound: jsonrpc.PipeConfig{Label: "Client->Server", Protocol: "MCP", Convert: MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Server->Client"}, + Outbound: jsonrpc.PipeConfig{Label: "Server->Client", Protocol: "MCP", Convert: MCPMethodToToolCall}, }) }() @@ -333,11 +342,11 @@ func TestE2E_MixedStream(t *testing.T) { messages := append(initMessages(), // id=2: tools/list (allowed — not tools/call) `{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}`, - // id=3: read .env (BLOCKED) + // id=3: read .env (BLOCKED by inbound path rules) fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/.env"}}}`, dir), // id=4: read safe.txt (allowed) fmt.Sprintf(`{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/safe.txt"}}}`, dir), - // id=5: write .env (BLOCKED) + // id=5: write .env (BLOCKED by inbound path rules) fmt.Sprintf(`{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"write_file","arguments":{"path":"%s/.env","content":"STOLEN"}}}`, dir), ) responses := runMCPE2E(t, dir, messages) @@ -379,3 +388,144 @@ func TestE2E_MixedStream(t *testing.T) { t.Errorf("write .env error code = %d, want %d", r.Error.Code, jsonrpc.BlockedError) } } + +// --- Response DLP E2E Tests --- +// These test the outbound direction: the REAL MCP server reads files with +// innocent names but secret content. Crust's response DLP blocks the response +// before it reaches the client. + +func TestE2E_ResponseDLP_AWSKey(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + // config.txt contains an AWS access key but is NOT a .env file + messages := append(initMessages(), + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/config.txt"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 3) + if resp == nil { + t.Fatal("no response for config.txt read (id=3)") + } + // Response DLP should block: the server returned the file content which contains an AWS key + if resp.Error == nil { + t.Fatalf("expected DLP block for AWS key in config.txt, got success: %s", string(resp.Result)) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "[Crust]") { + t.Errorf("error message missing [Crust]: %s", resp.Error.Message) + } + if !strings.Contains(resp.Error.Message, "AWS") { + t.Errorf("error message should mention AWS: %s", resp.Error.Message) + } +} + +func TestE2E_ResponseDLP_GitHubToken(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + // tokens.txt contains a GitHub personal access token + messages := append(initMessages(), + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/tokens.txt"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 3) + if resp == nil { + t.Fatal("no response for tokens.txt read (id=3)") + } + if resp.Error == nil { + t.Fatalf("expected DLP block for GitHub token in tokens.txt, got success: %s", string(resp.Result)) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + if !strings.Contains(resp.Error.Message, "GitHub") { + t.Errorf("error message should mention GitHub: %s", resp.Error.Message) + } +} + +func TestE2E_ResponseDLP_CleanFile(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + // notes.txt has no secrets — should pass through + messages := append(initMessages(), + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/notes.txt"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + resp := findByID(responses, 3) + if resp == nil { + t.Fatal("no response for notes.txt read (id=3)") + } + if resp.Error != nil { + t.Fatalf("notes.txt should pass DLP (no secrets), got error: code=%d msg=%s", resp.Error.Code, resp.Error.Message) + } + if !strings.Contains(string(resp.Result), "no secrets here") { + t.Errorf("expected file content in response, got: %s", string(resp.Result)) + } +} + +func TestE2E_ResponseDLP_MixedStream(t *testing.T) { + skipE2E(t) + dir := setupTestDir(t) + + messages := append(initMessages(), + // id=2: read clean file (ALLOWED — passes both inbound and response DLP) + fmt.Sprintf(`{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/notes.txt"}}}`, dir), + // id=3: read file with AWS key (BLOCKED by response DLP) + fmt.Sprintf(`{"jsonrpc":"2.0","id":3,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/config.txt"}}}`, dir), + // id=4: read .env (BLOCKED by inbound path rules — never reaches server) + fmt.Sprintf(`{"jsonrpc":"2.0","id":4,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/.env"}}}`, dir), + // id=5: read file with GitHub token (BLOCKED by response DLP) + fmt.Sprintf(`{"jsonrpc":"2.0","id":5,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/tokens.txt"}}}`, dir), + // id=6: read safe file (ALLOWED) + fmt.Sprintf(`{"jsonrpc":"2.0","id":6,"method":"tools/call","params":{"name":"read_text_file","arguments":{"path":"%s/safe.txt"}}}`, dir), + ) + responses := runMCPE2E(t, dir, messages) + + // id=2: clean file — should pass + if r := findByID(responses, 2); r == nil { + t.Error("expected response for notes.txt (id=2)") + } else if r.Error != nil { + t.Errorf("notes.txt (id=2) should pass, got error: %s", r.Error.Message) + } + + // id=3: AWS key in response — blocked by response DLP + if r := findByID(responses, 3); r == nil { + t.Error("expected response for config.txt (id=3)") + } else if r.Error == nil { + t.Error("config.txt (id=3) should be blocked by response DLP") + } else if r.Error.Code != jsonrpc.BlockedError { + t.Errorf("config.txt error code = %d, want %d", r.Error.Code, jsonrpc.BlockedError) + } + + // id=4: .env — blocked by inbound path rules + if r := findByID(responses, 4); r == nil { + t.Error("expected response for .env (id=4)") + } else if r.Error == nil { + t.Error(".env (id=4) should be blocked by inbound rules") + } + + // id=5: GitHub token in response — blocked by response DLP + if r := findByID(responses, 5); r == nil { + t.Error("expected response for tokens.txt (id=5)") + } else if r.Error == nil { + t.Error("tokens.txt (id=5) should be blocked by response DLP") + } else if r.Error.Code != jsonrpc.BlockedError { + t.Errorf("tokens.txt error code = %d, want %d", r.Error.Code, jsonrpc.BlockedError) + } + + // id=6: safe file — should pass + if r := findByID(responses, 6); r == nil { + t.Error("expected response for safe.txt (id=6)") + } else if r.Error != nil { + t.Errorf("safe.txt (id=6) should pass, got error: %s", r.Error.Message) + } else if !strings.Contains(string(r.Result), "hello world") { + t.Errorf("safe.txt (id=6) missing content, got: %s", string(r.Result)) + } +} diff --git a/internal/mcpgateway/run.go b/internal/mcpgateway/run.go index c688bf9..66c02dc 100644 --- a/internal/mcpgateway/run.go +++ b/internal/mcpgateway/run.go @@ -17,6 +17,6 @@ func Run(engine *rules.Engine, serverCmd []string) int { Log: log, ProcessLabel: "MCP server", Inbound: jsonrpc.PipeConfig{Label: "Client->Server", Protocol: "MCP", Convert: MCPMethodToToolCall}, - Outbound: jsonrpc.PipeConfig{Label: "Server->Client"}, + Outbound: jsonrpc.PipeConfig{Label: "Server->Client", Protocol: "MCP", Convert: MCPMethodToToolCall}, }) } diff --git a/internal/rules/engine.go b/internal/rules/engine.go index e24bf16..66c794f 100644 --- a/internal/rules/engine.go +++ b/internal/rules/engine.go @@ -793,6 +793,43 @@ func (e *Engine) GetLoader() *Loader { return e.loader } +// ScanDLP runs DLP (Data Loss Prevention) patterns against content. +// Returns a non-nil MatchResult if a secret is detected, nil if clean. +// This is used by PipeInspect to scan server responses for leaked secrets. +func (e *Engine) ScanDLP(content string) *MatchResult { + if content == "" { + return nil + } + // Tier 1: hardcoded patterns (fast, always available) + for _, pat := range dlpPatterns { + if pat.re.MatchString(content) { + return &MatchResult{ + Matched: true, + RuleName: pat.name, + Severity: SeverityCritical, + Action: ActionBlock, + Message: pat.message, + } + } + } + // Tier 2: gitleaks (if available) + if findings := e.dlpScanner.Scan(content); len(findings) > 0 { + f := findings[0] + msg := "Blocked secret — " + f.Description + if len(findings) > 1 { + msg += fmt.Sprintf(" (and %d more)", len(findings)-1) + } + return &MatchResult{ + Matched: true, + RuleName: "builtin:dlp-gitleaks-" + f.RuleID, + Severity: SeverityHigh, + Action: ActionBlock, + Message: msg, + } + } + return nil +} + // RuleValidationResult holds per-rule validation results. type RuleValidationResult struct { Name string `json:"name"` From 6004af5e7a3582073344f3ba907fb7249b93cad7 Mon Sep 17 00:00:00 2001 From: cyy Date: Fri, 27 Feb 2026 19:48:20 +0800 Subject: [PATCH 6/6] feat: add DLP scanning for error response field Scan JSON-RPC error field for leaked secrets alongside the existing result field scanning. A malicious server could embed API keys in error messages to exfiltrate them. --- internal/autowrap/run_test.go | 34 ++++++++++++++++++++++++++++++++++ internal/jsonrpc/pipe.go | 11 ++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/internal/autowrap/run_test.go b/internal/autowrap/run_test.go index f7c7ec0..cc69961 100644 --- a/internal/autowrap/run_test.go +++ b/internal/autowrap/run_test.go @@ -236,6 +236,40 @@ func TestPipeOutbound_ResponseDLP_PassesClean(t *testing.T) { } } +func TestPipeOutbound_ResponseDLP_BlocksErrorFieldSecrets(t *testing.T) { + tests := []struct { + name string + msg string + }{ + {"aws_key_in_error", `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"failed to read config: AKIAIOSFODNN7EXAMPLE"}}`}, + {"github_token_in_error", `{"jsonrpc":"2.0","id":2,"error":{"code":-32000,"message":"auth failed: ghp_ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklm"}}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fwd, _ := runOutboundPipe(t, tt.msg+"\n") + if strings.Contains(fwd, "AKIA") || strings.Contains(fwd, "ghp_") { + t.Errorf("error response with secret should be blocked by DLP, got forwarded: %s", fwd) + } + // fwd should contain a replacement JSON-RPC error from Crust + var resp jsonrpc.ErrorResponse + if err := json.Unmarshal(bytes.TrimSpace([]byte(fwd)), &resp); err != nil { + t.Fatalf("expected JSON-RPC error in fwd, got: %q", fwd) + } + if resp.Error.Code != jsonrpc.BlockedError { + t.Errorf("error code = %d, want %d", resp.Error.Code, jsonrpc.BlockedError) + } + }) + } +} + +func TestPipeOutbound_ResponseDLP_PassesCleanError(t *testing.T) { + msg := `{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"file not found"}}` + fwd, _ := runOutboundPipe(t, msg+"\n") + if fwd != msg+"\n" { + t.Errorf("clean error response should pass through\ngot: %q\nwant: %q", fwd, msg+"\n") + } +} + func TestPipeOutbound_ResponseDLP_ErrorResponseShape(t *testing.T) { msg := `{"jsonrpc":"2.0","id":1,"result":{"content":"key=AKIAIOSFODNN7EXAMPLE"}}` fwd, _ := runOutboundPipe(t, msg+"\n") diff --git a/internal/jsonrpc/pipe.go b/internal/jsonrpc/pipe.go index d0962d4..f40ebdf 100644 --- a/internal/jsonrpc/pipe.go +++ b/internal/jsonrpc/pipe.go @@ -77,7 +77,7 @@ func PipeInspect(log *logger.Logger, engine *rules.Engine, src io.Reader, } if !msg.IsRequest() { - // Response DLP: scan successful responses for leaked secrets. + // Response DLP: scan responses for leaked secrets. // Errors go to fwdWriter (client) because the client is waiting // for this response — the server doesn't need to know. if len(msg.Result) > 0 { @@ -89,6 +89,15 @@ func PipeInspect(log *logger.Logger, engine *rules.Engine, src io.Reader, continue } } + if len(msg.Error) > 0 { + if dlpResult := engine.ScanDLP(string(msg.Error)); dlpResult != nil { + log.Warn("Blocked %s error response (DLP): rule=%s message=%s", + protocol, dlpResult.RuleName, dlpResult.Message) + SendBlockError(log, fwdWriter, msg.ID, + fmt.Sprintf("[Crust] Blocked by rule %q: %s", dlpResult.RuleName, dlpResult.Message)) + continue + } + } if err := fwdWriter.WriteLine(line); err != nil { log.Debug("%s write error: %v", label, err) return