Skip to content

Commit 03bb57e

Browse files
committed
Update KubeflowExecutor to use CommandTrainer
Refactor the KubeflowExecutor class to replace the CustomTrainer with CommandTrainer for improved task handling. Introduce a new enable_tcpxo feature that configures a sidecar for TCP enhancements in the runtime template. The implementation now validates entrypoints and manages task configurations more robustly, ensuring compatibility with the CommandTrainer. - Added enable_tcpxo flag to runtime template - Updated TrainerClient initialization with KubernetesBackendConfig - Enhanced error handling for unsupported tasks - Improved logging for trainer configurations and commands Signed-off-by: Krishnaswamy Subramanian <subramk@thoughtworks.com>
1 parent dfff05e commit 03bb57e

File tree

4 files changed

+174
-68
lines changed

4 files changed

+174
-68
lines changed

nemo_run/core/execution/kubeflow.py

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from typing import Any, Dict, Optional, Union
2121

2222
import yaml
23-
from kubeflow.trainer import CustomTrainer, TrainerClient
23+
from kubeflow.trainer import CommandTrainer, TrainerClient
24+
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
2425
from kubernetes import client, config
2526
from kubernetes.client.exceptions import ApiException
2627

@@ -46,7 +47,7 @@ class KubeflowExecutor(Executor):
4647
4748
Example:
4849
49-
.. code-block:: python
50+
. code-block:: python
5051
5152
# Configure executor for execution environment
5253
executor = KubeflowExecutor(
@@ -104,6 +105,9 @@ class KubeflowExecutor(Executor):
104105
#: Detach mode flag (set by experiment framework)
105106
_detach_mode: bool = field(init=False, default=False)
106107

108+
#: Enable tcpxo sidecar and related mounts/env in runtime template
109+
enable_tcpxo: bool = False
110+
107111
def __post_init__(self):
108112
"""Validate executor configuration and setup Kubernetes access."""
109113
if self.nodes < 1:
@@ -213,7 +217,8 @@ def _get_trainer_client(self) -> TrainerClient:
213217
"""Get or create a TrainerClient instance."""
214218
if self._trainer_client is None:
215219
# Initialize client with the executor's namespace
216-
self._trainer_client = TrainerClient(namespace=self.namespace)
220+
k8s_backend_config = KubernetesBackendConfig(namespace=self.namespace)
221+
self._trainer_client = TrainerClient(backend_config=k8s_backend_config)
217222
return self._trainer_client
218223

219224
def _create_cluster_training_runtime(self, configmap_name: str, sha: str) -> str:
@@ -234,6 +239,7 @@ def _create_cluster_training_runtime(self, configmap_name: str, sha: str) -> str
234239
"cpu_limit": self.cpu_limit,
235240
"memory_limit": self.memory_limit,
236241
"gpus": self.gpus,
242+
"enable_tcpxo": self.enable_tcpxo,
237243
}
238244
rendered = fill_template(
239245
template_name="kubeflow_clustertrainingruntime.yaml.j2",
@@ -326,10 +332,8 @@ def _get_additional_files(self, task) -> dict[str, tuple[str, str]]:
326332
logger.info("Script task - will stage content in ConfigMap")
327333

328334
elif hasattr(task, "__fn_or_cls__"):
329-
# Partial task - will be handled directly by CustomTrainer, no ConfigMap staging needed
330-
logger.info(
331-
"Partial task - will be passed directly to CustomTrainer, skipping ConfigMap staging"
332-
)
335+
# Partial support not implemented yet for CommandTrainer path
336+
logger.warning("Partial tasks are not yet supported with Kubeflow CommandTrainer.")
333337

334338
return files_to_stage
335339

@@ -370,43 +374,51 @@ def cleanup_files(self, task_dir: str, task=None):
370374
# Use experiment-specific naming for cleanup
371375
self.packager.cleanup(self._get_experiment_identifier())
372376

373-
def _get_custom_trainer(self, task) -> CustomTrainer:
374-
"""Get the CustomTrainer configuration for the training job."""
375-
trainer_kwargs: dict = {"num_nodes": self.nodes}
377+
def _get_custom_trainer(self, task) -> CommandTrainer:
378+
"""Build a CommandTrainer for a Script task. Partial is not yet supported."""
379+
# Reject Partial until implemented
380+
if hasattr(task, "__fn_or_cls__"):
381+
raise NotImplementedError(
382+
"Partial tasks are not yet supported with Kubeflow CommandTrainer"
383+
)
384+
376385
resources_per_node: dict = {}
377386
if self.cpu_limit is not None:
378387
resources_per_node["cpu"] = self.cpu_limit
379388
if self.memory_limit is not None:
380389
resources_per_node["memory"] = self.memory_limit
381390
if self.gpus is not None:
382391
resources_per_node["nvidia.com/gpu"] = str(self.gpus)
383-
trainer_kwargs["resources_per_node"] = resources_per_node
384392

385-
if hasattr(task, "__fn_or_cls__"):
386-
trainer_kwargs["func"] = task.__fn_or_cls__
387-
if hasattr(task, "__arguments__") and task.__arguments__:
388-
trainer_kwargs["func_args"] = task.__arguments__
393+
# Determine command/args based on entrypoint
394+
entrypoint = getattr(task, "entrypoint", "bash") or "bash"
395+
mounted_path = f"{self.volume_mount_path}/{self.training_entry}"
396+
397+
command: list[str]
398+
args: list[str]
399+
ep_lower = entrypoint.lower()
400+
if "bash" in ep_lower:
401+
command = ["/bin/bash"]
402+
args = ["-c", mounted_path]
403+
elif "python" in ep_lower:
404+
command = ["python"]
405+
args = [mounted_path]
389406
else:
390-
# Script task - set python_file and check for bash scripts
391-
trainer_kwargs["python_file"] = f"{self.volume_mount_path}/{self.training_entry}"
392-
393-
# Check if this is a bash script and set appropriate command
394-
if hasattr(task, "inline") and task.inline:
395-
entrypoint = getattr(task, "entrypoint", "bash")
396-
if entrypoint and "bash" in entrypoint.lower():
397-
trainer_kwargs["command"] = ["/bin/bash"]
398-
logger.info("Using bash command for script execution")
399-
# For Python scripts, let SDK auto-detect based on runtime
400-
401-
# Debug logging to see what we're passing to CustomTrainer
402-
logger.info(f"Creating CustomTrainer with kwargs: {trainer_kwargs}")
403-
404-
trainer = CustomTrainer(**trainer_kwargs)
407+
# Fallback: treat entrypoint as executable to run the staged file
408+
command = [entrypoint]
409+
args = [mounted_path]
410+
411+
trainer = CommandTrainer(
412+
command=command,
413+
args=args,
414+
num_nodes=self.nodes,
415+
resources_per_node=resources_per_node,
416+
)
405417

406-
# Debug logging to see what CustomTrainer actually received
407-
logger.info(f"CustomTrainer created with func: {trainer.func}")
408-
logger.info(f"CustomTrainer created with func_args: {trainer.func_args}")
409-
logger.info(f"CustomTrainer created with python_file: {trainer.python_file}")
418+
logger.info(
419+
f"CommandTrainer created with command={trainer.command}, args={trainer.args}, "
420+
f"num_nodes={trainer.num_nodes}, resources_per_node={trainer.resources_per_node}"
421+
)
410422

411423
return trainer
412424

@@ -442,11 +454,15 @@ def delete_trainjob(self, job_name: str):
442454
except Exception as e:
443455
logger.error(f"Failed to delete TrainJob: {e}")
444456

445-
def get_trainjob_logs(self, job_name: str, follow: bool = False) -> dict:
457+
def get_trainjob_logs(self, job_name: str, follow: bool = False):
446458
"""Get logs from a TrainJob."""
447459
try:
448460
client = self._get_trainer_client()
449-
return client.get_job_logs(job_name, follow=follow)
461+
logs_iter = client.get_job_logs(job_name, follow=follow)
462+
# Some tests mock this as a dict; in real SDK it's an Iterator[str]
463+
if isinstance(logs_iter, dict):
464+
return logs_iter
465+
return logs_iter
450466
except Exception as e:
451467
logger.error(f"Failed to get TrainJob logs: {e}")
452468
return {}
@@ -529,3 +545,17 @@ def _runtime_name(self, sha: str) -> str:
529545
"""Build CRT name from the shared experiment identifier and sha."""
530546
identifier = self._get_experiment_identifier()
531547
return sanitize_kubernetes_name(f"nemo-runtime-{identifier}-{sha}")
548+
549+
def _get_staged_file_path(self, filename: str) -> str:
550+
"""Return path where a staged file would be mounted inside the container.
551+
552+
If using ConfigMapPackager, files are mounted under volume_mount_path with
553+
experiment-specific prefix. Otherwise, return the filename unchanged.
554+
"""
555+
if (
556+
isinstance(self.packager, ConfigMapPackager)
557+
and hasattr(self, "experiment_name")
558+
and self.experiment_name
559+
):
560+
return f"{self.volume_mount_path}/{self.experiment_name}-{filename}"
561+
return filename

nemo_run/core/execution/templates/kubeflow_clustertrainingruntime.yaml.j2

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ kind: ClusterTrainingRuntime
33
metadata:
44
name: {{ runtime_name }}
55
namespace: {{ namespace }}
6+
labels:
7+
trainer.kubeflow.org/framework: torch
68
spec:
79
mlPolicy:
810
numNodes: {{ nodes }}
@@ -17,22 +19,110 @@ spec:
1719
metadata:
1820
labels:
1921
trainer.kubeflow.org/trainjob-ancestor-step: trainer
22+
{% if enable_tcpxo %}
23+
annotations:
24+
devices.gke.io/container.tcpxo-daemon: |
25+
- path: /dev/nvidia0
26+
- path: /dev/nvidia1
27+
- path: /dev/nvidia2
28+
- path: /dev/nvidia3
29+
- path: /dev/nvidia4
30+
- path: /dev/nvidia5
31+
- path: /dev/nvidia6
32+
- path: /dev/nvidia7
33+
- path: /dev/nvidiactl
34+
- path: /dev/nvidia-uvm
35+
- path: /dev/dmabuf_import_helper
36+
networking.gke.io/default-interface: eth0
37+
networking.gke.io/interfaces: |
38+
[
39+
{"interfaceName":"eth0","network":"default"},
40+
{"interfaceName":"eth1","network":"vpc1"},
41+
{"interfaceName":"eth2","network":"vpc2"},
42+
{"interfaceName":"eth3","network":"vpc3"},
43+
{"interfaceName":"eth4","network":"vpc4"},
44+
{"interfaceName":"eth5","network":"vpc5"},
45+
{"interfaceName":"eth6","network":"vpc6"},
46+
{"interfaceName":"eth7","network":"vpc7"},
47+
{"interfaceName":"eth8","network":"vpc8"}
48+
]
49+
{% endif %}
2050
spec:
2151
template:
2252
spec:
2353
volumes:
2454
- name: workspace
2555
configMap:
2656
name: {{ configmap_name }}
57+
defaultMode: 0755
58+
- name: mistral-checkpoint
59+
persistentVolumeClaim:
60+
claimName: mistral-checkpoint
61+
- name: libraries
62+
hostPath:
63+
path: /home/kubernetes/bin/nvidia/lib64
64+
- name: sys
65+
hostPath:
66+
path: /sys
67+
- name: proc-sys
68+
hostPath:
69+
path: /proc/sys
70+
- name: aperture-devices
71+
hostPath:
72+
path: /dev/aperture_devices
73+
- name: dshm
74+
emptyDir:
75+
medium: Memory
76+
sizeLimit: 2048Gi
2777
containers:
2878
- name: node
2979
image: {{ image }}
80+
env:
81+
- name: LD_LIBRARY_PATH
82+
value: /usr/local/nvidia/lib64
83+
- name: NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY
84+
value: /dev/aperture_devices
3085
volumeMounts:
3186
- name: workspace
3287
mountPath: {{ volume_mount_path }}
88+
- name: mistral-checkpoint
89+
mountPath: /workspace
90+
- name: dshm
91+
mountPath: /dev/shm
92+
- name: aperture-devices
93+
mountPath: /dev/aperture_devices
3394
resources:
34-
requests: {}
95+
requests:
96+
{% if cpu_limit %}cpu: {{ cpu_limit }}{% endif %}
97+
{% if memory_limit %}memory: {{ memory_limit }}{% endif %}
98+
{% if gpus %}"nvidia.com/gpu": {{ gpus }}{% endif %}
3599
limits:
36100
{% if cpu_limit %}cpu: {{ cpu_limit }}{% endif %}
37101
{% if memory_limit %}memory: {{ memory_limit }}{% endif %}
38102
{% if gpus %}"nvidia.com/gpu": {{ gpus }}{% endif %}
103+
{% if enable_tcpxo %}
104+
- name: tcpxo-daemon
105+
image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.15
106+
imagePullPolicy: Always
107+
command: ["/bin/sh", "-c"]
108+
args:
109+
- |
110+
set -ex
111+
chmod 755 /fts/entrypoint_rxdm_container.sh
112+
/fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr
113+
env:
114+
- name: LD_LIBRARY_PATH
115+
value: /usr/local/nvidia/lib64
116+
securityContext:
117+
capabilities:
118+
add:
119+
- NET_ADMIN
120+
- NET_BIND_SERVICE
121+
volumeMounts:
122+
- name: libraries
123+
mountPath: /usr/local/nvidia
124+
- name: sys
125+
mountPath: /hostsysfs
126+
- name: proc-sys
127+
mountPath: /hostprocsysfs
128+
{% endif %}

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"packaging",
3535
"toml",
3636
"kubernetes>=28.0.0",
37-
"kubeflow @ git+https://github.com/jskswamy/kubeflow-sdk.git#subdirectory=python",
37+
"kubeflow @ git+https://github.com/jskswamy/kubeflow-sdk.git@main",
3838
]
3939
readme = "README.md"
4040
requires-python = ">= 3.10"
@@ -58,7 +58,7 @@ skypilot-all = ["skypilot[all]>=0.10.0"]
5858
ray = ["kubernetes"]
5959
kubernetes = [
6060
"kubernetes>=28.0.0",
61-
"kubeflow @ git+https://github.com/jskswamy/kubeflow-sdk.git#subdirectory=python",
61+
"kubeflow @ git+https://github.com/jskswamy/kubeflow-sdk.git@main",
6262
]
6363

6464
[dependency-groups]

0 commit comments

Comments
 (0)