Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dist
.vscode
tmp/
requirements-musa.txt
CLAUDE.md
20 changes: 18 additions & 2 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.common.basemodel.routing_manager import reset_moe_layer_counter
from lightllm.common.kv_cache_mem_manager import MemoryManager
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
from lightllm.common.req_manager import ReqManager
Expand Down Expand Up @@ -169,6 +170,7 @@ def _init_quant(self):
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")

def _init_weights(self, start_layer_index=0):
reset_moe_layer_counter()
self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config)
self.trans_layers_weight = [
self.transformer_weight_class(
Expand Down Expand Up @@ -276,6 +278,15 @@ def _init_prefill_cuda_graph(self):
self.prefill_graph.warmup(self)

def _init_custom(self):
"""Hook for model-specific initialization. Override in subclasses."""
pass

def _post_forward(self, model_input: ModelInput, microbatch_index: int = 0) -> None:
"""Hook called after forward pass completes. Override in subclasses for post-processing."""
pass

def _post_forward_dual(self, model_input0: ModelInput, model_input1: ModelInput) -> None:
"""Hook called after dual microbatch forward pass completes. Override in subclasses."""
pass

@torch.no_grad()
Expand All @@ -284,9 +295,12 @@ def forward(self, model_input: ModelInput):
assert model_input.mem_indexes.is_cuda

if model_input.is_prefill:
return self._prefill(model_input)
result = self._prefill(model_input)
else:
return self._decode(model_input)
result = self._decode(model_input)

self._post_forward(model_input)
return result

def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0):
infer_state = self.infer_state_class()
Expand Down Expand Up @@ -679,6 +693,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
dist_group_manager.clear_deepep_buffer()
model_output0.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
model_output1.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
self._post_forward_dual(model_input0, model_input1)
return model_output0, model_output1

@torch.no_grad()
Expand Down Expand Up @@ -772,6 +787,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
infer_state1.init_att_state()

model_output0, model_output1 = self._overlap_tpsp_token_forward(infer_state0, infer_state1=infer_state1)
self._post_forward_dual(model_input0, model_input1)
return model_output0, model_output1

@final
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
from lightllm.utils.log_utils import init_logger
from lightllm.common.triton_utils.autotuner import Autotuner
from lightllm.common.basemodel.routing_manager import get_routing_capture_manager, get_next_moe_layer_index


logger = init_logger(__name__)
Expand All @@ -43,7 +44,7 @@ def __init__(
quant_cfg=None,
) -> None:
super().__init__()

self.moe_layer_index = get_next_moe_layer_index()
self.layer_num = layer_num
self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
self.quantized_weight = quant_cfg.quantized_weight
Expand Down Expand Up @@ -115,6 +116,7 @@ def experts(
topk_group,
num_expert_group,
is_prefill,
microbatch_index: int = 0,
):
topk_weights, topk_ids = select_experts(
hidden_states=input_tensor,
Expand All @@ -139,6 +141,11 @@ def experts(
enable_counter=self.auto_update_redundancy_expert,
)

# Capture with explicit layer index
routing_manager = get_routing_capture_manager()
if routing_manager is not None:
routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)

w1, w1_scale = self.w1
w2, w2_scale = self.w2
return fused_experts_impl(
Expand All @@ -162,6 +169,7 @@ def low_latency_dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
microbatch_index: int = 0,
):

topk_weights, topk_idx = select_experts(
Expand All @@ -187,6 +195,11 @@ def low_latency_dispatch(
enable_counter=self.auto_update_redundancy_expert,
)

# Capture with explicit layer index
routing_manager = get_routing_capture_manager()
if routing_manager is not None:
routing_manager.capture(self.moe_layer_index, topk_idx, microbatch_index)

topk_idx = topk_idx.to(torch.long)
num_max_dispatch_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank()
recv_x, masked_m, handle, event, hook = dist_group_manager.ep_buffer.low_latency_dispatch(
Expand All @@ -204,6 +217,7 @@ def select_experts_and_quant_input(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
microbatch_index: int = 0,
):
topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
Expand All @@ -226,6 +240,12 @@ def select_experts_and_quant_input(
expert_counter=self.routed_expert_counter_tensor,
enable_counter=self.auto_update_redundancy_expert,
)

# Capture with explicit layer index
routing_manager = get_routing_capture_manager()
if routing_manager is not None:
routing_manager.capture(self.moe_layer_index, topk_idx, microbatch_index)

M, K = hidden_states.shape
w1, w1_scale = self.w1
block_size_k = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base_weight import BaseWeight
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id
from lightllm.common.quantization import Quantcfg
from lightllm.common.basemodel.routing_manager import get_routing_capture_manager, get_next_moe_layer_index


def create_tp_moe_wegiht_obj(
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
quant_cfg: Quantcfg = None,
) -> None:
super().__init__()
self.moe_layer_index = get_next_moe_layer_index()
self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
self.quantized_weight = quant_cfg.quantized_weight
if self.quant_method is not None:
Expand Down Expand Up @@ -101,7 +103,17 @@ def __init__(
self.w2 = [None, None] # weight, weight_scale
self.lock = threading.Lock()

def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):
def experts(
self,
input_tensor,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
microbatch_index: int = 0,
):
from lightllm.common.fused_moe.topk_select import select_experts

topk_weights, topk_ids = select_experts(
Expand All @@ -116,6 +128,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
scoring_func=self.scoring_func,
)
topk_weights.mul_(self.routed_scaling_factor)

routing_manager = get_routing_capture_manager()
if routing_manager is not None:
routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)

if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
Expand Down Expand Up @@ -315,6 +332,7 @@ def __init__(
quant_cfg: Quantcfg = None,
) -> None:
super().__init__()
self.moe_layer_index = get_next_moe_layer_index()
self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
self.quantized_weight = quant_cfg.quantized_weight
if self.quant_method is not None:
Expand Down Expand Up @@ -356,7 +374,17 @@ def __init__(
self.w2 = [None, None, None] # weight, weight_scale, zero_point
self.lock = threading.Lock()

def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):
def experts(
self,
input_tensor,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
microbatch_index: int = 0,
):
from lightllm.common.fused_moe.topk_select import select_experts

topk_weights, topk_ids = select_experts(
Expand All @@ -371,6 +399,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
scoring_func=self.scoring_func,
)
topk_weights.mul_(self.routed_scaling_factor)

routing_manager = get_routing_capture_manager()
if routing_manager is not None:
routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)

if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe_weight_tp import FusedMoeWeightTP
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id
from lightllm.common.quantization import Quantcfg
from lightllm.common.basemodel.routing_manager import get_routing_capture_manager
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -121,9 +122,23 @@ def router(self, router_logits, top_k):
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
return router_top_value, router_indices

def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):
def experts(
self,
input_tensor,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
microbatch_index: int = 0,
):
topk_weights, topk_ids = self.router(router_logits, top_k)

routing_manager = get_routing_capture_manager()
if routing_manager is not None:
routing_manager.capture(self.moe_layer_index, topk_ids, microbatch_index)

w1, w1_scale = self.w1
w2, w2_scale = self.w2
use_fp8_w8a8 = self.quant_method is not None
Expand Down
89 changes: 89 additions & 0 deletions lightllm/common/basemodel/moe_model_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Mixin for MoE (Mixture of Experts) models.

Provides R3 (Rollout Router Replay) routing capture functionality for MoE models.
MoE models that want R3 support should inherit from this mixin and call
`_init_routing_capture()` in their `_init_custom()` method.
"""

from lightllm.common.basemodel.batch_objs import ModelInput
from lightllm.common.basemodel.routing_manager import (
create_routing_capture_manager,
get_moe_layer_count,
flush_routing_capture,
flush_routing_capture_dual,
)
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class MoeModelMixin:
"""Mixin class providing R3 routing capture support for MoE models.

Usage:
class MyMoeModel(MoeModelMixin, LlamaTpPartModel):
def _init_custom(self):
super()._init_custom()
self._init_routing_capture() # Enable R3 if flag is set
"""

def _init_routing_capture(self) -> None:
"""Initialize R3 routing capture if enabled via --enable_return_routed_experts.

Should be called in the model's _init_custom() method after weights are loaded.
This method is idempotent - safe to call multiple times.
"""
if not getattr(self.args, "enable_return_routed_experts", False):
return

# Get MoE layer count from counter (set during _init_weights)
num_moe_layers = get_moe_layer_count()
if num_moe_layers == 0:
logger.warning(
"enable_return_routed_experts is set but no MoE layers found. " "Routing capture will not be enabled."
)
return

# Get MoE parameters from model config
n_routed_experts = self.config.get("n_routed_experts", self.config.get("num_experts", 0))
if n_routed_experts == 0:
logger.warning(
"enable_return_routed_experts is set but n_routed_experts=0. " "Routing capture will not be enabled."
)
return

topk = self.config.get("num_experts_per_tok", 1)
num_experts = n_routed_experts

# Check if overlap mode is enabled
enable_overlap = getattr(self.args, "enable_decode_microbatch_overlap", False)

logger.info(
f"Initializing routing capture: num_moe_layers={num_moe_layers}, "
f"topk={topk}, num_experts={num_experts}, enable_overlap={enable_overlap}"
)

create_routing_capture_manager(
num_moe_layers=num_moe_layers,
topk=topk,
num_experts=num_experts,
batch_max_tokens=self.max_total_token_num,
kv_cache_size=self.mem_manager.size,
enable_overlap=enable_overlap,
)

def _post_forward(self, model_input: ModelInput, microbatch_index: int = 0) -> None:
"""Hook called after forward pass completes.

Flushes R3 routing capture data from GPU to CPU buffer.
No-op if R3 is not enabled.
"""
flush_routing_capture(model_input.mem_indexes, microbatch_index)

def _post_forward_dual(self, model_input0: ModelInput, model_input1: ModelInput) -> None:
"""Hook called after dual microbatch forward pass completes.

Flushes R3 routing capture data for both microbatches.
No-op if R3 is not enabled.
"""
flush_routing_capture_dual(model_input0.mem_indexes, model_input1.mem_indexes)
Loading