Skip to content

Commit 68cd59e

Browse files
- Updated automated testing.
1 parent eb04f54 commit 68cd59e

File tree

2 files changed

+197
-12
lines changed

2 files changed

+197
-12
lines changed

pkg/mcp_server/client.go

Lines changed: 194 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ package mcp_server //nolint:revive // fine for now
33
// create an http client that can talk to the mcp server
44

55
import (
6+
"bytes"
67
"context"
8+
"crypto/sha256"
79
"crypto/tls"
810
"crypto/x509"
11+
"encoding/pem"
912
"fmt"
1013
"net/http"
14+
"net/url"
1115
"os"
1216

1317
"github.com/modelcontextprotocol/go-sdk/mcp"
@@ -35,30 +39,209 @@ func NewMCPClient(clientType string, baseURL string, clientCfgMap map[string]any
3539
}
3640
}
3741

38-
func getHTTPClient(clientCfgMap map[string]any) (*http.Client, error) {
42+
//nolint:nestif,gocognit,gocyclo,cyclop,funlen // complex but acceptable for now
43+
func getHTTPClient(logger *logrus.Logger, clientCfgMap map[string]any) (*http.Client, error) {
3944
if clientCfgMap != nil && clientCfgMap["ca_file"] != nil {
45+
logger.Infof("Configuring HTTP client with custom CA certificate")
4046
caFile, isString := clientCfgMap["ca_file"].(string)
4147
if !isString {
4248
return nil, fmt.Errorf("ca_file must be a string")
4349
}
44-
caCert, err := os.ReadFile(caFile)
50+
caBytes, err := os.ReadFile(caFile)
4551
if err != nil {
46-
return nil, fmt.Errorf("failed to read CA file: %w", err)
52+
return nil, fmt.Errorf("failed to read CA file '%s': %w", caFile, err)
4753
}
48-
caCertPool := x509.NewCertPool()
49-
caCertPool.AppendCertsFromPEM(caCert)
54+
logger.Infof("Read CA file '%s' (%d bytes)", caFile, len(caBytes))
5055

51-
// Create a TLS configuration
56+
// Start from system pool when possible
57+
var caCertPool *x509.CertPool
58+
if sysPool, sysErr := x509.SystemCertPool(); sysErr == nil && sysPool != nil {
59+
caCertPool = sysPool
60+
logger.Debug("Using system cert pool as base")
61+
} else {
62+
caCertPool = x509.NewCertPool()
63+
logger.Debug("System cert pool unavailable; using new pool")
64+
}
65+
66+
// Capture the first certificate (candidate leaf) for promote_leaf_to_ca.
67+
var firstCertRaw []byte
68+
{
69+
tmp := caBytes
70+
for {
71+
var blk *pem.Block
72+
blk, tmp = pem.Decode(tmp)
73+
if blk == nil {
74+
break
75+
}
76+
if blk.Type == "CERTIFICATE" {
77+
firstCertRaw = blk.Bytes
78+
break
79+
}
80+
}
81+
}
82+
83+
if ok := caCertPool.AppendCertsFromPEM(caBytes); !ok {
84+
// Fallback: manual decode to provide diagnostics
85+
logger.Warn("AppendCertsFromPEM returned false; attempting manual PEM decode for diagnostics")
86+
blockCount := 0
87+
validCerts := 0
88+
rest := caBytes
89+
for {
90+
var b *pem.Block
91+
b, rest = pem.Decode(rest)
92+
if b == nil {
93+
break
94+
}
95+
blockCount++
96+
if b.Type == "CERTIFICATE" {
97+
if _, perr := x509.ParseCertificate(b.Bytes); perr == nil {
98+
validCerts++
99+
} else {
100+
logger.Errorf("Failed to parse certificate PEM block %d: %v", blockCount, perr)
101+
}
102+
} else {
103+
logger.Debugf("Ignoring non-certificate PEM block type=%s", b.Type)
104+
}
105+
}
106+
return nil, fmt.Errorf("failed to append CA certificate '%s' into trust store: no valid CERTIFICATE PEM blocks found (blocks=%d, valid=%d)", caFile, blockCount, validCerts)
107+
} else {
108+
logger.Infof("Successfully appended custom CA(s) from '%s'", caFile)
109+
// Added: inspect for CA certificates
110+
rest := caBytes
111+
certBlockIdx := 0
112+
caCount := 0
113+
for {
114+
var b *pem.Block
115+
b, rest = pem.Decode(rest)
116+
if b == nil {
117+
break
118+
}
119+
if b.Type != "CERTIFICATE" {
120+
continue
121+
}
122+
certBlockIdx++
123+
parsed, perr := x509.ParseCertificate(b.Bytes)
124+
if perr != nil {
125+
logger.Debugf("Skipping unparsable certificate block %d: %v", certBlockIdx, perr)
126+
continue
127+
}
128+
if parsed.IsCA {
129+
caCount++
130+
}
131+
}
132+
if caCount == 0 {
133+
logger.Warnf("No CA certificates (IsCA=true) found in '%s'. If this file contains only the server leaf certificate it cannot establish standard trust. Supply the issuing CA (or chain) or enable 'promote_leaf_to_ca'.", caFile)
134+
} else {
135+
logger.Debugf("Detected %d CA certificate(s) in '%s'", caCount, caFile)
136+
}
137+
}
138+
139+
// Read optional flags
140+
promoteLeaf := false
141+
if v, ok := clientCfgMap["promote_leaf_to_ca"]; ok {
142+
b, okb := v.(bool)
143+
if !okb {
144+
return nil, fmt.Errorf("promote_leaf_to_ca must be a boolean")
145+
}
146+
promoteLeaf = b
147+
}
148+
149+
insecureSkipVerify := false
150+
if v, ok := clientCfgMap["insecure_skip_verify"]; ok {
151+
suppliedSkipVerify, isBool := v.(bool)
152+
if !isBool {
153+
return nil, fmt.Errorf("insecure_skip_verify must be a boolean")
154+
}
155+
insecureSkipVerify = suppliedSkipVerify
156+
}
157+
158+
var serverName string
159+
if v, ok := clientCfgMap["server_name"]; ok {
160+
if s, ok2 := v.(string); ok2 {
161+
serverName = s
162+
} else {
163+
return nil, fmt.Errorf("server_name must be a string")
164+
}
165+
}
166+
167+
// If server_name not supplied and base URL host differs from cert common name/SAN (common in IP usage),
168+
// user should supply server_name explicitly; we just log hint.
169+
if serverName == "" {
170+
if rawURL, ok := clientCfgMap["base_url"].(string); ok {
171+
if parsed, perr := url.Parse(rawURL); perr == nil && parsed.Hostname() != "" {
172+
// SNI will default to this hostname; log for clarity.
173+
logger.Debugf("Using implicit SNI server name '%s'", parsed.Hostname())
174+
}
175+
}
176+
} else {
177+
logger.Infof("Using explicit TLS server_name (SNI): %s", serverName)
178+
}
179+
180+
//nolint:gosec // testing client only
52181
tlsConfig := &tls.Config{
53-
RootCAs: caCertPool, // Trust custom CA certificates
54-
MinVersion: tls.VersionTLS12, // Enforce minimum TLS version
55-
// InsecureSkipVerify: false, // Set to true to skip server certificate verification (NOT recommended for production)
182+
RootCAs: caCertPool,
183+
MinVersion: tls.VersionTLS12,
184+
InsecureSkipVerify: insecureSkipVerify, // may be overridden below if promoting leaf
185+
ServerName: serverName,
186+
}
187+
188+
// If no CA certs and user wants to promote the leaf, install custom verifier.
189+
if promoteLeaf {
190+
if firstCertRaw == nil {
191+
return nil, fmt.Errorf("promote_leaf_to_ca enabled but no certificate PEM blocks found in '%s'", caFile)
192+
}
193+
// Re-parse to log fingerprint
194+
if leafCert, perr := x509.ParseCertificate(firstCertRaw); perr == nil {
195+
fp := sha256.Sum256(leafCert.Raw)
196+
logger.Warnf("Promoting leaf certificate (CN=%s, SHA256=%X) to trust anchor (non-CA). NOT recommended for production.", leafCert.Subject.CommonName, fp[:8])
197+
} else {
198+
logger.Warnf("Promoting leaf certificate (parse error for fingerprint: %v)", perr)
199+
}
200+
tlsConfig.InsecureSkipVerify = true // we will verify manually
201+
expected := make([]byte, len(firstCertRaw))
202+
copy(expected, firstCertRaw)
203+
204+
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
205+
if len(rawCerts) == 0 {
206+
return fmt.Errorf("no server certificates presented")
207+
}
208+
if !bytes.Equal(rawCerts[0], expected) {
209+
return fmt.Errorf("server leaf certificate mismatch with promoted leaf")
210+
}
211+
// Optionally parse for additional sanity
212+
if cert, certParseErr := x509.ParseCertificate(rawCerts[0]); certParseErr == nil {
213+
if serverName != "" && serverName != cert.Subject.CommonName {
214+
// Do hostname verification if a serverName was forced.
215+
if verr := cert.VerifyHostname(serverName); verr != nil {
216+
return fmt.Errorf("hostname verification failed for promoted leaf: %w", verr)
217+
}
218+
}
219+
}
220+
return nil
221+
}
56222
}
57223

58-
// Create a custom HTTP transport
59224
tr := &http.Transport{
60225
TLSClientConfig: tlsConfig,
61226
}
227+
228+
// Optionally apply TLS config globally so libraries using http.DefaultClient inherit it.
229+
if v, ok := clientCfgMap["apply_tls_globally"]; ok {
230+
if b, okb := v.(bool); !okb {
231+
return nil, fmt.Errorf("apply_tls_globally must be a boolean")
232+
} else if b {
233+
if defTr, okd := http.DefaultTransport.(*http.Transport); okd {
234+
// Shallow clone to avoid races; copy keeps other fields (proxy, dialer, etc.)
235+
cloned := defTr.Clone()
236+
cloned.TLSClientConfig = tlsConfig
237+
http.DefaultTransport = cloned
238+
logger.Warn("Applied custom TLS config globally (http.DefaultTransport). This affects all outbound HTTP requests in this process.")
239+
} else {
240+
logger.Warn("apply_tls_globally requested but http.DefaultTransport is not *http.Transport; skipped")
241+
}
242+
}
243+
}
244+
62245
return &http.Client{Transport: tr}, nil
63246
}
64247
return http.DefaultClient, nil
@@ -69,7 +252,7 @@ func newHTTPMCPClient(baseURL string, clientCfgMap map[string]any, logger *logru
69252
logger = logrus.New()
70253
logger.SetLevel(logrus.InfoLevel)
71254
}
72-
httpClient, httpClientErr := getHTTPClient(clientCfgMap)
255+
httpClient, httpClientErr := getHTTPClient(logger, clientCfgMap)
73256
if httpClientErr != nil {
74257
return nil, fmt.Errorf("error creating HTTP client: %w", httpClientErr)
75258
}

test/robot/functional/mcp.robot

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ Start MCP Servers
5151
... \-\-tls.allowInsecure
5252
... \-\-pgsrv.port
5353
... 5446
54+
... stdout=${CURDIR}${/}tmp${/}Stackql-MCP-Server-HTTPS.txt
55+
... stderr=${CURDIR}${/}tmp${/}Stackql-MCP-Server-HTTPS-stderr.txt
5456
Sleep 5s
5557

5658
*** Settings ***
@@ -226,7 +228,7 @@ Concurrent psql and Reverse Proxy MCP HTTPS Server Query Tool
226228
Should Contain ${mcp_client_result.stdout} cloudkms.googleapis.com
227229
Should Be Equal As Integers ${mcp_client_result.rc} 0
228230
${posixInput} = Catenate
229-
... "${PSQL_EXE}" -d postgres://stackql:stackql@127.0.0.1:5445 -c
231+
... "${PSQL_EXE}" -d postgres://stackql:stackql@127.0.0.1:5446 -c
230232
... "SELECT assetType, count(*) as asset_count FROM google.cloudasset.assets WHERE parentType = 'projects' and parent = 'testing-project' GROUP BY assetType order by count(*) desc, assetType desc;"
231233
${windowsInput} = Catenate
232234
... & ${posixInput}

0 commit comments

Comments
 (0)