From f10f0541d062f8304bd1a5094d9792683b076340 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 13 Apr 2026 01:52:09 -0500 Subject: [PATCH 1/5] add sglang benchmark Signed-off-by: zhuyuhua-v --- .github/dashboard/index.html | 3 +- ...sion.py => plugin_benchmark_regression.py} | 6 ++-- ...summary.py => plugin_benchmark_summary.py} | 19 +++++++---- ...rd.py => plugin_benchmark_to_dashboard.py} | 33 +++++++++++-------- ... => plugin_benchmark_validate_baseline.py} | 8 ++--- .github/workflows/atom-sglang-benchmark.yaml | 14 +++++--- .github/workflows/atom-vllm-benchmark.yaml | 12 ++++--- 7 files changed, 56 insertions(+), 39 deletions(-) rename .github/scripts/{oot_benchmark_regression.py => plugin_benchmark_regression.py} (93%) mode change 100644 => 100755 rename .github/scripts/{oot_benchmark_summary.py => plugin_benchmark_summary.py} (93%) mode change 100644 => 100755 rename .github/scripts/{oot_benchmark_to_dashboard.py => plugin_benchmark_to_dashboard.py} (80%) mode change 100644 => 100755 rename .github/scripts/{oot_benchmark_validate_baseline.py => plugin_benchmark_validate_baseline.py} (85%) mode change 100644 => 100755 diff --git a/.github/dashboard/index.html b/.github/dashboard/index.html index 094ee9697..07728d33a 100644 --- a/.github/dashboard/index.html +++ b/.github/dashboard/index.html @@ -342,7 +342,7 @@

Benchmark Dashboard

1. CONSTANTS & COLOR PALETTE ================================================================ */ // HSL-based palette: each model gets a distinct base hue; backend offsets keep -// ATOM and ATOM-vLLM visually related but not identical. +// ATOM, ATOM-vLLM, and ATOM-SGLang visually related but not identical. const MODEL_HUES = { 'DeepSeek-R1-0528': 210, // blue 'DeepSeek-R1-0528-mtp3': 175, // cyan/teal — distinct from base DeepSeek @@ -356,6 +356,7 @@

Benchmark Dashboard

const BACKEND_HUE_OFFSETS = { 'ATOM': 0, 'ATOM-vLLM': 28, + 'ATOM-SGLang': -30, }; const FALLBACK_HUES = [45, 330, 190, 30]; // yellow, pink, teal, amber let fallbackHueIdx = 0; diff --git a/.github/scripts/oot_benchmark_regression.py b/.github/scripts/plugin_benchmark_regression.py old mode 100644 new mode 100755 similarity index 93% rename from .github/scripts/oot_benchmark_regression.py rename to .github/scripts/plugin_benchmark_regression.py index cf31c124f..9903ce98b --- a/.github/scripts/oot_benchmark_regression.py +++ b/.github/scripts/plugin_benchmark_regression.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""OOT-specific regression summary built on top of shared summarize helpers.""" +"""Regression summary built on top of shared summarize helpers.""" from __future__ import annotations @@ -39,7 +39,7 @@ def build_report( def main() -> int: parser = argparse.ArgumentParser( - description="Print only the OOT regression report without the full results table" + description="Print only the regression report without the full results table" ) parser.add_argument( "--result-dir", @@ -88,4 +88,4 @@ def main() -> int: if __name__ == "__main__": - raise SystemExit(main()) + raise SystemExit(main()) \ No newline at end of file diff --git a/.github/scripts/oot_benchmark_summary.py b/.github/scripts/plugin_benchmark_summary.py old mode 100644 new mode 100755 similarity index 93% rename from .github/scripts/oot_benchmark_summary.py rename to .github/scripts/plugin_benchmark_summary.py index ed6f21883..db753aea1 --- a/.github/scripts/oot_benchmark_summary.py +++ b/.github/scripts/plugin_benchmark_summary.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Generate a resilient OOT benchmark summary table. +"""Generate a resilient benchmark summary table. This script is intentionally tolerant of partial or total benchmark failure: - missing result JSON => case is marked FAIL @@ -143,12 +143,12 @@ def _build_rows(result_dir: Path, matrix_payload: dict) -> list[dict]: return rows -def _print_markdown_table(rows: list[dict], run_url: str | None) -> None: +def _print_markdown_table(rows: list[dict], run_url: str | None, title: str) -> None: total_cases = len(rows) passed_cases = sum(1 for row in rows if row["status"] == "PASS") failed_cases = total_cases - passed_cases - print("## OOT Benchmark Summary\n") + print(f"## {title}\n") if run_url: print(f"Run: {run_url}\n") print( @@ -185,11 +185,11 @@ def _print_markdown_table(rows: list[dict], run_url: str | None) -> None: def main() -> int: - parser = argparse.ArgumentParser(description="Summarize OOT benchmark results") + parser = argparse.ArgumentParser(description="Summarize benchmark results") parser.add_argument( "--result-dir", required=True, - help="Directory containing downloaded OOT benchmark JSON files", + help="Directory containing downloaded benchmark JSON files", ) parser.add_argument( "--matrix-json", @@ -206,12 +206,17 @@ def main() -> int: default=None, help="Optional path to write a structured summary report", ) + parser.add_argument( + "--title", + default="Benchmark Summary", + help="Title for the markdown report", + ) args = parser.parse_args() matrix_payload = json.loads(args.matrix_json) rows = _build_rows(Path(args.result_dir), matrix_payload) - _print_markdown_table(rows, args.run_url) + _print_markdown_table(rows, args.run_url, args.title) if args.output_json: report = { @@ -229,4 +234,4 @@ def main() -> int: if __name__ == "__main__": - raise SystemExit(main()) + raise SystemExit(main()) \ No newline at end of file diff --git a/.github/scripts/oot_benchmark_to_dashboard.py b/.github/scripts/plugin_benchmark_to_dashboard.py old mode 100644 new mode 100755 similarity index 80% rename from .github/scripts/oot_benchmark_to_dashboard.py rename to .github/scripts/plugin_benchmark_to_dashboard.py index 3cf026013..c732390e0 --- a/.github/scripts/oot_benchmark_to_dashboard.py +++ b/.github/scripts/plugin_benchmark_to_dashboard.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Convert OOT benchmark JSON results to github-action-benchmark input.""" +"""Convert benchmark JSON results to github-action-benchmark input.""" from __future__ import annotations @@ -9,7 +9,6 @@ from pathlib import Path VARIANT_RE = re.compile(r"-(mtp\d*)-") -DEFAULT_BACKEND = "ATOM-vLLM" def derive_model_name(result_path: Path, payload: dict) -> str: @@ -59,11 +58,11 @@ def is_dashboard_publish_allowed(payload: dict) -> bool: return str(publish_flag).strip().lower() not in {"0", "false", "no"} -def build_entries(result_dir: Path, run_url: str | None) -> list[dict]: +def build_entries(result_dir: Path, run_url: str | None, default_backend: str) -> list[dict]: entries: list[dict] = [] for result_path in sorted(result_dir.glob("*.json")): - if result_path.name == "regression_report.json": + if result_path.name == "regression_report.json" or result_path.name.endswith("_benchmark_summary.json"): continue try: @@ -81,20 +80,23 @@ def build_entries(result_dir: Path, run_url: str | None) -> list[dict]: isl = int(payload.get("random_input_len", 0)) osl = int(payload.get("random_output_len", 0)) conc = int(payload.get("max_concurrency", 0)) - label_prefix = f"{DEFAULT_BACKEND}::{model} {isl}/{osl} c={conc}" + label_prefix = f"{default_backend}::{model} {isl}/{osl} c={conc}" extra = f"Run: {run_url}" if run_url else "" gpu_name = payload.get("gpu_name", "") gpu_vram = payload.get("gpu_vram_gb", 0) rocm_ver = payload.get("rocm_version", "") - oot_image_tag = payload.get("oot_image_tag", "") + + # Support both OOT and SGLang image tag fields + image_tag = payload.get("oot_image_tag", payload.get("sglang_image_tag", "")) + if gpu_name: extra += f" | GPU: {gpu_name}" if gpu_vram: extra += f" | VRAM: {gpu_vram}GB" if rocm_ver: extra += f" | ROCm: {rocm_ver}" - if oot_image_tag: - extra += f" | Docker: {oot_image_tag}" + if image_tag: + extra += f" | Docker: {image_tag}" extra = extra or None append_metric( @@ -145,10 +147,10 @@ def build_entries(result_dir: Path, run_url: str | None) -> list[dict]: def main() -> None: parser = argparse.ArgumentParser( - description="Convert OOT benchmark JSON files to github-action-benchmark input" + description="Convert benchmark JSON files to github-action-benchmark input" ) parser.add_argument( - "result_dir", help="Directory containing OOT benchmark JSON files" + "result_dir", help="Directory containing benchmark JSON files" ) parser.add_argument("--output", required=True, help="Output JSON path") parser.add_argument( @@ -156,15 +158,20 @@ def main() -> None: default=None, help="Optional GitHub Actions run URL added to each metric as extra metadata", ) + parser.add_argument( + "--default-backend", + required=True, + help="Default backend name (e.g. ATOM-SGLang or ATOM-vLLM)", + ) args = parser.parse_args() result_dir = Path(args.result_dir) - entries = build_entries(result_dir, args.run_url) + entries = build_entries(result_dir, args.run_url, args.default_backend) output_path = Path(args.output) output_path.write_text(json.dumps(entries, indent=2), encoding="utf-8") - print(f"Generated {len(entries)} OOT entries at {output_path}") + print(f"Generated {len(entries)} entries at {output_path}") if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/.github/scripts/oot_benchmark_validate_baseline.py b/.github/scripts/plugin_benchmark_validate_baseline.py old mode 100644 new mode 100755 similarity index 85% rename from .github/scripts/oot_benchmark_validate_baseline.py rename to .github/scripts/plugin_benchmark_validate_baseline.py index 35a317f00..23c03ce4e --- a/.github/scripts/oot_benchmark_validate_baseline.py +++ b/.github/scripts/plugin_benchmark_validate_baseline.py @@ -7,8 +7,6 @@ import json from pathlib import Path -SKIP_FILENAMES = {"regression_report.json", "oot_benchmark_summary.json"} - def is_dashboard_publish_allowed(payload: dict) -> bool: publish_flag = payload.get("dashboard_publish_allowed") @@ -22,7 +20,7 @@ def is_dashboard_publish_allowed(payload: dict) -> bool: def validate_result_dir(result_dir: Path) -> bool: has_valid_result = False for path in result_dir.rglob("*.json"): - if path.name in SKIP_FILENAMES: + if path.name == "regression_report.json" or path.name.endswith("_benchmark_summary.json"): continue try: payload = json.loads(path.read_text(encoding="utf-8")) @@ -38,7 +36,7 @@ def validate_result_dir(result_dir: Path) -> bool: def main() -> int: parser = argparse.ArgumentParser( - description="Validate dashboard-eligible OOT benchmark artifacts" + description="Validate dashboard-eligible benchmark artifacts" ) parser.add_argument("result_dir", help="Directory containing downloaded artifacts") args = parser.parse_args() @@ -47,4 +45,4 @@ def main() -> int: if __name__ == "__main__": - raise SystemExit(main()) + raise SystemExit(main()) \ No newline at end of file diff --git a/.github/workflows/atom-sglang-benchmark.yaml b/.github/workflows/atom-sglang-benchmark.yaml index 2c13dbe85..6a0b4df76 100644 --- a/.github/workflows/atom-sglang-benchmark.yaml +++ b/.github/workflows/atom-sglang-benchmark.yaml @@ -832,11 +832,12 @@ jobs: BENCHMARK_MATRIX: ${{ needs.build-benchmark-matrix.outputs.benchmark_matrix }} RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} run: | - python3 .github/scripts/sglang_benchmark_summary.py \ + python3 .github/scripts/plugin_benchmark_summary.py \ --result-dir . \ --matrix-json "${BENCHMARK_MATRIX}" \ --run-url "${RUN_URL}" \ --output-json sglang_benchmark_summary.json \ + --title "SGLang Benchmark Summary" \ >> "$GITHUB_STEP_SUMMARY" - name: Read summary stats @@ -888,7 +889,7 @@ jobs: continue fi - if python3 .github/scripts/sglang_benchmark_validate_baseline.py /tmp/baseline_candidate; then + if python3 .github/scripts/plugin_benchmark_validate_baseline.py /tmp/baseline_candidate; then mv /tmp/baseline_candidate /tmp/baseline BASELINE_DIR="/tmp/baseline" echo "Using baseline from run #$PREV_RUN_ID" @@ -918,7 +919,7 @@ jobs: >> "$GITHUB_STEP_SUMMARY" fi - python3 .github/scripts/sglang_benchmark_regression.py \ + python3 .github/scripts/plugin_benchmark_regression.py \ --result-dir . \ $BASELINE_ARG \ --output-json regression_report.json \ @@ -945,10 +946,11 @@ jobs: - name: Transform results for benchmark dashboard if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0' run: | - python3 .github/scripts/sglang_benchmark_to_dashboard.py \ + python3 .github/scripts/plugin_benchmark_to_dashboard.py \ . \ --output benchmark-action-input.json \ - --run-url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" + --run-url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" \ + --default-backend "ATOM-SGLang" - name: Store benchmark result to dashboard if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0' @@ -973,10 +975,12 @@ jobs: CURRENT_SHA=$(git rev-parse HEAD) cp .github/dashboard/index.html /tmp/dashboard_index.html cp docs/assets/atom_logo.png /tmp/dashboard_logo.png + cp docs/assets/atom_logo_mini.png /tmp/dashboard_logo_mini.png git fetch origin gh-pages git checkout gh-pages cp /tmp/dashboard_index.html benchmark-dashboard/index.html cp /tmp/dashboard_logo.png benchmark-dashboard/atom_logo.png + cp /tmp/dashboard_logo_mini.png benchmark-dashboard/atom_logo_mini.png git add benchmark-dashboard/ git diff --cached --quiet || git commit -m "Update SGLang benchmark data and dashboard" git push origin gh-pages diff --git a/.github/workflows/atom-vllm-benchmark.yaml b/.github/workflows/atom-vllm-benchmark.yaml index 7deb974d2..e01ab0040 100644 --- a/.github/workflows/atom-vllm-benchmark.yaml +++ b/.github/workflows/atom-vllm-benchmark.yaml @@ -973,11 +973,12 @@ jobs: BENCHMARK_MATRIX: ${{ needs.build-benchmark-matrix.outputs.benchmark_matrix }} RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} run: | - python3 .github/scripts/oot_benchmark_summary.py \ + python3 .github/scripts/plugin_benchmark_summary.py \ --result-dir . \ --matrix-json "${BENCHMARK_MATRIX}" \ --run-url "${RUN_URL}" \ --output-json oot_benchmark_summary.json \ + --title "OOT Benchmark Summary" \ >> "$GITHUB_STEP_SUMMARY" - name: Read summary stats @@ -1029,7 +1030,7 @@ jobs: continue fi - if python3 .github/scripts/oot_benchmark_validate_baseline.py /tmp/baseline_candidate; then + if python3 .github/scripts/plugin_benchmark_validate_baseline.py /tmp/baseline_candidate; then mv /tmp/baseline_candidate /tmp/baseline BASELINE_DIR="/tmp/baseline" echo "Using baseline from run #$PREV_RUN_ID" @@ -1059,7 +1060,7 @@ jobs: >> "$GITHUB_STEP_SUMMARY" fi - python3 .github/scripts/oot_benchmark_regression.py \ + python3 .github/scripts/plugin_benchmark_regression.py \ --result-dir . \ $BASELINE_ARG \ --output-json regression_report.json \ @@ -1086,10 +1087,11 @@ jobs: - name: Transform results for benchmark dashboard if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0' run: | - python3 .github/scripts/oot_benchmark_to_dashboard.py \ + python3 .github/scripts/plugin_benchmark_to_dashboard.py \ . \ --output benchmark-action-input.json \ - --run-url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" + --run-url "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}" \ + --default-backend "ATOM-vLLM" - name: Store benchmark result to dashboard if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0' From e0542f6e687ae9cec23b04db05a3de68af139902 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 13 Apr 2026 02:16:59 -0500 Subject: [PATCH 2/5] fix docker release workflow Signed-off-by: zhuyuhua-v --- .github/workflows/atom-vllm-benchmark.yaml | 2 +- .github/workflows/docker-release.yaml | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/atom-vllm-benchmark.yaml b/.github/workflows/atom-vllm-benchmark.yaml index e01ab0040..b9c02651a 100644 --- a/.github/workflows/atom-vllm-benchmark.yaml +++ b/.github/workflows/atom-vllm-benchmark.yaml @@ -978,7 +978,7 @@ jobs: --matrix-json "${BENCHMARK_MATRIX}" \ --run-url "${RUN_URL}" \ --output-json oot_benchmark_summary.json \ - --title "OOT Benchmark Summary" \ + --title "ATOM-vLLM Benchmark Summary" \ >> "$GITHUB_STEP_SUMMARY" - name: Read summary stats diff --git a/.github/workflows/docker-release.yaml b/.github/workflows/docker-release.yaml index a54230f25..c2ed6e2b6 100644 --- a/.github/workflows/docker-release.yaml +++ b/.github/workflows/docker-release.yaml @@ -61,11 +61,11 @@ on: type: boolean default: true only_release_oot: - description: "Only release the OOT vLLM image: skip native image test/push and always build/push OOT" + description: "Release the OOT vLLM image (skips native image test/push). If both this and SGLang are checked, both will be released." type: boolean default: false only_release_sglang: - description: "Only release the SGLang+ATOM image: skip native image test/push and always build/push SGLang+ATOM" + description: "Release the SGLang+ATOM image (skips native image test/push). If both this and OOT are checked, both will be released." type: boolean default: false @@ -193,7 +193,7 @@ jobs: docker push rocm/atom-dev:${TAG} - name: Build OOT Docker image - if: ${{ success() && inputs.only_release_sglang != true && (github.event_name == 'schedule' || inputs.build_oot_image == true || inputs.only_release_oot == true) }} + if: ${{ success() && (inputs.only_release_oot == true || (inputs.only_release_sglang != true && (github.event_name == 'schedule' || inputs.build_oot_image == true))) }} timeout-minutes: 180 run: | OOT_BASE_IMAGE="atom_release:ci" @@ -209,7 +209,7 @@ jobs: docker inspect atom_oot_release:ci - name: Push OOT Docker image - if: ${{ success() && inputs.only_release_sglang != true && (github.event_name == 'schedule' || inputs.build_oot_image == true || inputs.only_release_oot == true) }} + if: ${{ success() && (inputs.only_release_oot == true || (inputs.only_release_sglang != true && (github.event_name == 'schedule' || inputs.build_oot_image == true))) }} run: | VLLM_VER="${{ env.VLLM_VERSION }}" OOT_TAG="vllm-v${VLLM_VER}-nightly_$(date +%Y%m%d)" @@ -220,7 +220,7 @@ jobs: docker push rocm/atom-dev:${OOT_LATEST_TAG} - name: Build SGLang Docker image - if: ${{ success() && inputs.only_release_oot != true && (github.event_name == 'schedule' || inputs.build_sglang_image == true || inputs.only_release_sglang == true) }} + if: ${{ success() && (inputs.only_release_sglang == true || (inputs.only_release_oot != true && (github.event_name == 'schedule' || inputs.build_sglang_image == true))) }} timeout-minutes: 180 run: | SGLANG_BASE_IMAGE="atom_release:ci" @@ -235,7 +235,7 @@ jobs: docker inspect atom_sglang_release:ci - name: Push SGLang+ATOM Docker image - if: ${{ success() && inputs.only_release_oot != true && (github.event_name == 'schedule' || inputs.build_sglang_image == true || inputs.only_release_sglang == true) }} + if: ${{ success() && (inputs.only_release_sglang == true || (inputs.only_release_oot != true && (github.event_name == 'schedule' || inputs.build_sglang_image == true))) }} run: | SGLANG_VER="${{ inputs.sglang_version || env.SGLANG_VERSION }}" SGLANG_TAG="sglang-v${SGLANG_VER}-nightly_$(date +%Y%m%d)" From 14e2f4f6bea16a3708333d3ded29b800b4c41f03 Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Mon, 13 Apr 2026 02:55:06 -0500 Subject: [PATCH 3/5] update push dashboard data Signed-off-by: zhuyuhua-v --- .github/workflows/atom-sglang-benchmark.yaml | 29 ++++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/.github/workflows/atom-sglang-benchmark.yaml b/.github/workflows/atom-sglang-benchmark.yaml index 6a0b4df76..febc57a2a 100644 --- a/.github/workflows/atom-sglang-benchmark.yaml +++ b/.github/workflows/atom-sglang-benchmark.yaml @@ -967,21 +967,26 @@ jobs: max-items-in-chart: 90 github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Deploy custom dashboard to gh-pages + - name: Push dashboard data to gh-pages if: needs.resolve-atom-source.outputs.publish_to_dashboard == 'true' && steps.summary-stats.outputs.passed_cases != '0' run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - CURRENT_SHA=$(git rev-parse HEAD) - cp .github/dashboard/index.html /tmp/dashboard_index.html - cp docs/assets/atom_logo.png /tmp/dashboard_logo.png - cp docs/assets/atom_logo_mini.png /tmp/dashboard_logo_mini.png git fetch origin gh-pages git checkout gh-pages - cp /tmp/dashboard_index.html benchmark-dashboard/index.html - cp /tmp/dashboard_logo.png benchmark-dashboard/atom_logo.png - cp /tmp/dashboard_logo_mini.png benchmark-dashboard/atom_logo_mini.png - git add benchmark-dashboard/ - git diff --cached --quiet || git commit -m "Update SGLang benchmark data and dashboard" - git push origin gh-pages - git checkout "$CURRENT_SHA" + + if [ ! -f benchmark-dashboard/data.js ]; then + echo "::error::benchmark-dashboard/data.js was not produced by github-action-benchmark" + exit 1 + fi + + cp benchmark-dashboard/data.js /tmp/sglang-benchmark-data.js + + # Rebuild the publish branch from origin/gh-pages so only dashboard data is updated. + git checkout -B gh-pages-data-only origin/gh-pages + mkdir -p benchmark-dashboard + cp /tmp/sglang-benchmark-data.js benchmark-dashboard/data.js + + git add benchmark-dashboard/data.js + git diff --cached --quiet || git commit -m "Update SGLang benchmark dashboard data" + git push origin HEAD:gh-pages From 1f36af33c3095b9d0f15e86b20c3dfae4f6bd90d Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Wed, 1 Apr 2026 05:24:39 +0000 Subject: [PATCH 4/5] add graph capture patch(like vLLM) for sglang+atom path Signed-off-by: zhuyuhua-v --- atom/plugin/graph_capture_patch.py | 115 ++++++++++++++++++++++ atom/plugin/prepare.py | 8 ++ atom/plugin/sglang/graph_capture_patch.py | 18 ++++ atom/plugin/vllm/graph_capture_patch.py | 85 +--------------- 4 files changed, 144 insertions(+), 82 deletions(-) create mode 100644 atom/plugin/graph_capture_patch.py create mode 100644 atom/plugin/sglang/graph_capture_patch.py diff --git a/atom/plugin/graph_capture_patch.py b/atom/plugin/graph_capture_patch.py new file mode 100644 index 000000000..956ec5fd5 --- /dev/null +++ b/atom/plugin/graph_capture_patch.py @@ -0,0 +1,115 @@ +"""Patch a framework's graph_capture to also enter aiter's ca_comm.capture(). + +When ATOM model runs as a plugin backend (vLLM or SGLang), the model uses aiter's +collectives (tensor_model_parallel_fused_allreduce_rmsnorm etc.) +but the host framework's graph_capture only enters its own ca_comm.capture(). +aiter's ca_comm never enters capture mode, causing _IS_CAPTURING=False -> +registered=False -> hipMemcpyAsync on every call. + +This module provides a shared helper that patches any framework(vLLM or SGLang)'s +GroupCoordinator.graph_capture to also nest aiter's ca_comm.capture(), +so fused_allreduce_rmsnorm uses registered=True and avoids the extra hipMemcpyAsync. +""" + +import functools +import logging +from contextlib import contextmanager, nullcontext + +logger = logging.getLogger("atom") + + +def _get_aiter_ca_capture_context(): + """Lazily get aiter's ca_comm.capture() context, or nullcontext if unavailable.""" + try: + from aiter.dist.parallel_state import get_tp_group + + aiter_tp = get_tp_group() + except Exception: + return nullcontext() + + if aiter_tp is None: + return nullcontext() + + device_communicator = getattr(aiter_tp, "device_communicator", None) + if device_communicator is None: + return nullcontext() + + aiter_ca_comm = getattr(device_communicator, "ca_comm", None) + if aiter_ca_comm is None or getattr(aiter_ca_comm, "disabled", True): + return nullcontext() + + capture_method = getattr(aiter_ca_comm, "capture", None) + if capture_method is None: + return nullcontext() + + return capture_method() + + +def _patched_graph_capture(original_graph_capture): + """Wrap a framework's graph_capture to also enter aiter's ca_comm.capture().""" + + @functools.wraps(original_graph_capture) + @contextmanager + def wrapped(self, graph_capture_context=None, **kwargs): + aiter_ca_context = _get_aiter_ca_capture_context() + with aiter_ca_context: + with original_graph_capture(self, graph_capture_context, **kwargs) as ctx: + yield ctx + + return wrapped + + +def apply_graph_capture_patch(framework_module_path: str) -> bool: + """Patch a framework's GroupCoordinator.graph_capture to nest aiter's + ca_comm.capture(). + + Args: + framework_module_path: Dotted import path to the framework's + parallel_state module containing GroupCoordinator + (e.g. "vllm.distributed.parallel_state" or + "sglang.srt.distributed.parallel_state"). + + Returns: + True if the patch was applied, False otherwise. + """ + import importlib + + try: + parallel_state = importlib.import_module(framework_module_path) + except ImportError as e: + logger.debug( + "ATOM graph_capture patch: %s not available (%s), skip", + framework_module_path, + e, + ) + return False + + GroupCoordinator = getattr(parallel_state, "GroupCoordinator", None) + if GroupCoordinator is None: + logger.debug( + "ATOM graph_capture patch: GroupCoordinator not found in %s, skip", + framework_module_path, + ) + return False + + original = getattr(GroupCoordinator, "graph_capture", None) + if original is None or getattr(original, "_atom_aiter_patched", False): + return False + + try: + GroupCoordinator.graph_capture = _patched_graph_capture(original) + GroupCoordinator.graph_capture._atom_aiter_patched = True # type: ignore + logger.info( + "ATOM plugin: patched %s.GroupCoordinator.graph_capture to nest " + "aiter ca_comm.capture() (avoids hipMemcpyAsync in aiter collectives)", + framework_module_path, + ) + return True + except Exception as e: + logger.warning( + "ATOM graph_capture patch for %s failed: %s. " + "aiter collectives may incur extra hipMemcpyAsync in plugin mode.", + framework_module_path, + e, + ) + return False diff --git a/atom/plugin/prepare.py b/atom/plugin/prepare.py index 6ef788bdb..77c6bfc85 100644 --- a/atom/plugin/prepare.py +++ b/atom/plugin/prepare.py @@ -82,4 +82,12 @@ def prepare_model(config: Any, engine: str): # init aiter dist for using aiter custom collective ops init_aiter_dist(config=atom_config) + # Patch SGLang graph_capture to also enter aiter's ca_comm.capture(), + # avoiding hipMemcpyAsync in aiter collectives when model uses aiter's + # custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py) + if is_sglang(): + from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch + + apply_graph_capture_patch() + return model_cls(atom_config=atom_config) diff --git a/atom/plugin/sglang/graph_capture_patch.py b/atom/plugin/sglang/graph_capture_patch.py new file mode 100644 index 000000000..1fccdad1e --- /dev/null +++ b/atom/plugin/sglang/graph_capture_patch.py @@ -0,0 +1,18 @@ +"""Patch SGLang graph capture to also enter aiter's ca_comm.capture(). + +Delegates to the shared implementation in atom.plugin.graph_capture_patch. +""" + +_GRAPH_CAPTURE_PATCH_APPLIED = False + + +def apply_graph_capture_patch() -> None: + """Patch SGLang's GroupCoordinator.graph_capture to nest aiter's ca_comm.capture().""" + global _GRAPH_CAPTURE_PATCH_APPLIED + + if _GRAPH_CAPTURE_PATCH_APPLIED: + return + + from atom.plugin.graph_capture_patch import apply_graph_capture_patch as _apply + + _GRAPH_CAPTURE_PATCH_APPLIED = _apply("sglang.srt.distributed.parallel_state") diff --git a/atom/plugin/vllm/graph_capture_patch.py b/atom/plugin/vllm/graph_capture_patch.py index 7765f530e..c07ef5d0b 100644 --- a/atom/plugin/vllm/graph_capture_patch.py +++ b/atom/plugin/vllm/graph_capture_patch.py @@ -1,66 +1,11 @@ """Patch vLLM graph capture to also enter aiter's ca_comm.capture(). -When ATOM model runs as vLLM plugin, the model uses aiter's collectives -(tensor_model_parallel_fused_allreduce_rmsnorm etc.) but vLLM's graph_capture -only calls vLLM's ca_comm.capture(). aiter's ca_comm never enters capture mode, -causing _IS_CAPTURING=False -> registered=False -> hipMemcpyAsync on every call. - -This patch wraps vLLM's GroupCoordinator.graph_capture to also nest aiter's -ca_comm.capture(), so fused_allreduce_rmsnorm uses registered=True and avoids -the extra hipMemcpyAsync. +Delegates to the shared implementation in atom.plugin.graph_capture_patch. """ -import functools -import logging -from contextlib import contextmanager, nullcontext - -logger = logging.getLogger("atom") - -# Avoid applying patch multiple times _GRAPH_CAPTURE_PATCH_APPLIED = False -def _get_aiter_ca_capture_context(): - """Lazily get aiter's ca_comm.capture() context, or nullcontext if unavailable.""" - try: - from aiter.dist.parallel_state import get_tp_group - - aiter_tp = get_tp_group() - except Exception: - return nullcontext() - - if aiter_tp is None: - return nullcontext() - - device_communicator = getattr(aiter_tp, "device_communicator", None) - if device_communicator is None: - return nullcontext() - - aiter_ca_comm = getattr(device_communicator, "ca_comm", None) - if aiter_ca_comm is None or getattr(aiter_ca_comm, "disabled", True): - return nullcontext() - - capture_method = getattr(aiter_ca_comm, "capture", None) - if capture_method is None: - return nullcontext() - - return capture_method() - - -def _patched_graph_capture(original_graph_capture): - """Wrap vLLM's graph_capture to also enter aiter's ca_comm.capture().""" - - @functools.wraps(original_graph_capture) - @contextmanager - def wrapped(self, graph_capture_context=None): - aiter_ca_context = _get_aiter_ca_capture_context() - with aiter_ca_context: - with original_graph_capture(self, graph_capture_context) as ctx: - yield ctx - - return wrapped - - def apply_graph_capture_patch() -> None: """Patch vLLM's GroupCoordinator.graph_capture to nest aiter's ca_comm.capture().""" global _GRAPH_CAPTURE_PATCH_APPLIED @@ -68,30 +13,6 @@ def apply_graph_capture_patch() -> None: if _GRAPH_CAPTURE_PATCH_APPLIED: return - try: - import vllm.distributed.parallel_state as parallel_state - - GroupCoordinator = getattr(parallel_state, "GroupCoordinator", None) - if GroupCoordinator is None: - logger.debug("ATOM graph_capture patch: GroupCoordinator not found, skip") - return - - original = getattr(GroupCoordinator, "graph_capture", None) - if original is None or getattr(original, "_atom_aiter_patched", False): - return + from atom.plugin.graph_capture_patch import apply_graph_capture_patch as _apply - GroupCoordinator.graph_capture = _patched_graph_capture(original) - GroupCoordinator.graph_capture._atom_aiter_patched = True # type: ignore - _GRAPH_CAPTURE_PATCH_APPLIED = True - logger.info( - "ATOM plugin: patched vLLM graph_capture to nest aiter ca_comm.capture() " - "(avoids hipMemcpyAsync in fused_allreduce_rmsnorm)" - ) - except ImportError as e: - logger.debug("ATOM graph_capture patch: vllm not available (%s), skip", e) - except Exception as e: - logger.warning( - "ATOM graph_capture patch failed: %s. " - "fused_allreduce_rmsnorm may incur extra hipMemcpyAsync in vLLM plugin mode.", - e, - ) + _GRAPH_CAPTURE_PATCH_APPLIED = _apply("vllm.distributed.parallel_state") From 64975474b775d02bb2f92cdf617d2e4bed67d1ef Mon Sep 17 00:00:00 2001 From: zhuyuhua-v Date: Wed, 15 Apr 2026 09:42:51 -0500 Subject: [PATCH 5/5] remove pinned ROCm torchvision/torchaudio wheels for vllm and sglang Signed-off-by: zhuyuhua-v --- docker/Dockerfile | 33 +++++++++------------------------ 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index c4b6f509e..39d8a9451 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -13,9 +13,6 @@ ARG VLLM_REPO="https://github.com/vllm-project/vllm.git" ARG VLLM_COMMIT="b31e9326a7d9394aab8c767f8ebe225c65594b60" ARG INSTALL_LM_EVAL=1 ARG INSTALL_FASTSAFETENSORS=1 -ARG ROCM_WHEEL_INDEX="https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2.1/" -ARG ROCM_TORCHVISION_VERSION="0.24.0+rocm7.2.1.gitb919bd0c" -ARG ROCM_TORCHAUDIO_VERSION="2.9.0+rocm7.2.1.gite3c6ee2b" # Let PR OOT CI verify whether a pulled prebuilt image still matches this vLLM commit LABEL com.rocm.atom.vllm_commit="${VLLM_COMMIT}" @@ -103,11 +100,7 @@ RUN echo "========== [OOT 7/7] Install vLLM runtime dependencies ==========" && "${VENV_PYTHON}" -c "import glob, os, torch; print(f'torch.version.hip: {torch.version.hip}'); print(f'torch.version.cuda: {torch.version.cuda}'); torch_lib_dir=os.path.join(os.path.dirname(torch.__file__), 'lib'); print(f'torch lib dir: {torch_lib_dir}'); print(f'libtorch_hip candidates: {glob.glob(os.path.join(torch_lib_dir, \"libtorch_hip.so*\"))}'); assert torch.version.hip is not None, 'Torch is not ROCm build (torch.version.hip is None).'" && \ "${VENV_PYTHON}" -m pip show vllm torch triton torchvision torchaudio amdsmi amd-aiter atom mori || true -RUN echo "FIXME: rocm/pytorch:latest currently resolves incompatible torchvision/torchaudio wheels; pin ROCm wheels here until the base image is fixed upstream." && \ - "${VENV_PYTHON}" -m pip uninstall -y torchvision torchaudio || true && \ - "${VENV_PYTHON}" -m pip install --no-index --find-links "${ROCM_WHEEL_INDEX}" \ - "torchvision==${ROCM_TORCHVISION_VERSION}" \ - "torchaudio==${ROCM_TORCHAUDIO_VERSION}" && \ +RUN echo "========== [VLLM-ATOM] Validate vision/audio wheels ==========" && \ "${VENV_PYTHON}" -c "import torch, torchvision, torchaudio; from torchvision.transforms import InterpolationMode; from transformers.models.auto.image_processing_auto import get_image_processor_config; print(f'torch: {torch.__version__}'); print(f'torchvision: {torchvision.__version__}'); print(f'torchaudio: {torchaudio.__version__}'); print(f'InterpolationMode: {InterpolationMode.BILINEAR}'); print(f'get_image_processor_config: {get_image_processor_config.__name__}')" # Restore that exact base-image Triton after all OOT installs finish. The goal @@ -156,18 +149,15 @@ ARG VENV_PYTHON="/opt/venv/bin/python" ARG SGLANG_REPO="https://github.com/sgl-project/sglang.git" ARG SGLANG_REF="v0.5.10" ARG SGLANG_TRITON_VERSION="3.6.0" -ARG ROCM_WHEEL_INDEX="https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2.1/" -ARG ROCM_TORCHVISION_VERSION="0.24.0+rocm7.2.1.gitb919bd0c" -ARG ROCM_TORCHAUDIO_VERSION="2.9.0+rocm7.2.1.gite3c6ee2b" LABEL com.rocm.atom.sglang_ref="${SGLANG_REF}" ENV PATH="/opt/venv/bin:${PATH}" ENV PYTHONPATH="/app/sglang/python:/app/ATOM:${PYTHONPATH}" -RUN echo "========== [SGLANG 0/6] Check Aiter/FlyDSL versions before SGLang build ==========" && \ +RUN echo "========== [SGLANG-ATOM 0/6] Check Aiter/FlyDSL versions before SGLang build ==========" && \ "${VENV_PYTHON}" -m pip show atom mori amd-aiter flydsl || true -RUN echo "========== [SGLANG 1/6] Clone SGLang ==========" && \ +RUN echo "========== [SGLANG-ATOM 1/6] Clone SGLang ==========" && \ rm -rf /app/sglang && \ git clone "${SGLANG_REPO}" /app/sglang && \ cd /app/sglang && \ @@ -176,7 +166,7 @@ RUN echo "========== [SGLANG 1/6] Clone SGLang ==========" && \ echo "sglang ref:" && \ git rev-parse HEAD -RUN echo "========== [SGLANG 2/6] Build sglang kernel ==========" && \ +RUN echo "========== [SGLANG-ATOM 2/6] Build sglang kernel ==========" && \ "${VENV_PYTHON}" -m pip uninstall -y sgl-kernel sglang-kernel sglang || true && \ "${VENV_PYTHON}" -m pip install --upgrade pip setuptools wheel && \ DETECTED_AMDGPU_TARGET="$("${VENV_PYTHON}" -c "import torch; print(torch.cuda.get_device_properties(0).gcnArchName.split(':')[0] if torch.cuda.is_available() else '')" 2>/dev/null || true)" && \ @@ -192,7 +182,7 @@ RUN echo "========== [SGLANG 2/6] Build sglang kernel ==========" && \ AMDGPU_TARGET="${FINAL_AMDGPU_TARGET}" "${VENV_PYTHON}" setup_rocm.py install && \ "${VENV_PYTHON}" -m pip show sglang-kernel || true -RUN echo "========== [SGLANG 3/6] Install SGLang dependencies ==========" && \ +RUN echo "========== [SGLANG-ATOM 3/6] Install SGLang dependencies ==========" && \ cd /app/sglang/python && \ rm -f pyproject.toml && \ cp pyproject_other.toml pyproject.toml && \ @@ -226,11 +216,7 @@ RUN echo "========== [SGLANG 3/6] Install SGLang dependencies ==========" && \ rm -f /tmp/sglang-runtime-common.txt && \ "${VENV_PYTHON}" -m pip show sglang torch triton transformers IPython orjson pybase64 petit-kernel wave-lang xgrammar outlines apache-tvm-ffi || true -RUN echo "========== [SGLANG 4/6] Pin ROCm vision/audio wheels ==========" && \ - "${VENV_PYTHON}" -m pip uninstall -y torchvision torchaudio || true && \ - "${VENV_PYTHON}" -m pip install --no-deps --no-index --find-links "${ROCM_WHEEL_INDEX}" \ - "torchvision==${ROCM_TORCHVISION_VERSION}" \ - "torchaudio==${ROCM_TORCHAUDIO_VERSION}" && \ +RUN echo "========== [SGLANG-ATOM 4/6] Validate vision/audio wheels ==========" && \ "${VENV_PYTHON}" -m sglang.launch_server --help >/dev/null && \ "${VENV_PYTHON}" -c "import os, torch, torchvision, torchaudio, sglang, triton, transformers; from torchvision.io import decode_jpeg; assert torch.version.hip is not None, 'Torch is not ROCm build (torch.version.hip is None).'; print(f'torch: {torch.__version__}'); print(f'triton: {triton.__version__}'); print(f'transformers: {transformers.__version__}'); print(f'torchvision: {torchvision.__version__}'); print(f'torchaudio: {torchaudio.__version__}'); print(f'decode_jpeg: {decode_jpeg.__name__}'); print(f'sglang imported from: {sglang.__file__}'); print(f'PYTHONPATH={os.environ.get(\"PYTHONPATH\", \"\")}')" && \ echo "Validated sglang launch_server entrypoint" @@ -242,12 +228,11 @@ RUN echo "========== [SGLANG 4/6] Pin ROCm vision/audio wheels ==========" && \ # so pip does not downgrade it back to the torch-pinned 3.5.1 dependency. Keep # this override local to `atom_sglang` so the OOT/atom_release flow preserves # its prior Triton behavior. -RUN echo "========== [SGLANG 5/6] Install validated Triton ==========" && \ - "${VENV_PYTHON}" -m pip uninstall -y triton || true && \ +RUN echo "========== [SGLANG-ATOM 5/6] Install validated Triton ==========" && \ "${VENV_PYTHON}" -m pip install --no-cache-dir "triton==${SGLANG_TRITON_VERSION}" && \ - "${VENV_PYTHON}" -c "import triton; print(f'Installed Triton: {triton.__version__}')" + "${VENV_PYTHON}" -m pip show triton -RUN echo "========== [SGLANG 6/6] Check Aiter/FlyDSL versions after SGLang build ==========" && \ +RUN echo "========== [SGLANG-ATOM 6/6] Check Aiter/FlyDSL versions after SGLang build ==========" && \ "${VENV_PYTHON}" -m pip show atom mori amd-aiter flydsl || true CMD ["/bin/bash"]