[fix] add graph capture patch(like vLLM) for sglang+atom path#497
[fix] add graph capture patch(like vLLM) for sglang+atom path#497zhuyuhua-v wants to merge 2 commits intomainfrom
Conversation
Signed-off-by: zhuyuhua-v <yuhzhu@amd.com>
There was a problem hiding this comment.
Pull request overview
Adds a shared “graph capture patch” helper so plugin backends (vLLM and SGLang) enter both the host framework’s GroupCoordinator.graph_capture context and aiter’s ca_comm.capture() context, preventing aiter collectives from falling back to non-capture paths (and avoiding extra hipMemcpyAsync overhead) in plugin mode.
Changes:
- Introduce a shared patching utility (
atom.plugin.graph_capture_patch) to wrap a framework’sGroupCoordinator.graph_capturewithaitercapture. - Refactor the vLLM-specific patch module to delegate to the shared helper.
- Add a SGLang patch module and invoke it during SGLang plugin
prepare_model.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
atom/plugin/graph_capture_patch.py |
New shared helper to patch a framework’s GroupCoordinator.graph_capture to nest aiter capture. |
atom/plugin/vllm/graph_capture_patch.py |
Refactored to delegate to the shared helper for vLLM. |
atom/plugin/sglang/graph_capture_patch.py |
New delegating patch module for SGLang. |
atom/plugin/prepare.py |
Applies the SGLang patch during model preparation in SGLang plugin mode. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def wrapped(self, graph_capture_context=None, **kwargs): | ||
| aiter_ca_context = _get_aiter_ca_capture_context() | ||
| with aiter_ca_context: | ||
| with original_graph_capture(self, graph_capture_context, **kwargs) as ctx: |
There was a problem hiding this comment.
In the wrapper, graph_capture_context is forwarded positionally (original_graph_capture(self, graph_capture_context, **kwargs)). To keep this patch resilient to frameworks that define graph_capture_context as keyword-only, pass it by name (e.g., graph_capture_context=graph_capture_context) when calling the original method.
| with original_graph_capture(self, graph_capture_context, **kwargs) as ctx: | |
| with original_graph_capture( | |
| self, graph_capture_context=graph_capture_context, **kwargs | |
| ) as ctx: |
| "fused_allreduce_rmsnorm may incur extra hipMemcpyAsync in vLLM plugin mode.", | ||
| e, | ||
| ) | ||
| _GRAPH_CAPTURE_PATCH_APPLIED = _apply("vllm.distributed.parallel_state") |
There was a problem hiding this comment.
_GRAPH_CAPTURE_PATCH_APPLIED is set to the boolean returned by the shared helper. Since the shared helper returns False both when the patch cannot be applied and when it is already patched, this wrapper will keep retrying on every call (re-importing and re-checking). Consider marking _GRAPH_CAPTURE_PATCH_APPLIED = True after the first attempt, or adjusting the shared helper to return True when it detects an existing patch.
| _GRAPH_CAPTURE_PATCH_APPLIED = _apply("vllm.distributed.parallel_state") | |
| _apply("vllm.distributed.parallel_state") | |
| _GRAPH_CAPTURE_PATCH_APPLIED = True |
|
|
||
| from atom.plugin.graph_capture_patch import apply_graph_capture_patch as _apply | ||
|
|
||
| _GRAPH_CAPTURE_PATCH_APPLIED = _apply("sglang.srt.distributed.parallel_state") |
There was a problem hiding this comment.
_GRAPH_CAPTURE_PATCH_APPLIED is set to the boolean returned by the shared helper. Since the shared helper returns False both when the patch cannot be applied and when it is already patched, this wrapper will keep retrying on every call (re-importing and re-checking). Consider marking _GRAPH_CAPTURE_PATCH_APPLIED = True after the first attempt, or adjusting the shared helper to return True when it detects an existing patch.
| _GRAPH_CAPTURE_PATCH_APPLIED = _apply("sglang.srt.distributed.parallel_state") | |
| _apply("sglang.srt.distributed.parallel_state") | |
| _GRAPH_CAPTURE_PATCH_APPLIED = True |
| # Patch SGLang graph_capture to also enter aiter's ca_comm.capture(), | ||
| # avoiding hipMemcpyAsync in aiter collectives when model uses aiter's | ||
| # custom all_reduce (same fix as atom/plugin/vllm/graph_capture_patch.py) | ||
| if is_sglang(): | ||
| from atom.plugin.sglang.graph_capture_patch import apply_graph_capture_patch | ||
|
|
||
| apply_graph_capture_patch() |
There was a problem hiding this comment.
prepare_model now unconditionally attempts to apply the SGLang graph-capture patch in the sglang path. There are existing unit tests for prepare_model in tests/plugin/test_sglang_prepare_model.py, but none assert that this patch hook is invoked (or that failures are handled). Add a test that injects a fake atom.plugin.sglang.graph_capture_patch module and asserts apply_graph_capture_patch() is called once on the happy path.
Motivation
GroupCoordinator.graph_capturewithaiter'sca_comm.capture()Technical Details
When ATOM runs as a plugin backend, the model uses
aitercollectives such astensor_model_parallel_fused_allreduce_rmsnorm, but the host framework only enters its own graph capture context. As a result,aiterdoes not enter capture mode, falls back to the unregistered path, and triggers extrahipMemcpyAsynccalls.This change makes the framework graph capture also enter
aiter'sca_comm.capture()so theaitercollectives can run in the expected capture mode and avoid the extra memory copy overhead.Impact
hipMemcpyAsyncoverhead in SGLang plugin modeTest Plan
Test Result
Submission Checklist