5050from zenml .execution .pipeline .dynamic .run_context import (
5151 DynamicPipelineRunContext ,
5252)
53- from zenml .execution .pipeline .dynamic .utils import _Unmapped
53+ from zenml .execution .pipeline .dynamic .utils import (
54+ _Unmapped ,
55+ wait_for_step_run_to_finish ,
56+ )
5457from zenml .execution .step .utils import launch_step
5558from zenml .logger import get_logger
5659from zenml .logging .step_logging import setup_pipeline_logging
5962 PipelineRunResponse ,
6063 PipelineRunUpdate ,
6164 PipelineSnapshotResponse ,
65+ StepRunResponse ,
6266)
6367from zenml .orchestrators .publish_utils import (
6468 publish_failed_pipeline_run ,
69+ publish_failed_step_run ,
6570 publish_successful_pipeline_run ,
6671)
6772from zenml .pipelines .dynamic .pipeline_definition import DynamicPipeline
@@ -118,11 +123,20 @@ def __init__(
118123 self ._executor = ThreadPoolExecutor (max_workers = 10 )
119124 self ._pipeline : Optional ["DynamicPipeline" ] = None
120125 self ._orchestrator = Stack .from_model (snapshot .stack ).orchestrator
121- self ._orchestrator_run_id = (
122- self ._orchestrator .get_orchestrator_run_id ()
123- )
124126 self ._futures : List [StepRunOutputsFuture ] = []
125127
128+ self ._existing_step_runs : Dict [str , "StepRunResponse" ] = {}
129+ if run and run .orchestrator_run_id :
130+ logger .info ("Continuing existing run `%s`." , str (run .id ))
131+ self ._orchestrator_run_id = run .orchestrator_run_id
132+
133+ if run .status .is_in_progress :
134+ self ._existing_step_runs = run .steps .copy ()
135+ else :
136+ self ._orchestrator_run_id = (
137+ self ._orchestrator .get_orchestrator_run_id ()
138+ )
139+
126140 @property
127141 def pipeline (self ) -> "DynamicPipeline" :
128142 """The pipeline that the runner is executing.
@@ -153,17 +167,28 @@ def pipeline(self) -> "DynamicPipeline":
153167
154168 def run_pipeline (self ) -> None :
155169 """Run the pipeline."""
170+ existing_logs_response = None
171+ if self ._run :
172+ for log_response in self ._run .log_collection or []:
173+ if log_response .source == "orchestrator" :
174+ existing_logs_response = log_response
175+ break
176+
156177 with setup_pipeline_logging (
157178 source = "orchestrator" ,
158179 snapshot = self ._snapshot ,
180+ logs_response = existing_logs_response ,
159181 ) as logs_request :
160182 if self ._run :
183+ run_update = PipelineRunUpdate (
184+ add_logs = [logs_request ] if logs_request else None ,
185+ )
186+ if not self ._run .orchestrator_run_id :
187+ run_update .orchestrator_run_id = self ._orchestrator_run_id
188+
161189 run = Client ().zen_store .update_run (
162190 run_id = self ._run .id ,
163- run_update = PipelineRunUpdate (
164- orchestrator_run_id = self ._orchestrator_run_id ,
165- add_logs = [logs_request ] if logs_request else None ,
166- ),
191+ run_update = run_update ,
167192 )
168193 else :
169194 run = create_placeholder_run (
@@ -205,6 +230,7 @@ def run_pipeline(self) -> None:
205230 self ._executor .shutdown (wait = True , cancel_futures = True )
206231
207232 publish_successful_pipeline_run (run .id )
233+ logger .info ("Pipeline completed successfully." )
208234
209235 @overload
210236 def launch_step (
@@ -260,21 +286,114 @@ def launch_step(
260286 after = after ,
261287 )
262288
263- def _launch () -> StepRunOutputs :
289+ should_retry = _should_retry_locally (
290+ compiled_step ,
291+ self ._snapshot .pipeline_configuration .docker_settings ,
292+ )
293+
294+ def _run_step (
295+ remaining_retries : Optional [int ] = None ,
296+ ) -> StepRunOutputs :
297+ # TODO: maybe pass run here to avoid extra server requests?
264298 step_run = launch_step (
265299 snapshot = self ._snapshot ,
266300 step = compiled_step ,
267301 orchestrator_run_id = self ._orchestrator_run_id ,
268- retry = _should_retry_locally (
269- compiled_step ,
270- self ._snapshot .pipeline_configuration .docker_settings ,
271- ),
302+ retry = should_retry ,
303+ remaining_retries = remaining_retries ,
272304 )
273305 return _load_step_run_outputs (step_run .id )
274306
307+ existing_step_run = self ._existing_step_runs .get (
308+ compiled_step .spec .invocation_id
309+ )
310+ if existing_step_run :
311+ if existing_step_run .config != compiled_step .config :
312+ logger .warning (
313+ "Configuration for step `%s` changed since the the "
314+ "orchestration environment was restarted. If the step "
315+ "needs to be retried, it will use the new configuration." ,
316+ existing_step_run .name ,
317+ )
318+
319+ def _workload () -> StepRunOutputs :
320+ nonlocal existing_step_run
321+ assert existing_step_run
322+
323+ if existing_step_run .status .is_successful :
324+ return _load_step_run_outputs (existing_step_run .id )
325+
326+ runtime = get_step_runtime (
327+ step_config = compiled_step .config ,
328+ pipeline_docker_settings = self ._snapshot .pipeline_configuration .docker_settings ,
329+ )
330+ if (
331+ runtime == StepRuntime .INLINE
332+ and existing_step_run .status .is_in_progress
333+ ):
334+ # Inline steps that are in running state didn't have the
335+ # chance to report their failure back to ZenML before the
336+ # orchestration environment was shut down. But there is no
337+ # way that they're actually still running if we're in a new
338+ # orchestration environment, we we mark them as failed and
339+ # potentially restart them depending on the retry config.
340+ existing_step_run = publish_failed_step_run (
341+ existing_step_run .id
342+ )
343+
344+ remaining_retries = 0
345+
346+ if should_retry :
347+ max_retries = (
348+ compiled_step .config .retry .max_retries
349+ if compiled_step .config .retry
350+ else 0
351+ )
352+ remaining_retries = max (
353+ 0 , 1 + max_retries - existing_step_run .version
354+ )
355+
356+ if existing_step_run .status .is_in_progress :
357+ logger .info (
358+ "Restarting the monitoring of existing step `%s` "
359+ "(step run ID: %s). Remaining retries: %d" ,
360+ existing_step_run .name ,
361+ existing_step_run .id ,
362+ remaining_retries ,
363+ )
364+
365+ if remaining_retries > 0 :
366+ step_run = wait_for_step_run_to_finish (
367+ existing_step_run .id
368+ )
369+ if not step_run .status .is_successful :
370+ logger .error (
371+ "Failed to run step `%s`." ,
372+ existing_step_run .name ,
373+ )
374+ return _run_step (remaining_retries = remaining_retries )
375+ else :
376+ return _load_step_run_outputs (existing_step_run .id )
377+ else :
378+ step_run = wait_for_step_run_to_finish (
379+ existing_step_run .id
380+ )
381+ if not step_run .status .is_successful :
382+ # This is the last retry, in which case we have to raise
383+ # an error that the step failed.
384+ # TODO: Make this better by raising the actual exception
385+ # that caused the step to fail instead of just a generic
386+ # runtime error.
387+ raise RuntimeError (
388+ f"Failed to run step `{ existing_step_run .name } `."
389+ )
390+ return _load_step_run_outputs (existing_step_run .id )
391+ else :
392+ _workload = _run_step
393+
275394 if concurrent :
276395 ctx = contextvars .copy_context ()
277- future = self ._executor .submit (ctx .run , _launch )
396+ future = self ._executor .submit (ctx .run , _workload )
278397 step_run_future = StepRunOutputsFuture (
279398 wrapped = future ,
280399 invocation_id = compiled_step .spec .invocation_id ,
@@ -283,7 +402,7 @@ def _launch() -> StepRunOutputs:
283402 self ._futures .append (step_run_future )
284403 return step_run_future
285404 else :
286- return _launch ()
405+ return _workload ()
287406
288407 def map (
289408 self ,
@@ -337,12 +456,7 @@ def compile_dynamic_step_invocation(
337456 pipeline : "DynamicPipeline" ,
338457 step : "BaseStep" ,
339458 inputs : Dict [str , Any ],
340- after : Union [
341- "StepRunFuture" ,
342- "ArtifactFuture" ,
343- Sequence [Union ["StepRunFuture" , "ArtifactFuture" ]],
344- None ,
345- ] = None ,
459+ after : Union ["StepRunFuture" , Sequence ["StepRunFuture" ], None ] = None ,
346460 id : Optional [str ] = None ,
347461) -> "Step" :
348462 """Compile a dynamic step invocation.
0 commit comments