@@ -117,7 +117,25 @@ def maybe_execute_sparse_attention_finished(
117117 maybe_execute_sparse_attention_finished
118118 )
119119
120- def unified_ascend_attention_with_output (
120+ vllm_ops = torch .ops .vllm
121+ orig_unified_ascend_attention_with_output = (
122+ vllm_ops .unified_ascend_attention_with_output
123+ )
124+
125+ def _wrap_op_overload (orig , impl ):
126+ class _Wrapper :
127+ def __init__ (self , orig ):
128+ self ._orig = orig
129+
130+ def __call__ (self , * args , ** kwargs ):
131+ return impl (* args , ** kwargs )
132+
133+ def __getattr__ (self , name ):
134+ return getattr (self ._orig , name )
135+
136+ return _Wrapper (orig )
137+
138+ def unified_ascend_attention_with_output_impl (
121139 query : torch .Tensor ,
122140 key : torch .Tensor ,
123141 value : torch .Tensor ,
@@ -151,8 +169,13 @@ def unified_ascend_attention_with_output(
151169 maybe_save_kv_layer_to_connector (layer_name , kv_cache )
152170 return
153171
172+ vllm_ops .unified_ascend_attention_with_output = _wrap_op_overload (
173+ orig_unified_ascend_attention_with_output ,
174+ unified_ascend_attention_with_output_impl ,
175+ )
176+
154177 attention_v1 .unified_ascend_attention_with_output = (
155- unified_ascend_attention_with_output
178+ unified_ascend_attention_with_output_impl
156179 )
157180 except ImportError as e :
158181 logger .error (f"Failed to patch attention_v1.py: { e } " , exc_info = True )
0 commit comments