From 66a12cbebef156f5947548d4a7f72ab871573501 Mon Sep 17 00:00:00 2001 From: wlggraham Date: Thu, 16 Apr 2026 11:07:50 -0600 Subject: [PATCH] add throttling backoff to ecs plugin --- internal/pkg/object/command/ecs/ecs.go | 79 +++++++++++++++++---- internal/pkg/object/command/ecs/ecs_test.go | 79 +++++++++++++++++++++ 2 files changed, 143 insertions(+), 15 deletions(-) create mode 100644 internal/pkg/object/command/ecs/ecs_test.go diff --git a/internal/pkg/object/command/ecs/ecs.go b/internal/pkg/object/command/ecs/ecs.go index 30eaa8e..b2750b7 100644 --- a/internal/pkg/object/command/ecs/ecs.go +++ b/internal/pkg/object/command/ecs/ecs.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "math" + "math/rand" "os" "strings" "time" @@ -13,6 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" + smithy "github.com/aws/smithy-go" "github.com/hladush/go-telemetry/pkg/telemetry" heimdallAws "github.com/patterninc/heimdall/internal/pkg/aws" heimdallContext "github.com/patterninc/heimdall/pkg/context" @@ -109,18 +112,21 @@ type executionContext struct { } const ( - defaultPollingInterval = duration.Duration(30 * time.Second) - defaultTaskTimeout = duration.Duration(1 * time.Hour) - defaultMaxFailCount = 1 - defaultTaskCount = 1 - startedByPrefix = "heimdall-job-" - errMaxFailCount = "task %s failed %d times (max: %d), giving up" - errPollingTimeout = "polling timed out for arns %v after %v" - errJobTerminated = "job marked as stale or canceled" - Timeout FailureReason = "timeout" - Error FailureReason = "error" - maxLogChunkSize = 200 // Process 200 log entries at a time - maxLogMemoryBytes = 1024 * 1024 * 1024 // 1GB safety limit + defaultPollingInterval = duration.Duration(30 * time.Second) + defaultTaskTimeout = duration.Duration(1 * time.Hour) + defaultMaxFailCount = 1 + defaultTaskCount = 1 + defaultThrottleMaxRetries = 5 + throttleBackoffBase = time.Second + throttleBackoffMax = 2 * time.Minute + startedByPrefix = "heimdall-job-" + errMaxFailCount = "task %s failed %d times (max: %d), giving up" + errPollingTimeout = "polling timed out for arns %v after %v" + errJobTerminated = "job marked as stale or canceled" + Timeout FailureReason = "timeout" + Error FailureReason = "error" + maxLogChunkSize = 200 // Process 200 log entries at a time + maxLogMemoryBytes = 1024 * 1024 * 1024 // 1GB safety limit ) var ( @@ -336,8 +342,12 @@ func (execCtx *executionContext) pollForCompletion(ctx context.Context) error { Tasks: activeARNs, } - describeOutput, err := execCtx.ecsClient.DescribeTasks(ctx, describeInput) - if err != nil { + var describeOutput *ecs.DescribeTasksOutput + if err := retryWithBackoff(ctx, defaultThrottleMaxRetries, func() error { + var descErr error + describeOutput, descErr = execCtx.ecsClient.DescribeTasks(ctx, describeInput) + return descErr + }); err != nil { return err } @@ -644,7 +654,12 @@ func runTask(ctx context.Context, execCtx *executionContext, startedBy string, t }, } - runTaskOutput, err := execCtx.ecsClient.RunTask(ctx, runTaskInput) + var runTaskOutput *ecs.RunTaskOutput + err := retryWithBackoff(ctx, defaultThrottleMaxRetries, func() error { + var runErr error + runTaskOutput, runErr = execCtx.ecsClient.RunTask(ctx, runTaskInput) + return runErr + }) if err != nil { return ``, err } @@ -803,3 +818,37 @@ func (e *commandContext) Cleanup(ctx context.Context, jobID string, c *cluster.C return nil } + +// retryWithBackoff retries op on throttling errors using exponential backoff with jitter. +func retryWithBackoff(ctx context.Context, maxRetries int, op func() error) error { + var err error + for attempt := 0; attempt <= maxRetries; attempt++ { + err = op() + if err == nil { + return nil + } + if !isThrottlingError(err) || attempt == maxRetries { + return err + } + delay := min(time.Duration(math.Pow(2, float64(attempt)))*throttleBackoffBase, throttleBackoffMax) + delay += time.Duration(rand.Int63n(int64(throttleBackoffBase))) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + } + return err +} + +// isThrottlingError reports whether err is an AWS throttling error. +func isThrottlingError(err error) bool { + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + switch apiErr.ErrorCode() { + case "ThrottlingException", "RequestThrottledException", "Throttling", "RequestLimitExceeded": + return true + } + } + return false +} diff --git a/internal/pkg/object/command/ecs/ecs_test.go b/internal/pkg/object/command/ecs/ecs_test.go new file mode 100644 index 0000000..6369467 --- /dev/null +++ b/internal/pkg/object/command/ecs/ecs_test.go @@ -0,0 +1,79 @@ +package ecs + +import ( + "context" + "testing" + + smithy "github.com/aws/smithy-go" +) + +func TestRetryWithBackoff_Throttling(t *testing.T) { + calls := 0 + throttleErr := &smithy.GenericAPIError{Code: "ThrottlingException", Message: "Rate exceeded"} + + err := retryWithBackoff(context.Background(), 3, func() error { + calls++ + return throttleErr + }) + + if err != throttleErr { + t.Fatalf("expected throttle error, got %v", err) + } + if calls != 4 { // 1 initial + 3 retries + t.Fatalf("expected 4 calls, got %d", calls) + } +} + +func TestRetryWithBackoff_NonThrottling(t *testing.T) { + calls := 0 + otherErr := &smithy.GenericAPIError{Code: "AccessDeniedException", Message: "denied"} + + retryWithBackoff(context.Background(), 3, func() error { + calls++ + return otherErr + }) + + if calls != 1 { // should not retry + t.Fatalf("expected 1 call, got %d", calls) + } +} + +func TestRetryWithBackoff_SuccessAfterThrottle(t *testing.T) { + calls := 0 + throttleErr := &smithy.GenericAPIError{Code: "ThrottlingException", Message: "Rate exceeded"} + + err := retryWithBackoff(context.Background(), 3, func() error { + calls++ + if calls < 3 { + return throttleErr + } + return nil + }) + + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if calls != 3 { + t.Fatalf("expected 3 calls, got %d", calls) + } +} + +func TestRetryWithBackoff_ContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + calls := 0 + throttleErr := &smithy.GenericAPIError{Code: "ThrottlingException", Message: "Rate exceeded"} + + err := retryWithBackoff(ctx, 3, func() error { + calls++ + return throttleErr + }) + + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } + if calls != 1 { + t.Fatalf("expected 1 call, got %d", calls) + } +}