diff --git a/.gitignore b/.gitignore index 63408699f..3fb49db8b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ dist .vscode tmp/ requirements-musa.txt +CLAUDE.md diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 26d51af3d..b108db1b4 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -169,6 +170,7 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _init_weights(self, start_layer_index=0): + reset_moe_layer_counter() self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ self.transformer_weight_class( @@ -276,6 +278,15 @@ def _init_prefill_cuda_graph(self): self.prefill_graph.warmup(self) def _init_custom(self): + """Hook for model-specific initialization. Override in subclasses.""" + pass + + def _post_forward(self, model_input: ModelInput, microbatch_index: int = 0) -> None: + """Hook called after forward pass completes. Override in subclasses for post-processing.""" + pass + + def _post_forward_dual(self, model_input0: ModelInput, model_input1: ModelInput) -> None: + """Hook called after dual microbatch forward pass completes. Override in subclasses.""" pass @torch.no_grad() @@ -284,9 +295,12 @@ def forward(self, model_input: ModelInput): assert model_input.mem_indexes.is_cuda if model_input.is_prefill: - return self._prefill(model_input) + result = self._prefill(model_input) else: - return self._decode(model_input) + result = self._decode(model_input) + + self._post_forward(model_input) + return result def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): infer_state = self.infer_state_class() @@ -679,6 +693,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod dist_group_manager.clear_deepep_buffer() model_output0.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event model_output1.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event + self._post_forward_dual(model_input0, model_input1) return model_output0, model_output1 @torch.no_grad() @@ -772,6 +787,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.init_att_state() model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1) + self._post_forward_dual(model_input0, model_input1) return model_output0, model_output1 @final diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py index 7dc5b5fdc..f2557af85 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py @@ -23,6 +23,7 @@ from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair from lightllm.utils.log_utils import init_logger from lightllm.common.triton_utils.autotuner import Autotuner +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager, get_next_moe_layer_index logger = init_logger(__name__) @@ -43,7 +44,7 @@ def __init__( quant_cfg=None, ) -> None: super().__init__() - + self.moe_layer_index = get_next_moe_layer_index() self.layer_num = layer_num self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") self.quantized_weight = quant_cfg.quantized_weight @@ -115,6 +116,7 @@ def experts( topk_group, num_expert_group, is_prefill, + microbatch_index: int = 0, ): topk_weights, topk_ids = select_experts( hidden_states=input_tensor, @@ -139,6 +141,11 @@ def experts( enable_counter=self.auto_update_redundancy_expert, ) + # Capture with explicit layer index + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 return fused_experts_impl( @@ -162,6 +169,7 @@ def low_latency_dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + microbatch_index: int = 0, ): topk_weights, topk_idx = select_experts( @@ -187,6 +195,11 @@ def low_latency_dispatch( enable_counter=self.auto_update_redundancy_expert, ) + # Capture with explicit layer index + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_idx, microbatch_index) + topk_idx = topk_idx.to(torch.long) num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank() recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch( @@ -204,6 +217,7 @@ def select_experts_and_quant_input( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + microbatch_index: int = 0, ): topk_weights, topk_idx = select_experts( hidden_states=hidden_states, @@ -226,6 +240,12 @@ def select_experts_and_quant_input( expert_counter=self.routed_expert_counter_tensor, enable_counter=self.auto_update_redundancy_expert, ) + + # Capture with explicit layer index + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_idx, microbatch_index) + M, K = hidden_states.shape w1, w1_scale = self.w1 block_size_k = 0 diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 9295fa96a..c60d40814 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -5,6 +5,7 @@ from .base_weight import BaseWeight from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager, get_next_moe_layer_index def create_tp_moe_wegiht_obj( @@ -71,6 +72,7 @@ def __init__( quant_cfg: Quantcfg = None, ) -> None: super().__init__() + self.moe_layer_index = get_next_moe_layer_index() self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") self.quantized_weight = quant_cfg.quantized_weight if self.quant_method is not None: @@ -101,7 +103,17 @@ def __init__( self.w2 = [None, None] # weight, weight_scale self.lock = threading.Lock() - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + def experts( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + microbatch_index: int = 0, + ): from lightllm.common.fused_moe.topk_select import select_experts topk_weights, topk_ids = select_experts( @@ -116,6 +128,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t scoring_func=self.scoring_func, ) topk_weights.mul_(self.routed_scaling_factor) + + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( @@ -315,6 +332,7 @@ def __init__( quant_cfg: Quantcfg = None, ) -> None: super().__init__() + self.moe_layer_index = get_next_moe_layer_index() self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe") self.quantized_weight = quant_cfg.quantized_weight if self.quant_method is not None: @@ -356,7 +374,17 @@ def __init__( self.w2 = [None, None, None] # weight, weight_scale, zero_point self.lock = threading.Lock() - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + def experts( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + microbatch_index: int = 0, + ): from lightllm.common.fused_moe.topk_select import select_experts topk_weights, topk_ids = select_experts( @@ -371,6 +399,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t scoring_func=self.scoring_func, ) topk_weights.mul_(self.routed_scaling_factor) + + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py index df72cc620..8397ee74f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py @@ -6,6 +6,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id from lightllm.common.quantization import Quantcfg +from lightllm.common.basemodel.routing_manager import get_routing_capture_manager from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -121,9 +122,23 @@ def router(self, router_logits, top_k): router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_top_value, router_indices - def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group): + def experts( + self, + input_tensor, + router_logits, + top_k, + renormalize, + use_grouped_topk, + topk_group, + num_expert_group, + microbatch_index: int = 0, + ): topk_weights, topk_ids = self.router(router_logits, top_k) + routing_manager = get_routing_capture_manager() + if routing_manager is not None: + routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index) + w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None diff --git a/lightllm/common/basemodel/moe_model_mixin.py b/lightllm/common/basemodel/moe_model_mixin.py new file mode 100644 index 000000000..4d730954d --- /dev/null +++ b/lightllm/common/basemodel/moe_model_mixin.py @@ -0,0 +1,89 @@ +"""Mixin for MoE (Mixture of Experts) models. + +Provides R3 (Rollout Router Replay) routing capture functionality for MoE models. +MoE models that want R3 support should inherit from this mixin and call +`_init_routing_capture()` in their `_init_custom()` method. +""" + +from lightllm.common.basemodel.batch_objs import ModelInput +from lightllm.common.basemodel.routing_manager import ( + create_routing_capture_manager, + get_moe_layer_count, + flush_routing_capture, + flush_routing_capture_dual, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MoeModelMixin: + """Mixin class providing R3 routing capture support for MoE models. + + Usage: + class MyMoeModel(MoeModelMixin, LlamaTpPartModel): + def _init_custom(self): + super()._init_custom() + self._init_routing_capture() # Enable R3 if flag is set + """ + + def _init_routing_capture(self) -> None: + """Initialize R3 routing capture if enabled via --enable_return_routed_experts. + + Should be called in the model's _init_custom() method after weights are loaded. + This method is idempotent - safe to call multiple times. + """ + if not getattr(self.args, "enable_return_routed_experts", False): + return + + # Get MoE layer count from counter (set during _init_weights) + num_moe_layers = get_moe_layer_count() + if num_moe_layers == 0: + logger.warning( + "enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled." + ) + return + + # Get MoE parameters from model config + n_routed_experts = self.config.get("n_routed_experts", self.config.get("num_experts", 0)) + if n_routed_experts == 0: + logger.warning( + "enable_return_routed_experts is set but n_routed_experts=0. " "Routing capture will not be enabled." + ) + return + + topk = self.config.get("num_experts_per_tok", 1) + num_experts = n_routed_experts + + # Check if overlap mode is enabled + enable_overlap = getattr(self.args, "enable_decode_microbatch_overlap", False) + + logger.info( + f"Initializing routing capture: num_moe_layers={num_moe_layers}, " + f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}" + ) + + create_routing_capture_manager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=self.max_total_token_num, + kv_cache_size=self.mem_manager.size, + enable_overlap=enable_overlap, + ) + + def _post_forward(self, model_input: ModelInput, microbatch_index: int = 0) -> None: + """Hook called after forward pass completes. + + Flushes R3 routing capture data from GPU to CPU buffer. + No-op if R3 is not enabled. + """ + flush_routing_capture(model_input.mem_indexes, microbatch_index) + + def _post_forward_dual(self, model_input0: ModelInput, model_input1: ModelInput) -> None: + """Hook called after dual microbatch forward pass completes. + + Flushes R3 routing capture data for both microbatches. + No-op if R3 is not enabled. + """ + flush_routing_capture_dual(model_input0.mem_indexes, model_input1.mem_indexes) diff --git a/lightllm/common/basemodel/routing_manager.py b/lightllm/common/basemodel/routing_manager.py new file mode 100644 index 000000000..c7eecfc57 --- /dev/null +++ b/lightllm/common/basemodel/routing_manager.py @@ -0,0 +1,212 @@ +import torch +import numpy as np +from typing import Optional +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +# MoE layer counter for auto-incrementing moe_layer_index +_moe_layer_counter: int = 0 + + +def reset_moe_layer_counter() -> None: + """Reset MoE layer counter. Call before model weight initialization.""" + global _moe_layer_counter + _moe_layer_counter = 0 + + +def get_next_moe_layer_index() -> int: + """Get and increment counter. Called by FusedMoeWeight* during init.""" + global _moe_layer_counter + idx = _moe_layer_counter + _moe_layer_counter += 1 + return idx + + +def get_moe_layer_count() -> int: + """Get total MoE layers after weight init (for routing manager creation).""" + return _moe_layer_counter + + +class RoutingCaptureManager: + """Captures MoE routing decisions for export and analysis. + + Supports: + - Explicit moe_layer_index (no auto-increment) + - Double-buffered GPU buffers for two-batch overlap + - Async GPU->CPU flush with CUDA streams + - MTP via computed layer index offsets + """ + + def __init__( + self, + num_moe_layers: int, + topk: int, + num_experts: int, + batch_max_tokens: int, + kv_cache_size: int, + enable_overlap: bool = False, + ): + """Initialize routing capture buffers. + + Args: + num_moe_layers: Total MoE layers (main + draft models for MTP) + topk: Number of experts selected per token + num_experts: Total experts (determines int8 vs int16 dtype) + batch_max_tokens: Max tokens per batch + kv_cache_size: Size of KV cache (for mem_index mapping) + enable_overlap: Enable double-buffering for two-batch overlap + """ + self.num_moe_layers = num_moe_layers + self.topk = topk + self.num_experts = num_experts + self.batch_max_tokens = batch_max_tokens + self.kv_cache_size = kv_cache_size + + # Choose dtype based on number of experts + self.dtype = torch.int8 if num_experts <= 127 else torch.int16 + dtype_bytes = 1 if self.dtype == torch.int8 else 2 + + # Number of GPU buffer slots (2 for overlap mode, 1 otherwise) + self.num_slots = 2 if enable_overlap else 1 + + # GPU buffer: [num_slots, num_moe_layers, batch_max_tokens, topk] + gpu_buffer_size = self.num_slots * num_moe_layers * batch_max_tokens * topk * dtype_bytes + self.gpu_buffer = torch.zeros( + (self.num_slots, num_moe_layers, batch_max_tokens, topk), + dtype=self.dtype, + device="cuda", + ) + + # CPU buffer: [num_moe_layers, kv_cache_size, topk] + cpu_buffer_size = num_moe_layers * kv_cache_size * topk * dtype_bytes + self.cpu_buffer = torch.zeros( + (num_moe_layers, kv_cache_size, topk), + dtype=self.dtype, + device="cpu", + pin_memory=True, + ) + + # Per-slot async flush + self.flush_streams = [torch.cuda.Stream() for _ in range(self.num_slots)] + self.flush_events = [torch.cuda.Event() for _ in range(self.num_slots)] + + dtype_name = "int8" if self.dtype == torch.int8 else "int16" + logger.info( + f"RoutingCaptureManager initialized: {num_moe_layers} MoE layers, topk={topk}, " + f"slots={self.num_slots}, GPU={gpu_buffer_size / 1024 / 1024:.2f}MB, " + f"CPU={cpu_buffer_size / 1024 / 1024:.2f}MB, dtype={dtype_name}" + ) + + def capture(self, moe_layer_index: int, topk_ids: torch.Tensor, microbatch_index: int = 0) -> None: + """Capture routing decision for a specific MoE layer. + + Args: + moe_layer_index: Explicit index (0 to num_moe_layers-1) + topk_ids: Shape [num_tokens, topk] + microbatch_index: Current microbatch index for overlap mode (default 0) + """ + assert ( + 0 <= moe_layer_index < self.num_moe_layers + ), f"moe_layer_index {moe_layer_index} out of range [0, {self.num_moe_layers})" + slot = microbatch_index % self.num_slots + num_tokens = topk_ids.shape[0] + self.gpu_buffer[slot, moe_layer_index, :num_tokens, :] = topk_ids.to(self.dtype) + + def flush_to_cpu_async(self, mem_indexes: torch.Tensor, microbatch_index: int) -> None: + """Async flush GPU buffer to CPU buffer at mem_index positions. + + Called by backend after forward pass completes. + + Args: + mem_indexes: Shape [num_tokens] - KV cache slot indices + microbatch_index: Which microbatch slot to flush from + """ + num_tokens = mem_indexes.shape[0] + if num_tokens == 0: + return + + slot = microbatch_index % self.num_slots + stream = self.flush_streams[slot] + event = self.flush_events[slot] + + # Wait for inference on this slot to complete + stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(stream): + cpu_indexes = mem_indexes.cpu() + self.cpu_buffer[:, cpu_indexes, :] = self.gpu_buffer[slot, :, :num_tokens, :].cpu() + event.record() + + def extract_for_request(self, mem_indexes: torch.Tensor) -> np.ndarray: + """Extract routing data for a completed request. + + Args: + mem_indexes: KV cache slots belonging to this request (CPU tensor) + + Returns: + numpy array of shape [num_moe_layers, num_tokens, topk] + """ + # Synchronize all pending flushes + for event in self.flush_events: + event.synchronize() + return self.cpu_buffer[:, mem_indexes, :].numpy() + + +# Global instance +g_routing_capture_manager: Optional[RoutingCaptureManager] = None + + +def create_routing_capture_manager( + num_moe_layers: int, + topk: int, + num_experts: int, + batch_max_tokens: int, + kv_cache_size: int, + enable_overlap: bool = False, +) -> None: + """Initialize the global routing capture manager.""" + global g_routing_capture_manager + if g_routing_capture_manager is not None: + logger.warning("RoutingCaptureManager already exists, replacing") + g_routing_capture_manager = RoutingCaptureManager( + num_moe_layers=num_moe_layers, + topk=topk, + num_experts=num_experts, + batch_max_tokens=batch_max_tokens, + kv_cache_size=kv_cache_size, + enable_overlap=enable_overlap, + ) + + +def get_routing_capture_manager() -> Optional[RoutingCaptureManager]: + """Get the global routing capture manager.""" + return g_routing_capture_manager + + +def flush_routing_capture(mem_indexes: torch.Tensor, microbatch_index: int = 0) -> None: + """Flush routing capture to CPU if manager is active. + + Call after forward pass completes. No-op if R3 capture is not enabled. + + Args: + mem_indexes: KV cache slot indices for the batch + microbatch_index: Microbatch index (0 for single batch, 0/1 for overlap) + """ + if g_routing_capture_manager is not None: + g_routing_capture_manager.flush_to_cpu_async(mem_indexes, microbatch_index) + + +def flush_routing_capture_dual(mem_indexes0: torch.Tensor, mem_indexes1: torch.Tensor) -> None: + """Flush routing capture for dual microbatch overlap mode. + + Call after forward pass completes for both microbatches. + No-op if R3 capture is not enabled. + + Args: + mem_indexes0: KV cache slot indices for microbatch 0 + mem_indexes1: KV cache slot indices for microbatch 1 + """ + if g_routing_capture_manager is not None: + g_routing_capture_manager.flush_to_cpu_async(mem_indexes0, microbatch_index=0) + g_routing_capture_manager.flush_to_cpu_async(mem_indexes1, microbatch_index=1) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 8695f2de8..09661098d 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -325,6 +325,7 @@ def _moe_ffn( use_grouped_topk=self.n_group, topk_group=self.topk_group, num_expert_group=self.n_group, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: @@ -351,6 +352,7 @@ def _moe_ffn_edp( topk_group=self.topk_group, num_expert_group=self.n_group, is_prefill=infer_state.is_prefill, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) if self.n_shared_experts is not None: diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index f0739a8a8..b96a634d8 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -6,6 +6,7 @@ from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.llama.model import LlamaTpPartModel from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_added_mtp_kv_layer_num from lightllm.distributed.communication_op import dist_group_manager @@ -15,7 +16,7 @@ @ModelRegistry(["deepseek_v2", "deepseek_v3"]) -class Deepseek2TpPartModel(LlamaTpPartModel): +class Deepseek2TpPartModel(MoeModelMixin, LlamaTpPartModel): # weight class transformer_weight_class = Deepseek2TransformerLayerWeight @@ -48,6 +49,7 @@ def _init_some_value(self): def _init_custom(self): self._init_to_get_yarn_rotary() dist_group_manager.new_deepep_group(self.config["n_routed_experts"], self.config["hidden_size"]) + self._init_routing_capture() # R3 routing capture for MoE def _verify_params(self): return super()._verify_params() diff --git a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py index d80eefd16..67545f1da 100644 --- a/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gpt_oss/layer_infer/transformer_layer_infer.py @@ -51,6 +51,7 @@ def _ffn(self, input, infer_state, layer_weight: GptOssTransformerLayerWeight) - use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/gpt_oss/model.py b/lightllm/models/gpt_oss/model.py index dc5f2abdf..5c4b5ef03 100644 --- a/lightllm/models/gpt_oss/model.py +++ b/lightllm/models/gpt_oss/model.py @@ -2,6 +2,7 @@ from lightllm.models.gpt_oss.layer_weights.transformer_layer_weight import GptOssTransformerLayerWeight from lightllm.models.llama.model import LlamaTpPartModel from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger @@ -10,7 +11,7 @@ @ModelRegistry("gpt_oss") -class GptOssTpPartModel(LlamaTpPartModel): +class GptOssTpPartModel(MoeModelMixin, LlamaTpPartModel): # weight class transformer_weight_class = GptOssTransformerLayerWeight @@ -25,3 +26,7 @@ def __init__(self, kvargs): assert ( get_env_start_args().llm_decode_att_backend[0] == "fa3" ), "For now GPT-OSS type model only support flashattention-3" + + def _init_custom(self): + super()._init_custom() + self._init_routing_capture() # R3 routing capture for MoE diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index c104ebccc..f5a66a6c3 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -74,14 +74,18 @@ def _init_custom(self): rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: self._init_to_get_rotary() - return - - if "rope_type" in rope_scaling: + elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") + super()._init_custom() + + def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): + """Initialize rotary embeddings based on scaling type.""" if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -96,7 +100,6 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) diff --git a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py index 44e66cff2..a2968f5ab 100644 --- a/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mixtral/layer_infer/transformer_layer_infer.py @@ -1,9 +1,6 @@ -import os import torch -import torch.nn.functional as F from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.mixtral.layer_infer._custom_ops import fused_topk from lightllm.models.mixtral.layer_weights.transformer_layer_weight import MixtralTransformerLayerWeight @@ -19,25 +16,15 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - router_logits = layer_weight.moe_gate.mm(input.view(-1, self.embed_dim_)) - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=self.num_experts_per_tok, + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, renormalize=self.renormalize, - alloc_tensor_func=self.alloc_tensor, - ) - from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl - - return fused_experts_impl( - hidden_states=hidden_states, - w1=layer_weight.experts.w1[0], - w2=layer_weight.experts.w2[0], - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=False, - w1_scale=None, - w2_scale=None, - alloc_tensor_func=self.alloc_tensor, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) + return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e8..4a6b611d5 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -2,6 +2,7 @@ import numpy as np from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer @@ -16,7 +17,7 @@ @ModelRegistry("mixtral") -class MixtralTpPartModel(TpPartBaseModel): +class MixtralTpPartModel(MoeModelMixin, TpPartBaseModel): # weight class pre_and_post_weight_class = LlamaPreAndPostLayerWeight transformer_weight_class = MixtralTransformerLayerWeight @@ -45,6 +46,7 @@ def _verify_params(self): def _init_custom(self): self._init_to_get_rotary() + self._init_routing_capture() # R3 routing capture for MoE return def _init_mem_manager(self): diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index c85c423c2..f3db300f3 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -140,6 +140,7 @@ def _moe_ffn( use_grouped_topk=False, topk_group=None, num_expert_group=None, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) return hidden_states.view(num_tokens, hidden_dim) @@ -160,6 +161,7 @@ def _moe_ffn_edp( topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, + microbatch_index=getattr(infer_state, "microbatch_index", 0), ) ep_output = ep_output.view(token_num, hidden_dim) diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a505127..a6bf6a976 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -4,6 +4,7 @@ from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.common.basemodel.moe_model_mixin import MoeModelMixin from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -12,7 +13,7 @@ @ModelRegistry("qwen3_moe") -class Qwen3MOEModel(Qwen3TpPartModel): +class Qwen3MOEModel(MoeModelMixin, Qwen3TpPartModel): # weight class transformer_weight_class = Qwen3MOETransformerLayerWeight @@ -26,3 +27,4 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + self._init_routing_capture() # R3 routing capture for MoE diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 44cc38822..b39854401 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -608,4 +608,10 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument( + "--enable_return_routed_experts", + action="store_true", + default=False, + help="Enable returning routed expert indices for MoE models (R3 feature).", + ) return parser diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f5..5abd90815 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -53,6 +53,7 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana prompt_token_ids = None is_first_metadata = True input_usage = None + routed_experts_data = None async for sub_req_id, request_output, metadata, finish_status in results_generator: # when set "--return_all_prompt_logprobs", the first token metadata will contains # prompt_logprobs and prompt_token_ids @@ -78,6 +79,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana if finish_status.is_finished(): finish_reason_dict[sub_req_id] = finish_status + if "routed_experts" in metadata: + routed_experts_data = metadata["routed_experts"] n = sampling_params.n sub_ids = list(final_output_dict.keys())[:n] final_output_list = ["".join(final_output_dict[sub_id]) for sub_id in sub_ids] @@ -102,6 +105,8 @@ async def lightllm_generate(request: Request, httpserver_manager: HttpServerMana ret["prompt_logprobs"] = prompt_logprobs if input_usage is not None: ret["input_usage"] = input_usage + if routed_experts_data is not None: + ret["routed_experts"] = routed_experts_data return Response(content=json.dumps(ret, ensure_ascii=False).encode("utf-8")) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 212e037e9..1eafb7d18 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import base64 from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538..0978adff4 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -580,6 +580,8 @@ def handle( shm_req.shm_cur_output_len = self.output_len if finish_status.is_finished(): + # Extract routing data before setting finish_status so HTTP server sees it + g_infer_context._extract_routing_data(req_obj) shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a8a5224eb..496e4fb6c 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -109,6 +109,7 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -148,6 +149,7 @@ def decode_normal( model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -186,6 +188,7 @@ def prefill_mtp( model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, b_req_idx=model_input.b_req_idx, @@ -236,6 +239,7 @@ def decode_mtp( with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) + next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) # verify the next_token_ids b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0] diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index bb0e848e7..d8abf4e49 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -145,6 +145,7 @@ def prefill_normal( run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -188,6 +189,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq run_reqs_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits[:run_reqs_num], @@ -236,6 +238,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + logits0 = model_output0.logits logits1 = model_output1.logits @@ -305,6 +308,7 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + logits0 = model_output0.logits logits1 = model_output1.logits @@ -359,6 +363,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] req_num = len(run_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) + b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] b_req_idx = model_input.b_req_idx[0:req_num] @@ -421,6 +426,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) + mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: logits = model_output.logits[0:req_num, :] @@ -629,6 +635,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) + logits0 = model_output0.logits logits1 = model_output1.logits req_num0, req_num1 = len(run_reqs0), len(run_reqs1) @@ -728,6 +735,7 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) + logits0 = model_output0.logits logits1 = model_output1.logits run_reqs = run_reqs0 + run_reqs1 diff --git a/scripts/run_e2e_r3_test.sh b/scripts/run_e2e_r3_test.sh new file mode 100755 index 000000000..3100f18a7 --- /dev/null +++ b/scripts/run_e2e_r3_test.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# E2E Test Script for R3 Routing Capture Feature +# +# This script starts a LightLLM server with routing capture enabled, +# runs the client test, and verifies the results. +# +# Requirements: +# - A MoE model (DeepSeek-V2/V3, Qwen-MoE, Mixtral, etc.) +# - At least 1 GPU with sufficient memory +# - LightLLM installed +# +# Usage: +# ./scripts/run_e2e_r3_test.sh /path/to/moe/model [--tp N] + +set -e + +MODEL_DIR="${1:-}" +TP="${2:-1}" +PORT=8765 + +if [ -z "$MODEL_DIR" ]; then + echo "Usage: $0 /path/to/moe/model [--tp N]" + echo "" + echo "Example:" + echo " $0 /models/DeepSeek-V3 --tp 8" + echo " $0 /models/Qwen-MoE-A14B --tp 4" + exit 1 +fi + +if [ ! -d "$MODEL_DIR" ]; then + echo "ERROR: Model directory not found: $MODEL_DIR" + exit 1 +fi + +echo "==========================================" +echo "R3 E2E Test: Routing Capture Feature" +echo "==========================================" +echo "Model: $MODEL_DIR" +echo "TP: $TP" +echo "Port: $PORT" +echo "" + +# Kill any existing server on the port +pkill -f "lightllm.server.api_server.*--port $PORT" 2>/dev/null || true +sleep 2 + +# Start server in background +echo "Starting LightLLM server..." +python -m lightllm.server.api_server \ + --model_dir "$MODEL_DIR" \ + --tp "$TP" \ + --port "$PORT" \ + --enable_return_routed_experts \ + --max_total_token_num 8000 \ + --batch_max_tokens 4000 \ + > /tmp/lightllm_r3_test.log 2>&1 & + +SERVER_PID=$! +echo "Server PID: $SERVER_PID" +echo "Log: /tmp/lightllm_r3_test.log" + +# Wait for server to be ready +echo "Waiting for server to be ready..." +MAX_WAIT=300 +WAITED=0 +while [ $WAITED -lt $MAX_WAIT ]; do + if curl -s "http://localhost:$PORT/health" > /dev/null 2>&1; then + echo "Server is ready!" + break + fi + sleep 5 + WAITED=$((WAITED + 5)) + echo " Waited ${WAITED}s..." +done + +if [ $WAITED -ge $MAX_WAIT ]; then + echo "ERROR: Server failed to start within ${MAX_WAIT}s" + echo "Server log:" + tail -50 /tmp/lightllm_r3_test.log + kill $SERVER_PID 2>/dev/null || true + exit 1 +fi + +# Run client test +echo "" +echo "Running R3 client test..." +echo "==========================================" +python test_r3.py --url "http://localhost:$PORT" +TEST_RESULT=$? + +# Cleanup +echo "" +echo "Stopping server..." +kill $SERVER_PID 2>/dev/null || true +wait $SERVER_PID 2>/dev/null || true + +# Report result +echo "" +echo "==========================================" +if [ $TEST_RESULT -eq 0 ]; then + echo "E2E TEST PASSED!" +else + echo "E2E TEST FAILED!" + echo "Server log (last 30 lines):" + tail -30 /tmp/lightllm_r3_test.log +fi +echo "==========================================" + +exit $TEST_RESULT diff --git a/test_r3.py b/test_r3.py new file mode 100644 index 000000000..14157fab5 --- /dev/null +++ b/test_r3.py @@ -0,0 +1,99 @@ +""" +R3 Client Test: Tests the routing capture export feature. + +This test requires a running LightLLM server with: +- A MoE model (e.g., DeepSeek-V2/V3) +- --enable_return_routed_experts flag + +Usage: + python test_r3.py [--url URL] +""" +import sys +import argparse +import requests +import base64 +import numpy as np + + +def test_routing_export(url: str = "http://localhost:8000"): + """Test the routing export feature.""" + print(f"Testing routing export at {url}") + print("-" * 50) + + try: + response = requests.post( + f"{url}/generate", + json={ + "inputs": "What is the capital of France?", + "parameters": { + "max_new_tokens": 50, + "return_routed_experts": True, + }, + }, + timeout=60, + ) + except requests.exceptions.ConnectionError: + print(f"ERROR: Cannot connect to server at {url}") + print("Make sure the LightLLM server is running with --enable_return_routed_experts") + return False + except requests.exceptions.Timeout: + print("ERROR: Request timed out") + return False + + print(f"Status: {response.status_code}") + + if response.status_code != 200: + print(f"ERROR: Request failed with status {response.status_code}") + print(f"Response: {response.text}") + return False + + res = response.json() + print(f"Generated text: {res.get('generated_text', 'N/A')[:100]}...") + + # Check for routed_experts in response + if "routed_experts" not in res or not res["routed_experts"]: + print("\nWARNING: No routed_experts in response.") + print("This could mean:") + print(" - The model is not a MoE model") + print(" - The server was not started with --enable_return_routed_experts") + print(" - The routing capture manager was not initialized") + return False + + # Decode routed_experts from base64 + routing_info = res["routed_experts"] + shape = routing_info["shape"] + dtype = np.dtype(routing_info["dtype"]) + data = base64.b64decode(routing_info["data"]) + routing_array = np.frombuffer(data, dtype=dtype).reshape(shape) + + print(f"\n{'=' * 50}") + print("ROUTING CAPTURE SUCCESS!") + print(f"{'=' * 50}") + print(f"Shape: {shape} # [num_moe_layers, num_tokens, topk]") + print(f"Dtype: {dtype}") + print(f"Num MoE layers: {shape[0]}") + print(f"Num tokens: {shape[1]}") + print(f"Top-K: {shape[2]}") + + # Show sample of routing data + print(f"\nSample routing (first layer, first 5 tokens):") + num_tokens_to_show = min(5, shape[1]) + for i in range(num_tokens_to_show): + print(f" Token {i}: experts {routing_array[0, i, :].tolist()}") + + # Validate data + if np.all(routing_array == 0): + print("\nWARNING: All routing data is zeros. Capture may not be working correctly.") + return False + + print("\nTest PASSED!") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test R3 routing export feature") + parser.add_argument("--url", default="http://localhost:8000", help="Server URL") + args = parser.parse_args() + + success = test_routing_export(args.url) + sys.exit(0 if success else 1) diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/__init__.py b/unit_tests/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/__init__.py b/unit_tests/common/basemodel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/unit_tests/common/basemodel/test_routing_capture_manager.py b/unit_tests/common/basemodel/test_routing_capture_manager.py new file mode 100644 index 000000000..8d5a1d0bf --- /dev/null +++ b/unit_tests/common/basemodel/test_routing_capture_manager.py @@ -0,0 +1,132 @@ +import pytest +import torch +import numpy as np + + +def test_moe_layer_counter(): + """Counter increments and resets correctly.""" + from lightllm.common.basemodel.routing_manager import ( + reset_moe_layer_counter, + get_next_moe_layer_index, + get_moe_layer_count, + ) + + reset_moe_layer_counter() + assert get_moe_layer_count() == 0 + + assert get_next_moe_layer_index() == 0 + assert get_next_moe_layer_index() == 1 + assert get_next_moe_layer_index() == 2 + assert get_moe_layer_count() == 3 + + reset_moe_layer_counter() + assert get_moe_layer_count() == 0 + assert get_next_moe_layer_index() == 0 + + +class TestRoutingCaptureManager: + """Tests for the redesigned RoutingCaptureManager.""" + + def test_capture_explicit_layer_index(self): + """Capture stores data at explicit moe_layer_index.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=4, + topk=8, + num_experts=64, + batch_max_tokens=128, + kv_cache_size=1024, + enable_overlap=False, + ) + + # Capture at layer 2 (not sequential) + topk_ids = torch.randint(0, 64, (10, 8), device="cuda") + manager.capture(moe_layer_index=2, topk_ids=topk_ids) + + # Verify data is at layer 2, not layer 0 + assert torch.equal(manager.gpu_buffer[0, 2, :10, :], topk_ids.to(manager.dtype)) + + def test_double_buffer_overlap_mode(self): + """Double buffer prevents race condition in overlap mode.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + batch_max_tokens=64, + kv_cache_size=256, + enable_overlap=True, + ) + + # Should have 2 buffer slots + assert manager.num_slots == 2 + assert manager.gpu_buffer.shape[0] == 2 + + # Capture to slot 0 (microbatch_index=0) + ids_0 = torch.ones((5, 4), dtype=torch.int64, device="cuda") + manager.capture(moe_layer_index=0, topk_ids=ids_0, microbatch_index=0) + + # Capture to slot 1 (microbatch_index=1) + ids_1 = torch.ones((5, 4), dtype=torch.int64, device="cuda") * 2 + manager.capture(moe_layer_index=0, topk_ids=ids_1, microbatch_index=1) + + # Both slots have different data + assert manager.gpu_buffer[0, 0, 0, 0].item() == 1 + assert manager.gpu_buffer[1, 0, 0, 0].item() == 2 + + def test_flush_and_extract(self): + """Flush transfers data to CPU, extract retrieves by mem_index.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + manager = RoutingCaptureManager( + num_moe_layers=2, + topk=4, + num_experts=32, + batch_max_tokens=64, + kv_cache_size=256, + enable_overlap=False, + ) + + # Capture some data (microbatch_index defaults to 0) + topk_ids = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], device="cuda") + manager.capture(moe_layer_index=0, topk_ids=topk_ids) + manager.capture(moe_layer_index=1, topk_ids=topk_ids + 10) + + # Flush to mem_indexes 10 and 11 + mem_indexes = torch.tensor([10, 11], device="cuda") + manager.flush_to_cpu_async(mem_indexes, microbatch_index=0) + + # Extract + result = manager.extract_for_request(mem_indexes.cpu()) + + assert result.shape == (2, 2, 4) # [layers, tokens, topk] + assert result[0, 0, 0] == 1 + assert result[1, 0, 0] == 11 + + def test_dtype_selection(self): + """Uses int8 for <=127 experts, int16 otherwise.""" + from lightllm.common.basemodel.routing_manager import RoutingCaptureManager + + # Small expert count -> int8 + manager_small = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=64, + batch_max_tokens=32, + kv_cache_size=128, + enable_overlap=False, + ) + assert manager_small.dtype == torch.int8 + + # Large expert count -> int16 + manager_large = RoutingCaptureManager( + num_moe_layers=1, + topk=2, + num_experts=256, + batch_max_tokens=32, + kv_cache_size=128, + enable_overlap=False, + ) + assert manager_large.dtype == torch.int16