2020from typing import Any , Dict , Optional , Union
2121
2222import yaml
23- from kubeflow .trainer import CustomTrainer , TrainerClient
23+ from kubeflow .trainer import CommandTrainer , TrainerClient
24+ from kubeflow .trainer .backends .kubernetes .types import KubernetesBackendConfig
2425from kubernetes import client , config
2526from kubernetes .client .exceptions import ApiException
2627
@@ -46,7 +47,7 @@ class KubeflowExecutor(Executor):
4647
4748 Example:
4849
49- .. code-block:: python
50+ . code-block:: python
5051
5152 # Configure executor for execution environment
5253 executor = KubeflowExecutor(
@@ -104,6 +105,9 @@ class KubeflowExecutor(Executor):
104105 #: Detach mode flag (set by experiment framework)
105106 _detach_mode : bool = field (init = False , default = False )
106107
108+ #: Enable tcpxo sidecar and related mounts/env in runtime template
109+ enable_tcpxo : bool = False
110+
107111 def __post_init__ (self ):
108112 """Validate executor configuration and setup Kubernetes access."""
109113 if self .nodes < 1 :
@@ -213,7 +217,8 @@ def _get_trainer_client(self) -> TrainerClient:
213217 """Get or create a TrainerClient instance."""
214218 if self ._trainer_client is None :
215219 # Initialize client with the executor's namespace
216- self ._trainer_client = TrainerClient (namespace = self .namespace )
220+ k8s_backend_config = KubernetesBackendConfig (namespace = self .namespace )
221+ self ._trainer_client = TrainerClient (backend_config = k8s_backend_config )
217222 return self ._trainer_client
218223
219224 def _create_cluster_training_runtime (self , configmap_name : str , sha : str ) -> str :
@@ -234,6 +239,7 @@ def _create_cluster_training_runtime(self, configmap_name: str, sha: str) -> str
234239 "cpu_limit" : self .cpu_limit ,
235240 "memory_limit" : self .memory_limit ,
236241 "gpus" : self .gpus ,
242+ "enable_tcpxo" : self .enable_tcpxo ,
237243 }
238244 rendered = fill_template (
239245 template_name = "kubeflow_clustertrainingruntime.yaml.j2" ,
@@ -326,10 +332,8 @@ def _get_additional_files(self, task) -> dict[str, tuple[str, str]]:
326332 logger .info ("Script task - will stage content in ConfigMap" )
327333
328334 elif hasattr (task , "__fn_or_cls__" ):
329- # Partial task - will be handled directly by CustomTrainer, no ConfigMap staging needed
330- logger .info (
331- "Partial task - will be passed directly to CustomTrainer, skipping ConfigMap staging"
332- )
335+ # Partial support not implemented yet for CommandTrainer path
336+ logger .warning ("Partial tasks are not yet supported with Kubeflow CommandTrainer." )
333337
334338 return files_to_stage
335339
@@ -370,43 +374,51 @@ def cleanup_files(self, task_dir: str, task=None):
370374 # Use experiment-specific naming for cleanup
371375 self .packager .cleanup (self ._get_experiment_identifier ())
372376
373- def _get_custom_trainer (self , task ) -> CustomTrainer :
374- """Get the CustomTrainer configuration for the training job."""
375- trainer_kwargs : dict = {"num_nodes" : self .nodes }
377+ def _get_custom_trainer (self , task ) -> CommandTrainer :
378+ """Build a CommandTrainer for a Script task. Partial is not yet supported."""
379+ # Reject Partial until implemented
380+ if hasattr (task , "__fn_or_cls__" ):
381+ raise NotImplementedError (
382+ "Partial tasks are not yet supported with Kubeflow CommandTrainer"
383+ )
384+
376385 resources_per_node : dict = {}
377386 if self .cpu_limit is not None :
378387 resources_per_node ["cpu" ] = self .cpu_limit
379388 if self .memory_limit is not None :
380389 resources_per_node ["memory" ] = self .memory_limit
381390 if self .gpus is not None :
382391 resources_per_node ["nvidia.com/gpu" ] = str (self .gpus )
383- trainer_kwargs ["resources_per_node" ] = resources_per_node
384392
385- if hasattr (task , "__fn_or_cls__" ):
386- trainer_kwargs ["func" ] = task .__fn_or_cls__
387- if hasattr (task , "__arguments__" ) and task .__arguments__ :
388- trainer_kwargs ["func_args" ] = task .__arguments__
393+ # Determine command/args based on entrypoint
394+ entrypoint = getattr (task , "entrypoint" , "bash" ) or "bash"
395+ mounted_path = f"{ self .volume_mount_path } /{ self .training_entry } "
396+
397+ command : list [str ]
398+ args : list [str ]
399+ ep_lower = entrypoint .lower ()
400+ if "bash" in ep_lower :
401+ command = ["/bin/bash" ]
402+ args = ["-c" , mounted_path ]
403+ elif "python" in ep_lower :
404+ command = ["python" ]
405+ args = [mounted_path ]
389406 else :
390- # Script task - set python_file and check for bash scripts
391- trainer_kwargs ["python_file" ] = f"{ self .volume_mount_path } /{ self .training_entry } "
392-
393- # Check if this is a bash script and set appropriate command
394- if hasattr (task , "inline" ) and task .inline :
395- entrypoint = getattr (task , "entrypoint" , "bash" )
396- if entrypoint and "bash" in entrypoint .lower ():
397- trainer_kwargs ["command" ] = ["/bin/bash" ]
398- logger .info ("Using bash command for script execution" )
399- # For Python scripts, let SDK auto-detect based on runtime
400-
401- # Debug logging to see what we're passing to CustomTrainer
402- logger .info (f"Creating CustomTrainer with kwargs: { trainer_kwargs } " )
403-
404- trainer = CustomTrainer (** trainer_kwargs )
407+ # Fallback: treat entrypoint as executable to run the staged file
408+ command = [entrypoint ]
409+ args = [mounted_path ]
410+
411+ trainer = CommandTrainer (
412+ command = command ,
413+ args = args ,
414+ num_nodes = self .nodes ,
415+ resources_per_node = resources_per_node ,
416+ )
405417
406- # Debug logging to see what CustomTrainer actually received
407- logger . info ( f"CustomTrainer created with func: { trainer .func } " )
408- logger . info ( f"CustomTrainer created with func_args: { trainer .func_args } " )
409- logger . info ( f"CustomTrainer created with python_file: { trainer . python_file } " )
418+ logger . info (
419+ f"CommandTrainer created with command= { trainer . command } , args= { trainer .args } , "
420+ f"num_nodes= { trainer .num_nodes } , resources_per_node= { trainer . resources_per_node } "
421+ )
410422
411423 return trainer
412424
@@ -442,11 +454,15 @@ def delete_trainjob(self, job_name: str):
442454 except Exception as e :
443455 logger .error (f"Failed to delete TrainJob: { e } " )
444456
445- def get_trainjob_logs (self , job_name : str , follow : bool = False ) -> dict :
457+ def get_trainjob_logs (self , job_name : str , follow : bool = False ):
446458 """Get logs from a TrainJob."""
447459 try :
448460 client = self ._get_trainer_client ()
449- return client .get_job_logs (job_name , follow = follow )
461+ logs_iter = client .get_job_logs (job_name , follow = follow )
462+ # Some tests mock this as a dict; in real SDK it's an Iterator[str]
463+ if isinstance (logs_iter , dict ):
464+ return logs_iter
465+ return logs_iter
450466 except Exception as e :
451467 logger .error (f"Failed to get TrainJob logs: { e } " )
452468 return {}
@@ -529,3 +545,17 @@ def _runtime_name(self, sha: str) -> str:
529545 """Build CRT name from the shared experiment identifier and sha."""
530546 identifier = self ._get_experiment_identifier ()
531547 return sanitize_kubernetes_name (f"nemo-runtime-{ identifier } -{ sha } " )
548+
549+ def _get_staged_file_path (self , filename : str ) -> str :
550+ """Return path where a staged file would be mounted inside the container.
551+
552+ If using ConfigMapPackager, files are mounted under volume_mount_path with
553+ experiment-specific prefix. Otherwise, return the filename unchanged.
554+ """
555+ if (
556+ isinstance (self .packager , ConfigMapPackager )
557+ and hasattr (self , "experiment_name" )
558+ and self .experiment_name
559+ ):
560+ return f"{ self .volume_mount_path } /{ self .experiment_name } -{ filename } "
561+ return filename
0 commit comments