66import jax
77import jax .numpy as jnp
88import numpy as np
9+ from flax import nnx
10+ from jax import lax
11+ from jax .sharding import NamedSharding , PartitionSpec
912from vllm .config import VllmConfig
1013
1114from tpu_inference .layers .common .attention_metadata import AttentionMetadata
@@ -127,6 +130,17 @@ def _update_inputs_for_loop_speculation(
127130 max_num_blocks_per_req )
128131 new_block_tables = jnp .where (expanded_exceeds_mask , - 1 , block_tables )
129132
133+ positions = lax .with_sharding_constraint (
134+ positions , NamedSharding (self .mesh , PartitionSpec (None , )))
135+ clamped_positions = lax .with_sharding_constraint (
136+ clamped_positions , NamedSharding (self .mesh , PartitionSpec (None , )))
137+ new_seq_lens = lax .with_sharding_constraint (
138+ new_seq_lens , NamedSharding (self .mesh , PartitionSpec (None , )))
139+ query_start_loc = lax .with_sharding_constraint (
140+ query_start_loc , NamedSharding (self .mesh , PartitionSpec ()))
141+ new_block_tables = lax .with_sharding_constraint (
142+ new_block_tables , NamedSharding (self .mesh , PartitionSpec (None , )))
143+
130144 return positions , clamped_positions , new_seq_lens , query_start_loc , new_block_tables
131145
132146 @functools .partial (jax .jit , static_argnums = (0 , ))
@@ -138,6 +152,7 @@ def _stack_draft_token_ids(
138152 @functools .partial (jax .jit , static_argnums = (0 , ))
139153 def _prepare_hidden_states_and_input_ids (
140154 self ,
155+ state : nnx .State ,
141156 aux_hidden_states : tuple [jax .Array , ...],
142157 query_start_loc : jax .Array ,
143158 target_token_ids : jax .Array ,
@@ -146,7 +161,7 @@ def _prepare_hidden_states_and_input_ids(
146161 ) -> tuple [jax .Array , jax .Array , jax .Array ]:
147162 target_hidden_states = jnp .concatenate (aux_hidden_states , axis = - 1 )
148163 target_hidden_states = self .combine_hidden_states_fn (
149- self . state , target_hidden_states )
164+ state , target_hidden_states )
150165
151166 input_ids , last_token_indices = self ._prepare_input_ids (
152167 query_start_loc , target_token_ids , next_token_ids , num_reqs )
@@ -193,8 +208,8 @@ def prepare_inputs(
193208 block_tables = device_array (
194209 self .mesh , block_tables ))
195210 target_hidden_states , input_ids , last_token_indices = self ._prepare_hidden_states_and_input_ids (
196- aux_hidden_states , attn_metadata .query_start_loc , input_ids ,
197- next_token_ids , num_reqs )
211+ self . state , aux_hidden_states , attn_metadata .query_start_loc ,
212+ input_ids , next_token_ids , num_reqs )
198213 return target_hidden_states , input_ids , last_token_indices , attn_metadata
199214
200215 # Host copies from the metadata prepared by the runner.
@@ -258,12 +273,13 @@ def prepare_inputs(
258273
259274 attn_metadata = replace (attn_metadata , block_tables = block_tables )
260275 return self ._filter_token_and_prepare_initial_inputs (
261- token_indices , query_start_loc , seq_lens , input_ids ,
276+ self . state , token_indices , query_start_loc , seq_lens , input_ids ,
262277 aux_hidden_states , attn_metadata , next_token_ids , num_reqs )
263278
264279 @functools .partial (jax .jit , static_argnums = (0 , ))
265280 def _filter_token_and_prepare_initial_inputs (
266281 self ,
282+ state : nnx .State ,
267283 token_indices : jax .Array ,
268284 query_start_loc : jax .Array ,
269285 seq_lens : jax .Array ,
@@ -291,35 +307,51 @@ def _filter_token_and_prepare_initial_inputs(
291307 )
292308
293309 target_hidden_states , input_ids , last_token_indices = self ._prepare_hidden_states_and_input_ids (
294- [h [token_indices ] for h in aux_hidden_states ], query_start_loc ,
295- target_token_ids , next_token_ids , num_reqs )
310+ state , [h [token_indices ] for h in aux_hidden_states ],
311+ query_start_loc , target_token_ids , next_token_ids , num_reqs )
296312
297313 return target_hidden_states , input_ids , last_token_indices , attn_metadata
298314
299315 @functools .partial (jax .jit , static_argnums = (0 , ))
300316 def _select_draft_token_ids (
301317 self ,
318+ state : nnx .State ,
302319 hidden_states : jax .Array ,
303320 last_token_indices : jax .Array ,
304321 ) -> jax .Array :
305322 sample_hidden_states = hidden_states [last_token_indices ]
306- return self ._get_draft_token_ids (sample_hidden_states )
323+ sample_hidden_states = lax .with_sharding_constraint (
324+ sample_hidden_states ,
325+ NamedSharding (self .mesh , PartitionSpec (None , None )))
326+ return self ._get_draft_token_ids (state , sample_hidden_states )
307327
308328 @functools .partial (jax .jit , static_argnums = (0 , ))
309- def _get_draft_token_ids (self , hidden_states : jax .Array ) -> jax .Array :
329+ def _get_draft_token_ids (self , state : nnx .State ,
330+ hidden_states : jax .Array ) -> jax .Array :
310331 lora_metadata = None
311- logits = self .compute_logits_fn (self .state , hidden_states ,
312- lora_metadata )
313- return jnp .argmax (logits , axis = - 1 )
332+ logits = self .compute_logits_fn (state , hidden_states , lora_metadata )
333+ draft_token_ids = jnp .argmax (logits , axis = - 1 )
334+ return lax .with_sharding_constraint (
335+ draft_token_ids , NamedSharding (self .mesh , PartitionSpec ()))
314336
315337 @functools .partial (jax .jit , static_argnums = (0 , ))
316338 def _select_inputs_for_loop_speculation (
317- self , positions : jax .Array , residual : jax .Array ,
339+ self , state : nnx . State , positions : jax .Array , residual : jax .Array ,
318340 hidden_states : jax .Array ,
319341 last_token_indices : jax .Array ) -> tuple [jax .Array , jax .Array ]:
320- return positions [last_token_indices ], residual [
321- last_token_indices ], self ._select_draft_token_ids (
322- hidden_states , last_token_indices )
342+ positions = positions [last_token_indices ]
343+ residual = residual [last_token_indices ]
344+ draft_token_ids = self ._select_draft_token_ids (state , hidden_states ,
345+ last_token_indices )
346+
347+ positions = lax .with_sharding_constraint (
348+ positions , NamedSharding (self .mesh , PartitionSpec (None , )))
349+ residual = lax .with_sharding_constraint (
350+ residual , NamedSharding (self .mesh , PartitionSpec (None , None )))
351+ draft_token_ids = lax .with_sharding_constraint (
352+ draft_token_ids , NamedSharding (self .mesh , PartitionSpec ()))
353+
354+ return positions , residual , draft_token_ids
323355
324356 def propose (
325357 self ,
@@ -346,11 +378,11 @@ def propose(
346378
347379 if self .num_speculative_tokens == 1 :
348380 return kv_caches , self ._select_draft_token_ids (
349- hidden_states , last_token_indices )
381+ self . state , hidden_states , last_token_indices )
350382
351383 positions , hidden_states , draft_token_ids = self ._select_inputs_for_loop_speculation (
352- attn_metadata .input_positions , residual [0 ], hidden_states ,
353- last_token_indices )
384+ self . state , attn_metadata .input_positions , residual [0 ],
385+ hidden_states , last_token_indices )
354386
355387 draft_token_ids_list = [draft_token_ids ]
356388
@@ -375,7 +407,8 @@ def propose(
375407 attn_metadata ,
376408 )
377409 hidden_states = residual [0 ]
378- draft_token_ids = self ._get_draft_token_ids (new_hidden_states )
410+ draft_token_ids = self ._get_draft_token_ids (
411+ self .state , new_hidden_states )
379412 draft_token_ids_list .append (draft_token_ids )
380413
381414 # [batch_size, num_speculative_tokens]
0 commit comments