Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion nemo_run/core/execution/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import subprocess
import tempfile
import time
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, Optional
Expand All @@ -31,6 +31,8 @@

from nemo_run.config import get_nemorun_home
from nemo_run.core.execution.base import Executor, ExecutorMacros
from nemo_run.core.execution.launcher import FaultTolerance, Launcher
from nemo_run.core.execution.utils import fill_template
from nemo_run.core.packaging.base import Packager
from nemo_run.core.packaging.git import GitArchivePackager

Expand Down Expand Up @@ -556,3 +558,55 @@
if token:
headers["Authorization"] = f"Bearer {token}"
return headers


@dataclass(kw_only=True)
class DGXCloudRequest:
launch_cmd: list[str]
jobs: list[str]
executor: DGXCloudExecutor
max_retries: int
extra_env: dict[str, str]
launcher: Optional[Launcher] = None

def materialize(self) -> str:
"""Creates the content of a DGXC entrypoint script."""

# 1. Environment Variables
# Combine executor defaults with extra envs
env_vars = []
full_env_vars = self.executor.env_vars | self.extra_env
for key, value in full_env_vars.items():
env_vars.append(f"export {key.upper()}={value}")

# 3. Prepare Template Variables
vars_to_fill = {
"max_retries": self.max_retries,
"env_vars": env_vars,
"training_command": " ".join(self.launch_cmd),
"ft_enabled": self.launcher and isinstance(self.launcher, FaultTolerance),
}

# 4. Fault Tolerance Injection
if self.launcher and isinstance(self.launcher, FaultTolerance):
assert (
self.launcher.cfg_path
and self.launcher.finished_flag_file
and self.launcher.job_results_file
), "Fault Tolerance requires cfg_path, finished_flag_file, and job_results_file"

vars_to_fill["fault_tol_cfg_path"] = self.launcher.cfg_path
vars_to_fill["fault_tol_finished_flag_file"] = self.launcher.finished_flag_file
vars_to_fill["fault_tol_job_results_file"] = self.launcher.job_results_file

# Render the template
entrypoint_script = fill_template("dgxc.sh.j2", vars_to_fill)
return entrypoint_script

def __repr__(self) -> str:
return f"""# DGXC Entrypoint Script Request
# Executor: {self.executor.__class__.__name__}
# Jobs: {self.jobs}
# ---------------------------------------------------
{self.materialize()}
"""
47 changes: 47 additions & 0 deletions nemo_run/core/execution/templates/dgxc.sh.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{%- import "ft_launcher_k8s.j2" as fault_tolerance -%}
#!/bin/bash
#
# Generated by NeMo Run for Kubernetes (PyTorchJob)
#

# 1. Basic Shell Setup
set -evx # Print commands, but DO NOT exit immediately on error (we handle that below)
export PYTHONUNBUFFERED=1
export TORCHX_MAX_RETRIES={{max_retries}}

# 2. Environment Variables
# These are strictly user-defined vars (e.g. HYDRA_FULL_ERROR).
# Note: MASTER_ADDR, WORLD_SIZE, RANK are injected automatically by the PyTorchJob operator.
{%- for env_var in env_vars %}
{{env_var}}
{%- endfor %}

# 3. Fault Tolerance: SETUP (Check-in)
# Checks if we are resuming or if we are already finished.
{%- if ft_enabled %}
{{ fault_tolerance.ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) }}
{%- endif %}

# 4. Main Execution
# In PyTorchJob, we usually have exactly one main command (torchrun).
# We assume the variable 'training_command' contains the full torchrun string.

echo "Starting training command..."
set +e # Turn off auto-exit so we can capture the code
# ---------------------------------------------------------
{{ training_command }}
# ---------------------------------------------------------
exitcode=$?
set -e

echo "Main command exited with code $exitcode"

# 5. Fault Tolerance: TEARDOWN (Check-out)
# Decides if we should exit 0 (complete) or exit 1 (retry via K8s backoffLimit).
{%- if ft_enabled %}
{{ fault_tolerance.ft_launcher_teardown() }}
{%- else %}
# If FT is disabled, simply pass the exit code through.
# K8s will restart if exitcode != 0 and backoffLimit > 0.
exit $exitcode
{%- endif %}
69 changes: 69 additions & 0 deletions nemo_run/core/execution/templates/ft_launcher_dgxc.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
{% macro ft_launcher_setup(fault_tol_cfg_path, fault_tol_finished_flag_file, fault_tol_job_results_file) -%}
# -------------------------------------------------------------------------
# K8s Fault Tolerance Setup (The "Check-In" Desk)
# -------------------------------------------------------------------------

# 1. Export Paths
# IMPORTANT: These paths must reside on a ReadWriteMany (RWX) Persistent Volume
# mounted to all Pods so state is preserved across pod restarts/rescheduling.
export FAULT_TOL_CFG_PATH="{{fault_tol_cfg_path}}"
export FAULT_TOL_FINISHED_FLAG_FILE="{{fault_tol_finished_flag_file}}"
export FAULT_TOL_JOB_RESULTS_FILE="{{fault_tol_job_results_file}}"

# 2. Define Helper Functions
is_training_finished() {
test -f "$FAULT_TOL_FINISHED_FLAG_FILE"
}

# 3. Check for Previous Success
# In K8s, a Pod might be restarted due to node maintenance even if the job
# logic was done. If the flag file exists, we exit immediately with 0.
if is_training_finished ; then
echo "[FT-Setup] Found finished flag at $FAULT_TOL_FINISHED_FLAG_FILE."
echo "[FT-Setup] Training is already complete. Exiting successfully."
exit 0
fi

# 4. Logging Start
# We use HOSTNAME (usually pod-name) as the identifier since SLURM_JOB_ID is gone.
# We append 'X' (Running/Unknown) to the log.
echo "[FT-Setup] Starting training on $(hostname)..."
# Optional: Log attempt to shared file (Using X for Running)
# Note: In high-scale K8s, writing to a single file from 1000 pods can cause lock contention.
# If scale is small, this is fine.
if [ -n "$FAULT_TOL_JOB_RESULTS_FILE" ]; then
echo "$(hostname) $(date +%s) X" >> "$FAULT_TOL_JOB_RESULTS_FILE"
fi

{%- endmacro %}

{% macro ft_launcher_teardown() -%}
# -------------------------------------------------------------------------
# K8s Fault Tolerance Teardown (The "Check-Out" Desk)
# -------------------------------------------------------------------------

# 1. Analyze Exit Code from the Main Command
# 'exitcode' is captured in the main script before calling this macro.
if [ "$exitcode" -eq "0" ]; then
RESULT_STATUS="S" # Success
else
RESULT_STATUS="F" # Failure
fi

# 2. Update Log (Optional but helpful for debugging)
if [ -n "$FAULT_TOL_JOB_RESULTS_FILE" ]; then
# We update the specific entry for this host from X to S or F
# Note: 'sed -i' on a shared PVC can be risky with concurrency.
# Appending a new status line is safer in K8s.
echo "$(hostname) $(date +%s) $RESULT_STATUS" >> "$FAULT_TOL_JOB_RESULTS_FILE"
fi

# 3. The Requeue Decision Logic
if [ "$exitcode" -eq "0" ]; then
# Case A: Script exited successfully.
# Verification: Did it actually finish (create the flag file)?
if is_training_finished; then
echo "[FT-Teardown] Job finished successfully and flag file exists."
exit 0
else
# Edge Case: The python script exited 0, but didn't write the flag
2 changes: 1 addition & 1 deletion nemo_run/core/execution/templates/slurm.sh.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{%- import "ft_launcher.j2" as fault_tolerance -%}
{%- import "ft_launcher_slurm.j2" as fault_tolerance -%}
#!/bin/bash
#
# Generated by NeMo Run
Expand Down
1 change: 1 addition & 0 deletions nemo_run/run/torchx_backend/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def _get_details_from_script(fn_or_script: Script, serialize_configs: bool):
log_level=launcher.log_level,
max_retries=executor.retries,
max_restarts=launcher.max_restarts,
dgxc=isinstance(executor, DGXCloudExecutor),
use_env=use_env,
)
else:
Expand Down
23 changes: 21 additions & 2 deletions nemo_run/run/torchx_backend/schedulers/dgxcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from nemo_run.config import get_nemorun_home
from nemo_run.core.execution.base import Executor
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudRequest, DGXCloudState
from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer
from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin

Expand Down Expand Up @@ -109,6 +109,23 @@ def _submit_dryrun( # type: ignore
role = values.apply(role)

cmd = [role.entrypoint] + role.args

req = DGXCloudRequest(
launch_cmd=cmd,
jobs=[role.name],
executor=executor,
max_retries=role.max_retries,
extra_env=role.env,
launcher=executor.get_launcher(),
)

# Write and copy sbatch script
path = os.path.join(executor.experiment_dir, f"{executor.job_name}_job.sh")
script = req.materialize()

with open(path, "w") as f:
f.write(script)

return AppDryRunInfo(
DGXRequest(app=app, executor=executor, cmd=cmd, name=role.name),
# Minimal function to show the config, if any
Expand All @@ -128,7 +145,9 @@ def schedule(self, dryrun_info: AppDryRunInfo[DGXRequest]) -> str:

# The DGXExecutor's launch call typically returns (job_id, handle).
# We'll call it without additional parameters here.
job_id, status = executor.launch(name=req.name, cmd=req.cmd)
cmd = os.path.join(executor.experiment_dir, f"{executor.job_name}_job.sh")
req.launch_cmd = ["bash", cmd]
job_id, status = executor.launch(name=req.name, cmd=req.launch_cmd)
if not job_id:
raise RuntimeError("Failed scheduling run on DGX: no job_id returned")

Expand Down
Loading