From c2c636145da30ee25a3459bee5670fe18b9aa889 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 23 Jan 2026 06:54:16 +0000 Subject: [PATCH 1/4] feat(r3): implement Rollout Router Replay (R3) for MoE routing capture This commit adds R3 (Rollout Router Replay) support to LightLLM, enabling capture and replay of MoE routing decisions for improved performance and debugging capabilities. Key changes: - Add routing_manager module for centralized routing capture/export - Implement routing buffer management with CPU pinned memory - Add R3 support to all MoE models (Mixtral, DeepSeek2, Qwen3-MoE, GPT-OSS) - Refactor MoE layer indexing to use auto-increment counters - Add API endpoints for routing capture control and export - Add comprehensive unit tests for R3 functionality The implementation uses a model-agnostic approach with explicit microbatch indexing passed through the call chain, eliminating reliance on global state. --- .gitignore | 1 + lightllm/common/basemodel/basemodel.py | 53 ++++- .../meta_weights/fused_moe_weight_ep.py | 22 ++- .../meta_weights/fused_moe_weight_tp.py | 37 +++- .../gpt_oss_fused_moe_weight_tp.py | 17 +- lightllm/common/basemodel/routing_manager.py | 184 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 2 + .../layer_weights/transformer_layer_weight.py | 1 + .../layer_infer/transformer_layer_infer.py | 2 + lightllm/models/llama/model.py | 11 +- .../layer_infer/transformer_layer_infer.py | 33 +--- .../layer_infer/transformer_layer_infer.py | 3 + .../layer_weights/transformer_layer_weight.py | 1 + lightllm/server/api_cli.py | 6 + lightllm/server/api_lightllm.py | 5 + lightllm/server/api_models.py | 5 +- lightllm/server/api_openai.py | 11 +- lightllm/server/core/objs/req.py | 41 ++++ lightllm/server/core/objs/sampling_params.py | 3 + lightllm/server/httpserver/manager.py | 20 ++ .../server/router/model_infer/infer_batch.py | 28 +++ .../mode_backend/chunked_prefill/impl.py | 25 +++ .../mode_backend/dp_backend/impl.py | 53 +++++ scripts/run_e2e_r3_test.sh | 109 +++++++++++ test_r3.py | 99 ++++++++++ unit_tests/__init__.py | 0 unit_tests/common/__init__.py | 0 unit_tests/common/basemodel/__init__.py | 0 .../basemodel/test_routing_capture_manager.py | 132 +++++++++++++ 29 files changed, 868 insertions(+), 36 deletions(-) create mode 100644 lightllm/common/basemodel/routing_manager.py create mode 100755 scripts/run_e2e_r3_test.sh create mode 100644 test_r3.py create mode 100644 unit_tests/__init__.py create mode 100644 unit_tests/common/__init__.py create mode 100644 unit_tests/common/basemodel/__init__.py create mode 100644 unit_tests/common/basemodel/test_routing_capture_manager.py diff --git a/.gitignore b/.gitignore index 63408699f..3fb49db8b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 26d51af3d..19f5ec4ee 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,6 +11,11 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.routing_manager import ( + create_routing_capture_manager, + reset_moe_layer_counter, + get_moe_layer_count, +) from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -169,6 +174,7 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): + reset_moe_layer_counter() self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( @@ -276,7 +282,45 @@ def _init_prefill_cuda_graph(self): self.prefill_graph.warmup(self) def _init_custom(self): - pass + if self.args.enable_return_routed_experts: + # Get MoE layer count from counter (set during _init_weights) + num_moe_layers = get_moe_layer_count() + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. " + "Routing capture will not be enabled." + ) + return + + # Get MoE parameters from model config + n_routed_experts = self.config.get("n_routed_experts", self.config.get("num_experts", 0)) + if n_routed_experts == 0: + logger.warning( + "enable_return_routed_experts is set but n_routed_experts=0. " + "Routing capture will not be enabled." + ) + return + + topk = self.config.get("num_experts_per_tok", 1) + num_experts = n_routed_experts + + # Check if overlap mode is enabled + enable_overlap = getattr(self.args, "enable_decode_microbatch_overlap", False) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=self.max_total_token_num, + kv_cache_size=self.mem_manager.size, + enable_overlap=enable_overlap, + ) + return @torch.no_grad() def forward(self, model_input: ModelInput): @@ -284,9 +328,12 @@ def forward(self, model_input: ModelInput): assert model_input.mem_indexes.is_cuda if model_input.is_prefill: - return self._prefill(model_input) + result = self._prefill(model_input) else: - return self._decode(model_input) + result = self._decode(model_input) + + # Note: flush is now handled by backend layer (ChunkedPrefill, DP, etc.) + return result def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): infer_state = self.infer_state_class() diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index 7dc5b5fdc..f2557af85 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -23,6 +23,7 @@ from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair from lightllm.utils.log_utils import init_logger from lightllm.common.triton_utils.autotuner import Autotuner +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager, get_next_moe_layer_index logger = init_logger(__name__) @@ -43,7 +44,7 @@ def __init__( quant_cfg=None, ) -> None: super().__init__() - + self.moe_layer_index = get_next_moe_layer_index() self.layer_num = layer_num self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") self.quantized_weight = quant_cfg.quantized_weight @@ -115,6 +116,7 @@ def experts( topk_group, num_expert_group, is_prefill, + microbatch_index: int = 0, ): topk_weights, topk_ids = select_experts( hidden_states=input_tensor, @@ -139,6 +141,11 @@ def experts( enable_counter=self.auto_update_redundancy_expert, ) + # Capture with explicit layer index + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 return fused_experts_impl( @@ -162,6 +169,7 @@ def low_latency_dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + microbatch_index: int = 0, ): topk_weights, topk_idx = select_experts( @@ -187,6 +195,11 @@ def low_latency_dispatch( enable_counter=self.auto_update_redundancy_expert, ) + # Capture with explicit layer index + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_idx, microbatch_index) + topk_idx = topk_idx.to(torch.long) num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( @@ -204,6 +217,7 @@ def select_experts_and_quant_input( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + microbatch_index: int = 0, ): topk_weights, topk_idx = select_experts( hidden_states=hidden_states, @@ -226,6 +240,12 @@ def select_experts_and_quant_input( expert_counter=self.routed_expert_counter_tensor, enable_counter=self.auto_update_redundancy_expert, ) + + # Capture with explicit layer index + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_idx, microbatch_index) + M, K = hidden_states.shape w1, w1_scale = self.w1 block_size_k = 0 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 9295fa96a..c60d40814 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -5,6 +5,7 @@ from .base_weight import BaseWeight from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager, get_next_moe_layer_index def create_tp_moe_wegiht_obj( @@ -71,6 +72,7 @@ def __init__( quant_cfg: Quantcfg = None, ) -> None: super().__init__() + self.moe_layer_index = get_next_moe_layer_index() self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") self.quantized_weight = quant_cfg.quantized_weight if self.quant_method is not None: @@ -101,7 +103,17 @@ def __init__( self.w2 = [None, None] # weight, weight_scale self.lock = threading.Lock() - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + def experts( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + microbatch_index: int = 0, + ): from lightllm.common.fused_moe.topk_select import select_experts topk_weights, topk_ids = select_experts( @@ -116,6 +128,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t scoring_func=self.scoring_func, ) topk_weights.mul_(self.routed_scaling_factor) + + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( @@ -315,6 +332,7 @@ def __init__( quant_cfg: Quantcfg = None, ) -> None: super().__init__() + self.moe_layer_index = get_next_moe_layer_index() self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") self.quantized_weight = quant_cfg.quantized_weight if self.quant_method is not None: @@ -356,7 +374,17 @@ def __init__( self.w2 = [None, None, None] # weight, weight_scale, zero_point self.lock = threading.Lock() - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + def experts( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + microbatch_index: int = 0, + ): from lightllm.common.fused_moe.topk_select import select_experts topk_weights, topk_ids = select_experts( @@ -371,6 +399,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t scoring_func=self.scoring_func, ) topk_weights.mul_(self.routed_scaling_factor) + + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py index df72cc620..8397ee74f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py @@ -6,6 +6,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -121,9 +122,23 @@ def router(self, router_logits, top_k): router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_top_value, router_indices - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + def experts( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + microbatch_index: int = 0, + ): topk_weights, topk_ids = self.router(router_logits, top_k) + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 000000000..ea6d1eea7 --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,184 @@ +import torch +import numpy as np +from typing import Optional +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +# MoE layer counter for auto-incrementing moe_layer_index +_moe_layer_counter: int = 0 + + +def reset_moe_layer_counter() -> None: + """Reset MoE layer counter. Call before model weight initialization.""" + global _moe_layer_counter + _moe_layer_counter = 0 + + +def get_next_moe_layer_index() -> int: + """Get and increment counter. Called by FusedMoeWeight* during init.""" + global _moe_layer_counter + idx = _moe_layer_counter + _moe_layer_counter += 1 + return idx + + +def get_moe_layer_count() -> int: + """Get total MoE layers after weight init (for routing manager creation).""" + return _moe_layer_counter + + +class RoutingCaptureManager: + """Captures MoE routing decisions for export and analysis. + + Supports: + - Explicit moe_layer_index (no auto-increment) + - Double-buffered GPU buffers for two-batch overlap + - Async GPU->CPU flush with CUDA streams + - MTP via computed layer index offsets + """ + + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + batch_max_tokens: int, + kv_cache_size: int, + enable_overlap: bool = False, + ): + """Initialize routing capture buffers. + + Args: + num_moe_layers: Total MoE layers (main + draft models for MTP) + topk: Number of experts selected per token + num_experts: Total experts (determines int8 vs int16 dtype) + batch_max_tokens: Max tokens per batch + kv_cache_size: Size of KV cache (for mem_index mapping) + enable_overlap: Enable double-buffering for two-batch overlap + """ + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.batch_max_tokens = batch_max_tokens + self.kv_cache_size = kv_cache_size + + # Choose dtype based on number of experts + self.dtype = torch.int8 if num_experts <= 127 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.int8 else 2 + + # Number of GPU buffer slots (2 for overlap mode, 1 otherwise) + self.num_slots = 2 if enable_overlap else 1 + + # GPU buffer: [num_slots, num_moe_layers, batch_max_tokens, topk] + gpu_buffer_size = self.num_slots * num_moe_layers * batch_max_tokens * topk * dtype_bytes + self.gpu_buffer = torch.zeros( + (self.num_slots, num_moe_layers, batch_max_tokens, topk), + dtype=self.dtype, + device="cuda", + ) + + # CPU buffer: [num_moe_layers, kv_cache_size, topk] + cpu_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.cpu_buffer = torch.zeros( + (num_moe_layers, kv_cache_size, topk), + dtype=self.dtype, + device="cpu", + pin_memory=True, + ) + + # Per-slot async flush + self.flush_streams = [torch.cuda.Stream() for _ in range(self.num_slots)] + self.flush_events = [torch.cuda.Event() for _ in range(self.num_slots)] + + dtype_name = "int8" if self.dtype == torch.int8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"slots={self.num_slots}, GPU={gpu_buffer_size / 1024 / 1024:.2f}MB, " + f"CPU={cpu_buffer_size / 1024 / 1024:.2f}MB, dtype={dtype_name}" + ) + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + """Capture routing decision for a specific MoE layer. + + Args: + moe_layer_index: Explicit index (0 to num_moe_layers-1) + topk_ids: Shape [num_tokens, topk] + microbatch_index: Current microbatch index for overlap mode (default 0) + """ + assert ( + 0 <= moe_layer_index < self.num_moe_layers + ), f"moe_layer_index {moe_layer_index} out of range [0, {self.num_moe_layers})" + slot = microbatch_index % self.num_slots + num_tokens = topk_ids.shape[0] + self.gpu_buffer[slot, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype) + + def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None: + """Async flush GPU buffer to CPU buffer at mem_index positions. + + Called by backend after forward pass completes. + + Args: + mem_indexes: Shape [num_tokens] - KV cache slot indices + microbatch_index: Which microbatch slot to flush from + """ + num_tokens = mem_indexes.shape[0] + if num_tokens == 0: + return + + slot = microbatch_index % self.num_slots + stream = self.flush_streams[slot] + event = self.flush_events[slot] + + # Wait for inference on this slot to complete + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream): + cpu_indexes = mem_indexes.cpu() + self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu() + event.record() + + def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray: + """Extract routing data for a completed request. + + Args: + mem_indexes: KV cache slots belonging to this request (CPU tensor) + + Returns: + numpy array of shape [num_moe_layers, num_tokens, topk] + """ + # Synchronize all pending flushes + for event in self.flush_events: + event.synchronize() + return self.cpu_buffer[:, mem_indexes, :].numpy() + + +# Global instance +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + batch_max_tokens: int, + kv_cache_size: int, + enable_overlap: bool = False, +) -> None: + """Initialize the global routing capture manager.""" + global g_routing_capture_manager + if g_routing_capture_manager is not None: + logger.warning("RoutingCaptureManager already exists, replacing") + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=batch_max_tokens, + kv_cache_size=kv_cache_size, + enable_overlap=enable_overlap, + ) + + +def get_routing_capture_manager() -> Optional[RoutingCaptureManager]: + """Get the global routing capture manager.""" + return g_routing_capture_manager diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 8695f2de8..09661098d 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -325,6 +325,7 @@ def _moe_ffn( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -351,6 +352,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index c5a2d3352..eedec5f85 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -246,6 +246,7 @@ def _load_mlp(self, mlp_prefix): def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] + self.moe_gate = ROWMMWeight( weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d80eefd16..0fab063da 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -42,6 +42,7 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6): def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) hidden_states = layer_weight.experts.experts( hidden_states, @@ -51,6 +52,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index c104ebccc..f5a66a6c3 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -74,14 +74,18 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() - return - - if "rope_type" in rope_scaling: + elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") + super()._init_custom() + + def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): + """Initialize rotary embeddings based on scaling type.""" if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -96,7 +100,6 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 44e66cff2..a2968f5ab 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -19,25 +16,15 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_)) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, - ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - return fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index c85c423c2..27d1945a9 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -131,6 +131,7 @@ def _moe_ffn( hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) layer_weight.experts.experts( hidden_states, @@ -140,6 +141,7 @@ def _moe_ffn( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) return hidden_states.view(num_tokens, hidden_dim) @@ -160,6 +162,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 486f4d696..be130dcaf 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -60,6 +60,7 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, ) + moe_mode = os.getenv("MOE_MODE", "TP") assert moe_mode in ["EP", "TP"] if moe_mode == "TP": diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 44cc38822..b937bb8c6 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -199,6 +199,12 @@ def make_argument_parser() -> argparse.ArgumentParser: choices=["round_robin", "bs_balancer"], help="the dp balancer type, default is bs_balancer", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) parser.add_argument( "--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len" ) diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f5..5abd90815 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -53,6 +53,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +79,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +105,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index f30ecc55f..38edd71fb 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -220,7 +220,8 @@ class ChatCompletionRequest(BaseModel): role_settings: Optional[Dict[str, str]] = None character_settings: Optional[List[Dict[str, str]]] = None - # Class variables to store loaded default values + return_routed_experts: Optional[bool] = False + _loaded_defaults: ClassVar[Dict[str, Any]] = {} @classmethod @@ -279,6 +280,8 @@ class ChatCompletionResponse(BaseModel): model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo + # R3: Routing data for MoE models (when return_routed_experts=True) + routed_experts: Optional[Dict[str, Any]] = None @field_validator("id", mode="before") def ensure_id_is_str(cls, v): diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d91bb1d94..1e608f030 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -208,6 +208,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "n": request.n, "best_of": request.n, "add_special_tokens": False, + "return_routed_experts": request.return_routed_experts, } # Structured output handling @@ -237,6 +238,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req finish_reason_dict = {} prompt_tokens_dict = {} completion_tokens = 0 + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: from .req_id_generator import convert_sub_id_to_group_id @@ -246,6 +248,8 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status.get_finish_reason() prompt_tokens_dict[sub_req_id] = metadata["prompt_tokens"] + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] choices = [] sub_ids = list(final_output_dict.keys())[: request.n] for i in range(request.n): @@ -325,7 +329,12 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req ) choices.append(choice) resp = ChatCompletionResponse( - id=group_request_id, created=created_time, model=request.model, choices=choices, usage=usage + id=group_request_id, + created=created_time, + model=request.model, + choices=choices, + usage=usage, + routed_experts=routed_experts_data, ) return resp diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f489aac9c..128423e6e 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -122,6 +122,9 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), + ("routing_data_num_moe_layers", ctypes.c_int), + ("routing_data_num_tokens", ctypes.c_int), + ("routing_data_topk", ctypes.c_int), ] def get_str(self): @@ -180,6 +183,10 @@ def init( self.stop_str_matched = False self.stop_str_matched_token_index = -1 + self.routing_data_num_moe_layers = 0 + self.routing_data_num_tokens = 0 + self.routing_data_topk = 0 + self.post_init() self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size @@ -227,6 +234,40 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return + def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int): + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (num_moe_layers, num_tokens, topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) + self.shm_routing_data.create_shm() + self.routing_data_num_moe_layers = num_moe_layers + self.routing_data_num_tokens = num_tokens + self.routing_data_topk = topk + return + + def link_routing_data_shm_array(self): + if self.routing_data_num_moe_layers == 0: + return + service_uni_name = get_unique_server_name() + name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" + shape = (self.routing_data_num_moe_layers, self.routing_data_num_tokens, self.routing_data_topk) + self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) + self.shm_routing_data.link_shm() + return + + def get_routing_data(self): + if self.routing_data_num_moe_layers == 0 or not hasattr(self, "shm_routing_data"): + return None + if self.shm_routing_data is None: + return None + return self.shm_routing_data.arr + + def close_routing_data_shm_array(self): + if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: + self.shm_routing_data.close_shm() + self.shm_routing_data = None + return + def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index d955aa6a8..cf13e5d85 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -321,6 +321,7 @@ class SamplingParams(ctypes.Structure): ), # whether to add spaces between special tokens when decoding ("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True ("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache + ("return_routed_experts", ctypes.c_bool), ] _do_sample: bool = False @@ -352,6 +353,7 @@ def init(self, tokenizer, **kwargs): self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) + self.return_routed_experts = kwargs.get("return_routed_experts", False) self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) @@ -497,6 +499,7 @@ def to_dict(self): "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, "print_eos_token": self.print_eos_token, "disable_prompt_cache": self.disable_prompt_cache, + "return_routed_experts": self.return_routed_experts, } def to_origin_dict(self): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 212e037e9..bd92a6b7d 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import base64 from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -686,6 +687,11 @@ async def recycle_resource_loop(self): for req_status in release_req_status: self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) for req in req_status.group_req_objs.shm_req_objs: + if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: + try: + req.close_routing_data_shm_array() + except Exception as e: + logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) @@ -773,6 +779,20 @@ async def handle_loop(self): else: finish_status = FinishStatus(req.finish_status.status) + if req.sample_params.return_routed_experts and req.routing_data_num_moe_layers > 0: + try: + req.link_routing_data_shm_array() + routing_data = req.get_routing_data() + if routing_data is not None: + metadata["routed_experts"] = { + "shape": list(routing_data.shape), + "dtype": str(routing_data.dtype), + "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), + } + req.close_routing_data_shm_array() + except Exception as e: + logger.warning(f"Failed to read routing data for req {req_id}: {e}") + token_list.append((req_id, text, metadata, finish_status)) else: break diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538..94f74165c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -113,6 +113,29 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs + def _extract_routing_data(self, req: "InferReq"): + routing_manager = self.req_manager.mem_manager.routing_manager + if routing_manager is None: + logger.debug(f"R3: routing_manager is None for req {req.req_id}") + return + if not req.shm_req.sample_params.return_routed_experts: + logger.debug(f"R3: return_routed_experts is False for req {req.req_id}") + return + if req.cur_kv_len <= 0: + logger.debug(f"R3: cur_kv_len <= 0 for req {req.req_id}") + return + + mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] + num_moe_layers = routing_manager.num_moe_layers + topk = routing_manager.topk + num_tokens = req.cur_kv_len + + logger.debug(f"R3: Extracting routing for req {req.req_id}: {num_moe_layers}x{num_tokens}x{topk}") + routing_data = routing_manager.extract_for_request(mem_indexes.cpu()) + req.shm_req.create_routing_data_shm_array(num_moe_layers, num_tokens, topk) + req.shm_req.shm_routing_data.arr[:] = routing_data + logger.debug(f"R3: Successfully extracted routing data for req {req.req_id}") + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -155,6 +178,9 @@ def _filter(self, finished_request_ids: List[int]): req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() + + self._extract_routing_data(req) + self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) @@ -580,6 +606,8 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + # Extract routing data before setting finish_status so HTTP server sees it + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224eb..44940d933 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -24,6 +24,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager from .control_state import ControlState logger = init_logger(__name__) @@ -109,6 +110,12 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + + # Flush routing capture to CPU + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) + _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -148,6 +155,12 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + + # Flush routing capture to CPU + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) + _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -186,6 +199,12 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + + # Flush routing capture to CPU + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) + next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -236,6 +255,12 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) + + # Flush routing capture to CPU + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) + next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e7..e1dcdcd1b 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -22,6 +22,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager from .control_state import DPControlState @@ -145,6 +146,12 @@ def prefill_normal( run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + + # Flush routing capture to CPU + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) + if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -188,6 +195,12 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + + # Flush routing capture to CPU + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) + if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -236,6 +249,13 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + + # Flush routing capture to CPU for both microbatches + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input0.mem_indexes, microbatch_index=0) + routing_manager.flush_to_cpu_async(model_input1.mem_indexes, microbatch_index=1) + logits0 = model_output0.logits logits1 = model_output1.logits @@ -305,6 +325,13 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + + # Flush routing capture to CPU for both microbatches + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input0.mem_indexes, microbatch_index=0) + routing_manager.flush_to_cpu_async(model_input1.mem_indexes, microbatch_index=1) + logits0 = model_output0.logits logits1 = model_output1.logits @@ -359,6 +386,12 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) + + # Flush routing capture to CPU + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) + b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] b_req_idx = model_input.b_req_idx[0:req_num] @@ -421,6 +454,12 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + + # Flush routing capture to CPU + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) + mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: logits = model_output.logits[0:req_num, :] @@ -629,6 +668,13 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + + # Flush routing capture to CPU for both microbatches + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input0.mem_indexes, microbatch_index=0) + routing_manager.flush_to_cpu_async(model_input1.mem_indexes, microbatch_index=1) + logits0 = model_output0.logits logits1 = model_output1.logits req_num0, req_num1 = len(run_reqs0), len(run_reqs1) @@ -728,6 +774,13 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + + # Flush routing capture to CPU for both microbatches + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.flush_to_cpu_async(model_input0.mem_indexes, microbatch_index=0) + routing_manager.flush_to_cpu_async(model_input1.mem_indexes, microbatch_index=1) + logits0 = model_output0.logits logits1 = model_output1.logits run_reqs = run_reqs0 + run_reqs1 diff --git a/scripts/run_e2e_r3_test.sh b/scripts/run_e2e_r3_test.sh new file mode 100755 index 000000000..3100f18a7 --- /dev/null +++ b/scripts/run_e2e_r3_test.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# E2E Test Script for R3 Routing Capture Feature +# +# This script starts a LightLLM server with routing capture enabled, +# runs the client test, and verifies the results. +# +# Requirements: +# - A MoE model (DeepSeek-V2/V3, Qwen-MoE, Mixtral, etc.) +# - At least 1 GPU with sufficient memory +# - LightLLM installed +# +# Usage: +# ./scripts/run_e2e_r3_test.sh /path/to/moe/model [--tp N] + +set -e + +MODEL_DIR="${1:-}" +TP="${2:-1}" +PORT=8765 + +if [ -z "$MODEL_DIR" ]; then + echo "Usage: $0 /path/to/moe/model [--tp N]" + echo "" + echo "Example:" + echo " $0 /models/DeepSeek-V3 --tp 8" + echo " $0 /models/Qwen-MoE-A14B --tp 4" + exit 1 +fi + +if [ ! -d "$MODEL_DIR" ]; then + echo "ERROR: Model directory not found: $MODEL_DIR" + exit 1 +fi + +echo "==========================================" +echo "R3 E2E Test: Routing Capture Feature" +echo "==========================================" +echo "Model: $MODEL_DIR" +echo "TP: $TP" +echo "Port: $PORT" +echo "" + +# Kill any existing server on the port +pkill -f "lightllm.server.api_server.*--port $PORT" 2>/dev/null || true +sleep 2 + +# Start server in background +echo "Starting LightLLM server..." +python -m lightllm.server.api_server \ + --model_dir "$MODEL_DIR" \ + --tp "$TP" \ + --port "$PORT" \ + --enable_return_routed_experts \ + --max_total_token_num 8000 \ + --batch_max_tokens 4000 \ + > /tmp/lightllm_r3_test.log 2>&1 & + +SERVER_PID=$! +echo "Server PID: $SERVER_PID" +echo "Log: /tmp/lightllm_r3_test.log" + +# Wait for server to be ready +echo "Waiting for server to be ready..." +MAX_WAIT=300 +WAITED=0 +while [ $WAITED -lt $MAX_WAIT ]; do + if curl -s "http://localhost:$PORT/health" > /dev/null 2>&1; then + echo "Server is ready!" + break + fi + sleep 5 + WAITED=$((WAITED + 5)) + echo " Waited ${WAITED}s..." +done + +if [ $WAITED -ge $MAX_WAIT ]; then + echo "ERROR: Server failed to start within ${MAX_WAIT}s" + echo "Server log:" + tail -50 /tmp/lightllm_r3_test.log + kill $SERVER_PID 2>/dev/null || true + exit 1 +fi + +# Run client test +echo "" +echo "Running R3 client test..." +echo "==========================================" +python test_r3.py --url "http://localhost:$PORT" +TEST_RESULT=$? + +# Cleanup +echo "" +echo "Stopping server..." +kill $SERVER_PID 2>/dev/null || true +wait $SERVER_PID 2>/dev/null || true + +# Report result +echo "" +echo "==========================================" +if [ $TEST_RESULT -eq 0 ]; then + echo "E2E TEST PASSED!" +else + echo "E2E TEST FAILED!" + echo "Server log (last 30 lines):" + tail -30 /tmp/lightllm_r3_test.log +fi +echo "==========================================" + +exit $TEST_RESULT diff --git a/test_r3.py b/test_r3.py new file mode 100644 index 000000000..14157fab5 --- /dev/null +++ b/test_r3.py @@ -0,0 +1,99 @@ +""" +R3 Client Test: Tests the routing capture export feature. + +This test requires a running LightLLM server with: +- A MoE model (e.g., DeepSeek-V2/V3) +- --enable_return_routed_experts flag + +Usage: + python test_r3.py [--url URL] +""" +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + """Test the routing export feature.""" + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + "return_routed_experts": True, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + # Check for routed_experts in response + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + # Decode routed_experts from base64 + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype = np.dtype(routing_info["dtype"]) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape} # [num_moe_layers, num_tokens, topk]") + print(f"Dtype: {dtype}") + print(f"Num MoE layers: {shape[0]}") + print(f"Num tokens: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Show sample of routing data + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = min(5, shape[1]) + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") + + # Validate data + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 000000000..8d5a1d0bf --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,132 @@ +import pytest +import torch +import numpy as np + + +def test_moe_layer_counter(): + """Counter increments and resets correctly.""" + from lightllm.common.basemodel.routing_manager import ( + reset_moe_layer_counter, + get_next_moe_layer_index, + get_moe_layer_count, + ) + + reset_moe_layer_counter() + assert get_moe_layer_count() == 0 + + assert get_next_moe_layer_index() == 0 + assert get_next_moe_layer_index() == 1 + assert get_next_moe_layer_index() == 2 + assert get_moe_layer_count() == 3 + + reset_moe_layer_counter() + assert get_moe_layer_count() == 0 + assert get_next_moe_layer_index() == 0 + + +class TestRoutingCaptureManager: + """Tests for the redesigned RoutingCaptureManager.""" + + def test_capture_explicit_layer_index(self): + """Capture stores data at explicit moe_layer_index.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + batch_max_tokens=128, + kv_cache_size=1024, + enable_overlap=False, + ) + + # Capture at layer 2 (not sequential) + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=2, topk_ids=topk_ids) + + # Verify data is at layer 2, not layer 0 + assert torch.equal(manager.gpu_buffer[0, 2, :10, :], topk_ids.to(manager.dtype)) + + def test_double_buffer_overlap_mode(self): + """Double buffer prevents race condition in overlap mode.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + batch_max_tokens=64, + kv_cache_size=256, + enable_overlap=True, + ) + + # Should have 2 buffer slots + assert manager.num_slots == 2 + assert manager.gpu_buffer.shape[0] == 2 + + # Capture to slot 0 (microbatch_index=0) + ids_0 = torch.ones((5, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + # Capture to slot 1 (microbatch_index=1) + ids_1 = torch.ones((5, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + # Both slots have different data + assert manager.gpu_buffer[0, 0, 0, 0].item() == 1 + assert manager.gpu_buffer[1, 0, 0, 0].item() == 2 + + def test_flush_and_extract(self): + """Flush transfers data to CPU, extract retrieves by mem_index.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + batch_max_tokens=64, + kv_cache_size=256, + enable_overlap=False, + ) + + # Capture some data (microbatch_index defaults to 0) + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids) + manager.capture(moe_layer_index=1, topk_ids=topk_ids + 10) + + # Flush to mem_indexes 10 and 11 + mem_indexes = torch.tensor([10, 11], device="cuda") + manager.flush_to_cpu_async(mem_indexes, microbatch_index=0) + + # Extract + result = manager.extract_for_request(mem_indexes.cpu()) + + assert result.shape == (2, 2, 4) # [layers, tokens, topk] + assert result[0, 0, 0] == 1 + assert result[1, 0, 0] == 11 + + def test_dtype_selection(self): + """Uses int8 for <=127 experts, int16 otherwise.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + # Small expert count -> int8 + manager_small = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + batch_max_tokens=32, + kv_cache_size=128, + enable_overlap=False, + ) + assert manager_small.dtype == torch.int8 + + # Large expert count -> int16 + manager_large = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + batch_max_tokens=32, + kv_cache_size=128, + enable_overlap=False, + ) + assert manager_large.dtype == torch.int16 From f900dfa8f041021105551c9f1bdb528d46f15872 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 23 Jan 2026 07:04:25 +0000 Subject: [PATCH 2/4] fix(r3): use global accessor for routing_manager in _extract_routing_data The routing_manager was incorrectly accessed as an attribute of mem_manager, but it's a module-level global accessed via get_routing_capture_manager(). This caused routing data extraction to silently fail. --- lightllm/server/router/model_infer/infer_batch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 94f74165c..ba32ebe1d 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -19,6 +19,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager logger = init_logger(__name__) @@ -114,7 +115,7 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs def _extract_routing_data(self, req: "InferReq"): - routing_manager = self.req_manager.mem_manager.routing_manager + routing_manager = get_routing_capture_manager() if routing_manager is None: logger.debug(f"R3: routing_manager is None for req {req.req_id}") return From b1094e73fcd429bbc8061d0f3720899d9f35a432 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 23 Jan 2026 07:40:17 +0000 Subject: [PATCH 3/4] fix(r3): remove R3 support from /v1 OpenAI-compatible endpoints Keep R3 routing capture feature only on /generate endpoint to maintain OpenAI API compatibility. The /v1/chat/completions endpoint should not have non-standard fields like return_routed_experts or routed_experts. --- lightllm/server/api_models.py | 4 ---- lightllm/server/api_openai.py | 5 ----- 2 files changed, 9 deletions(-) diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 38edd71fb..6386f6caf 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -220,8 +220,6 @@ class ChatCompletionRequest(BaseModel): role_settings: Optional[Dict[str, str]] = None character_settings: Optional[List[Dict[str, str]]] = None - return_routed_experts: Optional[bool] = False - _loaded_defaults: ClassVar[Dict[str, Any]] = {} @classmethod @@ -280,8 +278,6 @@ class ChatCompletionResponse(BaseModel): model: str choices: List[ChatCompletionResponseChoice] usage: UsageInfo - # R3: Routing data for MoE models (when return_routed_experts=True) - routed_experts: Optional[Dict[str, Any]] = None @field_validator("id", mode="before") def ensure_id_is_str(cls, v): diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 1e608f030..39105db3a 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -208,7 +208,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "n": request.n, "best_of": request.n, "add_special_tokens": False, - "return_routed_experts": request.return_routed_experts, } # Structured output handling @@ -238,7 +237,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req finish_reason_dict = {} prompt_tokens_dict = {} completion_tokens = 0 - routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: from .req_id_generator import convert_sub_id_to_group_id @@ -248,8 +246,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status.get_finish_reason() prompt_tokens_dict[sub_req_id] = metadata["prompt_tokens"] - if "routed_experts" in metadata: - routed_experts_data = metadata["routed_experts"] choices = [] sub_ids = list(final_output_dict.keys())[: request.n] for i in range(request.n): @@ -334,7 +330,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req model=request.model, choices=choices, usage=usage, - routed_experts=routed_experts_data, ) return resp From ca0c282b1584730402f4ddfe39bf7bf1c7ec0f4c Mon Sep 17 00:00:00 2001 From: sufubao Date: Mon, 26 Jan 2026 07:03:52 +0000 Subject: [PATCH 4/4] draft --- lightllm/common/basemodel/basemodel.py | 55 +++--------- lightllm/common/basemodel/moe_model_mixin.py | 89 +++++++++++++++++++ lightllm/common/basemodel/routing_manager.py | 28 ++++++ .../layer_weights/transformer_layer_weight.py | 1 - lightllm/models/deepseek2/model.py | 4 +- .../layer_infer/transformer_layer_infer.py | 1 - lightllm/models/gpt_oss/model.py | 7 +- lightllm/models/mixtral/model.py | 4 +- .../layer_infer/transformer_layer_infer.py | 1 - .../layer_weights/transformer_layer_weight.py | 1 - lightllm/models/qwen3_moe/model.py | 4 +- lightllm/server/api_cli.py | 12 +-- lightllm/server/api_models.py | 1 + lightllm/server/api_openai.py | 6 +- lightllm/server/core/objs/req.py | 41 --------- lightllm/server/core/objs/sampling_params.py | 3 - lightllm/server/httpserver/manager.py | 19 ---- .../server/router/model_infer/infer_batch.py | 27 ------ .../mode_backend/chunked_prefill/impl.py | 21 ----- .../mode_backend/dp_backend/impl.py | 45 ---------- 20 files changed, 152 insertions(+), 218 deletions(-) create mode 100644 lightllm/common/basemodel/moe_model_mixin.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 19f5ec4ee..b108db1b4 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,11 +11,7 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo -from lightllm.common.basemodel.routing_manager import ( - create_routing_capture_manager, - reset_moe_layer_counter, - get_moe_layer_count, -) +from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -282,45 +278,16 @@ def _init_prefill_cuda_graph(self): self.prefill_graph.warmup(self) def _init_custom(self): - if self.args.enable_return_routed_experts: - # Get MoE layer count from counter (set during _init_weights) - num_moe_layers = get_moe_layer_count() - if num_moe_layers == 0: - logger.warning( - "enable_return_routed_experts is set but no MoE layers found. " - "Routing capture will not be enabled." - ) - return - - # Get MoE parameters from model config - n_routed_experts = self.config.get("n_routed_experts", self.config.get("num_experts", 0)) - if n_routed_experts == 0: - logger.warning( - "enable_return_routed_experts is set but n_routed_experts=0. " - "Routing capture will not be enabled." - ) - return + """Hook for model-specific initialization. Override in subclasses.""" + pass - topk = self.config.get("num_experts_per_tok", 1) - num_experts = n_routed_experts + def _post_forward(self, model_input: ModelInput, microbatch_index: int = 0) -> None: + """Hook called after forward pass completes. Override in subclasses for post-processing.""" + pass - # Check if overlap mode is enabled - enable_overlap = getattr(self.args, "enable_decode_microbatch_overlap", False) - - logger.info( - f"Initializing routing capture: num_moe_layers={num_moe_layers}, " - f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}" - ) - - create_routing_capture_manager( - num_moe_layers=num_moe_layers, - topk=topk, - num_experts=num_experts, - batch_max_tokens=self.max_total_token_num, - kv_cache_size=self.mem_manager.size, - enable_overlap=enable_overlap, - ) - return + def _post_forward_dual(self, model_input0: ModelInput, model_input1: ModelInput) -> None: + """Hook called after dual microbatch forward pass completes. Override in subclasses.""" + pass @torch.no_grad() def forward(self, model_input: ModelInput): @@ -332,7 +299,7 @@ def forward(self, model_input: ModelInput): else: result = self._decode(model_input) - # Note: flush is now handled by backend layer (ChunkedPrefill, DP, etc.) + self._post_forward(model_input) return result def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): @@ -726,6 +693,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod dist_group_manager.clear_deepep_buffer() model_output0.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event model_output1.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event + self._post_forward_dual(model_input0, model_input1) return model_output0, model_output1 @torch.no_grad() @@ -819,6 +787,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.init_att_state() model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1) + self._post_forward_dual(model_input0, model_input1) return model_output0, model_output1 @final diff --git a/lightllm/common/basemodel/moe_model_mixin.py b/lightllm/common/basemodel/moe_model_mixin.py new file mode 100644 index 000000000..4d730954d --- /dev/null +++ b/lightllm/common/basemodel/moe_model_mixin.py @@ -0,0 +1,89 @@ +"""Mixin for MoE (Mixture of Experts) models. + +Provides R3 (Rollout Router Replay) routing capture functionality for MoE models. +MoE models that want R3 support should inherit from this mixin and call +`_init_routing_capture()` in their `_init_custom()` method. +""" + +from lightllm.common.basemodel.batch_objs import ModelInput +from lightllm.common.basemodel.routing_manager import ( + create_routing_capture_manager, + get_moe_layer_count, + flush_routing_capture, + flush_routing_capture_dual, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MoeModelMixin: + """Mixin class providing R3 routing capture support for MoE models. + + Usage: + class MyMoeModel(MoeModelMixin, LlamaTpPartModel): + def _init_custom(self): + super()._init_custom() + self._init_routing_capture() # Enable R3 if flag is set + """ + + def _init_routing_capture(self) -> None: + """Initialize R3 routing capture if enabled via --enable_return_routed_experts. + + Should be called in the model's _init_custom() method after weights are loaded. + This method is idempotent - safe to call multiple times. + """ + if not getattr(self.args, "enable_return_routed_experts", False): + return + + # Get MoE layer count from counter (set during _init_weights) + num_moe_layers = get_moe_layer_count() + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled." + ) + return + + # Get MoE parameters from model config + n_routed_experts = self.config.get("n_routed_experts", self.config.get("num_experts", 0)) + if n_routed_experts == 0: + logger.warning( + "enable_return_routed_experts is set but n_routed_experts=0. " "Routing capture will not be enabled." + ) + return + + topk = self.config.get("num_experts_per_tok", 1) + num_experts = n_routed_experts + + # Check if overlap mode is enabled + enable_overlap = getattr(self.args, "enable_decode_microbatch_overlap", False) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=self.max_total_token_num, + kv_cache_size=self.mem_manager.size, + enable_overlap=enable_overlap, + ) + + def _post_forward(self, model_input: ModelInput, microbatch_index: int = 0) -> None: + """Hook called after forward pass completes. + + Flushes R3 routing capture data from GPU to CPU buffer. + No-op if R3 is not enabled. + """ + flush_routing_capture(model_input.mem_indexes, microbatch_index) + + def _post_forward_dual(self, model_input0: ModelInput, model_input1: ModelInput) -> None: + """Hook called after dual microbatch forward pass completes. + + Flushes R3 routing capture data for both microbatches. + No-op if R3 is not enabled. + """ + flush_routing_capture_dual(model_input0.mem_indexes, model_input1.mem_indexes) diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py index ea6d1eea7..c7eecfc57 100644 --- a/lightllm/common/basemodel/routing_manager.py +++ b/lightllm/common/basemodel/routing_manager.py @@ -182,3 +182,31 @@ def create_routing_capture_manager( def get_routing_capture_manager() -> Optional[RoutingCaptureManager]: """Get the global routing capture manager.""" return g_routing_capture_manager + + +def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + """Flush routing capture to CPU if manager is active. + + Call after forward pass completes. No-op if R3 capture is not enabled. + + Args: + mem_indexes: KV cache slot indices for the batch + microbatch_index: Microbatch index (0 for single batch, 0/1 for overlap) + """ + if g_routing_capture_manager is not None: + g_routing_capture_manager.flush_to_cpu_async(mem_indexes, microbatch_index) + + +def flush_routing_capture_dual(mem_indexes0: torch.Tensor, mem_indexes1: torch.Tensor) -> None: + """Flush routing capture for dual microbatch overlap mode. + + Call after forward pass completes for both microbatches. + No-op if R3 capture is not enabled. + + Args: + mem_indexes0: KV cache slot indices for microbatch 0 + mem_indexes1: KV cache slot indices for microbatch 1 + """ + if g_routing_capture_manager is not None: + g_routing_capture_manager.flush_to_cpu_async(mem_indexes0, microbatch_index=0) + g_routing_capture_manager.flush_to_cpu_async(mem_indexes1, microbatch_index=1) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index eedec5f85..c5a2d3352 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -246,7 +246,6 @@ def _load_mlp(self, mlp_prefix): def _init_moe(self): moe_intermediate_size = self.network_config_["moe_intermediate_size"] - self.moe_gate = ROWMMWeight( weight_names=f"model.layers.{self.layer_num_}.mlp.gate.weight", data_type=self.data_type_, diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index f0739a8a8..b96a634d8 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -15,7 +16,7 @@ @ModelRegistry(["deepseek_v2", "deepseek_v3"]) -class Deepseek2TpPartModel(LlamaTpPartModel): +class Deepseek2TpPartModel(MoeModelMixin, LlamaTpPartModel): # weight class transformer_weight_class = Deepseek2TransformerLayerWeight @@ -48,6 +49,7 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + self._init_routing_capture() # R3 routing capture for MoE def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index 0fab063da..67545f1da 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -42,7 +42,6 @@ def _gpt_oss_rmsnorm(self, hidden_states, weight, eps=1e-6): def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) -> torch.Tensor: hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(hidden_states) hidden_states = layer_weight.experts.experts( hidden_states, diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index dc5f2abdf..5c4b5ef03 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger @@ -10,7 +11,7 @@ @ModelRegistry("gpt_oss") -class GptOssTpPartModel(LlamaTpPartModel): +class GptOssTpPartModel(MoeModelMixin, LlamaTpPartModel): # weight class transformer_weight_class = GptOssTransformerLayerWeight @@ -25,3 +26,7 @@ def __init__(self, kvargs): assert ( get_env_start_args().llm_decode_att_backend[0] == "fa3" ), "For now GPT-OSS type model only support flashattention-3" + + def _init_custom(self): + super()._init_custom() + self._init_routing_capture() # R3 routing capture for MoE diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e8..4a6b611d5 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -16,7 +17,7 @@ @ModelRegistry("mixtral") -class MixtralTpPartModel(TpPartBaseModel): +class MixtralTpPartModel(MoeModelMixin, TpPartBaseModel): # weight class pre_and_post_weight_class = LlamaPreAndPostLayerWeight transformer_weight_class = MixtralTransformerLayerWeight @@ -45,6 +46,7 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + self._init_routing_capture() # R3 routing capture for MoE return def _init_mem_manager(self): diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 27d1945a9..f3db300f3 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -131,7 +131,6 @@ def _moe_ffn( hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(hidden_states) layer_weight.experts.experts( hidden_states, diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index be130dcaf..486f4d696 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -60,7 +60,6 @@ def _init_moe(self): tp_rank=0, tp_world_size=1, ) - moe_mode = os.getenv("MOE_MODE", "TP") assert moe_mode in ["EP", "TP"] if moe_mode == "TP": diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a505127..a6bf6a976 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -4,6 +4,7 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -12,7 +13,7 @@ @ModelRegistry("qwen3_moe") -class Qwen3MOEModel(Qwen3TpPartModel): +class Qwen3MOEModel(MoeModelMixin, Qwen3TpPartModel): # weight class transformer_weight_class = Qwen3MOETransformerLayerWeight @@ -26,3 +27,4 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + self._init_routing_capture() # R3 routing capture for MoE diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index b937bb8c6..b39854401 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -199,12 +199,6 @@ def make_argument_parser() -> argparse.ArgumentParser: choices=["round_robin", "bs_balancer"], help="the dp balancer type, default is bs_balancer", ) - parser.add_argument( - "--enable_return_routed_experts", - action="store_true", - default=False, - help="Enable returning routed expert indices for MoE models (R3 feature).", - ) parser.add_argument( "--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len" ) @@ -614,4 +608,10 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 6386f6caf..f30ecc55f 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -220,6 +220,7 @@ class ChatCompletionRequest(BaseModel): role_settings: Optional[Dict[str, str]] = None character_settings: Optional[List[Dict[str, str]]] = None + # Class variables to store loaded default values _loaded_defaults: ClassVar[Dict[str, Any]] = {} @classmethod diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 39105db3a..d91bb1d94 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -325,11 +325,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req ) choices.append(choice) resp = ChatCompletionResponse( - id=group_request_id, - created=created_time, - model=request.model, - choices=choices, - usage=usage, + id=group_request_id, created=created_time, model=request.model, choices=choices, usage=usage ) return resp diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 128423e6e..f489aac9c 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -122,9 +122,6 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), - ("routing_data_num_moe_layers", ctypes.c_int), - ("routing_data_num_tokens", ctypes.c_int), - ("routing_data_topk", ctypes.c_int), ] def get_str(self): @@ -183,10 +180,6 @@ def init( self.stop_str_matched = False self.stop_str_matched_token_index = -1 - self.routing_data_num_moe_layers = 0 - self.routing_data_num_tokens = 0 - self.routing_data_topk = 0 - self.post_init() self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size @@ -234,40 +227,6 @@ def link_logprobs_shm_array(self): self.shm_logprobs.link_shm() return - def create_routing_data_shm_array(self, num_moe_layers: int, num_tokens: int, topk: int): - service_uni_name = get_unique_server_name() - name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" - shape = (num_moe_layers, num_tokens, topk) - self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) - self.shm_routing_data.create_shm() - self.routing_data_num_moe_layers = num_moe_layers - self.routing_data_num_tokens = num_tokens - self.routing_data_topk = topk - return - - def link_routing_data_shm_array(self): - if self.routing_data_num_moe_layers == 0: - return - service_uni_name = get_unique_server_name() - name = f"{service_uni_name}_shm_routing_{self.index_in_shm_mem}" - shape = (self.routing_data_num_moe_layers, self.routing_data_num_tokens, self.routing_data_topk) - self.shm_routing_data = ShmArray(name, shape, dtype=np.int32) - self.shm_routing_data.link_shm() - return - - def get_routing_data(self): - if self.routing_data_num_moe_layers == 0 or not hasattr(self, "shm_routing_data"): - return None - if self.shm_routing_data is None: - return None - return self.shm_routing_data.arr - - def close_routing_data_shm_array(self): - if hasattr(self, "shm_routing_data") and self.shm_routing_data is not None: - self.shm_routing_data.close_shm() - self.shm_routing_data = None - return - def get_prompt_ids(self): return self.shm_prompt_ids.arr[: self.input_len].tolist() diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index cf13e5d85..d955aa6a8 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -321,7 +321,6 @@ class SamplingParams(ctypes.Structure): ), # whether to add spaces between special tokens when decoding ("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True ("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache - ("return_routed_experts", ctypes.c_bool), ] _do_sample: bool = False @@ -353,7 +352,6 @@ def init(self, tokenizer, **kwargs): self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) - self.return_routed_experts = kwargs.get("return_routed_experts", False) self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) @@ -499,7 +497,6 @@ def to_dict(self): "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, "print_eos_token": self.print_eos_token, "disable_prompt_cache": self.disable_prompt_cache, - "return_routed_experts": self.return_routed_experts, } def to_origin_dict(self): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index bd92a6b7d..1eafb7d18 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -687,11 +687,6 @@ async def recycle_resource_loop(self): for req_status in release_req_status: self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) for req in req_status.group_req_objs.shm_req_objs: - if hasattr(req, "shm_routing_data") and req.shm_routing_data is not None: - try: - req.close_routing_data_shm_array() - except Exception as e: - logger.debug(f"Failed to close routing data shm for req {req.request_id}: {e}") await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) @@ -779,20 +774,6 @@ async def handle_loop(self): else: finish_status = FinishStatus(req.finish_status.status) - if req.sample_params.return_routed_experts and req.routing_data_num_moe_layers > 0: - try: - req.link_routing_data_shm_array() - routing_data = req.get_routing_data() - if routing_data is not None: - metadata["routed_experts"] = { - "shape": list(routing_data.shape), - "dtype": str(routing_data.dtype), - "data": base64.b64encode(routing_data.tobytes()).decode("ascii"), - } - req.close_routing_data_shm_array() - except Exception as e: - logger.warning(f"Failed to read routing data for req {req_id}: {e}") - token_list.append((req_id, text, metadata, finish_status)) else: break diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index ba32ebe1d..0978adff4 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -19,7 +19,6 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient -from lightllm.common.basemodel.routing_manager import get_routing_capture_manager logger = init_logger(__name__) @@ -114,29 +113,6 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: return req_objs - def _extract_routing_data(self, req: "InferReq"): - routing_manager = get_routing_capture_manager() - if routing_manager is None: - logger.debug(f"R3: routing_manager is None for req {req.req_id}") - return - if not req.shm_req.sample_params.return_routed_experts: - logger.debug(f"R3: return_routed_experts is False for req {req.req_id}") - return - if req.cur_kv_len <= 0: - logger.debug(f"R3: cur_kv_len <= 0 for req {req.req_id}") - return - - mem_indexes = self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len] - num_moe_layers = routing_manager.num_moe_layers - topk = routing_manager.topk - num_tokens = req.cur_kv_len - - logger.debug(f"R3: Extracting routing for req {req.req_id}: {num_moe_layers}x{num_tokens}x{topk}") - routing_data = routing_manager.extract_for_request(mem_indexes.cpu()) - req.shm_req.create_routing_data_shm_array(num_moe_layers, num_tokens, topk) - req.shm_req.shm_routing_data.arr[:] = routing_data - logger.debug(f"R3: Successfully extracted routing data for req {req.req_id}") - def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) @@ -179,9 +155,6 @@ def _filter(self, finished_request_ids: List[int]): req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() - - self._extract_routing_data(req) - self.free_a_req_mem(free_token_index, req) free_req_index.append(req.req_idx) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 44940d933..496e4fb6c 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -24,7 +24,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args -from lightllm.common.basemodel.routing_manager import get_routing_capture_manager from .control_state import ControlState logger = init_logger(__name__) @@ -111,11 +110,6 @@ def prefill_normal( with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - # Flush routing capture to CPU - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -156,11 +150,6 @@ def decode_normal( with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - # Flush routing capture to CPU - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -200,11 +189,6 @@ def prefill_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - # Flush routing capture to CPU - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) - next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -256,11 +240,6 @@ def decode_mtp( b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) - # Flush routing capture to CPU - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) - next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index e1dcdcd1b..d8abf4e49 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -22,7 +22,6 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_scatter_next_token_ids -from lightllm.common.basemodel.routing_manager import get_routing_capture_manager from .control_state import DPControlState @@ -147,11 +146,6 @@ def prefill_normal( with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - # Flush routing capture to CPU - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) - if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -196,11 +190,6 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - # Flush routing capture to CPU - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) - if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -250,12 +239,6 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) - # Flush routing capture to CPU for both microbatches - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input0.mem_indexes, microbatch_index=0) - routing_manager.flush_to_cpu_async(model_input1.mem_indexes, microbatch_index=1) - logits0 = model_output0.logits logits1 = model_output1.logits @@ -326,12 +309,6 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) - # Flush routing capture to CPU for both microbatches - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input0.mem_indexes, microbatch_index=0) - routing_manager.flush_to_cpu_async(model_input1.mem_indexes, microbatch_index=1) - logits0 = model_output0.logits logits1 = model_output1.logits @@ -387,11 +364,6 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) - # Flush routing capture to CPU - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) - b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] b_req_idx = model_input.b_req_idx[0:req_num] @@ -455,11 +427,6 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) - # Flush routing capture to CPU - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input.mem_indexes, microbatch_index=0) - mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: logits = model_output.logits[0:req_num, :] @@ -669,12 +636,6 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) - # Flush routing capture to CPU for both microbatches - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input0.mem_indexes, microbatch_index=0) - routing_manager.flush_to_cpu_async(model_input1.mem_indexes, microbatch_index=1) - logits0 = model_output0.logits logits1 = model_output1.logits req_num0, req_num1 = len(run_reqs0), len(run_reqs1) @@ -775,12 +736,6 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) - # Flush routing capture to CPU for both microbatches - routing_manager = get_routing_capture_manager() - if routing_manager is not None: - routing_manager.flush_to_cpu_async(model_input0.mem_indexes, microbatch_index=0) - routing_manager.flush_to_cpu_async(model_input1.mem_indexes, microbatch_index=1) - logits0 = model_output0.logits logits1 = model_output1.logits run_reqs = run_reqs0 + run_reqs1