diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8ea67994ea2..6536eebb2c9 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -39,6 +39,8 @@ def __init__(self, vllm_config): self.expert_map_path = additional_config.get("expert_map_path", None) self.chunked_prefill_for_mla = additional_config.get( "chunked_prefill_for_mla", False) + self.fc_dual_batch = additional_config.get( + "fc_dual_batch", False) self.enable_weight_nz_layout = additional_config.get( "enable_weight_nz_layout", False) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 37edb9767a1..d68026c552c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -24,6 +24,9 @@ import torch_npu from torch import nn from transformers import PretrainedConfig + +from vllm.logger import logger + from vllm.attention import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, @@ -140,153 +143,361 @@ def fused_experts_with_mc2( is_torchair: bool = False, mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - quant_mode = 0 - ep_group = get_mc2_group() - ep_rank_id = ep_group.rank_in_group - ep_world_size = ep_group.world_size - tp_world_size = get_tp_group().world_size - - # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, - # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. - global_bs = ( - math.ceil(get_forward_context().max_tokens_across_dp / tp_world_size) * - ep_world_size) - - # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = get_ascend_soc_version( - ) == AscendSocVersion.A3 or is_torchair - - # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine - a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 - - enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") - - moe_expert_num = len(expert_map) - kwargs_mc2 = { - "x": hidden_states, - "expert_ids": topk_ids, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": global_bs, - } - - stage1_kwargs = { - "scales": None, - "quant_mode": quant_mode, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if need_extra_args: - stage1_kwargs.update({ - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage1_kwargs.update({ - "x_active_mask": mc2_mask, - }) - - kwargs_mc2.update(stage1_kwargs) - - output = torch_npu.npu_moe_distribute_dispatch_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( - **kwargs_mc2) - # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[ - 0:5] - - if shared_experts is not None: + """Fused MoE implementation with MC2 dispatch/combine + + Args: + hidden_states: Input tensor of shape [num_tokens, hidden_size] + w1: Expert gate-up weights of shape [num_experts, hidden_size, intermediate_size] + w2: Expert down weights of shape [num_experts, intermediate_size, hidden_size] + topk_weights: Selected expert weights of shape [num_tokens, top_k] + topk_ids: Selected expert indices of shape [num_tokens, top_k] + expert_map: Expert mapping tensor + moe_all_to_all_group_name: Communication group name + shared_experts: Optional shared experts module + is_torchair: Whether running in TorchAIR mode + mc2_mask: Optional mask tensor for A3 optimization + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Output tensor(s) + """ + + def _build_context() -> dict: + ep_group = get_mc2_group() + tp_group = get_tp_group() + forward_ctx = get_forward_context() + soc_version = get_ascend_soc_version() + + return { + # basic config + "moe_all_to_all_group_name": moe_all_to_all_group_name, + "is_torchair": is_torchair, + "mc2_mask": mc2_mask, + "ep_rank_id": ep_group.rank_in_group, + "ep_world_size": ep_group.world_size, + "tp_world_size": tp_group.world_size, + # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, + # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. + "global_bs": math.ceil( + forward_ctx.max_tokens_across_dp / tp_group.world_size + ) + * ep_group.world_size, + # feature switch + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + "enable_dispatch_v2": hasattr(torch_npu, "npu_moe_distribute_dispatch_v2"), + "a3_enabled": soc_version == AscendSocVersion.A3, + "need_extra_params": (soc_version == AscendSocVersion.A3) or is_torchair, + } + + def _process_shared_experts( + experts: Any, + h_states: torch.Tensor, + weights: torch.Tensor, + expand_x: torch.Tensor, + ) -> torch.Tensor: + """Compute shared expert activations in secondary stream""" with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(hidden_states, topk_weights) - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) + npu_wait_tensor(h_states, weights) + shared_gate_up, _ = experts.gate_up_proj(h_states) npu_wait_tensor(shared_gate_up, expand_x) - shared_act = shared_experts.act_fn(shared_gate_up) + return experts.act_fn(shared_gate_up) - w1 = w1.transpose(1, 2) + def _expert_forward( + expand_x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + expert_token_nums: torch.Tensor, + ) -> torch.Tensor: + """Execute expert forward computation (gate_up -> SwiGLU -> down)""" + # Gate-up projection + group_list = expert_token_nums.to(torch.int64) + gate_up_out_list = torch_npu.npu_grouped_matmul( + x=[expand_x], + weight=[w1], + split_item=2, + # 1 means count mode, to avoid cumulative operation of the group list + group_list_type=1, + group_type=0, + group_list=group_list, + ) + # TODO: Remove this in the future. + gate_up_out = torch.cat(gate_up_out_list, dim=0) + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + + # Down projection + down_list = torch_npu.npu_grouped_matmul( + x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=1, + group_type=0, + group_list=group_list, + ) + return torch.cat(down_list, dim=0) - group_list = expert_token_nums.to(torch.int64) - gate_up_out_list = torch_npu.npu_grouped_matmul( - x=[expand_x], - weight=[w1], - split_item=2, - # 1 means count mode, to avoid cumulative operation of the group list - group_list_type=1, - group_type=0, - group_list=group_list, - ) + def _build_dispatch_kwargs( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + moe_expert_num: int, + ctx: dict, + ) -> dict: + quant_mode = 0 + """Construct kwargs for MoE dispatch operation""" + base_kwargs = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": ctx["global_bs"], + "scales": None, + "quant_mode": quant_mode, + "group_ep": ctx["moe_all_to_all_group_name"], + "ep_world_size": ctx["ep_world_size"], + "ep_rank_id": ctx["ep_rank_id"], + } - # TODO: Remove this in the future. - gate_up_out = torch.cat(gate_up_out_list, dim=0) - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + if ctx["need_extra_params"]: + base_kwargs.update( + { + "group_tp": ctx["moe_all_to_all_group_name"], + "tp_world_size": 1, + "tp_rank_id": 0, + } + ) - w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=1, - group_type=0, - group_list=group_list, - ) + if ctx["a3_enabled"] and ctx["enable_dispatch_v2"]: + base_kwargs["x_active_mask"] = ctx["mc2_mask"] + + return base_kwargs + + def _build_combine_kwargs( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + moe_expert_num: int, + down_out: torch.Tensor, + ep_recv_counts: torch.Tensor, + tp_recv_counts: torch.Tensor, + assist_info: Any, + ctx: dict, + ) -> dict: + """Construct kwargs for MoE combine operation""" + base_kwargs = { + "expand_x": down_out, + "expert_ids": topk_ids, + "expert_scales": topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": ctx["global_bs"], + "ep_send_counts": ep_recv_counts, + "group_ep": ctx["moe_all_to_all_group_name"], + "ep_world_size": ctx["ep_world_size"], + "ep_rank_id": ctx["ep_rank_id"], + } - down_out_list = torch.cat(down_out_list, dim=0) + if ctx["enable_dispatch_v2"]: + base_kwargs["assist_info_for_combine"] = assist_info + else: + base_kwargs["expand_idx"] = assist_info + + if ctx["need_extra_params"]: + base_kwargs.update( + { + "tp_send_counts": tp_recv_counts, + "group_tp": ctx["moe_all_to_all_group_name"], + "tp_world_size": 1, + "tp_rank_id": 0, + } + ) + if ctx["a3_enabled"] and ctx["enable_dispatch_v2"]: + base_kwargs["x_active_mask"] = ctx["mc2_mask"] + + return base_kwargs + + def _prepare_return( + hidden_states: torch.Tensor, + shared_experts: Optional[Any], + shared_act: Optional[torch.Tensor], + down_out_list: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Construct final output with optional shared experts""" + if shared_experts is None: + return hidden_states - # moeCombine - kwargs_mc2 = { - "expand_x": down_out_list, - "expert_ids": topk_ids, - "expert_scales": topk_weights.to(torch.float32), - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": global_bs, - } - tp_recv_counts = output[5] - stage3_kwargs = { - "ep_send_counts": ep_recv_counts, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if enable_dispatch_v2: - stage3_kwargs.update({ - "assist_info_for_combine": - assist_info_for_combine, - }) - else: - stage3_kwargs.update({ - "expand_idx": assist_info_for_combine, - }) - if need_extra_args: - stage3_kwargs.update({ - "tp_send_counts": tp_recv_counts, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage3_kwargs.update({ - "x_active_mask": mc2_mask, - }) - kwargs_mc2.update(stage3_kwargs) - - hidden_states = torch_npu.npu_moe_distribute_combine_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( - **kwargs_mc2) - - if shared_experts is None: - return hidden_states - else: with npu_stream_switch("moe_secondary", 0): npu_wait_tensor(shared_act, down_out_list) shared_hidden_states, _ = shared_experts.down_proj(shared_act) return hidden_states, shared_hidden_states + def _get_dispatch_func(ctx: dict): + if ctx["enable_dispatch_v2"]: + return torch_npu.npu_moe_distribute_dispatch_v2 + else: + return torch_npu.npu_moe_distribute_dispatch + + def _get_combine_func(ctx: dict): + if ctx["enable_dispatch_v2"]: + return torch_npu.npu_moe_distribute_combine_v2 + else: + return torch_npu.npu_moe_distribute_combine + + def _single_stream_execution( + ctx: dict, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + moe_expert_num: int, + shared_experts: Optional[Any] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # MoE dispatch phase + dispatch_kwargs = _build_dispatch_kwargs( + hidden_states, topk_ids, moe_expert_num, ctx + ) + dispatch_output = _get_dispatch_func(ctx)(**dispatch_kwargs) + expand_x, _, assist_info, expert_token_nums, ep_recv_counts = dispatch_output[ + :5 + ] + # Shared experts computation (if any) + shared_act = None + if shared_experts is not None: + shared_act = _process_shared_experts( + experts=shared_experts, + h_states=hidden_states, + weights=topk_weights, + expand_x=expand_x, + ) + # Expert forward computation + w1_t = w1.transpose(1, 2) + w2_t = w2.transpose(1, 2) + down_out = _expert_forward( + expand_x=expand_x, w1=w1_t, w2=w2_t, expert_token_nums=expert_token_nums + ) + # MoE combine phase + combine_kwargs = _build_combine_kwargs( + topk_ids=topk_ids, + topk_weights=topk_weights, + moe_expert_num=moe_expert_num, + down_out=down_out, + ep_recv_counts=ep_recv_counts, + tp_recv_counts=dispatch_output[5], + assist_info=assist_info, + ctx=ctx, + ) + hidden_states = _get_combine_func(ctx)(**combine_kwargs) + # Final output preparation + return _prepare_return( + hidden_states=hidden_states, + shared_experts=shared_experts, + shared_act=shared_act, + down_out_list=down_out, + ) + + def _dual_stream_execution( + ctx: dict, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + moe_expert_num: int, + top_k: int, + w1: torch.Tensor, + w2: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # split experts ids and weights + split_idx = top_k // 2 + expert_groupA_size = split_idx + expert_groupB_size = top_k - split_idx + + topk_ids_A, topk_ids_B = torch.split( + topk_ids, [expert_groupA_size, expert_groupB_size], dim=1 + ) + topk_weights_A, topk_weights_B = torch.split( + topk_weights, [expert_groupA_size, expert_groupB_size], dim=1 + ) + + w1_t = w1.transpose(1, 2) + w2_t = w2.transpose(1, 2) + + # default stream:Dispatch A + kwargs = _build_dispatch_kwargs(hidden_states, topk_ids_A, moe_expert_num, ctx) + dispatchA_result = _get_dispatch_func(ctx)(**kwargs) + expand_xA, _, assist_infoA, expert_tokensA, ep_recvA = dispatchA_result[:5] + + # parallel execution: default stream: Compute A + secondary stream: Dispatch B + # default stream: Compute A + down_outA = _expert_forward(expand_xA, w1_t, w2_t, expert_tokensA) + # secondary stream: Dispatch B + dispatchB_result = None + with npu_stream_switch("moe_secondary1", 0): + npu_wait_tensor(dispatchB_result, expand_xA) + kwargs = _build_dispatch_kwargs( + hidden_states, topk_ids_B, moe_expert_num, ctx + ) + dispatchB_result = _get_dispatch_func(ctx)(**kwargs) + expand_xB, _, assist_infoB, expert_tokensB, ep_recvB = dispatchB_result[:5] + + # parallel execution: default stream: Combine A + secondary stream: Compute B + # default stream: Combine A + kwargs = _build_combine_kwargs( + topk_ids=topk_ids_A, + topk_weights=topk_weights_A, + moe_expert_num=moe_expert_num, + down_out=down_outA, + ep_recv_counts=ep_recvA, + tp_recv_counts=dispatchA_result[5], + assist_info=assist_infoA, + ctx=ctx, + ) + + resultA = _get_combine_func(ctx)(**kwargs) + # secondary stream: Compute B + down_outB = None + with npu_stream_switch("moe_secondary2", 0): + npu_wait_tensor(down_outB, expand_xB) + down_outB = _expert_forward(expand_xB, w1_t, w2_t, expert_tokensB) + + kwargs = _build_combine_kwargs( + topk_ids=topk_ids_B, + topk_weights=topk_weights_B, + moe_expert_num=moe_expert_num, + down_out=down_outB, + ep_recv_counts=ep_recvB, + tp_recv_counts=dispatchB_result[5], + assist_info=assist_infoB, + ctx=ctx, + ) + resultB = _get_combine_func(ctx)(**kwargs) + + return torch.add(resultA, resultB) + + ctx = _build_context() + + # Enabling conditions for dual-stream mode: Configuration is enabled and there are no shared experts. + use_dual_stream = get_ascend_config().fc_dual_batch and shared_experts is None + if use_dual_stream: + return _dual_stream_execution( + ctx=ctx, + hidden_states=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + moe_expert_num=len(expert_map), + top_k=top_k, + w1=w1, + w2=w2, + ) + + return _single_stream_execution( + ctx=ctx, + hidden_states=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + w1=w1, + w2=w2, + moe_expert_num=len(expert_map), + shared_experts=shared_experts, + ) def apply_mlp( hidden_states: torch.Tensor,