Skip to content

Commit 7c1e4b0

Browse files
mcp-tls-support (#581)
Summary: - Basic TLS support for MCP server and MCP testing client. - Added robot test `Concurrent psql and Reverse Proxy MCP HTTPS Server Query Tool`.
1 parent c37839a commit 7c1e4b0

File tree

6 files changed

+297
-13
lines changed

6 files changed

+297
-13
lines changed

mcp_client/cmd/exec.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,15 @@ var execCmd = &cobra.Command{
4242
Long: `simple mcp client example
4343
`,
4444
Run: func(cmd *cobra.Command, args []string) {
45+
clientCfgMap := make(map[string]any)
46+
jsonErr := json.Unmarshal([]byte(clientCfgJSON), &clientCfgMap)
47+
if jsonErr != nil {
48+
panic(fmt.Sprintf("error unmarshaling client cfg json: %v", jsonErr))
49+
}
4550
client, setupErr := mcp_server.NewMCPClient(
4651
clientType,
4752
url,
53+
clientCfgMap,
4854
nil,
4955
)
5056
if setupErr != nil {
@@ -64,9 +70,9 @@ var execCmd = &cobra.Command{
6470
outputString = string(output)
6571
default:
6672
var args map[string]any
67-
jsonErr := json.Unmarshal([]byte(actionArgs), &args)
68-
if jsonErr != nil {
69-
panic(fmt.Sprintf("error unmarshaling action args: %v", jsonErr))
73+
jsonCfgErr := json.Unmarshal([]byte(actionArgs), &args)
74+
if jsonCfgErr != nil {
75+
panic(fmt.Sprintf("error unmarshaling action args: %v", jsonCfgErr))
7076
}
7177
rv, rvErr := client.CallToolText(actionName, args)
7278
if rvErr != nil {

mcp_client/cmd/root.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ import (
2424
)
2525

2626
var (
27-
clientType = "http"
28-
url = "127.0.0.1:9191"
27+
clientType = "http"
28+
url = "127.0.0.1:9191"
29+
clientCfgJSON = "{}"
2930
)
3031

3132
//nolint:revive,gochecknoglobals // explicit preferred
@@ -68,6 +69,7 @@ func init() {
6869

6970
rootCmd.PersistentFlags().StringVar(&clientType, "client-type", mcp_server.MCPClientTypeSTDIO, "MCP client type (http or stdio for now)")
7071
rootCmd.PersistentFlags().StringVar(&url, "url", "http://127.0.0.1:9876", "MCP server URL. Relevant for http and sse client types.")
72+
rootCmd.PersistentFlags().StringVar(&clientCfgJSON, "client-cfg", "{}", "MCP client configuration as JSON string")
7173

7274
// Here you will define your flags and configuration settings.
7375
// Cobra supports persistent flags, which, if defined here,

pkg/mcp_server/client.go

Lines changed: 225 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@ package mcp_server //nolint:revive // fine for now
33
// create an http client that can talk to the mcp server
44

55
import (
6+
"bytes"
67
"context"
8+
"crypto/sha256"
9+
"crypto/tls"
10+
"crypto/x509"
11+
"encoding/pem"
712
"fmt"
813
"net/http"
14+
"net/url"
15+
"os"
916

1017
"github.com/modelcontextprotocol/go-sdk/mcp"
1118
"github.com/sirupsen/logrus"
@@ -21,33 +28,247 @@ type MCPClient interface {
2128
CallToolText(toolName string, args map[string]any) (string, error)
2229
}
2330

24-
func NewMCPClient(clientType string, baseURL string, logger *logrus.Logger) (MCPClient, error) {
31+
func NewMCPClient(clientType string, baseURL string, clientCfgMap map[string]any, logger *logrus.Logger) (MCPClient, error) {
2532
switch clientType {
2633
case MCPClientTypeHTTP:
27-
return newHTTPMCPClient(baseURL, logger)
34+
return newHTTPMCPClient(baseURL, clientCfgMap, logger)
2835
case MCPClientTypeSTDIO:
2936
return newStdioMCPClient(logger)
3037
default:
3138
return nil, fmt.Errorf("unknown client type: %s", clientType)
3239
}
3340
}
3441

35-
func newHTTPMCPClient(baseURL string, logger *logrus.Logger) (MCPClient, error) {
42+
//nolint:nestif,gocognit,gocyclo,cyclop,funlen // complex but acceptable for now
43+
func getHTTPClient(logger *logrus.Logger, clientCfgMap map[string]any) (*http.Client, error) {
44+
if clientCfgMap != nil && clientCfgMap["ca_file"] != nil {
45+
logger.Infof("Configuring HTTP client with custom CA certificate")
46+
caFile, isString := clientCfgMap["ca_file"].(string)
47+
if !isString {
48+
return nil, fmt.Errorf("ca_file must be a string")
49+
}
50+
caBytes, err := os.ReadFile(caFile)
51+
if err != nil {
52+
return nil, fmt.Errorf("failed to read CA file '%s': %w", caFile, err)
53+
}
54+
logger.Infof("Read CA file '%s' (%d bytes)", caFile, len(caBytes))
55+
56+
// Start from system pool when possible
57+
var caCertPool *x509.CertPool
58+
if sysPool, sysErr := x509.SystemCertPool(); sysErr == nil && sysPool != nil {
59+
caCertPool = sysPool
60+
logger.Debug("Using system cert pool as base")
61+
} else {
62+
caCertPool = x509.NewCertPool()
63+
logger.Debug("System cert pool unavailable; using new pool")
64+
}
65+
66+
// Capture the first certificate (candidate leaf) for promote_leaf_to_ca.
67+
var firstCertRaw []byte
68+
{
69+
tmp := caBytes
70+
for {
71+
var blk *pem.Block
72+
blk, tmp = pem.Decode(tmp)
73+
if blk == nil {
74+
break
75+
}
76+
if blk.Type == "CERTIFICATE" {
77+
firstCertRaw = blk.Bytes
78+
break
79+
}
80+
}
81+
}
82+
83+
if ok := caCertPool.AppendCertsFromPEM(caBytes); !ok {
84+
// Fallback: manual decode to provide diagnostics
85+
logger.Warn("AppendCertsFromPEM returned false; attempting manual PEM decode for diagnostics")
86+
blockCount := 0
87+
validCerts := 0
88+
rest := caBytes
89+
for {
90+
var b *pem.Block
91+
b, rest = pem.Decode(rest)
92+
if b == nil {
93+
break
94+
}
95+
blockCount++
96+
if b.Type == "CERTIFICATE" {
97+
if _, perr := x509.ParseCertificate(b.Bytes); perr == nil {
98+
validCerts++
99+
} else {
100+
logger.Errorf("Failed to parse certificate PEM block %d: %v", blockCount, perr)
101+
}
102+
} else {
103+
logger.Debugf("Ignoring non-certificate PEM block type=%s", b.Type)
104+
}
105+
}
106+
return nil, fmt.Errorf("failed to append CA certificate '%s' into trust store: no valid CERTIFICATE PEM blocks found (blocks=%d, valid=%d)", caFile, blockCount, validCerts)
107+
} else {
108+
logger.Infof("Successfully appended custom CA(s) from '%s'", caFile)
109+
// Added: inspect for CA certificates
110+
rest := caBytes
111+
certBlockIdx := 0
112+
caCount := 0
113+
for {
114+
var b *pem.Block
115+
b, rest = pem.Decode(rest)
116+
if b == nil {
117+
break
118+
}
119+
if b.Type != "CERTIFICATE" {
120+
continue
121+
}
122+
certBlockIdx++
123+
parsed, perr := x509.ParseCertificate(b.Bytes)
124+
if perr != nil {
125+
logger.Debugf("Skipping unparsable certificate block %d: %v", certBlockIdx, perr)
126+
continue
127+
}
128+
if parsed.IsCA {
129+
caCount++
130+
}
131+
}
132+
if caCount == 0 {
133+
logger.Warnf("No CA certificates (IsCA=true) found in '%s'. If this file contains only the server leaf certificate it cannot establish standard trust. Supply the issuing CA (or chain) or enable 'promote_leaf_to_ca'.", caFile)
134+
} else {
135+
logger.Debugf("Detected %d CA certificate(s) in '%s'", caCount, caFile)
136+
}
137+
}
138+
139+
// Read optional flags
140+
promoteLeaf := false
141+
if v, ok := clientCfgMap["promote_leaf_to_ca"]; ok {
142+
b, okb := v.(bool)
143+
if !okb {
144+
return nil, fmt.Errorf("promote_leaf_to_ca must be a boolean")
145+
}
146+
promoteLeaf = b
147+
}
148+
149+
insecureSkipVerify := false
150+
if v, ok := clientCfgMap["insecure_skip_verify"]; ok {
151+
suppliedSkipVerify, isBool := v.(bool)
152+
if !isBool {
153+
return nil, fmt.Errorf("insecure_skip_verify must be a boolean")
154+
}
155+
insecureSkipVerify = suppliedSkipVerify
156+
}
157+
158+
var serverName string
159+
if v, ok := clientCfgMap["server_name"]; ok {
160+
if s, ok2 := v.(string); ok2 {
161+
serverName = s
162+
} else {
163+
return nil, fmt.Errorf("server_name must be a string")
164+
}
165+
}
166+
167+
// If server_name not supplied and base URL host differs from cert common name/SAN (common in IP usage),
168+
// user should supply server_name explicitly; we just log hint.
169+
if serverName == "" {
170+
if rawURL, ok := clientCfgMap["base_url"].(string); ok {
171+
if parsed, perr := url.Parse(rawURL); perr == nil && parsed.Hostname() != "" {
172+
// SNI will default to this hostname; log for clarity.
173+
logger.Debugf("Using implicit SNI server name '%s'", parsed.Hostname())
174+
}
175+
}
176+
} else {
177+
logger.Infof("Using explicit TLS server_name (SNI): %s", serverName)
178+
}
179+
180+
//nolint:gosec // testing client only
181+
tlsConfig := &tls.Config{
182+
RootCAs: caCertPool,
183+
MinVersion: tls.VersionTLS12,
184+
InsecureSkipVerify: insecureSkipVerify, // may be overridden below if promoting leaf
185+
ServerName: serverName,
186+
}
187+
188+
// If no CA certs and user wants to promote the leaf, install custom verifier.
189+
if promoteLeaf {
190+
if firstCertRaw == nil {
191+
return nil, fmt.Errorf("promote_leaf_to_ca enabled but no certificate PEM blocks found in '%s'", caFile)
192+
}
193+
// Re-parse to log fingerprint
194+
if leafCert, perr := x509.ParseCertificate(firstCertRaw); perr == nil {
195+
fp := sha256.Sum256(leafCert.Raw)
196+
logger.Warnf("Promoting leaf certificate (CN=%s, SHA256=%X) to trust anchor (non-CA). NOT recommended for production.", leafCert.Subject.CommonName, fp[:8])
197+
} else {
198+
logger.Warnf("Promoting leaf certificate (parse error for fingerprint: %v)", perr)
199+
}
200+
tlsConfig.InsecureSkipVerify = true // we will verify manually
201+
expected := make([]byte, len(firstCertRaw))
202+
copy(expected, firstCertRaw)
203+
204+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
205+
if len(rawCerts) == 0 {
206+
return fmt.Errorf("no server certificates presented")
207+
}
208+
if !bytes.Equal(rawCerts[0], expected) {
209+
return fmt.Errorf("server leaf certificate mismatch with promoted leaf")
210+
}
211+
// Optionally parse for additional sanity
212+
if cert, certParseErr := x509.ParseCertificate(rawCerts[0]); certParseErr == nil {
213+
if serverName != "" && serverName != cert.Subject.CommonName {
214+
// Do hostname verification if a serverName was forced.
215+
if verr := cert.VerifyHostname(serverName); verr != nil {
216+
return fmt.Errorf("hostname verification failed for promoted leaf: %w", verr)
217+
}
218+
}
219+
}
220+
return nil
221+
}
222+
}
223+
224+
tr := &http.Transport{
225+
TLSClientConfig: tlsConfig,
226+
}
227+
228+
// Optionally apply TLS config globally so libraries using http.DefaultClient inherit it.
229+
if v, ok := clientCfgMap["apply_tls_globally"]; ok {
230+
if b, okb := v.(bool); !okb {
231+
return nil, fmt.Errorf("apply_tls_globally must be a boolean")
232+
} else if b {
233+
if defTr, okd := http.DefaultTransport.(*http.Transport); okd {
234+
// Shallow clone to avoid races; copy keeps other fields (proxy, dialer, etc.)
235+
cloned := defTr.Clone()
236+
cloned.TLSClientConfig = tlsConfig
237+
http.DefaultTransport = cloned
238+
logger.Warn("Applied custom TLS config globally (http.DefaultTransport). This affects all outbound HTTP requests in this process.")
239+
} else {
240+
logger.Warn("apply_tls_globally requested but http.DefaultTransport is not *http.Transport; skipped")
241+
}
242+
}
243+
}
244+
245+
return &http.Client{Transport: tr}, nil
246+
}
247+
return http.DefaultClient, nil
248+
}
249+
250+
func newHTTPMCPClient(baseURL string, clientCfgMap map[string]any, logger *logrus.Logger) (MCPClient, error) {
36251
if logger == nil {
37252
logger = logrus.New()
38253
logger.SetLevel(logrus.InfoLevel)
39254
}
255+
httpClient, httpClientErr := getHTTPClient(logger, clientCfgMap)
256+
if httpClientErr != nil {
257+
return nil, fmt.Errorf("error creating HTTP client: %w", httpClientErr)
258+
}
40259
return &httpMCPClient{
41260
baseURL: baseURL,
42-
httpClient: http.DefaultClient,
261+
httpClient: httpClient,
43262
logger: logger,
263+
clientCfg: clientCfgMap,
44264
}, nil
45265
}
46266

47267
type httpMCPClient struct {
48268
baseURL string
49269
httpClient *http.Client
50270
logger *logrus.Logger
271+
clientCfg map[string]any
51272
}
52273

53274
func (c *httpMCPClient) connect() (*mcp.ClientSession, error) {

pkg/mcp_server/config.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ type ServerConfig struct {
5959
// Version is the server version advertised to clients.
6060
Version string `json:"version" yaml:"version"`
6161

62-
ConnectionCfg map[string]any `json:"connection_cfg,omitempty" yaml:"connection_cfg,omitempty"`
62+
TLSCertFile string `json:"tls_cert_file,omitempty" yaml:"tls_cert_file,omitempty"`
63+
TLSKeyFile string `json:"tls_key_file,omitempty" yaml:"tls_key_file,omitempty"`
64+
65+
TransportCfg map[string]any `json:"transport_cfg,omitempty" yaml:"transport_cfg,omitempty"`
6366

6467
// Description is a human-readable description of the server.
6568
Description string `json:"description" yaml:"description"`

pkg/mcp_server/server.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ type simpleMCPServer struct {
4343
servers []io.Closer // Track all running servers for cleanup
4444
}
4545

46-
func (s *simpleMCPServer) runHTTPServer(server *mcp.Server, address string) error {
46+
func (s *simpleMCPServer) runHTTPServer(server *mcp.Server, config *Config) error {
4747
// Create the streamable HTTP handler.
48+
address := config.GetServerAddress()
4849
handler := mcp.NewStreamableHTTPHandler(func(req *http.Request) *mcp.Server {
4950
return server
5051
}, nil)
@@ -56,6 +57,15 @@ func (s *simpleMCPServer) runHTTPServer(server *mcp.Server, address string) erro
5657

5758
// Start the HTTP server with logging handler.
5859
//nolint:gosec // TODO: find viable alternative to http.ListenAndServe
60+
if config.Server.TLSCertFile != "" && config.Server.TLSKeyFile != "" {
61+
s.logger.Infof("Starting HTTPS server on %s", address)
62+
if err := http.ListenAndServeTLS(address, config.Server.TLSCertFile, config.Server.TLSKeyFile, handlerWithLogging); err != nil {
63+
s.logger.Errorf("HTTPS Server failed: %v", err)
64+
return err
65+
}
66+
return nil
67+
}
68+
//nolint:gosec // TODO: find viable alternative to http.ListenAndServe
5969
if err := http.ListenAndServe(address, handlerWithLogging); err != nil {
6070
s.logger.Errorf("Server failed: %v", err)
6171
return err
@@ -478,9 +488,9 @@ func (s *simpleMCPServer) Start(ctx context.Context) error {
478488
func (s *simpleMCPServer) run(ctx context.Context) error {
479489
switch s.config.GetServerTransport() {
480490
case serverTransportHTTP:
481-
return s.runHTTPServer(s.server, s.config.GetServerAddress())
491+
return s.runHTTPServer(s.server, s.config)
482492
case serverTransportSSE:
483-
return fmt.Errorf("SSE transport not yet implemented")
493+
return fmt.Errorf("SSE transport obsoleted; use streamable HTTP transport instead")
484494
case serverTransportStdIO:
485495
// Default to stdio transport
486496
return s.server.Run(ctx, &mcp.StdioTransport{})

0 commit comments

Comments
 (0)