Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 64 additions & 15 deletions internal/pkg/object/command/ecs/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"math"
"math/rand"
"os"
"strings"
"time"
Expand All @@ -13,6 +15,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/aws/aws-sdk-go-v2/service/ecs/types"
smithy "github.com/aws/smithy-go"
"github.com/hladush/go-telemetry/pkg/telemetry"
heimdallAws "github.com/patterninc/heimdall/internal/pkg/aws"
heimdallContext "github.com/patterninc/heimdall/pkg/context"
Expand Down Expand Up @@ -109,18 +112,21 @@ type executionContext struct {
}

const (
defaultPollingInterval = duration.Duration(30 * time.Second)
defaultTaskTimeout = duration.Duration(1 * time.Hour)
defaultMaxFailCount = 1
defaultTaskCount = 1
startedByPrefix = "heimdall-job-"
errMaxFailCount = "task %s failed %d times (max: %d), giving up"
errPollingTimeout = "polling timed out for arns %v after %v"
errJobTerminated = "job marked as stale or canceled"
Timeout FailureReason = "timeout"
Error FailureReason = "error"
maxLogChunkSize = 200 // Process 200 log entries at a time
maxLogMemoryBytes = 1024 * 1024 * 1024 // 1GB safety limit
defaultPollingInterval = duration.Duration(30 * time.Second)
defaultTaskTimeout = duration.Duration(1 * time.Hour)
defaultMaxFailCount = 1
defaultTaskCount = 1
defaultThrottleMaxRetries = 5
throttleBackoffBase = time.Second
throttleBackoffMax = 2 * time.Minute
startedByPrefix = "heimdall-job-"
errMaxFailCount = "task %s failed %d times (max: %d), giving up"
errPollingTimeout = "polling timed out for arns %v after %v"
errJobTerminated = "job marked as stale or canceled"
Timeout FailureReason = "timeout"
Error FailureReason = "error"
maxLogChunkSize = 200 // Process 200 log entries at a time
maxLogMemoryBytes = 1024 * 1024 * 1024 // 1GB safety limit
)

var (
Expand Down Expand Up @@ -336,8 +342,12 @@ func (execCtx *executionContext) pollForCompletion(ctx context.Context) error {
Tasks: activeARNs,
}

describeOutput, err := execCtx.ecsClient.DescribeTasks(ctx, describeInput)
if err != nil {
var describeOutput *ecs.DescribeTasksOutput
if err := retryWithBackoff(ctx, defaultThrottleMaxRetries, func() error {
var descErr error
describeOutput, descErr = execCtx.ecsClient.DescribeTasks(ctx, describeInput)
return descErr
}); err != nil {
return err
}

Expand Down Expand Up @@ -644,7 +654,12 @@ func runTask(ctx context.Context, execCtx *executionContext, startedBy string, t
},
}

runTaskOutput, err := execCtx.ecsClient.RunTask(ctx, runTaskInput)
var runTaskOutput *ecs.RunTaskOutput
err := retryWithBackoff(ctx, defaultThrottleMaxRetries, func() error {
var runErr error
runTaskOutput, runErr = execCtx.ecsClient.RunTask(ctx, runTaskInput)
return runErr
})
if err != nil {
return ``, err
}
Expand Down Expand Up @@ -803,3 +818,37 @@ func (e *commandContext) Cleanup(ctx context.Context, jobID string, c *cluster.C
return nil

}

// retryWithBackoff retries op on throttling errors using exponential backoff with jitter.
func retryWithBackoff(ctx context.Context, maxRetries int, op func() error) error {
var err error
for attempt := 0; attempt <= maxRetries; attempt++ {
err = op()
if err == nil {
return nil
}
if !isThrottlingError(err) || attempt == maxRetries {
return err
}
delay := min(time.Duration(math.Pow(2, float64(attempt)))*throttleBackoffBase, throttleBackoffMax)
delay += time.Duration(rand.Int63n(int64(throttleBackoffBase)))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
}
}
return err
}

// isThrottlingError reports whether err is an AWS throttling error.
func isThrottlingError(err error) bool {
var apiErr smithy.APIError
if errors.As(err, &apiErr) {
switch apiErr.ErrorCode() {
case "ThrottlingException", "RequestThrottledException", "Throttling", "RequestLimitExceeded":
return true
}
}
return false
}
79 changes: 79 additions & 0 deletions internal/pkg/object/command/ecs/ecs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package ecs

import (
"context"
"testing"

smithy "github.com/aws/smithy-go"
)

func TestRetryWithBackoff_Throttling(t *testing.T) {
calls := 0
throttleErr := &smithy.GenericAPIError{Code: "ThrottlingException", Message: "Rate exceeded"}

err := retryWithBackoff(context.Background(), 3, func() error {
calls++
return throttleErr
})

if err != throttleErr {
t.Fatalf("expected throttle error, got %v", err)
}
if calls != 4 { // 1 initial + 3 retries
t.Fatalf("expected 4 calls, got %d", calls)
}
}

func TestRetryWithBackoff_NonThrottling(t *testing.T) {
calls := 0
otherErr := &smithy.GenericAPIError{Code: "AccessDeniedException", Message: "denied"}

retryWithBackoff(context.Background(), 3, func() error {
calls++
return otherErr
})

if calls != 1 { // should not retry
t.Fatalf("expected 1 call, got %d", calls)
}
}

func TestRetryWithBackoff_SuccessAfterThrottle(t *testing.T) {
calls := 0
throttleErr := &smithy.GenericAPIError{Code: "ThrottlingException", Message: "Rate exceeded"}

err := retryWithBackoff(context.Background(), 3, func() error {
calls++
if calls < 3 {
return throttleErr
}
return nil
})

if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if calls != 3 {
t.Fatalf("expected 3 calls, got %d", calls)
}
}

func TestRetryWithBackoff_ContextCancelled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately

calls := 0
throttleErr := &smithy.GenericAPIError{Code: "ThrottlingException", Message: "Rate exceeded"}

err := retryWithBackoff(ctx, 3, func() error {
calls++
return throttleErr
})

if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
if calls != 1 {
t.Fatalf("expected 1 call, got %d", calls)
}
}
Loading