From d893b7006bd3a531fa9ef6317e1bd6f71c2454ac Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Sat, 6 Dec 2025 06:26:50 +0000 Subject: [PATCH 1/2] Fix MoE crash and hang issues Signed-off-by: Jevin Jiang --- tpu_inference/kernels/fused_moe/v1/kernel.py | 361 ++++++++++-------- .../layers/vllm/quantization/unquantized.py | 74 +++- 2 files changed, 259 insertions(+), 176 deletions(-) diff --git a/tpu_inference/kernels/fused_moe/v1/kernel.py b/tpu_inference/kernels/fused_moe/v1/kernel.py index 76604f893..c7eeed74e 100644 --- a/tpu_inference/kernels/fused_moe/v1/kernel.py +++ b/tpu_inference/kernels/fused_moe/v1/kernel.py @@ -19,7 +19,7 @@ def align_to(x, a): def get_dtype_packing(dtype): - bits = dtypes.bit_width(dtype) + bits = dtypes.itemsize_bits(dtype) return 32 // bits @@ -65,18 +65,19 @@ def ref_moe( top_k: int, *, renormalize_topk_logits: bool = False, - activation="silu", + act_fn: str = "silu", subc_quant_wsz: int | None = None, w1_scale: ( jax.Array | None - ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size) + ) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size) w2_scale: ( jax.Array | None - ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size) - b1: jax.Array | None = None, # (num_experts, 2, intermediate_size) - b2: jax.Array | None = None, # (num_experts, hidden_size) + ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size) + b1: jax.Array + | None = None, # F32(num_experts, 2, 1, intermediate_size) + b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size) ): n_tokens = tokens.shape[0] # num_tokens @@ -97,7 +98,7 @@ def ref_moe( # Process each token individually for i in range(n_tokens): - curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, d_model] + curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size] assigned_expert_ids = top_k_indices[ i] # [top_k] - indices of selected experts for token i tok_expert_act = [] @@ -108,19 +109,19 @@ def ref_moe( expert_w1 = w1[expert_id, 0].astype(jnp.float32) expert_w3 = w1[expert_id, 1].astype(jnp.float32) if w1_scale is not None: - expert_w1 *= jnp.repeat(w1_scale[expert_id, 0], + expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0], subc_quant_wsz, axis=0)[:hidden_size] - expert_w3 *= jnp.repeat(w1_scale[expert_id, 1], + expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0], subc_quant_wsz, axis=0)[:hidden_size] expert_weight_1 = jnp.concat( [expert_w1, expert_w3], - axis=-1) # [d_model, 2 * intermediate_size] + axis=-1) # [hidden_size, 2 * intermediate_size] expert_weight_2 = w2[expert_id].astype( - jnp.float32) # [intermediate_size, d_model] + jnp.float32) # [intermediate_size, hidden_size] if w2_scale is not None: - expert_weight_2 *= jnp.repeat(w2_scale[expert_id], + expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0], subc_quant_wsz, axis=0)[:intermediate_size] @@ -132,32 +133,33 @@ def ref_moe( gmm_1_out, 2, axis=-1) # [1, intermediate_size], [1, intermediate_size] if b1 is not None: - gmm1_w1_proj += b1[expert_id:expert_id + 1, 0] - gmm1_w3_proj += b1[expert_id:expert_id + 1, 1] + gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0] + gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0] # Apply gated activation: activation(gate) * up - act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, activation) + act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn) # Second linear layer (down projection) - gmm_2_out = act @ expert_weight_2 # [1, d_model] + gmm_2_out = act @ expert_weight_2 # [1, hidden_size] if b2 is not None: - gmm_2_out += b2[expert_id:expert_id + 1] + gmm_2_out += b2[expert_id:expert_id + 1, 0] tok_expert_act.append(gmm_2_out) # Combine outputs from all selected experts experts_act = jnp.concatenate(tok_expert_act, - axis=0) # [top_k, d_model] + axis=0) # [top_k, hidden_size] # Weighted sum using top-k gating weights top_k_weights = top_k_logits[i] # [top_k] top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1] weighted_output = jnp.sum(experts_act * top_k_weights, axis=0, - keepdims=True) # [1, d_model] + keepdims=True) # [1, hidden_size] t_outputs.append(weighted_output.astype(tokens.dtype)) - return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model] + return jnp.concatenate(t_outputs, + axis=0) # [actual_num_tokens, hidden_size] def _fused_ep_moe_kernel( @@ -177,7 +179,7 @@ def _fused_ep_moe_kernel( # Output output_hbm, # (local_num_tokens, hidden_size) # Scratch - t2e_routing_x2_smem, # (2, bt, padded_num_experts) + t2e_routing_x2_smem, # (2, bt, padded_top_k) d2e_count_x2_smem, # (2, num_devices, 1, padded_num_experts) expert_offsets_x2_smem, # (2, 2, padded_num_experts): for a2a_s and a2a_g expert_starts_x2_smem, # (2, 1, padded_num_experts) @@ -227,6 +229,11 @@ def _fused_ep_moe_kernel( local_num_tokens = tokens_hbm.shape[0] local_num_experts, intermediate_size, hidden_size = w2_hbm.shape right_id = (my_id + 1) % num_devices + num_experts = a2a_g_hbm.shape[0] + padded_num_experts = d2e_count_x2_smem.shape[-1] + padded_top_k = t2e_routing_x2_smem.shape[-1] + assert padded_num_experts == align_to(num_experts, 128) + assert padded_top_k == align_to(top_k, 128) t_dtype = tokens_hbm.dtype t_packing = get_dtype_packing(t_dtype) @@ -300,35 +307,40 @@ def wait_fetch_b_gating(bt_id): def get_top_k(input, top_k, renormalize_topk_logits): assert len(input.shape) == 2, input.shape input = input.astype(jnp.float32) + padded_k_shape = (input.shape[0], padded_top_k) top_k_logits_lst = [] top_k_indices_lst = [] t2e = jnp.zeros(input.shape, dtype=jnp.int32) - t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32) + t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32) iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1) - top_k_logits_sum = jnp.zeros((input.shape[0], 128), jnp.float32) + padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1) + top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32) for k_id in range(top_k): # TODO(jevinjiang): return both top_k values and indices in Mosaic top_k_logits = jnp.broadcast_to( - jnp.max(input, axis=1, keepdims=True), - (input.shape[0], 128)).astype(input.dtype) + jnp.max(input[:, :num_experts], axis=1, keepdims=True), + padded_k_shape, + ).astype(input.dtype) + top_k_logits_lst.append(top_k_logits) if renormalize_topk_logits: top_k_logits_sum += top_k_logits - top_k_logits_lst.append(top_k_logits) # TODO(jevinjiang): support bf16 argmax in Mosaic top_k_indices = jnp.broadcast_to( - jnp.argmax(input, axis=1, keepdims=True), input.shape) + jnp.argmax(input[:, :num_experts], axis=1, keepdims=True), + padded_k_shape, + ) top_k_indices_lst.append(top_k_indices) - t2e_routing = jnp.where(iota == k_id, top_k_indices, t2e_routing) - mask = iota == top_k_indices + t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices, + t2e_routing) + mask = iota == broadcast_minor(top_k_indices, input.shape) t2e += mask.astype(jnp.int32) if k_id != top_k - 1: input = jnp.where(mask, -jnp.inf, input) if renormalize_topk_logits: for k_id in range(top_k): - top_k_logits_lst[ - k_id] = top_k_logits_lst[k_id] / top_k_logits_sum + top_k_logits_lst[k_id] /= top_k_logits_sum expert_sizes = jnp.sum(t2e, axis=0, keepdims=True) expert_starts = jnp.zeros_like(expert_sizes) @@ -1071,27 +1083,38 @@ def run_per_bt(bt_id, e_sem_id): all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts, expert_sizes) + sync_barrier() + # Start a2a scatter for first active expert. start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0) def run_per_expert(local_e_id, e_sem_id): sync_barrier() + + # Prefetch weights for CURRENT active expert. + # TODO(jevinjiang): It is hard to prefetch weights in previous iteration + # because the expert_ffn keeps overwriting the buffers. Triple buffering + # could resolve this but it takes more VMEM scratch. Need further + # experiment on this. + start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0) + start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0) + + # Next ids. next_e_sem_id = lax.select(e_sem_id == 0, 1, 0) next_local_e_id = local_e_id + 1 + # Start a2a scatter for NEXT active expert. @pl.when(next_local_e_id < local_num_experts) def _(): start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id) - # Prefetch weights for active expert. - start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0) - start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0) - - # Wait for a2a scatter and perform FFN for active expert. + # Wait a2a scatter for CURRENT active expert. wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id) + + # Perform FFN for CURRENT active expert. expert_ffn(bt_id, e_sem_id, local_e_id) - # Wait for a2a gather to send back tokens for active expert. + # Start a2a gather to send back tokens for CURRENT active expert. start_a2a_gather(bt_id, e_sem_id, local_e_id) # A must-wait before next sync_barrier. @@ -1104,7 +1127,10 @@ def _(): e_sem_id, unroll=False) + # Wait to receive a2a gather for ALL experts. wait_a2a_gather_recv_all() + + # Accumulate results for current batch. output = bt_acc(bt_id, top_k_logits_lst) # Make sure it is safe to overwrite output buffer. @@ -1158,18 +1184,18 @@ def fused_ep_moe( w2: jax.Array, # (num_experts, intermediate_size, hidden_size) gating_output: jax.Array, # (num_tokens, num_experts) top_k: int, + *, renormalize_topk_logits: bool = False, act_fn: str = "silu", - *, subc_quant_wsz: int | None = None, w1_scale: ( jax.Array | None - ) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size) + ) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size) w2_scale: ( jax.Array | None - ) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size) - b1: jax.Array | None = None, # (num_experts, 2, intermediate_size) - b2: jax.Array | None = None, # (num_experts, hidden_size) + ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size) + b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size) + b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size) # Kernel tuning parameters. bt: int, bf: int, @@ -1182,75 +1208,159 @@ def fused_ep_moe( ep_axis_name: str = "model", ): # TODO(jevinjiang): move all these assertions to validation function. - # Assert all other axes have length of 1 - assert len(mesh.shape) == 2, "Expect 2D mesh" - assert ("data" in mesh.shape - and mesh.shape["data"] == 1), "Expect data axis size of 1" + if len(mesh.shape) != 2: + raise NotImplementedError("Only 2D mesh is supported.") + + for axis_name in mesh.axis_names: + if axis_name == ep_axis_name: + continue + if mesh.shape[axis_name] != 1: + raise NotImplementedError( + f"Expected all non-ep axis to have size 1 in {mesh.shape=}") ep_size = mesh.shape[ep_axis_name] num_devices = ep_size - num_tokens, actual_hidden_size = tokens.shape - num_experts, actual_intermediate_size, _ = w2.shape + num_tokens, hidden_size = tokens.shape + num_experts, intermediate_size, _ = w2.shape - assert num_tokens % ep_size == 0 - assert num_experts % ep_size == 0 + if w1.shape != (num_experts, 2, hidden_size, intermediate_size): + raise ValueError( + f"Expected {w1.shape=} to be" + f" {(num_experts, 2, hidden_size, intermediate_size)}.") + + if w2.shape != (num_experts, intermediate_size, hidden_size): + raise ValueError(f"Expected {w2.shape=} to be" + f" {(num_experts, intermediate_size, hidden_size)}.") + + if gating_output.shape != (num_tokens, num_experts): + raise ValueError( + f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}." + ) + + if not (0 < top_k <= num_experts): + raise ValueError( + f"Expected {top_k=} to be in range (0, {num_experts=}].") + + if hidden_size % 128 != 0 or intermediate_size % 128 != 0: + raise ValueError( + f"Expected {hidden_size=} and {intermediate_size=} to be aligned to" + " 128. Did you pad them with zeros outside the kernel?") + if num_tokens % ep_size != 0: + raise ValueError( + f"Expected {num_tokens=} to be aligned to {ep_size=}.") + if num_experts % ep_size != 0: + raise ValueError( + f"Expected {num_experts=} to be aligned to {ep_size=}.") local_num_tokens = num_tokens // ep_size # local_num_experts = num_experts // ep_size padded_num_experts = align_to(num_experts, 128) + padded_top_k = align_to(top_k, 128) t_dtype = tokens.dtype t_packing = get_dtype_packing(t_dtype) + # Override bt + if local_num_tokens <= t_packing * 8: + bt = local_num_tokens + btc = bt + bt = min(local_num_tokens, bt) + # The worst case is that all devices send bt to one device. + btc = min(bt, btc, bt * num_devices) + + if local_num_tokens % t_packing != 0: + raise ValueError( + f"Expected {local_num_tokens=} to be aligned to {t_packing=}.") + + if bt % t_packing != 0: + raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.") + if local_num_tokens % bt != 0: + raise ValueError( + f"Expected {local_num_tokens=} to be aligned to {bt=}.") + if subc_quant_wsz is not None: + if subc_quant_wsz <= 0: + raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.") if subc_quant_wsz % 256 != 0: - raise NotImplementedError( - "Sub-quantized window is not aligned to 256.") - # We force compute size of contracting dim to subc_quant_wsz. So we can + raise ValueError( + "Expected {subc_quant_wsz=} to be aligned to 256.") + if hidden_size % subc_quant_wsz != 0: + raise ValueError( + f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.") + if intermediate_size % subc_quant_wsz != 0: + raise ValueError( + f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}." + ) + # We force compute size of contracting dim to be subc_quant_wsz. So we can # apply same scale after matmul and accumulation. bd1c = subc_quant_wsz * t_packing bfc = subc_quant_wsz - assert bfc % 128 == 0 - assert bd1c % (t_packing * 128) == 0 - assert bd2c % (t_packing * 128) == 0 - assert bf % bfc == 0 - assert bd1 % bd1c == 0 - assert bd2 % bd2c == 0 - - btc = min(btc, bt * num_devices) - hidden_size = align_to(actual_hidden_size, 128 * t_packing) - # TODO(jevinjiang): instead of padding outside the kernel, we can try dynammic - # masking inside the kernel. - hidden_size = align_to(hidden_size, bd1) - hidden_size = align_to(hidden_size, bd2) - intermediate_size = align_to(actual_intermediate_size, bf) - - # TODO(jevinjiang): we should dump scale as the kernel expected shape in the + if bfc % 128 != 0: + raise ValueError(f"Expected {bfc=} to be aligned to 128.") + if bd1c % (t_packing * 128) != 0: + raise ValueError( + f"Expected {bd1c=} to be aligned to {t_packing * 128}.") + if bd2c % (t_packing * 128) != 0: + raise ValueError( + f"Expected {bd2c=} to be aligned to {t_packing * 128}.") + if bf % bfc != 0: + raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.") + if bd1 % bd1c != 0: + raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.") + if bd2 % bd2c != 0: + raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.") + if hidden_size % bd1 != 0 or hidden_size % bd2 != 0: + raise ValueError( + f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.") + if intermediate_size % bf != 0: + raise ValueError( + f"Expected {intermediate_size=} to be aligned to {bf=}.") + + # Note: we should dump scale as the kernel expected shape in the # checkpoint offline or reshape right after weight loading. if w1_scale is not None: - assert w1_scale.shape[0] == w1.shape[0] - assert w1_scale.shape[1] == w1.shape[1] == 2 - assert w1_scale.shape[2] == cdiv(w1.shape[2], subc_quant_wsz) - assert w1_scale.shape[3] == w1.shape[3] - w1_scale = jnp.expand_dims(w1_scale.astype(jnp.float32), axis=-2) + expected_w1_scale_shape = ( + num_experts, + 2, + hidden_size // subc_quant_wsz, + 1, + intermediate_size, + ) + if w1_scale.shape != expected_w1_scale_shape: + raise ValueError( + f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.") + if w1_scale.dtype != jnp.float32: + w1_scale = w1_scale.astype(jnp.float32) if w2_scale is not None: - assert w2_scale.shape[0] == w2.shape[0] - assert w2_scale.shape[1] == cdiv(w2.shape[1], subc_quant_wsz) - assert w2_scale.shape[2] == w2.shape[2] - w2_scale = jnp.expand_dims(w2_scale.astype(jnp.float32), axis=-2) + expected_w2_scale_shape = ( + num_experts, + intermediate_size // subc_quant_wsz, + 1, + hidden_size, + ) + if w2_scale.shape != expected_w2_scale_shape: + raise ValueError( + f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.") + if w2_scale.dtype != jnp.float32: + w2_scale = w2_scale.astype(jnp.float32) if b1 is not None: - assert b1.shape[0] == w1.shape[0] - assert b1.shape[1] == w1.shape[1] == 2 - assert b1.shape[2] == w1.shape[3] - b1 = jnp.expand_dims(b1.astype(jnp.float32), axis=-2) + expected_b1_shape = (num_experts, 2, 1, intermediate_size) + if b1.shape != expected_b1_shape: + raise ValueError( + f"Expected {b1.shape=} to be {expected_b1_shape}.") + if b1.dtype != jnp.float32: + b1 = b1.astype(jnp.float32) if b2 is not None: - assert b2.shape[0] == w2.shape[0] - assert b2.shape[1] == w2.shape[2] - b2 = jnp.expand_dims(b2.astype(jnp.float32), axis=-2) + expected_b2_shape = (num_experts, 1, hidden_size) + if b2.shape != expected_b2_shape: + raise ValueError( + f"Expected {b2.shape=} to be {expected_b2_shape}.") + if b2.dtype != jnp.float32: + b2 = b2.astype(jnp.float32) # Prepare inputs for the kernel. if padded_num_experts != gating_output.shape[-1]: @@ -1260,83 +1370,11 @@ def fused_ep_moe( constant_values=-jnp.inf, ) - if (hidden_size != actual_hidden_size - or intermediate_size != actual_intermediate_size): - tokens = jnp.pad( - tokens, - ((0, 0), (0, hidden_size - actual_hidden_size)), - constant_values=0, - ) - w1 = jnp.pad( - w1, - ( - (0, 0), - (0, 0), - (0, hidden_size - actual_hidden_size), - (0, intermediate_size - actual_intermediate_size), - ), - constant_values=0, - ) - w2 = jnp.pad( - w2, - ( - (0, 0), - (0, intermediate_size - actual_intermediate_size), - (0, hidden_size - actual_hidden_size), - ), - constant_values=0, - ) - if w1_scale is not None: - w1_scale = jnp.pad( - w1_scale, - ( - (0, 0), - (0, 0), - (0, - cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]), - (0, 0), - (0, intermediate_size - w1_scale.shape[-1]), - ), - constant_values=0, - ) - if w2_scale is not None: - w2_scale = jnp.pad( - w2_scale, - ( - (0, 0), - (0, cdiv(intermediate_size, subc_quant_wsz) - - w2_scale.shape[-3]), - (0, 0), - (0, hidden_size - w2_scale.shape[-1]), - ), - constant_values=0, - ) - if b1 is not None: - b1 = jnp.pad( - b1, - ( - (0, 0), - (0, 0), - (0, 0), - (0, intermediate_size - b1.shape[-1]), - ), - constant_values=0, - ) - if b2 is not None: - b2 = jnp.pad( - b2, - ( - (0, 0), - (0, 0), - (0, hidden_size - b2.shape[-1]), - ), - constant_values=0, - ) - tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing) hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM) - scope_name = f"fused_moe_k-{top_k}_renorm-{renormalize_topk_logits}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}" + renorm_str = "-renorm_k" if renormalize_topk_logits else "" + scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}" fused_moe = jax.named_scope(scope_name)( pl.pallas_call( functools.partial( @@ -1375,7 +1413,7 @@ def fused_ep_moe( out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM), scratch_shapes=([ # t2e_routing_x2_smem - pltpu.SMEM((2, bt, padded_num_experts), jnp.int32), + pltpu.SMEM((2, bt, padded_top_k), jnp.int32), # d2e_count_x2_smem pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32), @@ -1552,7 +1590,7 @@ def kernel( a2a_g_hbm_scratch = pl.empty( (num_experts, bt, t_packing, hidden_size // t_packing), t_dtype) - results = kernel( + return kernel( tokens, w1, w2, @@ -1563,4 +1601,3 @@ def kernel( gating_output, a2a_g_hbm_scratch, ) - return results[:, :actual_hidden_size] diff --git a/tpu_inference/layers/vllm/quantization/unquantized.py b/tpu_inference/layers/vllm/quantization/unquantized.py index 3371ddb23..1a02ef233 100644 --- a/tpu_inference/layers/vllm/quantization/unquantized.py +++ b/tpu_inference/layers/vllm/quantization/unquantized.py @@ -36,6 +36,10 @@ logger = init_logger(__name__) +def align_to(a, b): + return (a + b - 1) // b * b + + @register_quantization_config(get_tpu_quant_method(UNQUANTIZED)) class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig): @@ -223,29 +227,66 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Current format: # w13_weight: (num_experts, 2*intermediate_size, hidden_size) # w2_weight: (num_experts, hidden_size, intermediate_size) + num_experts = w13_weight.shape[0] + intermediate_size = w13_weight.shape[1] // 2 + hidden_size = w13_weight.shape[2] + + padded_intermediate_size = align_to(intermediate_size, 256) + padded_hidden_size = align_to(hidden_size, 256) - w13_reshaped = w13_weight.reshape(num_experts, 2, - intermediate_size, hidden_size) + w13_weight = w13_weight.reshape(num_experts, 2, intermediate_size, + hidden_size) + w13_weight = jnp.transpose(w13_weight, (0, 1, 3, 2)) - # Transpose non-constracting dim to right most dim - w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3) - w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2) + # Transpose w2_weight to (num_experts, intermediate_size, hidden_size) + w2_weight = jnp.transpose(w2_weight, (0, 2, 1)) + + w13_weight = jnp.pad( + w13_weight, + ((0, 0), (0, 0), (0, padded_hidden_size - hidden_size), + (0, padded_intermediate_size - intermediate_size)), + constant_values=0) + + w2_weight = jnp.pad( + w2_weight, + ((0, 0), (0, padded_intermediate_size - intermediate_size), + (0, padded_hidden_size - hidden_size)), + constant_values=0) # Apply EP sharding ep_sharding = NamedSharding(self.mesh, P("model")) w13_weight = jax.device_put( - w13_weight_transposed, Format(Layout((0, 1, 2, 3)), - ep_sharding)) - w2_weight = jax.device_put(w2_weight_transposed, - Format(Layout((0, 1, 2)), ep_sharding)) + w13_weight, + Format(Layout((0, 1, 2, 3)), + NamedSharding(self.mesh, P("model", None, None, None)))) + w2_weight = jax.device_put( + w2_weight, + Format(Layout((0, 1, 2)), + NamedSharding(self.mesh, P("model", None, None)))) if self.moe.has_bias: - w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size) + w13_bias = w13_bias.astype(jnp.float32).reshape( + num_experts, 2, 1, intermediate_size) + w2_bias = w2_bias.astype(jnp.float32).reshape( + num_experts, 1, hidden_size) + + w13_bias = jnp.pad( + w13_bias, + ((0, 0), (0, 0), (0, 0), + (0, padded_intermediate_size - intermediate_size)), + constant_values=0) + + w2_bias = jnp.pad(w2_bias, + ((0, 0), (0, 0), + (0, padded_hidden_size - hidden_size)), + constant_values=0) + + # Apply EP sharding w13_bias = jax.device_put( - w13_bias, Format(Layout((0, 1, 2)), ep_sharding)) - w2_bias = jax.device_put(w2_bias, - Format(Layout((0, 1)), ep_sharding)) + w13_bias, Format(Layout((0, 1, 2, 3)), ep_sharding)) + w2_bias = jax.device_put( + w2_bias, Format(Layout((0, 1, 2)), ep_sharding)) else: if layer.use_ep: @@ -319,6 +360,11 @@ def apply( gating_output = jax_view(router_logits) if self.use_kernel: + actual_hidden_size = x.shape[-1] + padded_hidden_size = align_to(actual_hidden_size, 256) + x = jnp.pad(x, + ((0, 0), (0, padded_hidden_size - actual_hidden_size)), + constant_values=0) output = fused_ep_moe( mesh=self.mesh, tokens=x, @@ -332,7 +378,7 @@ def apply( renormalize_topk_logits=layer.renormalize, act_fn=layer.activation, **self.block_size, - ) + )[:, :actual_hidden_size] else: output = fused_moe_func( hidden_states=x, From 8e951c910c8e9663286f37178d31984099611768 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Sat, 6 Dec 2025 06:29:13 +0000 Subject: [PATCH 2/2] Fix bit_width Signed-off-by: Jevin Jiang --- tpu_inference/kernels/fused_moe/v1/kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tpu_inference/kernels/fused_moe/v1/kernel.py b/tpu_inference/kernels/fused_moe/v1/kernel.py index c7eeed74e..917bab996 100644 --- a/tpu_inference/kernels/fused_moe/v1/kernel.py +++ b/tpu_inference/kernels/fused_moe/v1/kernel.py @@ -19,7 +19,7 @@ def align_to(x, a): def get_dtype_packing(dtype): - bits = dtypes.itemsize_bits(dtype) + bits = dtypes.bit_width(dtype) return 32 // bits