Skip to content

[fix] add graph capture patch(like vLLM) for sglang+atom path#497

Open
zhuyuhua-v wants to merge 2 commits intomainfrom
yuhua/sglang-graph-capture
Open

[fix] add graph capture patch(like vLLM) for sglang+atom path#497
zhuyuhua-v wants to merge 2 commits intomainfrom
yuhua/sglang-graph-capture

Conversation

@zhuyuhua-v
Copy link
Copy Markdown
Contributor

Motivation

  • add a shared graph capture patch that wraps the framework's GroupCoordinator.graph_capture with aiter's ca_comm.capture()
  • enable the patch for the SGLang + ATOM plugin path during model preparation
  • refactor the existing ATOM+vLLM-specific implementation to reuse the same shared helper

Technical Details

When ATOM runs as a plugin backend, the model uses aiter collectives such as tensor_model_parallel_fused_allreduce_rmsnorm, but the host framework only enters its own graph capture context. As a result, aiter does not enter capture mode, falls back to the unregistered path, and triggers extra hipMemcpyAsync calls.

This change makes the framework graph capture also enter aiter's ca_comm.capture() so the aiter collectives can run in the expected capture mode and avoid the extra memory copy overhead.

Impact

  • fixes the graph capture mismatch for the SGLang + ATOM path
  • keeps the vLLM path aligned by sharing the same implementation
  • reduces unnecessary hipMemcpyAsync overhead in SGLang plugin mode

Test Plan

Test Result

Submission Checklist

@zhuyuhua-v zhuyuhua-v marked this pull request as ready for review April 13, 2026 05:45
Copilot AI review requested due to automatic review settings April 13, 2026 05:45
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a shared “graph capture patch” helper so plugin backends (vLLM and SGLang) enter both the host framework’s GroupCoordinator.graph_capture context and aiter’s ca_comm.capture() context, preventing aiter collectives from falling back to non-capture paths (and avoiding extra hipMemcpyAsync overhead) in plugin mode.

Changes:

  • Introduce a shared patching utility (atom.plugin.graph_capture_patch) to wrap a framework’s GroupCoordinator.graph_capture with aiter capture.
  • Refactor the vLLM-specific patch module to delegate to the shared helper.
  • Add a SGLang patch module and invoke it during SGLang plugin prepare_model.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
atom/plugin/graph_capture_patch.py New shared helper to patch a framework’s GroupCoordinator.graph_capture to nest aiter capture.
atom/plugin/vllm/graph_capture_patch.py Refactored to delegate to the shared helper for vLLM.
atom/plugin/sglang/graph_capture_patch.py New delegating patch module for SGLang.
atom/plugin/prepare.py Applies the SGLang patch during model preparation in SGLang plugin mode.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

def wrapped(self, graph_capture_context=None, **kwargs):
aiter_ca_context = _get_aiter_ca_capture_context()
with aiter_ca_context:
with original_graph_capture(self, graph_capture_context, **kwargs) as ctx:
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the wrapper, graph_capture_context is forwarded positionally (original_graph_capture(self, graph_capture_context, **kwargs)). To keep this patch resilient to frameworks that define graph_capture_context as keyword-only, pass it by name (e.g., graph_capture_context=graph_capture_context) when calling the original method.

Suggested change
with original_graph_capture(self, graph_capture_context, **kwargs) as ctx:
with original_graph_capture(
self, graph_capture_context=graph_capture_context, **kwargs
) as ctx:

Copilot uses AI. Check for mistakes.
"fused_allreduce_rmsnorm may incur extra hipMemcpyAsync in vLLM plugin mode.",
e,
)
_GRAPH_CAPTURE_PATCH_APPLIED = _apply("vllm.distributed.parallel_state")
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_GRAPH_CAPTURE_PATCH_APPLIED is set to the boolean returned by the shared helper. Since the shared helper returns False both when the patch cannot be applied and when it is already patched, this wrapper will keep retrying on every call (re-importing and re-checking). Consider marking _GRAPH_CAPTURE_PATCH_APPLIED = True after the first attempt, or adjusting the shared helper to return True when it detects an existing patch.

Suggested change
_GRAPH_CAPTURE_PATCH_APPLIED = _apply("vllm.distributed.parallel_state")
_apply("vllm.distributed.parallel_state")
_GRAPH_CAPTURE_PATCH_APPLIED = True

Copilot uses AI. Check for mistakes.

from atom.plugin.graph_capture_patch import apply_graph_capture_patch as _apply

_GRAPH_CAPTURE_PATCH_APPLIED = _apply("sglang.srt.distributed.parallel_state")
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_GRAPH_CAPTURE_PATCH_APPLIED is set to the boolean returned by the shared helper. Since the shared helper returns False both when the patch cannot be applied and when it is already patched, this wrapper will keep retrying on every call (re-importing and re-checking). Consider marking _GRAPH_CAPTURE_PATCH_APPLIED = True after the first attempt, or adjusting the shared helper to return True when it detects an existing patch.

Suggested change
_GRAPH_CAPTURE_PATCH_APPLIED = _apply("sglang.srt.distributed.parallel_state")
_apply("sglang.srt.distributed.parallel_state")
_GRAPH_CAPTURE_PATCH_APPLIED = True

Copilot uses AI. Check for mistakes.
Comment on lines +85 to +91
# Patch SGLang graph_capture to also enter aiter's ca_comm.capture(),
# avoiding hipMemcpyAsync in aiter collectives when model uses aiter's
# custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py)
if is_sglang():
from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch

apply_graph_capture_patch()
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_model now unconditionally attempts to apply the SGLang graph-capture patch in the sglang path. There are existing unit tests for prepare_model in tests/plugin/test_sglang_prepare_model.py, but none assert that this patch hook is invoked (or that failures are handled). Add a test that injects a fake atom.plugin.sglang.graph_capture_patch module and asserts apply_graph_capture_patch() is called once on the happy path.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants