Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
72fdca5
refactor(qwen3_moe): use varlen_attn(backend=...) for attn_implementa…
kcz358 May 15, 2026
12063f1
chore: ignore docs/superpowers (local plans/specs)
kcz358 May 15, 2026
eeed7f8
feat(qwen3_5_moe): package skeleton with empty monkey patch entries
kcz358 May 15, 2026
21316ad
feat(qwen3_5_moe): MoE-specific forwards; reuse qwen3_5 attn/linear_a…
kcz358 May 15, 2026
72381c3
refactor(qwen3_5): widen self: type hints to Union[dense, moe] for sh…
kcz358 May 15, 2026
d61c416
feat(qwen3_5_moe): liger + rmpad monkey patches (OV2-style split)
kcz358 May 15, 2026
4324195
feat(qwen3_5_moe): EP ParallelStyle (mirrors qwen3_moe with Qwen3_5Mo…
kcz358 May 15, 2026
788d8e8
feat(qwen3_5_moe): EP parallelize fn + register MODEL_TO_PARALLEL_METHOD
kcz358 May 15, 2026
efb3380
feat(mapping_func): create_model_from_config accepts model_general_ty…
kcz358 May 15, 2026
b83f995
test(qwen3_5_moe): tiny-config EP smoke test (ep_degree 2/4/8)
kcz358 May 15, 2026
cdfb637
test(qwen3_5_moe): unittest wrapper + standardize train script argparse
kcz358 May 15, 2026
0958d48
fix(qwen3_5_moe): walk multimodal wrapper for layers (model.model.lan…
kcz358 May 15, 2026
888abb3
fix(qwen3_5_moe): decoder_layer always returns Tensor (drop router_lo…
kcz358 May 15, 2026
a7ad33f
fix(qwen3_5_moe): lce_forward handles Qwen3_5MoeConfig outer wrapper
kcz358 May 15, 2026
b3bf975
feat(merger): support FSDP2 + EP multi-axis DTensor checkpoints
kcz358 May 15, 2026
1cdfd5e
docs(qwen3_5_moe): example yaml + run.sh + model doc with EP merger u…
kcz358 May 15, 2026
1beb541
feat(qwen3_5_moe): wire router auxiliary loss through outer model
kcz358 May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,4 @@ checkpoints/
# macOS
.DS_Store
.vscode
docs/superpowers
97 changes: 97 additions & 0 deletions docs/models/qwen3_5_moe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Qwen3.5-MoE Training

## Overview

Qwen3.5-MoE (`Qwen/Qwen3.6-35B-A3B`) is a **multimodal** Mixture-of-Experts model
with a vision tower plus a hybrid-attention MoE language model. Each decoder
layer is either a **linear-attention** layer (gated delta net) or a **full
softmax-attention** layer, selected per layer via
`config.text_config.layer_types[i]`. The MoE block contains a
**shared_expert** alongside the routed experts.

The top-level multimodal class is `Qwen3_5MoeForConditionalGeneration`
(`model_type = "qwen3_5_moe"`).

## Supported Features

| Feature | Support |
|---------|---------|
| **FSDP2** | ✅ |
| **USP / Sequence Parallel** | ❌ (linear-attention path is not SP-safe) |
| **Muon Optimizer** | ✅ |
| **Liger Kernel** | ✅ |
| **Packing** | ✅ (rmpad) |
| **NSA** | ❌ |
| **Expert Parallelism (EP)** | ✅ |

**Highlights**: Hybrid attention (linear / full), `shared_expert` + routed
experts, Expert Parallelism via the custom `Qwen3_5MoeExperts` `ParallelStyle`.

## Quick Start

See the example configuration and run script:
- **Example Config**: [examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml](../../examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml)
- **Run Script**: [examples/qwen3_5_moe/run.sh](../../examples/qwen3_5_moe/run.sh)

Verified end-to-end with `cicd/run_traincicd.sh --model-name qwen3_5_moe --gpu-count 4`.

## Key Configuration

```yaml
model_config:
load_from_pretrained_path: "Qwen/Qwen3.6-35B-A3B"
# CRITICAL: Qwen3_5MoeConfig is registered in both causal_lm and
# image_text_to_text auto-mappings. Without this line we'd silently load the
# text-only Qwen3_5MoeForCausalLM instead of the multimodal
# Qwen3_5MoeForConditionalGeneration.
model_general_type: image_text_to_text
attn_implementation: flash_attention_2
monkey_patch_kwargs:
# Two patches registered separately for qwen3_5_moe; runner applies them
# in order. "rmpad" accepts no kwargs; the listed kwargs go to "liger".
patch_type: ["liger", "rmpad"]
fused_linear_cross_entropy: true
rms_norm: true
swiglu: true

trainer_args:
use_liger_kernel: true
use_rmpad: true
fsdp2: true
fsdp_config:
transformer_layer_cls_to_wrap: ["Qwen3_5MoeDecoderLayer"]
sp_ulysses_degree: 1 # SP is not supported
ep_degree: 8 # Expert Parallelism degree
```

## Expert Parallelism

Expert Parallelism (EP) distributes the routed MoE experts across GPUs.
Configure `ep_degree` to match your GPU count (e.g., 2, 4, 8). The FSDP wrap
branches on `decoder_layer.layer_type` (`linear_attn` vs `self_attn`) so that
the gated-delta-net and softmax-attention layers each get the right sharding
plan, while the experts are sharded along the expert dimension via the
`Qwen3_5MoeExperts` `ParallelStyle`.

## Merging EP Checkpoints

FSDP2 + EP checkpoints store expert weights as **multi-axis DTensors** with
placements like `(Shard(dim=1), Shard(dim=0))` on a 2D mesh
`(dp_shard_mod_ep, ep)`. The checkpoint merger consolidates these correctly
as of this branch.

Merge a checkpoint into a single HF-loadable directory with:

```bash
python -m lmms_engine.merger \
--checkpoint_path ./output/qwen3_5_moe_a3b_ep8/checkpoint-1000 \
--output_path ./output/qwen3_5_moe_a3b_ep8/merged-1000 \
--model_general_type image_text_to_text
```

`--model_general_type image_text_to_text` is **required** for the same reason
as at train time: without it the merger instantiates `Qwen3_5MoeForCausalLM`
(text-only) from the saved config and crashes with
`'Qwen3_5MoeConfig' has no attribute 'vocab_size'` (the vocab lives on
`config.text_config`, which the multimodal wrapper knows about but the
text-only causal-LM does not).
89 changes: 89 additions & 0 deletions examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Unified LMMs Engine Training Configuration for Qwen3.5-MoE (Qwen3.6-35B-A3B)
#
# Multimodal MoE model with hybrid attention (linear / full per layer) and
# shared_expert + routed experts. Expert Parallelism (EP) is supported; sequence
# parallelism (Ulysses) is NOT supported on this model (the gated-delta linear
# attention path is not SP-safe).
#
# For smaller boxes, `ep_degree=4` works on 4 GPUs (verified by cicd
# `cicd/run_traincicd.sh --model-name qwen3_5_moe --gpu-count 4`).


trainer_type: fsdp2_trainer

# Dataset configuration - inline dataset definitions
dataset_config:
dataset_type: vision_iterable
dataset_format: yaml

datasets:
- path: data/lmms_engine_test/text_example/open_thoughts_5k_parquet
data_folder: ""
data_type: parquet

# Processor configuration - qwen3_5_moe uses the qwen3_vl processor
processor_config:
processor_name: "Qwen/Qwen3.6-35B-A3B"
processor_type: "qwen3_vl"

packing: false
packing_strategy: first_fit
packing_length: 10240
video_backend: qwen_vl_utils
filter_overlong: true

# Model configuration
model_config:
load_from_pretrained_path: "Qwen/Qwen3.6-35B-A3B"
# Qwen3_5MoeConfig is registered in both causal_lm and image_text_to_text
# auto-mappings; pin to image_text_to_text so we get the multimodal
# Qwen3_5MoeForConditionalGeneration wrapper (vision tower + LM), not the
# text-only Qwen3_5MoeForCausalLM.
model_general_type: image_text_to_text
attn_implementation: "flash_attention_2"
# Two independent patches registered under qwen3_5_moe: "liger" and "rmpad".
# The trainer runner applies them in order ["liger", "rmpad"].
# Only liger accepts kwargs; rmpad takes none.
monkey_patch_kwargs:
patch_type: ["liger", "rmpad"]
fused_linear_cross_entropy: true
rms_norm: true
swiglu: true

# Training arguments, mostly compatible with HuggingFace Trainer
trainer_args:
per_device_train_batch_size: 1
learning_rate: 1.0e-06
weight_decay: 0.0
gradient_accumulation_steps: 1
gradient_checkpointing: true
max_steps: 500
num_train_epochs: 1
save_steps: 100
save_total_limit: 1
report_to: "none"
output_dir: "./output/qwen3_5_moe_a3b_ep8"
warmup_ratio: 0.0
warmup_steps: 100
run_name: "qwen3_5_moe_a3b_ep8"
eval_strategy: "no"
logging_steps: 1
group_by_length: false
dataloader_num_workers: 0
bf16: true
lr_scheduler_type: "constant"
use_liger_kernel: true
use_rmpad: true
fsdp2: true
fsdp_config:
transformer_layer_cls_to_wrap: ["Qwen3_5MoeDecoderLayer"]
reshard_after_forward: false
# Sequence parallelism is not supported for qwen3_5_moe (linear-attention
# path is not SP-safe). Keep sp_ulysses_degree=1.
sp_ulysses_degree: 1
# Expert Parallelism degree. 8 for an 8-GPU node; use 4 on a 4-GPU box.
ep_degree: 8
enable_profiler: false
profiler_config:
start_step: 1
end_step: 3
11 changes: 11 additions & 0 deletions examples/qwen3_5_moe/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Number of GPUs
NGPUS=8

# Training command
torchrun --nproc_per_node=${NGPUS} \
--nnodes=1 \
--node_rank=0 \
--master_addr=127.0.0.1 \
--master_port=12357 \
-m lmms_engine.launch.cli \
config_yaml=examples/qwen3_5_moe/qwen3_5_moe_ep8.yaml
48 changes: 35 additions & 13 deletions src/lmms_engine/mapping_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,42 @@ def create_model_from_pretrained(
return model_class


def create_model_from_config(model_type, config):
def create_model_from_config(model_type, config, model_general_type: str | None = None):
"""Build a model class + config from a model_type string and a config dict.

Args:
model_type: HF model_type string (e.g. ``"qwen3_5_moe"``).
config: dict of kwargs forwarded to the corresponding config class.
model_general_type: Optional override; one of the keys in
``AUTO_REGISTER_MODEL_MAPPING``. Use it to disambiguate when the
same config is registered under multiple AutoModel mappings (e.g.
``Qwen3_5MoeConfig`` is in both ``causal_lm`` and
``image_text_to_text``; without the override we'd silently pick
the wrong wrapper).
"""
from transformers.models.auto.configuration_auto import CONFIG_MAPPING

if model_type in CONFIG_MAPPING:
config_class = CONFIG_MAPPING[model_type]
m_config = config_class(**config)
if type(m_config) in AutoModelForCausalLM._model_mapping.keys():
model_class = AutoModelForCausalLM
elif type(m_config) in AutoModelForImageTextToText._model_mapping.keys():
model_class = AutoModelForImageTextToText
elif type(m_config) in AutoModelForMaskedLM._model_mapping.keys():
model_class = AutoModelForMaskedLM
elif type(m_config) in AutoModel._model_mapping.keys():
model_class = AutoModel
else:
if model_type not in CONFIG_MAPPING:
raise ValueError(f"Model type '{model_type}' is not found in CONFIG_MAPPING.")
config_class = CONFIG_MAPPING[model_type]
m_config = config_class(**config)

if model_general_type is not None:
if model_general_type not in AUTO_REGISTER_MODEL_MAPPING:
raise ValueError(
f"Unknown model_general_type={model_general_type!r}; "
f"choose one of {list(AUTO_REGISTER_MODEL_MAPPING)}"
)
return AUTO_REGISTER_MODEL_MAPPING[model_general_type], m_config

if type(m_config) in AutoModelForCausalLM._model_mapping.keys():
model_class = AutoModelForCausalLM
elif type(m_config) in AutoModelForImageTextToText._model_mapping.keys():
model_class = AutoModelForImageTextToText
elif type(m_config) in AutoModelForMaskedLM._model_mapping.keys():
model_class = AutoModelForMaskedLM
elif type(m_config) in AutoModel._model_mapping.keys():
model_class = AutoModel
else:
raise ValueError(f"Model type '{model_type}' is not in any AutoModel mapping.")
return model_class, m_config
19 changes: 18 additions & 1 deletion src/lmms_engine/merger/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def parse_args() -> argparse.Namespace:
help="Type of checkpoint to merge: 'regular' for main model weights, 'ema' for EMA weights",
)

parser.add_argument(
"--model_general_type",
type=str,
default=None,
choices=["causal_lm", "masked_lm", "image_text_to_text", "general"],
help=(
"Override AutoModel class used to instantiate the merged model. "
"Needed when the same config is registered under multiple AutoModel "
"mappings (e.g. Qwen3_5MoeConfig is in both causal_lm and "
"image_text_to_text). If unset, falls back to auto-detection."
),
)

return parser.parse_args()


Expand All @@ -53,7 +66,11 @@ def main() -> None:

print(f"Merging {args.checkpoint_type} checkpoint from {checkpoint_path}")
merger = FSDP2Merger(checkpoint_type=args.checkpoint_type)
result_path = merger.merge(checkpoint_path, output_path=output_path)
result_path = merger.merge(
checkpoint_path,
output_path=output_path,
model_general_type=args.model_general_type,
)

print(f"Merged checkpoint saved to: {result_path}")

Expand Down
Loading
Loading