Skip to content

Commit 24b9cbb

Browse files
refactor(security): migrate code_scanning, secret_scanning, dependabot to NewTool pattern
Co-authored-by: Adam Holt <omgitsads@users.noreply.github.com>
1 parent d29e73b commit 24b9cbb

File tree

7 files changed

+407
-376
lines changed

7 files changed

+407
-376
lines changed

pkg/github/code_scanning.go

Lines changed: 99 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@ import (
88
"net/http"
99

1010
ghErrors "github.com/github/github-mcp-server/pkg/errors"
11+
"github.com/github/github-mcp-server/pkg/toolsets"
1112
"github.com/github/github-mcp-server/pkg/translations"
1213
"github.com/github/github-mcp-server/pkg/utils"
1314
"github.com/google/go-github/v79/github"
1415
"github.com/google/jsonschema-go/jsonschema"
1516
"github.com/modelcontextprotocol/go-sdk/mcp"
1617
)
1718

18-
func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
19-
return mcp.Tool{
19+
func GetCodeScanningAlert(t translations.TranslationHelperFunc) toolsets.ServerTool {
20+
return NewTool(
21+
mcp.Tool{
2022
Name: "get_code_scanning_alert",
2123
Description: t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository."),
2224
Annotations: &mcp.ToolAnnotations{
@@ -42,54 +44,58 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe
4244
Required: []string{"owner", "repo", "alertNumber"},
4345
},
4446
},
45-
func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
46-
owner, err := RequiredParam[string](args, "owner")
47-
if err != nil {
48-
return utils.NewToolResultError(err.Error()), nil, nil
49-
}
50-
repo, err := RequiredParam[string](args, "repo")
51-
if err != nil {
52-
return utils.NewToolResultError(err.Error()), nil, nil
53-
}
54-
alertNumber, err := RequiredInt(args, "alertNumber")
55-
if err != nil {
56-
return utils.NewToolResultError(err.Error()), nil, nil
57-
}
47+
func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] {
48+
return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
49+
owner, err := RequiredParam[string](args, "owner")
50+
if err != nil {
51+
return utils.NewToolResultError(err.Error()), nil, nil
52+
}
53+
repo, err := RequiredParam[string](args, "repo")
54+
if err != nil {
55+
return utils.NewToolResultError(err.Error()), nil, nil
56+
}
57+
alertNumber, err := RequiredInt(args, "alertNumber")
58+
if err != nil {
59+
return utils.NewToolResultError(err.Error()), nil, nil
60+
}
5861

59-
client, err := getClient(ctx)
60-
if err != nil {
61-
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
62-
}
62+
client, err := deps.GetClient(ctx)
63+
if err != nil {
64+
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
65+
}
6366

64-
alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber))
65-
if err != nil {
66-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
67-
"failed to get alert",
68-
resp,
69-
err,
70-
), nil, nil
71-
}
72-
defer func() { _ = resp.Body.Close() }()
67+
alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber))
68+
if err != nil {
69+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
70+
"failed to get alert",
71+
resp,
72+
err,
73+
), nil, nil
74+
}
75+
defer func() { _ = resp.Body.Close() }()
76+
77+
if resp.StatusCode != http.StatusOK {
78+
body, err := io.ReadAll(resp.Body)
79+
if err != nil {
80+
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
81+
}
82+
return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil
83+
}
7384

74-
if resp.StatusCode != http.StatusOK {
75-
body, err := io.ReadAll(resp.Body)
85+
r, err := json.Marshal(alert)
7686
if err != nil {
77-
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
87+
return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil
7888
}
79-
return utils.NewToolResultError(fmt.Sprintf("failed to get alert: %s", string(body))), nil, nil
80-
}
8189

82-
r, err := json.Marshal(alert)
83-
if err != nil {
84-
return utils.NewToolResultErrorFromErr("failed to marshal alert", err), nil, nil
90+
return utils.NewToolResultText(string(r)), nil, nil
8591
}
86-
87-
return utils.NewToolResultText(string(r)), nil, nil
88-
}
92+
},
93+
)
8994
}
9095

91-
func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.Tool, mcp.ToolHandlerFor[map[string]any, any]) {
92-
return mcp.Tool{
96+
func ListCodeScanningAlerts(t translations.TranslationHelperFunc) toolsets.ServerTool {
97+
return NewTool(
98+
mcp.Tool{
9399
Name: "list_code_scanning_alerts",
94100
Description: t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository."),
95101
Annotations: &mcp.ToolAnnotations{
@@ -130,59 +136,62 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel
130136
Required: []string{"owner", "repo"},
131137
},
132138
},
133-
func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
134-
owner, err := RequiredParam[string](args, "owner")
135-
if err != nil {
136-
return utils.NewToolResultError(err.Error()), nil, nil
137-
}
138-
repo, err := RequiredParam[string](args, "repo")
139-
if err != nil {
140-
return utils.NewToolResultError(err.Error()), nil, nil
141-
}
142-
ref, err := OptionalParam[string](args, "ref")
143-
if err != nil {
144-
return utils.NewToolResultError(err.Error()), nil, nil
145-
}
146-
state, err := OptionalParam[string](args, "state")
147-
if err != nil {
148-
return utils.NewToolResultError(err.Error()), nil, nil
149-
}
150-
severity, err := OptionalParam[string](args, "severity")
151-
if err != nil {
152-
return utils.NewToolResultError(err.Error()), nil, nil
153-
}
154-
toolName, err := OptionalParam[string](args, "tool_name")
155-
if err != nil {
156-
return utils.NewToolResultError(err.Error()), nil, nil
157-
}
139+
func(deps ToolDependencies) mcp.ToolHandlerFor[map[string]any, any] {
140+
return func(ctx context.Context, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
141+
owner, err := RequiredParam[string](args, "owner")
142+
if err != nil {
143+
return utils.NewToolResultError(err.Error()), nil, nil
144+
}
145+
repo, err := RequiredParam[string](args, "repo")
146+
if err != nil {
147+
return utils.NewToolResultError(err.Error()), nil, nil
148+
}
149+
ref, err := OptionalParam[string](args, "ref")
150+
if err != nil {
151+
return utils.NewToolResultError(err.Error()), nil, nil
152+
}
153+
state, err := OptionalParam[string](args, "state")
154+
if err != nil {
155+
return utils.NewToolResultError(err.Error()), nil, nil
156+
}
157+
severity, err := OptionalParam[string](args, "severity")
158+
if err != nil {
159+
return utils.NewToolResultError(err.Error()), nil, nil
160+
}
161+
toolName, err := OptionalParam[string](args, "tool_name")
162+
if err != nil {
163+
return utils.NewToolResultError(err.Error()), nil, nil
164+
}
158165

159-
client, err := getClient(ctx)
160-
if err != nil {
161-
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
162-
}
163-
alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName})
164-
if err != nil {
165-
return ghErrors.NewGitHubAPIErrorResponse(ctx,
166-
"failed to list alerts",
167-
resp,
168-
err,
169-
), nil, nil
170-
}
171-
defer func() { _ = resp.Body.Close() }()
166+
client, err := deps.GetClient(ctx)
167+
if err != nil {
168+
return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil
169+
}
170+
alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity, ToolName: toolName})
171+
if err != nil {
172+
return ghErrors.NewGitHubAPIErrorResponse(ctx,
173+
"failed to list alerts",
174+
resp,
175+
err,
176+
), nil, nil
177+
}
178+
defer func() { _ = resp.Body.Close() }()
179+
180+
if resp.StatusCode != http.StatusOK {
181+
body, err := io.ReadAll(resp.Body)
182+
if err != nil {
183+
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
184+
}
185+
return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil
186+
}
172187

173-
if resp.StatusCode != http.StatusOK {
174-
body, err := io.ReadAll(resp.Body)
188+
r, err := json.Marshal(alerts)
175189
if err != nil {
176-
return utils.NewToolResultErrorFromErr("failed to read response body", err), nil, nil
190+
return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil
177191
}
178-
return utils.NewToolResultError(fmt.Sprintf("failed to list alerts: %s", string(body))), nil, nil
179-
}
180192

181-
r, err := json.Marshal(alerts)
182-
if err != nil {
183-
return utils.NewToolResultErrorFromErr("failed to marshal alerts", err), nil, nil
193+
return utils.NewToolResultText(string(r)), nil, nil
184194
}
185-
186-
return utils.NewToolResultText(string(r)), nil, nil
187-
}
195+
},
196+
)
188197
}

pkg/github/code_scanning_test.go

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@ import (
1717

1818
func Test_GetCodeScanningAlert(t *testing.T) {
1919
// Verify tool definition once
20-
mockClient := github.NewClient(nil)
21-
tool, _ := GetCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper)
22-
require.NoError(t, toolsnaps.Test(tool.Name, tool))
20+
toolDef := GetCodeScanningAlert(translations.NullTranslationHelper)
21+
require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool))
2322

24-
assert.Equal(t, "get_code_scanning_alert", tool.Name)
25-
assert.NotEmpty(t, tool.Description)
23+
assert.Equal(t, "get_code_scanning_alert", toolDef.Tool.Name)
24+
assert.NotEmpty(t, toolDef.Tool.Description)
2625

2726
// InputSchema is of type any, need to cast to *jsonschema.Schema
28-
schema, ok := tool.InputSchema.(*jsonschema.Schema)
27+
schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema)
2928
require.True(t, ok, "InputSchema should be *jsonschema.Schema")
3029
assert.Contains(t, schema.Properties, "owner")
3130
assert.Contains(t, schema.Properties, "repo")
@@ -89,13 +88,16 @@ func Test_GetCodeScanningAlert(t *testing.T) {
8988
t.Run(tc.name, func(t *testing.T) {
9089
// Setup client with mock
9190
client := github.NewClient(tc.mockedClient)
92-
_, handler := GetCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper)
91+
deps := ToolDependencies{
92+
GetClient: stubGetClientFn(client),
93+
}
94+
handler := toolDef.Handler(deps)
9395

9496
// Create call request
9597
request := createMCPRequest(tc.requestArgs)
9698

9799
// Call handler with new signature
98-
result, _, err := handler(context.Background(), &request, tc.requestArgs)
100+
result, err := handler(context.Background(), &request)
99101

100102
// Verify results
101103
if tc.expectError {
@@ -127,15 +129,14 @@ func Test_GetCodeScanningAlert(t *testing.T) {
127129

128130
func Test_ListCodeScanningAlerts(t *testing.T) {
129131
// Verify tool definition once
130-
mockClient := github.NewClient(nil)
131-
tool, _ := ListCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper)
132-
require.NoError(t, toolsnaps.Test(tool.Name, tool))
132+
toolDef := ListCodeScanningAlerts(translations.NullTranslationHelper)
133+
require.NoError(t, toolsnaps.Test(toolDef.Tool.Name, toolDef.Tool))
133134

134-
assert.Equal(t, "list_code_scanning_alerts", tool.Name)
135-
assert.NotEmpty(t, tool.Description)
135+
assert.Equal(t, "list_code_scanning_alerts", toolDef.Tool.Name)
136+
assert.NotEmpty(t, toolDef.Tool.Description)
136137

137138
// InputSchema is of type any, need to cast to *jsonschema.Schema
138-
schema, ok := tool.InputSchema.(*jsonschema.Schema)
139+
schema, ok := toolDef.Tool.InputSchema.(*jsonschema.Schema)
139140
require.True(t, ok, "InputSchema should be *jsonschema.Schema")
140141
assert.Contains(t, schema.Properties, "owner")
141142
assert.Contains(t, schema.Properties, "repo")
@@ -219,13 +220,16 @@ func Test_ListCodeScanningAlerts(t *testing.T) {
219220
t.Run(tc.name, func(t *testing.T) {
220221
// Setup client with mock
221222
client := github.NewClient(tc.mockedClient)
222-
_, handler := ListCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper)
223+
deps := ToolDependencies{
224+
GetClient: stubGetClientFn(client),
225+
}
226+
handler := toolDef.Handler(deps)
223227

224228
// Create call request
225229
request := createMCPRequest(tc.requestArgs)
226230

227231
// Call handler with new signature
228-
result, _, err := handler(context.Background(), &request, tc.requestArgs)
232+
result, err := handler(context.Background(), &request)
229233

230234
// Verify results
231235
if tc.expectError {

0 commit comments

Comments
 (0)