Skip to content

Commit 2cae60e

Browse files
committed
internal/mcp: iterators return pointers
Iterators return pointers instead of values, to be consistent. Move tests out of example file. Factor out iterator tests. Change-Id: Icb14edf99d738de9d9dc08c1169057cb3dc5894b Reviewed-on: https://go-review.googlesource.com/c/tools/+/680295 Reviewed-by: Robert Findley <rfindley@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Sam Thanawalla <samthanawalla@google.com>
1 parent 687b754 commit 2cae60e

File tree

4 files changed

+154
-139
lines changed

4 files changed

+154
-139
lines changed

internal/mcp/client.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,9 @@ func (c *Client) callLoggingHandler(ctx context.Context, cs *ClientSession, para
336336

337337
// Tools provides an iterator for all tools available on the server,
338338
// automatically fetching pages and managing cursors.
339-
// The `params` argument can set the initial cursor.
339+
// The params argument can set the initial cursor.
340340
// Iteration stops at the first encountered error, which will be yielded.
341-
func (cs *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[Tool, error] {
341+
func (cs *ClientSession) Tools(ctx context.Context, params *ListToolsParams) iter.Seq2[*Tool, error] {
342342
if params == nil {
343343
params = &ListToolsParams{}
344344
}
@@ -349,9 +349,9 @@ func (cs *ClientSession) Tools(ctx context.Context, params *ListToolsParams) ite
349349

350350
// Resources provides an iterator for all resources available on the server,
351351
// automatically fetching pages and managing cursors.
352-
// The `params` argument can set the initial cursor.
352+
// The params argument can set the initial cursor.
353353
// Iteration stops at the first encountered error, which will be yielded.
354-
func (cs *ClientSession) Resources(ctx context.Context, params *ListResourcesParams) iter.Seq2[Resource, error] {
354+
func (cs *ClientSession) Resources(ctx context.Context, params *ListResourcesParams) iter.Seq2[*Resource, error] {
355355
if params == nil {
356356
params = &ListResourcesParams{}
357357
}
@@ -364,7 +364,7 @@ func (cs *ClientSession) Resources(ctx context.Context, params *ListResourcesPar
364364
// automatically fetching pages and managing cursors.
365365
// The `params` argument can set the initial cursor.
366366
// Iteration stops at the first encountered error, which will be yielded.
367-
func (cs *ClientSession) ResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) iter.Seq2[ResourceTemplate, error] {
367+
func (cs *ClientSession) ResourceTemplates(ctx context.Context, params *ListResourceTemplatesParams) iter.Seq2[*ResourceTemplate, error] {
368368
if params == nil {
369369
params = &ListResourceTemplatesParams{}
370370
}
@@ -375,9 +375,9 @@ func (cs *ClientSession) ResourceTemplates(ctx context.Context, params *ListReso
375375

376376
// Prompts provides an iterator for all prompts available on the server,
377377
// automatically fetching pages and managing cursors.
378-
// The `params` argument can set the initial cursor.
378+
// The params argument can set the initial cursor.
379379
// Iteration stops at the first encountered error, which will be yielded.
380-
func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams) iter.Seq2[Prompt, error] {
380+
func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams) iter.Seq2[*Prompt, error] {
381381
if params == nil {
382382
params = &ListPromptsParams{}
383383
}
@@ -387,17 +387,16 @@ func (cs *ClientSession) Prompts(ctx context.Context, params *ListPromptsParams)
387387
}
388388

389389
// paginate is a generic helper function to provide a paginated iterator.
390-
func paginate[P listParams, R listResult[T], T any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*T) iter.Seq2[T, error] {
391-
return func(yield func(T, error) bool) {
390+
func paginate[P listParams, R listResult[T], T any](ctx context.Context, params P, listFunc func(context.Context, P) (R, error), items func(R) []*T) iter.Seq2[*T, error] {
391+
return func(yield func(*T, error) bool) {
392392
for {
393393
res, err := listFunc(ctx, params)
394394
if err != nil {
395-
var zero T
396-
yield(zero, err)
395+
yield(nil, err)
397396
return
398397
}
399398
for _, r := range items(res) {
400-
if !yield(*r, nil) {
399+
if !yield(r, nil) {
401400
return
402401
}
403402
}

internal/mcp/client_list_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright 2025 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package mcp_test
6+
7+
import (
8+
"context"
9+
"iter"
10+
"testing"
11+
12+
"github.com/google/go-cmp/cmp"
13+
"github.com/google/go-cmp/cmp/cmpopts"
14+
"golang.org/x/tools/internal/mcp"
15+
"golang.org/x/tools/internal/mcp/jsonschema"
16+
)
17+
18+
func TestListTools(t *testing.T) {
19+
toolA := mcp.NewTool("apple", "apple tool", SayHi)
20+
toolB := mcp.NewTool("banana", "banana tool", SayHi)
21+
toolC := mcp.NewTool("cherry", "cherry tool", SayHi)
22+
tools := []*mcp.ServerTool{toolA, toolB, toolC}
23+
24+
wantTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool}
25+
ctx := context.Background()
26+
clientSession, serverSession, server := createSessions(ctx)
27+
defer clientSession.Close()
28+
defer serverSession.Close()
29+
server.AddTools(tools...)
30+
t.Run("ListTools", func(t *testing.T) {
31+
res, err := clientSession.ListTools(ctx, nil)
32+
if err != nil {
33+
t.Fatal("ListTools() failed:", err)
34+
}
35+
if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
36+
t.Fatalf("ListTools() mismatch (-want +got):\n%s", diff)
37+
}
38+
})
39+
t.Run("ToolsIterator", func(t *testing.T) {
40+
testIterator(ctx, t, clientSession.Tools(ctx, nil), wantTools)
41+
})
42+
}
43+
44+
func TestListResources(t *testing.T) {
45+
resourceA := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://apple"}}
46+
resourceB := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://banana"}}
47+
resourceC := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://cherry"}}
48+
wantResources := []*mcp.Resource{resourceA.Resource, resourceB.Resource, resourceC.Resource}
49+
50+
resources := []*mcp.ServerResource{resourceA, resourceB, resourceC}
51+
ctx := context.Background()
52+
clientSession, serverSession, server := createSessions(ctx)
53+
defer clientSession.Close()
54+
defer serverSession.Close()
55+
server.AddResources(resources...)
56+
t.Run("ListResources", func(t *testing.T) {
57+
res, err := clientSession.ListResources(ctx, nil)
58+
if err != nil {
59+
t.Fatal("ListResources() failed:", err)
60+
}
61+
if diff := cmp.Diff(wantResources, res.Resources, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
62+
t.Fatalf("ListResources() mismatch (-want +got):\n%s", diff)
63+
}
64+
})
65+
t.Run("ResourcesIterator", func(t *testing.T) {
66+
testIterator(ctx, t, clientSession.Resources(ctx, nil), wantResources)
67+
})
68+
}
69+
70+
func TestListResourceTemplates(t *testing.T) {
71+
resourceTmplA := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://apple/{x}"}}
72+
resourceTmplB := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://banana/{x}"}}
73+
resourceTmplC := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://cherry/{x}"}}
74+
wantResourceTemplates := []*mcp.ResourceTemplate{
75+
resourceTmplA.ResourceTemplate, resourceTmplB.ResourceTemplate,
76+
resourceTmplC.ResourceTemplate,
77+
}
78+
resourceTemplates := []*mcp.ServerResourceTemplate{resourceTmplA, resourceTmplB, resourceTmplC}
79+
ctx := context.Background()
80+
clientSession, serverSession, server := createSessions(ctx)
81+
defer clientSession.Close()
82+
defer serverSession.Close()
83+
server.AddResourceTemplates(resourceTemplates...)
84+
t.Run("ListResourceTemplates", func(t *testing.T) {
85+
res, err := clientSession.ListResourceTemplates(ctx, nil)
86+
if err != nil {
87+
t.Fatal("ListResourceTemplates() failed:", err)
88+
}
89+
if diff := cmp.Diff(wantResourceTemplates, res.ResourceTemplates, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
90+
t.Fatalf("ListResourceTemplates() mismatch (-want +got):\n%s", diff)
91+
}
92+
})
93+
t.Run("ResourceTemplatesIterator", func(t *testing.T) {
94+
testIterator(ctx, t, clientSession.ResourceTemplates(ctx, nil), wantResourceTemplates)
95+
})
96+
}
97+
98+
func TestListPrompts(t *testing.T) {
99+
promptA := mcp.NewPrompt("apple", "apple prompt", testPromptHandler[struct{}])
100+
promptB := mcp.NewPrompt("banana", "banana prompt", testPromptHandler[struct{}])
101+
promptC := mcp.NewPrompt("cherry", "cherry prompt", testPromptHandler[struct{}])
102+
wantPrompts := []*mcp.Prompt{promptA.Prompt, promptB.Prompt, promptC.Prompt}
103+
104+
prompts := []*mcp.ServerPrompt{promptA, promptB, promptC}
105+
ctx := context.Background()
106+
clientSession, serverSession, server := createSessions(ctx)
107+
defer clientSession.Close()
108+
defer serverSession.Close()
109+
server.AddPrompts(prompts...)
110+
t.Run("ListPrompts", func(t *testing.T) {
111+
res, err := clientSession.ListPrompts(ctx, nil)
112+
if err != nil {
113+
t.Fatal("ListPrompts() failed:", err)
114+
}
115+
if diff := cmp.Diff(wantPrompts, res.Prompts, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
116+
t.Fatalf("ListPrompts() mismatch (-want +got):\n%s", diff)
117+
}
118+
})
119+
t.Run("PromptsIterator", func(t *testing.T) {
120+
testIterator(ctx, t, clientSession.Prompts(ctx, nil), wantPrompts)
121+
})
122+
}
123+
124+
func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, error], want []*T) {
125+
t.Helper()
126+
var got []*T
127+
for x, err := range seq {
128+
if err != nil {
129+
t.Fatalf("iteration failed: %v", err)
130+
}
131+
got = append(got, x)
132+
}
133+
if diff := cmp.Diff(want, got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
134+
t.Fatalf("mismatch (-want +got):\n%s", diff)
135+
}
136+
}

internal/mcp/client_test.go

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,6 @@ var allItems = []*Item{
5050
{"kilo", "val-K"},
5151
}
5252

53-
func toItemValueSlice(ptrSlice []*Item) []Item {
54-
var valSlice []Item
55-
for _, ptr := range ptrSlice {
56-
valSlice = append(valSlice, *ptr)
57-
}
58-
return valSlice
59-
}
60-
6153
// generatePaginatedResults is a helper to create a sequence of mock responses for pagination.
6254
// It simulates a server returning items in pages based on a given page size.
6355
func generatePaginatedResults(all []*Item, pageSize int) []*ListTestResult {
@@ -88,18 +80,18 @@ func TestClientPaginateBasic(t *testing.T) {
8880
results []*ListTestResult
8981
mockError error
9082
initialParams *ListTestParams
91-
expected []Item
83+
expected []*Item
9284
expectError bool
9385
}{
9486
{
9587
name: "SinglePageAllItems",
9688
results: generatePaginatedResults(allItems, len(allItems)),
97-
expected: toItemValueSlice(allItems),
89+
expected: allItems,
9890
},
9991
{
10092
name: "MultiplePages",
10193
results: generatePaginatedResults(allItems, 3),
102-
expected: toItemValueSlice(allItems),
94+
expected: allItems,
10395
},
10496
{
10597
name: "EmptyResults",
@@ -117,7 +109,7 @@ func TestClientPaginateBasic(t *testing.T) {
117109
name: "InitialCursorProvided",
118110
initialParams: &ListTestParams{Cursor: "cursor_2"},
119111
results: generatePaginatedResults(allItems[2:], 3),
120-
expected: toItemValueSlice(allItems[2:]),
112+
expected: allItems[2:],
121113
},
122114
{
123115
name: "CursorBeyondAllItems",
@@ -148,7 +140,7 @@ func TestClientPaginateBasic(t *testing.T) {
148140
params = &ListTestParams{}
149141
}
150142

151-
var gotItems []Item
143+
var gotItems []*Item
152144
var iterationErr error
153145
seq := paginate(ctx, params, listFunc, func(r *ListTestResult) []*Item { return r.Items })
154146
for item, err := range seq {
@@ -185,15 +177,15 @@ func TestClientPaginateVariousPageSizes(t *testing.T) {
185177
results = results[1:]
186178
return res, nil
187179
}
188-
var gotItems []Item
180+
var gotItems []*Item
189181
seq := paginate(ctx, &ListTestParams{}, listFunc, func(r *ListTestResult) []*Item { return r.Items })
190182
for item, err := range seq {
191183
if err != nil {
192184
t.Fatalf("paginate() unexpected error during iteration: %v", err)
193185
}
194186
gotItems = append(gotItems, item)
195187
}
196-
if diff := cmp.Diff(toItemValueSlice(allItems), gotItems, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
188+
if diff := cmp.Diff(allItems, gotItems, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
197189
t.Fatalf("paginate() mismatch (-want +got):\n%s", diff)
198190
}
199191
})

0 commit comments

Comments
 (0)