Skip to content

Commit 8f30493

Browse files
authored
Fix moe layer from upstream change (#1274)
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 5f7dc4e commit 8f30493

File tree

3 files changed

+22
-73
lines changed

3 files changed

+22
-73
lines changed

tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Union
1+
from typing import Union
22

33
import jax
44
import jax.numpy as jnp
@@ -119,29 +119,12 @@ def apply(
119119
layer: torch.nn.Module,
120120
x: torch.Tensor,
121121
router_logits: torch.Tensor,
122-
top_k: int,
123-
renormalize: bool,
124-
use_grouped_topk: bool = False,
125-
topk_group: Optional[int] = None,
126-
num_expert_group: Optional[int] = None,
127-
global_num_experts: int = -1,
128-
expert_map: Optional[torch.Tensor] = None,
129-
custom_routing_function: Optional[Callable] = None,
130-
scoring_func: str = "softmax",
131-
routed_scaling_factor: float = 1.0,
132-
e_score_correction_bias: Optional[torch.Tensor] = None,
133-
apply_router_weight_on_input: bool = False,
134-
activation: str = "silu",
135-
enable_eplb: bool = False,
136-
expert_load_view: Optional[torch.Tensor] = None,
137-
logical_to_physical_map: Optional[torch.Tensor] = None,
138-
logical_replica_count: Optional[torch.Tensor] = None,
139122
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
140123
assert isinstance(layer, FusedMoE)
141-
if activation != "silu":
124+
if layer.activation != "silu":
142125
raise NotImplementedError(
143126
"Only silu is supported for activation function.")
144-
if scoring_func != "softmax":
127+
if layer.scoring_func != "softmax":
145128
raise NotImplementedError(
146129
"Only softmax is supported for scoring_func")
147130

@@ -155,9 +138,9 @@ def apply(
155138

156139
expert_weights = F.softmax(router_logits, dim=-1)
157140
expert_weights, expert_indices = torch.topk(expert_weights,
158-
top_k,
141+
layer.top_k,
159142
dim=-1)
160-
if renormalize:
143+
if layer.renormalize:
161144
expert_weights /= expert_weights.sum(dim=-1, keepdim=True)
162145

163146
# cond ffn

tpu_inference/layers/vllm/quantization/mxfp4.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Union
1+
from typing import Optional, Union
22

33
import jax
44
import jax.numpy as jnp
@@ -268,26 +268,9 @@ def apply(
268268
layer: torch.nn.Module,
269269
x: torch.Tensor,
270270
router_logits: torch.Tensor,
271-
top_k: int,
272-
renormalize: bool,
273-
use_grouped_topk: bool = False,
274-
topk_group: Optional[int] = None,
275-
num_expert_group: Optional[int] = None,
276-
global_num_experts: int = -1,
277-
expert_map: Optional[torch.Tensor] = None,
278-
custom_routing_function: Optional[Callable] = None,
279-
scoring_func: str = "softmax",
280-
routed_scaling_factor: float = 1.0,
281-
e_score_correction_bias: Optional[torch.Tensor] = None,
282-
apply_router_weight_on_input: bool = False,
283-
activation: str = "silu",
284-
enable_eplb: bool = False,
285-
expert_load_view: Optional[torch.Tensor] = None,
286-
logical_to_physical_map: Optional[torch.Tensor] = None,
287-
logical_replica_count: Optional[torch.Tensor] = None,
288271
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
289272
assert isinstance(layer, FusedMoE)
290-
if scoring_func != "softmax":
273+
if layer.scoring_func != "softmax":
291274
raise NotImplementedError(
292275
"Only softmax is supported for scoring_func")
293276

@@ -307,10 +290,10 @@ def apply(
307290
b1=w13_bias,
308291
b2=w2_bias,
309292
gating_output=gating_output,
310-
top_k=top_k,
293+
top_k=layer.top_k,
311294
ep_axis_name=self.ep_axis_name,
312-
renormalize_topk_logits=renormalize,
313-
act_fn=activation,
295+
renormalize_topk_logits=layer.renormalize,
296+
act_fn=layer.activation,
314297
**self.block_size,
315298
)
316299
else:
@@ -321,11 +304,11 @@ def apply(
321304
w1_bias=w13_bias,
322305
w2_bias=w2_bias,
323306
gating_output=gating_output,
324-
topk=top_k,
325-
renormalize=renormalize,
307+
topk=layer.top_k,
308+
renormalize=layer.renormalize,
326309
mesh=self.mesh,
327310
use_ep=layer.use_ep,
328-
activation=activation,
311+
activation=layer.activation,
329312
)
330313

331314
return torch_view(output)

tpu_inference/layers/vllm/quantization/unquantized.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Optional, Union
1+
from typing import Any, Optional, Union
22

33
import jax
44
import jax.numpy as jnp
@@ -303,26 +303,9 @@ def apply(
303303
layer: torch.nn.Module,
304304
x: torch.Tensor,
305305
router_logits: torch.Tensor,
306-
top_k: int,
307-
renormalize: bool,
308-
use_grouped_topk: bool = False,
309-
topk_group: Optional[int] = None,
310-
num_expert_group: Optional[int] = None,
311-
global_num_experts: int = -1,
312-
expert_map: Optional[torch.Tensor] = None,
313-
custom_routing_function: Optional[Callable] = None,
314-
scoring_func: str = "softmax",
315-
routed_scaling_factor: float = 1.0,
316-
e_score_correction_bias: Optional[torch.Tensor] = None,
317-
apply_router_weight_on_input: bool = False,
318-
activation: str = "silu",
319-
enable_eplb: bool = False,
320-
expert_load_view: Optional[torch.Tensor] = None,
321-
logical_to_physical_map: Optional[torch.Tensor] = None,
322-
logical_replica_count: Optional[torch.Tensor] = None,
323306
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
324307
assert isinstance(layer, FusedMoE)
325-
if scoring_func != "softmax":
308+
if layer.scoring_func != "softmax":
326309
raise NotImplementedError(
327310
"Only softmax is supported for scoring_func")
328311

@@ -335,7 +318,7 @@ def apply(
335318
w2_bias = jax_view(layer.w2_bias)
336319
gating_output = jax_view(router_logits)
337320

338-
if self.use_kernel and layer.use_ep:
321+
if self.use_kernel:
339322
output = fused_ep_moe(
340323
mesh=self.mesh,
341324
tokens=x,
@@ -344,10 +327,10 @@ def apply(
344327
b1=w13_bias,
345328
b2=w2_bias,
346329
gating_output=gating_output,
347-
top_k=top_k,
330+
top_k=layer.top_k,
348331
ep_axis_name=self.ep_axis_name,
349-
renormalize_topk_logits=renormalize,
350-
act_fn=activation,
332+
renormalize_topk_logits=layer.renormalize,
333+
act_fn=layer.activation,
351334
**self.block_size,
352335
)
353336
else:
@@ -358,11 +341,11 @@ def apply(
358341
w1_bias=w13_bias,
359342
w2_bias=w2_bias,
360343
gating_output=gating_output,
361-
topk=top_k,
362-
renormalize=renormalize,
344+
topk=layer.top_k,
345+
renormalize=layer.renormalize,
363346
mesh=self.mesh,
364347
use_ep=layer.use_ep,
365-
activation=activation,
348+
activation=layer.activation,
366349
)
367350

368351
return torch_view(output)

0 commit comments

Comments
 (0)