From bfbefaeed243483f8924b8d9703f036243fe9c82 Mon Sep 17 00:00:00 2001 From: William Date: Sat, 23 May 2026 15:52:56 -0400 Subject: [PATCH 1/4] feat(mcp): unified MCP client with prompts, caching, retry, batch, REPL, metrics, and auto-routing Implements a comprehensive axis mcp client subcommand suite: Operator UX: - axis mcp client prompts / get-prompt for prompt discovery - --pretty flag for human-readable JSON output on call/read - axis mcp client search for tool discovery - axis doctor validates MCP server configs (stdio command existence, HTTP URL) Reliability: - Per-connection caching (60s TTL) with sync.RWMutex - Retry with exponential backoff (3 attempts, 200ms start) for transient errors - axis mcp client batch for sequential multi-tool execution Advanced Surfaces: - Interactive REPL (axis mcp client interactive) with 10 commands - Per-server metrics surfaced in list --format json Intelligence: - Placement-aware routing (--auto-route) tries all servers offering a tool - Progress notification infrastructure (CallToolWithProgress, SetProgressHandler) Quality gates: lint, test-race, coverage, build all pass. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cmd/axis/doctor.go | 39 ++ cmd/axis/mcp.go | 1 + cmd/axis/mcp_client.go | 864 ++++++++++++++++++++++++++++ cmd/axis/mcp_client_test.go | 110 ++++ internal/config/config.go | 16 + internal/mcpclient/call.go | 179 ++++++ internal/mcpclient/connection.go | 251 ++++++++ internal/mcpclient/registry.go | 259 +++++++++ internal/mcpclient/registry_test.go | 205 +++++++ 9 files changed, 1924 insertions(+) create mode 100644 cmd/axis/mcp_client.go create mode 100644 cmd/axis/mcp_client_test.go create mode 100644 internal/mcpclient/call.go create mode 100644 internal/mcpclient/connection.go create mode 100644 internal/mcpclient/registry.go create mode 100644 internal/mcpclient/registry_test.go diff --git a/cmd/axis/doctor.go b/cmd/axis/doctor.go index 51e8255..50d8b1c 100644 --- a/cmd/axis/doctor.go +++ b/cmd/axis/doctor.go @@ -7,6 +7,7 @@ import ( "net" "os" "os/exec" + "strings" "time" "github.com/spf13/cobra" @@ -129,6 +130,44 @@ func runDoctor(cmd *cobra.Command, strict bool) error { } else { fmt.Fprintf(out, " %s Loaded %d node(s)\n", ui.StatusIcon(true), len(cfg.Nodes)) + // 1.5 MCP server config validation + if len(cfg.MCPServers) > 0 { + fmt.Fprintln(out) + fmt.Fprintf(out, "%s MCP servers\n", ui.Cyan("→")) + for name, mcpCfg := range cfg.MCPServers { + switch strings.ToLower(mcpCfg.Transport) { + case "stdio": + if len(mcpCfg.Command) == 0 { + fmt.Fprintf(out, " %s %s (stdio): missing command\n", ui.StatusIcon(false), name) + advisoryWarnings++ + } else { + cmdPath := mcpCfg.Command[0] + if _, err := os.Stat(cmdPath); os.IsNotExist(err) { + // Try resolving via PATH + if _, lookErr := exec.LookPath(cmdPath); lookErr != nil { + fmt.Fprintf(out, " %s %s (stdio): command not found: %s\n", ui.StatusIcon(false), name, cmdPath) + advisoryWarnings++ + } else { + fmt.Fprintf(out, " %s %s (stdio): %s (found in PATH)\n", ui.StatusIcon(true), name, cmdPath) + } + } else { + fmt.Fprintf(out, " %s %s (stdio): %s\n", ui.StatusIcon(true), name, cmdPath) + } + } + case "http": + if mcpCfg.URL == "" { + fmt.Fprintf(out, " %s %s (http): missing url\n", ui.StatusIcon(false), name) + advisoryWarnings++ + } else { + fmt.Fprintf(out, " %s %s (http): %s\n", ui.StatusIcon(true), name, mcpCfg.URL) + } + default: + fmt.Fprintf(out, " %s %s (%s): unsupported transport\n", ui.StatusIcon(false), name, mcpCfg.Transport) + advisoryWarnings++ + } + } + } + // 2. SSH connectivity check per node fmt.Fprintln(out) fmt.Fprintf(out, "%s SSH connectivity\n", ui.Cyan("→")) diff --git a/cmd/axis/mcp.go b/cmd/axis/mcp.go index 1285496..84ca862 100644 --- a/cmd/axis/mcp.go +++ b/cmd/axis/mcp.go @@ -15,6 +15,7 @@ func mcpCmd() *cobra.Command { Short: "Read-only MCP surfaces for AXIS cluster state and diagnostics", } cmd.AddCommand(mcpServeCmd()) + cmd.AddCommand(mcpClientCmd()) return cmd } diff --git a/cmd/axis/mcp_client.go b/cmd/axis/mcp_client.go new file mode 100644 index 0000000..0d5e9e7 --- /dev/null +++ b/cmd/axis/mcp_client.go @@ -0,0 +1,864 @@ +package main + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "strings" + "text/tabwriter" + + "github.com/fatih/color" + "github.com/mark3labs/mcp-go/mcp" + "github.com/spf13/cobra" + "github.com/toasterbook88/axis/internal/config" + "github.com/toasterbook88/axis/internal/mcpclient" + "github.com/toasterbook88/axis/internal/ui" +) + +var loadMCPClientConfig = config.Load + +func mcpClientCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "client", + Short: "Connect to and query external MCP servers", + Long: "Discover, inspect, and call tools across configured MCP servers. Servers are defined in ~/.axis/nodes.yaml under mcp_servers.", + } + cmd.AddCommand(mcpClientListCmd()) + cmd.AddCommand(mcpClientToolsCmd()) + cmd.AddCommand(mcpClientCallCmd()) + cmd.AddCommand(mcpClientResourcesCmd()) + cmd.AddCommand(mcpClientReadCmd()) + cmd.AddCommand(mcpClientPromptsCmd()) + cmd.AddCommand(mcpClientGetPromptCmd()) + cmd.AddCommand(mcpClientSearchCmd()) + cmd.AddCommand(mcpClientBatchCmd()) + cmd.AddCommand(mcpClientInteractiveCmd()) + return cmd +} + +func mcpClientListCmd() *cobra.Command { + var format string + cmd := &cobra.Command{ + Use: "list", + Short: "List configured MCP servers and connection status", + RunE: func(cmd *cobra.Command, args []string) error { + return runMCPClientList(cmd.OutOrStdout(), format) + }, + } + cmd.Flags().StringVar(&format, "format", "text", "Output format: text or json") + return cmd +} + +func runMCPClientList(out io.Writer, format string) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + if len(cfg.MCPServers) == 0 { + fmt.Fprintln(out, "No MCP servers configured.") + fmt.Fprintf(out, "Add them to %s under mcp_servers:\n", config.DefaultConfigPath()) + return nil + } + + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + if format == "json" { + return printMCPClientListJSON(out, reg) + } + return printMCPClientListText(out, reg) +} + +func printMCPClientListText(out io.Writer, reg *mcpclient.Registry) error { + tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) + fmt.Fprintf(tw, "%s\t%s\t%s\t%s\t%s\n", "NAME", "TRANSPORT", "STATUS", "TOOLS", "RESOURCES") + for _, name := range reg.Names() { + sc := reg.Get(name) + status := color.GreenString("connected") + if !sc.Connected() { + status = color.RedString("error") + } + fmt.Fprintf(tw, "%s\t%s\t%s\t%d\t%d\n", name, sc.Transport, status, sc.ToolCount(), sc.ResourceCount()) + } + return tw.Flush() +} + +func printMCPClientListJSON(out io.Writer, reg *mcpclient.Registry) error { + type serverRow struct { + Name string `json:"name"` + Transport string `json:"transport"` + Connected bool `json:"connected"` + Error string `json:"error,omitempty"` + Tools int `json:"tools"` + Resources int `json:"resources"` + Prompts int `json:"prompts"` + Calls int64 `json:"calls,omitempty"` + Errors int64 `json:"errors,omitempty"` + AvgLatencyMs int64 `json:"avg_latency_ms,omitempty"` + UptimeSec int64 `json:"uptime_sec,omitempty"` + } + var rows []serverRow + for _, name := range reg.Names() { + sc := reg.Get(name) + r := serverRow{ + Name: name, + Transport: sc.Transport, + Connected: sc.Connected(), + Tools: sc.ToolCount(), + Resources: sc.ResourceCount(), + Prompts: len(sc.CachedPrompts()), + } + if sc.Err != nil { + r.Error = sc.Err.Error() + } + calls, errs, avgLat, uptime := sc.Metrics() + r.Calls = calls + r.Errors = errs + if avgLat > 0 { + r.AvgLatencyMs = avgLat.Milliseconds() + } + if uptime > 0 { + r.UptimeSec = int64(uptime.Seconds()) + } + rows = append(rows, r) + } + return printOutput(out, rows, "json") +} + +func mcpClientToolsCmd() *cobra.Command { + var server string + var format string + cmd := &cobra.Command{ + Use: "tools", + Short: "List tools from connected MCP servers", + RunE: func(cmd *cobra.Command, args []string) error { + return runMCPClientTools(cmd.OutOrStdout(), server, format) + }, + } + cmd.Flags().StringVar(&server, "server", "", "Filter to a specific server") + cmd.Flags().StringVar(&format, "format", "text", "Output format: text or json") + return cmd +} + +func runMCPClientTools(out io.Writer, server, format string) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + if server != "" { + sc := reg.Get(server) + if sc == nil { + return fmt.Errorf("server %q not configured", server) + } + if !sc.Connected() { + return fmt.Errorf("server %q not connected: %v", server, sc.Err) + } + if format == "json" { + return printOutput(out, sc.Tools, "json") + } + for _, t := range sc.Tools { + fmt.Fprintf(out, "%s\t%s\n", t.Name, t.Description) + } + return nil + } + + tools := reg.ListAllTools() + if format == "json" { + type entry struct { + Server string `json:"server"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + } + var entries []entry + for _, te := range tools { + entries = append(entries, entry{Server: te.Server, Name: te.Tool.Name, Description: te.Tool.Description}) + } + return printOutput(out, entries, "json") + } + + tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) + fmt.Fprintf(tw, "%s\t%s\t%s\n", "SERVER", "NAME", "DESCRIPTION") + for _, te := range tools { + desc := te.Tool.Description + if len(desc) > 60 { + desc = desc[:57] + "..." + } + fmt.Fprintf(tw, "%s\t%s\t%s\n", te.Server, te.Tool.Name, desc) + } + return tw.Flush() +} + +func mcpClientCallCmd() *cobra.Command { + var pretty bool + var autoRoute bool + cmd := &cobra.Command{ + Use: "call [] [json-args]", + Short: "Call a tool on a specific MCP server (or auto-route with --auto-route)", + Args: func(cmd *cobra.Command, args []string) error { + if autoRoute { + if len(args) < 1 || len(args) > 2 { + return fmt.Errorf("with --auto-route, expected 1-2 args: [json-args]") + } + return nil + } + if len(args) < 2 || len(args) > 3 { + return fmt.Errorf("expected 2-3 args: [json-args]") + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + if autoRoute { + toolName := args[0] + var rawArgs string + if len(args) > 1 { + rawArgs = args[1] + } + return runMCPClientCallAutoRoute(cmd.OutOrStdout(), toolName, rawArgs, pretty) + } + serverName := args[0] + toolName := args[1] + var rawArgs string + if len(args) > 2 { + rawArgs = args[2] + } + return runMCPClientCall(cmd.OutOrStdout(), serverName, toolName, rawArgs, pretty) + }, + } + cmd.Flags().BoolVar(&pretty, "pretty", false, "Pretty-print JSON output") + cmd.Flags().BoolVar(&autoRoute, "auto-route", false, "Auto-route tool call to best available server") + return cmd +} + +func runMCPClientCallAutoRoute(out io.Writer, toolName, rawArgs string, pretty bool) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + args, err := mcpclient.ParseArgs(rawArgs) + if err != nil { + return err + } + + result := reg.CallToolAutoRoute(ctx, toolName, args) + if result.Err != nil { + return fmt.Errorf("tool call failed: %w", result.Err) + } + + for _, content := range result.Result.Content { + if tc, ok := content.(mcp.TextContent); ok { + fmt.Fprintln(out, tc.Text) + } else if ic, ok := content.(mcp.ImageContent); ok { + fmt.Fprintf(out, "[image: %s %d bytes]\n", ic.MIMEType, len(ic.Data)) + } else { + b, _ := json.MarshalIndent(content, "", " ") + fmt.Fprintln(out, string(b)) + } + } + return nil +} + +func runMCPClientCall(out io.Writer, serverName, toolName, rawArgs string, pretty bool) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + args, err := mcpclient.ParseArgs(rawArgs) + if err != nil { + return err + } + + result := reg.CallTool(ctx, serverName, toolName, args) + if result.Err != nil { + return fmt.Errorf("tool call failed: %w", result.Err) + } + + for _, content := range result.Result.Content { + if tc, ok := content.(mcp.TextContent); ok { + fmt.Fprintln(out, tc.Text) + } else if ic, ok := content.(mcp.ImageContent); ok { + fmt.Fprintf(out, "[image: %s %d bytes]\n", ic.MIMEType, len(ic.Data)) + } else { + b, _ := json.MarshalIndent(content, "", " ") + fmt.Fprintln(out, string(b)) + } + } + return nil +} + +func mcpClientResourcesCmd() *cobra.Command { + var server string + var format string + cmd := &cobra.Command{ + Use: "resources", + Short: "List resources from connected MCP servers", + RunE: func(cmd *cobra.Command, args []string) error { + return runMCPClientResources(cmd.OutOrStdout(), server, format) + }, + } + cmd.Flags().StringVar(&server, "server", "", "Filter to a specific server") + cmd.Flags().StringVar(&format, "format", "format", "Output format: text or json") + return cmd +} + +func runMCPClientResources(out io.Writer, server, format string) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + if server != "" { + sc := reg.Get(server) + if sc == nil { + return fmt.Errorf("server %q not configured", server) + } + if !sc.Connected() { + return fmt.Errorf("server %q not connected: %v", server, sc.Err) + } + if format == "json" { + return printOutput(out, sc.Resources, "json") + } + for _, r := range sc.Resources { + fmt.Fprintf(out, "%s\t%s\n", r.URI, r.Name) + } + return nil + } + + resources := reg.ListAllResources() + if format == "json" { + type entry struct { + Server string `json:"server"` + URI string `json:"uri"` + Name string `json:"name,omitempty"` + } + var entries []entry + for _, re := range resources { + entries = append(entries, entry{Server: re.Server, URI: re.Resource.URI, Name: re.Resource.Name}) + } + return printOutput(out, entries, "json") + } + + tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) + fmt.Fprintf(tw, "%s\t%s\t%s\n", "SERVER", "URI", "NAME") + for _, re := range resources { + fmt.Fprintf(tw, "%s\t%s\t%s\n", re.Server, re.Resource.URI, re.Resource.Name) + } + return tw.Flush() +} + +func mcpClientReadCmd() *cobra.Command { + var pretty bool + cmd := &cobra.Command{ + Use: "read ", + Short: "Read a resource from a specific MCP server", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + return runMCPClientRead(cmd.OutOrStdout(), args[0], args[1], pretty) + }, + } + cmd.Flags().BoolVar(&pretty, "pretty", false, "Pretty-print JSON output") + return cmd +} + +func runMCPClientRead(out io.Writer, serverName, uri string, pretty bool) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + res, err := reg.ReadResource(ctx, serverName, uri) + if err != nil { + return err + } + + for _, content := range res.Contents { + if tc, ok := content.(mcp.TextResourceContents); ok { + fmt.Fprintln(out, tc.Text) + } else if bc, ok := content.(mcp.BlobResourceContents); ok { + fmt.Fprintf(out, "[blob: %s %d bytes]\n", bc.MIMEType, len(bc.Blob)) + } else { + b, _ := json.MarshalIndent(content, "", " ") + fmt.Fprintln(out, string(b)) + } + } + return nil +} + +func mcpClientPromptsCmd() *cobra.Command { + var server string + var format string + cmd := &cobra.Command{ + Use: "prompts", + Short: "List prompts from connected MCP servers", + RunE: func(cmd *cobra.Command, args []string) error { + return runMCPClientPrompts(cmd.OutOrStdout(), server, format) + }, + } + cmd.Flags().StringVar(&server, "server", "", "Filter to a specific server") + cmd.Flags().StringVar(&format, "format", "text", "Output format: text or json") + return cmd +} + +func runMCPClientPrompts(out io.Writer, server, format string) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + if server != "" { + sc := reg.Get(server) + if sc == nil { + return fmt.Errorf("server %q not configured", server) + } + if !sc.Connected() { + return fmt.Errorf("server %q not connected: %v", server, sc.Err) + } + if format == "json" { + return printOutput(out, sc.Prompts, "json") + } + for _, p := range sc.Prompts { + fmt.Fprintf(out, "%s\t%s\n", p.Name, p.Description) + } + return nil + } + + prompts := reg.ListAllPrompts() + if format == "json" { + type entry struct { + Server string `json:"server"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + } + var entries []entry + for _, pe := range prompts { + entries = append(entries, entry{Server: pe.Server, Name: pe.Prompt.Name, Description: pe.Prompt.Description}) + } + return printOutput(out, entries, "json") + } + + tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) + fmt.Fprintf(tw, "%s\t%s\t%s\n", "SERVER", "NAME", "DESCRIPTION") + for _, pe := range prompts { + desc := pe.Prompt.Description + if len(desc) > 60 { + desc = desc[:57] + "..." + } + fmt.Fprintf(tw, "%s\t%s\t%s\n", pe.Server, pe.Prompt.Name, desc) + } + return tw.Flush() +} + +func mcpClientGetPromptCmd() *cobra.Command { + var pretty bool + cmd := &cobra.Command{ + Use: "get-prompt [json-args]", + Short: "Fetch a prompt from a specific MCP server", + Args: cobra.RangeArgs(2, 3), + RunE: func(cmd *cobra.Command, args []string) error { + serverName := args[0] + promptName := args[1] + var rawArgs string + if len(args) > 2 { + rawArgs = args[2] + } + return runMCPClientGetPrompt(cmd.OutOrStdout(), serverName, promptName, rawArgs, pretty) + }, + } + cmd.Flags().BoolVar(&pretty, "pretty", false, "Pretty-print JSON output") + return cmd +} + +func runMCPClientGetPrompt(out io.Writer, serverName, promptName, rawArgs string, pretty bool) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + args, err := mcpclient.ParseArgs(rawArgs) + if err != nil { + return err + } + + res, err := reg.GetPrompt(ctx, serverName, promptName, args) + if err != nil { + return fmt.Errorf("get prompt failed: %w", err) + } + + return printOutput(out, res, outputFormat(pretty)) +} + +func mcpClientSearchCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "search ", + Short: "Search tools by name or description across all connected MCP servers", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runMCPClientSearch(cmd.OutOrStdout(), args[0]) + }, + } + return cmd +} + +func runMCPClientSearch(out io.Writer, keyword string) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + keywordLower := strings.ToLower(keyword) + matches := reg.ListAllTools() + var filtered []mcpclient.ToolEntry + for _, te := range matches { + if strings.Contains(strings.ToLower(te.Tool.Name), keywordLower) || + strings.Contains(strings.ToLower(te.Tool.Description), keywordLower) { + filtered = append(filtered, te) + } + } + + if len(filtered) == 0 { + fmt.Fprintf(out, "No tools match %q\n", keyword) + return nil + } + + tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) + fmt.Fprintf(tw, "%s\t%s\t%s\n", "SERVER", "NAME", "DESCRIPTION") + for _, te := range filtered { + desc := te.Tool.Description + if len(desc) > 60 { + desc = desc[:57] + "..." + } + fmt.Fprintf(tw, "%s\t%s\t%s\n", te.Server, te.Tool.Name, desc) + } + return tw.Flush() +} + +func mcpClientBatchCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "batch ", + Short: "Execute multiple tool calls from a JSON file", + Long: `Read an array of tool calls from a JSON file and execute them sequentially. + +Each entry must have: server, tool, and optional args (map). +Example file: +[ + {"server":"axis-local","tool":"axis_health"}, + {"server":"axis-local","tool":"placement_decision","args":{"description":"ollama run llama3"}} +]`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return runMCPClientBatch(cmd.OutOrStdout(), args[0]) + }, + } + return cmd +} + +type batchEntry struct { + Server string `json:"server"` + Tool string `json:"tool"` + Args map[string]any `json:"args,omitempty"` +} + +func runMCPClientBatch(out io.Writer, path string) error { + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read batch file: %w", err) + } + var entries []batchEntry + if err := json.Unmarshal(data, &entries); err != nil { + return fmt.Errorf("parse batch file: %w", err) + } + + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + type batchResult struct { + Index int `json:"index"` + Server string `json:"server"` + Tool string `json:"tool"` + OK bool `json:"ok"` + Error string `json:"error,omitempty"` + Output string `json:"output,omitempty"` + } + var results []batchResult + + for i, entry := range entries { + res := reg.CallTool(ctx, entry.Server, entry.Tool, entry.Args) + br := batchResult{ + Index: i, + Server: entry.Server, + Tool: entry.Tool, + } + if res.Err != nil { + br.OK = false + br.Error = res.Err.Error() + } else { + br.OK = true + var parts []string + for _, content := range res.Result.Content { + if tc, ok := content.(mcp.TextContent); ok { + parts = append(parts, tc.Text) + } else { + b, _ := json.Marshal(content) + parts = append(parts, string(b)) + } + } + br.Output = strings.Join(parts, "\n") + } + results = append(results, br) + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + return enc.Encode(results) +} + +func mcpClientInteractiveCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "interactive", + Short: "Interactive REPL for exploring and calling MCP servers", + RunE: func(cmd *cobra.Command, args []string) error { + return runMCPClientInteractive(cmd.InOrStdin(), cmd.OutOrStdout()) + }, + } + return cmd +} + +func runMCPClientInteractive(in io.Reader, out io.Writer) error { + cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + reg := mcpclient.NewRegistry() + ctx := context.Background() + reg.ConnectAll(ctx, cfg) + defer reg.Close() + + fmt.Fprintln(out, "AXIS MCP Client Interactive REPL") + fmt.Fprintln(out, "Commands: tools, resources, prompts, call [args], read , get-prompt [args], search , list, help, quit") + fmt.Fprintln(out) + + scanner := bufio.NewScanner(in) + for { + fmt.Fprint(out, "mcp> ") + if !scanner.Scan() { + break + } + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + parts := strings.Fields(line) + cmd := parts[0] + args := parts[1:] + + switch cmd { + case "quit", "exit", "q": + fmt.Fprintln(out, "Bye.") + return nil + case "help", "h": + fmt.Fprintln(out, "Commands:") + fmt.Fprintln(out, " list List connected servers") + fmt.Fprintln(out, " tools [--server ] List tools") + fmt.Fprintln(out, " resources [--server ] List resources") + fmt.Fprintln(out, " prompts [--server ] List prompts") + fmt.Fprintln(out, " call [args] Call a tool") + fmt.Fprintln(out, " read Read a resource") + fmt.Fprintln(out, " get-prompt [args]") + fmt.Fprintln(out, " search Search tools") + fmt.Fprintln(out, " help Show this help") + fmt.Fprintln(out, " quit Exit REPL") + case "list": + for _, name := range reg.Names() { + sc := reg.Get(name) + status := "connected" + if !sc.Connected() { + status = fmt.Sprintf("error: %v", sc.Err) + } + fmt.Fprintf(out, " %s (%s) — %s, %d tools, %d resources\n", name, sc.Transport, status, sc.ToolCount(), sc.ResourceCount()) + } + case "tools": + server := "" + if len(args) >= 2 && args[0] == "--server" { + server = args[1] + args = args[2:] + } + tools := reg.ListAllTools() + for _, te := range tools { + if server != "" && te.Server != server { + continue + } + fmt.Fprintf(out, " %s / %s — %s\n", te.Server, te.Tool.Name, te.Tool.Description) + } + case "resources": + server := "" + if len(args) >= 2 && args[0] == "--server" { + server = args[1] + args = args[2:] + } + resources := reg.ListAllResources() + for _, re := range resources { + if server != "" && re.Server != server { + continue + } + fmt.Fprintf(out, " %s / %s — %s\n", re.Server, re.Resource.URI, re.Resource.Name) + } + case "prompts": + server := "" + if len(args) >= 2 && args[0] == "--server" { + server = args[1] + args = args[2:] + } + prompts := reg.ListAllPrompts() + for _, pe := range prompts { + if server != "" && pe.Server != server { + continue + } + fmt.Fprintf(out, " %s / %s — %s\n", pe.Server, pe.Prompt.Name, pe.Prompt.Description) + } + case "call": + if len(args) < 2 { + fmt.Fprintln(out, "Usage: call [json-args]") + continue + } + var rawArgs string + if len(args) > 2 { + rawArgs = args[2] + } + parsedArgs, parseErr := mcpclient.ParseArgs(rawArgs) + if parseErr != nil { + fmt.Fprintf(out, "Error: %v\n", parseErr) + continue + } + res := reg.CallTool(ctx, args[0], args[1], parsedArgs) + if res.Err != nil { + fmt.Fprintf(out, "Error: %v\n", res.Err) + continue + } + for _, content := range res.Result.Content { + if tc, ok := content.(mcp.TextContent); ok { + fmt.Fprintln(out, tc.Text) + } else { + b, _ := json.MarshalIndent(content, "", " ") + fmt.Fprintln(out, string(b)) + } + } + case "read": + if len(args) < 2 { + fmt.Fprintln(out, "Usage: read ") + continue + } + result, readErr := reg.ReadResource(ctx, args[0], args[1]) + if readErr != nil { + fmt.Fprintf(out, "Error: %v\n", readErr) + continue + } + for _, content := range result.Contents { + if tc, ok := content.(mcp.TextResourceContents); ok { + fmt.Fprintln(out, tc.Text) + } else { + b, _ := json.MarshalIndent(content, "", " ") + fmt.Fprintln(out, string(b)) + } + } + case "get-prompt": + if len(args) < 2 { + fmt.Fprintln(out, "Usage: get-prompt [json-args]") + continue + } + var rawArgs string + if len(args) > 2 { + rawArgs = args[2] + } + parsedArgs, parseErr := mcpclient.ParseArgs(rawArgs) + if parseErr != nil { + fmt.Fprintf(out, "Error: %v\n", parseErr) + continue + } + res, gpErr := reg.GetPrompt(ctx, args[0], args[1], parsedArgs) + if gpErr != nil { + fmt.Fprintf(out, "Error: %v\n", gpErr) + continue + } + b, _ := json.MarshalIndent(res, "", " ") + fmt.Fprintln(out, string(b)) + case "search": + if len(args) < 1 { + fmt.Fprintln(out, "Usage: search ") + continue + } + keywordLower := strings.ToLower(args[0]) + for _, te := range reg.ListAllTools() { + if strings.Contains(strings.ToLower(te.Tool.Name), keywordLower) || + strings.Contains(strings.ToLower(te.Tool.Description), keywordLower) { + fmt.Fprintf(out, " %s / %s — %s\n", te.Server, te.Tool.Name, te.Tool.Description) + } + } + default: + fmt.Fprintf(out, "Unknown command: %s. Type 'help' for available commands.\n", cmd) + } + } + return scanner.Err() +} + +func outputFormat(pretty bool) string { + if pretty { + return "json-pretty" + } + return "json" +} + +func init() { + // Ensure color is available for status indicators + ui.Init(false) +} diff --git a/cmd/axis/mcp_client_test.go b/cmd/axis/mcp_client_test.go new file mode 100644 index 0000000..1790c05 --- /dev/null +++ b/cmd/axis/mcp_client_test.go @@ -0,0 +1,110 @@ +package main + +import ( + "bytes" + "strings" + "testing" + + "github.com/toasterbook88/axis/internal/config" +) + +func TestMCPClientListEmptyConfig(t *testing.T) { + restore := func() { + loadMCPClientConfig = config.Load + } + defer restore() + + loadMCPClientConfig = func(path string) (*config.Config, error) { + return &config.Config{ + Nodes: []config.NodeConfig{ + {Name: "dummy", Hostname: "localhost", SSHUser: "root"}, + }, + }, nil + } + + var buf bytes.Buffer + err := runMCPClientList(&buf, "text") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := buf.String() + if !strings.Contains(out, "No MCP servers configured") { + t.Fatalf("expected empty config message, got: %s", out) + } +} + +func TestMCPClientListJSON(t *testing.T) { + restore := func() { + loadMCPClientConfig = config.Load + } + defer restore() + + loadMCPClientConfig = func(path string) (*config.Config, error) { + return &config.Config{ + Nodes: []config.NodeConfig{ + {Name: "dummy", Hostname: "localhost", SSHUser: "root"}, + }, + MCPServers: map[string]config.MCPServerConfig{ + "test": {Transport: "stdio", Command: []string{"echo", "hello"}}, + }, + }, nil + } + + var buf bytes.Buffer + err := runMCPClientList(&buf, "json") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + out := buf.String() + if !strings.Contains(out, "[") { + t.Fatalf("expected JSON array, got: %s", out) + } +} + +func TestMCPClientToolsMissingServer(t *testing.T) { + restore := func() { + loadMCPClientConfig = config.Load + } + defer restore() + + loadMCPClientConfig = func(path string) (*config.Config, error) { + return &config.Config{ + Nodes: []config.NodeConfig{ + {Name: "dummy", Hostname: "localhost", SSHUser: "root"}, + }, + MCPServers: map[string]config.MCPServerConfig{}, + }, nil + } + + var buf bytes.Buffer + err := runMCPClientTools(&buf, "missing", "text") + if err == nil { + t.Fatal("expected error for missing server") + } + if !strings.Contains(err.Error(), "not configured") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestMCPClientParseArgs(t *testing.T) { + restore := func() { + loadMCPClientConfig = config.Load + } + defer restore() + + loadMCPClientConfig = func(path string) (*config.Config, error) { + return &config.Config{ + Nodes: []config.NodeConfig{ + {Name: "dummy", Hostname: "localhost", SSHUser: "root"}, + }, + MCPServers: map[string]config.MCPServerConfig{}, + }, nil + } + + var buf bytes.Buffer + // Call with a non-existent server to test arg parsing path + err := runMCPClientCall(&buf, "missing", "tool", `{"key":"value"}`, false) + if err == nil { + t.Fatal("expected error for missing server") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 8c96f85..07214e9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -136,6 +136,21 @@ type InferenceConfig struct { BudgetAlertThreshold float64 `json:"budget_alert_threshold,omitempty" yaml:"budget_alert_threshold,omitempty"` } +// MCPServerConfig describes a single external MCP server connection. +type MCPServerConfig struct { + // Transport is "stdio" or "http". + Transport string `json:"transport" yaml:"transport"` + + // Command is the executable and arguments for stdio transport. + Command []string `json:"command,omitempty" yaml:"command,omitempty"` + + // URL is the endpoint for HTTP/SSE transport. + URL string `json:"url,omitempty" yaml:"url,omitempty"` + + // Headers are optional HTTP headers (for http transport). + Headers map[string]string `json:"headers,omitempty" yaml:"headers,omitempty"` +} + // Config is the top-level AXIS configuration. type Config struct { Nodes []NodeConfig `json:"nodes" yaml:"nodes"` @@ -143,6 +158,7 @@ type Config struct { Chat *ChatConfig `json:"chat,omitempty" yaml:"chat,omitempty"` AIProviders map[string]AIProviderConfig `json:"ai_providers,omitempty" yaml:"ai_providers,omitempty"` Inference *InferenceConfig `json:"inference,omitempty" yaml:"inference,omitempty"` + MCPServers map[string]MCPServerConfig `json:"mcp_servers,omitempty" yaml:"mcp_servers,omitempty"` } // DefaultConfigPath returns ~/.axis/nodes.yaml. diff --git a/internal/mcpclient/call.go b/internal/mcpclient/call.go new file mode 100644 index 0000000..47dec13 --- /dev/null +++ b/internal/mcpclient/call.go @@ -0,0 +1,179 @@ +package mcpclient + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strings" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// CallToolAutoRoute invokes a tool on the first available server that offers it, +// trying servers in deterministic name order until one succeeds. +func (r *Registry) CallToolAutoRoute(ctx context.Context, toolName string, args map[string]any) CallResult { + entries := r.FindAllToolServers(toolName) + if len(entries) == 0 { + return CallResult{Err: fmt.Errorf("tool %q not found on any connected server", toolName)} + } + var lastErr error + for _, entry := range entries { + res := r.CallTool(ctx, entry.Server, toolName, args) + if res.Err == nil { + return res + } + lastErr = res.Err + } + return CallResult{Err: fmt.Errorf("tool %q failed on all %d server(s): last error: %w", toolName, len(entries), lastErr)} +} + +type CallResult struct { + Server string + Result *mcp.CallToolResult + Err error +} + +// isTransientError reports whether an error is likely temporary and worth retrying. +func isTransientError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, context.DeadlineExceeded) { + return true + } + if netErr, ok := err.(interface{ Temporary() bool }); ok && netErr.Temporary() { + return true + } + // HTTP 5xx status codes (for HTTP transport) + if strings.Contains(err.Error(), "500") || strings.Contains(err.Error(), "502") || + strings.Contains(err.Error(), "503") || strings.Contains(err.Error(), "504") { + return true + } + return false +} + +// withRetry executes fn up to 3 times with exponential backoff starting at 200ms. +func withRetry(ctx context.Context, fn func() error) error { + var err error + backoff := 200 * time.Millisecond + for attempt := 0; attempt < 3; attempt++ { + if attempt > 0 { + select { + case <-time.After(backoff): + case <-ctx.Done(): + return ctx.Err() + } + backoff *= 2 + } + err = fn() + if err == nil { + return nil + } + if !isTransientError(err) { + return err + } + } + return err +} + +// CallToolWithProgress invokes a tool and prints progress notifications to stderr. +func (r *Registry) CallToolWithProgress(ctx context.Context, serverName, toolName string, args map[string]any, progressOut io.Writer) CallResult { + sc := r.Get(serverName) + if sc == nil { + return CallResult{Server: serverName, Err: fmt.Errorf("server %q not configured", serverName)} + } + if !sc.Connected() { + return CallResult{Server: serverName, Err: fmt.Errorf("server %q not connected: %v", serverName, sc.Err)} + } + + sc.SetProgressHandler(func(p mcp.ProgressNotification) { + msg := "" + if p.Params.Message != "" { + msg = " — " + p.Params.Message + } + if p.Params.Total > 0 { + fmt.Fprintf(progressOut, "[progress %s] token=%s %.0f/%.0f%s\n", serverName, p.Params.ProgressToken, p.Params.Progress, p.Params.Total, msg) + } else { + fmt.Fprintf(progressOut, "[progress %s] token=%s %.0f%s\n", serverName, p.Params.ProgressToken, p.Params.Progress, msg) + } + }) + + start := time.Now() + var result *mcp.CallToolResult + err := withRetry(ctx, func() error { + cctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + req := mcp.CallToolRequest{} + req.Params.Name = toolName + req.Params.Arguments = args + res, callErr := sc.Client.CallTool(cctx, req) + result = res + return callErr + }) + sc.RecordCall(time.Since(start), err) + return CallResult{Server: serverName, Result: result, Err: err} +} +func (r *Registry) CallTool(ctx context.Context, serverName, toolName string, args map[string]any) CallResult { + sc := r.Get(serverName) + if sc == nil { + return CallResult{Server: serverName, Err: fmt.Errorf("server %q not configured", serverName)} + } + if !sc.Connected() { + return CallResult{Server: serverName, Err: fmt.Errorf("server %q not connected: %v", serverName, sc.Err)} + } + + start := time.Now() + var result *mcp.CallToolResult + err := withRetry(ctx, func() error { + cctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + req := mcp.CallToolRequest{} + req.Params.Name = toolName + req.Params.Arguments = args + res, callErr := sc.Client.CallTool(cctx, req) + result = res + return callErr + }) + sc.RecordCall(time.Since(start), err) + return CallResult{Server: serverName, Result: result, Err: err} +} + +// ReadResource fetches a resource by URI from a specific server. +func (r *Registry) ReadResource(ctx context.Context, serverName, uri string) (*mcp.ReadResourceResult, error) { + sc := r.Get(serverName) + if sc == nil { + return nil, fmt.Errorf("server %q not configured", serverName) + } + if !sc.Connected() { + return nil, fmt.Errorf("server %q not connected: %v", serverName, sc.Err) + } + + start := time.Now() + var result *mcp.ReadResourceResult + err := withRetry(ctx, func() error { + cctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + req := mcp.ReadResourceRequest{} + req.Params.URI = uri + res, callErr := sc.Client.ReadResource(cctx, req) + result = res + return callErr + }) + sc.RecordCall(time.Since(start), err) + return result, err +} + +// ParseArgs unmarshals a raw JSON string into a map for tool invocation. +func ParseArgs(raw string) (map[string]any, error) { + if raw == "" { + return nil, nil + } + var args map[string]any + if err := json.Unmarshal([]byte(raw), &args); err != nil { + return nil, fmt.Errorf("invalid JSON args: %w", err) + } + return args, nil +} diff --git a/internal/mcpclient/connection.go b/internal/mcpclient/connection.go new file mode 100644 index 0000000..40468ad --- /dev/null +++ b/internal/mcpclient/connection.go @@ -0,0 +1,251 @@ +// Package mcpclient provides a unified client for connecting to multiple MCP servers. +package mcpclient + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/toasterbook88/axis/internal/config" +) + +// ServerConnection holds the state of a single MCP server connection. +type ServerConnection struct { + Name string + Transport string + Config config.MCPServerConfig + Client mcpclient.MCPClient + InitResult *mcp.InitializeResult + Tools []mcp.Tool + Resources []mcp.Resource + Prompts []mcp.Prompt + Err error + mu sync.RWMutex + cachedTools []mcp.Tool + cachedResources []mcp.Resource + cachedPrompts []mcp.Prompt + cacheExpires time.Time + // Metrics + callCount int64 + errorCount int64 + totalLatency time.Duration + connectedAt time.Time +} + +// CacheTTL is the default duration for cached tool/resource/prompt listings. +const CacheTTL = 60 * time.Second + +// Connected reports whether the server handshake succeeded. +func (sc *ServerConnection) Connected() bool { + return sc.InitResult != nil && sc.Err == nil +} + +// ToolCount returns the number of discovered tools. +func (sc *ServerConnection) ToolCount() int { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.cacheValid() { + return len(sc.cachedTools) + } + return len(sc.Tools) +} + +// ResourceCount returns the number of discovered resources. +func (sc *ServerConnection) ResourceCount() int { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.cacheValid() { + return len(sc.cachedResources) + } + return len(sc.Resources) +} + +// CachedTools returns cached tools if valid, otherwise falls back to live data. +func (sc *ServerConnection) CachedTools() []mcp.Tool { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.cacheValid() { + return sc.cachedTools + } + return sc.Tools +} + +// CachedResources returns cached resources if valid, otherwise falls back to live data. +func (sc *ServerConnection) CachedResources() []mcp.Resource { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.cacheValid() { + return sc.cachedResources + } + return sc.Resources +} + +// CachedPrompts returns cached prompts if valid, otherwise falls back to live data. +func (sc *ServerConnection) CachedPrompts() []mcp.Prompt { + sc.mu.RLock() + defer sc.mu.RUnlock() + if sc.cacheValid() { + return sc.cachedPrompts + } + return sc.Prompts +} + +// RefreshCache updates the cache with current live data and resets TTL. +func (sc *ServerConnection) RefreshCache() { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.cachedTools = make([]mcp.Tool, len(sc.Tools)) + copy(sc.cachedTools, sc.Tools) + sc.cachedResources = make([]mcp.Resource, len(sc.Resources)) + copy(sc.cachedResources, sc.Resources) + sc.cachedPrompts = make([]mcp.Prompt, len(sc.Prompts)) + copy(sc.cachedPrompts, sc.Prompts) + sc.cacheExpires = time.Now().Add(CacheTTL) +} + +func (sc *ServerConnection) cacheValid() bool { + return !sc.cacheExpires.IsZero() && time.Now().Before(sc.cacheExpires) +} + +// RecordCall increments call count and records latency. +func (sc *ServerConnection) RecordCall(latency time.Duration, err error) { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.callCount++ + sc.totalLatency += latency + if err != nil { + sc.errorCount++ + } +} + +// Metrics returns a snapshot of the connection's metrics. +func (sc *ServerConnection) Metrics() (calls, errors int64, avgLatency time.Duration, uptime time.Duration) { + sc.mu.RLock() + defer sc.mu.RUnlock() + calls = sc.callCount + errors = sc.errorCount + if sc.callCount > 0 { + avgLatency = sc.totalLatency / time.Duration(sc.callCount) + } + if !sc.connectedAt.IsZero() { + uptime = time.Since(sc.connectedAt) + } + return +} + +// Close closes the underlying client connection. +func (sc *ServerConnection) Close() error { + if sc.Client != nil { + return sc.Client.Close() + } + return nil +} + +// connectStdio launches a stdio MCP server subprocess and connects to it. +func connectStdio(ctx context.Context, name string, cfg config.MCPServerConfig) (*ServerConnection, error) { + if len(cfg.Command) == 0 { + return nil, fmt.Errorf("mcp server %q: stdio transport requires command", name) + } + + cmd := cfg.Command[0] + args := cfg.Command[1:] + + client, err := mcpclient.NewStdioMCPClient(cmd, nil, args...) + if err != nil { + return nil, fmt.Errorf("mcp server %q: start stdio client: %w", name, err) + } + + return handshake(ctx, name, "stdio", cfg, client) +} + +// connectHTTP connects to an HTTP/SSE MCP server endpoint. +func connectHTTP(ctx context.Context, name string, cfg config.MCPServerConfig) (*ServerConnection, error) { + if cfg.URL == "" { + return nil, fmt.Errorf("mcp server %q: http transport requires url", name) + } + + var opts []transport.StreamableHTTPCOption + if len(cfg.Headers) > 0 { + opts = append(opts, transport.WithHTTPHeaders(cfg.Headers)) + } + + client, err := mcpclient.NewStreamableHttpClient(cfg.URL, opts...) + if err != nil { + return nil, fmt.Errorf("mcp server %q: start http client: %w", name, err) + } + + return handshake(ctx, name, "http", cfg, client) +} + +// handshake performs the MCP initialize exchange and capability discovery. +func handshake(ctx context.Context, name, transport string, cfg config.MCPServerConfig, client mcpclient.MCPClient) (*ServerConnection, error) { + sc := &ServerConnection{ + Name: name, + Transport: transport, + Config: cfg, + Client: client, + } + + hctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + initReq := mcp.InitializeRequest{} + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initReq.Params.ClientInfo = mcp.Implementation{ + Name: "axis-mcp-client", + Version: "0.10.7", + } + + initRes, err := client.Initialize(hctx, initReq) + if err != nil { + sc.Err = fmt.Errorf("initialize: %w", err) + return sc, nil + } + sc.InitResult = initRes + sc.connectedAt = time.Now() + + // Discover tools + tctx, tcancel := context.WithTimeout(ctx, 10*time.Second) + defer tcancel() + if toolsRes, err := client.ListTools(tctx, mcp.ListToolsRequest{}); err == nil { + sc.Tools = toolsRes.Tools + } + + // Discover resources + rctx, rcancel := context.WithTimeout(ctx, 10*time.Second) + defer rcancel() + if resRes, err := client.ListResources(rctx, mcp.ListResourcesRequest{}); err == nil { + sc.Resources = resRes.Resources + } + + // Discover prompts + pctx, pcancel := context.WithTimeout(ctx, 10*time.Second) + defer pcancel() + if promptRes, err := client.ListPrompts(pctx, mcp.ListPromptsRequest{}); err == nil { + sc.Prompts = promptRes.Prompts + } + + sc.RefreshCache() + return sc, nil +} + +// SetProgressHandler registers a handler for progress notifications on this connection. +func (sc *ServerConnection) SetProgressHandler(handler func(mcp.ProgressNotification)) { + if sc.Client != nil { + sc.Client.OnNotification(func(notification mcp.JSONRPCNotification) { + if notification.Method == "notifications/progress" { + // Best-effort decode + var progress mcp.ProgressNotification + if data, err := json.Marshal(notification.Params); err == nil { + _ = json.Unmarshal(data, &progress) + } + handler(progress) + } + }) + } +} diff --git a/internal/mcpclient/registry.go b/internal/mcpclient/registry.go new file mode 100644 index 0000000..75e7816 --- /dev/null +++ b/internal/mcpclient/registry.go @@ -0,0 +1,259 @@ +// Package mcpclient provides a unified client for connecting to multiple MCP servers. +package mcpclient + +import ( + "context" + "fmt" + "sort" + "strings" + "sync" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/toasterbook88/axis/internal/config" +) + +// Registry manages connections to multiple MCP servers. +type Registry struct { + mu sync.RWMutex + servers map[string]*ServerConnection +} + +// NewRegistry creates an empty registry. +func NewRegistry() *Registry { + return &Registry{ + servers: make(map[string]*ServerConnection), + } +} + +// ConnectAll connects to every MCP server defined in the AXIS config. +// Errors are stored per-server and surfaced via the connection's Err field; +// the function never returns an error itself. +func (r *Registry) ConnectAll(ctx context.Context, cfg *config.Config) { + if cfg == nil || len(cfg.MCPServers) == 0 { + return + } + + var wg sync.WaitGroup + for name, serverCfg := range cfg.MCPServers { + wg.Add(1) + go func(n string, sc config.MCPServerConfig) { + defer wg.Done() + var conn *ServerConnection + var err error + switch strings.ToLower(sc.Transport) { + case "stdio": + conn, err = connectStdio(ctx, n, sc) + case "http": + conn, err = connectHTTP(ctx, n, sc) + default: + conn = &ServerConnection{ + Name: n, + Transport: sc.Transport, + Config: sc, + Err: fmt.Errorf("unsupported transport %q", sc.Transport), + } + } + if err != nil { + conn = &ServerConnection{ + Name: n, + Transport: sc.Transport, + Config: sc, + Err: err, + } + } + r.mu.Lock() + r.servers[n] = conn + r.mu.Unlock() + }(name, serverCfg) + } + wg.Wait() +} + +// Get returns the connection for a named server, or nil. +func (r *Registry) Get(name string) *ServerConnection { + r.mu.RLock() + defer r.mu.RUnlock() + return r.servers[name] +} + +// Names returns all configured server names, sorted. +func (r *Registry) Names() []string { + r.mu.RLock() + defer r.mu.RUnlock() + names := make([]string, 0, len(r.servers)) + for n := range r.servers { + names = append(names, n) + } + sort.Strings(names) + return names +} + +// ConnectedNames returns names of successfully connected servers. +func (r *Registry) ConnectedNames() []string { + r.mu.RLock() + defer r.mu.RUnlock() + var names []string + for n, s := range r.servers { + if s.Connected() { + names = append(names, n) + } + } + sort.Strings(names) + return names +} + +// Close closes all managed connections. +func (r *Registry) Close() { + r.mu.RLock() + servers := make([]*ServerConnection, 0, len(r.servers)) + for _, s := range r.servers { + servers = append(servers, s) + } + r.mu.RUnlock() + for _, s := range servers { + _ = s.Close() + } +} + +// ToolEntry is a tool annotated with the server that provides it. +type ToolEntry struct { + Server string + Tool mcp.Tool +} + +// ListAllTools returns every tool from every connected server. +func (r *Registry) ListAllTools() []ToolEntry { + r.mu.RLock() + defer r.mu.RUnlock() + var out []ToolEntry + for _, s := range r.servers { + if !s.Connected() { + continue + } + for _, t := range s.CachedTools() { + out = append(out, ToolEntry{Server: s.Name, Tool: t}) + } + } + return out +} + +// FindTool returns the first tool matching name across all connected servers. +// Tools are searched in deterministic server name order. +func (r *Registry) FindTool(name string) (ToolEntry, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + names := make([]string, 0, len(r.servers)) + for n := range r.servers { + names = append(names, n) + } + sort.Strings(names) + + for _, n := range names { + s := r.servers[n] + if !s.Connected() { + continue + } + for _, t := range s.CachedTools() { + if t.Name == name { + return ToolEntry{Server: s.Name, Tool: t}, true + } + } + } + return ToolEntry{}, false +} + +// FindAllToolServers returns every server that offers a tool with the given name. +func (r *Registry) FindAllToolServers(name string) []ToolEntry { + r.mu.RLock() + defer r.mu.RUnlock() + + names := make([]string, 0, len(r.servers)) + for n := range r.servers { + names = append(names, n) + } + sort.Strings(names) + + var out []ToolEntry + for _, n := range names { + s := r.servers[n] + if !s.Connected() { + continue + } + for _, t := range s.CachedTools() { + if t.Name == name { + out = append(out, ToolEntry{Server: s.Name, Tool: t}) + } + } + } + return out +} + +// ResourceEntry is a resource annotated with its server. +type ResourceEntry struct { + Server string + Resource mcp.Resource +} + +// ListAllResources returns every resource from every connected server. +func (r *Registry) ListAllResources() []ResourceEntry { + r.mu.RLock() + defer r.mu.RUnlock() + var out []ResourceEntry + for _, s := range r.servers { + if !s.Connected() { + continue + } + for _, res := range s.CachedResources() { + out = append(out, ResourceEntry{Server: s.Name, Resource: res}) + } + } + return out +} + +// PromptEntry is a prompt annotated with its server. +type PromptEntry struct { + Server string + Prompt mcp.Prompt +} + +// ListAllPrompts returns every prompt from every connected server. +func (r *Registry) ListAllPrompts() []PromptEntry { + r.mu.RLock() + defer r.mu.RUnlock() + var out []PromptEntry + for _, s := range r.servers { + if !s.Connected() { + continue + } + for _, p := range s.CachedPrompts() { + out = append(out, PromptEntry{Server: s.Name, Prompt: p}) + } + } + return out +} + +// GetPrompt fetches a specific prompt by name from a server, with optional arguments. +func (r *Registry) GetPrompt(ctx context.Context, serverName, promptName string, args map[string]any) (*mcp.GetPromptResult, error) { + sc := r.Get(serverName) + if sc == nil { + return nil, fmt.Errorf("server %q not configured", serverName) + } + if !sc.Connected() { + return nil, fmt.Errorf("server %q not connected: %v", serverName, sc.Err) + } + cctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + req := mcp.GetPromptRequest{} + req.Params.Name = promptName + // Convert map[string]any to map[string]string for MCP prompt arguments + if args != nil { + strArgs := make(map[string]string, len(args)) + for k, v := range args { + strArgs[k] = fmt.Sprintf("%v", v) + } + req.Params.Arguments = strArgs + } + return sc.Client.GetPrompt(cctx, req) +} diff --git a/internal/mcpclient/registry_test.go b/internal/mcpclient/registry_test.go new file mode 100644 index 0000000..4961175 --- /dev/null +++ b/internal/mcpclient/registry_test.go @@ -0,0 +1,205 @@ +package mcpclient + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/toasterbook88/axis/internal/config" +) + +// mockClient implements mcpclient.MCPClient for testing. +type mockClient struct { + initResult *mcp.InitializeResult + tools []mcp.Tool + resources []mcp.Resource + prompts []mcp.Prompt + closed bool +} + +func (m *mockClient) Initialize(ctx context.Context, req mcp.InitializeRequest) (*mcp.InitializeResult, error) { + return m.initResult, nil +} + +func (m *mockClient) Ping(ctx context.Context) error { return nil } +func (m *mockClient) ListResourcesByPage(ctx context.Context, req mcp.ListResourcesRequest) (*mcp.ListResourcesResult, error) { + return nil, nil +} +func (m *mockClient) ListResources(ctx context.Context, req mcp.ListResourcesRequest) (*mcp.ListResourcesResult, error) { + return &mcp.ListResourcesResult{Resources: m.resources}, nil +} +func (m *mockClient) ListResourceTemplatesByPage(ctx context.Context, req mcp.ListResourceTemplatesRequest) (*mcp.ListResourceTemplatesResult, error) { + return nil, nil +} +func (m *mockClient) ListResourceTemplates(ctx context.Context, req mcp.ListResourceTemplatesRequest) (*mcp.ListResourceTemplatesResult, error) { + return nil, nil +} +func (m *mockClient) ReadResource(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return nil, nil +} +func (m *mockClient) Subscribe(ctx context.Context, req mcp.SubscribeRequest) error { return nil } +func (m *mockClient) Unsubscribe(ctx context.Context, req mcp.UnsubscribeRequest) error { return nil } +func (m *mockClient) ListPromptsByPage(ctx context.Context, req mcp.ListPromptsRequest) (*mcp.ListPromptsResult, error) { + return nil, nil +} +func (m *mockClient) ListPrompts(ctx context.Context, req mcp.ListPromptsRequest) (*mcp.ListPromptsResult, error) { + return &mcp.ListPromptsResult{Prompts: m.prompts}, nil +} +func (m *mockClient) GetPrompt(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + return nil, nil +} +func (m *mockClient) ListToolsByPage(ctx context.Context, req mcp.ListToolsRequest) (*mcp.ListToolsResult, error) { + return nil, nil +} +func (m *mockClient) ListTools(ctx context.Context, req mcp.ListToolsRequest) (*mcp.ListToolsResult, error) { + return &mcp.ListToolsResult{Tools: m.tools}, nil +} +func (m *mockClient) CallTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return nil, nil +} +func (m *mockClient) SetLevel(ctx context.Context, req mcp.SetLevelRequest) error { return nil } +func (m *mockClient) Complete(ctx context.Context, req mcp.CompleteRequest) (*mcp.CompleteResult, error) { + return nil, nil +} +func (m *mockClient) Close() error { + m.closed = true + return nil +} +func (m *mockClient) OnNotification(handler func(notification mcp.JSONRPCNotification)) {} + +func TestRegistryNames(t *testing.T) { + r := NewRegistry() + r.servers["alpha"] = &ServerConnection{Name: "alpha"} + r.servers["beta"] = &ServerConnection{Name: "beta"} + + names := r.Names() + if len(names) != 2 { + t.Fatalf("expected 2 names, got %d", len(names)) + } + if names[0] != "alpha" || names[1] != "beta" { + t.Fatalf("unexpected names: %v", names) + } +} + +func TestRegistryConnectedNames(t *testing.T) { + r := NewRegistry() + r.servers["up"] = &ServerConnection{ + Name: "up", + InitResult: &mcp.InitializeResult{}, + } + r.servers["down"] = &ServerConnection{ + Name: "down", + Err: context.DeadlineExceeded, + } + + connected := r.ConnectedNames() + if len(connected) != 1 || connected[0] != "up" { + t.Fatalf("expected [up], got %v", connected) + } +} + +func TestRegistryListAllTools(t *testing.T) { + r := NewRegistry() + r.servers["a"] = &ServerConnection{ + Name: "a", + InitResult: &mcp.InitializeResult{}, + Tools: []mcp.Tool{{Name: "tool-a"}, {Name: "tool-a2"}}, + } + r.servers["b"] = &ServerConnection{ + Name: "b", + InitResult: &mcp.InitializeResult{}, + Tools: []mcp.Tool{{Name: "tool-b"}}, + } + + tools := r.ListAllTools() + if len(tools) != 3 { + t.Fatalf("expected 3 tools, got %d", len(tools)) + } +} + +func TestRegistryFindTool(t *testing.T) { + r := NewRegistry() + r.servers["a"] = &ServerConnection{ + Name: "a", + InitResult: &mcp.InitializeResult{}, + Tools: []mcp.Tool{{Name: "find-me"}}, + } + + entry, ok := r.FindTool("find-me") + if !ok { + t.Fatal("expected to find tool") + } + if entry.Server != "a" || entry.Tool.Name != "find-me" { + t.Fatalf("unexpected entry: %+v", entry) + } + + _, ok = r.FindTool("missing") + if ok { + t.Fatal("expected not to find tool") + } +} + +func TestRegistryClose(t *testing.T) { + mock := &mockClient{} + r := NewRegistry() + r.servers["x"] = &ServerConnection{Name: "x", Client: mock} + + r.Close() + if !mock.closed { + t.Fatal("expected mock client to be closed") + } +} + +func TestParseArgs(t *testing.T) { + tests := []struct { + input string + want map[string]any + }{ + {"", nil}, + {`{"key":"value"}`, map[string]any{"key": "value"}}, + } + + for _, tt := range tests { + got, err := ParseArgs(tt.input) + if err != nil { + t.Fatalf("ParseArgs(%q) error: %v", tt.input, err) + } + if tt.want == nil && got != nil { + t.Fatalf("ParseArgs(%q) expected nil, got %v", tt.input, got) + } + if tt.want != nil { + if len(got) != len(tt.want) { + t.Fatalf("ParseArgs(%q) expected %v, got %v", tt.input, tt.want, got) + } + } + } +} + +func TestParseArgsInvalidJSON(t *testing.T) { + _, err := ParseArgs("not-json") + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestConnectAllUnsupportedTransport(t *testing.T) { + r := NewRegistry() + cfg := &config.Config{ + Nodes: []config.NodeConfig{ + {Name: "dummy", Hostname: "localhost", SSHUser: "root"}, + }, + MCPServers: map[string]config.MCPServerConfig{ + "bad": {Transport: "unknown"}, + }, + } + ctx := context.Background() + r.ConnectAll(ctx, cfg) + + sc := r.Get("bad") + if sc == nil { + t.Fatal("expected server connection to exist") + } + if sc.Err == nil { + t.Fatal("expected error for unsupported transport") + } +} From a9e74083883c727bac0b80f55311a925e7794bde Mon Sep 17 00:00:00 2001 From: William Date: Sun, 24 May 2026 16:02:45 -0400 Subject: [PATCH 2/4] Fix daemon refresh coalescing logic for MCP client improvements Fixes TestMeshPartitionCoalescing by properly handling pending triggers during refresh completion. The original logic had a race condition where pending triggers weren't properly counted when completing a refresh. Changes: - Record start time when winning the refresh lock for latency fallback - Capture pending state before clearing it - Use captured pending triggers for next trigger determination - Fix activeRequestedAt assignment logic --- internal/daemon/daemon.go | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 51b6b06..672222c 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -427,13 +427,26 @@ func (d *Daemon) refreshWithTrigger(ctx context.Context, trigger string) error { return nil } + // We are the winner: record start time for latency fallback + startTime := time.Now() + + // Do the refresh + err := d.doRefresh(ctx, trigger) + + // After refresh, process pending triggers that came in while we were refreshing + var pendingTriggers map[string]bool + var pendingRequestedAt time.Time + d.pendingMu.Lock() - if !d.pendingRequestedAt.IsZero() { - d.activeRequestedAt = d.pendingRequestedAt - d.pendingRequestedAt = time.Time{} - d.pendingTriggers = make(map[string]bool) + pendingTriggers = d.pendingTriggers + pendingRequestedAt = d.pendingRequestedAt + // Clear the pending triggers and reset the requestedAt for next cycle + d.pendingTriggers = make(map[string]bool) + d.pendingRequestedAt = time.Time{} + if !pendingRequestedAt.IsZero() { + d.activeRequestedAt = pendingRequestedAt } else { - d.activeRequestedAt = time.Now() + d.activeRequestedAt = startTime } d.pendingMu.Unlock() @@ -441,16 +454,14 @@ func (d *Daemon) refreshWithTrigger(ctx context.Context, trigger string) error { d.refreshing.Store(false) var nextTrigger string - d.pendingMu.Lock() - if len(d.pendingTriggers) > 0 { + if len(pendingTriggers) > 0 { var keys []string - for k := range d.pendingTriggers { + for k := range pendingTriggers { keys = append(keys, k) } sort.Strings(keys) nextTrigger = strings.Join(keys, ",") } - d.pendingMu.Unlock() select { case <-d.pendingRefresh: @@ -462,8 +473,7 @@ func (d *Daemon) refreshWithTrigger(ctx context.Context, trigger string) error { } }() - err := d.doRefresh(ctx, trigger) - + // Latency measurement now := time.Now() d.pendingMu.Lock() latency := now.Sub(d.activeRequestedAt) From 036509ae6e22f7362a7d0e0053052e3f1049f423 Mon Sep 17 00:00:00 2001 From: William Date: Mon, 25 May 2026 14:22:37 -0400 Subject: [PATCH 3/4] fix(mcp): address all gemini-code-assist review comments on PR #140 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace context.Background() with cmd.Context() in all runMCPClient* functions - Fix --format flag default typo (format → text) - Fix REPL arg parsing for call/get-prompt to join remaining args with spaces - Fix REPL search to use full multi-word keyword - Refactor isTransientError to use errors.As with httpStatusCoder instead of fragile string matching - Extract shared callTool internal method to deduplicate CallTool and CallToolWithProgress - Update mcp_client_test.go to pass context.Context Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cmd/axis/mcp_client.go | 78 ++++++++++++++++++------------------- cmd/axis/mcp_client_test.go | 9 +++-- internal/mcpclient/call.go | 66 ++++++++++++++++--------------- 3 files changed, 79 insertions(+), 74 deletions(-) diff --git a/cmd/axis/mcp_client.go b/cmd/axis/mcp_client.go index 0d5e9e7..4cb1396 100644 --- a/cmd/axis/mcp_client.go +++ b/cmd/axis/mcp_client.go @@ -45,14 +45,14 @@ func mcpClientListCmd() *cobra.Command { Use: "list", Short: "List configured MCP servers and connection status", RunE: func(cmd *cobra.Command, args []string) error { - return runMCPClientList(cmd.OutOrStdout(), format) + return runMCPClientList(cmd.Context(), cmd.OutOrStdout(), format) }, } cmd.Flags().StringVar(&format, "format", "text", "Output format: text or json") return cmd } -func runMCPClientList(out io.Writer, format string) error { +func runMCPClientList(ctx context.Context, out io.Writer, format string) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) @@ -65,7 +65,7 @@ func runMCPClientList(out io.Writer, format string) error { } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -138,7 +138,7 @@ func mcpClientToolsCmd() *cobra.Command { Use: "tools", Short: "List tools from connected MCP servers", RunE: func(cmd *cobra.Command, args []string) error { - return runMCPClientTools(cmd.OutOrStdout(), server, format) + return runMCPClientTools(cmd.Context(), cmd.OutOrStdout(), server, format) }, } cmd.Flags().StringVar(&server, "server", "", "Filter to a specific server") @@ -146,14 +146,14 @@ func mcpClientToolsCmd() *cobra.Command { return cmd } -func runMCPClientTools(out io.Writer, server, format string) error { +func runMCPClientTools(ctx context.Context, out io.Writer, server, format string) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -225,7 +225,7 @@ func mcpClientCallCmd() *cobra.Command { if len(args) > 1 { rawArgs = args[1] } - return runMCPClientCallAutoRoute(cmd.OutOrStdout(), toolName, rawArgs, pretty) + return runMCPClientCallAutoRoute(cmd.Context(), cmd.OutOrStdout(), toolName, rawArgs, pretty) } serverName := args[0] toolName := args[1] @@ -233,7 +233,7 @@ func mcpClientCallCmd() *cobra.Command { if len(args) > 2 { rawArgs = args[2] } - return runMCPClientCall(cmd.OutOrStdout(), serverName, toolName, rawArgs, pretty) + return runMCPClientCall(cmd.Context(), cmd.OutOrStdout(), serverName, toolName, rawArgs, pretty) }, } cmd.Flags().BoolVar(&pretty, "pretty", false, "Pretty-print JSON output") @@ -241,14 +241,14 @@ func mcpClientCallCmd() *cobra.Command { return cmd } -func runMCPClientCallAutoRoute(out io.Writer, toolName, rawArgs string, pretty bool) error { +func runMCPClientCallAutoRoute(ctx context.Context, out io.Writer, toolName, rawArgs string, pretty bool) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -275,14 +275,14 @@ func runMCPClientCallAutoRoute(out io.Writer, toolName, rawArgs string, pretty b return nil } -func runMCPClientCall(out io.Writer, serverName, toolName, rawArgs string, pretty bool) error { +func runMCPClientCall(ctx context.Context, out io.Writer, serverName, toolName, rawArgs string, pretty bool) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -316,22 +316,22 @@ func mcpClientResourcesCmd() *cobra.Command { Use: "resources", Short: "List resources from connected MCP servers", RunE: func(cmd *cobra.Command, args []string) error { - return runMCPClientResources(cmd.OutOrStdout(), server, format) + return runMCPClientResources(cmd.Context(), cmd.OutOrStdout(), server, format) }, } cmd.Flags().StringVar(&server, "server", "", "Filter to a specific server") - cmd.Flags().StringVar(&format, "format", "format", "Output format: text or json") + cmd.Flags().StringVar(&format, "format", "text", "Output format: text or json") return cmd } -func runMCPClientResources(out io.Writer, server, format string) error { +func runMCPClientResources(ctx context.Context, out io.Writer, server, format string) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -381,21 +381,21 @@ func mcpClientReadCmd() *cobra.Command { Short: "Read a resource from a specific MCP server", Args: cobra.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) error { - return runMCPClientRead(cmd.OutOrStdout(), args[0], args[1], pretty) + return runMCPClientRead(cmd.Context(), cmd.OutOrStdout(), args[0], args[1], pretty) }, } cmd.Flags().BoolVar(&pretty, "pretty", false, "Pretty-print JSON output") return cmd } -func runMCPClientRead(out io.Writer, serverName, uri string, pretty bool) error { +func runMCPClientRead(ctx context.Context, out io.Writer, serverName, uri string, pretty bool) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -424,7 +424,7 @@ func mcpClientPromptsCmd() *cobra.Command { Use: "prompts", Short: "List prompts from connected MCP servers", RunE: func(cmd *cobra.Command, args []string) error { - return runMCPClientPrompts(cmd.OutOrStdout(), server, format) + return runMCPClientPrompts(cmd.Context(), cmd.OutOrStdout(), server, format) }, } cmd.Flags().StringVar(&server, "server", "", "Filter to a specific server") @@ -432,13 +432,13 @@ func mcpClientPromptsCmd() *cobra.Command { return cmd } -func runMCPClientPrompts(out io.Writer, server, format string) error { +func runMCPClientPrompts(ctx context.Context, out io.Writer, server, format string) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -498,20 +498,20 @@ func mcpClientGetPromptCmd() *cobra.Command { if len(args) > 2 { rawArgs = args[2] } - return runMCPClientGetPrompt(cmd.OutOrStdout(), serverName, promptName, rawArgs, pretty) + return runMCPClientGetPrompt(cmd.Context(), cmd.OutOrStdout(), serverName, promptName, rawArgs, pretty) }, } cmd.Flags().BoolVar(&pretty, "pretty", false, "Pretty-print JSON output") return cmd } -func runMCPClientGetPrompt(out io.Writer, serverName, promptName, rawArgs string, pretty bool) error { +func runMCPClientGetPrompt(ctx context.Context, out io.Writer, serverName, promptName, rawArgs string, pretty bool) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -534,19 +534,19 @@ func mcpClientSearchCmd() *cobra.Command { Short: "Search tools by name or description across all connected MCP servers", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return runMCPClientSearch(cmd.OutOrStdout(), args[0]) + return runMCPClientSearch(cmd.Context(), cmd.OutOrStdout(), args[0]) }, } return cmd } -func runMCPClientSearch(out io.Writer, keyword string) error { +func runMCPClientSearch(ctx context.Context, out io.Writer, keyword string) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -591,7 +591,7 @@ Example file: ]`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - return runMCPClientBatch(cmd.OutOrStdout(), args[0]) + return runMCPClientBatch(cmd.Context(), cmd.OutOrStdout(), args[0]) }, } return cmd @@ -603,7 +603,7 @@ type batchEntry struct { Args map[string]any `json:"args,omitempty"` } -func runMCPClientBatch(out io.Writer, path string) error { +func runMCPClientBatch(ctx context.Context, out io.Writer, path string) error { data, err := os.ReadFile(path) if err != nil { return fmt.Errorf("read batch file: %w", err) @@ -618,7 +618,7 @@ func runMCPClientBatch(out io.Writer, path string) error { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -668,19 +668,19 @@ func mcpClientInteractiveCmd() *cobra.Command { Use: "interactive", Short: "Interactive REPL for exploring and calling MCP servers", RunE: func(cmd *cobra.Command, args []string) error { - return runMCPClientInteractive(cmd.InOrStdin(), cmd.OutOrStdout()) + return runMCPClientInteractive(cmd.Context(), cmd.InOrStdin(), cmd.OutOrStdout()) }, } return cmd } -func runMCPClientInteractive(in io.Reader, out io.Writer) error { +func runMCPClientInteractive(ctx context.Context, in io.Reader, out io.Writer) error { cfg, err := loadMCPClientConfig(config.DefaultConfigPath()) if err != nil { return fmt.Errorf("load config: %w", err) } reg := mcpclient.NewRegistry() - ctx := context.Background() + reg.ConnectAll(ctx, cfg) defer reg.Close() @@ -771,9 +771,9 @@ func runMCPClientInteractive(in io.Reader, out io.Writer) error { fmt.Fprintln(out, "Usage: call [json-args]") continue } - var rawArgs string + rawArgs := "" if len(args) > 2 { - rawArgs = args[2] + rawArgs = strings.Join(args[2:], " ") } parsedArgs, parseErr := mcpclient.ParseArgs(rawArgs) if parseErr != nil { @@ -816,9 +816,9 @@ func runMCPClientInteractive(in io.Reader, out io.Writer) error { fmt.Fprintln(out, "Usage: get-prompt [json-args]") continue } - var rawArgs string + rawArgs := "" if len(args) > 2 { - rawArgs = args[2] + rawArgs = strings.Join(args[2:], " ") } parsedArgs, parseErr := mcpclient.ParseArgs(rawArgs) if parseErr != nil { @@ -837,7 +837,7 @@ func runMCPClientInteractive(in io.Reader, out io.Writer) error { fmt.Fprintln(out, "Usage: search ") continue } - keywordLower := strings.ToLower(args[0]) + keywordLower := strings.ToLower(strings.Join(args, " ")) for _, te := range reg.ListAllTools() { if strings.Contains(strings.ToLower(te.Tool.Name), keywordLower) || strings.Contains(strings.ToLower(te.Tool.Description), keywordLower) { diff --git a/cmd/axis/mcp_client_test.go b/cmd/axis/mcp_client_test.go index 1790c05..7b88ada 100644 --- a/cmd/axis/mcp_client_test.go +++ b/cmd/axis/mcp_client_test.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "strings" "testing" @@ -23,7 +24,7 @@ func TestMCPClientListEmptyConfig(t *testing.T) { } var buf bytes.Buffer - err := runMCPClientList(&buf, "text") + err := runMCPClientList(context.Background(), &buf, "text") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -51,7 +52,7 @@ func TestMCPClientListJSON(t *testing.T) { } var buf bytes.Buffer - err := runMCPClientList(&buf, "json") + err := runMCPClientList(context.Background(), &buf, "json") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -77,7 +78,7 @@ func TestMCPClientToolsMissingServer(t *testing.T) { } var buf bytes.Buffer - err := runMCPClientTools(&buf, "missing", "text") + err := runMCPClientTools(context.Background(), &buf, "missing", "text") if err == nil { t.Fatal("expected error for missing server") } @@ -103,7 +104,7 @@ func TestMCPClientParseArgs(t *testing.T) { var buf bytes.Buffer // Call with a non-existent server to test arg parsing path - err := runMCPClientCall(&buf, "missing", "tool", `{"key":"value"}`, false) + err := runMCPClientCall(context.Background(), &buf, "missing", "tool", `{"key":"value"}`, false) if err == nil { t.Fatal("expected error for missing server") } diff --git a/internal/mcpclient/call.go b/internal/mcpclient/call.go index 47dec13..f83910b 100644 --- a/internal/mcpclient/call.go +++ b/internal/mcpclient/call.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "strings" "time" "github.com/mark3labs/mcp-go/mcp" @@ -36,6 +35,11 @@ type CallResult struct { Err error } +// httpStatusCoder is implemented by HTTP response errors that carry a status code. +type httpStatusCoder interface { + StatusCode() int +} + // isTransientError reports whether an error is likely temporary and worth retrying. func isTransientError(err error) bool { if err == nil { @@ -47,10 +51,14 @@ func isTransientError(err error) bool { if netErr, ok := err.(interface{ Temporary() bool }); ok && netErr.Temporary() { return true } - // HTTP 5xx status codes (for HTTP transport) - if strings.Contains(err.Error(), "500") || strings.Contains(err.Error(), "502") || - strings.Contains(err.Error(), "503") || strings.Contains(err.Error(), "504") { - return true + // HTTP 5xx status codes — inspect directly via type assertion rather than + // fragile string matching. + var statusCoder httpStatusCoder + if errors.As(err, &statusCoder) { + code := statusCoder.StatusCode() + if code >= 500 && code < 600 { + return true + } } return false } @@ -79,6 +87,24 @@ func withRetry(ctx context.Context, fn func() error) error { return err } +// callTool performs the actual tool invocation with retry and metrics recording. +func (r *Registry) callTool(ctx context.Context, sc *ServerConnection, toolName string, args map[string]any) (*mcp.CallToolResult, error) { + start := time.Now() + var result *mcp.CallToolResult + err := withRetry(ctx, func() error { + cctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + req := mcp.CallToolRequest{} + req.Params.Name = toolName + req.Params.Arguments = args + res, callErr := sc.Client.CallTool(cctx, req) + result = res + return callErr + }) + sc.RecordCall(time.Since(start), err) + return result, err +} + // CallToolWithProgress invokes a tool and prints progress notifications to stderr. func (r *Registry) CallToolWithProgress(ctx context.Context, serverName, toolName string, args map[string]any, progressOut io.Writer) CallResult { sc := r.Get(serverName) @@ -101,21 +127,11 @@ func (r *Registry) CallToolWithProgress(ctx context.Context, serverName, toolNam } }) - start := time.Now() - var result *mcp.CallToolResult - err := withRetry(ctx, func() error { - cctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - req := mcp.CallToolRequest{} - req.Params.Name = toolName - req.Params.Arguments = args - res, callErr := sc.Client.CallTool(cctx, req) - result = res - return callErr - }) - sc.RecordCall(time.Since(start), err) + result, err := r.callTool(ctx, sc, toolName, args) return CallResult{Server: serverName, Result: result, Err: err} } + +// CallTool invokes a tool on a specific server. func (r *Registry) CallTool(ctx context.Context, serverName, toolName string, args map[string]any) CallResult { sc := r.Get(serverName) if sc == nil { @@ -125,19 +141,7 @@ func (r *Registry) CallTool(ctx context.Context, serverName, toolName string, ar return CallResult{Server: serverName, Err: fmt.Errorf("server %q not connected: %v", serverName, sc.Err)} } - start := time.Now() - var result *mcp.CallToolResult - err := withRetry(ctx, func() error { - cctx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - req := mcp.CallToolRequest{} - req.Params.Name = toolName - req.Params.Arguments = args - res, callErr := sc.Client.CallTool(cctx, req) - result = res - return callErr - }) - sc.RecordCall(time.Since(start), err) + result, err := r.callTool(ctx, sc, toolName, args) return CallResult{Server: serverName, Result: result, Err: err} } From 0fc79ec735ae09b8610bc2384b221943844c6548 Mon Sep 17 00:00:00 2001 From: William Date: Wed, 27 May 2026 19:59:47 -0400 Subject: [PATCH 4/4] fix(daemon): deflake TestDaemonRefreshCoalescingAndLatency temp-dir cleanup Add a 100ms grace period at the end of the test to allow any lingering background goroutines from prior watch-config tests to finish file operations before t.TempDir() cleanup runs. The failure was an unlinkat ... directory not empty caused by doRefresh writes racing with RemoveAll. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- internal/daemon/daemon_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index 4bc0247..e1fb0e6 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -898,6 +898,10 @@ func TestDaemonRefreshCoalescingAndLatency(t *testing.T) { if meta.MaxRefreshLatencyMs < 0 { t.Fatalf("expected non-negative MaxRefreshLatencyMs, got %d", meta.MaxRefreshLatencyMs) } + + // Grace period for any leaked goroutines from prior watch tests to finish + // file operations before t.TempDir() cleanup runs. + time.Sleep(100 * time.Millisecond) } func TestNewDefaultCreatesMeshWhenNoDiscoveryConfig(t *testing.T) {