diff --git a/.gitignore b/.gitignore index 5763f502..9775aa96 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,4 @@ checkpoints/ # macOS .DS_Store .vscode +docs/superpowers diff --git a/docs/models/qwen3_5_moe.md b/docs/models/qwen3_5_moe.md new file mode 100644 index 00000000..8d1d945b --- /dev/null +++ b/docs/models/qwen3_5_moe.md @@ -0,0 +1,97 @@ +# Qwen3.5-MoE Training + +## Overview + +Qwen3.5-MoE (`Qwen/Qwen3.6-35B-A3B`) is a **multimodal** Mixture-of-Experts model +with a vision tower plus a hybrid-attention MoE language model. Each decoder +layer is either a **linear-attention** layer (gated delta net) or a **full +softmax-attention** layer, selected per layer via +`config.text_config.layer_types[i]`. The MoE block contains a +**shared_expert** alongside the routed experts. + +The top-level multimodal class is `Qwen3_5MoeForConditionalGeneration` +(`model_type = "qwen3_5_moe"`). + +## Supported Features + +| Feature | Support | +|---------|---------| +| **FSDP2** | ✅ | +| **USP / Sequence Parallel** | ❌ (linear-attention path is not SP-safe) | +| **Muon Optimizer** | ✅ | +| **Liger Kernel** | ✅ | +| **Packing** | ✅ (rmpad) | +| **NSA** | ❌ | +| **Expert Parallelism (EP)** | ✅ | + +**Highlights**: Hybrid attention (linear / full), `shared_expert` + routed +experts, Expert Parallelism via the custom `Qwen3_5MoeExperts` `ParallelStyle`. + +## Quick Start + +See the example configuration and run script: +- **Example Config**: [examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml](../../examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml) +- **Run Script**: [examples/qwen3_5_moe/run.sh](../../examples/qwen3_5_moe/run.sh) + +Verified end-to-end with `cicd/run_traincicd.sh --model-name qwen3_5_moe --gpu-count 4`. + +## Key Configuration + +```yaml +model_config: + load_from_pretrained_path: "Qwen/Qwen3.6-35B-A3B" + # CRITICAL: Qwen3_5MoeConfig is registered in both causal_lm and + # image_text_to_text auto-mappings. Without this line we'd silently load the + # text-only Qwen3_5MoeForCausalLM instead of the multimodal + # Qwen3_5MoeForConditionalGeneration. + model_general_type: image_text_to_text + attn_implementation: flash_attention_2 + monkey_patch_kwargs: + # Two patches registered separately for qwen3_5_moe; runner applies them + # in order. "rmpad" accepts no kwargs; the listed kwargs go to "liger". + patch_type: ["liger", "rmpad"] + fused_linear_cross_entropy: true + rms_norm: true + swiglu: true + +trainer_args: + use_liger_kernel: true + use_rmpad: true + fsdp2: true + fsdp_config: + transformer_layer_cls_to_wrap: ["Qwen3_5MoeDecoderLayer"] + sp_ulysses_degree: 1 # SP is not supported + ep_degree: 8 # Expert Parallelism degree +``` + +## Expert Parallelism + +Expert Parallelism (EP) distributes the routed MoE experts across GPUs. +Configure `ep_degree` to match your GPU count (e.g., 2, 4, 8). The FSDP wrap +branches on `decoder_layer.layer_type` (`linear_attn` vs `self_attn`) so that +the gated-delta-net and softmax-attention layers each get the right sharding +plan, while the experts are sharded along the expert dimension via the +`Qwen3_5MoeExperts` `ParallelStyle`. + +## Merging EP Checkpoints + +FSDP2 + EP checkpoints store expert weights as **multi-axis DTensors** with +placements like `(Shard(dim=1), Shard(dim=0))` on a 2D mesh +`(dp_shard_mod_ep, ep)`. The checkpoint merger consolidates these correctly +as of this branch. + +Merge a checkpoint into a single HF-loadable directory with: + +```bash +python -m lmms_engine.merger \ + --checkpoint_path ./output/qwen3_5_moe_a3b_ep8/checkpoint-1000 \ + --output_path ./output/qwen3_5_moe_a3b_ep8/merged-1000 \ + --model_general_type image_text_to_text +``` + +`--model_general_type image_text_to_text` is **required** for the same reason +as at train time: without it the merger instantiates `Qwen3_5MoeForCausalLM` +(text-only) from the saved config and crashes with +`'Qwen3_5MoeConfig' has no attribute 'vocab_size'` (the vocab lives on +`config.text_config`, which the multimodal wrapper knows about but the +text-only causal-LM does not). diff --git a/examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml b/examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml new file mode 100644 index 00000000..46944a96 --- /dev/null +++ b/examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml @@ -0,0 +1,89 @@ +# Unified LMMs Engine Training Configuration for Qwen3.5-MoE (Qwen3.6-35B-A3B) +# +# Multimodal MoE model with hybrid attention (linear / full per layer) and +# shared_expert + routed experts. Expert Parallelism (EP) is supported; sequence +# parallelism (Ulysses) is NOT supported on this model (the gated-delta linear +# attention path is not SP-safe). +# +# For smaller boxes, `ep_degree=4` works on 4 GPUs (verified by cicd +# `cicd/run_traincicd.sh --model-name qwen3_5_moe --gpu-count 4`). + + +trainer_type: fsdp2_trainer + +# Dataset configuration - inline dataset definitions +dataset_config: + dataset_type: vision_iterable + dataset_format: yaml + + datasets: + - path: data/lmms_engine_test/text_example/open_thoughts_5k_parquet + data_folder: "" + data_type: parquet + + # Processor configuration - qwen3_5_moe uses the qwen3_vl processor + processor_config: + processor_name: "Qwen/Qwen3.6-35B-A3B" + processor_type: "qwen3_vl" + + packing: false + packing_strategy: first_fit + packing_length: 10240 + video_backend: qwen_vl_utils + filter_overlong: true + +# Model configuration +model_config: + load_from_pretrained_path: "Qwen/Qwen3.6-35B-A3B" + # Qwen3_5MoeConfig is registered in both causal_lm and image_text_to_text + # auto-mappings; pin to image_text_to_text so we get the multimodal + # Qwen3_5MoeForConditionalGeneration wrapper (vision tower + LM), not the + # text-only Qwen3_5MoeForCausalLM. + model_general_type: image_text_to_text + attn_implementation: "flash_attention_2" + # Two independent patches registered under qwen3_5_moe: "liger" and "rmpad". + # The trainer runner applies them in order ["liger", "rmpad"]. + # Only liger accepts kwargs; rmpad takes none. + monkey_patch_kwargs: + patch_type: ["liger", "rmpad"] + fused_linear_cross_entropy: true + rms_norm: true + swiglu: true + +# Training arguments, mostly compatible with HuggingFace Trainer +trainer_args: + per_device_train_batch_size: 1 + learning_rate: 1.0e-06 + weight_decay: 0.0 + gradient_accumulation_steps: 1 + gradient_checkpointing: true + max_steps: 500 + num_train_epochs: 1 + save_steps: 100 + save_total_limit: 1 + report_to: "none" + output_dir: "./output/qwen3_5_moe_a3b_ep8" + warmup_ratio: 0.0 + warmup_steps: 100 + run_name: "qwen3_5_moe_a3b_ep8" + eval_strategy: "no" + logging_steps: 1 + group_by_length: false + dataloader_num_workers: 0 + bf16: true + lr_scheduler_type: "constant" + use_liger_kernel: true + use_rmpad: true + fsdp2: true + fsdp_config: + transformer_layer_cls_to_wrap: ["Qwen3_5MoeDecoderLayer"] + reshard_after_forward: false + # Sequence parallelism is not supported for qwen3_5_moe (linear-attention + # path is not SP-safe). Keep sp_ulysses_degree=1. + sp_ulysses_degree: 1 + # Expert Parallelism degree. 8 for an 8-GPU node; use 4 on a 4-GPU box. + ep_degree: 8 + enable_profiler: false + profiler_config: + start_step: 1 + end_step: 3 diff --git a/examples/qwen3_5_moe/run.sh b/examples/qwen3_5_moe/run.sh new file mode 100755 index 00000000..a4f1ceed --- /dev/null +++ b/examples/qwen3_5_moe/run.sh @@ -0,0 +1,11 @@ +# Number of GPUs +NGPUS=8 + +# Training command +torchrun --nproc_per_node=${NGPUS} \ + --nnodes=1 \ + --node_rank=0 \ + --master_addr=127.0.0.1 \ + --master_port=12357 \ + -m lmms_engine.launch.cli \ + config_yaml=examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml diff --git a/src/lmms_engine/mapping_func.py b/src/lmms_engine/mapping_func.py index ce7635ad..a6ae7b8f 100644 --- a/src/lmms_engine/mapping_func.py +++ b/src/lmms_engine/mapping_func.py @@ -118,20 +118,42 @@ def create_model_from_pretrained( return model_class -def create_model_from_config(model_type, config): +def create_model_from_config(model_type, config, model_general_type: str | None = None): + """Build a model class + config from a model_type string and a config dict. + + Args: + model_type: HF model_type string (e.g. ``"qwen3_5_moe"``). + config: dict of kwargs forwarded to the corresponding config class. + model_general_type: Optional override; one of the keys in + ``AUTO_REGISTER_MODEL_MAPPING``. Use it to disambiguate when the + same config is registered under multiple AutoModel mappings (e.g. + ``Qwen3_5MoeConfig`` is in both ``causal_lm`` and + ``image_text_to_text``; without the override we'd silently pick + the wrong wrapper). + """ from transformers.models.auto.configuration_auto import CONFIG_MAPPING - if model_type in CONFIG_MAPPING: - config_class = CONFIG_MAPPING[model_type] - m_config = config_class(**config) - if type(m_config) in AutoModelForCausalLM._model_mapping.keys(): - model_class = AutoModelForCausalLM - elif type(m_config) in AutoModelForImageTextToText._model_mapping.keys(): - model_class = AutoModelForImageTextToText - elif type(m_config) in AutoModelForMaskedLM._model_mapping.keys(): - model_class = AutoModelForMaskedLM - elif type(m_config) in AutoModel._model_mapping.keys(): - model_class = AutoModel - else: + if model_type not in CONFIG_MAPPING: raise ValueError(f"Model type '{model_type}' is not found in CONFIG_MAPPING.") + config_class = CONFIG_MAPPING[model_type] + m_config = config_class(**config) + + if model_general_type is not None: + if model_general_type not in AUTO_REGISTER_MODEL_MAPPING: + raise ValueError( + f"Unknown model_general_type={model_general_type!r}; " + f"choose one of {list(AUTO_REGISTER_MODEL_MAPPING)}" + ) + return AUTO_REGISTER_MODEL_MAPPING[model_general_type], m_config + + if type(m_config) in AutoModelForCausalLM._model_mapping.keys(): + model_class = AutoModelForCausalLM + elif type(m_config) in AutoModelForImageTextToText._model_mapping.keys(): + model_class = AutoModelForImageTextToText + elif type(m_config) in AutoModelForMaskedLM._model_mapping.keys(): + model_class = AutoModelForMaskedLM + elif type(m_config) in AutoModel._model_mapping.keys(): + model_class = AutoModel + else: + raise ValueError(f"Model type '{model_type}' is not in any AutoModel mapping.") return model_class, m_config diff --git a/src/lmms_engine/merger/__main__.py b/src/lmms_engine/merger/__main__.py index 421e2ea9..52edd5f7 100644 --- a/src/lmms_engine/merger/__main__.py +++ b/src/lmms_engine/merger/__main__.py @@ -38,6 +38,19 @@ def parse_args() -> argparse.Namespace: help="Type of checkpoint to merge: 'regular' for main model weights, 'ema' for EMA weights", ) + parser.add_argument( + "--model_general_type", + type=str, + default=None, + choices=["causal_lm", "masked_lm", "image_text_to_text", "general"], + help=( + "Override AutoModel class used to instantiate the merged model. " + "Needed when the same config is registered under multiple AutoModel " + "mappings (e.g. Qwen3_5MoeConfig is in both causal_lm and " + "image_text_to_text). If unset, falls back to auto-detection." + ), + ) + return parser.parse_args() @@ -53,7 +66,11 @@ def main() -> None: print(f"Merging {args.checkpoint_type} checkpoint from {checkpoint_path}") merger = FSDP2Merger(checkpoint_type=args.checkpoint_type) - result_path = merger.merge(checkpoint_path, output_path=output_path) + result_path = merger.merge( + checkpoint_path, + output_path=output_path, + model_general_type=args.model_general_type, + ) print(f"Merged checkpoint saved to: {result_path}") diff --git a/src/lmms_engine/merger/fsdp2.py b/src/lmms_engine/merger/fsdp2.py index cdf400b6..2e6042cc 100644 --- a/src/lmms_engine/merger/fsdp2.py +++ b/src/lmms_engine/merger/fsdp2.py @@ -86,10 +86,19 @@ def process_one_shard(rank: int, model_state_dict_lst: list) -> dict: def consolidate(self, shard_state_dicts: list[dict]) -> dict: """Consolidate sharded FSDP2 state dicts into a single full state dict. - Uses each tensor's ``DTensor.placements`` to decide whether shards are - sharded (concatenate along the sharding dim) or replicated (take one - copy). Falls back to value equality for plain tensors that don't carry - placement metadata. + Uses each tensor's ``DTensor.placements`` and ``device_mesh`` to decide + how to merge shards: + + * Single Shard / Replicate placement (plain FSDP2): concatenate along + the sharding dim or take one copy. + * Multi-placement DTensor (FSDP2 + Expert Parallel): each placement is + one axis of the device mesh. We materialize the 2D mesh layout + implied by ``mesh.shape`` + the global rank ordering (``stride``), + then cat along inner mesh axes first and outer mesh axes last so the + local-tensor index arithmetic matches DTensor semantics. + + For plain torch.Tensor entries (e.g. ``inv_freq`` buffers) we just + take the first shard — they're replicated across ranks. Args: shard_state_dicts: List of state dicts from each shard @@ -99,29 +108,47 @@ def consolidate(self, shard_state_dicts: list[dict]) -> dict: """ state_dict: dict = {} - # Gather all tensor shards by key, remembering placements / global shape - # for proper consolidation. We can't use byte-level equality because a - # parameter that happens to be uniform after init (e.g. RMSNorm.weight - # initialized to 1.0) is genuinely sharded but every shard has the - # same values, so equality would silently drop 7/8 of its dim. + # Gather all tensor shards by key, remembering placements / mesh shape + # / mesh stride / global shape for proper consolidation. We can't use + # byte-level equality because a parameter that happens to be uniform + # after init (e.g. RMSNorm.weight initialized to 1.0) is genuinely + # sharded but every shard has the same values, so equality would + # silently drop 7/8 of its dim. placements_per_key: dict = {} + mesh_shape_per_key: dict = {} + mesh_stride_per_key: dict = {} global_shape_per_key: dict = {} for key in set(shard_state_dicts[0].keys()): shards: list[torch.Tensor] = [] placements = None + mesh_shape = None + mesh_stride = None global_shape = None for model_state_shard in shard_state_dicts: tensor = model_state_shard.pop(key) if hasattr(tensor, "_local_tensor"): if placements is None: placements = tensor.placements + mesh = tensor.device_mesh + mesh_shape = tuple(mesh.shape) global_shape = tuple(tensor.shape) + # mesh.mesh accessor requires an initialized PG (we + # don't have one when merging offline). DeviceMesh + # uses C-order (row-major) rank layout by default: + # last axis stride=1, then each preceding axis's + # stride = product of all following axes' sizes. + mesh_stride = tuple( + int(torch.tensor(mesh_shape[i + 1 :]).prod().item()) if i < len(mesh_shape) - 1 else 1 + for i in range(len(mesh_shape)) + ) shards.append(tensor._local_tensor.bfloat16()) else: # Plain tensor (e.g. inv_freq buffer): replicated implicitly. shards.append(tensor.bfloat16()) state_dict[key] = shards placements_per_key[key] = placements + mesh_shape_per_key[key] = mesh_shape + mesh_stride_per_key[key] = mesh_stride global_shape_per_key[key] = global_shape # Merge tensors using placements when available, otherwise fall back to @@ -135,25 +162,119 @@ def consolidate(self, shard_state_dicts: list[dict]) -> dict: state_dict[key] = shards[0] continue - # Single placement (FSDP1D): handle Shard / Replicate / Partial. if len(placements) == 1: + # Single placement (FSDP1D): handle Shard / Replicate / Partial. p = placements[0] if p.is_replicate(): state_dict[key] = shards[0] elif p.is_shard(): - state_dict[key] = torch.cat(shards, dim=p.dim) + # When the mesh axis has size 1 (e.g. dp_shard_mod_ep=1 + # for non-expert params in an EP-only config), each rank + # holds a full copy even though placement says Shard(dim). + # Detect this by comparing the global shape on the + # sharded dim with one local shard's: equal = no actual + # split, treat as replicate. + if shards[0].shape[p.dim] == global_shape_per_key[key][p.dim]: + state_dict[key] = shards[0] + else: + state_dict[key] = torch.cat(shards, dim=p.dim) else: raise NotImplementedError( f"Unsupported placement {p} for key '{key}' (only Shard / Replicate are handled)." ) else: - # Multi-axis (e.g. HSDP / 2D mesh): not currently produced by - # the trainer's FSDP2 setup. Fail loudly rather than silently - # mis-consolidating. - raise NotImplementedError(f"Multi-placement DTensor not supported: key='{key}' placements={placements}") + # Multi-axis DTensor (FSDP2 + EP). Re-stitch by walking mesh + # axes from inner-most (largest stride dim of the global rank + # index, but mesh.mesh.stride() gives that ordering directly) + # outward — at each step we group consecutive shards by the + # current axis and cat along that placement's dim. + state_dict[key] = self._consolidate_multi_axis( + shards, + placements, + mesh_shape_per_key[key], + mesh_stride_per_key[key], + key, + ) return state_dict + @staticmethod + def _consolidate_multi_axis( + shards: list[torch.Tensor], + placements: tuple, + mesh_shape: tuple[int, ...], + mesh_stride: tuple[int, ...], + key: str, + ) -> torch.Tensor: + """Reduce a list of shards from a multi-axis DTensor into one tensor. + + Strategy: arrange shards into a nested list shaped like ``mesh_shape`` + using ``mesh_stride`` to map global rank -> multi-index, then + recursively cat from inner-most axis outward. Each axis's Placement + tells us which tensor dim to cat along (Shard.dim) or whether it's a + no-op (Replicate / Partial). + """ + from torch.distributed.tensor.placement_types import Replicate, Shard + + assert len(placements) == len(mesh_shape), ( + f"placements vs mesh_shape rank mismatch for key '{key}': " + f"placements={placements} mesh_shape={mesh_shape}" + ) + world_size = 1 + for s in mesh_shape: + world_size *= s + assert len(shards) == world_size, ( + f"shard count mismatch for key '{key}': got {len(shards)} shards, " f"mesh implies {world_size}" + ) + + def rank_to_index(rank: int) -> tuple[int, ...]: + """Map a flat global rank to a multi-index on the device mesh. + + mesh_stride[i] is the stride for axis i: rank //= stride[i], then + % shape[i] gives that axis's coordinate. + """ + return tuple((rank // mesh_stride[i]) % mesh_shape[i] for i in range(len(mesh_shape))) + + # Build a nested-list grid indexed by mesh coords + def make_grid(shape): + if not shape: + return None + head, *tail = shape + return [make_grid(tail) for _ in range(head)] + + grid = make_grid(mesh_shape) + + def grid_set(grid, idx, val): + for i in idx[:-1]: + grid = grid[i] + grid[idx[-1]] = val + + def grid_get(grid, idx): + for i in idx: + grid = grid[i] + return grid + + for rank, shard in enumerate(shards): + idx = rank_to_index(rank) + grid_set(grid, idx, shard) + + # Fold axes from inner-most (last) to outer-most (first). At each + # axis, we cat together the inner sublists per placement[i]. + def fold(subgrid, axis: int) -> torch.Tensor: + if axis == len(mesh_shape) - 1: + # Leaf: subgrid is a list of tensors, one per coord on this axis. + tensors = subgrid + else: + tensors = [fold(child, axis + 1) for child in subgrid] + p = placements[axis] + if isinstance(p, Replicate): + return tensors[0] + if isinstance(p, Shard): + return torch.cat(tensors, dim=p.dim) + raise NotImplementedError(f"Unsupported placement {p} on mesh axis {axis} for key '{key}'.") + + return fold(grid, axis=0) + def _resolve_checkpoint_path(self, path: Path) -> Path: """Resolve checkpoint path, handling parent directories with multiple checkpoints. @@ -218,6 +339,7 @@ def merge( output_path: Path | None = None, model_cls: type | None = None, config: object | None = None, + model_general_type: str | None = None, ) -> Path: """Merge FSDP2 sharded checkpoint into a single consolidated checkpoint. @@ -227,6 +349,10 @@ def merge( output_path: Where to save merged checkpoint. If None, saves to checkpoint_path directly model_cls: Model class to instantiate. If None, infers from checkpoint_path config: Model config. If None, loads from checkpoint_path + model_general_type: Override AutoModel class (causal_lm / + image_text_to_text / masked_lm / general). Forwarded to + ``create_model_from_pretrained`` for configs registered under + multiple AutoModel mappings. Returns: Path to the merged checkpoint directory @@ -248,7 +374,7 @@ def merge( # Infer model class and config if not provided if model_cls is None: - model_cls = create_model_from_pretrained(checkpoint_path) + model_cls = create_model_from_pretrained(checkpoint_path, model_general_type=model_general_type) if config is None: config = AutoConfig.from_pretrained(checkpoint_path) diff --git a/src/lmms_engine/models/__init__.py b/src/lmms_engine/models/__init__.py index 1d296b3e..36691772 100644 --- a/src/lmms_engine/models/__init__.py +++ b/src/lmms_engine/models/__init__.py @@ -19,6 +19,7 @@ from .qwen2_audio import apply_liger_kernel_to_qwen2_audio from .qwen3 import apply_liger_kernel_to_qwen3 from .qwen3_5 import apply_liger_kernel_to_qwen3_5 +from .qwen3_5_moe import apply_liger_kernel_to_qwen3_5_moe from .qwen3_moe import apply_liger_kernel_to_qwen3_moe from .qwen3_omni_moe import ( Qwen3OmniMoeThinkerConfig, @@ -52,6 +53,7 @@ "apply_liger_kernel_to_qwen2_5_vl", "apply_liger_kernel_to_qwen2_audio", "apply_liger_kernel_to_qwen3_5", + "apply_liger_kernel_to_qwen3_5_moe", "apply_liger_kernel_to_qwen3_vl", "apply_liger_kernel_to_qwen3_vl_moe", "apply_liger_kernel_to_qwen3_moe", diff --git a/src/lmms_engine/models/qwen3_5/qwen3_5_ops.py b/src/lmms_engine/models/qwen3_5/qwen3_5_ops.py index d4d9fa08..e32c3af0 100644 --- a/src/lmms_engine/models/qwen3_5/qwen3_5_ops.py +++ b/src/lmms_engine/models/qwen3_5/qwen3_5_ops.py @@ -20,6 +20,12 @@ apply_mask_to_padding_states, apply_rotary_pos_emb, ) +from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeAttention, + Qwen3_5MoeGatedDeltaNet, + Qwen3_5MoeModel, + Qwen3_5MoeTextModel, +) from transformers.utils import is_flash_attn_2_available, logging from ..common_ops.rope import qwen3_vl_get_rope_index @@ -106,7 +112,7 @@ def _seq_idx_from_cu_seqlens(cu_seqlens: torch.Tensor, total_tokens: int) -> tor def linear_attn_forward( - self: Qwen3_5GatedDeltaNet, + self: Union[Qwen3_5GatedDeltaNet, Qwen3_5MoeGatedDeltaNet], hidden_states: torch.Tensor, cache_params: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, @@ -206,7 +212,7 @@ def linear_attn_forward( def text_model_forward( - self: Qwen3_5TextModel, + self: Union[Qwen3_5TextModel, Qwen3_5MoeTextModel], input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -350,7 +356,7 @@ def decoder_layer_forward( def attn_forward( - self: Qwen3_5Attention, + self: Union[Qwen3_5Attention, Qwen3_5MoeAttention], hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -411,7 +417,7 @@ def attn_forward( def model_forward( - self: Qwen3_5Model, + self: Union[Qwen3_5Model, Qwen3_5MoeModel], input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, diff --git a/src/lmms_engine/models/qwen3_5_moe/__init__.py b/src/lmms_engine/models/qwen3_5_moe/__init__.py new file mode 100644 index 00000000..8aadabf7 --- /dev/null +++ b/src/lmms_engine/models/qwen3_5_moe/__init__.py @@ -0,0 +1,6 @@ +from .monkey_patch import apply_liger_kernel_to_qwen3_5_moe, apply_rmpad_to_qwen3_5_moe + +__all__ = [ + "apply_liger_kernel_to_qwen3_5_moe", + "apply_rmpad_to_qwen3_5_moe", +] diff --git a/src/lmms_engine/models/qwen3_5_moe/monkey_patch.py b/src/lmms_engine/models/qwen3_5_moe/monkey_patch.py new file mode 100644 index 00000000..e1162d88 --- /dev/null +++ b/src/lmms_engine/models/qwen3_5_moe/monkey_patch.py @@ -0,0 +1,130 @@ +"""Monkey patches for transformers.models.qwen3_5_moe. + +Two independent registrations: +- `liger`: rope/rmsnorm/swiglu + fused-LCE forward on the CausalLM class. +- `rmpad`: text model/decoder/attention/gated-delta-net/MoE/experts forwards, + plus rmpad-flavoured CausalLM forward. + +The trainer runner applies them in order ["liger", "rmpad"] when both are +requested. SP is intentionally not supported. +""" + +from functools import partial, wraps +from types import MethodType + +from loguru import logger +from transformers import PreTrainedModel + +try: + from liger_kernel.transformers.monkey_patch import ( + _patch_rms_norm_module, + _patch_swiglu_module, + ) + from liger_kernel.transformers.rms_norm import LigerRMSNorm + from liger_kernel.transformers.rope import liger_rotary_pos_emb + from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +except ImportError: + _patch_rms_norm_module = None + _patch_swiglu_module = None + LigerRMSNorm = None + liger_rotary_pos_emb = None + LigerSwiGLUMLP = None + logger.warning("liger kernel not installed; qwen3_5_moe liger patch will be a no-op.") + +from lmms_engine.models.monkey_patch import MONKEY_PATCHER +from lmms_engine.utils.import_utils import is_transformers_version_greater_or_equal_to + +_IS_TRANSFORMERS_5 = is_transformers_version_greater_or_equal_to("5.0") + + +@MONKEY_PATCHER.register("qwen3_5_moe", "liger") +def apply_liger_kernel_to_qwen3_5_moe( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.qwen3_5_moe import modeling_qwen3_5_moe + + from .qwen3_5_moe_liger import lce_forward + + if rope and liger_rotary_pos_emb is not None: + modeling_qwen3_5_moe.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm and LigerRMSNorm is not None: + modeling_qwen3_5_moe.Qwen3_5MoeRMSNorm = LigerRMSNorm + + if cross_entropy: + from liger_kernel.transformers.functional import liger_cross_entropy + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(lce_forward, model) + else: + modeling_qwen3_5_moe.Qwen3_5MoeForCausalLM.forward = lce_forward + + if swiglu and LigerSwiGLUMLP is not None: + # qwen3_5_moe MLP (the inner Qwen3_5MoeMLP used as shared_expert) is swiglu-style + modeling_qwen3_5_moe.Qwen3_5MoeMLP = LigerSwiGLUMLP + + if model is not None: + base_model = getattr(model, model.base_model_prefix, model) + # base_model is Qwen3_5MoeTextModel for ForCausalLM, Qwen3_5MoeModel for ConditionalGeneration + language_model = getattr(base_model, "language_model", base_model) + if rms_norm and _patch_rms_norm_module is not None: + if hasattr(language_model, "norm"): + _patch_rms_norm_module(language_model.norm) + for decoder_layer in getattr(language_model, "layers", []): + if hasattr(decoder_layer, "input_layernorm"): + _patch_rms_norm_module(decoder_layer.input_layernorm) + if hasattr(decoder_layer, "post_attention_layernorm"): + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + +@MONKEY_PATCHER.register("qwen3_5_moe", "rmpad") +def apply_rmpad_to_qwen3_5_moe(model: PreTrainedModel = None) -> None: + """Replace the qwen3_5_moe text model / decoder / attention / linear-attn / + MoE / experts forwards with rmpad-aware versions. If `model` is None we + patch class-level only (future construction); otherwise we also rebind + `model.forward` to the rmpad-flavoured lce_forward. + """ + from transformers.models.qwen3_5_moe import modeling_qwen3_5_moe + + from .qwen3_5_moe_liger import lce_forward + from .qwen3_5_moe_ops import ( + attn_forward, + decoder_layer_forward, + experts_forward, + gated_delta_net_forward, + model_forward, + moe_sparse_layer_forward, + text_model_forward, + ) + + modeling_qwen3_5_moe.Qwen3_5MoeTextModel.forward = text_model_forward + modeling_qwen3_5_moe.Qwen3_5MoeModel.forward = model_forward + modeling_qwen3_5_moe.Qwen3_5MoeDecoderLayer.forward = decoder_layer_forward + modeling_qwen3_5_moe.Qwen3_5MoeAttention.forward = attn_forward + modeling_qwen3_5_moe.Qwen3_5MoeGatedDeltaNet.forward = gated_delta_net_forward + modeling_qwen3_5_moe.Qwen3_5MoeSparseMoeBlock.forward = moe_sparse_layer_forward + if _IS_TRANSFORMERS_5: + modeling_qwen3_5_moe.Qwen3_5MoeExperts.forward = experts_forward + + if model is not None: + # rebind CausalLM forward to lce_forward with use_rmpad=True + bound = partial(lce_forward, use_rmpad=True) + + @wraps(lce_forward) + def _forward(self, *args, **kwargs): + return bound(self, *args, **kwargs) + + model.forward = MethodType(_forward, model) diff --git a/src/lmms_engine/models/qwen3_5_moe/qwen3_5_moe_liger.py b/src/lmms_engine/models/qwen3_5_moe/qwen3_5_moe_liger.py new file mode 100644 index 00000000..0ee06536 --- /dev/null +++ b/src/lmms_engine/models/qwen3_5_moe/qwen3_5_moe_liger.py @@ -0,0 +1,131 @@ +"""Fused-linear cross-entropy forward for Qwen3_5MoeForCausalLM.""" + +from typing import List, Optional, Union + +import torch +from transformers.modeling_outputs import MoeCausalLMOutputWithPast +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func +from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM + +try: + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) +except ImportError: + print("Liger Kernel is not installed, pip install liger-kernel to use this patch") + + +def lce_forward( + self: Qwen3_5MoeForCausalLM, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + use_rmpad: bool = False, + **kwargs, +) -> MoeCausalLMOutputWithPast: + # Top-level config may be Qwen3_5MoeConfig (multimodal wrapper) or + # Qwen3_5MoeTextConfig (text-only ForCausalLM). Pull text-side fields + # from text_config when present. + text_cfg = getattr(self.config, "text_config", self.config) + output_attentions = ( + output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) + ) + output_router_logits = ( + output_router_logits if output_router_logits is not None else getattr(text_cfg, "output_router_logits", False) + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else getattr(self.config, "output_hidden_states", False) + ) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + seq_lens = outputs.get("seq_lens", None) + word_idx = outputs.get("word_idx", None) + + hidden_states = outputs.last_hidden_state + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + if labels is not None and word_idx is not None: + labels = labels.view(-1)[word_idx.long()] + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + if use_rmpad: + shift_hidden_states = [] + shift_labels_list = [] + for i in range(len(seq_lens) - 1): + cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :] + cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous() + cur_labels = labels[seq_lens[i] : seq_lens[i + 1]] + cur_shift_labels = cur_labels[1:].contiguous() + shift_hidden_states.append(cur_shift_hidden_states) + shift_labels_list.append(cur_shift_labels) + shift_hidden_states = torch.cat(shift_hidden_states, dim=0) + shift_labels = torch.cat(shift_labels_list, dim=0) + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + shift_hidden_states = shift_hidden_states.view(-1, text_cfg.hidden_size) + shift_labels = shift_labels.view(-1) + + reduction = "sum" if "num_items_in_batch" in kwargs else "mean" + lce = LigerFusedLinearCrossEntropyLoss(reduction=reduction) + + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + if reduction == "sum": + loss /= kwargs["num_items_in_batch"] + else: + logits = self.lm_head(hidden_states) + if labels is not None: + loss = self.loss_function(logits, labels, text_cfg.vocab_size, **kwargs) + + aux_loss = None + router_logits = getattr(outputs, "router_logits", None) + if output_router_logits and router_logits is not None: + aux_loss_mask = None if use_rmpad else attention_mask + aux_loss = load_balancing_loss_func( + router_logits, + text_cfg.num_experts, + text_cfg.num_experts_per_tok, + aux_loss_mask, + ) + if labels is not None: + loss += text_cfg.router_aux_loss_coef * aux_loss.to(loss.device) + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=router_logits, + ) diff --git a/src/lmms_engine/models/qwen3_5_moe/qwen3_5_moe_ops.py b/src/lmms_engine/models/qwen3_5_moe/qwen3_5_moe_ops.py new file mode 100644 index 00000000..362ba2dc --- /dev/null +++ b/src/lmms_engine/models/qwen3_5_moe/qwen3_5_moe_ops.py @@ -0,0 +1,399 @@ +"""Patched forwards for transformers.models.qwen3_5_moe. + +Attention-side forwards (`attn_forward`, `linear_attn_forward`) and the vision +patch-embed are reused as-is from the dense qwen3_5 ops — the gated-attention +and gated-delta-net layers are structurally identical between the dense and +MoE variants. + +MoE-specific paths (kept local): +- `decoder_layer_forward` — handles the SparseMoeBlock tuple return shape and + propagates router_logits when ``output_router_logits`` is requested. +- `text_model_forward` / `model_forward` — collect router_logits across layers + and surface them on ``BaseModelOutputWithPastAndRmpad`` so ``lce_forward`` + can compute the load-balancing aux loss. +- `moe_sparse_layer_forward` — routed experts + shared_expert combine. +- `experts_forward` — stacked-parameter experts (gate_up_proj + down_proj). +""" +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.distributed.tensor import DTensor +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeDecoderLayer, + Qwen3_5MoeModel, + Qwen3_5MoeSparseMoeBlock, + Qwen3_5MoeTextModel, +) +from transformers.utils import is_flash_attn_2_available + +# ---- reused as-is from qwen3_5 (dense) ---- +from lmms_engine.models.qwen3_5.qwen3_5_ops import attn_forward +from lmms_engine.models.qwen3_5.qwen3_5_ops import ( # noqa: F401 + linear_attn_forward as gated_delta_net_forward, +) +from lmms_engine.models.qwen3_5.qwen3_5_ops import patch_embed_forward + +from ..common_ops.rope import qwen3_vl_get_rope_index +from ..sequence_packing_utils import BaseModelOutputWithPastAndRmpad, _unpad_input + +if is_flash_attn_2_available(): + from flash_attn.bert_padding import index_first_axis, rearrange + + +# --------------------------------------------------------------------------- +# decoder_layer_forward — same attention dispatch as qwen3_5, but MoE MLP +# returns (hidden, router_logits) tuple instead of a plain tensor. +# --------------------------------------------------------------------------- +def decoder_layer_forward( + self: Qwen3_5MoeDecoderLayer, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + output_router_logits: bool = False, + **kwargs, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + needs_squeeze = hidden_states.ndim == 2 + if needs_squeeze: + hidden_states = hidden_states.unsqueeze(0) + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cache_params=past_key_values, + cache_position=cache_position, + attention_mask=None, + cu_seq_lens=cu_seq_lens, + ) + if needs_squeeze: + hidden_states = hidden_states.squeeze(0) + elif self.layer_type == "full_attention": + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cu_seq_lens=cu_seq_lens, + indices=indices, + position_embeddings=position_embeddings, + cache_position=cache_position, + **kwargs, + ) + else: + raise ValueError(f"unknown layer_type={self.layer_type!r}") + + hidden_states = residual + hidden_states + + # MoE block — wraps add batch dim if rmpad flattened to 2D + residual = hidden_states + needs_squeeze = hidden_states.ndim == 2 + if needs_squeeze: + hidden_states = hidden_states.unsqueeze(0) + hidden_states = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(hidden_states) + + # Qwen3_5MoeSparseMoeBlock returns (Tensor, router_logits) + router_logits = None + if isinstance(mlp_output, tuple): + hidden_states, router_logits = mlp_output + else: + hidden_states = mlp_output + + if needs_squeeze: + hidden_states = hidden_states.squeeze(0) + hidden_states = residual + hidden_states + + if output_router_logits and router_logits is not None: + return hidden_states, router_logits + return hidden_states + + +# --------------------------------------------------------------------------- +# text_model_forward — like qwen3_5 dense, but collects router_logits across +# layers when requested. +# --------------------------------------------------------------------------- +def text_model_forward( + self: Qwen3_5MoeTextModel, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + cu_seq_lens: Optional[torch.IntTensor] = None, + indices: Optional[torch.IntTensor] = None, + output_router_logits: Optional[bool] = None, + **kwargs, +) -> BaseModelOutputWithPastAndRmpad: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else getattr(self.config, "output_router_logits", False) + ) + + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[0], + device=inputs_embeds.device, + ) + + # Qwen3.5 expects 4-component position_ids ``(text, t, h, w)``. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(4, 1, -1) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(4, position_ids.shape[0], -1) + elif position_ids.ndim == 3 and position_ids.shape[0] == 3: + text_axis = ( + torch.arange(position_ids.shape[-1], device=position_ids.device, dtype=position_ids.dtype) + .view(1, 1, -1) + .expand(1, position_ids.shape[1], -1) + ) + position_ids = torch.cat([text_axis, position_ids], dim=0) + + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + hidden_states = inputs_embeds + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + all_router_logits = () if output_router_logits else None + + for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=text_position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cu_seq_lens=cu_seq_lens, + indices=indices, + cache_position=cache_position, + position_embeddings=position_embeddings, + output_router_logits=output_router_logits, + **kwargs, + ) + if isinstance(layer_outputs, tuple): + hidden_states, router_logits = layer_outputs + if output_router_logits and router_logits is not None: + all_router_logits += (router_logits,) + else: + hidden_states = layer_outputs + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPastAndRmpad( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + seq_lens=cu_seq_lens, + word_idx=indices, + router_logits=all_router_logits if output_router_logits else None, + ) + + +# --------------------------------------------------------------------------- +# model_forward — outer multimodal wrapper. Mirrors qwen3_5 dense +# model_forward, but plumbs output_router_logits through and surfaces +# router_logits on the returned BaseModelOutputWithPastAndRmpad. +# --------------------------------------------------------------------------- +def model_forward( + self: Qwen3_5MoeModel, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + output_router_logits: Optional[bool] = None, + **kwargs, +) -> BaseModelOutputWithPastAndRmpad: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + # ---- un-pad input_ids / inputs_embeds ---- + if input_ids is not None: + original_input_ids = input_ids + input_ids, indices, cu_seq_lens, _ = _unpad_input(input_ids, attention_mask=attention_mask) + batch_size, seq_length = original_input_ids.shape + else: + original_input_ids = None + original_inputs_embeds = inputs_embeds + inputs_embeds, indices, cu_seq_lens, _ = _unpad_input(inputs_embeds, attention_mask=attention_mask) + batch_size, seq_length, _ = original_inputs_embeds.shape + + # ---- compute 3D position ids from padded layout, then gather to packed ---- + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + position_ids, rope_deltas = qwen3_vl_get_rope_index( + self, + original_input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + + # position_ids: (c, B, S) -> packed (c, 1, total_tokens) + position_ids = ( + index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1) + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # ---- visual feature injection (still on packed inputs_embeds) ---- + if pixel_values is not None: + image_outputs: BaseModelOutputWithPooling = self.get_image_features( + pixel_values, image_grid_thw, return_dict=True + ) + image_embeds = image_outputs.pooler_output + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_outputs: BaseModelOutputWithPooling = self.get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True + ) + video_embeds = video_outputs.pooler_output + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + # `cu_seq_lens` / `indices` may already be in **kwargs from the collator + # (we compute fresh ones from attention_mask); drop them to avoid duplicate kwargs. + kwargs.pop("cu_seq_lens", None) + kwargs.pop("indices", None) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + indices=indices, + cu_seq_lens=cu_seq_lens, + output_router_logits=output_router_logits, + **kwargs, + ) + + return BaseModelOutputWithPastAndRmpad( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + seq_lens=cu_seq_lens, + word_idx=indices, + router_logits=getattr(outputs, "router_logits", None), + ) + + +# --------------------------------------------------------------------------- +# moe_sparse_layer_forward — routed experts + shared_expert combine +# --------------------------------------------------------------------------- +def moe_sparse_layer_forward( + self: Qwen3_5MoeSparseMoeBlock, + hidden_states: torch.Tensor, + **kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + + # Shared expert path + shared_out = self.shared_expert(hidden_states_flat) + shared_out = torch.sigmoid(self.shared_expert_gate(hidden_states_flat)) * shared_out + + # Router (returns logits, normalized weights, indices) + router_logits, routing_weights, selected_experts = self.gate(hidden_states_flat) + num_experts = self.gate.num_experts + top_k = self.gate.top_k + + # Build per-expert routing tensors (same shape qwen3_moe uses so the EP + # dispatch in Qwen3_5MoeParallelStyle._input_fn is identical) + selected_experts = selected_experts.to(torch.float32) + num_tokens_per_expert = torch.histc(selected_experts, bins=num_experts, min=0, max=num_experts) + selected_experts = selected_experts.to(torch.int64) + num_tokens_per_expert = num_tokens_per_expert.to(torch.int64) + + token_indices_experts_sorted = torch.argsort(selected_experts.view(-1), stable=True) + top_scores_experts_sorted = routing_weights.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // top_k + + token_indices_experts_sorted = token_indices_experts_sorted.reshape(-1, 1).expand(-1, hidden_dim) + routed_input = torch.gather(hidden_states_flat, dim=0, index=token_indices_experts_sorted) + + out_experts_split = self.experts(routed_input, num_tokens_per_expert) + + routed_output = out_experts_split * top_scores_experts_sorted.reshape(-1, 1) + final_hidden_states = torch.zeros_like(hidden_states_flat) + final_hidden_states = final_hidden_states.scatter_add(dim=0, index=token_indices_experts_sorted, src=routed_output) + + # Combine routed + shared + final_hidden_states = final_hidden_states + shared_out + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +# --------------------------------------------------------------------------- +# experts_forward — stacked-parameter experts (same shape as qwen3_moe T>=5) +# --------------------------------------------------------------------------- +def experts_forward(self, *routed_input): + if len(routed_input) == 2 and routed_input[1].ndim == 1: + routed_input = torch.split( + routed_input[0], + split_size_or_sections=routed_input[1].tolist(), + dim=0, + ) + + if isinstance(self.down_proj, DTensor): + down_proj = self.down_proj.to_local() + gate_up_proj = self.gate_up_proj.to_local() + else: + down_proj = self.down_proj + gate_up_proj = self.gate_up_proj + + out_experts_split = [] + for idx, x in enumerate(routed_input): + gate_up = F.linear(x, gate_up_proj[idx]) + gate, up = gate_up.chunk(2, dim=-1) + hidden = self.act_fn(gate) * up + hidden = F.linear(hidden, down_proj[idx]) + out_experts_split.append(hidden) + + return torch.cat(out_experts_split, dim=0) diff --git a/src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py b/src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py index 19924a2c..bb63d2f6 100644 --- a/src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py +++ b/src/lmms_engine/models/qwen3_moe/qwen3_moe_ops.py @@ -18,15 +18,16 @@ _IS_TRANSFORMERS_5 = is_transformers_version_greater_or_equal_to("5.0") if _IS_TRANSFORMERS_5: from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeExperts + from transformers.utils import is_flash_attn_2_available +from lmms_engine.kernels.attention import varlen_attn from lmms_engine.models.sequence_packing_utils import ( BaseModelOutputWithPastAndRmpad, _unpad_input, ) if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import ( index_first_axis, pad_input, @@ -234,7 +235,7 @@ def attn_forward( max_seqlen = torch.diff(cu_seq_lens).max().item() if cu_seq_lens is not None else None window_size = (-1, -1) - attn_output = flash_attn_varlen_func( + attn_output = varlen_attn( q=query_states, k=key_states, v=value_states, @@ -246,6 +247,7 @@ def attn_forward( window_size=window_size, softmax_scale=self.head_dim**-0.5, dropout_p=0.0, + backend=self.config._attn_implementation, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) diff --git a/src/lmms_engine/parallel/parallelize.py b/src/lmms_engine/parallel/parallelize.py index 98f8b268..a3136bb1 100644 --- a/src/lmms_engine/parallel/parallelize.py +++ b/src/lmms_engine/parallel/parallelize.py @@ -3,12 +3,14 @@ if TYPE_CHECKING: from lmms_engine.train.config import TrainingArguments +from .qwen3_5_moe.parallelize import apply_qwen3_5_moe_parallelize_fn from .qwen3_moe.parallelize import apply_qwen3_moe_parallelize_fn from .qwen3_omni_moe.parallelize import apply_qwen3_omni_moe_parallelize_fn from .qwen3_vl_moe.parallelize import apply_qwen3_vl_moe_parallelize_fn MODEL_TO_PARALLEL_METHOD = { "qwen3_moe": apply_qwen3_moe_parallelize_fn, + "qwen3_5_moe": apply_qwen3_5_moe_parallelize_fn, "qwen3_omni_moe": apply_qwen3_omni_moe_parallelize_fn, "qwen3_omni_moe_thinker": apply_qwen3_omni_moe_parallelize_fn, "qwen3_vl_moe": apply_qwen3_vl_moe_parallelize_fn, diff --git a/src/lmms_engine/parallel/qwen3_5_moe/__init__.py b/src/lmms_engine/parallel/qwen3_5_moe/__init__.py new file mode 100644 index 00000000..7d9d7922 --- /dev/null +++ b/src/lmms_engine/parallel/qwen3_5_moe/__init__.py @@ -0,0 +1,8 @@ +from .parallelize import apply_qwen3_5_moe_parallel, apply_qwen3_5_moe_parallelize_fn +from .style import Qwen3_5MoeParallelStyle + +__all__ = [ + "apply_qwen3_5_moe_parallel", + "apply_qwen3_5_moe_parallelize_fn", + "Qwen3_5MoeParallelStyle", +] diff --git a/src/lmms_engine/parallel/qwen3_5_moe/parallelize.py b/src/lmms_engine/parallel/qwen3_5_moe/parallelize.py new file mode 100644 index 00000000..0bc4b8be --- /dev/null +++ b/src/lmms_engine/parallel/qwen3_5_moe/parallelize.py @@ -0,0 +1,120 @@ +"""FSDP2 + Expert Parallel wiring for qwen3_5_moe. + +qwen3_5_moe is multimodal: the top-level model class is +``Qwen3_5MoeForConditionalGeneration``, whose ``.model`` is the multimodal +wrapper ``Qwen3_5MoeModel`` (containing ``visual`` + ``language_model``). +Decoder layers live at ``model.model.language_model.layers`` (same shape as +qwen3_vl_moe — there is no ``.layers`` attribute on the outer wrapper). +""" + +from typing import TYPE_CHECKING + +import torch +from loguru import logger +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard +from torch.distributed.tensor import Shard +from torch.distributed.tensor.parallel import parallelize_module +from transformers.models.qwen3_5_moe import Qwen3_5MoeForConditionalGeneration + +import lmms_engine.parallel.process_group_manager as pgm +from lmms_engine.utils.fsdp2_utils import fsdp2_load_full_state_dict + +from .style import Qwen3_5MoeParallelStyle + +if TYPE_CHECKING: + from lmms_engine.train.config import TrainingArguments + + +def apply_qwen3_5_moe_parallel( + model: Qwen3_5MoeForConditionalGeneration, + ep_mesh: DeviceMesh, + tp_mesh: DeviceMesh = None, + **kwargs, +): + assert tp_mesh is None, "Tensor Parallelism is not supported yet for Qwen3_5Moe" + + for decoder_layer in model.model.language_model.layers: + module = decoder_layer.mlp + ep_plan = Qwen3_5MoeParallelStyle() + parallelize_module( + module.experts, + device_mesh=ep_mesh, + parallelize_plan=ep_plan, + ) + + logger.info(f"Applied Qwen3_5MoeParallelStyle to {len(model.model.language_model.layers)} layers") + + +def apply_qwen3_5_moe_fsdp2( + model: Qwen3_5MoeForConditionalGeneration, + train_args: "TrainingArguments", + **kwargs, +): + if not train_args.fsdp_config.get("transformer_layer_cls_to_wrap", None): + logger.warning("transformer_layer_cls_to_wrap ignored; qwen3_5_moe wraps decoder layers explicitly.") + + param_dtype = torch.bfloat16 if train_args.bf16 else torch.float16 + + if train_args.gradient_checkpointing: + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + + reduce_dtype = getattr(torch, train_args.reduce_dtype) + output_dtype = getattr(torch, train_args.output_dtype) + mp_policy = MixedPrecisionPolicy( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + output_dtype=output_dtype, + ) + + dp_mesh = pgm.process_group_manager.device_mesh["fsdp"] + + fsdp_kwargs = { + "reshard_after_forward": getattr(train_args, "fsdp_config", {}).get("reshard_after_forward", True), + "mp_policy": mp_policy, + "mesh": dp_mesh, + } + + ep_size = pgm.process_group_manager.ep_size + expert_fsdp_kwargs = None + if ep_size > 1: + + def _experts_shard_placement_fn(param): + return Shard(1) + + expert_fsdp_kwargs = dict(fsdp_kwargs) + expert_fsdp_kwargs["mesh"] = pgm.process_group_manager.device_mesh["dp_shard_mod_ep"] + expert_fsdp_kwargs["shard_placement_fn"] = _experts_shard_placement_fn + + # Wrap vision tower (same pattern as qwen3_vl_moe) + if hasattr(model.model, "visual") and model.model.visual is not None: + fully_shard(model.model.visual, **fsdp_kwargs) + + for decoder_layer in model.model.language_model.layers: + # MoE block + if ep_size > 1: + fully_shard(decoder_layer.mlp, **expert_fsdp_kwargs) + + # Attention block — branch on layer_type + if decoder_layer.layer_type == "linear_attention": + fully_shard(decoder_layer.linear_attn, **fsdp_kwargs) + else: # "full_attention" + fully_shard(decoder_layer.self_attn, **fsdp_kwargs) + + fully_shard(model.model.language_model.embed_tokens, **fsdp_kwargs) + fully_shard(model, **fsdp_kwargs) + + +def apply_qwen3_5_moe_parallelize_fn( + model: Qwen3_5MoeForConditionalGeneration, + train_args: "TrainingArguments", + **kwargs, +): + ep_size = pgm.process_group_manager.ep_size + full_state_dict = model.state_dict() + if ep_size > 1: + ep_mesh = pgm.process_group_manager.device_mesh["ep"] + apply_qwen3_5_moe_parallel(model, ep_mesh=ep_mesh, **kwargs) + + apply_qwen3_5_moe_fsdp2(model, train_args, **kwargs) + fsdp2_load_full_state_dict(model, full_state_dict) diff --git a/src/lmms_engine/parallel/qwen3_5_moe/style.py b/src/lmms_engine/parallel/qwen3_5_moe/style.py new file mode 100644 index 00000000..753b83db --- /dev/null +++ b/src/lmms_engine/parallel/qwen3_5_moe/style.py @@ -0,0 +1,101 @@ +"""ParallelStyle for Qwen3_5MoeExperts. + +Identical EP dispatch/permute/combine logic as Qwen3MoeParallelStyle; only the +class check for partition_fn differs. +""" +from typing import Optional + +import torch +import torch.nn as nn +from torch.distributed.tensor import ( + DeviceMesh, + Shard, + distribute_module, + distribute_tensor, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement +from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts + +import lmms_engine.parallel.process_group_manager as pgm +from lmms_engine.parallel.expert_parallel.utils import ( + _compute_permute_indices, + _token_combine, + _token_dispatch, +) + + +class Qwen3_5MoeParallelStyle(ParallelStyle): + def __init__( + self, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ) -> None: + super().__init__() + self.input_layouts = (input_layouts or Shard(0),) + self.output_layouts = (output_layouts or Shard(0),) + self.use_local_output = use_local_output + self.desired_input_layouts = (Shard(0),) + self.input_splits = None + self.output_splits = None + self.permute_indices = None + self.num_experts = None + + def _input_fn(self, inputs, mesh: DeviceMesh): + routed_input, num_tokens_per_expert = inputs + if pgm.process_group_manager.ep_world_size > 1: + (routed_input, input_splits, output_splits, num_tokens_per_expert_group) = _token_dispatch( + routed_input, num_tokens_per_expert + ) + permute_indices, split_sizes = _compute_permute_indices( + torch.tensor(num_tokens_per_expert_group, device=routed_input.device), + pgm.process_group_manager.ep_world_size, + self.num_experts // pgm.process_group_manager.ep_world_size, + ) + routed_input = routed_input[permute_indices] + routed_input = torch.split( + routed_input[: sum(output_splits)], + split_size_or_sections=split_sizes, + dim=0, + ) + self.input_splits = input_splits + self.output_splits = output_splits + self.permute_indices = permute_indices + else: + routed_input = torch.split( + routed_input, + split_size_or_sections=num_tokens_per_expert.tolist(), + dim=0, + ) + return routed_input + + def _output_fn(self, output, mesh: DeviceMesh): + if pgm.process_group_manager.ep_world_size > 1: + output[self.permute_indices] = output.clone() + output = _token_combine(output, self.input_splits, self.output_splits) + return output + + @staticmethod + def _partition_fn(name, mod, device_mesh): + if isinstance(mod, Qwen3_5MoeExperts): + expert_parallel_dim = 0 + mod.register_parameter( + "gate_up_proj", + nn.Parameter(distribute_tensor(mod.gate_up_proj, device_mesh, [Shard(expert_parallel_dim)])), + ) + mod.register_parameter( + "down_proj", + nn.Parameter(distribute_tensor(mod.down_proj, device_mesh, [Shard(expert_parallel_dim)])), + ) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if isinstance(module, Qwen3_5MoeExperts): + self.num_experts = module.num_experts + return distribute_module( + module, + device_mesh, + partition_fn=Qwen3_5MoeParallelStyle._partition_fn, + input_fn=self._input_fn, + output_fn=self._output_fn, + ) diff --git a/src/lmms_engine/train/runner.py b/src/lmms_engine/train/runner.py index e3a3be41..fcacbcb9 100644 --- a/src/lmms_engine/train/runner.py +++ b/src/lmms_engine/train/runner.py @@ -85,7 +85,11 @@ def _build_model(self): if init_config is None: # If no nested config, use the load_from_config dict directly (excluding model_type) init_config = {k: v for k, v in load_from_config.items() if k != "model_type"} - model_class, m_config = create_model_from_config(model_type, init_config) + model_class, m_config = create_model_from_config( + model_type, + init_config, + model_general_type=self.model_config.model_general_type, + ) model = model_class.from_config(m_config, **model_kwargs) else: raise ValueError("No model name or pretrained path provided. Please provide one of them.") diff --git a/test/train/qwen3_5_moe/__init__.py b/test/train/qwen3_5_moe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/train/qwen3_5_moe/test_qwen3_5_moe.py b/test/train/qwen3_5_moe/test_qwen3_5_moe.py new file mode 100644 index 00000000..7f351b20 --- /dev/null +++ b/test/train/qwen3_5_moe/test_qwen3_5_moe.py @@ -0,0 +1,37 @@ +import os +import unittest +from unittest import TestCase + +from utils import launch_torchrun_training, with_multi_gpu_training, with_temp_dir + + +class TestQwen3_5Moe(TestCase): + @with_temp_dir + @with_multi_gpu_training + def test_train_fsdp2(self, temp_dir, nproc_per_node): + """Test Qwen3.5 MoE training with FSDP2 and Expert Parallelism using torchrun subprocess.""" + + script_path = os.path.join(os.path.dirname(__file__), "train_qwen3_5_moe_ep.py") + + result = launch_torchrun_training( + script_path=script_path, + output_dir=temp_dir, + nproc_per_node=nproc_per_node, + timeout=600, + ) + + self.assertIsNotNone(result, "Training process should not be None") + self.assertEqual( + result.returncode, + 0, + f"Training failed with return code {result.returncode}", + ) + + if result.stdout: + print("Training stdout:", result.stdout) + if result.stderr: + print("Training stderr:", result.stderr) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/train/qwen3_5_moe/train_qwen3_5_moe_ep.py b/test/train/qwen3_5_moe/train_qwen3_5_moe_ep.py new file mode 100644 index 00000000..89970925 --- /dev/null +++ b/test/train/qwen3_5_moe/train_qwen3_5_moe_ep.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +"""CI/CD smoke test: tiny qwen3_5_moe with Expert Parallelism. + +Builds the model via `load_from_config` (random init) — no checkpoint +download. Top-level model is Qwen3_5MoeForConditionalGeneration with +model_type='qwen3_5_moe', so our liger/rmpad monkey patches and EP +parallelize fn dispatch correctly. + +Usage: + torchrun --nproc_per_node=8 test/train/qwen3_5_moe/train_qwen3_5_moe_ep.py \\ + --output_dir ./output/qwen3_5_moe_ep4 --ep_degree 4 +""" +import argparse + +from lmms_engine.launch.cli import create_train_task + + +def main(): + parser = argparse.ArgumentParser(description="Train Qwen3.5 MoE model with Expert Parallelism") + parser.add_argument("--output_dir", type=str, required=True, help="Output directory for training") + parser.add_argument("--ep_degree", type=int, default=2, choices=[2, 4, 8], help="Expert parallelism degree") + parser.add_argument("--max_steps", type=int, default=10, help="Maximum number of training steps") + parser.add_argument("--processor_name", default="Qwen/Qwen3.6-35B-A3B") + parser.add_argument("--nproc_per_node", type=int, default=None) + parser.add_argument("--nnodes", type=int, default=1) + parser.add_argument("--node_rank", type=int, default=0) + parser.add_argument("--master_addr", type=str, default="127.0.0.1") + parser.add_argument("--master_port", type=str, default="8000") + + args, unknown = parser.parse_known_args() + + text_hidden_size = 256 + + cfg = { + "trainer_type": "fsdp2_trainer", + "dataset_config": { + "dataset_type": "vision_iterable", + "dataset_format": "yaml", + "datasets": [ + { + "path": "data/lmms_engine_test/text_example/open_thoughts_5k.parquet", + "data_folder": "", + "data_type": "parquet", + } + ], + "processor_config": { + "processor_name": args.processor_name, + "processor_type": "qwen3_vl", + }, + "packing": False, + "video_backend": "qwen_vl_utils", + }, + "model_config": { + "load_from_config": { + "model_type": "qwen3_5_moe", + "text_config": { + "hidden_size": text_hidden_size, + "intermediate_size": 512, + "num_hidden_layers": 4, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "num_experts": 8, + "num_experts_per_tok": 2, + "shared_expert_intermediate_size": 256, + "layer_types": [ + "linear_attention", + "full_attention", + "linear_attention", + "full_attention", + ], + "head_dim": 32, + # match Qwen/Qwen3.6-35B-A3B tokenizer vocab (incl. image/video special tokens) + "vocab_size": 248320, + }, + "vision_config": { + "depth": 2, + "hidden_size": 128, + "intermediate_size": 256, + "num_heads": 4, + "out_hidden_size": text_hidden_size, + "num_position_embeddings": 64, + }, + }, + "attn_implementation": "flash_attention_2", + "model_general_type": "image_text_to_text", + "monkey_patch_kwargs": { + "patch_type": ["liger", "rmpad"], + "fused_linear_cross_entropy": True, + "rms_norm": True, + "swiglu": True, + }, + }, + "trainer_args": { + "per_device_train_batch_size": 1, + "gradient_checkpointing": True, + "num_train_epochs": 1, + "max_steps": args.max_steps, + "report_to": "none", + "output_dir": args.output_dir, + "warmup_ratio": 0.0, + "eval_strategy": "no", + "save_strategy": "no", + "dataloader_num_workers": 2, + "bf16": True, + "lr_scheduler_type": "cosine", + "use_liger_kernel": True, + "use_rmpad": True, + "fsdp2": True, + "group_by_length": True, + "fsdp_config": { + "transformer_layer_cls_to_wrap": ["Qwen3_5MoeDecoderLayer"], + "reshard_after_forward": False, + }, + "ep_degree": args.ep_degree, + "sp_ulysses_degree": 1, + }, + } + + print(f"\n{'='*70}\nqwen3_5_moe EP test ep={args.ep_degree} steps={args.max_steps}\n{'='*70}\n") + train_task = create_train_task(cfg) + train_task.build() + train_task.run() + print(f"\n{'='*70}\nEP test completed successfully\n{'='*70}\n") + + +if __name__ == "__main__": + main()