Skip to content

Commit 957fa61

Browse files
committed
Enable retries for dynamic pipeline function execution
1 parent d2071ec commit 957fa61

File tree

6 files changed

+209
-28
lines changed

6 files changed

+209
-28
lines changed

src/zenml/enums.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def is_finished(self) -> bool:
104104
ExecutionStatus.COMPLETED,
105105
ExecutionStatus.CACHED,
106106
ExecutionStatus.RETRIED,
107+
ExecutionStatus.RETRYING,
107108
ExecutionStatus.STOPPED,
108109
}
109110

@@ -125,6 +126,20 @@ def is_failed(self) -> bool:
125126
"""
126127
return self in {ExecutionStatus.FAILED}
127128

129+
@property
130+
def is_in_progress(self) -> bool:
131+
"""Whether the execution status refers to an in progress execution.
132+
133+
Returns:
134+
Whether the execution status refers to an in progress execution.
135+
"""
136+
return self in {
137+
ExecutionStatus.INITIALIZING,
138+
ExecutionStatus.PROVISIONING,
139+
ExecutionStatus.RUNNING,
140+
ExecutionStatus.STOPPING,
141+
}
142+
128143

129144
class LoggingLevels(Enum):
130145
"""Enum for logging levels."""

src/zenml/execution/pipeline/dynamic/outputs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,19 @@ def result(self) -> List[StepRunOutputs]:
289289
"""
290290
return [future.result() for future in self.futures]
291291

292+
def load(self, disable_cache: bool = False) -> List[Any]:
293+
"""Load the step run output artifacts.
294+
295+
Args:
296+
disable_cache: Whether to disable the artifact cache.
297+
298+
Returns:
299+
The step run output artifacts.
300+
"""
301+
return [
302+
future.load(disable_cache=disable_cache) for future in self.futures
303+
]
304+
292305
def unpack(self) -> Tuple[List[ArtifactFuture], ...]:
293306
"""Unpack the map results future.
294307

src/zenml/execution/pipeline/dynamic/runner.py

Lines changed: 135 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@
5050
from zenml.execution.pipeline.dynamic.run_context import (
5151
DynamicPipelineRunContext,
5252
)
53-
from zenml.execution.pipeline.dynamic.utils import _Unmapped
53+
from zenml.execution.pipeline.dynamic.utils import (
54+
_Unmapped,
55+
wait_for_step_run_to_finish,
56+
)
5457
from zenml.execution.step.utils import launch_step
5558
from zenml.logger import get_logger
5659
from zenml.logging.step_logging import setup_pipeline_logging
@@ -59,9 +62,11 @@
5962
PipelineRunResponse,
6063
PipelineRunUpdate,
6164
PipelineSnapshotResponse,
65+
StepRunResponse,
6266
)
6367
from zenml.orchestrators.publish_utils import (
6468
publish_failed_pipeline_run,
69+
publish_failed_step_run,
6570
publish_successful_pipeline_run,
6671
)
6772
from zenml.pipelines.dynamic.pipeline_definition import DynamicPipeline
@@ -118,11 +123,20 @@ def __init__(
118123
self._executor = ThreadPoolExecutor(max_workers=10)
119124
self._pipeline: Optional["DynamicPipeline"] = None
120125
self._orchestrator = Stack.from_model(snapshot.stack).orchestrator
121-
self._orchestrator_run_id = (
122-
self._orchestrator.get_orchestrator_run_id()
123-
)
124126
self._futures: List[StepRunOutputsFuture] = []
125127

128+
self._existing_step_runs: Dict[str, "StepRunResponse"] = {}
129+
if run and run.orchestrator_run_id:
130+
logger.info("Continuing existing run `%s`.", str(run.id))
131+
self._orchestrator_run_id = run.orchestrator_run_id
132+
133+
if run.status.is_in_progress:
134+
self._existing_step_runs = run.steps.copy()
135+
else:
136+
self._orchestrator_run_id = (
137+
self._orchestrator.get_orchestrator_run_id()
138+
)
139+
126140
@property
127141
def pipeline(self) -> "DynamicPipeline":
128142
"""The pipeline that the runner is executing.
@@ -153,17 +167,28 @@ def pipeline(self) -> "DynamicPipeline":
153167

154168
def run_pipeline(self) -> None:
155169
"""Run the pipeline."""
170+
existing_logs_response = None
171+
if self._run:
172+
for log_response in self._run.log_collection or []:
173+
if log_response.source == "orchestrator":
174+
existing_logs_response = log_response
175+
break
176+
156177
with setup_pipeline_logging(
157178
source="orchestrator",
158179
snapshot=self._snapshot,
180+
logs_response=existing_logs_response,
159181
) as logs_request:
160182
if self._run:
183+
run_update = PipelineRunUpdate(
184+
add_logs=[logs_request] if logs_request else None,
185+
)
186+
if not self._run.orchestrator_run_id:
187+
run_update.orchestrator_run_id = self._orchestrator_run_id
188+
161189
run = Client().zen_store.update_run(
162190
run_id=self._run.id,
163-
run_update=PipelineRunUpdate(
164-
orchestrator_run_id=self._orchestrator_run_id,
165-
add_logs=[logs_request] if logs_request else None,
166-
),
191+
run_update=run_update,
167192
)
168193
else:
169194
run = create_placeholder_run(
@@ -205,6 +230,7 @@ def run_pipeline(self) -> None:
205230
self._executor.shutdown(wait=True, cancel_futures=True)
206231

207232
publish_successful_pipeline_run(run.id)
233+
logger.info("Pipeline completed successfully.")
208234

209235
@overload
210236
def launch_step(
@@ -260,21 +286,114 @@ def launch_step(
260286
after=after,
261287
)
262288

263-
def _launch() -> StepRunOutputs:
289+
should_retry = _should_retry_locally(
290+
compiled_step,
291+
self._snapshot.pipeline_configuration.docker_settings,
292+
)
293+
294+
def _run_step(
295+
remaining_retries: Optional[int] = None,
296+
) -> StepRunOutputs:
297+
# TODO: maybe pass run here to avoid extra server requests?
264298
step_run = launch_step(
265299
snapshot=self._snapshot,
266300
step=compiled_step,
267301
orchestrator_run_id=self._orchestrator_run_id,
268-
retry=_should_retry_locally(
269-
compiled_step,
270-
self._snapshot.pipeline_configuration.docker_settings,
271-
),
302+
retry=should_retry,
303+
remaining_retries=remaining_retries,
272304
)
273305
return _load_step_run_outputs(step_run.id)
274306

307+
existing_step_run = self._existing_step_runs.get(
308+
compiled_step.spec.invocation_id
309+
)
310+
if existing_step_run:
311+
if existing_step_run.config != compiled_step.config:
312+
logger.warning(
313+
"Configuration for step `%s` changed since the the "
314+
"orchestration environment was restarted. If the step "
315+
"needs to be retried, it will use the new configuration.",
316+
existing_step_run.name,
317+
)
318+
319+
def _workload() -> StepRunOutputs:
320+
nonlocal existing_step_run
321+
assert existing_step_run
322+
323+
if existing_step_run.status.is_successful:
324+
return _load_step_run_outputs(existing_step_run.id)
325+
326+
runtime = get_step_runtime(
327+
step_config=compiled_step.config,
328+
pipeline_docker_settings=self._snapshot.pipeline_configuration.docker_settings,
329+
)
330+
if (
331+
runtime == StepRuntime.INLINE
332+
and existing_step_run.status.is_in_progress
333+
):
334+
# Inline steps that are in running state didn't have the
335+
# chance to report their failure back to ZenML before the
336+
# orchestration environment was shut down. But there is no
337+
# way that they're actually still running if we're in a new
338+
# orchestration environment, we we mark them as failed and
339+
# potentially restart them depending on the retry config.
340+
existing_step_run = publish_failed_step_run(
341+
existing_step_run.id
342+
)
343+
344+
remaining_retries = 0
345+
346+
if should_retry:
347+
max_retries = (
348+
compiled_step.config.retry.max_retries
349+
if compiled_step.config.retry
350+
else 0
351+
)
352+
remaining_retries = max(
353+
0, 1 + max_retries - existing_step_run.version
354+
)
355+
356+
if existing_step_run.status.is_in_progress:
357+
logger.info(
358+
"Restarting the monitoring of existing step `%s` "
359+
"(step run ID: %s). Remaining retries: %d",
360+
existing_step_run.name,
361+
existing_step_run.id,
362+
remaining_retries,
363+
)
364+
365+
if remaining_retries > 0:
366+
step_run = wait_for_step_run_to_finish(
367+
existing_step_run.id
368+
)
369+
if not step_run.status.is_successful:
370+
logger.error(
371+
"Failed to run step `%s`.",
372+
existing_step_run.name,
373+
)
374+
return _run_step(remaining_retries=remaining_retries)
375+
else:
376+
return _load_step_run_outputs(existing_step_run.id)
377+
else:
378+
step_run = wait_for_step_run_to_finish(
379+
existing_step_run.id
380+
)
381+
if not step_run.status.is_successful:
382+
# This is the last retry, in which case we have to raise
383+
# an error that the step failed.
384+
# TODO: Make this better by raising the actual exception
385+
# that caused the step to fail instead of just a generic
386+
# runtime error.
387+
raise RuntimeError(
388+
f"Failed to run step `{existing_step_run.name}`."
389+
)
390+
return _load_step_run_outputs(existing_step_run.id)
391+
else:
392+
_workload = _run_step
393+
275394
if concurrent:
276395
ctx = contextvars.copy_context()
277-
future = self._executor.submit(ctx.run, _launch)
396+
future = self._executor.submit(ctx.run, _workload)
278397
step_run_future = StepRunOutputsFuture(
279398
wrapped=future,
280399
invocation_id=compiled_step.spec.invocation_id,
@@ -283,7 +402,7 @@ def _launch() -> StepRunOutputs:
283402
self._futures.append(step_run_future)
284403
return step_run_future
285404
else:
286-
return _launch()
405+
return _workload()
287406

288407
def map(
289408
self,
@@ -337,12 +456,7 @@ def compile_dynamic_step_invocation(
337456
pipeline: "DynamicPipeline",
338457
step: "BaseStep",
339458
inputs: Dict[str, Any],
340-
after: Union[
341-
"StepRunFuture",
342-
"ArtifactFuture",
343-
Sequence[Union["StepRunFuture", "ArtifactFuture"]],
344-
None,
345-
] = None,
459+
after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None,
346460
id: Optional[str] = None,
347461
) -> "Step":
348462
"""Compile a dynamic step invocation.

src/zenml/execution/pipeline/dynamic/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,19 @@
1313
# permissions and limitations under the License.
1414
"""Dynamic pipeline execution utilities."""
1515

16+
import time
1617
from typing import (
1718
Generic,
1819
TypeVar,
1920
)
21+
from uuid import UUID
22+
23+
from zenml.client import Client
24+
from zenml.logger import get_logger
25+
from zenml.models import StepRunResponse
26+
27+
logger = get_logger(__name__)
28+
2029

2130
T = TypeVar("T")
2231

@@ -46,3 +55,31 @@ def unmapped(value: T) -> _Unmapped[T]:
4655
The wrapped value.
4756
"""
4857
return _Unmapped(value)
58+
59+
60+
def wait_for_step_run_to_finish(step_run_id: UUID) -> "StepRunResponse":
61+
"""Wait until a step run is finished.
62+
63+
Args:
64+
step_run_id: The ID of the step run.
65+
66+
Returns:
67+
The finished step run.
68+
"""
69+
sleep_interval = 1
70+
max_sleep_interval = 64
71+
72+
while True:
73+
step_run = Client().zen_store.get_run_step(step_run_id)
74+
75+
if step_run.status.is_finished:
76+
return step_run
77+
78+
logger.debug(
79+
"Waiting for step run with ID %s to finish (current status: %s)",
80+
step_run_id,
81+
step_run.status,
82+
)
83+
time.sleep(sleep_interval)
84+
if sleep_interval < max_sleep_interval:
85+
sleep_interval *= 2

src/zenml/execution/step/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import time
1717
from typing import (
1818
TYPE_CHECKING,
19+
Optional,
1920
)
2021

2122
from zenml.config.step_configurations import Step
@@ -39,6 +40,7 @@ def launch_step(
3940
step: "Step",
4041
orchestrator_run_id: str,
4142
retry: bool = False,
43+
remaining_retries: Optional[int] = None,
4244
) -> StepRunResponse:
4345
"""Launch a step.
4446
@@ -47,6 +49,8 @@ def launch_step(
4749
step: The step to run.
4850
orchestrator_run_id: The orchestrator run ID.
4951
retry: Whether to retry the step if it fails.
52+
remaining_retries: The number of remaining retries. If not passed, this
53+
will be read from the step configuration.
5054
5155
Raises:
5256
RunStoppedException: If the run was stopped.
@@ -69,7 +73,10 @@ def _launch_without_retry() -> StepRunResponse:
6973
else:
7074
retries = 0
7175
retry_config = step.config.retry
72-
max_retries = retry_config.max_retries if retry_config else 0
76+
if remaining_retries is None:
77+
max_retries = retry_config.max_retries if retry_config else 0
78+
else:
79+
max_retries = remaining_retries
7380
delay = retry_config.delay if retry_config else 0
7481
backoff = retry_config.backoff if retry_config else 1
7582

src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,7 @@ def _submit_orchestrator_job(
775775
annotations=annotations,
776776
settings=settings,
777777
pod_settings=orchestrator_pod_settings,
778-
# In dynamic pipelines restarting the orchestrator pod is not
779-
# supported yet. It will create new runs for each restart which
780-
# we have to avoid.
781-
backoff_limit=0
782-
if snapshot.is_dynamic
783-
else settings.orchestrator_job_backoff_limit,
778+
backoff_limit=settings.orchestrator_job_backoff_limit,
784779
)
785780

786781
if snapshot.schedule:

0 commit comments

Comments
 (0)