Skip to content

Commit 6c38b65

Browse files
Lijiachen1018lijiachen19
andauthored
[fix] fix ascend attention (#394)
fix ascend attention Co-authored-by: lijiachen19 <lijiachen19@huawei.com>
1 parent b6e5f62 commit 6c38b65

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

ucm/integration/vllm/patch/patch_funcs/v092/vllm_ascend_adapt.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,25 @@ def maybe_execute_sparse_attention_finished(
117117
maybe_execute_sparse_attention_finished
118118
)
119119

120-
def unified_ascend_attention_with_output(
120+
vllm_ops = torch.ops.vllm
121+
orig_unified_ascend_attention_with_output = (
122+
vllm_ops.unified_ascend_attention_with_output
123+
)
124+
125+
def _wrap_op_overload(orig, impl):
126+
class _Wrapper:
127+
def __init__(self, orig):
128+
self._orig = orig
129+
130+
def __call__(self, *args, **kwargs):
131+
return impl(*args, **kwargs)
132+
133+
def __getattr__(self, name):
134+
return getattr(self._orig, name)
135+
136+
return _Wrapper(orig)
137+
138+
def unified_ascend_attention_with_output_impl(
121139
query: torch.Tensor,
122140
key: torch.Tensor,
123141
value: torch.Tensor,
@@ -151,8 +169,13 @@ def unified_ascend_attention_with_output(
151169
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
152170
return
153171

172+
vllm_ops.unified_ascend_attention_with_output = _wrap_op_overload(
173+
orig_unified_ascend_attention_with_output,
174+
unified_ascend_attention_with_output_impl,
175+
)
176+
154177
attention_v1.unified_ascend_attention_with_output = (
155-
unified_ascend_attention_with_output
178+
unified_ascend_attention_with_output_impl
156179
)
157180
except ImportError as e:
158181
logger.error(f"Failed to patch attention_v1.py: {e}", exc_info=True)

0 commit comments

Comments
 (0)