Skip to content

Commit eb04f54

Browse files
mcp-tls-support
Summary: - Basic TLS support for MCP server and client. - Added robot test `Concurrent psql and Reverse Proxy MCP HTTPS Server Query Tool`.
1 parent c37839a commit eb04f54

File tree

6 files changed

+112
-13
lines changed

6 files changed

+112
-13
lines changed

mcp_client/cmd/exec.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,15 @@ var execCmd = &cobra.Command{
4242
Long: `simple mcp client example
4343
`,
4444
Run: func(cmd *cobra.Command, args []string) {
45+
clientCfgMap := make(map[string]any)
46+
jsonErr := json.Unmarshal([]byte(clientCfgJSON), &clientCfgMap)
47+
if jsonErr != nil {
48+
panic(fmt.Sprintf("error unmarshaling client cfg json: %v", jsonErr))
49+
}
4550
client, setupErr := mcp_server.NewMCPClient(
4651
clientType,
4752
url,
53+
clientCfgMap,
4854
nil,
4955
)
5056
if setupErr != nil {
@@ -64,9 +70,9 @@ var execCmd = &cobra.Command{
6470
outputString = string(output)
6571
default:
6672
var args map[string]any
67-
jsonErr := json.Unmarshal([]byte(actionArgs), &args)
68-
if jsonErr != nil {
69-
panic(fmt.Sprintf("error unmarshaling action args: %v", jsonErr))
73+
jsonCfgErr := json.Unmarshal([]byte(actionArgs), &args)
74+
if jsonCfgErr != nil {
75+
panic(fmt.Sprintf("error unmarshaling action args: %v", jsonCfgErr))
7076
}
7177
rv, rvErr := client.CallToolText(actionName, args)
7278
if rvErr != nil {

mcp_client/cmd/root.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ import (
2424
)
2525

2626
var (
27-
clientType = "http"
28-
url = "127.0.0.1:9191"
27+
clientType = "http"
28+
url = "127.0.0.1:9191"
29+
clientCfgJSON = "{}"
2930
)
3031

3132
//nolint:revive,gochecknoglobals // explicit preferred
@@ -68,6 +69,7 @@ func init() {
6869

6970
rootCmd.PersistentFlags().StringVar(&clientType, "client-type", mcp_server.MCPClientTypeSTDIO, "MCP client type (http or stdio for now)")
7071
rootCmd.PersistentFlags().StringVar(&url, "url", "http://127.0.0.1:9876", "MCP server URL. Relevant for http and sse client types.")
72+
rootCmd.PersistentFlags().StringVar(&clientCfgJSON, "client-cfg", "{}", "MCP client configuration as JSON string")
7173

7274
// Here you will define your flags and configuration settings.
7375
// Cobra supports persistent flags, which, if defined here,

pkg/mcp_server/client.go

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@ package mcp_server //nolint:revive // fine for now
44

55
import (
66
"context"
7+
"crypto/tls"
8+
"crypto/x509"
79
"fmt"
810
"net/http"
11+
"os"
912

1013
"github.com/modelcontextprotocol/go-sdk/mcp"
1114
"github.com/sirupsen/logrus"
@@ -21,33 +24,68 @@ type MCPClient interface {
2124
CallToolText(toolName string, args map[string]any) (string, error)
2225
}
2326

24-
func NewMCPClient(clientType string, baseURL string, logger *logrus.Logger) (MCPClient, error) {
27+
func NewMCPClient(clientType string, baseURL string, clientCfgMap map[string]any, logger *logrus.Logger) (MCPClient, error) {
2528
switch clientType {
2629
case MCPClientTypeHTTP:
27-
return newHTTPMCPClient(baseURL, logger)
30+
return newHTTPMCPClient(baseURL, clientCfgMap, logger)
2831
case MCPClientTypeSTDIO:
2932
return newStdioMCPClient(logger)
3033
default:
3134
return nil, fmt.Errorf("unknown client type: %s", clientType)
3235
}
3336
}
3437

35-
func newHTTPMCPClient(baseURL string, logger *logrus.Logger) (MCPClient, error) {
38+
func getHTTPClient(clientCfgMap map[string]any) (*http.Client, error) {
39+
if clientCfgMap != nil && clientCfgMap["ca_file"] != nil {
40+
caFile, isString := clientCfgMap["ca_file"].(string)
41+
if !isString {
42+
return nil, fmt.Errorf("ca_file must be a string")
43+
}
44+
caCert, err := os.ReadFile(caFile)
45+
if err != nil {
46+
return nil, fmt.Errorf("failed to read CA file: %w", err)
47+
}
48+
caCertPool := x509.NewCertPool()
49+
caCertPool.AppendCertsFromPEM(caCert)
50+
51+
// Create a TLS configuration
52+
tlsConfig := &tls.Config{
53+
RootCAs: caCertPool, // Trust custom CA certificates
54+
MinVersion: tls.VersionTLS12, // Enforce minimum TLS version
55+
// InsecureSkipVerify: false, // Set to true to skip server certificate verification (NOT recommended for production)
56+
}
57+
58+
// Create a custom HTTP transport
59+
tr := &http.Transport{
60+
TLSClientConfig: tlsConfig,
61+
}
62+
return &http.Client{Transport: tr}, nil
63+
}
64+
return http.DefaultClient, nil
65+
}
66+
67+
func newHTTPMCPClient(baseURL string, clientCfgMap map[string]any, logger *logrus.Logger) (MCPClient, error) {
3668
if logger == nil {
3769
logger = logrus.New()
3870
logger.SetLevel(logrus.InfoLevel)
3971
}
72+
httpClient, httpClientErr := getHTTPClient(clientCfgMap)
73+
if httpClientErr != nil {
74+
return nil, fmt.Errorf("error creating HTTP client: %w", httpClientErr)
75+
}
4076
return &httpMCPClient{
4177
baseURL: baseURL,
42-
httpClient: http.DefaultClient,
78+
httpClient: httpClient,
4379
logger: logger,
80+
clientCfg: clientCfgMap,
4481
}, nil
4582
}
4683

4784
type httpMCPClient struct {
4885
baseURL string
4986
httpClient *http.Client
5087
logger *logrus.Logger
88+
clientCfg map[string]any
5189
}
5290

5391
func (c *httpMCPClient) connect() (*mcp.ClientSession, error) {

pkg/mcp_server/config.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ type ServerConfig struct {
5959
// Version is the server version advertised to clients.
6060
Version string `json:"version" yaml:"version"`
6161

62-
ConnectionCfg map[string]any `json:"connection_cfg,omitempty" yaml:"connection_cfg,omitempty"`
62+
TLSCertFile string `json:"tls_cert_file,omitempty" yaml:"tls_cert_file,omitempty"`
63+
TLSKeyFile string `json:"tls_key_file,omitempty" yaml:"tls_key_file,omitempty"`
64+
65+
TransportCfg map[string]any `json:"transport_cfg,omitempty" yaml:"transport_cfg,omitempty"`
6366

6467
// Description is a human-readable description of the server.
6568
Description string `json:"description" yaml:"description"`

pkg/mcp_server/server.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ type simpleMCPServer struct {
4343
servers []io.Closer // Track all running servers for cleanup
4444
}
4545

46-
func (s *simpleMCPServer) runHTTPServer(server *mcp.Server, address string) error {
46+
func (s *simpleMCPServer) runHTTPServer(server *mcp.Server, config *Config) error {
4747
// Create the streamable HTTP handler.
48+
address := config.GetServerAddress()
4849
handler := mcp.NewStreamableHTTPHandler(func(req *http.Request) *mcp.Server {
4950
return server
5051
}, nil)
@@ -56,6 +57,15 @@ func (s *simpleMCPServer) runHTTPServer(server *mcp.Server, address string) erro
5657

5758
// Start the HTTP server with logging handler.
5859
//nolint:gosec // TODO: find viable alternative to http.ListenAndServe
60+
if config.Server.TLSCertFile != "" && config.Server.TLSKeyFile != "" {
61+
s.logger.Infof("Starting HTTPS server on %s", address)
62+
if err := http.ListenAndServeTLS(address, config.Server.TLSCertFile, config.Server.TLSKeyFile, handlerWithLogging); err != nil {
63+
s.logger.Errorf("HTTPS Server failed: %v", err)
64+
return err
65+
}
66+
return nil
67+
}
68+
//nolint:gosec // TODO: find viable alternative to http.ListenAndServe
5969
if err := http.ListenAndServe(address, handlerWithLogging); err != nil {
6070
s.logger.Errorf("Server failed: %v", err)
6171
return err
@@ -478,9 +488,9 @@ func (s *simpleMCPServer) Start(ctx context.Context) error {
478488
func (s *simpleMCPServer) run(ctx context.Context) error {
479489
switch s.config.GetServerTransport() {
480490
case serverTransportHTTP:
481-
return s.runHTTPServer(s.server, s.config.GetServerAddress())
491+
return s.runHTTPServer(s.server, s.config)
482492
case serverTransportSSE:
483-
return fmt.Errorf("SSE transport not yet implemented")
493+
return fmt.Errorf("SSE transport obsoleted; use streamable HTTP transport instead")
484494
case serverTransportStdIO:
485495
// Default to stdio transport
486496
return s.server.Run(ctx, &mcp.StdioTransport{})

test/robot/functional/mcp.robot

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ Start MCP Servers
3939
... \-\-tls.allowInsecure
4040
... \-\-pgsrv.port
4141
... 5445
42+
Start Process ${STACKQL_EXE}
43+
... srv
44+
... \-\-mcp.server.type\=reverse_proxy
45+
... \-\-mcp.config
46+
... {"server": {"tls_cert_file": "test/server/mtls/credentials/pg_server_cert.pem", "tls_key_file": "test/server/mtls/credentials/pg_server_key.pem", "transport": "http", "address": "127.0.0.1:9004"}, "backend": {"dsn": "postgres:\/\/stackql:stackql@127.0.0.1:5446?default_query_exec_mode\=simple_protocol"} }
47+
... \-\-registry
48+
... ${REGISTRY_NO_VERIFY_CFG_JSON_STR}
49+
... \-\-auth
50+
... ${AUTH_CFG_STR}
51+
... \-\-tls.allowInsecure
52+
... \-\-pgsrv.port
53+
... 5446
4254
Sleep 5s
4355

4456
*** Settings ***
@@ -198,3 +210,31 @@ Concurrent psql and Reverse Proxy MCP HTTP Server Query Tool
198210
... stderr=${CURDIR}${/}tmp${/}Concurrent-psql-and-Reverse-Proxy-MCP-HTTP-Server-Query-Tool-psql-stderr.txt
199211
Should Contain ${psql_client_result.stdout} cloudkms.googleapis.com
200212
Should Be Equal As Integers ${psql_client_result.rc} 0
213+
214+
Concurrent psql and Reverse Proxy MCP HTTPS Server Query Tool
215+
Pass Execution If "%{IS_SKIP_MCP_TEST=false}" == "true" Some platforms do not have the MCP client available
216+
Sleep 5s
217+
${mcp_client_result}= Run Process ${STACKQL_MCP_CLIENT_EXE}
218+
... exec
219+
... \-\-client\-type\=http
220+
... \-\-url\=https://127.0.0.1:9004
221+
... \-\-client\-cfg { "ca_file": "test/server/mtls/credentials/pg_server_cert.pem" }
222+
... \-\-exec.action query_v2
223+
... \-\-exec.args {"sql": "SELECT assetType, count(*) as asset_count FROM google.cloudasset.assets WHERE parentType \= 'projects' and parent \= 'testing-project' GROUP BY assetType order by count(*) desc, assetType desc;"}
224+
... stdout=${CURDIR}${/}tmp${/}Concurrent-psql-and-Reverse-Proxy-MCP-HTTPS-Server-Query-Tool.txt
225+
... stderr=${CURDIR}${/}tmp${/}Concurrent-psql-and-Reverse-Proxy-MCP-HTTPS-Server-Query-Tool-stderr.txt
226+
Should Contain ${mcp_client_result.stdout} cloudkms.googleapis.com
227+
Should Be Equal As Integers ${mcp_client_result.rc} 0
228+
${posixInput} = Catenate
229+
... "${PSQL_EXE}" -d postgres://stackql:stackql@127.0.0.1:5445 -c
230+
... "SELECT assetType, count(*) as asset_count FROM google.cloudasset.assets WHERE parentType = 'projects' and parent = 'testing-project' GROUP BY assetType order by count(*) desc, assetType desc;"
231+
${windowsInput} = Catenate
232+
... & ${posixInput}
233+
${input} = Set Variable If "${IS_WINDOWS}" == "1" ${windowsInput} ${posixInput}
234+
${shellExe} = Set Variable If "${IS_WINDOWS}" == "1" powershell sh
235+
${psql_client_result}= Run Process
236+
... ${shellExe} \-c ${input}
237+
... stdout=${CURDIR}${/}tmp${/}Concurrent-psql-and-Reverse-Proxy-MCP-HTTPS-Server-Query-Tool-psql.txt
238+
... stderr=${CURDIR}${/}tmp${/}Concurrent-psql-and-Reverse-Proxy-MCP-HTTPS-Server-Query-Tool-psql-stderr.txt
239+
Should Contain ${psql_client_result.stdout} cloudkms.googleapis.com
240+
Should Be Equal As Integers ${psql_client_result.rc} 0

0 commit comments

Comments
 (0)