Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,41 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs):

self._task_config = task_config

# Add dynamic task_type based on current (possibly overridden) nnodes
@property
def task_type(self) -> str:
"""Return task type dynamically so overrides to nnodes take effect at serialization time.
Single-node => "python-task"; multi-node => "pytorch".
"""
# Avoid importing torch here; keep logic lightweight and resilient during serialization
n = self._task_config.nnodes
# Accept ints and simple strings like "1" or ranges like "1:1"
try:
# Fast path: integer equality
if isinstance(n, int):
return self._ELASTIC_TASK_TYPE_STANDALONE if n == 1 else self._ELASTIC_TASK_TYPE
# String handling
if isinstance(n, str):
s = n.strip()
if ":" in s:
parts = s.split(":", 1)
min_n = int(parts[0].strip())
max_n = int(parts[1].strip())
# Treat 1:1 as single-node; anything else as multi-node
return (
self._ELASTIC_TASK_TYPE_STANDALONE
if min_n == 1 and max_n == 1
else self._ELASTIC_TASK_TYPE
)
# Plain numeric string
as_int = int(s)
return self._ELASTIC_TASK_TYPE_STANDALONE if as_int == 1 else self._ELASTIC_TASK_TYPE
except Exception:
# On any parsing issue, fall back to multi-node behavior (safe default)
return self._ELASTIC_TASK_TYPE
# Unknown type; default to multi-node
return self._ELASTIC_TASK_TYPE

def _execute(self, **kwargs) -> Any:
"""
Execute the task function using torch distributed's `elastic_launch`.
Expand All @@ -362,13 +397,38 @@ def _execute(self, **kwargs) -> Any:
except ImportError:
raise ImportError(TORCH_IMPORT_ERROR_MESSAGE)

nnodes_str = os.environ.get("PET_NNODES", str(self._task_config.nnodes))
# Determine if we are in single-pod mode (python-task) or multi-node (pytorch)
single_pod_mode = self.task_type == self._ELASTIC_TASK_TYPE_STANDALONE

# Ensure single-pod mode never waits for multiple nodes due to env overrides
nnodes_str_env = os.environ.get("PET_NNODES")
nnodes_str = "1" if single_pod_mode else (nnodes_str_env or str(self._task_config.nnodes))
min_nodes, max_nodes = run.parse_min_max_nnodes(nnodes_str)

nproc_per_node = int(os.environ.get("PET_NPROC_PER_NODE", self._task_config.nproc_per_node))
max_restarts = int(os.environ.get("PET_MAX_RESTARTS", self._task_config.max_restarts))
monitor_interval = int(os.environ.get("PET_MONITOR_INTERVAL", self._task_config.monitor_interval))
rdzv_endpoint = os.environ.get("PET_RDZV_ENDPOINT", "localhost:0")
nproc_per_node_env = os.environ.get("PET_NPROC_PER_NODE")
nproc_per_node = int(nproc_per_node_env or self._task_config.nproc_per_node)
max_restarts_env = os.environ.get("PET_MAX_RESTARTS")
max_restarts = int(max_restarts_env or self._task_config.max_restarts)
monitor_interval_env = os.environ.get("PET_MONITOR_INTERVAL")
monitor_interval = int(monitor_interval_env or self._task_config.monitor_interval)

# In single-pod mode, always rendezvous on loopback to avoid cross-pod endpoints from env overrides
rdzv_endpoint_env = os.environ.get("PET_RDZV_ENDPOINT")
rdzv_endpoint = "127.0.0.1:0" if single_pod_mode else (rdzv_endpoint_env or "localhost:0")

# Emit detailed debug so misconfigurations are obvious in logs
logger.info(
"[Elastic] mode=%s nnodes_str=%s -> min_nodes=%s max_nodes=%s nproc_per_node=%s "
"rdzv_backend=%s rdzv_endpoint=%s start_method=%s",
"single-pod" if single_pod_mode else "multi-node",
nnodes_str_env if nnodes_str_env is not None else str(self._task_config.nnodes),
min_nodes,
max_nodes,
nproc_per_node,
self.rdzv_backend,
rdzv_endpoint,
self._task_config.start_method,
)

# If OMP_NUM_THREADS is not set, set it to 1 to avoid overloading the system.
# Doing so to copy the default behavior of torchrun.
Expand Down