diff --git a/README.md b/README.md index 9ef19030..c0841363 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ To learn more, click on each link. This represents the typical order that NeMo R - [Why Use NeMo Run?](#why-use-nemo-run) - [Install NeMo Run](#install-nemo-run) - [Get Started](#get-started) + - [Supported Executors](#supported-executors) - [Design Philosophy and Inspiration](#design-philosophy-and-inspiration) - [Pythonic](#pythonic) - [Modular](#modular) @@ -36,6 +37,12 @@ To install the project, use the following command: pip install git+https://github.com/NVIDIA-NeMo/Run.git ``` +For Kubeflow support, install with the kubernetes optional dependency: + +```bash +pip install "git+https://github.com/NVIDIA-NeMo/Run.git[kubernetes]" +``` + Make sure you have `pip` installed and configured properly. ## Get Started @@ -59,6 +66,20 @@ local_executor = run.LocalExecutor() run.run(partial_func, executor=local_executor, name="llama3_8b_pretraining") ``` +## Supported Executors + +NeMo Run supports multiple executors for different computing environments: + +- **LocalExecutor**: Execute tasks locally on your machine +- **DockerExecutor**: Execute tasks in Docker containers +- **SlurmExecutor**: Execute tasks on Slurm clusters +- **SkypilotExecutor**: Execute tasks on cloud platforms via Skypilot +- **DGXCloudExecutor**: Execute tasks on NVIDIA DGX Cloud +- **LeptonExecutor**: Execute tasks on NVIDIA DGX Cloud Lepton clusters +- **KubeflowExecutor**: Execute tasks on Kubernetes using Kubeflow Trainer + +For detailed information about each executor, see the [Execution Guide](./docs/source/guides/execution.md). + ## Design Philosophy and Inspiration In building NeMo Run, we drew inspiration from and relied on the following primary libraries. We would like to extend our gratitude for their work. diff --git a/docs/source/guides/execution.md b/docs/source/guides/execution.md index 1eb8d82e..8a5e119e 100644 --- a/docs/source/guides/execution.md +++ b/docs/source/guides/execution.md @@ -3,11 +3,13 @@ After configuring NeMo-Run, the next step is to execute it. Nemo-Run decouples configuration from execution, allowing you to configure a function or task once and then execute it across multiple environments. With Nemo-Run, you can choose to execute a single task or multiple tasks simultaneously on different remote clusters, managing them under an experiment. This brings us to the core building blocks for execution: `run.Executor` and `run.Experiment`. Each execution of a single configured task requires an executor. Nemo-Run provides `run.Executor`, which are APIs to configure your remote executor and set up the packaging of your code. Currently we support: + - `run.LocalExecutor` - `run.DockerExecutor` - `run.SlurmExecutor` with an optional `SSHTunnel` for executing on Slurm clusters from your local machine - `run.SkypilotExecutor` (available under the optional feature `skypilot` in the python package). - `run.LeptonExecutor` +- `run.KubeflowExecutor` A tuple of task and executor form an execution unit. A key goal of NeMo-Run is to allow you to mix and match tasks and executors to arbitrarily define execution units. @@ -19,17 +21,20 @@ The `run.Experiment` takes care of storing the run metadata, launching it on the > **_NOTE:_** All the experiment metadata is stored under `NEMORUN_HOME` env var on the machine where you launch the experiments. By default, the value for `NEMORUN_HOME` value is `~/.run`. Be sure to change this according to your needs. ## Executors + Executors are dataclasses that configure your remote executor and set up the packaging of your code. All supported executors inherit from the base class `run.Executor`, but have configuration parameters specific to their execution environment. There is an initial cost to understanding the specifics of your executor and setting it up, but this effort is easily amortized over time. Each `run.Executor` has the two attributes: `packager` and `launcher`. The `packager` specifies how to package the code for execution, while the `launcher` determines which tool to use for launching the task. ### Launchers + We support the following `launchers`: + - `default` or `None`: This will directly launch your task without using any special launchers. Set `executor.launcher = None` (which is the default value) if you don't want to use a specific launcher. - `torchrun` or `run.Torchrun`: This will launch the task using `torchrun`. See the `Torchrun` class for configuration options. You can use it using `executor.launcher = "torchrun"` or `executor.launcher = Torchrun(...)`. - `ft` or `run.core.execution.FaultTolerance`: This will launch the task using NVIDIA's fault tolerant launcher. See the `FaultTolerance` class for configuration options. You can use it using `executor.launcher = "ft"` or `executor.launcher = FaultTolerance(...)`. -> **_NOTE:_** Launcher may not work very well with `run.Script`. Please report any issues at https://github.com/NVIDIA-NeMo/Run/issues. +> **_NOTE:_** Launcher may not work very well with `run.Script`. Please report any issues at . ### Packagers @@ -43,31 +48,38 @@ The packager support matrix is described below: | SkypilotExecutor | run.Packager, run.GitArchivePackager, run.PatternPackager, run.HybridPackager | | DGXCloudExecutor | run.Packager, run.GitArchivePackager, run.PatternPackager, run.HybridPackager | | LeptonExecutor | run.Packager, run.GitArchivePackager, run.PatternPackager, run.HybridPackager | +| KubeflowExecutor | run.ConfigMapPackager | `run.Packager` is a passthrough base packager. `run.GitArchivePackager` uses `git archive` to package your code. Refer to the API reference for `run.GitArchivePackager` to see the exact mechanics of packaging using `git archive`. At a high level, it works in the following way: + 1. base_path = `git rev-parse --show-toplevel`. 2. Optionally define a subpath as `base_path/GitArchivePackager.subpath` by setting `subpath` attribute on `GitArchivePackager`. 3. `cd base_path && git archive --format=tar.gz --output={output_file} {GitArchivePackager.subpath}:{subpath}` This extracted tar file becomes the working directory for your job. As an example, given the following directory structure with `subpath="src"`: + ``` - docs - src - your_library - tests ``` + Your working directory at the time of execution will look like: + ``` - your_library ``` + If you're executing a Python function, this working directory will automatically be included in your Python path. > **_NOTE:_** git archive doesn't package uncommitted changes. In the future, we may add support for including uncommitted changes while honoring `.gitignore`. `run.PatternPackager` is a packager that uses a pattern to package your code. It is useful for packaging code that is not under version control. For example, if you have a directory structure like this: + ``` - docs - src @@ -86,6 +98,7 @@ cd {relative_path} && find {relative_include_pattern} -type f Each sub-packager in the `sub_packagers` dictionary is assigned a key, which becomes the directory name under which its contents are placed in the final archive. If `extract_at_root` is set to `True`, all contents are placed directly in the root of the archive, potentially overwriting files if names conflict. Example: + ```python import nemo_run as run import os @@ -100,9 +113,11 @@ hybrid_packager = run.HybridPackager( # Usage with an executor: # executor.packager = hybrid_packager ``` + This would create an archive where the contents of `src` are under a `code/` directory and matched `configs/*.yaml` files are under a `configs/` directory. ### Defining Executors + Next, We'll describe details on setting up each of the executors below. #### LocalExecutor @@ -137,6 +152,7 @@ run.DockerExecutor( The SlurmExecutor enables launching the configured task on a Slurm Cluster with Pyxis. Additionally, you can configure a `run.SSHTunnel`, which enables you to execute tasks on the Slurm cluster from your local machine while NeMo-Run manages the SSH connection for you. This setup supports use cases such as launching the same task on multiple Slurm clusters. Below is an example of configuring a Slurm Executor + ```python def your_slurm_executor(nodes: int = 1, container_image: str = DEFAULT_IMAGE): # SSH Tunnel @@ -197,9 +213,11 @@ The `dependency_type` parameter specifies the type of dependency relationship: This functionality enables you to create complex workflows with proper orchestration between different tasks, such as starting a training job only after data preparation is complete, or running an evaluation only after training finishes successfully. #### SkypilotExecutor + This executor is used to configure [Skypilot](https://skypilot.readthedocs.io/en/latest/docs/index.html). Make sure Skypilot is installed using `pip install "nemo_run[skypilot]"` and atleast one cloud is configured using `sky check`. Here's an example of the `SkypilotExecutor` for Kubernetes: + ```python def your_skypilot_executor(nodes: int, devices: int, container_image: str): return SkypilotExecutor( @@ -228,7 +246,7 @@ As demonstrated in the examples, defining executors in Python offers great flexi The `DGXCloudExecutor` integrates with a DGX Cloud cluster's Run:ai API to launch distributed jobs. It uses REST API calls to authenticate, identify the target project and cluster, and submit the job specification. -> **_WARNING:_** Currently, the `DGXCloudExecutor` is only supported when launching experiments *from* a pod running on the DGX Cloud cluster itself. Furthermore, this launching pod must have access to a Persistent Volume Claim (PVC) where the experiment/job directories will be created, and this same PVC must also be configured to be mounted by the job being launched. +> **_WARNING:_** Currently, the `DGXCloudExecutor` is only supported when launching experiments _from_ a pod running on the DGX Cloud cluster itself. Furthermore, this launching pod must have access to a Persistent Volume Claim (PVC) where the experiment/job directories will be created, and this same PVC must also be configured to be mounted by the job being launched. Here's an example configuration: @@ -303,3 +321,233 @@ def your_lepton_executor(nodes: int, gpus_per_node: int, container_image: str): executor = your_lepton_executor(nodes=4, gpus_per_node=8, container_image="your-nemo-image") ``` + +#### KubeflowExecutor + +The `KubeflowExecutor` enables launching distributed training jobs on Kubernetes using the Kubeflow Trainer SDK. It follows Kubeflow's separation of concerns where infrastructure teams create ClusterTrainingRuntime resources, and application teams use existing runtimes to submit training jobs. + +The executor supports both file-based and function-based execution modes, and uses `ConfigMapPackager` to stage files into Kubernetes ConfigMaps for training. + +> **_NOTE:_** The `KubeflowExecutor` requires a pre-configured ClusterTrainingRuntime to be available in your Kubernetes cluster. This runtime should be created by your infrastructure team and include the necessary volume mounting configurations. + +Here's an example configuration: + +```python +from nemo_run.core.packaging.configmap import ConfigMapPackager +from nemo_run.core.execution.kubeflow import KubeflowExecutor + +def your_kubeflow_executor(nodes: int = 2, gpus_per_node: int = 4): + # Configure the ConfigMapPackager for staging files + packager = ConfigMapPackager( + include_pattern="*.py", + relative_path=".", + namespace="default" + ) + + executor = KubeflowExecutor( + # Basic configuration + nodes=nodes, + ntasks_per_node=gpus_per_node, + namespace="default", + runtime_name="torch-distributed-nemo", # Created by infrastructure team + + # Resource configuration + cpu_request="4", + cpu_limit="8", + memory_request="8Gi", + memory_limit="16Gi", + gpus=gpus_per_node, + + # File-based execution + python_file="train.py", # File to execute + + # Packager for staging files + packager=packager, + ) + return executor + +# Example usage: +executor = your_kubeflow_executor(nodes=2, gpus_per_node=4) +``` + +##### File-Based Execution + +For file-based execution, the executor stages your Python files to a ConfigMap and runs the specified file: + +```python +# Configure executor for file-based execution +executor = KubeflowExecutor( + python_file="mistral.py", # File to execute + packager=ConfigMapPackager(include_pattern="*.py"), + runtime_name="torch-distributed-nemo", + nodes=2, + gpus=4 +) + +# Usage with Experiment +with run.Experiment("mistral_training") as exp: + # The executor handles running the staged files + pass +``` + +##### Function-Based Execution + +For function-based execution, the executor serializes your function and executes it: + +```python +def my_training_function(): + """Training function that will be serialized and executed.""" + import torch + print("Training started!") + # Your training logic here + print("Training completed!") + +# Configure executor for function-based execution +executor = KubeflowExecutor( + func=my_training_function, + runtime_name="torch-distributed-nemo", + nodes=2, + gpus=4 +) + +# Usage with Experiment +with run.Experiment("mistral_training") as exp: + exp.add(my_training_function) # Function is serialized and shipped +``` + +##### Advanced Configuration + +For more complex scenarios, you can configure additional options: + +```python +def advanced_kubeflow_executor(): + # Custom packager configuration + packager = ConfigMapPackager( + include_pattern=["*.py", "*.yaml", "*.json"], + relative_path=".", + namespace="default", + configmap_prefix="my-workspace" + ) + + return KubeflowExecutor( + # Basic configuration + nodes=4, + ntasks_per_node=8, + namespace="ml-training", + runtime_name="torch-distributed-nemo", + + # Resource configuration + cpu_request="8", + cpu_limit="16", + memory_request="32Gi", + memory_limit="64Gi", + gpus=8, + + # File-based execution + python_file="distributed_training.py", + + # Packager + packager=packager, + ) +``` + +##### File Staging with ConfigMapPackager + +The `ConfigMapPackager` stages your files into Kubernetes ConfigMaps for training: + +```python +from nemo_run.core.packaging.configmap import ConfigMapPackager + +# Basic configuration +packager = ConfigMapPackager( + include_pattern="*.py", # Files to include + relative_path=".", # Base path for files + namespace="default", # Kubernetes namespace + configmap_prefix="nemo-workspace" # ConfigMap name prefix +) + +# Advanced file staging +packager = ConfigMapPackager( + include_pattern=["*.py", "*.yaml", "*.json", "configs/*"], + relative_path=".", + namespace="default" +) + +# Stage specific directories +packager = ConfigMapPackager( + include_pattern="src/**/*.py", + relative_path=".", + namespace="default" +) +``` + +> **_NOTE:_** ConfigMaps have a 1MB size limit. For larger files, consider using PVC-based staging (future feature) or Git-based staging with volume mounts. + +##### Prerequisites + +Before using the `KubeflowExecutor`, ensure: + +1. **Kubernetes cluster is accessible** + - `kubectl` is installed and configured + - You have access to the target cluster: `kubectl cluster-info` + - Proper authentication and authorization are set up + +2. **Kubeflow Trainer is installed** in your Kubernetes cluster + - Trainer controller is running: `kubectl get pods -n kubeflow-system` + - Custom resources are available: `kubectl get crd | grep trainer` + +3. **ClusterTrainingRuntime is created** by your infrastructure team (e.g., `torch-distributed-nemo`) + - Verify runtime exists: `kubectl get clustertrainingruntimes` + - Check runtime configuration: `kubectl describe clustertrainingruntime torch-distributed-nemo` + +4. **NeMo Run with Kubernetes support** is installed + - Install with Kubernetes extras: `pip install "nemo_run[kubernetes]"` + - Verify ConfigMapPackager is available: `python -c "from nemo_run.core.packaging.configmap import ConfigMapPackager; print('ConfigMapPackager available')"` + +5. **Target namespace exists** and you have permissions to create resources + - Check namespace: `kubectl get namespace ` + - Verify permissions: `kubectl auth can-i create trainjobs -n ` + +##### Architecture + +The `KubeflowExecutor` follows Kubeflow's separation of concerns: + +- **Infrastructure Team**: Creates and manages ClusterTrainingRuntime resources with volume mounting, security, and networking configurations +- **Application Team**: Uses existing ClusterTrainingRuntime to submit TrainJob resources via NeMo Run +- **NeMo Run**: Handles file staging via ConfigMapPackager and job submission via Kubeflow Trainer SDK + +This architecture provides better security, standardization, and scalability across teams. + +##### Monitoring and Debugging + +You can monitor your Kubeflow jobs using standard Kubernetes commands: + +```bash +# List TrainJobs +kubectl get trainjobs -n default + +# Get job details +kubectl describe trainjob -n default + +# Get pod logs +kubectl logs -f -n default + +# List ConfigMaps +kubectl get configmaps -n default +``` + +##### Troubleshooting + +Common issues and solutions: + +1. **ClusterTrainingRuntime not found** + - Contact your infrastructure team to create the runtime + +2. **ConfigMap size exceeded** + - Reduce file size or use different staging strategy + +3. **Kubeflow SDK not available** + - Install kubeflow-trainer package: `pip install kubeflow-trainer` + +4. **Kubernetes client not configured** + - Configure kubectl or set KUBECONFIG environment variable diff --git a/examples/kubeflow/README.md b/examples/kubeflow/README.md new file mode 100644 index 00000000..0ad28b75 --- /dev/null +++ b/examples/kubeflow/README.md @@ -0,0 +1,114 @@ +# KubeflowExecutor Example + +This example demonstrates how to use NeMo Run's `KubeflowExecutor` to run distributed training jobs on Kubernetes using Kubeflow Trainer. + +## Overview + +The `KubeflowExecutor` enables distributed training on Kubernetes clusters using Kubeflow Trainer. This example includes CLI factory functions that make it easy to configure and use `KubeflowExecutor` from the command line. + +## Files + +- `hello_kubeflow.py` - Complete example with CLI integration +- `README.md` - This documentation file + +## CLI Integration + +The example includes CLI factory functions for easy configuration: + +### Available Factories + +#### `kubeflow_gpu` + +GPU training configuration with default settings: + +- 2 nodes, 8 GPUs per node +- 16 CPU cores, 64Gi memory per node +- NVIDIA PyTorch container image + +#### `kubeflow_cpu` + +CPU training configuration: + +- 1 node, no GPUs +- 8 CPU cores, 32Gi memory per node +- NVIDIA PyTorch container image + +### Usage Examples + +```bash +# Use default GPU configuration +python hello_kubeflow.py executor=kubeflow_gpu + +# Customize GPU configuration +python hello_kubeflow.py executor=kubeflow_gpu executor.nodes=4 executor.gpus=16 + +# Use CPU configuration +python hello_kubeflow.py executor=kubeflow_cpu + +# Use the CLI entrypoint +python hello_kubeflow.py train_with_kubeflow executor=kubeflow_gpu epochs=20 +``` + +## Prerequisites + +1. **Kubernetes cluster** with Kubeflow Trainer installed +2. **ClusterTrainingRuntime** named "torch-distributed-nemo" configured +3. **kubectl** configured to access your cluster +4. **NeMo Run** with KubeflowExecutor support + +## Running the Example + +1. **Ensure prerequisites are met**: + + ```bash + # Check kubectl access + kubectl get nodes + + # Check ClusterTrainingRuntime + kubectl get clustertrainingruntime torch-distributed-nemo + ``` + +2. **Run the example**: + + ```bash + cd examples/kubeflow + python hello_kubeflow.py + ``` + +3. **Use CLI integration**: + + ```bash + # GPU training + python hello_kubeflow.py executor=kubeflow_gpu + + # CPU training + python hello_kubeflow.py executor=kubeflow_cpu + + # CLI entrypoint + python hello_kubeflow.py train_with_kubeflow executor=kubeflow_gpu epochs=20 + ``` + +## Key Features + +- **CLI Integration**: Factory functions for easy configuration +- **Resource Management**: GPU and CPU training configurations +- **Distributed Training**: Multi-node training support +- **File Staging**: Automatic file packaging via ConfigMapPackager + +## Troubleshooting + +### Common Issues + +1. **ClusterTrainingRuntime not found**: + + ```bash + kubectl get clustertrainingruntime + ``` + +2. **Kubeflow Trainer not installed**: + + ```bash + kubectl get pods -n kubeflow-system + ``` + +3. **Resource allocation**: Ensure your cluster has sufficient resources. diff --git a/examples/kubeflow/hello_kubeflow.py b/examples/kubeflow/hello_kubeflow.py new file mode 100644 index 00000000..996a2e69 --- /dev/null +++ b/examples/kubeflow/hello_kubeflow.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +""" +Hello Kubeflow Example + +This example demonstrates how to use NeMo Run's KubeflowExecutor to run +distributed training jobs on Kubernetes using Kubeflow Trainer. + +Prerequisites: +1. Kubernetes cluster with Kubeflow Trainer installed +2. A ClusterTrainingRuntime named "torch-distributed-nemo" configured +3. kubectl configured to access your cluster + +This example shows both file-based and function-based execution modes. +""" + +import logging +from pathlib import Path + +import run + +from nemo_run.core.execution.kubeflow import KubeflowExecutor +from nemo_run.core.packaging.configmap import ConfigMapPackager + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def create_training_script(): + """Create a simple training script for demonstration.""" + script_content = '''#!/usr/bin/env python3 +""" +Simple training script for KubeflowExecutor demonstration. +""" +import os +import torch +import torch.distributed as dist + +def main(): + """Main training function.""" + print("๐Ÿš€ Starting distributed training with KubeflowExecutor!") + + # Initialize distributed training + if dist.is_available(): + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + print(f"๐Ÿ“Š Process {rank}/{world_size} initialized") + else: + print("โš ๏ธ Distributed training not available") + rank = 0 + world_size = 1 + + # Simulate training + print(f"๐ŸŽฏ Training on process {rank}/{world_size}") + + # Create a simple model + model = torch.nn.Linear(10, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + # Simulate training steps + for step in range(5): + # Simulate forward pass + x = torch.randn(32, 10) + y = model(x) + loss = y.mean() + + # Backward pass + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if rank == 0: + print(f"๐Ÿ“ˆ Step {step}: Loss = {loss.item():.4f}") + + print(f"โœ… Training completed on process {rank}") + + if dist.is_available(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +''' + + script_path = Path("train_script.py") + with open(script_path, "w") as f: + f.write(script_content) + + return script_path + + +def training_function(): + """Function-based training example.""" + import torch + import torch.distributed as dist + + print("๐ŸŽฏ Function-based training started!") + + # Initialize distributed training + if dist.is_available(): + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + print(f"๐Ÿ“Š Process {rank}/{world_size} in function-based training") + + # Simulate training + model = torch.nn.Linear(10, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + for step in range(3): + x = torch.randn(16, 10) + y = model(x) + loss = y.mean() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if rank == 0: + print(f"๐Ÿ“ˆ Function Step {step}: Loss = {loss.item():.4f}") + + print(f"โœ… Function-based training completed on process {rank}") + + if dist.is_available(): + dist.destroy_process_group() + + +# CLI Factory Functions for KubeflowExecutor +@run.cli.factory +@run.autoconvert +def kubeflow_gpu( + nodes: int = 2, + gpus: int = 8, + cpu_limit: str = "16", + memory_limit: str = "64Gi", + image: str = "nvcr.io/nvidia/pytorch:23.12-py3", + namespace: str = "default", +) -> KubeflowExecutor: + """Factory for GPU training with KubeflowExecutor.""" + return KubeflowExecutor( + nodes=nodes, + gpus=gpus, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + image=image, + namespace=namespace, + packager=ConfigMapPackager(), + ) + + +@run.cli.factory +@run.autoconvert +def kubeflow_cpu( + nodes: int = 1, + cpu_limit: str = "8", + memory_limit: str = "32Gi", + image: str = "nvcr.io/nvidia/pytorch:23.12-py3", + namespace: str = "default", +) -> KubeflowExecutor: + """Factory for CPU training with KubeflowExecutor.""" + return KubeflowExecutor( + nodes=nodes, + cpu_limit=cpu_limit, + memory_limit=memory_limit, + image=image, + namespace=namespace, + packager=ConfigMapPackager(), + ) + + +@run.cli.entrypoint +def train_with_kubeflow( + executor: KubeflowExecutor = kubeflow_gpu(), + epochs: int = 10, + batch_size: int = 32, +): + """ + Train a model using KubeflowExecutor. + + Args: + executor: KubeflowExecutor configuration + epochs: Number of training epochs + batch_size: Batch size for training + """ + print("๐Ÿš€ Starting training with KubeflowExecutor") + print(f"๐Ÿ”ง Executor: {executor}") + print(f"๐Ÿ“Š Epochs: {epochs}, Batch Size: {batch_size}") + + # Simulate training process + for epoch in range(epochs): + print(f"๐Ÿ“ˆ Epoch {epoch + 1}/{epochs}") + + print("โœ… Training completed!") + + +def main(): + """Main function demonstrating KubeflowExecutor usage.""" + logger.info("๐Ÿš€ Starting KubeflowExecutor example") + + # Create training script + script_path = create_training_script() + logger.info(f"๐Ÿ“ Created training script: {script_path}") + + # Example 1: File-based execution + logger.info("๐Ÿ“ Example 1: File-based execution") + + # Configure the packager + packager = ConfigMapPackager(include_pattern="*.py", relative_path=".", namespace="default") + + # Create KubeflowExecutor for GPU training + gpu_executor = KubeflowExecutor( + nodes=2, + gpus=8, + cpu_limit="16", + memory_limit="64Gi", + namespace="default", + packager=packager, + ) + + # Example 2: CPU training + logger.info("โš™๏ธ Example 2: CPU training") + + cpu_executor = KubeflowExecutor( + nodes=1, + cpu_limit="8", + memory_limit="32Gi", + namespace="default", + packager=packager, + ) + + # Run experiments + logger.info("๐ŸŽฏ Running GPU training experiment") + + with run.Experiment("kubeflow_gpu_training") as exp: + exp.add( + "gpu_training", + executor=gpu_executor, + description="GPU training with KubeflowExecutor", + ) + + logger.info("๐ŸŽฏ Running CPU training experiment") + + with run.Experiment("kubeflow_cpu_training") as exp: + exp.add( + "cpu_training", + executor=cpu_executor, + description="CPU training with KubeflowExecutor", + ) + + # Clean up + if script_path.exists(): + script_path.unlink() + logger.info(f"๐Ÿงน Cleaned up {script_path}") + + logger.info("โœ… KubeflowExecutor example completed!") + + +if __name__ == "__main__": + main() diff --git a/nemo_run/core/execution/__init__.py b/nemo_run/core/execution/__init__.py index 7c787a16..0537a5d5 100644 --- a/nemo_run/core/execution/__init__.py +++ b/nemo_run/core/execution/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. from nemo_run.core.execution.dgxcloud import DGXCloudExecutor +from nemo_run.core.execution.kubeflow import KubeflowExecutor from nemo_run.core.execution.lepton import LeptonExecutor from nemo_run.core.execution.local import LocalExecutor from nemo_run.core.execution.skypilot import SkypilotExecutor @@ -25,4 +26,5 @@ "SkypilotExecutor", "DGXCloudExecutor", "LeptonExecutor", + "KubeflowExecutor", ] diff --git a/nemo_run/core/execution/kubeflow.py b/nemo_run/core/execution/kubeflow.py new file mode 100644 index 00000000..3f469eab --- /dev/null +++ b/nemo_run/core/execution/kubeflow.py @@ -0,0 +1,756 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import re +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Optional, Union + +import yaml +from kubeflow.trainer import CommandTrainer, TrainerClient +from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig +from kubernetes import client, config +from kubernetes.client.exceptions import ApiException + +from nemo_run.config import Partial, Script +from nemo_run.core.execution.base import Executor, ExecutorMacros +from nemo_run.core.execution.utils import ( + fill_template, +) +from nemo_run.core.packaging.base import sanitize_kubernetes_name +from nemo_run.core.packaging.configmap import ConfigMapPackager + +logger = logging.getLogger(__name__) + + +def _build_trainer_command(task, mounted_path: str) -> tuple[list[str], list[str]]: + """Return (command, args) for CommandTrainer based on task type/content. + + - Partial: treat as python entry + - Script: use task.entrypoint/inline if present + """ + entrypoint = getattr(task, "entrypoint", "") + inline = getattr(task, "inline", "") + is_partial = hasattr(task, "__fn_or_cls__") + ep = "python" if is_partial else entrypoint.strip() + is_python = is_partial or bool(re.search(r"(^|/)?python(\d+(\.\d+)*)?$", ep, re.IGNORECASE)) + is_bash = bool(re.search(r"(^|/)?bash$", ep, re.IGNORECASE)) + + # Shared PET-derived rendezvous args + base_args: list[str] = [ + "--nnodes", + "${PET_NNODES}", + "--nproc_per_node", + "${PET_NPROC_PER_NODE}", + "--rdzv_backend", + "c10d", + "--rdzv_endpoint", + "${PET_MASTER_ADDR}:${PET_MASTER_PORT}", + ] + + # Pass-through for bash inline that already includes torchrun + if is_bash and re.search(r"(^|\s)torchrun(\s|$)", inline): + return [mounted_path], [] + + # Build args once; add --no-python for non-python entrypoints + args: list[str] = [*base_args] + if not is_python: + args.append("--no-python") + args.append(mounted_path) + + return ["torchrun"], args + + +def _materialize_task_content_for_staging(self, task) -> tuple[str, str]: + """Return (content, entrypoint) for staging Script or Partial into ConfigMap.""" + + def _read_text(file_path: str) -> str: + with open(file_path, "r", encoding="utf-8") as f: + return f.read() + + if hasattr(task, "inline") and task.inline: + entrypoint = getattr(task, "entrypoint", "bash") or "bash" + inline_val = task.inline.strip() + if inline_val.startswith("/") and inline_val.endswith(".sh"): + local_script_path = inline_val.replace("/nemo_run/scripts/", f"{self.job_dir}/scripts/") + if not os.path.exists(local_script_path): + raise FileNotFoundError(f"TorchX script file not found: {local_script_path}") + return _read_text(local_script_path), entrypoint + return inline_val, entrypoint + + if hasattr(task, "__fn_or_cls__"): + scripts_dir = os.path.join(self.job_dir, "scripts") + os.makedirs(scripts_dir, exist_ok=True) + script_filename = os.path.join(scripts_dir, f"{self.training_entry}.sh") + if hasattr(task, "to_command"): + _ = task.to_command(with_entrypoint=False, filename=script_filename, is_local=True) + content = _read_text(script_filename) + else: + raise ValueError("Cannot stage Partial: task does not support to_command()") + return content, "python" + + raise ValueError("Unsupported task type for staging") + + +@dataclass +class StorageMount: + """Generic storage mount configuration. + + kind="pvc" currently supported. Future kinds: hostPath, emptyDir, nfs. + """ + + mount_path: str + read_only: bool = False + name: Optional[str] = None + + # PVC-specific + pvc_claim_name: Optional[str] = None + create_if_missing: bool = False + size: Optional[str] = None + storage_class: Optional[str] = None + access_modes: list[str] = field(default_factory=lambda: ["ReadWriteOnce"]) + kind: str = "pvc" + + def to_template_fragment(self, index: int) -> dict[str, Any]: + vol_name = self.get_volume_name(index) + claim_name_sanitized = self.get_pvc_claim_name() + if self.kind == "pvc" and self.pvc_claim_name: + return { + "name": vol_name, + "claim_name": claim_name_sanitized, + "mount_path": self.mount_path, + "read_only": self.read_only, + } + raise ValueError(f"Unsupported StorageMount config: {self}") + + def get_volume_name(self, index: int) -> str: + """Return a DNS-1123 safe volume name, defaulting to pvc-{index}.""" + base = self.name or f"pvc-{index}" + return sanitize_kubernetes_name(base).lower() + + def get_pvc_claim_name(self) -> Optional[str]: + """Return a DNS-1123 safe PVC claim name or None if unset.""" + if not self.pvc_claim_name: + return None + return sanitize_kubernetes_name(self.pvc_claim_name).lower() + + +@dataclass +class AdditionalPackages: + """Optional package installation configuration for the training container. + + Fields map directly to SDK `CommandTrainer` parameters. + """ + + packages_to_install: Optional[list[str]] = None + pip_index_urls: Optional[list[str]] = None + pip_extra_args: Optional[list[str]] = None + + def as_trainer_kwargs(self) -> Dict[str, Any]: + """Return subset of kwargs for CommandTrainer based on configured fields.""" + allowed = {"packages_to_install", "pip_index_urls", "pip_extra_args"} + return asdict( + self, + dict_factory=lambda items: { + k: (list(v) if isinstance(v, list) else v) for k, v in items if k in allowed and v + }, + ) + + +@dataclass(kw_only=True) +class KubeflowExecutor(Executor): + """ + Dataclass to configure Kubeflow executor for distributed training jobs. + + This executor uses the Kubeflow Trainer SDK to create and manage TrainJob objects. + It supports execution of tasks passed from the Experiment API (Script, Partial, Config). + + The actual execution details (torchrun vs python, command construction) are handled + by the Kubeflow SDK through the Runtime and Trainer objects. + + Example: + + . code-block:: python + + # Configure executor for execution environment + executor = KubeflowExecutor( + name="myexec", + namespace="default", + ) + + # Use with Experiment API + training_script = run.Script(inline="python train.py") + + with run.Experiment("training") as exp: + exp.add(training_script, executor=executor) + exp.run() + """ + + #: Number of nodes for distributed training + nodes: int = 1 + + #: Number of processes per node (typically matches number of GPUs) + ntasks_per_node: int = 1 + + #: Kubernetes namespace for the training job + namespace: str = "default" + + #: Resource limits for CPU + cpu_limit: Optional[str] = None + + #: Resource limits for memory + memory_limit: Optional[str] = None + + #: Number of GPUs per node to request + gpus_per_node: Optional[int] = None + + #: Container image for training jobs + container_image: str = "nvcr.io/nvidia/nemo:dev" + + #: Training job filename + training_entry: str = "experiment" + + #: Workspace mount path for staged files (default: /src) + workspace_mount_path: str = "/src" + + #: TrainerClient instance for managing TrainJob objects + _trainer_client: Optional[TrainerClient] = field(init=False, repr=False, default=None) + + #: Job name (set from task_id during assign) + job_name: str = field(init=False, default="") + + #: Current task being executed (set by Experiment API) + _current_task: Optional[Union[Script, Partial]] = None + + #: Kubernetes connectivity status + _kubernetes_available: bool = field(init=False, default=False) + + #: Detach mode flag (set by experiment framework) + _detach_mode: bool = field(init=False, default=False) + + #: Enable tcpxo sidecar and related mounts/env in runtime template + enable_tcpxo: bool = False + + storage_mounts: list["StorageMount"] = field(default_factory=list) + + #: Optional package installation configuration + additional_packages: Optional[AdditionalPackages] = None + + def __post_init__(self): + """Validate executor configuration and setup Kubernetes access.""" + if self.nodes < 1: + raise ValueError("nodes must be >= 1") + if self.ntasks_per_node < 1: + raise ValueError("ntasks_per_node must be >= 1") + + # Setup Kubernetes configuration + self._setup_kubernetes_config() + + def _setup_kubernetes_config(self): + """Setup Kubernetes configuration for ClusterTrainingRuntime operations.""" + try: + # Try in-cluster config first (when running inside Kubernetes) + config.load_incluster_config() + logger.info("Using in-cluster Kubernetes configuration") + except config.ConfigException: + try: + # Try local kubeconfig (when running locally) + config.load_kube_config() + logger.info("Using local kubeconfig") + except config.ConfigException: + logger.warning( + "Could not load Kubernetes configuration - ClusterTrainingRuntime operations require Kubernetes" + ) + self._kubernetes_available = False + return + + # Test Kubernetes connectivity + try: + api_client = client.CoreV1Api() + api_client.list_namespace() + logger.info("Kubernetes connectivity verified") + self._kubernetes_available = True + except Exception as e: + logger.warning(f"Kubernetes connectivity test failed: {e}") + self._kubernetes_available = False + + def assign( + self, + exp_id: str, + exp_dir: str, + task_id: str, + task_dir: str, + ): + """Assign experiment and task information to the executor.""" + self.experiment_id = exp_id + self.experiment_name = re.sub(r"([_\d]+)", "", exp_id) + self.experiment_dir = exp_dir + self.job_dir = os.path.join(exp_dir, task_dir) + self.job_name = task_id + + logger.info( + f"KubeflowExecutor assigned: experiment_id={self.experiment_id}, job_name={self.job_name}" + ) + + def set_detach_mode(self, detach: bool): + """Set detach mode for the executor.""" + self._detach_mode = detach + logger.info(f"KubeflowExecutor detach mode set to: {detach}") + + def nnodes(self) -> int: + """Return the number of nodes for distributed training.""" + return self.nodes + + def nproc_per_node(self) -> int: + """Return the number of processes per node.""" + return self.ntasks_per_node + + def macro_values(self) -> Optional[ExecutorMacros]: + return None + + def get_launcher_prefix(self) -> Optional[list[str]]: + """Get launcher prefix for profiling if enabled.""" + launcher = self.get_launcher() + if launcher and hasattr(launcher, "nsys_profile") and launcher.nsys_profile: + os.makedirs(os.path.join(self.job_dir, launcher.nsys_folder), exist_ok=True) + return launcher.get_nsys_prefix(profile_dir=self.job_dir) + return None + + def get_nsys_entrypoint(self) -> str: + """Get nsys entrypoint for profiling.""" + return "nsys" + + def supports_launcher_transform(self) -> bool: + """Return whether this executor supports launcher transforms.""" + return False + + def package_configs(self, *cfgs: tuple[str, str]) -> list[str]: + """Package configuration files for the job.""" + filenames = [] + basepath = os.path.join(self.job_dir, "configs") + os.makedirs(basepath, exist_ok=True) + for name, cfg in cfgs: + filename = os.path.join(basepath, name) + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "w") as f: + f.write(cfg) + filenames.append(filename) + return filenames + + def create_job_dir(self): + """Create the job directory.""" + os.makedirs(self.job_dir, exist_ok=True) + + def _get_trainer_client(self) -> TrainerClient: + """Get or create a TrainerClient instance.""" + if self._trainer_client is None: + # Initialize client with the executor's namespace + k8s_backend_config = KubernetesBackendConfig(namespace=self.namespace) + self._trainer_client = TrainerClient(backend_config=k8s_backend_config) + return self._trainer_client + + def _create_cluster_training_runtime(self, configmap_name: str, sha: str) -> str: + """Create or replace a ClusterTrainingRuntime bound to the given ConfigMap.""" + runtime_name = self._runtime_name(sha) + + if not hasattr(self, "_kubernetes_available") or not self._kubernetes_available: + raise RuntimeError("Kubernetes is not available; cannot create ClusterTrainingRuntime") + + api_client = client.CustomObjectsApi() + # Ensure storage objects exist prior to runtime creation + self._ensure_storage() + + # Ensure env secret exists prior to runtime creation + env_from_secrets: list[str] = self._ensure_env_secret(sha) + + template_vars = { + "runtime_name": runtime_name, + "namespace": self.namespace, + "nodes": self.nodes, + "image": self.container_image, + "workspace_mount_path": self.workspace_mount_path, + "configmap_name": configmap_name, + "cpu_limit": self.cpu_limit, + "memory_limit": self.memory_limit, + "gpus": self.gpus_per_node, + "num_proc_per_node": self.ntasks_per_node, + "enable_tcpxo": self.enable_tcpxo, + "storage_pvc_mounts": self._get_normalized_storage_mounts(), + "env_from_secrets": env_from_secrets, + } + rendered = fill_template( + template_name="kubeflow_clustertrainingruntime.yaml.j2", + variables=template_vars, + ) + runtime_body = yaml.safe_load(rendered) # type: ignore[assignment] + + try: + api_client.create_cluster_custom_object( + group="trainer.kubeflow.org", + version="v1alpha1", + plural="clustertrainingruntimes", + body=runtime_body, + ) + logger.info(f"Created ClusterTrainingRuntime: {runtime_name}") + except ApiException as e: + if e.status == 409: + # Resource already exists, fetch it first to get resourceVersion + try: + existing_runtime_obj = api_client.get_cluster_custom_object( + group="trainer.kubeflow.org", + version="v1alpha1", + plural="clustertrainingruntimes", + name=runtime_name, + ) + existing_runtime: Dict[str, Any] = existing_runtime_obj # type: ignore[assignment] + # Update the resourceVersion in our new body + runtime_body["metadata"]["resourceVersion"] = existing_runtime["metadata"][ + "resourceVersion" + ] # type: ignore[index] + + # Replace the existing ClusterTrainingRuntime + api_client.replace_cluster_custom_object( + group="trainer.kubeflow.org", + version="v1alpha1", + plural="clustertrainingruntimes", + name=runtime_name, + body=runtime_body, + ) + logger.info(f"Replaced existing ClusterTrainingRuntime: {runtime_name}") + except Exception as replace_error: + logger.error( + f"Failed to replace existing ClusterTrainingRuntime: {replace_error}" + ) + raise + else: + logger.error(f"Failed to create ClusterTrainingRuntime: {e}") + raise + return runtime_name + + def _ensure_storage(self) -> None: + """Create PVCs for storage_mounts with create_if_missing=True.""" + if not self.storage_mounts: + return + core_client = client.CoreV1Api() + for sm in self.storage_mounts: + if sm.kind != "pvc" or not sm.create_if_missing or not sm.pvc_claim_name: + continue + sanitized_claim = sm.get_pvc_claim_name() + try: + core_client.read_namespaced_persistent_volume_claim( + name=sanitized_claim, namespace=self.namespace + ) + continue + except ApiException as e: + if e.status != 404: + logger.warning(f"PVC check failed for {sm.pvc_claim_name}: {e}") + continue + pvc_yaml = fill_template( + template_name="kubeflow_pvc.yaml.j2", + variables={ + "name": sanitized_claim, + "namespace": self.namespace, + "size": sm.size or "100Gi", + "access_modes": sm.access_modes, + "storage_class": sm.storage_class, + }, + ) + pvc_manifest: Dict[str, Any] = yaml.safe_load(pvc_yaml) + try: + core_client.create_namespaced_persistent_volume_claim( + namespace=self.namespace, body=pvc_manifest + ) + logger.info(f"Created PVC {sm.pvc_claim_name} in {self.namespace}") + except ApiException as e: + if e.status == 409: + logger.info(f"PVC {sm.pvc_claim_name} already exists") + else: + logger.warning(f"Failed to create PVC {sm.pvc_claim_name}: {e}") + + def _get_normalized_storage_mounts(self) -> list[dict[str, Any]]: + """Normalize storage_mounts (currently kind=pvc) to template fragments.""" + normalized: list[dict[str, Any]] = [] + for j, sm in enumerate(self.storage_mounts, start=1): + try: + frag = sm.to_template_fragment(index=j) + normalized.append(frag) + except Exception: + continue + return normalized + + def _get_additional_files(self, task) -> dict[str, tuple[str, str]]: + """Get additional files to stage based on task type. + + Returns: + Dict mapping filename to (content, entrypoint) tuples + """ + files_to_stage = {} + + if task is None: + return files_to_stage + + if (hasattr(task, "inline") and task.inline) or hasattr(task, "__fn_or_cls__"): + try: + content, entrypoint = _materialize_task_content_for_staging(self, task) + files_to_stage[self.training_entry] = (content, entrypoint) + logger.info("Staged task content in ConfigMap") + except Exception as e: + logger.warning(f"Failed staging task content: {e}") + + return files_to_stage + + def _ensure_env_secret(self, sha: str) -> list[str]: + """Ensure a Secret exists when env_vars are configured; return list of envFrom names.""" + if not self.env_vars: + return [] + generated_secret_name = self._env_secret_name(sha) + try: + core_client = client.CoreV1Api() + body = client.V1Secret( + metadata=client.V1ObjectMeta(name=generated_secret_name, namespace=self.namespace), + string_data=self.env_vars, + type="Opaque", + ) + core_client.create_namespaced_secret(namespace=self.namespace, body=body) + logger.info(f"Created Secret {generated_secret_name} in {self.namespace}") + except ApiException as e: + if e.status == 409: + # Secret exists; patch to ensure latest env_vars are reflected + try: + patch_body = {"stringData": self.env_vars, "type": "Opaque"} + core_client.patch_namespaced_secret( + name=generated_secret_name, namespace=self.namespace, body=patch_body + ) + logger.info( + f"Patched Secret {generated_secret_name} with updated stringData in {self.namespace}" + ) + except Exception as patch_err: + logger.warning(f"Failed to patch Secret {generated_secret_name}: {patch_err}") + else: + logger.warning(f"Failed to create Secret {generated_secret_name}: {e}") + return [generated_secret_name] + + def stage_files(self, task_dir: str, task=None) -> tuple[str, str]: + """Stage files using the packager. + + Adds additional files based on task content and packages along with + any original files configured on the packager. Returns the ConfigMap name. + """ + if not isinstance(self.packager, ConfigMapPackager): + return (task_dir, "") + + # Get additional files to stage based on task type + additional_files = self._get_additional_files(task) + + # Stage all additional files + experiment_id = self._get_experiment_identifier() + for filename, (content, entrypoint) in additional_files.items(): + self.packager.add_file(experiment_id, filename, content, entrypoint=entrypoint) + + try: + configmap_name, sha = self.packager.package_with_hash(experiment_id) + logger.info(f"Staged files into ConfigMap: {configmap_name} (sha={sha or 'n/a'})") + return (configmap_name, sha) + except Exception as e: + logger.error(f"Failed to stage files: {e}") + raise + + def _get_experiment_identifier(self) -> str: + """Return experiment_id; raise if not assigned yet.""" + if hasattr(self, "experiment_name") and self.experiment_name: + return f"{self.experiment_name}" + raise RuntimeError("Executor not assigned to experiment; missing experiment_name") + + def cleanup_files(self, task_dir: str, task=None): + """Clean up staged files.""" + if isinstance(self.packager, ConfigMapPackager): + # Use experiment-specific naming for cleanup + self.packager.cleanup(self._get_experiment_identifier()) + + def _get_custom_trainer(self, task) -> CommandTrainer: + """Build a CommandTrainer for a Script or Partial task using launcher semantics.""" + + resources_per_node: dict = {} + if self.cpu_limit is not None: + resources_per_node["cpu"] = self.cpu_limit + if self.memory_limit is not None: + resources_per_node["memory"] = self.memory_limit + if self.gpus_per_node is not None: + resources_per_node["nvidia.com/gpu"] = str(self.gpus_per_node) + + mounted_path = f"{self.workspace_mount_path}/{self.training_entry}" + command, args = _build_trainer_command(task, mounted_path) + + trainer_kwargs: Dict[str, Any] = { + "command": command, + "args": args, + "num_nodes": self.nodes, + "resources_per_node": resources_per_node, + } + if self.additional_packages: + trainer_kwargs.update(self.additional_packages.as_trainer_kwargs()) + + trainer = CommandTrainer(**trainer_kwargs) + + logger.info( + f"CommandTrainer created with command={trainer.command}, args={trainer.args}, " + f"num_nodes={trainer.num_nodes}, resources_per_node={trainer.resources_per_node}" + ) + + return trainer + + def create_trainjob(self, job_name: str, task, runtime_name: str) -> str: + """Create a TrainJob using the Kubeflow SDK.""" + try: + client = self._get_trainer_client() + trainer = self._get_custom_trainer(task) + runtime = client.get_runtime(runtime_name) + job_id = client.train(runtime=runtime, trainer=trainer) + logger.info(f"Created TrainJob: {job_id}") + return job_id + except Exception as e: + logger.error(f"Failed to create TrainJob: {e}") + raise + + def get_trainjob_status(self, job_name: str) -> str: + """Get the status of a TrainJob.""" + try: + client = self._get_trainer_client() + job = client.get_job(job_name) + return job.status or "Unknown" + except Exception as e: + logger.error(f"Failed to get TrainJob status: {e}") + return "Unknown" + + def delete_trainjob(self, job_name: str): + """Delete a TrainJob.""" + try: + client = self._get_trainer_client() + client.delete_job(job_name) + logger.info(f"Deleted TrainJob: {job_name}") + except Exception as e: + logger.error(f"Failed to delete TrainJob: {e}") + + def get_trainjob_logs(self, job_name: str, follow: bool = False): + """Get logs from a TrainJob.""" + try: + client = self._get_trainer_client() + logs_iter = client.get_job_logs(job_name, follow=follow) + # Some tests mock this as a dict; in real SDK it's an Iterator[str] + if isinstance(logs_iter, dict): + return logs_iter + return logs_iter + except Exception as e: + logger.error(f"Failed to get TrainJob logs: {e}") + return {} + + def prepare_runtime(self, task=None) -> tuple[str, str]: + """Atomically prepare runtime dependencies for this executor. + + Steps: + - Create a unique ConfigMap for this experiment that includes: + * Initial training code (from ConfigMapPackager) + * Dynamic experiment scripts (created during task execution) + - Create a unique ClusterTrainingRuntime that references that ConfigMap + + Returns (runtime_name, sha). Raises on failure so callers don't proceed to submit(). + """ + # Stage files to ensure we have the latest content and ConfigMap + configmap_name, sha = self.stage_files(task_dir="", task=task) + + # Create runtime bound to this ConfigMap + try: + runtime_name = self._create_cluster_training_runtime( + configmap_name=configmap_name, sha=sha + ) + logger.info(f"Prepared runtime: {runtime_name}") + return (runtime_name, sha) + except Exception: + raise + + def submit(self, task, job_name: str) -> str: + """ + Submit a job using the Kubeflow SDK. + + Prepares the ConfigMap and ClusterTrainingRuntime (idempotent) and + then creates the TrainJob. + """ + if not hasattr(self, "experiment_id") or not self.experiment_id: + raise RuntimeError("Executor not assigned to experiment") + + try: + # Prepare runtime dependencies (stages files and creates runtime) + runtime_name, _ = self.prepare_runtime(task=task) + + job_id = self.create_trainjob(job_name, task, runtime_name) + logger.info(f"Submitted job {job_name} with ID: {job_id}") + return job_id + + except Exception as e: + logger.error(f"Failed to submit job {job_name}: {e}") + raise + + def monitor(self, job_id: str) -> str: + """Monitor the status of a job.""" + if not hasattr(self, "experiment_id") or not self.experiment_id: + raise RuntimeError("Executor not assigned to experiment") + try: + status = self.get_trainjob_status(job_id) + logger.debug(f"Job {job_id} status: {status}") + return status + except Exception as e: + logger.error(f"Failed to monitor job {job_id}: {e}") + return "Unknown" + + def cleanup(self, handle: str) -> None: + """Clean up resources associated with a job.""" + if not hasattr(self, "experiment_id") or not self.experiment_id: + raise RuntimeError("Executor not assigned to experiment") + try: + logger.info( + "KubeflowExecutor.cleanup: not deleting job or runtime; align with non-TorchX executors (Lepton/DGXCloud)" + ) + return + except Exception as e: + logger.error(f"Failed to cleanup job {handle}: {e}") + + def info(self) -> str: + """Get information about the executor configuration.""" + return f"KubeflowExecutor (nodes={self.nodes}, gpus={self.gpus_per_node or 0})" + + def _runtime_name(self, sha: str) -> str: + """Build CRT name from the shared experiment identifier and sha.""" + identifier = self._get_experiment_identifier() + return sanitize_kubernetes_name(f"nemo-runtime-{identifier}-{sha}") + + def _env_secret_name(self, sha: str) -> str: + """Return a deterministic Secret name for env vars derived from experiment+sha.""" + identifier = self._get_experiment_identifier() + return sanitize_kubernetes_name(f"nemo-env-{identifier}-{sha}") + + def _get_staged_file_path(self, filename: str) -> str: + """Return path where a staged file would be mounted inside the container. + + If using ConfigMapPackager, files are mounted under workspace_mount_path with + experiment-specific prefix. Otherwise, return the filename unchanged. + """ + if ( + isinstance(self.packager, ConfigMapPackager) + and hasattr(self, "experiment_name") + and self.experiment_name + ): + return f"{self.workspace_mount_path}/{self.experiment_name}-{filename}" + return filename diff --git a/nemo_run/core/execution/templates/kubeflow_clustertrainingruntime.yaml.j2 b/nemo_run/core/execution/templates/kubeflow_clustertrainingruntime.yaml.j2 new file mode 100644 index 00000000..198e4e33 --- /dev/null +++ b/nemo_run/core/execution/templates/kubeflow_clustertrainingruntime.yaml.j2 @@ -0,0 +1,157 @@ +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: {{ runtime_name }} + namespace: {{ namespace }} + labels: + trainer.kubeflow.org/framework: torch +spec: + mlPolicy: + numNodes: {{ nodes }} + torch: + numProcPerNode: {{ num_proc_per_node }} + template: + spec: + replicatedJobs: + - name: node + replicas: 1 + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: trainer + spec: + template: + metadata: + {% if enable_tcpxo %} + annotations: + devices.gke.io/container.tcpxo-daemon: | + - path: /dev/nvidia0 + - path: /dev/nvidia1 + - path: /dev/nvidia2 + - path: /dev/nvidia3 + - path: /dev/nvidia4 + - path: /dev/nvidia5 + - path: /dev/nvidia6 + - path: /dev/nvidia7 + - path: /dev/nvidiactl + - path: /dev/nvidia-uvm + - path: /dev/dmabuf_import_helper + networking.gke.io/default-interface: eth0 + networking.gke.io/interfaces: | + [ + {"interfaceName":"eth0","network":"default"}, + {"interfaceName":"eth1","network":"vpc1"}, + {"interfaceName":"eth2","network":"vpc2"}, + {"interfaceName":"eth3","network":"vpc3"}, + {"interfaceName":"eth4","network":"vpc4"}, + {"interfaceName":"eth5","network":"vpc5"}, + {"interfaceName":"eth6","network":"vpc6"}, + {"interfaceName":"eth7","network":"vpc7"}, + {"interfaceName":"eth8","network":"vpc8"} + ] + {% endif %} + spec: + affinity: + podAntiAffinity: + requiredDuringSchedulingIgnoredDuringExecution: + - labelSelector: + matchExpressions: + - key: jobset.sigs.k8s.io/replicatedjob-name + operator: In + values: + - node + topologyKey: kubernetes.io/hostname + volumes: + - name: workspace + configMap: + name: {{ configmap_name }} + defaultMode: 0755 + {% if storage_pvc_mounts %} + {% for pvc in storage_pvc_mounts %} + - name: {{ pvc.name }} + persistentVolumeClaim: + claimName: {{ pvc.claim_name }} + {% endfor %} + {% endif %} + - name: libraries + hostPath: + path: /home/kubernetes/bin/nvidia/lib64 + - name: sys + hostPath: + path: /sys + - name: proc-sys + hostPath: + path: /proc/sys + - name: aperture-devices + hostPath: + path: /dev/aperture_devices + - name: dshm + emptyDir: + medium: Memory + sizeLimit: 2048Gi + initContainers: + {% if enable_tcpxo %} + - name: tcpxo-daemon + image: us-docker.pkg.dev/gce-ai-infra/gpudirect-tcpxo/tcpgpudmarxd-dev:v1.0.15 + imagePullPolicy: Always + restartPolicy: Always + command: ["/bin/sh", "-c"] + args: + - | + set -ex + chmod 755 /fts/entrypoint_rxdm_container.sh + /fts/entrypoint_rxdm_container.sh --num_hops=2 --num_nics=8 --uid= --alsologtostderr + env: + - name: LD_LIBRARY_PATH + value: /usr/local/nvidia/lib64 + securityContext: + capabilities: + add: + - NET_ADMIN + - NET_BIND_SERVICE + volumeMounts: + - name: libraries + mountPath: /usr/local/nvidia + - name: sys + mountPath: /hostsysfs + - name: proc-sys + mountPath: /hostprocsysfs + {% endif %} + containers: + - name: node + image: {{ image }} + env: + - name: LD_LIBRARY_PATH + value: /usr/local/nvidia/lib64 + - name: NCCL_FASTRAK_LLCM_DEVICE_DIRECTORY + value: /dev/aperture_devices + {% if env_from_secrets and env_from_secrets|length > 0 %} + envFrom: + {% for s in env_from_secrets %} + - secretRef: + name: {{ s }} + {% endfor %} + {% endif %} + volumeMounts: + - name: workspace + mountPath: {{ workspace_mount_path }} + {% if storage_pvc_mounts %} + {% for pvc in storage_pvc_mounts %} + - name: {{ pvc.name }} + mountPath: {{ pvc.mount_path }} + {% if pvc.read_only %}readOnly: true{% endif %} + {% endfor %} + {% endif %} + - name: dshm + mountPath: /dev/shm + - name: aperture-devices + mountPath: /dev/aperture_devices + resources: + requests: + {% if cpu_limit %}cpu: {{ cpu_limit }}{% endif %} + {% if memory_limit %}memory: {{ memory_limit }}{% endif %} + {% if gpus %}"nvidia.com/gpu": {{ gpus }}{% endif %} + limits: + {% if cpu_limit %}cpu: {{ cpu_limit }}{% endif %} + {% if memory_limit %}memory: {{ memory_limit }}{% endif %} + {% if gpus %}"nvidia.com/gpu": {{ gpus }}{% endif %} diff --git a/nemo_run/core/execution/templates/kubeflow_pvc.yaml.j2 b/nemo_run/core/execution/templates/kubeflow_pvc.yaml.j2 new file mode 100644 index 00000000..6da137ef --- /dev/null +++ b/nemo_run/core/execution/templates/kubeflow_pvc.yaml.j2 @@ -0,0 +1,16 @@ +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: {{ name }} + namespace: {{ namespace }} +spec: + accessModes: + {% for mode in access_modes %} + - {{ mode }} + {% endfor %} + resources: + requests: + storage: {{ size }} + {% if storage_class %} + storageClassName: {{ storage_class }} + {% endif %} diff --git a/nemo_run/core/packaging/__init__.py b/nemo_run/core/packaging/__init__.py index 2d935ccc..4ca65b92 100644 --- a/nemo_run/core/packaging/__init__.py +++ b/nemo_run/core/packaging/__init__.py @@ -14,8 +14,15 @@ # limitations under the License. from nemo_run.core.packaging.base import Packager +from nemo_run.core.packaging.configmap import ConfigMapPackager from nemo_run.core.packaging.git import GitArchivePackager from nemo_run.core.packaging.hybrid import HybridPackager from nemo_run.core.packaging.pattern import PatternPackager -__all__ = ["Packager", "GitArchivePackager", "PatternPackager", "HybridPackager"] +__all__ = [ + "Packager", + "ConfigMapPackager", + "GitArchivePackager", + "PatternPackager", + "HybridPackager", +] diff --git a/nemo_run/core/packaging/base.py b/nemo_run/core/packaging/base.py index 95bd25d0..5a65023c 100644 --- a/nemo_run/core/packaging/base.py +++ b/nemo_run/core/packaging/base.py @@ -23,6 +23,22 @@ logger = logging.getLogger(__name__) +def sanitize_kubernetes_name(name: str) -> str: + """ + Sanitize a string to be used as a Kubernetes resource name. + + Replaces underscores with hyphens to comply with RFC 1123 subdomain rules. + This is a common pattern used across the codebase for Kubernetes resource naming. + + Args: + name: The string to sanitize + + Returns: + A sanitized string suitable for use as a Kubernetes resource name + """ + return name.replace("_", "-") + + @dataclass(kw_only=True) class Packager(ConfigurableMixin): """ diff --git a/nemo_run/core/packaging/configmap.py b/nemo_run/core/packaging/configmap.py new file mode 100644 index 00000000..4b48519a --- /dev/null +++ b/nemo_run/core/packaging/configmap.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional + +from kubernetes import client, config +from kubernetes.client.rest import ApiException +from kubernetes.config.config_exception import ConfigException + +from nemo_run.core.packaging.base import Packager, sanitize_kubernetes_name + +logger = logging.getLogger(__name__) + + +# Kubernetes ConfigMap has 1MB limit per key, but we'll use a conservative limit +MAX_CONFIGMAP_SIZE = 1024 * 1024 # 1MB + + +@dataclass(kw_only=True) +class ConfigMapPackager(Packager): + """ + Packages files into a Kubernetes ConfigMap for use in distributed jobs. + """ + + include_pattern: str | List[str] = "*.py" + relative_path: str | List[str] = "." + namespace: str = "default" + configmap_prefix: str = "nemo-workspace" + base_path: Optional[Path] = None + key_prefix: Optional[str] = None + + # Internal store for additional in-memory files per experiment identifier + _additional_files: Dict[str, Dict[str, str]] = field( + default_factory=dict + ) # experiment_id -> {filename: content} + + def __post_init__(self): + """Initialize the Kubernetes client.""" + try: + config.load_incluster_config() + self.v1 = client.CoreV1Api() + except ConfigException: + try: + config.load_kube_config() + self.v1 = client.CoreV1Api() + except ConfigException: + logger.warning( + "Could not load Kubernetes config, ConfigMap creation will be skipped" + ) + self.v1 = None + + def get_container_file_path(self, filename: str, volume_mount_path: str = "/workspace") -> str: + """ + Get the container file path for a given job_dir and filename. + + This method returns the full path where a file would be accessible + after being packaged in a ConfigMap and mounted in a container. + + Args: + job_dir: Directory prefix for organizing files within the ConfigMap + filename: The filename to get the path for + volume_mount_path: The volume mount path in the container + + Returns: + The full path where the file would be accessible in the container + """ + rel_path = Path(f"{volume_mount_path}/{filename}") + return self._sanitize_configmap_key(rel_path) + + def _sanitize_configmap_key(self, rel_path: Path) -> str: + """ + Sanitize a ConfigMap key to comply with Kubernetes ConfigMap key rules. + + Kubernetes ConfigMap keys cannot contain forward slashes (/), so we replace + them with hyphens (-). This method creates a key that organizes files within + the ConfigMap using the job_dir as a prefix. + + Args: + rel_path: Relative path of the file from the base directory + + Returns: + A sanitized ConfigMap key that complies with Kubernetes naming rules + """ + # Replace forward slashes with hyphens to satisfy key format in our mount path + # Preserve underscores and dots in file names. Only the ConfigMap NAME must be DNS-1123 safe, + # keys may contain underscores. See: ConfigMaps docs (envFrom example) + # https://kubernetes.io/docs/concepts/configuration/configmap/ + return str(rel_path).replace("/", "-") + + def package_default(self, name: str) -> str: + """ + Package using internal defaults so callers only provide a name. + + - base_path: defaults to Path.cwd() + - key_prefix: defaults to the resolved name suffix (sanitized) + """ + resolved_name = self.resolve_configmap_name(name) + path = self.base_path or Path.cwd() + job_dir = self.key_prefix or sanitize_kubernetes_name(name) + return self.package(path=path, job_dir=job_dir, name=resolved_name) + + def add_file( + self, + experiment_identifier: str, + filename: str, + content: str, + entrypoint: Optional[str] = None, + ) -> None: + """Add an in-memory file to be included for a specific experiment. + + The content is normalized by ensuring a shebang exists at the top. The + interpreter is selected based on the provided entrypoint hint. + + Args: + experiment_identifier: Logical experiment key used to group files + filename: The file name to expose inside the ConfigMap mount + content: Raw file content + entrypoint: Optional hint ("python" or "bash"), defaults to python + """ + normalized = content or "" + leading = normalized.lstrip() + if not leading.startswith("#!"): + ep = (entrypoint or "python").lower() + shebang = "#!/usr/bin/env python3" if "python" in ep else "#!/usr/bin/env bash" + normalized = f"{shebang}\n{normalized}" + + if experiment_identifier not in self._additional_files: + self._additional_files[experiment_identifier] = {} + self._additional_files[experiment_identifier][filename] = normalized + + def package_with_hash(self, name: str) -> tuple[str, str]: + """Package files and return (configmap_name, sha) based on content. + + This method collects files from disk based on include_pattern/relative_path + and merges them with any additional in-memory files previously added via + add_file(...). It computes a content hash over all entries (stable ordering) + and uses that to produce a deterministic ConfigMap name. + + Args: + name: Experiment identifier used to group additional files and as key prefix + + Returns: + Tuple of (configmap_name, sha256_hex) + """ + base_path = self.base_path or Path.cwd() + + # Collect files from disk + files_to_stage = self._find_files_to_package(base_path) + + configmap_data: Dict[str, str] = {} + for file_path in files_to_stage: + rel_path = file_path.relative_to(base_path) + configmap_key = self._sanitize_configmap_key(rel_path) + try: + with open(file_path, "r", encoding="utf-8") as f: + configmap_data[configmap_key] = f.read() + except Exception as e: + logger.warning(f"Could not read file {file_path}: {e}") + + # Merge additional in-memory files + for fname, fcontent in self._additional_files.get(name, {}).items(): + rel_path = Path(fname) + configmap_key = self._sanitize_configmap_key(rel_path) + configmap_data[configmap_key] = fcontent + + if not configmap_data: + logger.warning("No files found to package into ConfigMap") + # Fallback name without hash + return (self.resolve_configmap_name(name), "") + + # Enforce size limit + total_size = sum(len(v.encode("utf-8")) for v in configmap_data.values()) + if total_size > MAX_CONFIGMAP_SIZE: + logger.error( + f"Total content size ({total_size} bytes) exceeds ConfigMap limit ({MAX_CONFIGMAP_SIZE} bytes)." + ) + return (self.resolve_configmap_name(name), "") + + # Compute hash over sorted keys and contents + hasher = hashlib.sha256() + for key in sorted(configmap_data.keys()): + hasher.update(key.encode("utf-8")) + hasher.update(b"\0") + hasher.update(configmap_data[key].encode("utf-8")) + + sha = hasher.hexdigest()[:8] + configmap_name = self.resolve_configmap_name(f"{name}-{sha}") + + if self.v1 is None: + logger.warning("Kubernetes client not available, skipping ConfigMap creation") + return (configmap_name, sha) + + body = client.V1ConfigMap( + metadata=client.V1ObjectMeta(name=configmap_name), data=configmap_data + ) + try: + self.v1.create_namespaced_config_map(namespace=self.namespace, body=body) + logger.info( + f"Created ConfigMap: {configmap_name} with {len(configmap_data)} files (sha={sha})" + ) + except ApiException as e: + if e.status == 409: + logger.info( + f"ConfigMap already exists (content-addressed): {configmap_name} (sha={sha})" + ) + else: + logger.error(f"Failed to create ConfigMap {configmap_name}: {e}") + return (configmap_name, sha) + + def package(self, path: Path, job_dir: str, name: str) -> str: + """ + Package files into a Kubernetes ConfigMap. + Args: + path: Base path to search for files + job_dir: Directory prefix for organizing files within the ConfigMap + name: Name for the ConfigMap + Returns: + The name of the created ConfigMap (or intended name if not created) + """ + # Resolve the final ConfigMap name centrally + configmap_name = self.resolve_configmap_name(name) + + if self.v1 is None: + logger.warning("Kubernetes client not available, skipping ConfigMap creation") + return configmap_name + files_to_stage = self._find_files_to_package(path) + if not files_to_stage: + logger.warning("No files found to package into ConfigMap") + return configmap_name + + # Check total size of files to be staged + total_size = sum(file_path.stat().st_size for file_path in files_to_stage) + if total_size > MAX_CONFIGMAP_SIZE: + logger.error( + f"Total file size ({total_size} bytes) exceeds ConfigMap limit ({MAX_CONFIGMAP_SIZE} bytes). " + f"Consider using a different staging method for large files." + ) + return configmap_name + + if self.debug: + logger.debug( + f"Found {len(files_to_stage)} files to package (total size: {total_size} bytes)" + ) + for file_path in files_to_stage: + logger.debug(f" - {file_path} ({file_path.stat().st_size} bytes)") + + configmap_data = {} + for file_path in files_to_stage: + rel_path = file_path.relative_to(path) + # Use the sanitization method to create a valid ConfigMap key + configmap_key = self._sanitize_configmap_key(rel_path) + try: + with open(file_path, "r", encoding="utf-8") as f: + configmap_data[configmap_key] = f.read() + except Exception as e: + logger.warning(f"Could not read file {file_path}: {e}") + + if not configmap_data: + logger.warning("No files could be read for ConfigMap") + return configmap_name + + body = client.V1ConfigMap( + metadata=client.V1ObjectMeta(name=configmap_name), data=configmap_data + ) + try: + self.v1.create_namespaced_config_map(namespace=self.namespace, body=body) + logger.info(f"Created ConfigMap: {configmap_name} with {len(configmap_data)} files") + except ApiException as e: + if e.status == 409: + # Update existing ConfigMap with new data + try: + self.v1.replace_namespaced_config_map( + name=configmap_name, namespace=self.namespace, body=body + ) + logger.info( + f"Replaced ConfigMap: {configmap_name} with {len(configmap_data)} files" + ) + except ApiException as e2: + logger.error(f"Failed to replace ConfigMap {configmap_name}: {e2}") + else: + logger.error(f"Failed to create ConfigMap {configmap_name}: {e}") + return configmap_name + + def resolve_configmap_name(self, name: str) -> str: + """ + Resolve the full ConfigMap name from a caller-provided suffix. + + Centralizes naming logic so callers never assemble full names. + Ensures the final name has the prefix exactly once. + """ + return sanitize_kubernetes_name(f"{self.configmap_prefix}-{name}") + + def _find_files_to_package(self, base_path: Path) -> List[Path]: + """ + Find files to package based on include_pattern and relative_path. + Args: + base_path: The base directory to search from + Returns: + List of Path objects for files to include + """ + files = [] + patterns = ( + [self.include_pattern] + if isinstance(self.include_pattern, str) + else self.include_pattern + ) + rel_paths = ( + [self.relative_path] if isinstance(self.relative_path, str) else self.relative_path + ) + for pattern, rel_path in zip(patterns, rel_paths): + search_path = base_path / rel_path + if search_path.exists(): + for file_path in search_path.rglob(pattern): + if file_path.is_file(): + files.append(file_path) + return sorted(set(files)) + + def cleanup(self, name: str) -> None: + """ + Delete the ConfigMap from Kubernetes. + Args: + name: The name suffix of the ConfigMap to delete + """ + if self.v1 is None: + return + # Use the same resolution logic as in package() + configmap_name = self.resolve_configmap_name(name) + try: + self.v1.delete_namespaced_config_map(name=configmap_name, namespace=self.namespace) + logger.info(f"Cleaned up ConfigMap: {configmap_name}") + except ApiException as e: + if e.status == 404: + logger.info(f"ConfigMap {configmap_name} not found") + else: + logger.error(f"Failed to clean up ConfigMap {configmap_name}: {e}") diff --git a/nemo_run/run/experiment.py b/nemo_run/run/experiment.py index 49b9e43e..e1e89762 100644 --- a/nemo_run/run/experiment.py +++ b/nemo_run/run/experiment.py @@ -52,6 +52,7 @@ from nemo_run.core.execution.base import Executor from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.docker import DockerExecutor +from nemo_run.core.execution.kubeflow import KubeflowExecutor from nemo_run.core.execution.lepton import LeptonExecutor from nemo_run.core.execution.local import LocalExecutor from nemo_run.core.execution.skypilot import SkypilotExecutor @@ -204,12 +205,14 @@ class Experiment(ConfigurableMixin): DockerExecutor, DGXCloudExecutor, LeptonExecutor, + KubeflowExecutor, ) _DETACH_SUPPORTED_EXECUTORS = ( SlurmExecutor, SkypilotExecutor, DGXCloudExecutor, LeptonExecutor, + KubeflowExecutor, ) _DEPENDENCY_SUPPORTED_EXECUTORS = (SlurmExecutor,) _RUNNER_DEPENDENT_EXECUTORS = (LocalExecutor,) @@ -339,6 +342,7 @@ def __init__( self.log_level = log_level self._runner = get_runner(component_defaults=None, experiment=self) + self._detach_mode = False # Will be set in _run_dag if not _reconstruct: self.executor = executor if executor else LocalExecutor() @@ -468,6 +472,23 @@ def _add_single_job( task_dir=name if reuse_job_dir else task_id, ) + # Set detach mode on executor if supported + if hasattr(self, "detach") and hasattr(executor, "set_detach_mode"): + set_detach_mode = getattr(executor, "set_detach_mode", None) + if set_detach_mode: + self.console.log( + f"Setting detach mode to {self.detach} on executor {type(executor).__name__}" + ) + set_detach_mode(self.detach) + else: + self.console.log( + f"Executor {type(executor).__name__} doesn't support set_detach_mode" + ) + else: + self.console.log( + f"Experiment detach mode: {getattr(self, 'detach', 'not set')}, Executor has set_detach_mode: {hasattr(executor, 'set_detach_mode')}" + ) + cloned = copy.deepcopy(task) if isinstance(task, Script) else task.clone() job = Job( id=task_id, @@ -780,6 +801,12 @@ def _run_dag(self, detach: bool, tail_logs: bool, executors: set[Executor]): ) wait = False self.detach = detach + self._detach_mode = detach + + # Create a new runner with detach mode for this execution + from nemo_run.run.torchx_backend.runner import get_runner + + self._runner = get_runner(component_defaults=None, detach_mode=detach) for level in order: # Launch jobs in this level concurrently since they are independent diff --git a/nemo_run/run/torchx_backend/runner.py b/nemo_run/run/torchx_backend/runner.py index 7de27e83..bb93987c 100644 --- a/nemo_run/run/torchx_backend/runner.py +++ b/nemo_run/run/torchx_backend/runner.py @@ -112,6 +112,7 @@ def schedule(self, dryrun_info: AppDryRunInfo) -> AppHandle: def get_runner( component_defaults: Optional[dict[str, dict[str, str]]] = None, + detach_mode: bool = False, **scheduler_params: Any, ) -> Runner: """ @@ -144,5 +145,9 @@ def get_runner( """ name = "nemo_run" + # Add detach_mode to scheduler_params for kubeflow scheduler + if detach_mode: + scheduler_params["detach_mode"] = detach_mode + scheduler_factories = get_scheduler_factories() return Runner(name, scheduler_factories, component_defaults, scheduler_params=scheduler_params) diff --git a/nemo_run/run/torchx_backend/schedulers/api.py b/nemo_run/run/torchx_backend/schedulers/api.py index 5ade157d..b971ec90 100644 --- a/nemo_run/run/torchx_backend/schedulers/api.py +++ b/nemo_run/run/torchx_backend/schedulers/api.py @@ -20,6 +20,7 @@ from nemo_run.core.execution.base import Executor from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.docker import DockerExecutor +from nemo_run.core.execution.kubeflow import KubeflowExecutor from nemo_run.core.execution.lepton import LeptonExecutor from nemo_run.core.execution.local import LocalExecutor from nemo_run.core.execution.skypilot import SkypilotExecutor @@ -32,6 +33,7 @@ DockerExecutor: "docker_persistent", DGXCloudExecutor: "dgx_cloud", LeptonExecutor: "lepton", + KubeflowExecutor: "kubeflow", } REVERSE_EXECUTOR_MAPPING: dict[str, Type[Executor]] = { @@ -41,6 +43,7 @@ "docker_persistent": DockerExecutor, "dgx_cloud": DGXCloudExecutor, "lepton": LeptonExecutor, + "kubeflow": KubeflowExecutor, } diff --git a/nemo_run/run/torchx_backend/schedulers/kubeflow.py b/nemo_run/run/torchx_backend/schedulers/kubeflow.py new file mode 100644 index 00000000..e2f6bc98 --- /dev/null +++ b/nemo_run/run/torchx_backend/schedulers/kubeflow.py @@ -0,0 +1,238 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from datetime import datetime +from typing import Any, Iterable, Optional + +from torchx.schedulers.api import ( + AppDryRunInfo, + DescribeAppResponse, + Stream, +) +from torchx.specs.api import AppDef, AppState + +from nemo_run.core.execution.base import Executor +from nemo_run.core.execution.kubeflow import KubeflowExecutor +from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin + +logger = logging.getLogger(__name__) + + +class KubeflowScheduler(SchedulerMixin): + """ + TorchX scheduler for Kubeflow Trainer. + + This scheduler integrates with the KubeflowExecutor to submit and manage + training jobs using the Kubeflow Trainer SDK. + """ + + def __init__( + self, + session_name: str, + namespace: str = "default", + detach_mode: bool = False, + **kwargs: Any, + ) -> None: + self.backend = "kubeflow" + self.session_name = session_name + self.namespace = namespace + self.detach_mode = detach_mode + self._apps: dict[str, dict[str, Any]] = {} + + def _submit_dryrun(self, app: AppDef, cfg: Executor) -> AppDryRunInfo[dict[str, Any]]: + """Create a dry run info for the Kubeflow job.""" + assert isinstance(cfg, KubeflowExecutor), ( + f"{cfg.__class__} not supported for kubeflow scheduler." + ) + + # Convert AppDef to Kubeflow job configuration + job_config = self._appdef_to_kubeflow_config(app, cfg) + + return AppDryRunInfo( + job_config, + lambda _: f"Kubeflow job: {app.name}", + ) + + def schedule(self, dryrun_info: AppDryRunInfo[dict[str, Any]]) -> str: + """Submit the job to Kubeflow.""" + job_config = dryrun_info.request + cfg = job_config["executor"] + + # Create the TrainJob using KubeflowExecutor + # Extract the task from the app definition + app = job_config["app"] + task = None + + # Try to extract task from the app roles + if app.roles and len(app.roles) > 0: + main_role = app.roles[0] + if main_role.args: + # Create a simple task object for the executor + from nemo_run.config import Script + + task = Script(inline=" ".join(main_role.args)) + + if task is None: + # Create a default task if none found + from nemo_run.config import Script + + task = Script(inline="echo 'No task specified'") + + # Delegate fully to executor; it handles ConfigMap/CRT prep and TrainJob creation + job_id = cfg.submit(task, app.name) + + # Store job info for later reference + self._apps[job_id] = { + "app": job_config["app"], + "executor": cfg, + "job_id": job_id, + "state": AppState.SUBMITTED, + } + + logger.info(f"Submitted Kubeflow job: {job_id}") + return job_id + + def describe(self, app_id: str) -> Optional[DescribeAppResponse]: + """Get the status of a Kubeflow job.""" + if app_id not in self._apps: + return None + + job_info = self._apps[app_id] + executor = job_info["executor"] + + try: + status = executor.get_trainjob_status(app_id) + # Map Kubeflow status to TorchX AppState + app_state = self._map_kubeflow_status_to_torchx(status) + + return DescribeAppResponse( + app_id=app_id, + state=app_state, + num_restarts=0, # Kubeflow handles restarts internally + msg=f"Kubeflow job status: {status}", + structured_error_msg="", + roles_statuses=[], + ) + except Exception as e: + logger.error(f"Failed to describe job {app_id}: {e}") + return None + + def log_iter( + self, + app_id: str, + role_name: str, + k: int = 0, + regex: Optional[str] = None, + since: Optional[datetime] = None, + until: Optional[datetime] = None, + should_tail: bool = False, + streams: Optional[Stream] = None, + ) -> Iterable[str]: + """Get logs from the Kubeflow job.""" + if app_id not in self._apps: + return [] + + job_info = self._apps[app_id] + executor = job_info["executor"] + + try: + logs = executor.get_trainjob_logs(app_id, follow=should_tail) + # For now, return a simple log message + # In a real implementation, you'd parse the actual logs + log_lines = [f"Kubeflow job {app_id} logs:"] + if logs: + log_lines.extend(str(logs).split("\n")) + else: + log_lines.append("No logs available yet") + + return log_lines + except Exception as e: + logger.error(f"Failed to get logs for job {app_id}: {e}") + return [f"Error getting logs: {e}"] + + def cancel(self, app_id: str) -> None: + """Cancel a Kubeflow job.""" + if app_id not in self._apps: + return + + job_info = self._apps[app_id] + executor = job_info["executor"] + + try: + executor.delete_trainjob(app_id) + logger.info(f"Cancelled Kubeflow job: {app_id}") + except Exception as e: + logger.error(f"Failed to cancel job {app_id}: {e}") + + def _appdef_to_kubeflow_config(self, app: AppDef, cfg: KubeflowExecutor) -> dict[str, Any]: + """Convert AppDef to Kubeflow job configuration.""" + # Return the config for executor submission + return { + "app": app, + "executor": cfg, + } + + def _map_kubeflow_status_to_torchx(self, kubeflow_status: str) -> AppState: + """Map Kubeflow job status to TorchX AppState.""" + status_lower = kubeflow_status.lower() + + if "running" in status_lower or "pending" in status_lower: + return AppState.RUNNING + elif "succeeded" in status_lower or "completed" in status_lower: + return AppState.SUCCEEDED + elif "failed" in status_lower or "error" in status_lower: + return AppState.FAILED + elif "cancelled" in status_lower or "terminated" in status_lower: + return AppState.CANCELLED + else: + return AppState.UNKNOWN + + def _validate(self, app: AppDef, scheduler: str) -> None: + """Validate the app definition for Kubeflow.""" + # For now, skip validation as Kubeflow handles this internally + pass + + def close(self) -> None: + """Clean up resources when the scheduler is closed.""" + # Cancel all running jobs unless in detach mode + for app_id in list(self._apps.keys()): + try: + # Check if scheduler is in detach mode + if self.detach_mode: + logger.info(f"Skipping cleanup for job {app_id} in detach mode") + continue + + self.cancel(app_id) + except Exception as e: + logger.error(f"Failed to cancel job {app_id} during close: {e}") + + # Clear the apps dictionary + self._apps.clear() + + +def create_scheduler( + session_name: str, + namespace: str = "default", + detach_mode: bool = False, + **kwargs: Any, +) -> KubeflowScheduler: + """Create a Kubeflow scheduler instance.""" + return KubeflowScheduler( + session_name=session_name, + namespace=namespace, + detach_mode=detach_mode, + **kwargs, + ) diff --git a/pyproject.toml b/pyproject.toml index 56bfbdc5..3eaf0c68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,8 @@ dependencies = [ "leptonai>=0.25.0", "packaging", "toml", + "kubernetes>=28.0.0", + "kubeflow @ git+https://github.com/jskswamy/kubeflow-sdk.git@main", ] readme = "README.md" requires-python = ">= 3.10" @@ -47,17 +49,16 @@ skypilot = "nemo_run.run.torchx_backend.schedulers.skypilot:create_scheduler" local_persistent = "nemo_run.run.torchx_backend.schedulers.local:create_scheduler" docker_persistent = "nemo_run.run.torchx_backend.schedulers.docker:create_scheduler" dgx_cloud = "nemo_run.run.torchx_backend.schedulers.dgxcloud:create_scheduler" +kubeflow = "nemo_run.run.torchx_backend.schedulers.kubeflow:create_scheduler" lepton = "nemo_run.run.torchx_backend.schedulers.lepton:create_scheduler" [project.optional-dependencies] -skypilot = [ - "skypilot[kubernetes]>=0.10.0", -] -skypilot-all = [ - "skypilot[all]>=0.10.0", -] -ray = [ - "kubernetes" +skypilot = ["skypilot[kubernetes]>=0.10.0"] +skypilot-all = ["skypilot[all]>=0.10.0"] +ray = ["kubernetes"] +kubernetes = [ + "kubernetes>=28.0.0", + "kubeflow @ git+https://github.com/jskswamy/kubeflow-sdk.git@main", ] [dependency-groups] @@ -71,12 +72,10 @@ dev = [ "ipykernel>=6.29.4", "ipywidgets>=8.1.2", "jupyter>=1.1.1", - "pytest-cov" + "pytest-cov", ] -lint = [ - "ruff>=0.4.4", -] +lint = ["ruff>=0.4.4"] docs = [ "astroid==3.3.8", @@ -99,20 +98,23 @@ conflicts = [ [ { group = "docs", name = "colorama" }, { extra = "skypilot", name = "colorama" }, - { extra = "skypilot-all", name = "colorama" } - ] + { extra = "skypilot-all", name = "colorama" }, + ], ] [tool.pytest.ini_options] -markers = [ - "slow: marks tests as slow (deselect with '-m \"not slow\"')", -] +markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] addopts = '-m "not slow"' [tool.coverage.run] branch = true include = ["nemo_run/**/*.py"] -omit = ["nemo_run/core/tunnel/callback.py", "nemo_run/help.py", "nemo_run/**/__init__.py", "nemo_run/**/_version.py"] +omit = [ + "nemo_run/core/tunnel/callback.py", + "nemo_run/help.py", + "nemo_run/**/__init__.py", + "nemo_run/**/_version.py", +] [tool.coverage.report] # Regexes for lines to exclude from consideration @@ -132,7 +134,7 @@ exclude_also = [ # Don't complain about abstract methods, they aren't run: "@(abc\\.)?abstractmethod", - ] +] ignore_errors = true @@ -146,7 +148,7 @@ allow-direct-references = true packages = ["nemo_run"] [tool.hatch.version] -path = "nemo_run/package_info.py" +path = "nemo_run/package_info.py" [tool.ruff] line-length = 100 diff --git a/test/core/execution/Run.code-workspace b/test/core/execution/Run.code-workspace new file mode 100644 index 00000000..cbdcb7b3 --- /dev/null +++ b/test/core/execution/Run.code-workspace @@ -0,0 +1,23 @@ +{ + "folders": [ + { + "path": "../../.." + }, + { + "path": "../../../../../twlabs/mpt-platform-workbench" + }, + { + "path": "../../../../../twlabs/mpt-platform-mle-experiments/kubernetes/NeMo" + }, + { + "path": "../../../../../kubeflow/sdk" + }, + { + "path": "../../../../../twlabs/mpt-platform-mle-experiments/gpt-pretrain-kubeflow" + }, + { + "path": "../../../../../kubeflow/trainer" + } + ], + "settings": {} +} diff --git a/test/core/execution/test_kubeflow.py b/test/core/execution/test_kubeflow.py new file mode 100644 index 00000000..b47e3b69 --- /dev/null +++ b/test/core/execution/test_kubeflow.py @@ -0,0 +1,946 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest +from kubernetes import config +from kubernetes.client.exceptions import ApiException + +from nemo_run.config import Partial, Script +from nemo_run.core.execution.kubeflow import ( + AdditionalPackages, + KubeflowExecutor, + StorageMount, +) +from nemo_run.core.execution.utils import fill_template +from nemo_run.core.packaging.base import Packager +from nemo_run.core.packaging.configmap import ConfigMapPackager + + +class TestStorageMounts: + def test_get_volume_name_defaults_and_sanitizes(self): + # Explicit name is sanitized + sm_named = StorageMount( + mount_path="/mnt/a", + name="bad_Name", + pvc_claim_name="claim-a", + ) + assert sm_named.get_volume_name(5) == "bad-name" + + # No name -> defaults to pvc-{index} + sm_default = StorageMount( + mount_path="/mnt/a", + ) + assert sm_default.get_volume_name(3) == "pvc-3" + + def test_get_pvc_claim_name_sanitizes_and_none(self): + # Sanitizes underscores to hyphens + sm_claim = StorageMount( + mount_path="/mnt/a", + pvc_claim_name="my_claim", + ) + assert sm_claim.get_pvc_claim_name() == "my-claim" + + # None stays None + sm_none = StorageMount( + mount_path="/mnt/a", + ) + assert sm_none.get_pvc_claim_name() is None + + def test_storage_mount_name_sanitization(self): + executor = KubeflowExecutor() + executor.storage_mounts = [ + StorageMount( + mount_path="/mnt/a", + read_only=False, + name="mistral_checkpoint", + pvc_claim_name="claim-a", + kind="pvc", + ) + ] + + frags = executor._get_normalized_storage_mounts() + assert frags[0]["name"] == "mistral-checkpoint" + + def test_storage_mounts_normalization_to_template(self): + executor = KubeflowExecutor() + # Create storage mounts + + executor.storage_mounts = [ + StorageMount( + mount_path="/mnt/a", + read_only=True, + name="data-a", + pvc_claim_name="claim-a", + kind="pvc", + ), + StorageMount( + mount_path="/mnt/b", + read_only=False, + pvc_claim_name="claim-b", + kind="pvc", + ), + ] + + frags = executor._get_normalized_storage_mounts() + assert len(frags) == 2 + assert frags[0]["name"] == "data-a" + assert frags[0]["claim_name"] == "claim-a" + assert frags[0]["mount_path"] == "/mnt/a" + assert frags[0]["read_only"] is True + assert frags[1]["name"].startswith("pvc-") + assert frags[1]["claim_name"] == "claim-b" + assert frags[1]["mount_path"] == "/mnt/b" + assert frags[1]["read_only"] is False + + def test_crt_template_renders_storage_pvc(self): + # Render CRT template directly with storage_pvc_mounts + + rendered = fill_template( + template_name="kubeflow_clustertrainingruntime.yaml.j2", + variables={ + "runtime_name": "rt", + "namespace": "ns", + "nodes": 1, + "image": "img", + "workspace_mount_path": "/src", + "configmap_name": "cfg", + "cpu_limit": None, + "memory_limit": None, + "gpus": None, + "enable_tcpxo": False, + "storage_pvc_mounts": [ + { + "name": "data-a", + "claim_name": "claim-a", + "mount_path": "/mnt/a", + "read_only": True, + } + ], + }, + ) + + assert "persistentVolumeClaim" in rendered + assert "claim-a" in rendered + assert "mountPath: /mnt/a" in rendered + assert "readOnly: true" in rendered + + def test_crt_template_renders_envfrom_secret(self): + rendered = fill_template( + template_name="kubeflow_clustertrainingruntime.yaml.j2", + variables={ + "runtime_name": "rt", + "namespace": "ns", + "nodes": 1, + "image": "img", + "workspace_mount_path": "/src", + "configmap_name": "cfg", + "cpu_limit": None, + "memory_limit": None, + "gpus": None, + "enable_tcpxo": False, + "storage_pvc_mounts": [], + "env_from_secrets": ["my-secret"], + }, + ) + + assert "envFrom:" in rendered + assert "name: my-secret" in rendered + + +def test_crt_template_renders_nodes_and_numproc(): + rendered = fill_template( + template_name="kubeflow_clustertrainingruntime.yaml.j2", + variables={ + "runtime_name": "rt", + "namespace": "ns", + "nodes": 2, + "num_proc_per_node": 8, + "image": "img", + "workspace_mount_path": "/src", + "configmap_name": "cfg", + "cpu_limit": None, + "memory_limit": None, + "gpus": None, + "enable_tcpxo": False, + "storage_pvc_mounts": [], + }, + ) + + assert "numNodes: 2" in rendered + assert "numProcPerNode: 8" in rendered + + +def test_crt_template_renders_gpu_resources_in_requests_and_limits(): + rendered = fill_template( + template_name="kubeflow_clustertrainingruntime.yaml.j2", + variables={ + "runtime_name": "rt", + "namespace": "ns", + "nodes": 1, + "num_proc_per_node": 8, + "image": "img", + "workspace_mount_path": "/src", + "configmap_name": "cfg", + "cpu_limit": None, + "memory_limit": None, + "gpus": 8, + "enable_tcpxo": False, + "storage_pvc_mounts": [], + }, + ) + + # GPU count should be present under both requests and limits + assert '"nvidia.com/gpu": 8' in rendered + + def test_pvc_creation_when_missing(self, mocker): + # Configure an executor with a PVC that should be created + + from nemo_run.core.execution.kubeflow import StorageMount + + executor = KubeflowExecutor(namespace="default") + executor.storage_mounts = [ + StorageMount( + mount_path="/mnt/a", + pvc_claim_name="claim_a", + create_if_missing=True, + size="200Gi", + storage_class="standard", + access_modes=["ReadWriteOnce"], + ) + ] + + mock_core = mocker.patch("kubernetes.client.CoreV1Api") + api = mock_core.return_value + api.read_namespaced_persistent_volume_claim.side_effect = ApiException(status=404) + + executor._ensure_storage() + + assert api.create_namespaced_persistent_volume_claim.called + args, kwargs = api.create_namespaced_persistent_volume_claim.call_args + body = kwargs["body"] + assert body["metadata"]["name"] == "claim-a" + assert body["spec"]["resources"]["requests"]["storage"] == "200Gi" + assert body.get("spec", {}).get("storageClassName") == "standard" + + def test_pvc_creation_skipped_when_exists(self, mocker): + # Should not call create when PVC exists + + executor = KubeflowExecutor(namespace="default") + executor.storage_mounts = [ + StorageMount( + mount_path="/mnt/a", + pvc_claim_name="claim_a", + create_if_missing=True, + ) + ] + + mock_core = mocker.patch("kubernetes.client.CoreV1Api") + api = mock_core.return_value + # read succeeds (no exception) + executor._ensure_storage() + + assert not api.create_namespaced_persistent_volume_claim.called + + +def test_kubeflow_executor_default_init(): + """Test that KubeflowExecutor initializes with defaults.""" + executor = KubeflowExecutor() + + assert executor.nodes == 1 + assert executor.ntasks_per_node == 1 + assert executor.namespace == "default" + assert executor.gpus_per_node is None + assert executor.job_name == "" + assert executor.workspace_mount_path == "/src" + assert isinstance(executor.packager, Packager) + + +def test_kubeflow_executor_custom_init(): + """Test that KubeflowExecutor initializes with custom values.""" + custom_config = { + "nodes": 2, + "ntasks_per_node": 4, + "namespace": "training", + "gpus_per_node": 8, + "workspace_mount_path": "/custom/workspace", + } + + executor = KubeflowExecutor(**custom_config) + + assert executor.nodes == 2 + assert executor.ntasks_per_node == 4 + assert executor.namespace == "training" + assert executor.gpus_per_node == 8 + assert executor.workspace_mount_path == "/custom/workspace" + + +def test_kubeflow_executor_validation(): + """Test parameter validation.""" + with pytest.raises(ValueError, match="nodes must be >= 1"): + KubeflowExecutor(nodes=0) + + with pytest.raises(ValueError, match="ntasks_per_node must be >= 1"): + KubeflowExecutor(ntasks_per_node=0) + + +def test_kubeflow_executor_assign(): + """Test that assign method sets the correct directories.""" + executor = KubeflowExecutor() + exp_id = "exp-123" + exp_dir = "/tmp/exp" + task_id = "task-1" + task_dir = "task_dir" + + executor.assign(exp_id, exp_dir, task_id, task_dir) + + assert executor.experiment_id == exp_id + assert executor.experiment_dir == exp_dir + assert executor.job_dir == f"{exp_dir}/{task_dir}" + assert executor.job_name == task_id + + +def test_kubeflow_executor_nnodes(): + """Test that nnodes returns the correct number of nodes.""" + expected_nodes = 3 + executor = KubeflowExecutor(nodes=expected_nodes) + + result = executor.nnodes() + + assert result == expected_nodes + + +def test_kubeflow_executor_nproc_per_node(): + """Test that nproc_per_node returns the correct number of processes.""" + expected_procs = 4 + executor = KubeflowExecutor(ntasks_per_node=expected_procs) + + result = executor.nproc_per_node() + + assert result == expected_procs + + +# _get_runtime was removed; runtime_name is passed explicitly + + +@pytest.mark.parametrize( + "executor_kwargs,expected_nodes", + [ + ( + { + "nodes": 2, + "gpus_per_node": 8, + "cpu_limit": "16", + "memory_limit": "32Gi", + }, + 2, + ), + ( + { + "nodes": 1, + "gpus_per_node": 4, + "workspace_mount_path": "/custom/workspace", + }, + 1, + ), + ], +) +def test_kubeflow_executor_get_custom_trainer_inline(executor_kwargs, expected_nodes): + """Test _get_custom_trainer with inline Script using SDK func embedding.""" + script_task = Script(inline="python train.py") + executor = KubeflowExecutor(**executor_kwargs) + executor.packager = ConfigMapPackager() + # Simulate the assignment process to set the experiment name + executor.assign("exp-123", "/tmp/exp", "task-1", "task_dir") + mock_trainer_instance = MagicMock() + + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + mock_trainer.return_value = mock_trainer_instance + + result = executor._get_custom_trainer(script_task) + + assert result == mock_trainer_instance + mock_trainer.assert_called_once() + + call_args = mock_trainer.call_args[1] + assert call_args["num_nodes"] == expected_nodes + # CommandTrainer should be invoked with runtime-aware command/args + mounted_path = f"{executor.workspace_mount_path}/{executor.training_entry}" + assert call_args.get("command") in (["/bin/bash"], ["python"], ["bash"], ["torchrun"]) + assert mounted_path in " ".join(call_args.get("args", [])) + + resources = call_args["resources_per_node"] + if "cpu_limit" in executor_kwargs: + assert resources["cpu"] == executor_kwargs["cpu_limit"] + if "memory_limit" in executor_kwargs: + assert resources["memory"] == executor_kwargs["memory_limit"] + if "gpus_per_node" in executor_kwargs: + assert resources["nvidia.com/gpu"] == str(executor_kwargs["gpus_per_node"]) + + +def test_kubeflow_executor_get_custom_trainer_function_based(): + """Partial is supported: ensure launcher produces torchrun with PET flags.""" + + def dummy_function(): + return "function result" + + partial_task = Partial(dummy_function) + executor = KubeflowExecutor(nodes=1, gpus_per_node=4) + executor.packager = ConfigMapPackager() + executor.assign("exp-123", "/tmp/exp", "task-1", "task_dir") + + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + instance = MagicMock() + mock_trainer.return_value = instance + + result = executor._get_custom_trainer(partial_task) + + assert result == instance + mock_trainer.assert_called_once() + + kwargs = mock_trainer.call_args[1] + assert kwargs["command"] in (["/bin/bash"], ["torchrun"]) + args_joined = " ".join(kwargs.get("args", [])) + assert "--nnodes ${PET_NNODES}" in args_joined + assert "--nproc_per_node ${PET_NPROC_PER_NODE}" in args_joined + assert "--rdzv_backend c10d" in args_joined + assert "--rdzv_endpoint ${PET_MASTER_ADDR}:${PET_MASTER_PORT}" in args_joined + + +def test_kubeflow_executor_get_custom_trainer_fallback(): + """Test _get_custom_trainer fallback behavior when using non-ConfigMapPackager.""" + script_task = Script(inline="python train.py") + executor = KubeflowExecutor() + # Use a different packager type to test fallback behavior + executor.packager = MagicMock() # Not a ConfigMapPackager + mock_trainer_instance = MagicMock() + + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + mock_trainer.return_value = mock_trainer_instance + + result = executor._get_custom_trainer(script_task) + + assert result == mock_trainer_instance + mock_trainer.assert_called_once() + + call_args = mock_trainer.call_args[1] + assert call_args["num_nodes"] == 1 + mounted_path = f"{executor.workspace_mount_path}/{executor.training_entry}" + assert mounted_path in " ".join(call_args.get("args", [])) + + +class TestEnvSecretHandling: + def test_secret_creation_without_conflict(self, mocker): + executor = KubeflowExecutor(namespace="default") + executor.packager = ConfigMapPackager() + executor.assign("exp-abc", "/tmp/exp", "task-1", "task_dir") + + executor.env_vars = {"CONFIG_KEY1": "xyz", "FOO": "bar"} + + mock_core = mocker.patch("kubernetes.client.CoreV1Api") + api = mock_core.return_value + # No exception on create (no conflict) + api.create_namespaced_secret.return_value = None + + with patch("nemo_run.core.execution.kubeflow.fill_template") as ft: + ft.return_value = "apiVersion: v1\nkind: ClusterTrainingRuntime\nmetadata: {}" + with patch("kubernetes.client.CustomObjectsApi") as mock_coa: + coa = mock_coa.return_value + coa.create_cluster_custom_object.return_value = {} + # Ensure executor believes Kubernetes is available for this test + executor._kubernetes_available = True + executor._create_cluster_training_runtime(configmap_name="cfg", sha="beadfeed") + + # Ensure create was called, and patch was NOT called + assert api.create_namespaced_secret.called + assert not api.patch_namespaced_secret.called + + # Capture variables passed to template and assert env_from_secrets includes our secret + called_vars = ft.call_args[1]["variables"] + assert "env_from_secrets" in called_vars + assert isinstance(called_vars["env_from_secrets"], list) + assert len(called_vars["env_from_secrets"]) == 1 + + def test_secret_creation_and_patch_on_conflict(self, mocker): + executor = KubeflowExecutor(namespace="default") + executor.packager = ConfigMapPackager() + # Simulate assignment to set experiment identifier used in secret name + executor.assign("exp-xyz", "/tmp/exp", "task-1", "task_dir") + + # Set env vars that should be converted to a Secret + executor.env_vars = {"CONFIG_KEY1": "abc", "OTHER": "val"} + + # Mock k8s CoreV1Api to simulate create 409 then patch + mock_core = mocker.patch("kubernetes.client.CoreV1Api") + api = mock_core.return_value + from kubernetes.client.exceptions import ApiException + + # First call: create raises 409 (already exists) + api.create_namespaced_secret.side_effect = ApiException(status=409) + + # Run ensure function indirectly via _create_cluster_training_runtime + with patch("nemo_run.core.execution.kubeflow.fill_template") as ft: + ft.return_value = "apiVersion: v1\nkind: ClusterTrainingRuntime\nmetadata: {}" + with patch("kubernetes.client.CustomObjectsApi") as mock_coa: + coa = mock_coa.return_value + coa.create_cluster_custom_object.return_value = {} + # Should call patch on conflict + # Ensure executor believes Kubernetes is available for this test + executor._kubernetes_available = True + executor._create_cluster_training_runtime(configmap_name="cfg", sha="deadbeef") + + assert api.patch_namespaced_secret.called + + +def test_kubeflow_executor_create_trainjob(): + """Test create_trainjob method.""" + executor = KubeflowExecutor(nodes=1) + script_task = Script(inline="print('Training')") + expected_job_id = "job-123" + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.train.return_value = expected_job_id + + result = executor.create_trainjob("test-job", script_task, "nemo-runtime-exp-abc-12345678") + + assert result == expected_job_id + mock_client_instance.train.assert_called_once() + _, kwargs = mock_client_instance.train.call_args + assert "trainer" in kwargs and kwargs["trainer"] is not None + + +def test_kubeflow_executor_get_trainjob_status(): + """Test get_trainjob_status method.""" + executor = KubeflowExecutor() + executor.packager = ConfigMapPackager() + expected_status = "Running" + job_name = "job-123" + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_job = MagicMock() + mock_job.status = expected_status + mock_client_instance.get_job.return_value = mock_job + + status = executor.get_trainjob_status(job_name) + + assert status == expected_status + mock_client_instance.get_job.assert_called_once_with(job_name) + + +def test_kubeflow_executor_delete_trainjob(): + """Test delete_trainjob method.""" + executor = KubeflowExecutor() + job_name = "job-123" + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + + executor.delete_trainjob(job_name) + + mock_client_instance.delete_job.assert_called_once_with(job_name) + + +def test_kubeflow_executor_get_trainjob_logs(): + """Test get_trainjob_logs method.""" + executor = KubeflowExecutor() + job_name = "job-123" + expected_logs = {"logs": "test logs"} + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.get_job_logs.return_value = expected_logs + + logs = executor.get_trainjob_logs(job_name, follow=True) + + assert logs == expected_logs + mock_client_instance.get_job_logs.assert_called_once_with(job_name, follow=True) + + +def test_kubeflow_executor_get_trainer_client(): + """Test _get_trainer_client method.""" + executor = KubeflowExecutor() + mock_client_instance = MagicMock() + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client.return_value = mock_client_instance + + result = executor._get_trainer_client() + + assert result == mock_client_instance + mock_client.assert_called_once() + + result2 = executor._get_trainer_client() + + assert result2 == mock_client_instance + assert mock_client.call_count == 1 + + +def test_kubeflow_executor_post_init(): + """Test __post_init__ method with valid configuration.""" + expected_nodes = 1 + expected_ntasks = 1 + + executor = KubeflowExecutor(nodes=expected_nodes, ntasks_per_node=expected_ntasks) + + assert executor.nodes == expected_nodes + assert executor.ntasks_per_node == expected_ntasks + + +def test_kubeflow_executor_create_trainjob_with_error(): + """Test create_trainjob method with error handling.""" + executor = KubeflowExecutor() + script_task = Script(inline="print('Training')") + error_message = "TrainJob creation failed" + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.train.side_effect = Exception(error_message) + + with pytest.raises(Exception, match=error_message): + executor.create_trainjob("test-job", script_task, "nemo-runtime-exp-abc-12345678") + + +def test_kubeflow_executor_get_trainjob_status_with_error(): + """Test get_trainjob_status method with error handling.""" + executor = KubeflowExecutor() + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.get_job.side_effect = Exception("Status check failed") + + status = executor.get_trainjob_status("job-123") + + assert status == "Unknown" + + +def test_kubeflow_executor_delete_trainjob_with_error(): + """Test delete_trainjob method with error handling.""" + executor = KubeflowExecutor() + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.delete_job.side_effect = Exception("Delete failed") + + executor.delete_trainjob("job-123") + + +def test_kubeflow_executor_get_trainjob_logs_with_error(): + """Test get_trainjob_logs method with error handling.""" + executor = KubeflowExecutor() + + with patch("nemo_run.core.execution.kubeflow.TrainerClient") as mock_client: + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.get_job_logs.side_effect = Exception("Log retrieval failed") + + logs = executor.get_trainjob_logs("job-123") + + assert logs == {} + + +def test_kubeflow_executor_info(): + """Test info method.""" + expected_nodes = 2 + expected_gpus = 4 + executor = KubeflowExecutor(nodes=expected_nodes, gpus_per_node=expected_gpus) + + info = executor.info() + + expected_info = f"KubeflowExecutor (nodes={expected_nodes}, gpus={expected_gpus})" + assert expected_info in info + + +def test_kubeflow_executor_stage_files(): + """Test stage_files method.""" + executor = KubeflowExecutor() + executor.packager = ConfigMapPackager() + executor.experiment_id = "exp-123" + executor.experiment_name = "exp123" + executor.experiment_dir = "/tmp/exp" + expected_configmap_name = "nemo-workspace-exp-123-abcdef12" + expected_sha = "abcdef12" + + with patch.object(executor.packager, "package_with_hash") as mock_package: + mock_package.return_value = (expected_configmap_name, expected_sha) + + result_name, result_sha = executor.stage_files("task-dir", task=Script(inline="print('x')")) + + assert result_name == expected_configmap_name + assert result_sha == expected_sha + mock_package.assert_called_once() + + +def test_kubeflow_executor_cleanup_files(): + """Test cleanup_files method.""" + executor = KubeflowExecutor() + executor.packager = ConfigMapPackager() + executor.experiment_id = "exp-123" + executor.experiment_name = "exp123" + + with patch.object(executor.packager, "cleanup") as mock_cleanup: + executor.cleanup_files("task-dir") + + mock_cleanup.assert_called_once() + + +def test_kubeflow_executor_get_staged_file_path(): + """Test _get_staged_file_path method.""" + executor = KubeflowExecutor() + executor.packager = ConfigMapPackager() + filename = "test.py" + # Set experiment_name since we didn't call assign + executor.experiment_name = "expname" + expected_path = "/src/expname-test.py" + + result = executor._get_staged_file_path(filename) + + assert result == expected_path + + +def test_kubeflow_executor_get_staged_file_path_non_configmap(): + """Test _get_staged_file_path with non-ConfigMap packager.""" + executor = KubeflowExecutor() + from nemo_run.core.packaging import PatternPackager + + executor.packager = PatternPackager(include_pattern="*.py", relative_path=".") + filename = "test.py" + + result = executor._get_staged_file_path(filename) + + assert result == filename + + +def test_kubeflow_executor_invalid_task(): + """Test that KubeflowExecutor handles invalid task types by defaulting to python_file.""" + executor = KubeflowExecutor(nodes=1) + invalid_task = "invalid_task" + + mock_trainer_instance = MagicMock() + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + mock_trainer.return_value = mock_trainer_instance + + result = executor._get_custom_trainer(invalid_task) + + assert result == mock_trainer_instance + mock_trainer.assert_called_once() + + call_args = mock_trainer.call_args[1] + # Invalid tasks are treated like script and use staged entry path + mounted_path = f"{executor.workspace_mount_path}/{executor.training_entry}" + assert mounted_path in " ".join(call_args.get("args", [])) + + +def test_kubeflow_executor_kubernetes_setup(): + """Test Kubernetes configuration setup.""" + with patch("kubernetes.config.load_incluster_config") as mock_incluster: + with patch("kubernetes.config.load_kube_config") as mock_kubeconfig: + with patch("kubernetes.client.CoreV1Api") as mock_core: + mock_core.return_value.list_namespace.return_value = None + + executor = KubeflowExecutor() + + assert executor._kubernetes_available is True + + +def test_kubeflow_executor_kubernetes_setup_failure(): + """Test Kubernetes configuration setup failure.""" + + with patch( + "kubernetes.config.load_incluster_config", + side_effect=config.ConfigException("Config error"), + ): + with patch( + "kubernetes.config.load_kube_config", side_effect=config.ConfigException("Config error") + ): + with patch("kubernetes.client.CoreV1Api") as mock_core: + mock_core.return_value.list_namespace.side_effect = Exception("API error") + + executor = KubeflowExecutor() + + assert executor._kubernetes_available is False + + +def test_kubeflow_executor_detach_mode(): + """Test detach mode setting.""" + executor = KubeflowExecutor() + + executor.set_detach_mode(True) + + assert executor._detach_mode is True + + executor.set_detach_mode(False) + + assert executor._detach_mode is False + + +def test_kubeflow_executor_macro_values(): + """Test macro_values method.""" + executor = KubeflowExecutor() + + result = executor.macro_values() + + assert result is None + + +def test_kubeflow_executor_injects_torchrun_for_script(): + """Script tasks should run under torchrun with PET-derived rendezvous flags.""" + executor = KubeflowExecutor(nodes=2, ntasks_per_node=8) + executor.packager = ConfigMapPackager() + # Simulate assignment to set experiment fields + executor.assign("exp-abc123", "/tmp/exp", "task-1", "task_dir") + + script_task = Script(inline="python mistral.py") + + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + instance = MagicMock() + mock_trainer.return_value = instance + + result = executor._get_custom_trainer(script_task) + + assert result == instance + mock_trainer.assert_called_once() + + kwargs = mock_trainer.call_args[1] + # Use direct torchrun invocation with PET-derived flags + assert kwargs["command"] == ["torchrun"] + args_list = kwargs.get("args") + assert isinstance(args_list, list) and len(args_list) >= 2 + args_joined = " ".join(args_list) + assert "--nnodes ${PET_NNODES}" in args_joined + assert "--nproc_per_node ${PET_NPROC_PER_NODE}" in args_joined + assert "--rdzv_backend c10d" in args_joined + assert "--rdzv_endpoint ${PET_MASTER_ADDR}:${PET_MASTER_PORT}" in args_joined + # Mounted script path + mounted_path = f"{executor.workspace_mount_path}/{executor.training_entry}" + assert mounted_path in args_joined + + +def test_kubeflow_executor_wraps_bash_script_without_torchrun(): + executor = KubeflowExecutor(nodes=2, ntasks_per_node=8) + executor.packager = ConfigMapPackager() + executor.assign("exp-abc123", "/tmp/exp", "task-1", "task_dir") + + script_task = Script(entrypoint="bash", inline="#!/bin/bash\necho hello") + + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + instance = MagicMock() + mock_trainer.return_value = instance + + result = executor._get_custom_trainer(script_task) + + assert result == instance + mock_trainer.assert_called_once() + + kwargs = mock_trainer.call_args[1] + assert kwargs["command"] == ["torchrun"] + args_list = kwargs.get("args") + assert isinstance(args_list, list) and len(args_list) >= 2 + args_joined = " ".join(args_list) + assert "--nnodes ${PET_NNODES}" in args_joined + assert "--nproc_per_node ${PET_NPROC_PER_NODE}" in args_joined + assert "--rdzv_backend c10d" in args_joined + assert "--rdzv_endpoint ${PET_MASTER_ADDR}:${PET_MASTER_PORT}" in args_joined + assert "--no-python" in args_joined + + +def test_kubeflow_executor_pass_through_bash_with_torchrun(): + executor = KubeflowExecutor(nodes=2, ntasks_per_node=8) + executor.packager = ConfigMapPackager() + executor.assign("exp-def456", "/tmp/exp", "task-2", "task_dir") + + script_task = Script(entrypoint="bash", inline="#!/bin/bash\n torchrun train.py") + + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + instance = MagicMock() + mock_trainer.return_value = instance + + result = executor._get_custom_trainer(script_task) + + assert result == instance + mock_trainer.assert_called_once() + + kwargs = mock_trainer.call_args[1] + mounted_path = f"{executor.workspace_mount_path}/{executor.training_entry}" + # Pass-through: command should be the staged script path, no PET flags injection + assert kwargs["command"] == [mounted_path] + args_list = kwargs.get("args") + assert args_list == [] + + +def test_kubeflow_executor_injects_torchrun_for_partial(): + """Partial should also run under torchrun using the launcher transform.""" + executor = KubeflowExecutor(nodes=2, ntasks_per_node=8) + executor.packager = ConfigMapPackager() + executor.assign("exp-partial", "/tmp/exp", "task-3", "task_dir") + + def _dummy(x, y=2): + return x + y + + task = Partial(_dummy, 1, y=3) + + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + instance = MagicMock() + mock_trainer.return_value = instance + + result = executor._get_custom_trainer(task) + + assert result == instance + mock_trainer.assert_called_once() + + kwargs = mock_trainer.call_args[1] + assert kwargs["command"] in (["/bin/bash"], ["torchrun"]) + args_list = kwargs.get("args") + assert isinstance(args_list, list) and len(args_list) >= 2 + args_joined = " ".join(args_list) + assert (kwargs["command"][0] == "torchrun") or ("torchrun" in args_joined) + assert "--nnodes ${PET_NNODES}" in args_joined + assert "--nproc_per_node ${PET_NPROC_PER_NODE}" in args_joined + assert "--rdzv_backend c10d" in args_joined + assert "--rdzv_endpoint ${PET_MASTER_ADDR}:${PET_MASTER_PORT}" in args_joined + + +def test_executor_additional_packages_forwarding(): + script_task = Script(inline="python train.py") + executor = KubeflowExecutor(nodes=1, ntasks_per_node=4) + executor.packager = ConfigMapPackager() + executor.assign("exp-abc123", "/tmp/exp", "task-1", "task_dir") + + executor.additional_packages = AdditionalPackages( + packages_to_install=["nemo==2.0.0", "deepspeed>=0.14.0"], + pip_index_urls=["https://pypi.org/simple", "https://extra/simple"], + pip_extra_args=["--no-cache-dir", "--find-links", "/wheels"], + ) + + with patch("nemo_run.core.execution.kubeflow.CommandTrainer") as mock_trainer: + instance = MagicMock() + mock_trainer.return_value = instance + + res = executor._get_custom_trainer(script_task) + + assert res == instance + kwargs = mock_trainer.call_args[1] + assert kwargs["packages_to_install"] == ["nemo==2.0.0", "deepspeed>=0.14.0"] + assert kwargs["pip_index_urls"] == ["https://pypi.org/simple", "https://extra/simple"] + assert kwargs["pip_extra_args"] == ["--no-cache-dir", "--find-links", "/wheels"] diff --git a/test/core/packaging/test_configmap.py b/test/core/packaging/test_configmap.py new file mode 100644 index 00000000..3c7ad8e4 --- /dev/null +++ b/test/core/packaging/test_configmap.py @@ -0,0 +1,461 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from nemo_run.core.packaging.base import sanitize_kubernetes_name +from nemo_run.core.packaging.configmap import ConfigMapPackager + + +class TestSanitizeKubernetesName: + """Test cases for the sanitize_kubernetes_name function.""" + + @pytest.mark.parametrize( + "input_name,expected_output", + [ + # Basic sanitization + ("test_name", "test-name"), + ("my_experiment_id", "my-experiment-id"), + ("task_dir", "task-dir"), + # No underscores - should remain unchanged + ("test-name", "test-name"), + ("experiment", "experiment"), + ("taskdir", "taskdir"), + # Multiple consecutive underscores + ("test__name", "test--name"), + ("my___experiment", "my---experiment"), + # Underscores at the beginning and end + ("_test_name_", "-test-name-"), + ("_experiment", "-experiment"), + ("experiment_", "experiment-"), + # Edge cases + ("", ""), + ("_", "-"), + # Mixed characters including underscores + ("test_123_name", "test-123-name"), + ("my-experiment_123", "my-experiment-123"), + ("mistral_training_task_dir", "mistral-training-task-dir"), + # Real-world examples + ("mistral_training", "mistral-training"), + ("nemo_mistral_workspace", "nemo-mistral-workspace"), + ("task_dir", "task-dir"), + ], + ) + def test_sanitize_kubernetes_name(self, input_name, expected_output): + """Test the sanitize_kubernetes_name function with various inputs.""" + assert sanitize_kubernetes_name(input_name) == expected_output + + +class TestConfigMapPackager: + """Test cases for the ConfigMapPackager class.""" + + def test_configmap_packager_default_init(self): + """Test that ConfigMapPackager initializes with default values.""" + packager = ConfigMapPackager() + + assert packager.include_pattern == "*.py" + assert packager.relative_path == "." + assert packager.namespace == "default" + assert packager.configmap_prefix == "nemo-workspace" + + def test_configmap_packager_custom_init(self): + """Test that ConfigMapPackager initializes with custom values.""" + packager = ConfigMapPackager( + include_pattern=["*.py", "*.yaml"], + relative_path=["src", "config"], + namespace="training", + configmap_prefix="custom-prefix", + ) + + assert packager.include_pattern == ["*.py", "*.yaml"] + assert packager.relative_path == ["src", "config"] + assert packager.namespace == "training" + assert packager.configmap_prefix == "custom-prefix" + + @pytest.mark.parametrize( + "rel_path,expected_key", + [ + # Basic file names + (Path("mistral.py"), "mistral.py"), + (Path("train.py"), "train.py"), + # Files with nested paths (forward slashes become hyphens) + (Path("src/train.py"), "src-train.py"), + (Path("config/model.yaml"), "config-model.yaml"), + (Path("src/models/mistral.py"), "src-models-mistral.py"), + (Path("configs/training/hyperparams.yaml"), "configs-training-hyperparams.yaml"), + # Edge cases + (Path("file.with.dots.py"), "file.with.dots.py"), + # Real-world examples + (Path("src/training/script.py"), "src-training-script.py"), + ], + ) + def test_sanitize_configmap_key(self, rel_path, expected_key): + """Test the _sanitize_configmap_key method with various inputs.""" + packager = ConfigMapPackager() + result = packager._sanitize_configmap_key(rel_path) + assert result == expected_key + + @pytest.mark.parametrize( + "rel_path,expected_key", + [ + # Test that forward slashes are properly replaced with hyphens + (Path("some/dir/mistral.py"), "some-dir-mistral.py"), + (Path("workspace/subdir/src/train.py"), "workspace-subdir-src-train.py"), + ( + Path("nemo/mistral/workspace/config/model.yaml"), + "nemo-mistral-workspace-config-model.yaml", + ), + # Test with multiple forward slashes + (Path("task/dir/subdir/file.py"), "task-dir-subdir-file.py"), + (Path("src/models/mistral.py"), "src-models-mistral.py"), + # Test with mixed forward slashes and existing hyphens + (Path("task-dir/subdir/file.py"), "task-dir-subdir-file.py"), + (Path("workspace/sub-dir/src/train.py"), "workspace-sub-dir-src-train.py"), + ], + ) + def test_sanitize_configmap_key_forward_slash_replacement(self, rel_path, expected_key): + """Test that forward slashes are properly replaced with hyphens in ConfigMap keys.""" + packager = ConfigMapPackager() + result = packager._sanitize_configmap_key(rel_path) + assert result == expected_key + + def test_sanitize_configmap_key_with_simple_filename(self): + """Test _sanitize_configmap_key with simple filename.""" + packager = ConfigMapPackager() + result = packager._sanitize_configmap_key(Path("mistral.py")) + assert result == "mistral.py" + + def test_sanitize_configmap_key_with_special_characters(self): + """Test _sanitize_configmap_key keeps underscores in keys (allowed by K8s).""" + packager = ConfigMapPackager() + result = packager._sanitize_configmap_key(Path("file_with_underscores.py")) + assert result == "file_with_underscores.py" + + def test_sanitize_configmap_key_with_complex_paths(self): + """Test _sanitize_configmap_key with complex nested paths.""" + packager = ConfigMapPackager() + + # Test deeply nested paths + result = packager._sanitize_configmap_key(Path("src/models/transformers/mistral/config.py")) + expected = "src-models-transformers-mistral-config.py" + assert result == expected + + def test_find_files_to_package_with_multiple_patterns(self): + """Test _find_files_to_package with multiple include patterns.""" + packager = ConfigMapPackager( + include_pattern=["*.py", "*.yaml"], relative_path=["src", "config"] + ) + + # Create test directory structure + with ( + patch("pathlib.Path.exists", return_value=True), + patch("pathlib.Path.rglob") as mock_rglob, + patch("pathlib.Path.is_file", return_value=True), + ): + # Mock files found by rglob + mock_files = [ + Path("/tmp/src/train.py"), + Path("/tmp/src/model.py"), + Path("/tmp/config/hyperparams.yaml"), + Path("/tmp/config/config.yaml"), + ] + mock_rglob.return_value = mock_files + + result = packager._find_files_to_package(Path("/tmp")) + + # Should find all files from both patterns + assert len(result) == 4 + assert all(file in result for file in mock_files) + + def test_find_files_to_package_with_nonexistent_paths(self): + """Test _find_files_to_package when search paths don't exist.""" + packager = ConfigMapPackager(include_pattern=["*.py"], relative_path=["nonexistent"]) + + with patch("pathlib.Path.exists", return_value=False): + result = packager._find_files_to_package(Path("/tmp")) + + # Should return empty list when paths don't exist + assert result == [] + + def test_package_with_file_reading_exception(self): + """Test package method when file reading fails.""" + tmp_path = Path("/tmp") + mock_v1 = MagicMock() + + with ( + patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ), + patch("pathlib.Path.exists", return_value=True), + patch("pathlib.Path.rglob", return_value=[Path("/tmp/test.py")]), + patch("pathlib.Path.is_file", return_value=True), + patch("pathlib.Path.stat") as mock_stat, + patch("builtins.open", side_effect=PermissionError("Permission denied")), + ): + mock_stat.return_value.st_size = 100 + packager = ConfigMapPackager() + configmap_name = packager.package(tmp_path, "task-dir", "testjob") + + # Should return configmap name but not create it due to file reading error + assert configmap_name == "nemo-workspace-testjob" + assert not mock_v1.create_namespaced_config_map.called + + def test_package_with_configmap_already_exists(self): + """Test package method when ConfigMap already exists (409 conflict).""" + tmp_path = Path("/tmp") + mock_v1 = MagicMock() + + # Mock ApiException for 409 conflict + from kubernetes.client.exceptions import ApiException + + mock_v1.create_namespaced_config_map.side_effect = ApiException(status=409) + + with ( + patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ), + patch("pathlib.Path.exists", return_value=True), + patch("pathlib.Path.rglob", return_value=[Path("/tmp/test.py")]), + patch("pathlib.Path.is_file", return_value=True), + patch("pathlib.Path.stat") as mock_stat, + patch("builtins.open", create=True) as mock_open, + ): + mock_stat.return_value.st_size = 100 + mock_open.return_value.__enter__.return_value.read.return_value = "print('hello')" + + packager = ConfigMapPackager() + configmap_name = packager.package(tmp_path, "task-dir", "testjob") + + # Should return configmap name even when it already exists + assert configmap_name == "nemo-workspace-testjob" + mock_v1.create_namespaced_config_map.assert_called_once() + + def test_package_with_other_api_exception(self): + """Test package method when ConfigMap creation fails with other error.""" + tmp_path = Path("/tmp") + mock_v1 = MagicMock() + + # Mock ApiException for other error + from kubernetes.client.exceptions import ApiException + + mock_v1.create_namespaced_config_map.side_effect = ApiException(status=500) + + with ( + patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ), + patch("pathlib.Path.exists", return_value=True), + patch("pathlib.Path.rglob", return_value=[Path("/tmp/test.py")]), + patch("pathlib.Path.is_file", return_value=True), + patch("pathlib.Path.stat") as mock_stat, + patch("builtins.open", create=True) as mock_open, + ): + mock_stat.return_value.st_size = 100 + mock_open.return_value.__enter__.return_value.read.return_value = "print('hello')" + + packager = ConfigMapPackager() + configmap_name = packager.package(tmp_path, "task-dir", "testjob") + + # Should return configmap name even when creation fails + assert configmap_name == "nemo-workspace-testjob" + mock_v1.create_namespaced_config_map.assert_called_once() + + def test_cleanup_with_configmap_not_found(self): + """Test cleanup when ConfigMap doesn't exist (404 error).""" + mock_v1 = MagicMock() + + # Mock ApiException for 404 not found + from kubernetes.client.exceptions import ApiException + + mock_v1.delete_namespaced_config_map.side_effect = ApiException(status=404) + + with patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ): + packager = ConfigMapPackager() + # Should not raise exception when ConfigMap doesn't exist + packager.cleanup("testjob") + mock_v1.delete_namespaced_config_map.assert_called_once() + + def test_cleanup_with_other_api_exception(self): + """Test cleanup when ConfigMap deletion fails with other error.""" + mock_v1 = MagicMock() + + # Mock ApiException for other error + from kubernetes.client.exceptions import ApiException + + mock_v1.delete_namespaced_config_map.side_effect = ApiException(status=500) + + with patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ): + packager = ConfigMapPackager() + # Should not raise exception when deletion fails + packager.cleanup("testjob") + mock_v1.delete_namespaced_config_map.assert_called_once() + + +@pytest.fixture +def temp_py_files(tmp_path): + """Create test files for packaging.""" + # Create some test files + file1 = tmp_path / "a.py" + file2 = tmp_path / "b.py" + file3 = tmp_path / "subdir" / "c.py" + file3.parent.mkdir() + + file1.write_text("print('A')\n") + file2.write_text("print('B')\n") + file3.write_text("print('C')\n") + + return tmp_path, [file1, file2, file3] + + +def test_package_creates_configmap_with_job_dir(temp_py_files): + """Test that package creates a ConfigMap with the correct data.""" + tmp_path, files = temp_py_files + mock_v1 = MagicMock() + + with patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ): + packager = ConfigMapPackager(include_pattern="*.py", relative_path=".", namespace="test-ns") + configmap_name = packager.package(tmp_path, "test-job", "testjob") + + assert configmap_name == "nemo-workspace-testjob" + assert mock_v1.create_namespaced_config_map.called + + _, kwargs = mock_v1.create_namespaced_config_map.call_args + assert kwargs["namespace"] == "test-ns" + + data = kwargs["body"].data + for file_path in files: + rel_path = file_path.relative_to(tmp_path) + configmap_key = packager._sanitize_configmap_key(rel_path) + assert configmap_key in data + assert data[configmap_key] == file_path.read_text() + + +def test_cleanup_deletes_configmap(): + """Test that cleanup deletes the ConfigMap.""" + mock_v1 = MagicMock() + + with patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ): + packager = ConfigMapPackager() + packager.cleanup("testjob") + + assert mock_v1.delete_namespaced_config_map.called + _, kwargs = mock_v1.delete_namespaced_config_map.call_args + assert kwargs["name"] == "nemo-workspace-testjob" + assert kwargs["namespace"] == "default" + + +def test_find_files_to_package(temp_py_files): + """Test file finding logic.""" + tmp_path, files = temp_py_files + + # Add a non-Python file to test filtering + txt_file = tmp_path / "b.txt" + txt_file.write_text("text file") + + packager = ConfigMapPackager(include_pattern="*.py", relative_path=".") + found_files = packager._find_files_to_package(tmp_path) + + # Use files from fixture to make test maintainable + assert len(found_files) == len(files) # Should find all Python files from fixture + + # Check that all fixture files are found + for file_path in files: + assert file_path in found_files + + # Check that the non-Python file is NOT found + assert txt_file not in found_files + + +def test_package_no_files_found(temp_py_files): + """Test behavior when no files match the pattern.""" + tmp_path, _ = temp_py_files + mock_v1 = MagicMock() + + with patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ): + packager = ConfigMapPackager(include_pattern="*.nonexistent", relative_path=".") + configmap_name = packager.package(tmp_path, "test-job", "testjob") + + assert configmap_name == "nemo-workspace-testjob" + # Should not call create_namespaced_config_map + assert not mock_v1.create_namespaced_config_map.called + + +def test_package_kubernetes_client_unavailable(temp_py_files): + """Test behavior when Kubernetes client is not available.""" + tmp_path, _ = temp_py_files + + with patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", None), + ): + packager = ConfigMapPackager() + configmap_name = packager.package(tmp_path, "test-job", "testjob") + + assert configmap_name == "nemo-workspace-testjob" + + +def test_cleanup_kubernetes_client_unavailable(): + """Test cleanup behavior when Kubernetes client is not available.""" + with patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", None), + ): + packager = ConfigMapPackager() + # Should not raise any exception + packager.cleanup("testjob") + + +def test_package_with_large_files(temp_py_files): + """Test that package handles large files appropriately.""" + tmp_path, files = temp_py_files + mock_v1 = MagicMock() + + # Create a large file that would exceed the 1MB limit + large_file = tmp_path / "large_file.py" + large_content = "print('x')\n" * 200000 # Create a large file (~1.2MB) + large_file.write_text(large_content) + + with patch( + "nemo_run.core.packaging.configmap.ConfigMapPackager.__post_init__", + lambda self: setattr(self, "v1", mock_v1), + ): + packager = ConfigMapPackager(include_pattern="*.py", relative_path=".", debug=True) + configmap_name = packager.package(tmp_path, "test-job", "testjob") + + # Should return the configmap name but not create it due to size limit + assert configmap_name == "nemo-workspace-testjob" + # Should not call create_namespaced_config_map due to size limit + assert not mock_v1.create_namespaced_config_map.called diff --git a/uv.lock b/uv.lock index 363caffc..4b8ba285 100644 --- a/uv.lock +++ b/uv.lock @@ -3775,6 +3775,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/53/6b/caf27d5a40618c7e945a1c68e1961c2d3637edfce9ebb0edc27c9ff53c1c/knack-0.11.0-py3-none-any.whl", hash = "sha256:6704c867840978a119a193914a90e2e98c7be7dff764c8fcd8a2286c5a978d00", size = 60848, upload-time = "2023-07-26T06:23:30.221Z" }, ] +[[package]] +name = "kubeflow" +version = "0.1.0" +source = { git = "https://github.com/jskswamy/kubeflow-sdk.git?subdirectory=python#42739715e04aee91e7b4a13735ff96f603d035b4" } +dependencies = [ + { name = "kubeflow-trainer-api" }, + { name = "kubernetes" }, + { name = "pydantic" }, +] + +[[package]] +name = "kubeflow-trainer-api" +version = "2.0.0" +source = { git = "https://github.com/kubeflow/trainer.git?subdirectory=api%2Fpython_api&rev=master#d997dd96f38feeda45af2a24179e515d388425e4" } +dependencies = [ + { name = "pydantic" }, +] + [[package]] name = "kubernetes" version = "32.0.1" @@ -4407,6 +4425,8 @@ dependencies = [ { name = "fiddle" }, { name = "inquirerpy" }, { name = "jinja2" }, + { name = "kubeflow" }, + { name = "kubernetes" }, { name = "leptonai" }, { name = "networkx" }, { name = "omegaconf" }, @@ -4418,6 +4438,10 @@ dependencies = [ ] [package.optional-dependencies] +kubernetes = [ + { name = "kubeflow" }, + { name = "kubernetes" }, +] ray = [ { name = "kubernetes" }, ] @@ -4463,6 +4487,10 @@ requires-dist = [ { name = "fiddle", specifier = ">=0.3.0" }, { name = "inquirerpy", specifier = ">=0.3.4" }, { name = "jinja2", specifier = ">=3.1.4" }, + { name = "kubeflow", git = "https://github.com/jskswamy/kubeflow-sdk.git?subdirectory=python" }, + { name = "kubeflow", marker = "extra == 'kubernetes'", git = "https://github.com/jskswamy/kubeflow-sdk.git?subdirectory=python" }, + { name = "kubernetes", specifier = ">=28.0.0" }, + { name = "kubernetes", marker = "extra == 'kubernetes'", specifier = ">=28.0.0" }, { name = "kubernetes", marker = "extra == 'ray'" }, { name = "leptonai", specifier = ">=0.25.0" }, { name = "networkx", specifier = ">=3.3" }, @@ -4475,7 +4503,7 @@ requires-dist = [ { name = "torchx", specifier = ">=0.7.0" }, { name = "typer", specifier = ">=0.12.3" }, ] -provides-extras = ["ray", "skypilot", "skypilot-all"] +provides-extras = ["kubernetes", "ray", "skypilot", "skypilot-all"] [package.metadata.requires-dev] dev = [