Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ class ExecutionStatus(StrEnum):

INITIALIZING = "initializing"
PROVISIONING = "provisioning"
RUNNING = "running"
FAILED = "failed"
COMPLETED = "completed"
RUNNING = "running"
CACHED = "cached"
# When a step that can be retried failed, its status is set to retrying.
# Once the next retry is attempted, the status is set to retried.
RETRYING = "retrying"
RETRIED = "retried"
STOPPED = "stopped"
Expand All @@ -103,6 +105,7 @@ def is_finished(self) -> bool:
ExecutionStatus.FAILED,
ExecutionStatus.COMPLETED,
ExecutionStatus.CACHED,
ExecutionStatus.RETRYING,
ExecutionStatus.RETRIED,
ExecutionStatus.STOPPED,
}
Expand Down
64 changes: 56 additions & 8 deletions src/zenml/execution/pipeline/dynamic/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Dynamic pipeline execution outputs."""

from abc import abstractmethod
from concurrent.futures import Future
from typing import Any, Iterator, List, Optional, Tuple, Union, overload

Expand All @@ -34,7 +35,27 @@ class OutputArtifact(ArtifactVersionResponse):
StepRunOutputs = Union[None, OutputArtifact, Tuple[OutputArtifact, ...]]


class _BaseStepRunFuture:
class BaseFuture:
"""Base future."""

@abstractmethod
def running(self) -> bool:
"""Check if the future is running.

Returns:
True if the future is running, False otherwise.
"""

@abstractmethod
def result(self) -> Any:
"""Get the result of the future.

Returns:
The result of the future.
"""


class BaseStepRunFuture(BaseFuture):
"""Base step run future."""

def __init__(
Expand Down Expand Up @@ -62,12 +83,16 @@ def invocation_id(self) -> str:
"""
return self._invocation_id

def _wait(self) -> None:
"""Wait for the step run future to complete."""
self._wrapped.result()
def running(self) -> bool:
"""Check if the step run future is running.

Returns:
True if the step run future is running, False otherwise.
"""
return self._wrapped.running()


class ArtifactFuture(_BaseStepRunFuture):
class ArtifactFuture(BaseStepRunFuture):
"""Future for a step run output artifact."""

def __init__(
Expand Down Expand Up @@ -115,7 +140,7 @@ def load(self, disable_cache: bool = False) -> Any:
return self.result().load(disable_cache=disable_cache)


class StepRunOutputsFuture(_BaseStepRunFuture):
class StepRunOutputsFuture(BaseStepRunFuture):
"""Future for a step run output."""

def __init__(
Expand Down Expand Up @@ -270,7 +295,7 @@ def __len__(self) -> int:
return len(self._output_keys)


class MapResultsFuture:
class MapResultsFuture(BaseFuture):
"""Future that represents the results of a `step.map/product(...)` call."""

def __init__(self, futures: List[StepRunOutputsFuture]) -> None:
Expand All @@ -281,6 +306,14 @@ def __init__(self, futures: List[StepRunOutputsFuture]) -> None:
"""
self.futures = futures

def running(self) -> bool:
"""Check if the map results future is running.

Returns:
True if the map results future is running, False otherwise.
"""
return any(future.running() for future in self.futures)

def result(self) -> List[StepRunOutputs]:
"""Get the step run outputs this future represents.

Expand All @@ -289,6 +322,19 @@ def result(self) -> List[StepRunOutputs]:
"""
return [future.result() for future in self.futures]

def load(self, disable_cache: bool = False) -> List[Any]:
"""Load the step run output artifacts.

Args:
disable_cache: Whether to disable the artifact cache.

Returns:
The step run output artifacts.
"""
return [
future.load(disable_cache=disable_cache) for future in self.futures
]

def unpack(self) -> Tuple[List[ArtifactFuture], ...]:
"""Unpack the map results future.

Expand Down Expand Up @@ -358,4 +404,6 @@ def __len__(self) -> int:
return len(self.futures)


StepRunFuture = Union[ArtifactFuture, StepRunOutputsFuture, MapResultsFuture]
AnyStepRunFuture = Union[
ArtifactFuture, StepRunOutputsFuture, MapResultsFuture
]
Loading
Loading