Skip to content

Commit 687b754

Browse files
committed
internal/mcp: reorganize the tool API, keeping generics
CallToolParams.Arguments is any. Clients can send whatever they want. Tool authors who create their own ToolHandler (bypassing NewTool) will receive the arguments as a map[string]any which has been validated against the input schema. Tool authors who call NewTool can choose the type that the arguments unmarshal into. There is no double-unmarshaling and no downcasting. There is no way for tool authors to avoid the unmarshal, and no way to avoid the validation unless the omit an input schema, which is probably bad for LLMs. If tool authors want this optimization, we can provide it later. Change-Id: Ieb201e56c7fb8a23bb1a93b27e946630ec9e79ad Reviewed-on: https://go-review.googlesource.com/c/tools/+/679595 TryBot-Bypass: Jonathan Amsterdam <jba@google.com> Auto-Submit: Jonathan Amsterdam <jba@google.com> Reviewed-by: Robert Findley <rfindley@google.com>
1 parent c7d803c commit 687b754

File tree

22 files changed

+334
-237
lines changed

22 files changed

+334
-237
lines changed

gopls/internal/mcp/context.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type ContextParams struct {
3232
Location protocol.Location `json:"location"`
3333
}
3434

35-
func contextHandler(ctx context.Context, session *cache.Session, params *mcp.CallToolParams[ContextParams]) (*mcp.CallToolResult, error) {
35+
func contextHandler(ctx context.Context, session *cache.Session, params *mcp.CallToolParamsFor[ContextParams]) (*mcp.CallToolResultFor[struct{}], error) {
3636
fh, snapshot, release, err := session.FileOf(ctx, params.Arguments.Location.URI)
3737
if err != nil {
3838
return nil, err
@@ -137,7 +137,7 @@ func contextHandler(ctx context.Context, session *cache.Session, params *mcp.Cal
137137
}
138138
}
139139

140-
return &mcp.CallToolResult{
140+
return &mcp.CallToolResultFor[struct{}]{
141141
Content: []*mcp.Content{
142142
mcp.NewTextContent(result.String()),
143143
},

gopls/internal/mcp/mcp.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func newServer(session *cache.Session) *mcp.Server {
117117
mcp.NewTool(
118118
"context",
119119
"Provide context for a region within a Go file",
120-
func(ctx context.Context, _ *mcp.ServerSession, request *mcp.CallToolParams[ContextParams]) (*mcp.CallToolResult, error) {
120+
func(ctx context.Context, _ *mcp.ServerSession, request *mcp.CallToolParamsFor[ContextParams]) (*mcp.CallToolResultFor[struct{}], error) {
121121
return contextHandler(ctx, session, request)
122122
},
123123
mcp.Input(

gopls/internal/test/marker/marker_test.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ func Test(t *testing.T) {
297297
t.Errorf("formatTest: %v", err)
298298
} else if *update {
299299
filename := filepath.Join(dir, test.name)
300-
if err := os.WriteFile(filename, formatted, 0644); err != nil {
300+
if err := os.WriteFile(filename, formatted, 0o644); err != nil {
301301
t.Error(err)
302302
}
303303
} else if !t.Failed() {
@@ -1414,7 +1414,6 @@ func (sm stringMatcher) check(mark marker, got string) {
14141414
if !sm.pattern.MatchString(got) {
14151415
mark.errorf("got %q, does not match pattern %#q", got, sm.pattern)
14161416
}
1417-
14181417
} else if !strings.Contains(got, sm.substr) {
14191418
// Content must contain the expected substring.
14201419
mark.errorf("got %q, want substring %q", got, sm.substr)
@@ -1428,7 +1427,6 @@ func checkChangedFiles(mark marker, changed map[string][]byte, golden *Golden) {
14281427
if want, ok := golden.Get(mark.T(), filename, got); !ok {
14291428
mark.errorf("%s: unexpected change to file %s; got:\n%s",
14301429
mark.note.Name, filename, got)
1431-
14321430
} else if string(got) != string(want) {
14331431
mark.errorf("%s: wrong file content for %s: got:\n%s\nwant:\n%s\ndiff:\n%s",
14341432
mark.note.Name, filename, got, want,
@@ -1481,7 +1479,6 @@ func checkDiffs(mark marker, changed map[string][]byte, golden *Golden) {
14811479
if want, ok := golden.Get(mark.T(), filename, []byte(got)); !ok {
14821480
mark.errorf("%s: unexpected change to file %s; got diff:\n%s",
14831481
mark.note.Name, filename, got)
1484-
14851482
} else if got != string(want) {
14861483
mark.errorf("%s: wrong diff for %s:\n\ngot:\n%s\n\nwant:\n%s\n",
14871484
mark.note.Name, filename, got, want)
@@ -1680,7 +1677,6 @@ func acceptCompletionMarker(mark marker, src protocol.Location, label string, go
16801677
filename := mark.path()
16811678
mapper := mark.mapper()
16821679
patched, _, err := protocol.ApplyEdits(mapper, append([]protocol.TextEdit{edit}, selected.AdditionalTextEdits...))
1683-
16841680
if err != nil {
16851681
mark.errorf("ApplyProtocolEdits failed: %v", err)
16861682
return
@@ -1799,6 +1795,7 @@ func highlightLocationMarker(mark marker, loc protocol.Location, kindName expect
17991795
Kind: kind,
18001796
}
18011797
}
1798+
18021799
func sortDocumentHighlights(s []protocol.DocumentHighlight) {
18031800
sort.Slice(s, func(i, j int) bool {
18041801
return protocol.CompareRange(s[i].Range, s[j].Range) < 0
@@ -2452,7 +2449,7 @@ func mcpToolMarker(mark marker, tool string, rawArgs string, loc protocol.Locati
24522449
// TODO(hxjiang): Make the "location" key configurable.
24532450
args["location"] = loc
24542451

2455-
res, err := mcp.CallTool(mark.ctx(), mark.run.env.MCPSession, &mcp.CallToolParams[map[string]any]{
2452+
res, err := mark.run.env.MCPSession.CallTool(mark.ctx(), &mcp.CallToolParams{
24562453
Name: tool,
24572454
Arguments: args,
24582455
})

internal/mcp/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ func main() {
5050
}
5151
defer session.Close()
5252
// Call a tool on the server.
53-
params := &mcp.CallToolParams[map[string]any]{
53+
params := &mcp.CallToolParams{
5454
Name: "greet",
5555
Arguments: map[string]any{"name": "you"},
5656
}
57-
if res, err := mcp.CallTool(ctx, session, params); err != nil {
57+
if res, err := session.CallTool(ctx, params); err != nil {
5858
log.Printf("CallTool failed: %v", err)
5959
} else {
6060
if res.IsError {
@@ -82,8 +82,8 @@ type HiParams struct {
8282
Name string `json:"name"`
8383
}
8484

85-
func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParams[HiParams]) (*mcp.CallToolResult, error) {
86-
return &mcp.CallToolResult{
85+
func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[string], error) {
86+
return &mcp.CallToolResultFor[string]{
8787
Content: []*mcp.Content{mcp.NewTextContent("Hi " + params.Name)},
8888
}, nil
8989
}

internal/mcp/client.go

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ package mcp
66

77
import (
88
"context"
9-
"encoding/json"
10-
"fmt"
119
"iter"
1210
"slices"
1311
"sync"
@@ -288,34 +286,13 @@ func (cs *ClientSession) ListTools(ctx context.Context, params *ListToolsParams)
288286
}
289287

290288
// CallTool calls the tool with the given name and arguments.
291-
// Pass a [CallToolOptions] to provide additional request fields.
292-
func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams[json.RawMessage]) (*CallToolResult, error) {
293-
return handleSend[*CallToolResult](ctx, cs, methodCallTool, params)
294-
}
295-
296-
// CallTool is a helper to call a tool with any argument type. It returns an
297-
// error if params.Arguments fails to marshal to JSON.
298-
func CallTool[TArgs any](ctx context.Context, cs *ClientSession, params *CallToolParams[TArgs]) (*CallToolResult, error) {
299-
wireParams, err := toWireParams(params)
300-
if err != nil {
301-
return nil, err
289+
// The arguments can be any value that marshals into a JSON object.
290+
func (cs *ClientSession) CallTool(ctx context.Context, params *CallToolParams) (*CallToolResult, error) {
291+
if params.Arguments == nil {
292+
// Avoid sending nil over the wire.
293+
params.Arguments = map[string]any{}
302294
}
303-
return cs.CallTool(ctx, wireParams)
304-
}
305-
306-
func toWireParams[TArgs any](params *CallToolParams[TArgs]) (*CallToolParams[json.RawMessage], error) {
307-
data, err := json.Marshal(params.Arguments)
308-
if err != nil {
309-
return nil, fmt.Errorf("failed to marshal arguments: %v", err)
310-
}
311-
// The field mapping here must be kept up to date with the CallToolParams.
312-
// This is partially enforced by TestToWireParams, which verifies that all
313-
// comparable fields are mapped.
314-
return &CallToolParams[json.RawMessage]{
315-
Meta: params.Meta,
316-
Name: params.Name,
317-
Arguments: data,
318-
}, nil
295+
return handleSend[*CallToolResult](ctx, cs, methodCallTool, params)
319296
}
320297

321298
func (cs *ClientSession) SetLevel(ctx context.Context, params *SetLevelParams) error {

internal/mcp/client_test.go

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ package mcp
77
import (
88
"context"
99
"fmt"
10-
"reflect"
1110
"testing"
1211

1312
"github.com/google/go-cmp/cmp"
@@ -200,28 +199,3 @@ func TestClientPaginateVariousPageSizes(t *testing.T) {
200199
})
201200
}
202201
}
203-
204-
func TestToWireParams(t *testing.T) {
205-
// This test verifies that toWireParams maps all fields.
206-
// The Meta and Arguments fields are not comparable, so can't be checked by
207-
// this simple test. However, this test will fail if new fields are added,
208-
// and not handled by toWireParams.
209-
params := &CallToolParams[struct{}]{
210-
Name: "tool",
211-
}
212-
wireParams, err := toWireParams(params)
213-
if err != nil {
214-
t.Fatal(err)
215-
}
216-
v := reflect.ValueOf(wireParams).Elem()
217-
for i := range v.Type().NumField() {
218-
f := v.Type().Field(i)
219-
if f.Name == "Meta" || f.Name == "Arguments" {
220-
continue // not comparable
221-
}
222-
fv := v.Field(i)
223-
if fv.Interface() == reflect.Zero(f.Type).Interface() {
224-
t.Fatalf("toWireParams: unmapped field %q", f.Name)
225-
}
226-
}
227-
}

internal/mcp/cmd_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func TestCmdTransport(t *testing.T) {
5656
if err != nil {
5757
log.Fatal(err)
5858
}
59-
got, err := mcp.CallTool(ctx, session, &mcp.CallToolParams[map[string]any]{
59+
got, err := session.CallTool(ctx, &mcp.CallToolParams{
6060
Name: "greet",
6161
Arguments: map[string]any{"name": "user"},
6262
})

internal/mcp/examples/hello/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ type HiArgs struct {
2121
Name string `json:"name"`
2222
}
2323

24-
func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParams[HiArgs]) (*mcp.CallToolResult, error) {
25-
return &mcp.CallToolResult{
24+
func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[any], error) {
25+
return &mcp.CallToolResultFor[any]{
2626
Content: []*mcp.Content{
2727
mcp.NewTextContent("Hi " + params.Name),
2828
},

internal/mcp/examples/sse/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ type SayHiParams struct {
1919
Name string `json:"name"`
2020
}
2121

22-
func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParams[SayHiParams]) (*mcp.CallToolResult, error) {
23-
return &mcp.CallToolResult{
22+
func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[SayHiParams]) (*mcp.CallToolResultFor[any], error) {
23+
return &mcp.CallToolResultFor[any]{
2424
Content: []*mcp.Content{
2525
mcp.NewTextContent("Hi " + params.Name),
2626
},

internal/mcp/features_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ type SayHiParams struct {
1818
Name string `json:"name"`
1919
}
2020

21-
func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParams[SayHiParams]) (*CallToolResult, error) {
22-
return &CallToolResult{
21+
func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[SayHiParams]) (*CallToolResultFor[any], error) {
22+
return &CallToolResultFor[any]{
2323
Content: []*Content{
2424
NewTextContent("Hi " + params.Name),
2525
},
@@ -71,7 +71,6 @@ func TestFeatureSetAbove(t *testing.T) {
7171
got := slices.Collect(fs.above(tc.above))
7272
if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
7373
t.Errorf("expected %v, got %v, (-want +got):\n%s", tc.want, got, diff)
74-
7574
}
7675
}
7776
}

0 commit comments

Comments
 (0)