Skip to content
Open
1 change: 1 addition & 0 deletions bundle/config/validate/fast_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func (f *fastValidate) Apply(ctx context.Context, rb *bundle.Bundle) diag.Diagno
// Fast mutators with only in-memory checks
JobClusterKeyDefined(),
JobTaskClusterSpec(),
ForEachTask(),

// Blocking mutators. Deployments will fail if these checks fail.
ValidateArtifactPath(),
Expand Down
76 changes: 76 additions & 0 deletions bundle/config/validate/for_each_task.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package validate

import (
"context"
"fmt"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/databricks-sdk-go/service/jobs"
)

// ForEachTask validates constraints for for_each_task configuration
func ForEachTask() bundle.ReadOnlyMutator {
return &forEachTask{}
}

type forEachTask struct{ bundle.RO }

func (v *forEachTask) Name() string {
return "validate:for_each_task"
}

func (v *forEachTask) Apply(ctx context.Context, b *bundle.Bundle) diag.Diagnostics {
diags := diag.Diagnostics{}

jobsPath := dyn.NewPath(dyn.Key("resources"), dyn.Key("jobs"))

for resourceName, job := range b.Config.Resources.Jobs {
resourcePath := jobsPath.Append(dyn.Key(resourceName))

for taskIndex, task := range job.Tasks {
taskPath := resourcePath.Append(dyn.Key("tasks"), dyn.Index(taskIndex))

if task.ForEachTask != nil {
diags = diags.Extend(validateForEachTask(b, task, taskPath))
}
}
}

return diags
}

func validateForEachTask(b *bundle.Bundle, task jobs.Task, taskPath dyn.Path) diag.Diagnostics {
diags := diag.Diagnostics{}

if task.MaxRetries != 0 {
diags = diags.Append(invalidRetryFieldDiag(b, task, taskPath, "max_retries", diag.Error))
}

if task.MinRetryIntervalMillis != 0 {
diags = diags.Append(invalidRetryFieldDiag(b, task, taskPath, "min_retry_interval_millis", diag.Warning))
}

if task.RetryOnTimeout {
diags = diags.Append(invalidRetryFieldDiag(b, task, taskPath, "retry_on_timeout", diag.Warning))
}

return diags
}

func invalidRetryFieldDiag(b *bundle.Bundle, task jobs.Task, taskPath dyn.Path, fieldName string, severity diag.Severity) diag.Diagnostic {
detail := fmt.Sprintf(
"Task %q has %s defined at the parent level, but it uses for_each_task.\n"+
"When using for_each_task, %s must be defined on the nested task (for_each_task.task.%s), not on the parent task.",
task.TaskKey, fieldName, fieldName, fieldName,
)

return diag.Diagnostic{
Severity: severity,
Summary: fmt.Sprintf("Invalid %s configuration for for_each_task", fieldName),
Detail: detail,
Locations: b.Config.GetLocations(taskPath.String()),
Paths: []dyn.Path{taskPath},
}
}
191 changes: 191 additions & 0 deletions bundle/config/validate/for_each_task_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
package validate

import (
"context"
"testing"

"github.com/databricks/cli/bundle"
"github.com/databricks/cli/bundle/config"
"github.com/databricks/cli/bundle/config/resources"
"github.com/databricks/cli/bundle/internal/bundletest"
"github.com/databricks/cli/libs/diag"
"github.com/databricks/cli/libs/dyn"
"github.com/databricks/databricks-sdk-go/service/jobs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func createBundleWithForEachTask(parentTask jobs.Task) *bundle.Bundle {
if parentTask.ForEachTask == nil {
parentTask.ForEachTask = &jobs.ForEachTask{
Inputs: "[1, 2, 3]",
Task: jobs.Task{
TaskKey: "child_task",
NotebookTask: &jobs.NotebookTask{
NotebookPath: "test.py",
},
},
}
}

b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"job1": {
JobSettings: jobs.JobSettings{
Name: "My Job",
Tasks: []jobs.Task{parentTask},
},
},
},
},
},
}

bundletest.SetLocation(b, "resources.jobs.job1.tasks[0]", []dyn.Location{{File: "job.yml", Line: 1, Column: 1}})
return b
}

func TestForEachTask_InvalidRetryFields(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test looks good overall, but generally for testing mutators, we prefer to use acceptance tests (see acceptance folder). Could you change this test to an acceptance test instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback. I added the acceptance tests.

Let me know if I show delete the for_each_task_test.go, I removed some of it's content. Went the unit tests route cause I saw there were some other unit tests there already but missed the acceptance part.

tests := []struct {
name string
task jobs.Task
expectedSeverity diag.Severity
expectedSummary string
expectedDetail string
}{
{
name: "max_retries on parent",
task: jobs.Task{
TaskKey: "parent_task",
MaxRetries: 3,
},
expectedSeverity: diag.Error,
expectedSummary: "Invalid max_retries configuration for for_each_task",
expectedDetail: "max_retries must be defined on the nested task",
},
{
name: "min_retry_interval_millis on parent",
task: jobs.Task{
TaskKey: "parent_task",
MinRetryIntervalMillis: 1000,
},
expectedSeverity: diag.Warning,
expectedSummary: "Invalid min_retry_interval_millis configuration for for_each_task",
expectedDetail: "min_retry_interval_millis must be defined on the nested task",
},
{
name: "retry_on_timeout on parent",
task: jobs.Task{
TaskKey: "parent_task",
RetryOnTimeout: true,
},
expectedSeverity: diag.Warning,
expectedSummary: "Invalid retry_on_timeout configuration for for_each_task",
expectedDetail: "retry_on_timeout must be defined on the nested task",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
b := createBundleWithForEachTask(tt.task)

diags := ForEachTask().Apply(ctx, b)

require.Len(t, diags, 1)
assert.Equal(t, tt.expectedSeverity, diags[0].Severity)
assert.Equal(t, tt.expectedSummary, diags[0].Summary)
assert.Contains(t, diags[0].Detail, tt.expectedDetail)
})
}
}

func TestForEachTask_MultipleRetryFieldsOnParent(t *testing.T) {
ctx := context.Background()
b := createBundleWithForEachTask(jobs.Task{
TaskKey: "parent_task",
MaxRetries: 3,
MinRetryIntervalMillis: 1000,
RetryOnTimeout: true,
})

diags := ForEachTask().Apply(ctx, b)
require.Len(t, diags, 3)

errorCount := 0
warningCount := 0
for _, d := range diags {
switch d.Severity {

Check failure on line 120 in bundle/config/validate/for_each_task_test.go

View workflow job for this annotation

GitHub Actions / lint

missing cases in switch of type diag.Severity: diag.Recommendation (exhaustive)
case diag.Error:
errorCount++
case diag.Warning:
warningCount++
}
}
assert.Equal(t, 1, errorCount)
assert.Equal(t, 2, warningCount)
}

func TestForEachTask_ValidConfigurationOnChild(t *testing.T) {
ctx := context.Background()
b := createBundleWithForEachTask(jobs.Task{
TaskKey: "parent_task",
ForEachTask: &jobs.ForEachTask{
Inputs: "[1, 2, 3]",
Task: jobs.Task{
TaskKey: "child_task",
MaxRetries: 3,
NotebookTask: &jobs.NotebookTask{
NotebookPath: "test.py",
},
},
},
})

diags := ForEachTask().Apply(ctx, b)
assert.Empty(t, diags)
}

func TestForEachTask_NoForEachTask(t *testing.T) {
ctx := context.Background()
b := &bundle.Bundle{
Config: config.Root{
Resources: config.Resources{
Jobs: map[string]*resources.Job{
"job1": {
JobSettings: jobs.JobSettings{
Name: "My Job",
Tasks: []jobs.Task{
{
TaskKey: "simple_task",
MaxRetries: 3,
NotebookTask: &jobs.NotebookTask{
NotebookPath: "test.py",
},
},
},
},
},
},
},
},
}

bundletest.SetLocation(b, "resources.jobs.job1.tasks[0]", []dyn.Location{{File: "job.yml", Line: 1, Column: 1}})

diags := ForEachTask().Apply(ctx, b)
assert.Empty(t, diags)
}

func TestForEachTask_RetryOnTimeoutFalse(t *testing.T) {
ctx := context.Background()
b := createBundleWithForEachTask(jobs.Task{
TaskKey: "parent_task",
RetryOnTimeout: false,
})

diags := ForEachTask().Apply(ctx, b)
assert.Empty(t, diags)
}
Loading