Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions kempnerforge/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ModelConfig:
moe_gradient_scale: bool = False # Per-expert gradient normalization
moe_bias_schedule: str = "constant" # "constant", "cosine_decay", "linear_warmup"
moe_packed_experts: bool = False # Pack expert weights into one tensor per projection
moe_expert_ffn_multiplier: float = 1.0 # expert FFN hidden vs dense (0.5 = fine-grained)

def __post_init__(self) -> None:
if self.n_kv_heads is None:
Expand Down Expand Up @@ -99,6 +100,8 @@ def __post_init__(self) -> None:
f"Unknown moe_bias_schedule: '{self.moe_bias_schedule}'. "
"Options: 'constant', 'cosine_decay', 'linear_warmup'"
)
if self.moe_expert_ffn_multiplier <= 0:
raise ValueError("moe_expert_ffn_multiplier must be positive")

@property
def is_moe(self) -> bool:
Expand All @@ -118,6 +121,19 @@ def computed_ffn_hidden_dim(self) -> int:
raw = int(4 * self.dim * (2 / 3) * self.ffn_dim_multiplier)
return 256 * math.ceil(raw / 256)

@property
def computed_expert_ffn_hidden_dim(self) -> int:
"""Per-expert FFN hidden dim = ``computed_ffn_hidden_dim`` * ``moe_expert_ffn_multiplier``.

Rounded to a multiple of 16 for tensor-core alignment. With the default
multiplier 1.0 this equals ``computed_ffn_hidden_dim`` (zero behavior
change); set 0.5 for fine-grained experts so top-2 routing matches the
dense FFN's activated FLOPs (2 * F/2 = F). Applies to routed and shared
experts wherever they are built (build_moe and MoMa's ExpertChoiceMoE).
"""
raw = int(self.computed_ffn_hidden_dim * self.moe_expert_ffn_multiplier)
return max(16, 16 * round(raw / 16))

@property
def num_params_estimate(self) -> int:
"""Rough total parameter count estimate (excluding embedding if tied).
Expand All @@ -134,11 +150,13 @@ def num_params_estimate(self) -> int:
norm = 2 * d # 2 norms per layer

if self.is_moe:
he = self.computed_expert_ffn_hidden_dim # may be fine-grained (< h)
expert_mlp = d * he + d * he + he * d
n_moe = sum(1 for i in range(self.n_layers) if (i + 1) % self.moe_frequency == 0)
n_dense = self.n_layers - n_moe
router = d * self.num_experts # gate linear per MoE layer
shared_mlp = self.moe_shared_experts * mlp
moe_per_layer = attn + self.num_experts * mlp + router + shared_mlp + norm
shared_mlp = self.moe_shared_experts * expert_mlp
moe_per_layer = attn + self.num_experts * expert_mlp + router + shared_mlp + norm
dense_per_layer = attn + mlp + norm
layer_params = n_moe * moe_per_layer + n_dense * dense_per_layer
else:
Expand Down
2 changes: 1 addition & 1 deletion kempnerforge/model/moma.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(
{
m: ExpertChoiceMoE(
dim=config.dim,
hidden_dim=config.computed_ffn_hidden_dim,
hidden_dim=config.computed_expert_ffn_hidden_dim,
num_experts=experts_per_modality[m],
capacity_factor=capacity_factor_per_modality[m],
activation=config.activation,
Expand Down
2 changes: 1 addition & 1 deletion kempnerforge/model/mot.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(
{
m: build_moe(
dim=config.dim,
hidden_dim=config.computed_ffn_hidden_dim,
hidden_dim=config.computed_expert_ffn_hidden_dim,
num_experts=config.num_experts,
top_k=config.moe_top_k,
activation=config.activation,
Expand Down
2 changes: 1 addition & 1 deletion kempnerforge/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, config: ModelConfig, layer_idx: int) -> None:
if use_moe:
self.mlp = build_moe(
dim=config.dim,
hidden_dim=config.computed_ffn_hidden_dim,
hidden_dim=config.computed_expert_ffn_hidden_dim,
num_experts=config.num_experts,
top_k=config.moe_top_k,
activation=config.activation,
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,35 @@ def test_moe_validation_passes(self):
assert m.is_moe is True
assert m.moe_top_k == 2

def test_expert_ffn_multiplier_default_matches_dense(self):
m = ModelConfig(dim=1024, n_heads=16, num_experts=4, moe_top_k=2)
assert m.moe_expert_ffn_multiplier == 1.0
assert m.computed_expert_ffn_hidden_dim == m.computed_ffn_hidden_dim

def test_expert_ffn_multiplier_half(self):
m = ModelConfig(
dim=1024, n_heads=16, num_experts=4, moe_top_k=2, moe_expert_ffn_multiplier=0.5
)
assert m.computed_expert_ffn_hidden_dim == m.computed_ffn_hidden_dim // 2

def test_expert_ffn_isoflop_top2(self):
# top-2 with half-size experts matches the dense FFN's activated FLOPs (2 * F/2 = F)
m = ModelConfig(
dim=1024, n_heads=16, num_experts=4, moe_top_k=2, moe_expert_ffn_multiplier=0.5
)
assert m.moe_top_k * m.computed_expert_ffn_hidden_dim == m.computed_ffn_hidden_dim

def test_expert_ffn_multiplier_rejects_nonpositive(self):
with pytest.raises(ValueError, match="moe_expert_ffn_multiplier must be positive"):
ModelConfig(num_experts=4, moe_top_k=2, moe_expert_ffn_multiplier=0.0)

def test_finegrained_reduces_param_estimate(self):
common = dict(dim=512, n_layers=2, n_heads=8, vocab_size=1000, num_experts=4, moe_top_k=2)
assert (
ModelConfig(**common, moe_expert_ffn_multiplier=0.5).num_params_estimate
< ModelConfig(**common).num_params_estimate
)

def test_moe_rejects_top_k_greater_than_experts(self):
with pytest.raises(ValueError, match="moe_top_k.*must be <= num_experts"):
ModelConfig(num_experts=8, moe_top_k=16)
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,26 @@ def test_dense_model_completely_unchanged(self):
actual = sum(p.numel() for p in model.parameters())
assert actual == config.num_params_estimate

def test_finegrained_experts_built_smaller(self):
"""moe_expert_ffn_multiplier=0.5 -> experts have F/2 hidden; dense layers unchanged."""
config = ModelConfig(
**_SMALL, num_experts=4, moe_top_k=2, moe_frequency=2, moe_expert_ffn_multiplier=0.5
)
f = config.computed_ffn_hidden_dim
model = Transformer(config)
for _name, layer in model.layers.items():
if isinstance(layer.mlp, MoEMLP): # layers 1, 3
assert layer.mlp.experts[0].gate_proj.weight.shape[0] == f // 2
else: # dense layers 0, 2
assert layer.mlp.gate_proj.weight.shape[0] == f

def test_finegrained_param_count_matches(self):
"""num_params_estimate matches actual params for a fine-grained MoE model."""
config = ModelConfig(**_SMALL, num_experts=4, moe_top_k=2, moe_expert_ffn_multiplier=0.5)
model = Transformer(config)
actual = sum(p.numel() for p in model.parameters())
assert actual == config.num_params_estimate


# ---------------------------------------------------------------------------
# SigmoidTopKRouter (DeepSeek-V3 style)
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_moma.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ def _config(
)


def test_moma_finegrained_experts():
"""MoMa expert-choice MoE experts honor moe_expert_ffn_multiplier (fine-grained)."""
cfg = ModelConfig(
dim=64, n_layers=2, n_heads=4, vocab_size=128, max_seq_len=64, moe_expert_ffn_multiplier=0.5
)
ffn = MoMaFFN(
cfg,
modalities=("image", "text"),
experts_per_modality={"image": 4, "text": 4},
capacity_factor_per_modality={"image": 1.0, "text": 1.0},
)
expert = ffn.experts["text"].experts[0]
assert expert.gate_proj.weight.shape[0] == cfg.computed_ffn_hidden_dim // 2


# ---------------------------------------------------------------------------
# MoMaConfig
# ---------------------------------------------------------------------------
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_mot.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,27 @@ def _config(
)


def test_mot_finegrained_experts():
"""MoT MoE experts honor moe_expert_ffn_multiplier (fine-grained)."""
from kempnerforge.model.moe import MoEMLP

cfg = ModelConfig(
dim=64,
n_layers=2,
n_heads=4,
vocab_size=128,
max_seq_len=64,
num_experts=4,
moe_top_k=2,
moe_frequency=1,
moe_expert_ffn_multiplier=0.5,
)
block = MoTBlock(cfg, modalities=("image", "text"), layer_idx=0)
moe = block.mlp["text"]
assert isinstance(moe, MoEMLP)
assert moe.experts[0].gate_proj.weight.shape[0] == cfg.computed_ffn_hidden_dim // 2


# ---------------------------------------------------------------------------
# MoTAttention — structural
# ---------------------------------------------------------------------------
Expand Down
Loading