diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/DynamicWorkflowTask.java b/flytekit-api/src/main/java/org/flyte/api/v1/DynamicWorkflowTask.java index e4fff1c3..2957250a 100644 --- a/flytekit-api/src/main/java/org/flyte/api/v1/DynamicWorkflowTask.java +++ b/flytekit-api/src/main/java/org/flyte/api/v1/DynamicWorkflowTask.java @@ -26,4 +26,8 @@ public interface DynamicWorkflowTask { DynamicJobSpec run(Map inputs); RetryStrategy getRetries(); + + default Resources getResources() { + return Resources.builder().build(); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java index 5a95c8d8..93eef36a 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java @@ -93,4 +93,8 @@ public SdkNode apply( public int getRetries() { return 0; } + + public SdkResources getResources() { + return SdkResources.empty(); + } } diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrar.java index 41016a40..7af07f9a 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrar.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrar.java @@ -33,6 +33,7 @@ import org.flyte.api.v1.DynamicWorkflowTaskRegistrar; import org.flyte.api.v1.Literal; import org.flyte.api.v1.Node; +import org.flyte.api.v1.Resources; import org.flyte.api.v1.RetryStrategy; import org.flyte.api.v1.TaskIdentifier; import org.flyte.api.v1.TypedInterface; @@ -112,6 +113,11 @@ public DynamicJobSpec run(Map inputs) { public RetryStrategy getRetries() { return RetryStrategy.builder().retries(sdkDynamicWorkflow.getRetries()).build(); } + + @Override + public Resources getResources() { + return sdkDynamicWorkflow.getResources().toIdl(); + } } /** diff --git a/flytekit-java/src/test/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrarTest.java b/flytekit-java/src/test/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrarTest.java index 617021d4..8dbc0c2e 100644 --- a/flytekit-java/src/test/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrarTest.java +++ b/flytekit-java/src/test/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrarTest.java @@ -22,6 +22,7 @@ import static org.hamcrest.Matchers.hasSize; import com.google.errorprone.annotations.Var; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.flyte.api.v1.Binding; @@ -30,6 +31,7 @@ import org.flyte.api.v1.Literal; import org.flyte.api.v1.OutputReference; import org.flyte.api.v1.Primitive; +import org.flyte.api.v1.Resources; import org.flyte.api.v1.Scalar; import org.flyte.api.v1.TaskIdentifier; import org.flyte.api.v1.TypedInterface; @@ -70,6 +72,14 @@ void shouldLoad() { .inputs(SdkLiteralTypes.integers().asSdkType("n").getVariableMap()) .outputs(SdkLiteralTypes.integers().asSdkType("2n").getVariableMap()) .build())); + assertThat( + dynWf.getResources(), + equalTo( + Resources.builder() + .requests(resources("0.5", "2Gi")) + .limits(resources("2", "5Gi")) + .build())); + ; var spec = dynWf.run(Map.of("n", Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(3))))); assertThat(spec.nodes(), hasSize(3)); @@ -103,6 +113,14 @@ public SdkBindingData run(SdkWorkflowBuilder builder, SdkBindingData } return x; } + + @Override + public SdkResources getResources() { + return SdkResources.builder() + .requests(sdkResources("0.5", "2Gi")) + .limits(sdkResources("2", "5Gi")) + .build(); + } } static class Mult2 extends SdkRunnableTask, SdkBindingData> { @@ -117,4 +135,18 @@ public SdkBindingData run(SdkBindingData input) { return SdkBindingDataFactory.of(input.get() * 2); } } + + private static Map resources(String cpu, String memory) { + Map limits = new HashMap<>(); + limits.put(Resources.ResourceName.CPU, cpu); + limits.put(Resources.ResourceName.MEMORY, memory); + return limits; + } + + private static Map sdkResources(String cpu, String memory) { + Map limits = new HashMap<>(); + limits.put(SdkResources.ResourceName.CPU, cpu); + limits.put(SdkResources.ResourceName.MEMORY, memory); + return limits; + } } diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index 33d51852..a25cbfd9 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -577,6 +577,7 @@ private static TaskTemplate createTaskTemplateForDynamicWorkflow( "{{.taskTemplatePath}}")) .image(image) .env(emptyList()) + .resources(task.getResources()) .build(); return TaskTemplate.builder()