diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 1972f10bd9..1dbf2a4d68 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -342,6 +342,41 @@ def __init__(self, task_config: Elastic, task_function: Callable, **kwargs): self._task_config = task_config + # Add dynamic task_type based on current (possibly overridden) nnodes + @property + def task_type(self) -> str: + """Return task type dynamically so overrides to nnodes take effect at serialization time. + Single-node => "python-task"; multi-node => "pytorch". + """ + # Avoid importing torch here; keep logic lightweight and resilient during serialization + n = self._task_config.nnodes + # Accept ints and simple strings like "1" or ranges like "1:1" + try: + # Fast path: integer equality + if isinstance(n, int): + return self._ELASTIC_TASK_TYPE_STANDALONE if n == 1 else self._ELASTIC_TASK_TYPE + # String handling + if isinstance(n, str): + s = n.strip() + if ":" in s: + parts = s.split(":", 1) + min_n = int(parts[0].strip()) + max_n = int(parts[1].strip()) + # Treat 1:1 as single-node; anything else as multi-node + return ( + self._ELASTIC_TASK_TYPE_STANDALONE + if min_n == 1 and max_n == 1 + else self._ELASTIC_TASK_TYPE + ) + # Plain numeric string + as_int = int(s) + return self._ELASTIC_TASK_TYPE_STANDALONE if as_int == 1 else self._ELASTIC_TASK_TYPE + except Exception: + # On any parsing issue, fall back to multi-node behavior (safe default) + return self._ELASTIC_TASK_TYPE + # Unknown type; default to multi-node + return self._ELASTIC_TASK_TYPE + def _execute(self, **kwargs) -> Any: """ Execute the task function using torch distributed's `elastic_launch`. @@ -362,13 +397,38 @@ def _execute(self, **kwargs) -> Any: except ImportError: raise ImportError(TORCH_IMPORT_ERROR_MESSAGE) - nnodes_str = os.environ.get("PET_NNODES", str(self._task_config.nnodes)) + # Determine if we are in single-pod mode (python-task) or multi-node (pytorch) + single_pod_mode = self.task_type == self._ELASTIC_TASK_TYPE_STANDALONE + + # Ensure single-pod mode never waits for multiple nodes due to env overrides + nnodes_str_env = os.environ.get("PET_NNODES") + nnodes_str = "1" if single_pod_mode else (nnodes_str_env or str(self._task_config.nnodes)) min_nodes, max_nodes = run.parse_min_max_nnodes(nnodes_str) - nproc_per_node = int(os.environ.get("PET_NPROC_PER_NODE", self._task_config.nproc_per_node)) - max_restarts = int(os.environ.get("PET_MAX_RESTARTS", self._task_config.max_restarts)) - monitor_interval = int(os.environ.get("PET_MONITOR_INTERVAL", self._task_config.monitor_interval)) - rdzv_endpoint = os.environ.get("PET_RDZV_ENDPOINT", "localhost:0") + nproc_per_node_env = os.environ.get("PET_NPROC_PER_NODE") + nproc_per_node = int(nproc_per_node_env or self._task_config.nproc_per_node) + max_restarts_env = os.environ.get("PET_MAX_RESTARTS") + max_restarts = int(max_restarts_env or self._task_config.max_restarts) + monitor_interval_env = os.environ.get("PET_MONITOR_INTERVAL") + monitor_interval = int(monitor_interval_env or self._task_config.monitor_interval) + + # In single-pod mode, always rendezvous on loopback to avoid cross-pod endpoints from env overrides + rdzv_endpoint_env = os.environ.get("PET_RDZV_ENDPOINT") + rdzv_endpoint = "127.0.0.1:0" if single_pod_mode else (rdzv_endpoint_env or "localhost:0") + + # Emit detailed debug so misconfigurations are obvious in logs + logger.info( + "[Elastic] mode=%s nnodes_str=%s -> min_nodes=%s max_nodes=%s nproc_per_node=%s " + "rdzv_backend=%s rdzv_endpoint=%s start_method=%s", + "single-pod" if single_pod_mode else "multi-node", + nnodes_str_env if nnodes_str_env is not None else str(self._task_config.nnodes), + min_nodes, + max_nodes, + nproc_per_node, + self.rdzv_backend, + rdzv_endpoint, + self._task_config.start_method, + ) # If OMP_NUM_THREADS is not set, set it to 1 to avoid overloading the system. # Doing so to copy the default behavior of torchrun.