Skip to content

Commit 63afd75

Browse files
py4Pooya Moradi
andauthored
[Spec][Eagle3] Improve perf and compilation time (#1192)
Signed-off-by: Pooya Moradi <pooyam@google.com> Co-authored-by: Pooya Moradi <pooyam@google.com>
1 parent b26a566 commit 63afd75

File tree

2 files changed

+59
-22
lines changed

2 files changed

+59
-22
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def _precompile_eagle3_helpers(self) -> None:
657657
self._run_compilation(
658658
"eagle3_get_draft_token_ids",
659659
self.runner.drafter._get_draft_token_ids,
660+
self.runner.drafter.state,
660661
hidden_states,
661662
num_logits=num_logits,
662663
)
@@ -701,9 +702,9 @@ def filter_token_and_prepare_initial_inputs_wrapper(
701702
num_reqs,
702703
):
703704
target_hidden_states, input_ids, last_token_indices, _ = self.runner.drafter._filter_token_and_prepare_initial_inputs(
704-
token_indices, query_start_loc, seq_lens, input_ids,
705-
aux_hidden_states, attention_metadata, next_token_ids,
706-
num_reqs)
705+
self.runner.drafter.state, token_indices, query_start_loc,
706+
seq_lens, input_ids, aux_hidden_states, attention_metadata,
707+
next_token_ids, num_reqs)
707708
return target_hidden_states, input_ids, last_token_indices
708709

709710
input_ids = self._create_dummy_tensor((num_tokens, ), jnp.int32)
@@ -780,6 +781,7 @@ def draft_model_fn_wrapper(
780781
self._run_compilation(
781782
"eagle3_prepare_hidden_states_and_input_ids",
782783
self.runner.drafter._prepare_hidden_states_and_input_ids,
784+
self.runner.drafter.state,
783785
aux_hidden_states,
784786
query_start_loc,
785787
target_token_ids,
@@ -814,6 +816,7 @@ def draft_model_fn_wrapper(
814816
self._run_compilation(
815817
"eagle3_select_inputs_for_loop_speculation",
816818
self.runner.drafter._select_inputs_for_loop_speculation,
819+
self.runner.drafter.state,
817820
positions,
818821
hidden_states,
819822
hidden_states,
@@ -824,6 +827,7 @@ def draft_model_fn_wrapper(
824827
self._run_compilation(
825828
"eagle3_select_draft_token_ids",
826829
self.runner.drafter._select_draft_token_ids,
830+
self.runner.drafter.state,
827831
hidden_states,
828832
last_token_indices,
829833
num_tokens=num_tokens,

tpu_inference/spec_decode/jax/eagle3.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import jax
77
import jax.numpy as jnp
88
import numpy as np
9+
from flax import nnx
10+
from jax import lax
11+
from jax.sharding import NamedSharding, PartitionSpec
912
from vllm.config import VllmConfig
1013

1114
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
@@ -127,6 +130,17 @@ def _update_inputs_for_loop_speculation(
127130
max_num_blocks_per_req)
128131
new_block_tables = jnp.where(expanded_exceeds_mask, -1, block_tables)
129132

133+
positions = lax.with_sharding_constraint(
134+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
135+
clamped_positions = lax.with_sharding_constraint(
136+
clamped_positions, NamedSharding(self.mesh, PartitionSpec(None, )))
137+
new_seq_lens = lax.with_sharding_constraint(
138+
new_seq_lens, NamedSharding(self.mesh, PartitionSpec(None, )))
139+
query_start_loc = lax.with_sharding_constraint(
140+
query_start_loc, NamedSharding(self.mesh, PartitionSpec()))
141+
new_block_tables = lax.with_sharding_constraint(
142+
new_block_tables, NamedSharding(self.mesh, PartitionSpec(None, )))
143+
130144
return positions, clamped_positions, new_seq_lens, query_start_loc, new_block_tables
131145

132146
@functools.partial(jax.jit, static_argnums=(0, ))
@@ -138,6 +152,7 @@ def _stack_draft_token_ids(
138152
@functools.partial(jax.jit, static_argnums=(0, ))
139153
def _prepare_hidden_states_and_input_ids(
140154
self,
155+
state: nnx.State,
141156
aux_hidden_states: tuple[jax.Array, ...],
142157
query_start_loc: jax.Array,
143158
target_token_ids: jax.Array,
@@ -146,7 +161,7 @@ def _prepare_hidden_states_and_input_ids(
146161
) -> tuple[jax.Array, jax.Array, jax.Array]:
147162
target_hidden_states = jnp.concatenate(aux_hidden_states, axis=-1)
148163
target_hidden_states = self.combine_hidden_states_fn(
149-
self.state, target_hidden_states)
164+
state, target_hidden_states)
150165

151166
input_ids, last_token_indices = self._prepare_input_ids(
152167
query_start_loc, target_token_ids, next_token_ids, num_reqs)
@@ -193,8 +208,8 @@ def prepare_inputs(
193208
block_tables=device_array(
194209
self.mesh, block_tables))
195210
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
196-
aux_hidden_states, attn_metadata.query_start_loc, input_ids,
197-
next_token_ids, num_reqs)
211+
self.state, aux_hidden_states, attn_metadata.query_start_loc,
212+
input_ids, next_token_ids, num_reqs)
198213
return target_hidden_states, input_ids, last_token_indices, attn_metadata
199214

200215
# Host copies from the metadata prepared by the runner.
@@ -258,12 +273,13 @@ def prepare_inputs(
258273

259274
attn_metadata = replace(attn_metadata, block_tables=block_tables)
260275
return self._filter_token_and_prepare_initial_inputs(
261-
token_indices, query_start_loc, seq_lens, input_ids,
276+
self.state, token_indices, query_start_loc, seq_lens, input_ids,
262277
aux_hidden_states, attn_metadata, next_token_ids, num_reqs)
263278

264279
@functools.partial(jax.jit, static_argnums=(0, ))
265280
def _filter_token_and_prepare_initial_inputs(
266281
self,
282+
state: nnx.State,
267283
token_indices: jax.Array,
268284
query_start_loc: jax.Array,
269285
seq_lens: jax.Array,
@@ -291,35 +307,51 @@ def _filter_token_and_prepare_initial_inputs(
291307
)
292308

293309
target_hidden_states, input_ids, last_token_indices = self._prepare_hidden_states_and_input_ids(
294-
[h[token_indices] for h in aux_hidden_states], query_start_loc,
295-
target_token_ids, next_token_ids, num_reqs)
310+
state, [h[token_indices] for h in aux_hidden_states],
311+
query_start_loc, target_token_ids, next_token_ids, num_reqs)
296312

297313
return target_hidden_states, input_ids, last_token_indices, attn_metadata
298314

299315
@functools.partial(jax.jit, static_argnums=(0, ))
300316
def _select_draft_token_ids(
301317
self,
318+
state: nnx.State,
302319
hidden_states: jax.Array,
303320
last_token_indices: jax.Array,
304321
) -> jax.Array:
305322
sample_hidden_states = hidden_states[last_token_indices]
306-
return self._get_draft_token_ids(sample_hidden_states)
323+
sample_hidden_states = lax.with_sharding_constraint(
324+
sample_hidden_states,
325+
NamedSharding(self.mesh, PartitionSpec(None, None)))
326+
return self._get_draft_token_ids(state, sample_hidden_states)
307327

308328
@functools.partial(jax.jit, static_argnums=(0, ))
309-
def _get_draft_token_ids(self, hidden_states: jax.Array) -> jax.Array:
329+
def _get_draft_token_ids(self, state: nnx.State,
330+
hidden_states: jax.Array) -> jax.Array:
310331
lora_metadata = None
311-
logits = self.compute_logits_fn(self.state, hidden_states,
312-
lora_metadata)
313-
return jnp.argmax(logits, axis=-1)
332+
logits = self.compute_logits_fn(state, hidden_states, lora_metadata)
333+
draft_token_ids = jnp.argmax(logits, axis=-1)
334+
return lax.with_sharding_constraint(
335+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
314336

315337
@functools.partial(jax.jit, static_argnums=(0, ))
316338
def _select_inputs_for_loop_speculation(
317-
self, positions: jax.Array, residual: jax.Array,
339+
self, state: nnx.State, positions: jax.Array, residual: jax.Array,
318340
hidden_states: jax.Array,
319341
last_token_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
320-
return positions[last_token_indices], residual[
321-
last_token_indices], self._select_draft_token_ids(
322-
hidden_states, last_token_indices)
342+
positions = positions[last_token_indices]
343+
residual = residual[last_token_indices]
344+
draft_token_ids = self._select_draft_token_ids(state, hidden_states,
345+
last_token_indices)
346+
347+
positions = lax.with_sharding_constraint(
348+
positions, NamedSharding(self.mesh, PartitionSpec(None, )))
349+
residual = lax.with_sharding_constraint(
350+
residual, NamedSharding(self.mesh, PartitionSpec(None, None)))
351+
draft_token_ids = lax.with_sharding_constraint(
352+
draft_token_ids, NamedSharding(self.mesh, PartitionSpec()))
353+
354+
return positions, residual, draft_token_ids
323355

324356
def propose(
325357
self,
@@ -346,11 +378,11 @@ def propose(
346378

347379
if self.num_speculative_tokens == 1:
348380
return kv_caches, self._select_draft_token_ids(
349-
hidden_states, last_token_indices)
381+
self.state, hidden_states, last_token_indices)
350382

351383
positions, hidden_states, draft_token_ids = self._select_inputs_for_loop_speculation(
352-
attn_metadata.input_positions, residual[0], hidden_states,
353-
last_token_indices)
384+
self.state, attn_metadata.input_positions, residual[0],
385+
hidden_states, last_token_indices)
354386

355387
draft_token_ids_list = [draft_token_ids]
356388

@@ -375,7 +407,8 @@ def propose(
375407
attn_metadata,
376408
)
377409
hidden_states = residual[0]
378-
draft_token_ids = self._get_draft_token_ids(new_hidden_states)
410+
draft_token_ids = self._get_draft_token_ids(
411+
self.state, new_hidden_states)
379412
draft_token_ids_list.append(draft_token_ids)
380413

381414
# [batch_size, num_speculative_tokens]

0 commit comments

Comments
 (0)