diff --git a/.github/dashboard/index.html b/.github/dashboard/index.html
index 094ee9697..07728d33a 100644
--- a/.github/dashboard/index.html
+++ b/.github/dashboard/index.html
@@ -342,7 +342,7 @@
Benchmark Dashboard
1. CONSTANTS & COLOR PALETTE
================================================================ */
// HSL-based palette: each model gets a distinct base hue; backend offsets keep
-// ATOM and ATOM-vLLM visually related but not identical.
+// ATOM, ATOM-vLLM, and ATOM-SGLang visually related but not identical.
const MODEL_HUES = {
'DeepSeek-R1-0528': 210, // blue
'DeepSeek-R1-0528-mtp3': 175, // cyan/teal — distinct from base DeepSeek
@@ -356,6 +356,7 @@ Benchmark Dashboard
const BACKEND_HUE_OFFSETS = {
'ATOM': 0,
'ATOM-vLLM': 28,
+ 'ATOM-SGLang': -30,
};
const FALLBACK_HUES = [45, 330, 190, 30]; // yellow, pink, teal, amber
let fallbackHueIdx = 0;
diff --git a/.github/scripts/oot_benchmark_regression.py b/.github/scripts/plugin_benchmark_regression.py
old mode 100644
new mode 100755
similarity index 93%
rename from .github/scripts/oot_benchmark_regression.py
rename to .github/scripts/plugin_benchmark_regression.py
index cf31c124f..9903ce98b
--- a/.github/scripts/oot_benchmark_regression.py
+++ b/.github/scripts/plugin_benchmark_regression.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-"""OOT-specific regression summary built on top of shared summarize helpers."""
+"""Regression summary built on top of shared summarize helpers."""
from __future__ import annotations
@@ -39,7 +39,7 @@ def build_report(
def main() -> int:
parser = argparse.ArgumentParser(
- description="Print only the OOT regression report without the full results table"
+ description="Print only the regression report without the full results table"
)
parser.add_argument(
"--result-dir",
@@ -88,4 +88,4 @@ def main() -> int:
if __name__ == "__main__":
- raise SystemExit(main())
+ raise SystemExit(main())
\ No newline at end of file
diff --git a/.github/scripts/oot_benchmark_summary.py b/.github/scripts/plugin_benchmark_summary.py
old mode 100644
new mode 100755
similarity index 93%
rename from .github/scripts/oot_benchmark_summary.py
rename to .github/scripts/plugin_benchmark_summary.py
index ed6f21883..db753aea1
--- a/.github/scripts/oot_benchmark_summary.py
+++ b/.github/scripts/plugin_benchmark_summary.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-"""Generate a resilient OOT benchmark summary table.
+"""Generate a resilient benchmark summary table.
This script is intentionally tolerant of partial or total benchmark failure:
- missing result JSON => case is marked FAIL
@@ -143,12 +143,12 @@ def _build_rows(result_dir: Path, matrix_payload: dict) -> list[dict]:
return rows
-def _print_markdown_table(rows: list[dict], run_url: str | None) -> None:
+def _print_markdown_table(rows: list[dict], run_url: str | None, title: str) -> None:
total_cases = len(rows)
passed_cases = sum(1 for row in rows if row["status"] == "PASS")
failed_cases = total_cases - passed_cases
- print("## OOT Benchmark Summary\n")
+ print(f"## {title}\n")
if run_url:
print(f"Run: {run_url}\n")
print(
@@ -185,11 +185,11 @@ def _print_markdown_table(rows: list[dict], run_url: str | None) -> None:
def main() -> int:
- parser = argparse.ArgumentParser(description="Summarize OOT benchmark results")
+ parser = argparse.ArgumentParser(description="Summarize benchmark results")
parser.add_argument(
"--result-dir",
required=True,
- help="Directory containing downloaded OOT benchmark JSON files",
+ help="Directory containing downloaded benchmark JSON files",
)
parser.add_argument(
"--matrix-json",
@@ -206,12 +206,17 @@ def main() -> int:
default=None,
help="Optional path to write a structured summary report",
)
+ parser.add_argument(
+ "--title",
+ default="Benchmark Summary",
+ help="Title for the markdown report",
+ )
args = parser.parse_args()
matrix_payload = json.loads(args.matrix_json)
rows = _build_rows(Path(args.result_dir), matrix_payload)
- _print_markdown_table(rows, args.run_url)
+ _print_markdown_table(rows, args.run_url, args.title)
if args.output_json:
report = {
@@ -229,4 +234,4 @@ def main() -> int:
if __name__ == "__main__":
- raise SystemExit(main())
+ raise SystemExit(main())
\ No newline at end of file
diff --git a/.github/scripts/oot_benchmark_to_dashboard.py b/.github/scripts/plugin_benchmark_to_dashboard.py
old mode 100644
new mode 100755
similarity index 80%
rename from .github/scripts/oot_benchmark_to_dashboard.py
rename to .github/scripts/plugin_benchmark_to_dashboard.py
index 3cf026013..c732390e0
--- a/.github/scripts/oot_benchmark_to_dashboard.py
+++ b/.github/scripts/plugin_benchmark_to_dashboard.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-"""Convert OOT benchmark JSON results to github-action-benchmark input."""
+"""Convert benchmark JSON results to github-action-benchmark input."""
from __future__ import annotations
@@ -9,7 +9,6 @@
from pathlib import Path
VARIANT_RE = re.compile(r"-(mtp\d*)-")
-DEFAULT_BACKEND = "ATOM-vLLM"
def derive_model_name(result_path: Path, payload: dict) -> str:
@@ -59,11 +58,11 @@ def is_dashboard_publish_allowed(payload: dict) -> bool:
return str(publish_flag).strip().lower() not in {"0", "false", "no"}
-def build_entries(result_dir: Path, run_url: str | None) -> list[dict]:
+def build_entries(result_dir: Path, run_url: str | None, default_backend: str) -> list[dict]:
entries: list[dict] = []
for result_path in sorted(result_dir.glob("*.json")):
- if result_path.name == "regression_report.json":
+ if result_path.name == "regression_report.json" or result_path.name.endswith("_benchmark_summary.json"):
continue
try:
@@ -81,20 +80,23 @@ def build_entries(result_dir: Path, run_url: str | None) -> list[dict]:
isl = int(payload.get("random_input_len", 0))
osl = int(payload.get("random_output_len", 0))
conc = int(payload.get("max_concurrency", 0))
- label_prefix = f"{DEFAULT_BACKEND}::{model} {isl}/{osl} c={conc}"
+ label_prefix = f"{default_backend}::{model} {isl}/{osl} c={conc}"
extra = f"Run: {run_url}" if run_url else ""
gpu_name = payload.get("gpu_name", "")
gpu_vram = payload.get("gpu_vram_gb", 0)
rocm_ver = payload.get("rocm_version", "")
- oot_image_tag = payload.get("oot_image_tag", "")
+
+ # Support both OOT and SGLang image tag fields
+ image_tag = payload.get("oot_image_tag", payload.get("sglang_image_tag", ""))
+
if gpu_name:
extra += f" | GPU: {gpu_name}"
if gpu_vram:
extra += f" | VRAM: {gpu_vram}GB"
if rocm_ver:
extra += f" | ROCm: {rocm_ver}"
- if oot_image_tag:
- extra += f" | Docker: {oot_image_tag}"
+ if image_tag:
+ extra += f" | Docker: {image_tag}"
extra = extra or None
append_metric(
@@ -145,10 +147,10 @@ def build_entries(result_dir: Path, run_url: str | None) -> list[dict]:
def main() -> None:
parser = argparse.ArgumentParser(
- description="Convert OOT benchmark JSON files to github-action-benchmark input"
+ description="Convert benchmark JSON files to github-action-benchmark input"
)
parser.add_argument(
- "result_dir", help="Directory containing OOT benchmark JSON files"
+ "result_dir", help="Directory containing benchmark JSON files"
)
parser.add_argument("--output", required=True, help="Output JSON path")
parser.add_argument(
@@ -156,15 +158,20 @@ def main() -> None:
default=None,
help="Optional GitHub Actions run URL added to each metric as extra metadata",
)
+ parser.add_argument(
+ "--default-backend",
+ required=True,
+ help="Default backend name (e.g. ATOM-SGLang or ATOM-vLLM)",
+ )
args = parser.parse_args()
result_dir = Path(args.result_dir)
- entries = build_entries(result_dir, args.run_url)
+ entries = build_entries(result_dir, args.run_url, args.default_backend)
output_path = Path(args.output)
output_path.write_text(json.dumps(entries, indent=2), encoding="utf-8")
- print(f"Generated {len(entries)} OOT entries at {output_path}")
+ print(f"Generated {len(entries)} entries at {output_path}")
if __name__ == "__main__":
- main()
+ main()
\ No newline at end of file
diff --git a/.github/scripts/oot_benchmark_validate_baseline.py b/.github/scripts/plugin_benchmark_validate_baseline.py
old mode 100644
new mode 100755
similarity index 85%
rename from .github/scripts/oot_benchmark_validate_baseline.py
rename to .github/scripts/plugin_benchmark_validate_baseline.py
index 35a317f00..23c03ce4e
--- a/.github/scripts/oot_benchmark_validate_baseline.py
+++ b/.github/scripts/plugin_benchmark_validate_baseline.py
@@ -7,8 +7,6 @@
import json
from pathlib import Path
-SKIP_FILENAMES = {"regression_report.json", "oot_benchmark_summary.json"}
-
def is_dashboard_publish_allowed(payload: dict) -> bool:
publish_flag = payload.get("dashboard_publish_allowed")
@@ -22,7 +20,7 @@ def is_dashboard_publish_allowed(payload: dict) -> bool:
def validate_result_dir(result_dir: Path) -> bool:
has_valid_result = False
for path in result_dir.rglob("*.json"):
- if path.name in SKIP_FILENAMES:
+ if path.name == "regression_report.json" or path.name.endswith("_benchmark_summary.json"):
continue
try:
payload = json.loads(path.read_text(encoding="utf-8"))
@@ -38,7 +36,7 @@ def validate_result_dir(result_dir: Path) -> bool:
def main() -> int:
parser = argparse.ArgumentParser(
- description="Validate dashboard-eligible OOT benchmark artifacts"
+ description="Validate dashboard-eligible benchmark artifacts"
)
parser.add_argument("result_dir", help="Directory containing downloaded artifacts")
args = parser.parse_args()
@@ -47,4 +45,4 @@ def main() -> int:
if __name__ == "__main__":
- raise SystemExit(main())
+ raise SystemExit(main())
\ No newline at end of file
diff --git a/.github/workflows/atom-sglang-benchmark.yaml b/.github/workflows/atom-sglang-benchmark.yaml
index 2c13dbe85..febc57a2a 100644
--- a/.github/workflows/atom-sglang-benchmark.yaml
+++ b/.github/workflows/atom-sglang-benchmark.yaml
@@ -832,11 +832,12 @@ jobs:
BENCHMARK_MATRIX: ${{ needs.build-benchmark-matrix.outputs.benchmark_matrix }}
RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
run: |
- python3 .github/scripts/sglang_benchmark_summary.py \
+ python3 .github/scripts/plugin_benchmark_summary.py \
--result-dir . \
--matrix-json "${BENCHMARK_MATRIX}" \
--run-url "${RUN_URL}" \
--output-json sglang_benchmark_summary.json \
+ --title "SGLang Benchmark Summary" \
>> "$GITHUB_STEP_SUMMARY"
- name: Read summary stats
@@ -888,7 +889,7 @@ jobs:
continue
fi
- if python3 .github/scripts/sglang_benchmark_validate_baseline.py /tmp/baseline_candidate; then
+ if python3 .github/scripts/plugin_benchmark_validate_baseline.py /tmp/baseline_candidate; then
mv /tmp/baseline_candidate /tmp/baseline
BASELINE_DIR="/tmp/baseline"
echo "Using baseline from run #$PREV_RUN_ID"
@@ -918,7 +919,7 @@ jobs:
>> "$GITHUB_STEP_SUMMARY"
fi
- python3 .github/scripts/sglang_benchmark_regression.py \
+ python3 .github/scripts/plugin_benchmark_regression.py \
--result-dir . \
$BASELINE_ARG \
--output-json regression_report.json \
@@ -945,10 +946,11 @@ jobs:
- name: Transform results for benchmark dashboard
if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0'
run: |
- python3 .github/scripts/sglang_benchmark_to_dashboard.py \
+ python3 .github/scripts/plugin_benchmark_to_dashboard.py \
. \
--output benchmark-action-input.json \
- --run-url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
+ --run-url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" \
+ --default-backend "ATOM-SGLang"
- name: Store benchmark result to dashboard
if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0'
@@ -965,19 +967,26 @@ jobs:
max-items-in-chart: 90
github-token: ${{ secrets.GITHUB_TOKEN }}
- - name: Deploy custom dashboard to gh-pages
+ - name: Push dashboard data to gh-pages
if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- CURRENT_SHA=$(git rev-parse HEAD)
- cp .github/dashboard/index.html /tmp/dashboard_index.html
- cp docs/assets/atom_logo.png /tmp/dashboard_logo.png
git fetch origin gh-pages
git checkout gh-pages
- cp /tmp/dashboard_index.html benchmark-dashboard/index.html
- cp /tmp/dashboard_logo.png benchmark-dashboard/atom_logo.png
- git add benchmark-dashboard/
- git diff --cached --quiet || git commit -m "Update SGLang benchmark data and dashboard"
- git push origin gh-pages
- git checkout "$CURRENT_SHA"
+
+ if [ ! -f benchmark-dashboard/data.js ]; then
+ echo "::error::benchmark-dashboard/data.js was not produced by github-action-benchmark"
+ exit 1
+ fi
+
+ cp benchmark-dashboard/data.js /tmp/sglang-benchmark-data.js
+
+ # Rebuild the publish branch from origin/gh-pages so only dashboard data is updated.
+ git checkout -B gh-pages-data-only origin/gh-pages
+ mkdir -p benchmark-dashboard
+ cp /tmp/sglang-benchmark-data.js benchmark-dashboard/data.js
+
+ git add benchmark-dashboard/data.js
+ git diff --cached --quiet || git commit -m "Update SGLang benchmark dashboard data"
+ git push origin HEAD:gh-pages
diff --git a/.github/workflows/atom-vllm-benchmark.yaml b/.github/workflows/atom-vllm-benchmark.yaml
index 7deb974d2..b9c02651a 100644
--- a/.github/workflows/atom-vllm-benchmark.yaml
+++ b/.github/workflows/atom-vllm-benchmark.yaml
@@ -973,11 +973,12 @@ jobs:
BENCHMARK_MATRIX: ${{ needs.build-benchmark-matrix.outputs.benchmark_matrix }}
RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
run: |
- python3 .github/scripts/oot_benchmark_summary.py \
+ python3 .github/scripts/plugin_benchmark_summary.py \
--result-dir . \
--matrix-json "${BENCHMARK_MATRIX}" \
--run-url "${RUN_URL}" \
--output-json oot_benchmark_summary.json \
+ --title "ATOM-vLLM Benchmark Summary" \
>> "$GITHUB_STEP_SUMMARY"
- name: Read summary stats
@@ -1029,7 +1030,7 @@ jobs:
continue
fi
- if python3 .github/scripts/oot_benchmark_validate_baseline.py /tmp/baseline_candidate; then
+ if python3 .github/scripts/plugin_benchmark_validate_baseline.py /tmp/baseline_candidate; then
mv /tmp/baseline_candidate /tmp/baseline
BASELINE_DIR="/tmp/baseline"
echo "Using baseline from run #$PREV_RUN_ID"
@@ -1059,7 +1060,7 @@ jobs:
>> "$GITHUB_STEP_SUMMARY"
fi
- python3 .github/scripts/oot_benchmark_regression.py \
+ python3 .github/scripts/plugin_benchmark_regression.py \
--result-dir . \
$BASELINE_ARG \
--output-json regression_report.json \
@@ -1086,10 +1087,11 @@ jobs:
- name: Transform results for benchmark dashboard
if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0'
run: |
- python3 .github/scripts/oot_benchmark_to_dashboard.py \
+ python3 .github/scripts/plugin_benchmark_to_dashboard.py \
. \
--output benchmark-action-input.json \
- --run-url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"
+ --run-url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" \
+ --default-backend "ATOM-vLLM"
- name: Store benchmark result to dashboard
if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0'
diff --git a/.github/workflows/docker-release.yaml b/.github/workflows/docker-release.yaml
index a54230f25..c2ed6e2b6 100644
--- a/.github/workflows/docker-release.yaml
+++ b/.github/workflows/docker-release.yaml
@@ -61,11 +61,11 @@ on:
type: boolean
default: true
only_release_oot:
- description: "Only release the OOT vLLM image: skip native image test/push and always build/push OOT"
+ description: "Release the OOT vLLM image (skips native image test/push). If both this and SGLang are checked, both will be released."
type: boolean
default: false
only_release_sglang:
- description: "Only release the SGLang+ATOM image: skip native image test/push and always build/push SGLang+ATOM"
+ description: "Release the SGLang+ATOM image (skips native image test/push). If both this and OOT are checked, both will be released."
type: boolean
default: false
@@ -193,7 +193,7 @@ jobs:
docker push rocm/atom-dev:${TAG}
- name: Build OOT Docker image
- if: ${{ success() && inputs.only_release_sglang != true && (github.event_name == 'schedule' || inputs.build_oot_image == true || inputs.only_release_oot == true) }}
+ if: ${{ success() && (inputs.only_release_oot == true || (inputs.only_release_sglang != true && (github.event_name == 'schedule' || inputs.build_oot_image == true))) }}
timeout-minutes: 180
run: |
OOT_BASE_IMAGE="atom_release:ci"
@@ -209,7 +209,7 @@ jobs:
docker inspect atom_oot_release:ci
- name: Push OOT Docker image
- if: ${{ success() && inputs.only_release_sglang != true && (github.event_name == 'schedule' || inputs.build_oot_image == true || inputs.only_release_oot == true) }}
+ if: ${{ success() && (inputs.only_release_oot == true || (inputs.only_release_sglang != true && (github.event_name == 'schedule' || inputs.build_oot_image == true))) }}
run: |
VLLM_VER="${{ env.VLLM_VERSION }}"
OOT_TAG="vllm-v${VLLM_VER}-nightly_$(date +%Y%m%d)"
@@ -220,7 +220,7 @@ jobs:
docker push rocm/atom-dev:${OOT_LATEST_TAG}
- name: Build SGLang Docker image
- if: ${{ success() && inputs.only_release_oot != true && (github.event_name == 'schedule' || inputs.build_sglang_image == true || inputs.only_release_sglang == true) }}
+ if: ${{ success() && (inputs.only_release_sglang == true || (inputs.only_release_oot != true && (github.event_name == 'schedule' || inputs.build_sglang_image == true))) }}
timeout-minutes: 180
run: |
SGLANG_BASE_IMAGE="atom_release:ci"
@@ -235,7 +235,7 @@ jobs:
docker inspect atom_sglang_release:ci
- name: Push SGLang+ATOM Docker image
- if: ${{ success() && inputs.only_release_oot != true && (github.event_name == 'schedule' || inputs.build_sglang_image == true || inputs.only_release_sglang == true) }}
+ if: ${{ success() && (inputs.only_release_sglang == true || (inputs.only_release_oot != true && (github.event_name == 'schedule' || inputs.build_sglang_image == true))) }}
run: |
SGLANG_VER="${{ inputs.sglang_version || env.SGLANG_VERSION }}"
SGLANG_TAG="sglang-v${SGLANG_VER}-nightly_$(date +%Y%m%d)"
diff --git a/atom/plugin/graph_capture_patch.py b/atom/plugin/graph_capture_patch.py
new file mode 100644
index 000000000..956ec5fd5
--- /dev/null
+++ b/atom/plugin/graph_capture_patch.py
@@ -0,0 +1,115 @@
+"""Patch a framework's graph_capture to also enter aiter's ca_comm.capture().
+
+When ATOM model runs as a plugin backend (vLLM or SGLang), the model uses aiter's
+collectives (tensor_model_parallel_fused_allreduce_rmsnorm etc.)
+but the host framework's graph_capture only enters its own ca_comm.capture().
+aiter's ca_comm never enters capture mode, causing _IS_CAPTURING=False ->
+registered=False -> hipMemcpyAsync on every call.
+
+This module provides a shared helper that patches any framework(vLLM or SGLang)'s
+GroupCoordinator.graph_capture to also nest aiter's ca_comm.capture(),
+so fused_allreduce_rmsnorm uses registered=True and avoids the extra hipMemcpyAsync.
+"""
+
+import functools
+import logging
+from contextlib import contextmanager, nullcontext
+
+logger = logging.getLogger("atom")
+
+
+def _get_aiter_ca_capture_context():
+ """Lazily get aiter's ca_comm.capture() context, or nullcontext if unavailable."""
+ try:
+ from aiter.dist.parallel_state import get_tp_group
+
+ aiter_tp = get_tp_group()
+ except Exception:
+ return nullcontext()
+
+ if aiter_tp is None:
+ return nullcontext()
+
+ device_communicator = getattr(aiter_tp, "device_communicator", None)
+ if device_communicator is None:
+ return nullcontext()
+
+ aiter_ca_comm = getattr(device_communicator, "ca_comm", None)
+ if aiter_ca_comm is None or getattr(aiter_ca_comm, "disabled", True):
+ return nullcontext()
+
+ capture_method = getattr(aiter_ca_comm, "capture", None)
+ if capture_method is None:
+ return nullcontext()
+
+ return capture_method()
+
+
+def _patched_graph_capture(original_graph_capture):
+ """Wrap a framework's graph_capture to also enter aiter's ca_comm.capture()."""
+
+ @functools.wraps(original_graph_capture)
+ @contextmanager
+ def wrapped(self, graph_capture_context=None, **kwargs):
+ aiter_ca_context = _get_aiter_ca_capture_context()
+ with aiter_ca_context:
+ with original_graph_capture(self, graph_capture_context, **kwargs) as ctx:
+ yield ctx
+
+ return wrapped
+
+
+def apply_graph_capture_patch(framework_module_path: str) -> bool:
+ """Patch a framework's GroupCoordinator.graph_capture to nest aiter's
+ ca_comm.capture().
+
+ Args:
+ framework_module_path: Dotted import path to the framework's
+ parallel_state module containing GroupCoordinator
+ (e.g. "vllm.distributed.parallel_state" or
+ "sglang.srt.distributed.parallel_state").
+
+ Returns:
+ True if the patch was applied, False otherwise.
+ """
+ import importlib
+
+ try:
+ parallel_state = importlib.import_module(framework_module_path)
+ except ImportError as e:
+ logger.debug(
+ "ATOM graph_capture patch: %s not available (%s), skip",
+ framework_module_path,
+ e,
+ )
+ return False
+
+ GroupCoordinator = getattr(parallel_state, "GroupCoordinator", None)
+ if GroupCoordinator is None:
+ logger.debug(
+ "ATOM graph_capture patch: GroupCoordinator not found in %s, skip",
+ framework_module_path,
+ )
+ return False
+
+ original = getattr(GroupCoordinator, "graph_capture", None)
+ if original is None or getattr(original, "_atom_aiter_patched", False):
+ return False
+
+ try:
+ GroupCoordinator.graph_capture = _patched_graph_capture(original)
+ GroupCoordinator.graph_capture._atom_aiter_patched = True # type: ignore
+ logger.info(
+ "ATOM plugin: patched %s.GroupCoordinator.graph_capture to nest "
+ "aiter ca_comm.capture() (avoids hipMemcpyAsync in aiter collectives)",
+ framework_module_path,
+ )
+ return True
+ except Exception as e:
+ logger.warning(
+ "ATOM graph_capture patch for %s failed: %s. "
+ "aiter collectives may incur extra hipMemcpyAsync in plugin mode.",
+ framework_module_path,
+ e,
+ )
+ return False
diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py
index 6ef788bdb..77c6bfc85 100644
--- a/atom/plugin/prepare.py
+++ b/atom/plugin/prepare.py
@@ -82,4 +82,12 @@ def prepare_model(config: Any, engine: str):
# init aiter dist for using aiter custom collective ops
init_aiter_dist(config=atom_config)
+ # Patch SGLang graph_capture to also enter aiter's ca_comm.capture(),
+ # avoiding hipMemcpyAsync in aiter collectives when model uses aiter's
+ # custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py)
+ if is_sglang():
+ from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch
+
+ apply_graph_capture_patch()
+
return model_cls(atom_config=atom_config)
diff --git a/atom/plugin/sglang/graph_capture_patch.py b/atom/plugin/sglang/graph_capture_patch.py
new file mode 100644
index 000000000..1fccdad1e
--- /dev/null
+++ b/atom/plugin/sglang/graph_capture_patch.py
@@ -0,0 +1,18 @@
+"""Patch SGLang graph capture to also enter aiter's ca_comm.capture().
+
+Delegates to the shared implementation in atom.plugin.graph_capture_patch.
+"""
+
+_GRAPH_CAPTURE_PATCH_APPLIED = False
+
+
+def apply_graph_capture_patch() -> None:
+ """Patch SGLang's GroupCoordinator.graph_capture to nest aiter's ca_comm.capture()."""
+ global _GRAPH_CAPTURE_PATCH_APPLIED
+
+ if _GRAPH_CAPTURE_PATCH_APPLIED:
+ return
+
+ from atom.plugin.graph_capture_patch import apply_graph_capture_patch as _apply
+
+ _GRAPH_CAPTURE_PATCH_APPLIED = _apply("sglang.srt.distributed.parallel_state")
diff --git a/atom/plugin/vllm/graph_capture_patch.py b/atom/plugin/vllm/graph_capture_patch.py
index 7765f530e..c07ef5d0b 100644
--- a/atom/plugin/vllm/graph_capture_patch.py
+++ b/atom/plugin/vllm/graph_capture_patch.py
@@ -1,66 +1,11 @@
"""Patch vLLM graph capture to also enter aiter's ca_comm.capture().
-When ATOM model runs as vLLM plugin, the model uses aiter's collectives
-(tensor_model_parallel_fused_allreduce_rmsnorm etc.) but vLLM's graph_capture
-only calls vLLM's ca_comm.capture(). aiter's ca_comm never enters capture mode,
-causing _IS_CAPTURING=False -> registered=False -> hipMemcpyAsync on every call.
-
-This patch wraps vLLM's GroupCoordinator.graph_capture to also nest aiter's
-ca_comm.capture(), so fused_allreduce_rmsnorm uses registered=True and avoids
-the extra hipMemcpyAsync.
+Delegates to the shared implementation in atom.plugin.graph_capture_patch.
"""
-import functools
-import logging
-from contextlib import contextmanager, nullcontext
-
-logger = logging.getLogger("atom")
-
-# Avoid applying patch multiple times
_GRAPH_CAPTURE_PATCH_APPLIED = False
-def _get_aiter_ca_capture_context():
- """Lazily get aiter's ca_comm.capture() context, or nullcontext if unavailable."""
- try:
- from aiter.dist.parallel_state import get_tp_group
-
- aiter_tp = get_tp_group()
- except Exception:
- return nullcontext()
-
- if aiter_tp is None:
- return nullcontext()
-
- device_communicator = getattr(aiter_tp, "device_communicator", None)
- if device_communicator is None:
- return nullcontext()
-
- aiter_ca_comm = getattr(device_communicator, "ca_comm", None)
- if aiter_ca_comm is None or getattr(aiter_ca_comm, "disabled", True):
- return nullcontext()
-
- capture_method = getattr(aiter_ca_comm, "capture", None)
- if capture_method is None:
- return nullcontext()
-
- return capture_method()
-
-
-def _patched_graph_capture(original_graph_capture):
- """Wrap vLLM's graph_capture to also enter aiter's ca_comm.capture()."""
-
- @functools.wraps(original_graph_capture)
- @contextmanager
- def wrapped(self, graph_capture_context=None):
- aiter_ca_context = _get_aiter_ca_capture_context()
- with aiter_ca_context:
- with original_graph_capture(self, graph_capture_context) as ctx:
- yield ctx
-
- return wrapped
-
-
def apply_graph_capture_patch() -> None:
"""Patch vLLM's GroupCoordinator.graph_capture to nest aiter's ca_comm.capture()."""
global _GRAPH_CAPTURE_PATCH_APPLIED
@@ -68,30 +13,6 @@ def apply_graph_capture_patch() -> None:
if _GRAPH_CAPTURE_PATCH_APPLIED:
return
- try:
- import vllm.distributed.parallel_state as parallel_state
-
- GroupCoordinator = getattr(parallel_state, "GroupCoordinator", None)
- if GroupCoordinator is None:
- logger.debug("ATOM graph_capture patch: GroupCoordinator not found, skip")
- return
-
- original = getattr(GroupCoordinator, "graph_capture", None)
- if original is None or getattr(original, "_atom_aiter_patched", False):
- return
+ from atom.plugin.graph_capture_patch import apply_graph_capture_patch as _apply
- GroupCoordinator.graph_capture = _patched_graph_capture(original)
- GroupCoordinator.graph_capture._atom_aiter_patched = True # type: ignore
- _GRAPH_CAPTURE_PATCH_APPLIED = True
- logger.info(
- "ATOM plugin: patched vLLM graph_capture to nest aiter ca_comm.capture() "
- "(avoids hipMemcpyAsync in fused_allreduce_rmsnorm)"
- )
- except ImportError as e:
- logger.debug("ATOM graph_capture patch: vllm not available (%s), skip", e)
- except Exception as e:
- logger.warning(
- "ATOM graph_capture patch failed: %s. "
- "fused_allreduce_rmsnorm may incur extra hipMemcpyAsync in vLLM plugin mode.",
- e,
- )
+ _GRAPH_CAPTURE_PATCH_APPLIED = _apply("vllm.distributed.parallel_state")
diff --git a/docker/Dockerfile b/docker/Dockerfile
index c4b6f509e..39d8a9451 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -13,9 +13,6 @@ ARG VLLM_REPO="https://github.com/vllm-project/vllm.git"
ARG VLLM_COMMIT="b31e9326a7d9394aab8c767f8ebe225c65594b60"
ARG INSTALL_LM_EVAL=1
ARG INSTALL_FASTSAFETENSORS=1
-ARG ROCM_WHEEL_INDEX="https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2.1/"
-ARG ROCM_TORCHVISION_VERSION="0.24.0+rocm7.2.1.gitb919bd0c"
-ARG ROCM_TORCHAUDIO_VERSION="2.9.0+rocm7.2.1.gite3c6ee2b"
# Let PR OOT CI verify whether a pulled prebuilt image still matches this vLLM commit
LABEL com.rocm.atom.vllm_commit="${VLLM_COMMIT}"
@@ -103,11 +100,7 @@ RUN echo "========== [OOT 7/7] Install vLLM runtime dependencies ==========" &&
"${VENV_PYTHON}" -c "import glob, os, torch; print(f'torch.version.hip: {torch.version.hip}'); print(f'torch.version.cuda: {torch.version.cuda}'); torch_lib_dir=os.path.join(os.path.dirname(torch.__file__), 'lib'); print(f'torch lib dir: {torch_lib_dir}'); print(f'libtorch_hip candidates: {glob.glob(os.path.join(torch_lib_dir, \"libtorch_hip.so*\"))}'); assert torch.version.hip is not None, 'Torch is not ROCm build (torch.version.hip is None).'" && \
"${VENV_PYTHON}" -m pip show vllm torch triton torchvision torchaudio amdsmi amd-aiter atom mori || true
-RUN echo "FIXME: rocm/pytorch:latest currently resolves incompatible torchvision/torchaudio wheels; pin ROCm wheels here until the base image is fixed upstream." && \
- "${VENV_PYTHON}" -m pip uninstall -y torchvision torchaudio || true && \
- "${VENV_PYTHON}" -m pip install --no-index --find-links "${ROCM_WHEEL_INDEX}" \
- "torchvision==${ROCM_TORCHVISION_VERSION}" \
- "torchaudio==${ROCM_TORCHAUDIO_VERSION}" && \
+RUN echo "========== [VLLM-ATOM] Validate vision/audio wheels ==========" && \
"${VENV_PYTHON}" -c "import torch, torchvision, torchaudio; from torchvision.transforms import InterpolationMode; from transformers.models.auto.image_processing_auto import get_image_processor_config; print(f'torch: {torch.__version__}'); print(f'torchvision: {torchvision.__version__}'); print(f'torchaudio: {torchaudio.__version__}'); print(f'InterpolationMode: {InterpolationMode.BILINEAR}'); print(f'get_image_processor_config: {get_image_processor_config.__name__}')"
# Restore that exact base-image Triton after all OOT installs finish. The goal
@@ -156,18 +149,15 @@ ARG VENV_PYTHON="/opt/venv/bin/python"
ARG SGLANG_REPO="https://github.com/sgl-project/sglang.git"
ARG SGLANG_REF="v0.5.10"
ARG SGLANG_TRITON_VERSION="3.6.0"
-ARG ROCM_WHEEL_INDEX="https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2.1/"
-ARG ROCM_TORCHVISION_VERSION="0.24.0+rocm7.2.1.gitb919bd0c"
-ARG ROCM_TORCHAUDIO_VERSION="2.9.0+rocm7.2.1.gite3c6ee2b"
LABEL com.rocm.atom.sglang_ref="${SGLANG_REF}"
ENV PATH="/opt/venv/bin:${PATH}"
ENV PYTHONPATH="/app/sglang/python:/app/ATOM:${PYTHONPATH}"
-RUN echo "========== [SGLANG 0/6] Check Aiter/FlyDSL versions before SGLang build ==========" && \
+RUN echo "========== [SGLANG-ATOM 0/6] Check Aiter/FlyDSL versions before SGLang build ==========" && \
"${VENV_PYTHON}" -m pip show atom mori amd-aiter flydsl || true
-RUN echo "========== [SGLANG 1/6] Clone SGLang ==========" && \
+RUN echo "========== [SGLANG-ATOM 1/6] Clone SGLang ==========" && \
rm -rf /app/sglang && \
git clone "${SGLANG_REPO}" /app/sglang && \
cd /app/sglang && \
@@ -176,7 +166,7 @@ RUN echo "========== [SGLANG 1/6] Clone SGLang ==========" && \
echo "sglang ref:" && \
git rev-parse HEAD
-RUN echo "========== [SGLANG 2/6] Build sglang kernel ==========" && \
+RUN echo "========== [SGLANG-ATOM 2/6] Build sglang kernel ==========" && \
"${VENV_PYTHON}" -m pip uninstall -y sgl-kernel sglang-kernel sglang || true && \
"${VENV_PYTHON}" -m pip install --upgrade pip setuptools wheel && \
DETECTED_AMDGPU_TARGET="$("${VENV_PYTHON}" -c "import torch; print(torch.cuda.get_device_properties(0).gcnArchName.split(':')[0] if torch.cuda.is_available() else '')" 2>/dev/null || true)" && \
@@ -192,7 +182,7 @@ RUN echo "========== [SGLANG 2/6] Build sglang kernel ==========" && \
AMDGPU_TARGET="${FINAL_AMDGPU_TARGET}" "${VENV_PYTHON}" setup_rocm.py install && \
"${VENV_PYTHON}" -m pip show sglang-kernel || true
-RUN echo "========== [SGLANG 3/6] Install SGLang dependencies ==========" && \
+RUN echo "========== [SGLANG-ATOM 3/6] Install SGLang dependencies ==========" && \
cd /app/sglang/python && \
rm -f pyproject.toml && \
cp pyproject_other.toml pyproject.toml && \
@@ -226,11 +216,7 @@ RUN echo "========== [SGLANG 3/6] Install SGLang dependencies ==========" && \
rm -f /tmp/sglang-runtime-common.txt && \
"${VENV_PYTHON}" -m pip show sglang torch triton transformers IPython orjson pybase64 petit-kernel wave-lang xgrammar outlines apache-tvm-ffi || true
-RUN echo "========== [SGLANG 4/6] Pin ROCm vision/audio wheels ==========" && \
- "${VENV_PYTHON}" -m pip uninstall -y torchvision torchaudio || true && \
- "${VENV_PYTHON}" -m pip install --no-deps --no-index --find-links "${ROCM_WHEEL_INDEX}" \
- "torchvision==${ROCM_TORCHVISION_VERSION}" \
- "torchaudio==${ROCM_TORCHAUDIO_VERSION}" && \
+RUN echo "========== [SGLANG-ATOM 4/6] Validate vision/audio wheels ==========" && \
"${VENV_PYTHON}" -m sglang.launch_server --help >/dev/null && \
"${VENV_PYTHON}" -c "import os, torch, torchvision, torchaudio, sglang, triton, transformers; from torchvision.io import decode_jpeg; assert torch.version.hip is not None, 'Torch is not ROCm build (torch.version.hip is None).'; print(f'torch: {torch.__version__}'); print(f'triton: {triton.__version__}'); print(f'transformers: {transformers.__version__}'); print(f'torchvision: {torchvision.__version__}'); print(f'torchaudio: {torchaudio.__version__}'); print(f'decode_jpeg: {decode_jpeg.__name__}'); print(f'sglang imported from: {sglang.__file__}'); print(f'PYTHONPATH={os.environ.get(\"PYTHONPATH\", \"\")}')" && \
echo "Validated sglang launch_server entrypoint"
@@ -242,12 +228,11 @@ RUN echo "========== [SGLANG 4/6] Pin ROCm vision/audio wheels ==========" && \
# so pip does not downgrade it back to the torch-pinned 3.5.1 dependency. Keep
# this override local to `atom_sglang` so the OOT/atom_release flow preserves
# its prior Triton behavior.
-RUN echo "========== [SGLANG 5/6] Install validated Triton ==========" && \
- "${VENV_PYTHON}" -m pip uninstall -y triton || true && \
+RUN echo "========== [SGLANG-ATOM 5/6] Install validated Triton ==========" && \
"${VENV_PYTHON}" -m pip install --no-cache-dir "triton==${SGLANG_TRITON_VERSION}" && \
- "${VENV_PYTHON}" -c "import triton; print(f'Installed Triton: {triton.__version__}')"
+ "${VENV_PYTHON}" -m pip show triton
-RUN echo "========== [SGLANG 6/6] Check Aiter/FlyDSL versions after SGLang build ==========" && \
+RUN echo "========== [SGLANG-ATOM 6/6] Check Aiter/FlyDSL versions after SGLang build ==========" && \
"${VENV_PYTHON}" -m pip show atom mori amd-aiter flydsl || true
CMD ["/bin/bash"]