diff --git a/core/capabilities/fakes/http_action.go b/core/capabilities/fakes/http_action.go index 948a5a58637..205d4c3859e 100644 --- a/core/capabilities/fakes/http_action.go +++ b/core/capabilities/fakes/http_action.go @@ -6,6 +6,7 @@ import ( "errors" "io" "net/http" + "slices" "strings" "time" @@ -18,12 +19,16 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types/core" ) -var _ httpserver.ClientCapability = (*DirectHTTPAction)(nil) -var _ services.Service = (*DirectHTTPAction)(nil) -var _ commonCap.ExecutableCapability = (*DirectHTTPAction)(nil) +var ( + _ httpserver.ClientCapability = (*DirectHTTPAction)(nil) + _ services.Service = (*DirectHTTPAction)(nil) + _ commonCap.ExecutableCapability = (*DirectHTTPAction)(nil) +) -const HTTPActionID = "http-actions@0.1.0" -const HTTPActionServiceName = "HttpActionService" +const ( + HTTPActionID = "http-actions@0.1.0" + HTTPActionServiceName = "HttpActionService" +) var directHTTPActionInfo = commonCap.MustNewCapabilityInfo( HTTPActionID, @@ -90,9 +95,17 @@ func (fh *DirectHTTPAction) SendRequest(ctx context.Context, metadata commonCap. return &responseAndMetadata, caperrors.NewPrivateSystemError(err, caperrors.Unknown) } - // Add headers - for k, v := range input.GetHeaders() { - req.Header.Set(k, v) + // Add headers: prefer MultiHeaders, fall back to deprecated Headers + if len(input.GetMultiHeaders()) > 0 { + for k, v := range input.GetMultiHeaders() { + for _, val := range v.GetValues() { + req.Header.Add(k, val) + } + } + } else { + for k, v := range input.GetHeaders() { //nolint: staticcheck // deprecated + req.Header.Set(k, v) + } } // Make the HTTP request @@ -124,18 +137,23 @@ func (fh *DirectHTTPAction) SendRequest(ctx context.Context, metadata commonCap. return &responseAndMetadata, caperrors.NewPrivateSystemError(err, caperrors.Unknown) } - // Convert headers - headers := make(map[string]string) + // Convert headers: Headers (comma-joined for backwards compat) and MultiHeaders (per capability) + headers := make(map[string]string, len(resp.Header)) + multiHeaders := make(map[string]*customhttp.HeaderValues, len(resp.Header)) for k, v := range resp.Header { - // Join multiple header values with comma + if len(v) == 0 { + continue + } headers[k] = strings.Join(v, ", ") + multiHeaders[k] = &customhttp.HeaderValues{Values: slices.Clone(v)} } // Create response response := &customhttp.Response{ - StatusCode: uint32(resp.StatusCode), //nolint:gosec // status code is always in valid range - Headers: headers, - Body: respBody, + StatusCode: uint32(resp.StatusCode), //nolint:gosec // status code is always in valid range + Headers: headers, + MultiHeaders: multiHeaders, + Body: respBody, } responseAndMetadata := commonCap.ResponseAndMetadata[*customhttp.Response]{ Response: response, diff --git a/core/capabilities/fakes/http_action_test.go b/core/capabilities/fakes/http_action_test.go new file mode 100644 index 00000000000..e04e1aa5e07 --- /dev/null +++ b/core/capabilities/fakes/http_action_test.go @@ -0,0 +1,132 @@ +package fakes + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + commonCap "github.com/smartcontractkit/chainlink-common/pkg/capabilities" + customhttp "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/actions/http" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +func TestDirectHTTPAction_RequestHeaders(t *testing.T) { + t.Run("MultiHeaders are sent in request", func(t *testing.T) { + var receivedAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + t.Cleanup(srv.Close) + + lggr := logger.Test(t) + action := NewDirectHTTPAction(lggr) + require.NoError(t, action.Start(context.Background())) + t.Cleanup(func() { _ = action.Close() }) + + input := &customhttp.Request{ + Url: srv.URL, + Method: "GET", + MultiHeaders: map[string]*customhttp.HeaderValues{ + "Authorization": {Values: []string{"Bearer test-token"}}, + }, + } + metadata := commonCap.RequestMetadata{} + + result, err := action.SendRequest(context.Background(), metadata, input) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "Bearer test-token", receivedAuth, "Authorization header should be sent") + }) + + t.Run("Headers (deprecated) are sent in request when MultiHeaders empty", func(t *testing.T) { + var receivedAuth string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + t.Cleanup(srv.Close) + + lggr := logger.Test(t) + action := NewDirectHTTPAction(lggr) + require.NoError(t, action.Start(context.Background())) + t.Cleanup(func() { _ = action.Close() }) + + input := &customhttp.Request{ + Url: srv.URL, + Method: "GET", + Headers: map[string]string{"Authorization": "Basic legacy-auth"}, + } + metadata := commonCap.RequestMetadata{} + + result, err := action.SendRequest(context.Background(), metadata, input) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, "Basic legacy-auth", receivedAuth, "Authorization header should be sent via deprecated Headers") + }) +} + +func TestDirectHTTPAction_ResponseHeadersAndMultiHeaders(t *testing.T) { + t.Run("response has both Headers and MultiHeaders populated", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Add("Set-Cookie", "sessionid=abc123; Path=/") + w.Header().Add("Set-Cookie", "csrf=xyz789; Path=/") + w.Header().Add("X-Custom", "single-value") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + t.Cleanup(srv.Close) + + lggr := logger.Test(t) + action := NewDirectHTTPAction(lggr) + require.NoError(t, action.Start(context.Background())) + t.Cleanup(func() { _ = action.Close() }) + + input := &customhttp.Request{ + Url: srv.URL, + Method: "GET", + } + metadata := commonCap.RequestMetadata{} + + result, err := action.SendRequest(context.Background(), metadata, input) + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Response) + + resp := result.Response + assert.Equal(t, uint32(200), resp.StatusCode) + + // Headers (comma-joined, backwards compat) + require.NotNil(t, resp.Headers) //nolint:staticcheck // testing deprecated field + assert.Contains(t, resp.Headers, "Content-Type") //nolint:staticcheck // testing deprecated field + assert.Equal(t, "application/json", resp.Headers["Content-Type"]) //nolint:staticcheck // testing deprecated field + assert.Contains(t, resp.Headers, "Set-Cookie") //nolint:staticcheck // testing deprecated field + assert.Contains(t, resp.Headers["Set-Cookie"], "sessionid=abc123") //nolint:staticcheck // testing deprecated field + assert.Contains(t, resp.Headers["Set-Cookie"], "csrf=xyz789") //nolint:staticcheck // testing deprecated field + assert.Equal(t, "single-value", resp.Headers["X-Custom"]) //nolint:staticcheck // testing deprecated field + + // MultiHeaders (per-value slices) + require.NotNil(t, resp.MultiHeaders) + assert.Contains(t, resp.MultiHeaders, "Content-Type") + assert.Equal(t, []string{"application/json"}, resp.MultiHeaders["Content-Type"].GetValues()) + + setCookie := resp.MultiHeaders["Set-Cookie"] + require.NotNil(t, setCookie) + vals := setCookie.GetValues() + require.Len(t, vals, 2) + assert.Contains(t, vals, "sessionid=abc123; Path=/") + assert.Contains(t, vals, "csrf=xyz789; Path=/") + + assert.Contains(t, resp.MultiHeaders, "X-Custom") + assert.Equal(t, []string{"single-value"}, resp.MultiHeaders["X-Custom"].GetValues()) + }) +}