From ec898fc358b22d6a37be49934bd4fa836fd710e0 Mon Sep 17 00:00:00 2001 From: Joris Scharp Date: Mon, 11 May 2026 18:51:30 +0200 Subject: [PATCH 1/2] refactor: centralise outbound URL validation in StrictHTTPClient Move core.ParsePublicURL from per-caller into StrictHTTPClient.Do so every outbound HTTP request gets HTTPS-only + no-IP + no-RFC2606-reserved-host validation in strict mode. Add CheckRedirect to re-validate every redirect target (10-hop cap matches net/http's default). Removes the duplicated validation in auth/openid4vci.Client, auth/client/iam.HTTPClient.ClientMetadata, and the scheme-only check in auth/services/oauth/relying_party.RequestRFC003AccessToken. All production callers route through *StrictHTTPClient. Behavior change: strict mode now rejects IP literals and RFC 2606 reserved hostnames on every outbound HTTP request, not just OpenID4VCI. Breaking API: oauth.NewRelyingParty no longer accepts a strictMode bool parameter; strict-mode behaviour is now controlled exclusively by the http/client.StrictMode flag set by the HTTP engine at startup. Closes #4244 --- auth/auth.go | 4 +- auth/client/iam/client.go | 6 - auth/openid4vci/client.go | 42 +----- auth/openid4vci/client_test.go | 35 ++--- auth/services/oauth/relying_party.go | 11 +- auth/services/oauth/relying_party_test.go | 37 ----- docs/pages/deployment/configuration.rst | 4 +- http/client/client.go | 35 +++-- http/client/client_test.go | 172 +++++++++++++++++++++- vcr/vcr_test.go | 4 +- 10 files changed, 225 insertions(+), 125 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 52528e4f0e..b3fd95d1cb 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -181,13 +181,13 @@ func (auth *Auth) Configure(config core.ServerConfig) error { // auth.http.config got deprecated in favor of httpclient.timeout auth.httpClientTimeout = config.HTTPClient.Timeout } - auth.openID4VCIClient = openid4vci.NewClient(httpclient.NewWithCache(auth.httpClientTimeout), auth.strictMode) + auth.openID4VCIClient = openid4vci.NewClient(httpclient.NewWithCache(auth.httpClientTimeout)) // V1 API related stuff accessTokenLifeSpan := time.Duration(auth.config.AccessTokenLifeSpan) * time.Second auth.authzServer = oauth.NewAuthorizationServer(auth.vdrInstance.Resolver(), auth.vcr, auth.vcr.Verifier(), auth.serviceResolver, auth.keyStore, auth.contractNotary, auth.jsonldManager, accessTokenLifeSpan) auth.relyingParty = oauth.NewRelyingParty(auth.vdrInstance.Resolver(), auth.serviceResolver, - auth.keyStore, auth.vcr.Wallet(), auth.httpClientTimeout, auth.tlsConfig, config.Strictmode, auth.pkiProvider) + auth.keyStore, auth.vcr.Wallet(), auth.httpClientTimeout, auth.tlsConfig, auth.pkiProvider) if err := auth.authzServer.Configure(auth.config.ClockSkew, config.Strictmode); err != nil { return err diff --git a/auth/client/iam/client.go b/auth/client/iam/client.go index 3112ede9bf..a9d441f862 100644 --- a/auth/client/iam/client.go +++ b/auth/client/iam/client.go @@ -85,12 +85,6 @@ func (hb HTTPClient) OAuthAuthorizationServerMetadata(ctx context.Context, oauth // ClientMetadata retrieves the client metadata from the client metadata endpoint given in the authorization request. // We use the AuthorizationServerMetadata struct since it overlaps greatly with the client metadata. func (hb HTTPClient) ClientMetadata(ctx context.Context, endpoint string) (*oauth.OAuthClientMetadata, error) { - _, err := core.ParsePublicURL(endpoint, hb.strictMode) - if err != nil { - return nil, err - } - - // create a GET request request, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err diff --git a/auth/openid4vci/client.go b/auth/openid4vci/client.go index 23f34b70e0..ac74a8d08d 100644 --- a/auth/openid4vci/client.go +++ b/auth/openid4vci/client.go @@ -73,44 +73,18 @@ type Client interface { } // NewClient returns a Client backed by the provided HTTP request doer. -// In production callers should pass *httpclient.StrictHTTPClient so the -// shared transport policies apply (HTTPS-in-strict, body size limit, -// User-Agent header). -// -// When strictMode is true, target URLs are additionally validated via -// core.ParsePublicURL: HTTPS scheme, no IP hosts, no reserved hostnames. -func NewClient(httpClient core.HTTPRequestDoer, strictMode bool) Client { - return &client{httpClient: httpClient, strictMode: strictMode} +// Production callers should pass *httpclient.StrictHTTPClient so the shared +// transport policies apply: HTTPS-only in strict mode, no IP/reserved hosts +// in strict mode (via core.ParsePublicURL), body size limit, User-Agent. +func NewClient(httpClient core.HTTPRequestDoer) Client { + return &client{httpClient: httpClient} } type client struct { httpClient core.HTTPRequestDoer - strictMode bool -} - -// validateURL guards against SSRF by rejecting target URLs that fail -// core.ParsePublicURL (in strict mode: HTTPS only, no IP/reserved hosts). -// Called at the entry of every method that makes outbound HTTP. -// -// TODO: this validation belongs on httpclient.StrictHTTPClient so every -// outbound HTTP call (not just OpenID4VCI) gets the IP/reserved-host check, -// not only the HTTPS scheme check that StrictHTTPClient.Do enforces today. -// Placed here for now to preserve parity with master, where the equivalent -// caller (auth/client/iam.HTTPClient) validated via oauth.IssuerIdToWellKnown -// → core.ParsePublicURL before issuing the request, and to address a CodeQL -// SSRF finding on this PR. Tracked as a follow-up to consolidate the check -// in the shared HTTP client. -func (c *client) validateURL(name, target string) error { - if _, err := core.ParsePublicURL(target, c.strictMode); err != nil { - return fmt.Errorf("openid4vci: invalid %s URL: %w", name, err) - } - return nil } func (c *client) OpenIDCredentialIssuerMetadata(ctx context.Context, issuerURL string) (*OpenIDCredentialIssuerMetadata, error) { - if err := c.validateURL("issuer", issuerURL); err != nil { - return nil, err - } // Per §12.2.1, the Credential Issuer Identifier MUST NOT contain query // or fragment components. if parsed, _ := url.Parse(issuerURL); parsed != nil && (parsed.RawQuery != "" || parsed.Fragment != "") { @@ -145,9 +119,6 @@ func (c *client) OpenIDCredentialIssuerMetadata(ctx context.Context, issuerURL s } func (c *client) RequestNonce(ctx context.Context, nonceEndpoint string) (string, error) { - if err := c.validateURL("nonce endpoint", nonceEndpoint); err != nil { - return "", err - } req, err := http.NewRequestWithContext(ctx, http.MethodPost, nonceEndpoint, http.NoBody) if err != nil { return "", err @@ -171,9 +142,6 @@ func (c *client) RequestNonce(ctx context.Context, nonceEndpoint string) (string } func (c *client) RequestCredential(ctx context.Context, opts RequestCredentialOpts) (*CredentialResponse, error) { - if err := c.validateURL("credential endpoint", opts.CredentialEndpoint); err != nil { - return nil, err - } body := CredentialRequest{ Proofs: &CredentialRequestProofs{ JWT: []string{opts.ProofJWT}, diff --git a/auth/openid4vci/client_test.go b/auth/openid4vci/client_test.go index c88034d64e..a094eb0867 100644 --- a/auth/openid4vci/client_test.go +++ b/auth/openid4vci/client_test.go @@ -42,7 +42,7 @@ func TestClient_RequestNonce(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) nonce, err := client.RequestNonce(context.Background(), srv.URL) require.NoError(t, err) assert.Equal(t, "test-nonce-123", nonce) @@ -54,7 +54,7 @@ func TestClient_RequestNonce(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.RequestNonce(context.Background(), srv.URL) require.Error(t, err) assert.Contains(t, err.Error(), "500") @@ -67,7 +67,7 @@ func TestClient_RequestNonce(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.RequestNonce(context.Background(), srv.URL) require.Error(t, err) assert.Contains(t, err.Error(), "empty c_nonce") @@ -91,7 +91,7 @@ func TestClient_OpenIDCredentialIssuerMetadata(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) metadata, err := client.OpenIDCredentialIssuerMetadata(context.Background(), srv.URL) require.NoError(t, err) require.NotNil(t, metadata) @@ -110,7 +110,7 @@ func TestClient_OpenIDCredentialIssuerMetadata(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.OpenIDCredentialIssuerMetadata(context.Background(), srv.URL+"/oauth2/alice") require.NoError(t, err) assert.Equal(t, "/.well-known/openid-credential-issuer/oauth2/alice", capturedPath) @@ -126,7 +126,7 @@ func TestClient_OpenIDCredentialIssuerMetadata(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.OpenIDCredentialIssuerMetadata(context.Background(), srv.URL+"/foo%2Fbar") require.NoError(t, err) assert.Equal(t, "/.well-known/openid-credential-issuer/foo%2Fbar", capturedRawPath) @@ -141,7 +141,7 @@ func TestClient_OpenIDCredentialIssuerMetadata(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.OpenIDCredentialIssuerMetadata(context.Background(), srv.URL) require.Error(t, err) assert.Contains(t, err.Error(), "credential_issuer") @@ -154,21 +154,14 @@ func TestClient_OpenIDCredentialIssuerMetadata(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.OpenIDCredentialIssuerMetadata(context.Background(), srv.URL) require.Error(t, err) assert.Contains(t, err.Error(), "404") }) - t.Run("rejects non-https issuer URL in strict mode", func(t *testing.T) { - client := NewClient(http.DefaultClient, true) - _, err := client.OpenIDCredentialIssuerMetadata(context.Background(), "http://issuer.example/") - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid issuer URL") - }) - t.Run("rejects issuer URL with query or fragment per §12.2.1", func(t *testing.T) { - client := NewClient(http.DefaultClient, false) + client := NewClient(http.DefaultClient) _, err := client.OpenIDCredentialIssuerMetadata(context.Background(), "https://issuer.example/?foo=bar") require.Error(t, err) assert.Contains(t, err.Error(), "query and fragment") @@ -185,7 +178,7 @@ func TestClient_OpenIDCredentialIssuerMetadata(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.OpenIDCredentialIssuerMetadata(context.Background(), srv.URL) require.Error(t, err) assert.Contains(t, err.Error(), "decoding issuer metadata") @@ -216,7 +209,7 @@ func TestClient_RequestCredential(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) resp, err := client.RequestCredential(context.Background(), RequestCredentialOpts{ CredentialEndpoint: srv.URL, AccessToken: "test-token", @@ -240,7 +233,7 @@ func TestClient_RequestCredential(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.RequestCredential(context.Background(), RequestCredentialOpts{ CredentialEndpoint: srv.URL, AccessToken: "t", @@ -261,7 +254,7 @@ func TestClient_RequestCredential(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.RequestCredential(context.Background(), RequestCredentialOpts{ CredentialEndpoint: srv.URL, AccessToken: "test-token", @@ -279,7 +272,7 @@ func TestClient_RequestCredential(t *testing.T) { })) defer srv.Close() - client := NewClient(srv.Client(), false) + client := NewClient(srv.Client()) _, err := client.RequestCredential(context.Background(), RequestCredentialOpts{ CredentialEndpoint: srv.URL, AccessToken: "test-token", diff --git a/auth/services/oauth/relying_party.go b/auth/services/oauth/relying_party.go index e78160a86c..48d967e49d 100644 --- a/auth/services/oauth/relying_party.go +++ b/auth/services/oauth/relying_party.go @@ -24,7 +24,6 @@ import ( "fmt" "github.com/nuts-foundation/nuts-node/pki" "net/url" - "strings" "time" "github.com/lestrrat-go/jwx/v2/jwt" @@ -47,24 +46,23 @@ type relyingParty struct { keyResolver resolver.KeyResolver privateKeyStore nutsCrypto.KeyStore serviceResolver didman.CompoundServiceResolver - strictMode bool httpClientTimeout time.Duration httpClientTLS *tls.Config wallet holder.Wallet pkiValidator pki.Validator } -// NewRelyingParty returns an implementation of RelyingParty +// NewRelyingParty returns an implementation of RelyingParty. +// Strict-mode URL validation is centralised in http/client.StrictHTTPClient; this constructor no longer takes a strictMode flag. func NewRelyingParty( didResolver resolver.DIDResolver, serviceResolver didman.CompoundServiceResolver, privateKeyStore nutsCrypto.KeyStore, - wallet holder.Wallet, httpClientTimeout time.Duration, httpClientTLS *tls.Config, strictMode bool, pkiValidator pki.Validator) RelyingParty { + wallet holder.Wallet, httpClientTimeout time.Duration, httpClientTLS *tls.Config, pkiValidator pki.Validator) RelyingParty { return &relyingParty{ keyResolver: resolver.DIDKeyResolver{Resolver: didResolver}, serviceResolver: serviceResolver, privateKeyStore: privateKeyStore, httpClientTimeout: httpClientTimeout, httpClientTLS: httpClientTLS, - strictMode: strictMode, wallet: wallet, pkiValidator: pkiValidator, } @@ -110,9 +108,6 @@ func (s *relyingParty) CreateJwtGrant(ctx context.Context, request services.Crea } func (s *relyingParty) RequestRFC003AccessToken(ctx context.Context, jwtGrantToken string, authorizationServerEndpoint url.URL) (*oauth.TokenResponse, error) { - if s.strictMode && strings.ToLower(authorizationServerEndpoint.Scheme) != "https" { - return nil, fmt.Errorf("authorization server endpoint must be HTTPS when in strict mode: %s", authorizationServerEndpoint.String()) - } httpClient := strictHttp.NewWithTLSConfig(s.httpClientTimeout, s.httpClientTLS) authClient, err := client.NewHTTPClient("", s.httpClientTimeout, client.WithHTTPClient(httpClient), client.WithRequestEditorFn(core.UserAgentRequestEditor)) if err != nil { diff --git a/auth/services/oauth/relying_party_test.go b/auth/services/oauth/relying_party_test.go index 9a3d647ca7..057efbf0c2 100644 --- a/auth/services/oauth/relying_party_test.go +++ b/auth/services/oauth/relying_party_test.go @@ -22,7 +22,6 @@ import ( "context" "crypto/tls" "errors" - "fmt" "net/http" "net/http/httptest" "testing" @@ -76,42 +75,6 @@ func TestRelyingParty_RequestRFC003AccessToken(t *testing.T) { assert.EqualError(t, err, "remote server/nuts node returned error creating access token: server returned HTTP 502 (expected: 200)") }) - t.Run("endpoint security validation (only HTTPS in strict mode)", func(t *testing.T) { - ctx := createRPContext(t, nil) - httpServer := httptest.NewServer(&http2.Handler{ - StatusCode: http.StatusOK, - }) - httpsServer := httptest.NewTLSServer(&http2.Handler{ - StatusCode: http.StatusOK, - }) - t.Cleanup(httpServer.Close) - t.Cleanup(httpsServer.Close) - - t.Run("HTTPS in strict mode", func(t *testing.T) { - ctx.relyingParty.strictMode = true - - response, err := ctx.relyingParty.RequestRFC003AccessToken(context.Background(), bearerToken, *test.MustParseURL(httpsServer.URL)) - - assert.NoError(t, err) - assert.NotNil(t, response) - }) - t.Run("HTTP allowed in non-strict mode", func(t *testing.T) { - ctx.relyingParty.strictMode = false - - response, err := ctx.relyingParty.RequestRFC003AccessToken(context.Background(), bearerToken, *test.MustParseURL(httpServer.URL)) - - assert.NoError(t, err) - assert.NotNil(t, response) - }) - t.Run("HTTP not allowed in strict mode", func(t *testing.T) { - ctx.relyingParty.strictMode = true - - response, err := ctx.relyingParty.RequestRFC003AccessToken(context.Background(), bearerToken, *test.MustParseURL(httpServer.URL)) - - assert.EqualError(t, err, fmt.Sprintf("authorization server endpoint must be HTTPS when in strict mode: %s", httpServer.URL)) - assert.Nil(t, response) - }) - }) } func TestService_CreateJwtBearerToken(t *testing.T) { diff --git a/docs/pages/deployment/configuration.rst b/docs/pages/deployment/configuration.rst index 4385fcb20a..baa604fa4d 100644 --- a/docs/pages/deployment/configuration.rst +++ b/docs/pages/deployment/configuration.rst @@ -95,4 +95,6 @@ requesting an access token from another node on ``/n2n/auth/v1/accesstoken`` doe json-ld context can only be downloaded from trusted domains configured in ``jsonld.contexts.remoteallowlist``, and the ``internalratelimiter`` is always on. -Interacting with remote Nuts nodes requires HTTPS: it will refuse to connect to plain HTTP endpoints when in strict mode. \ No newline at end of file +Interacting with remote Nuts nodes requires HTTPS: it will refuse to connect to plain HTTP endpoints when in strict mode. +Strict mode additionally rejects outbound URLs whose host is an IP literal or an RFC 2606 reserved hostname/TLD (e.g. ``*.localhost``, ``*.test``, ``example.com/net/org``). +This applies to every outbound HTTP call made via the shared HTTP client (OpenID4VCI, OAuth relying-party, IAM, Discovery, did:web resolution, etc.) and to every redirect target along the way. \ No newline at end of file diff --git a/http/client/client.go b/http/client/client.go index 2c9e4308df..2be7a2b25f 100644 --- a/http/client/client.go +++ b/http/client/client.go @@ -21,7 +21,6 @@ package client import ( "bytes" "crypto/tls" - "errors" "fmt" "io" "net/http" @@ -32,6 +31,21 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) +// maxRedirects matches net/http's default redirect cap. +const maxRedirects = 10 + +// checkRedirect re-runs core.ParsePublicURL on every redirect target so a +// validated initial URL cannot be turned into an SSRF via a 3xx response. +func checkRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + if _, err := core.ParsePublicURL(req.URL.String(), StrictMode); err != nil { + return fmt.Errorf("httpclient: invalid redirect target: %w", err) + } + return nil +} + // SafeHttpTransport is a http.Transport that can be used as a default transport for HTTP clients. var SafeHttpTransport *http.Transport @@ -75,8 +89,9 @@ func New(timeout time.Duration) *StrictHTTPClient { transport := getTransport(SafeHttpTransport) return &StrictHTTPClient{ client: &http.Client{ - Transport: transport, - Timeout: timeout, + Transport: transport, + Timeout: timeout, + CheckRedirect: checkRedirect, }, } } @@ -98,8 +113,9 @@ func NewWithCache(timeout time.Duration) *StrictHTTPClient { transport := getTransport(DefaultCachingTransport) return &StrictHTTPClient{ client: &http.Client{ - Transport: transport, - Timeout: timeout, + Transport: transport, + Timeout: timeout, + CheckRedirect: checkRedirect, }, } } @@ -112,8 +128,9 @@ func NewWithTLSConfig(timeout time.Duration, tlsConfig *tls.Config) *StrictHTTPC transport.TLSClientConfig = tlsConfig return &StrictHTTPClient{ client: &http.Client{ - Transport: getTransport(transport), - Timeout: timeout, + Transport: getTransport(transport), + Timeout: timeout, + CheckRedirect: checkRedirect, }, } } @@ -123,8 +140,8 @@ type StrictHTTPClient struct { } func (s *StrictHTTPClient) Do(req *http.Request) (*http.Response, error) { - if StrictMode && req.URL.Scheme != "https" { - return nil, errors.New("strictmode is enabled, but request is not over HTTPS") + if _, err := core.ParsePublicURL(req.URL.String(), StrictMode); err != nil { + return nil, fmt.Errorf("httpclient: invalid target URL: %w", err) } req.Header.Set("User-Agent", core.UserAgent()) result, err := s.client.Do(req) diff --git a/http/client/client_test.go b/http/client/client_test.go index 1a1b01366c..5eb261255a 100644 --- a/http/client/client_test.go +++ b/http/client/client_test.go @@ -21,6 +21,7 @@ package client import ( "crypto/tls" "fmt" + "io" "net/http" "net/http/httptest" "strings" @@ -40,12 +41,13 @@ func TestStrictHTTPClient(t *testing.T) { rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = true + t.Cleanup(func() { StrictMode = false }) client := NewWithCache(time.Second) httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) _, err := client.Do(httpRequest) - assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.ErrorContains(t, err, "httpclient: invalid target URL") assert.Equal(t, 0, rt.invocations) }) t.Run("strict mode disabled", func(t *testing.T) { @@ -66,12 +68,13 @@ func TestStrictHTTPClient(t *testing.T) { rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = true + t.Cleanup(func() { StrictMode = false }) client := NewWithCache(time.Second) httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) _, err := client.Do(httpRequest) - assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.ErrorContains(t, err, "httpclient: invalid target URL") assert.Equal(t, 0, rt.invocations) }) t.Run("sets TLS config", func(t *testing.T) { @@ -89,14 +92,177 @@ func TestStrictHTTPClient(t *testing.T) { rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = true + t.Cleanup(func() { StrictMode = false }) client := NewWithCache(time.Second) httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) _, err := client.Do(httpRequest) - assert.EqualError(t, err, "strictmode is enabled, but request is not over HTTPS") + assert.ErrorContains(t, err, "httpclient: invalid target URL") assert.Equal(t, 0, rt.invocations) }) + t.Run("strict mode rejects IP host", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = true + t.Cleanup(func() { StrictMode = false }) + + client := NewWithCache(time.Second) + httpRequest, _ := http.NewRequest("GET", "https://127.0.0.1/foo", nil) + _, err := client.Do(httpRequest) + + assert.ErrorContains(t, err, "httpclient: invalid target URL") + assert.ErrorContains(t, err, "hostname is IP") + assert.Equal(t, 0, rt.invocations) + }) + t.Run("strict mode rejects RFC2606 reserved host", func(t *testing.T) { + rt := &stubRoundTripper{} + DefaultCachingTransport = rt + StrictMode = true + t.Cleanup(func() { StrictMode = false }) + + client := NewWithCache(time.Second) + httpRequest, _ := http.NewRequest("GET", "https://service.localhost/foo", nil) + _, err := client.Do(httpRequest) + + assert.ErrorContains(t, err, "httpclient: invalid target URL") + assert.ErrorContains(t, err, "hostname is RFC2606 reserved") + assert.Equal(t, 0, rt.invocations) + }) +} + +func TestCheckRedirect(t *testing.T) { + makeReq := func(target string) *http.Request { + req, _ := http.NewRequest("GET", target, nil) + return req + } + t.Run("strict mode rejects http redirect", func(t *testing.T) { + StrictMode = true + t.Cleanup(func() { StrictMode = false }) + err := checkRedirect(makeReq("http://example.org"), nil) + assert.ErrorContains(t, err, "invalid redirect target") + assert.ErrorContains(t, err, "scheme must be https") + }) + t.Run("strict mode rejects redirect to IP host", func(t *testing.T) { + StrictMode = true + t.Cleanup(func() { StrictMode = false }) + err := checkRedirect(makeReq("https://127.0.0.1/x"), nil) + assert.ErrorContains(t, err, "invalid redirect target") + assert.ErrorContains(t, err, "hostname is IP") + }) + t.Run("strict mode rejects redirect to reserved host", func(t *testing.T) { + StrictMode = true + t.Cleanup(func() { StrictMode = false }) + err := checkRedirect(makeReq("https://internal.localhost/x"), nil) + assert.ErrorContains(t, err, "invalid redirect target") + assert.ErrorContains(t, err, "hostname is RFC2606 reserved") + }) + t.Run("non-strict mode allows http redirect", func(t *testing.T) { + StrictMode = false + err := checkRedirect(makeReq("http://example.org"), nil) + assert.NoError(t, err) + }) + t.Run("non-strict mode allows redirect to IP host", func(t *testing.T) { + StrictMode = false + err := checkRedirect(makeReq("http://127.0.0.1/x"), nil) + assert.NoError(t, err) + }) + t.Run("redirect cap stops after 10 hops", func(t *testing.T) { + StrictMode = false + via := make([]*http.Request, maxRedirects) + err := checkRedirect(makeReq("http://example.org"), via) + assert.ErrorContains(t, err, "stopped after 10 redirects") + }) + t.Run("cap checked before URL validation", func(t *testing.T) { + // Even an invalid target should produce the cap error first when via is at the limit. + StrictMode = true + t.Cleanup(func() { StrictMode = false }) + via := make([]*http.Request, maxRedirects) + err := checkRedirect(makeReq("http://example.org"), via) + assert.ErrorContains(t, err, "stopped after 10 redirects") + }) +} + +// redirectOnceTransport is a stub RoundTripper that responds with a 302 to redirectTo +// on the first request, then a 200 OK on subsequent requests. It also counts hops so +// tests can verify the redirected request was (or was not) sent. +type redirectOnceTransport struct { + redirectTo string + requests []*http.Request +} + +func (t *redirectOnceTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.requests = append(t.requests, req) + if len(t.requests) == 1 { + return &http.Response{ + StatusCode: http.StatusFound, + Header: http.Header{"Location": []string{t.redirectTo}}, + Body: io.NopCloser(strings.NewReader("")), + Request: req, + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader("ok")), + Request: req, + }, nil +} + +// TestStrictHTTPClient_RedirectEndToEnd drives the full net/http redirect path +// through a stub transport so we prove the configured CheckRedirect actually +// fires and blocks the second HTTP request from being issued. +// +// Initial URLs use "nuts.nl" because strict mode rejects RFC 2606 reserved 2LDs +// (example.com/net/org), and the test needs the initial Do() to pass so the +// redirect can be exercised. The hostname is never resolved — the stub +// transport intercepts all requests. +func TestStrictHTTPClient_RedirectEndToEnd(t *testing.T) { + const initialURL = "https://nuts.nl/" + t.Run("strict mode blocks redirect to non-https target", func(t *testing.T) { + rt := &redirectOnceTransport{redirectTo: "http://example.com/x"} + DefaultCachingTransport = rt + StrictMode = true + t.Cleanup(func() { StrictMode = false }) + + c := NewWithCache(time.Second) + req, _ := http.NewRequest("GET", initialURL, nil) + _, err := c.Do(req) + + require.Error(t, err) + assert.ErrorContains(t, err, "invalid redirect target") + assert.ErrorContains(t, err, "scheme must be https") + // only the initial request reached the transport; the redirect was blocked + assert.Len(t, rt.requests, 1, "second request must not be issued") + }) + t.Run("strict mode blocks redirect to IP host", func(t *testing.T) { + rt := &redirectOnceTransport{redirectTo: "https://10.0.0.1/x"} + DefaultCachingTransport = rt + StrictMode = true + t.Cleanup(func() { StrictMode = false }) + + c := NewWithCache(time.Second) + req, _ := http.NewRequest("GET", initialURL, nil) + _, err := c.Do(req) + + require.Error(t, err) + assert.ErrorContains(t, err, "invalid redirect target") + assert.ErrorContains(t, err, "hostname is IP") + assert.Len(t, rt.requests, 1) + }) + t.Run("non-strict mode follows http redirect", func(t *testing.T) { + rt := &redirectOnceTransport{redirectTo: "http://example.com/x"} + DefaultCachingTransport = rt + StrictMode = false + + c := NewWithCache(time.Second) + req, _ := http.NewRequest("GET", "http://example.com/", nil) + resp, err := c.Do(req) + + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Len(t, rt.requests, 2, "both initial and redirected request should be issued") + }) } func TestLimitedReadAll(t *testing.T) { diff --git a/vcr/vcr_test.go b/vcr/vcr_test.go index d4833943a9..231bc5da7d 100644 --- a/vcr/vcr_test.go +++ b/vcr/vcr_test.go @@ -87,6 +87,7 @@ func TestVCR_Configure(t *testing.T) { t.Run("strictmode passed to client APIs", func(t *testing.T) { ctx := newMockContext(t) client.StrictMode = true + t.Cleanup(func() { client.StrictMode = false }) testVC := test.ValidNutsOrganizationCredential(t) issuerDID := did.MustParseDID(testVC.Issuer.String()) testDirectory := io.TestDirectory(t) @@ -106,7 +107,8 @@ func TestVCR_Configure(t *testing.T) { require.NoError(t, err) err = issuer.OfferCredential(context.Background(), testVC, "http://example.com") - assert.ErrorContains(t, err, "http request error: strictmode is enabled, but request is not over HTTPS") + assert.ErrorContains(t, err, "httpclient: invalid target URL") + assert.ErrorContains(t, err, "scheme must be https") }) } From 3ae15497618a0aab5afb4d68790b67a5a7726f0c Mon Sep 17 00:00:00 2001 From: Joris Scharp Date: Mon, 11 May 2026 19:47:38 +0200 Subject: [PATCH 2/2] test: capture and restore package globals in client tests Apply the capture-and-restore t.Cleanup pattern to tests that mutate the package-level StrictMode and DefaultCachingTransport globals so they no longer leak state into subsequent tests. Adds a withClientGlobals helper in http/client/client_test.go to avoid repeating the boilerplate. Also fixes a pre-existing DefaultCachingTransport leak in TestCaching. --- http/client/client_test.go | 46 +++++++++++++++++++++++++------------- vcr/vcr_test.go | 3 ++- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/http/client/client_test.go b/http/client/client_test.go index 5eb261255a..de11e0075e 100644 --- a/http/client/client_test.go +++ b/http/client/client_test.go @@ -35,13 +35,26 @@ import ( "github.com/stretchr/testify/require" ) +// withClientGlobals captures the package-level globals StrictMode and +// DefaultCachingTransport at call time and restores them via t.Cleanup, +// so subtests can mutate them without leaking state into other tests. +func withClientGlobals(t *testing.T) { + t.Helper() + oldStrict := StrictMode + oldTransport := DefaultCachingTransport + t.Cleanup(func() { + StrictMode = oldStrict + DefaultCachingTransport = oldTransport + }) +} + func TestStrictHTTPClient(t *testing.T) { t.Run("caching transport", func(t *testing.T) { t.Run("strict mode enabled", func(t *testing.T) { + withClientGlobals(t) rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = true - t.Cleanup(func() { StrictMode = false }) client := NewWithCache(time.Second) httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) @@ -51,6 +64,7 @@ func TestStrictHTTPClient(t *testing.T) { assert.Equal(t, 0, rt.invocations) }) t.Run("strict mode disabled", func(t *testing.T) { + withClientGlobals(t) rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = false @@ -65,10 +79,10 @@ func TestStrictHTTPClient(t *testing.T) { }) t.Run("TLS transport", func(t *testing.T) { t.Run("strict mode enabled", func(t *testing.T) { + withClientGlobals(t) rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = true - t.Cleanup(func() { StrictMode = false }) client := NewWithCache(time.Second) httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) @@ -89,10 +103,10 @@ func TestStrictHTTPClient(t *testing.T) { }) }) t.Run("error on HTTP call when strictmode is enabled", func(t *testing.T) { + withClientGlobals(t) rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = true - t.Cleanup(func() { StrictMode = false }) client := NewWithCache(time.Second) httpRequest, _ := http.NewRequest("GET", "http://example.com", nil) @@ -102,10 +116,10 @@ func TestStrictHTTPClient(t *testing.T) { assert.Equal(t, 0, rt.invocations) }) t.Run("strict mode rejects IP host", func(t *testing.T) { + withClientGlobals(t) rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = true - t.Cleanup(func() { StrictMode = false }) client := NewWithCache(time.Second) httpRequest, _ := http.NewRequest("GET", "https://127.0.0.1/foo", nil) @@ -116,10 +130,10 @@ func TestStrictHTTPClient(t *testing.T) { assert.Equal(t, 0, rt.invocations) }) t.Run("strict mode rejects RFC2606 reserved host", func(t *testing.T) { + withClientGlobals(t) rt := &stubRoundTripper{} DefaultCachingTransport = rt StrictMode = true - t.Cleanup(func() { StrictMode = false }) client := NewWithCache(time.Second) httpRequest, _ := http.NewRequest("GET", "https://service.localhost/foo", nil) @@ -137,37 +151,40 @@ func TestCheckRedirect(t *testing.T) { return req } t.Run("strict mode rejects http redirect", func(t *testing.T) { + withClientGlobals(t) StrictMode = true - t.Cleanup(func() { StrictMode = false }) err := checkRedirect(makeReq("http://example.org"), nil) assert.ErrorContains(t, err, "invalid redirect target") assert.ErrorContains(t, err, "scheme must be https") }) t.Run("strict mode rejects redirect to IP host", func(t *testing.T) { + withClientGlobals(t) StrictMode = true - t.Cleanup(func() { StrictMode = false }) err := checkRedirect(makeReq("https://127.0.0.1/x"), nil) assert.ErrorContains(t, err, "invalid redirect target") assert.ErrorContains(t, err, "hostname is IP") }) t.Run("strict mode rejects redirect to reserved host", func(t *testing.T) { + withClientGlobals(t) StrictMode = true - t.Cleanup(func() { StrictMode = false }) err := checkRedirect(makeReq("https://internal.localhost/x"), nil) assert.ErrorContains(t, err, "invalid redirect target") assert.ErrorContains(t, err, "hostname is RFC2606 reserved") }) t.Run("non-strict mode allows http redirect", func(t *testing.T) { + withClientGlobals(t) StrictMode = false err := checkRedirect(makeReq("http://example.org"), nil) assert.NoError(t, err) }) t.Run("non-strict mode allows redirect to IP host", func(t *testing.T) { + withClientGlobals(t) StrictMode = false err := checkRedirect(makeReq("http://127.0.0.1/x"), nil) assert.NoError(t, err) }) t.Run("redirect cap stops after 10 hops", func(t *testing.T) { + withClientGlobals(t) StrictMode = false via := make([]*http.Request, maxRedirects) err := checkRedirect(makeReq("http://example.org"), via) @@ -175,8 +192,8 @@ func TestCheckRedirect(t *testing.T) { }) t.Run("cap checked before URL validation", func(t *testing.T) { // Even an invalid target should produce the cap error first when via is at the limit. + withClientGlobals(t) StrictMode = true - t.Cleanup(func() { StrictMode = false }) via := make([]*http.Request, maxRedirects) err := checkRedirect(makeReq("http://example.org"), via) assert.ErrorContains(t, err, "stopped after 10 redirects") @@ -220,10 +237,10 @@ func (t *redirectOnceTransport) RoundTrip(req *http.Request) (*http.Response, er func TestStrictHTTPClient_RedirectEndToEnd(t *testing.T) { const initialURL = "https://nuts.nl/" t.Run("strict mode blocks redirect to non-https target", func(t *testing.T) { + withClientGlobals(t) rt := &redirectOnceTransport{redirectTo: "http://example.com/x"} DefaultCachingTransport = rt StrictMode = true - t.Cleanup(func() { StrictMode = false }) c := NewWithCache(time.Second) req, _ := http.NewRequest("GET", initialURL, nil) @@ -236,10 +253,10 @@ func TestStrictHTTPClient_RedirectEndToEnd(t *testing.T) { assert.Len(t, rt.requests, 1, "second request must not be issued") }) t.Run("strict mode blocks redirect to IP host", func(t *testing.T) { + withClientGlobals(t) rt := &redirectOnceTransport{redirectTo: "https://10.0.0.1/x"} DefaultCachingTransport = rt StrictMode = true - t.Cleanup(func() { StrictMode = false }) c := NewWithCache(time.Second) req, _ := http.NewRequest("GET", initialURL, nil) @@ -251,6 +268,7 @@ func TestStrictHTTPClient_RedirectEndToEnd(t *testing.T) { assert.Len(t, rt.requests, 1) }) t.Run("non-strict mode follows http redirect", func(t *testing.T) { + withClientGlobals(t) rt := &redirectOnceTransport{redirectTo: "http://example.com/x"} DefaultCachingTransport = rt StrictMode = false @@ -283,9 +301,8 @@ func TestLimitedReadAll(t *testing.T) { } func TestMaxConns(t *testing.T) { - oldStrictMode := StrictMode + withClientGlobals(t) StrictMode = false - t.Cleanup(func() { StrictMode = oldStrictMode }) // Our safe http Transport has MaxConnsPerHost = 5 // if we request 6 resources multiple times, we expect a max connection usage of 5 @@ -328,9 +345,8 @@ func TestMaxConns(t *testing.T) { } func TestCaching(t *testing.T) { - oldStrictMode := StrictMode + withClientGlobals(t) StrictMode = false - t.Cleanup(func() { StrictMode = oldStrictMode }) // counter for the number of concurrent requests var total atomic.Int32 diff --git a/vcr/vcr_test.go b/vcr/vcr_test.go index 231bc5da7d..29700f3be2 100644 --- a/vcr/vcr_test.go +++ b/vcr/vcr_test.go @@ -86,8 +86,9 @@ func TestVCR_Configure(t *testing.T) { }) t.Run("strictmode passed to client APIs", func(t *testing.T) { ctx := newMockContext(t) + oldStrict := client.StrictMode client.StrictMode = true - t.Cleanup(func() { client.StrictMode = false }) + t.Cleanup(func() { client.StrictMode = oldStrict }) testVC := test.ValidNutsOrganizationCredential(t) issuerDID := did.MustParseDID(testVC.Issuer.String()) testDirectory := io.TestDirectory(t)