Skip to content

Commit 9ac9a61

Browse files
feat: integrate go-sdk SchemaProvider interface (Phase 2)
- Added replace directive in go.mod to use SamMorrowDrums/go-sdk fork - Fork provides SchemaProvider and ResolvedSchemaProvider interfaces - Fork provides automatic schema caching for all types - 21 tools already have ResolvedSchemaProvider implementations in schema_providers.go Co-authored-by: SamMorrowDrums <4811358+SamMorrowDrums@users.noreply.github.com>
1 parent ac77b3c commit 9ac9a61

File tree

8 files changed

+958
-577
lines changed

8 files changed

+958
-577
lines changed

pkg/github/helper_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ import (
1010
"github.com/stretchr/testify/require"
1111
)
1212

13+
// mapToTypedInput converts a map[string]interface{} to a typed struct using JSON marshaling.
14+
// This is useful for tests that need to pass typed input to handlers.
15+
func mapToTypedInput[T any](t *testing.T, m map[string]interface{}) T {
16+
t.Helper()
17+
var result T
18+
jsonBytes, err := json.Marshal(m)
19+
require.NoError(t, err, "failed to marshal map to JSON")
20+
err = json.Unmarshal(jsonBytes, &result)
21+
require.NoError(t, err, "failed to unmarshal JSON to typed input")
22+
return result
23+
}
24+
1325
type expectations struct {
1426
path string
1527
queryParams map[string]string

pkg/github/issues.go

Lines changed: 38 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -229,69 +229,29 @@ func fragmentToIssue(fragment IssueFragment) *github.Issue {
229229
}
230230

231231
// IssueRead creates a tool to get details of a specific issue in a GitHub repository.
232-
func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
233-
schema := &jsonschema.Schema{
234-
Type: "object",
235-
Properties: map[string]*jsonschema.Schema{
236-
"method": {
237-
Type: "string",
238-
Description: `The read operation to perform on a single issue.
239-
Options are:
240-
1. get - Get details of a specific issue.
241-
2. get_comments - Get issue comments.
242-
3. get_sub_issues - Get sub-issues of the issue.
243-
4. get_labels - Get labels assigned to the issue.
244-
`,
245-
Enum: []any{"get", "get_comments", "get_sub_issues", "get_labels"},
246-
},
247-
"owner": {
248-
Type: "string",
249-
Description: "The owner of the repository",
250-
},
251-
"repo": {
252-
Type: "string",
253-
Description: "The name of the repository",
254-
},
255-
"issue_number": {
256-
Type: "number",
257-
Description: "The number of the issue",
258-
},
259-
},
260-
Required: []string{"method", "owner", "repo", "issue_number"},
261-
}
262-
WithPagination(schema)
263-
232+
func IssueRead(getClient GetClientFn, getGQLClient GetGQLClientFn, cache *lockdown.RepoAccessCache, t translations.TranslationHelperFunc, flags FeatureFlags) (mcp.Tool, mcp.ToolHandlerFor[IssueReadInput, any]) {
264233
return mcp.Tool{
265234
Name: "issue_read",
266235
Description: t("TOOL_ISSUE_READ_DESCRIPTION", "Get information about a specific issue in a GitHub repository."),
267236
Annotations: &mcp.ToolAnnotations{
268237
Title: t("TOOL_ISSUE_READ_USER_TITLE", "Get issue details"),
269238
ReadOnlyHint: true,
270239
},
271-
InputSchema: schema,
240+
InputSchema: IssueReadInput{}.MCPSchema(),
272241
},
273-
func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
274-
method, err := RequiredParam[string](args, "method")
275-
if err != nil {
276-
return utils.NewToolResultError(err.Error()), nil, nil
277-
}
278-
279-
owner, err := RequiredParam[string](args, "owner")
280-
if err != nil {
281-
return utils.NewToolResultError(err.Error()), nil, nil
242+
func(ctx context.Context, _ *mcp.CallToolRequest, input IssueReadInput) (*mcp.CallToolResult, any, error) {
243+
// Set pagination defaults
244+
page := input.Page
245+
if page == 0 {
246+
page = 1
282247
}
283-
repo, err := RequiredParam[string](args, "repo")
284-
if err != nil {
285-
return utils.NewToolResultError(err.Error()), nil, nil
248+
perPage := input.PerPage
249+
if perPage == 0 {
250+
perPage = 30
286251
}
287-
issueNumber, err := RequiredInt(args, "issue_number")
288-
if err != nil {
289-
return utils.NewToolResultError(err.Error()), nil, nil
290-
}
291-
292-
pagination, err := OptionalPaginationParams(args)
293-
if err != nil {
294-
return utils.NewToolResultError(err.Error()), nil, nil
252+
pagination := PaginationParams{
253+
Page: page,
254+
PerPage: perPage,
295255
}
296256

297257
client, err := getClient(ctx)
@@ -304,21 +264,21 @@ Options are:
304264
return utils.NewToolResultErrorFromErr("failed to get GitHub graphql client", err), nil, nil
305265
}
306266

307-
switch method {
267+
switch input.Method {
308268
case "get":
309-
result, err := GetIssue(ctx, client, cache, owner, repo, issueNumber, flags)
269+
result, err := GetIssue(ctx, client, cache, input.Owner, input.Repo, input.IssueNumber, flags)
310270
return result, nil, err
311271
case "get_comments":
312-
result, err := GetIssueComments(ctx, client, cache, owner, repo, issueNumber, pagination, flags)
272+
result, err := GetIssueComments(ctx, client, cache, input.Owner, input.Repo, input.IssueNumber, pagination, flags)
313273
return result, nil, err
314274
case "get_sub_issues":
315-
result, err := GetSubIssues(ctx, client, cache, owner, repo, issueNumber, pagination, flags)
275+
result, err := GetSubIssues(ctx, client, cache, input.Owner, input.Repo, input.IssueNumber, pagination, flags)
316276
return result, nil, err
317277
case "get_labels":
318-
result, err := GetIssueLabels(ctx, gqlClient, owner, repo, issueNumber)
278+
result, err := GetIssueLabels(ctx, gqlClient, input.Owner, input.Repo, input.IssueNumber)
319279
return result, nil, err
320280
default:
321-
return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", method)), nil, nil
281+
return utils.NewToolResultError(fmt.Sprintf("unknown method: %s", input.Method)), nil, nil
322282
}
323283
}
324284
}
@@ -1313,97 +1273,28 @@ func UpdateIssue(ctx context.Context, client *github.Client, gqlClient *githubv4
13131273
}
13141274

13151275
// ListIssues creates a tool to list and filter repository issues
1316-
func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
1317-
schema := &jsonschema.Schema{
1318-
Type: "object",
1319-
Properties: map[string]*jsonschema.Schema{
1320-
"owner": {
1321-
Type: "string",
1322-
Description: "Repository owner",
1323-
},
1324-
"repo": {
1325-
Type: "string",
1326-
Description: "Repository name",
1327-
},
1328-
"state": {
1329-
Type: "string",
1330-
Description: "Filter by state, by default both open and closed issues are returned when not provided",
1331-
Enum: []any{"OPEN", "CLOSED"},
1332-
},
1333-
"labels": {
1334-
Type: "array",
1335-
Description: "Filter by labels",
1336-
Items: &jsonschema.Schema{
1337-
Type: "string",
1338-
},
1339-
},
1340-
"orderBy": {
1341-
Type: "string",
1342-
Description: "Order issues by field. If provided, the 'direction' also needs to be provided.",
1343-
Enum: []any{"CREATED_AT", "UPDATED_AT", "COMMENTS"},
1344-
},
1345-
"direction": {
1346-
Type: "string",
1347-
Description: "Order direction. If provided, the 'orderBy' also needs to be provided.",
1348-
Enum: []any{"ASC", "DESC"},
1349-
},
1350-
"since": {
1351-
Type: "string",
1352-
Description: "Filter by date (ISO 8601 timestamp)",
1353-
},
1354-
},
1355-
Required: []string{"owner", "repo"},
1356-
}
1357-
WithCursorPagination(schema)
1358-
1276+
func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[ListIssuesInput, any]) {
13591277
return mcp.Tool{
13601278
Name: "list_issues",
13611279
Description: t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter."),
13621280
Annotations: &mcp.ToolAnnotations{
13631281
Title: t("TOOL_LIST_ISSUES_USER_TITLE", "List issues"),
13641282
ReadOnlyHint: true,
13651283
},
1366-
InputSchema: schema,
1284+
InputSchema: ListIssuesInput{}.MCPSchema(),
13671285
},
1368-
func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
1369-
owner, err := RequiredParam[string](args, "owner")
1370-
if err != nil {
1371-
return utils.NewToolResultError(err.Error()), nil, nil
1372-
}
1373-
repo, err := RequiredParam[string](args, "repo")
1374-
if err != nil {
1375-
return utils.NewToolResultError(err.Error()), nil, nil
1376-
}
1377-
1378-
// Set optional parameters if provided
1379-
state, err := OptionalParam[string](args, "state")
1380-
if err != nil {
1381-
return utils.NewToolResultError(err.Error()), nil, nil
1382-
}
1383-
1286+
func(ctx context.Context, _ *mcp.CallToolRequest, input ListIssuesInput) (*mcp.CallToolResult, any, error) {
13841287
// If the state has a value, cast into an array of strings
13851288
var states []githubv4.IssueState
1386-
if state != "" {
1387-
states = append(states, githubv4.IssueState(state))
1289+
if input.State != "" {
1290+
states = append(states, githubv4.IssueState(input.State))
13881291
} else {
13891292
states = []githubv4.IssueState{githubv4.IssueStateOpen, githubv4.IssueStateClosed}
13901293
}
13911294

1392-
// Get labels
1393-
labels, err := OptionalStringArrayParam(args, "labels")
1394-
if err != nil {
1395-
return utils.NewToolResultError(err.Error()), nil, nil
1396-
}
1397-
1398-
orderBy, err := OptionalParam[string](args, "orderBy")
1399-
if err != nil {
1400-
return utils.NewToolResultError(err.Error()), nil, nil
1401-
}
1402-
1403-
direction, err := OptionalParam[string](args, "direction")
1404-
if err != nil {
1405-
return utils.NewToolResultError(err.Error()), nil, nil
1406-
}
1295+
labels := input.Labels
1296+
orderBy := input.OrderBy
1297+
direction := input.Direction
14071298

14081299
// These variables are required for the GraphQL query to be set by default
14091300
// If orderBy is empty, default to CREATED_AT
@@ -1415,16 +1306,12 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun
14151306
direction = "DESC"
14161307
}
14171308

1418-
since, err := OptionalParam[string](args, "since")
1419-
if err != nil {
1420-
return utils.NewToolResultError(err.Error()), nil, nil
1421-
}
1422-
14231309
// There are two optional parameters: since and labels.
14241310
var sinceTime time.Time
14251311
var hasSince bool
1426-
if since != "" {
1427-
sinceTime, err = parseISOTimestamp(since)
1312+
if input.Since != "" {
1313+
var err error
1314+
sinceTime, err = parseISOTimestamp(input.Since)
14281315
if err != nil {
14291316
return utils.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil, nil
14301317
}
@@ -1433,39 +1320,28 @@ func ListIssues(getGQLClient GetGQLClientFn, t translations.TranslationHelperFun
14331320
hasLabels := len(labels) > 0
14341321

14351322
// Get pagination parameters and convert to GraphQL format
1436-
pagination, err := OptionalCursorPaginationParams(args)
1437-
if err != nil {
1438-
return nil, nil, err
1323+
perPage := input.PerPage
1324+
if perPage == 0 {
1325+
perPage = 30
14391326
}
1440-
1441-
// Check if someone tried to use page-based pagination instead of cursor-based
1442-
if _, pageProvided := args["page"]; pageProvided {
1443-
return utils.NewToolResultError("This tool uses cursor-based pagination. Use the 'after' parameter with the 'endCursor' value from the previous response instead of 'page'."), nil, nil
1327+
pagination := CursorPaginationParams{
1328+
PerPage: perPage,
1329+
After: input.After,
14441330
}
14451331

1446-
// Check if pagination parameters were explicitly provided
1447-
_, perPageProvided := args["perPage"]
1448-
paginationExplicit := perPageProvided
1449-
14501332
paginationParams, err := pagination.ToGraphQLParams()
14511333
if err != nil {
14521334
return nil, nil, err
14531335
}
14541336

1455-
// Use default of 30 if pagination was not explicitly provided
1456-
if !paginationExplicit {
1457-
defaultFirst := int32(DefaultGraphQLPageSize)
1458-
paginationParams.First = &defaultFirst
1459-
}
1460-
14611337
client, err := getGQLClient(ctx)
14621338
if err != nil {
14631339
return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil
14641340
}
14651341

14661342
vars := map[string]interface{}{
1467-
"owner": githubv4.String(owner),
1468-
"repo": githubv4.String(repo),
1343+
"owner": githubv4.String(input.Owner),
1344+
"repo": githubv4.String(input.Repo),
14691345
"states": states,
14701346
"orderBy": githubv4.IssueOrderField(orderBy),
14711347
"direction": githubv4.OrderDirection(direction),

pkg/github/issues_test.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ func Test_GetIssue(t *testing.T) {
334334
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), cache, translations.NullTranslationHelper, flags)
335335

336336
request := createMCPRequest(tc.requestArgs)
337-
result, _, err := handler(context.Background(), &request, tc.requestArgs)
337+
typedInput := mapToTypedInput[IssueReadInput](t, tc.requestArgs)
338+
result, _, err := handler(context.Background(), &request, typedInput)
338339

339340
if tc.expectHandlerError {
340341
require.Error(t, err)
@@ -1244,7 +1245,8 @@ func Test_ListIssues(t *testing.T) {
12441245
_, handler := ListIssues(stubGetGQLClientFn(gqlClient), translations.NullTranslationHelper)
12451246

12461247
req := createMCPRequest(tc.reqParams)
1247-
res, _, err := handler(context.Background(), &req, tc.reqParams)
1248+
typedInput := mapToTypedInput[ListIssuesInput](t, tc.reqParams)
1249+
res, _, err := handler(context.Background(), &req, typedInput)
12481250
text := getTextResult(t, res).Text
12491251

12501252
if tc.expectError {
@@ -1988,9 +1990,10 @@ func Test_GetIssueComments(t *testing.T) {
19881990

19891991
// Create call request
19901992
request := createMCPRequest(tc.requestArgs)
1993+
typedInput := mapToTypedInput[IssueReadInput](t, tc.requestArgs)
19911994

19921995
// Call handler
1993-
result, _, err := handler(context.Background(), &request, tc.requestArgs)
1996+
result, _, err := handler(context.Background(), &request, typedInput)
19941997

19951998
// Verify results
19961999
if tc.expectError {
@@ -2102,7 +2105,8 @@ func Test_GetIssueLabels(t *testing.T) {
21022105
_, handler := IssueRead(stubGetClientFn(client), stubGetGQLClientFn(gqlClient), stubRepoAccessCache(gqlClient, 15*time.Minute), translations.NullTranslationHelper, stubFeatureFlags(map[string]bool{"lockdown-mode": false}))
21032106

21042107
request := createMCPRequest(tc.requestArgs)
2105-
result, _, err := handler(context.Background(), &request, tc.requestArgs)
2108+
typedInput := mapToTypedInput[IssueReadInput](t, tc.requestArgs)
2109+
result, _, err := handler(context.Background(), &request, typedInput)
21062110

21072111
require.NoError(t, err)
21082112
assert.NotNil(t, result)
@@ -2991,9 +2995,10 @@ func Test_GetSubIssues(t *testing.T) {
29912995

29922996
// Create call request
29932997
request := createMCPRequest(tc.requestArgs)
2998+
typedInput := mapToTypedInput[IssueReadInput](t, tc.requestArgs)
29942999

29953000
// Call handler
2996-
result, _, err := handler(context.Background(), &request, tc.requestArgs)
3001+
result, _, err := handler(context.Background(), &request, typedInput)
29973002

29983003
// Verify results
29993004
if tc.expectError {

0 commit comments

Comments
 (0)