From 38fada4d959c27b49c40bc37c66d2380fa211d1f Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Thu, 11 Dec 2025 10:57:36 -0800 Subject: [PATCH] StaticAttetnion runtime support for generate_full_logits=False (#16171) Summary: Prefill will produce the logits at position 0 when not generating full logits. Lookahead decoding requires full logits. Reviewed By: billmguo Differential Revision: D88790445 --- .../runner/static_attention_io_manager.h | 24 ++++++++++++++++--- examples/models/llama/static_attention.py | 12 +++++++++- .../llama/tests/test_static_attention.py | 4 +++- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index 06fbffbef83..c4e851f0b0c 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -438,6 +438,8 @@ class StaticAttentionIOManager { RopeT* rope_freqs_cos; RopeT* rope_freqs_sin; StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK; + bool generate_full_logits = true; + std::optional last_valid_token_pos_index = 0; }; StaticAttentionIOManager(StaticAttentionIOConfig config) @@ -607,9 +609,16 @@ class StaticAttentionIOManager { batch_len = std::min(input_len, tokens.size() - i); if (input_pos_ + batch_len > config_.max_context_len) { ET_LOG(Error, "Maximum context size reached, stopping prefill."); - return input_len - 1; + return config_.generate_full_logits ? input_len - 1 : 0; } std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin()); + if (!config_.generate_full_logits && config_.last_valid_token_pos_index) { + last_valid_token_pos_ = batch_len - 1; + set_input( + method, + *config_.last_valid_token_pos_index, + &last_valid_token_pos_); + } prepare(method); ET_CHECK(method.execute() == executorch::runtime::Error::Ok); update( @@ -622,10 +631,12 @@ class StaticAttentionIOManager { auto* logits = logits_tensor.const_data_ptr(); logits_callback(executorch::runtime::Span( logits, - logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1))); + logits + + (config_.generate_full_logits ? batch_len : 1) * + logits_tensor.size(logits_tensor.dim() - 1))); } } - return batch_len - 1; + return config_.generate_full_logits ? batch_len - 1 : 0; } /** @@ -648,6 +659,11 @@ class StaticAttentionIOManager { mask.set_causal_mask(); set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get()); } + if (!config_.generate_full_logits && config_.last_valid_token_pos_index) { + last_valid_token_pos_ = 0; + set_input( + method, *config_.last_valid_token_pos_index, &last_valid_token_pos_); + } while (true) { input_buffer[0] = prev_tok; @@ -685,6 +701,7 @@ class StaticAttentionIOManager { size_t window_size, size_t n_verifications, std::unordered_map> suffix_caches) { + ET_CHECK(config_.generate_full_logits); ET_LOG( Info, "Decoding with lookahead and verification at position %zu", @@ -968,6 +985,7 @@ class StaticAttentionIOManager { std::unordered_map attentionMasks_; std::vector rope_freqs_cos_override_; std::vector rope_freqs_sin_override_; + int64_t last_valid_token_pos_; }; } // namespace example diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 871f40ffc69..f97873ce646 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -456,7 +456,10 @@ def decode( new_tokens = [init_token] for _ in range(n): y = self._run_once(model, new_tokens[-1:])[0] - new_tokens.append(y[:, :1, ...].argmax().item()) + if self.generate_full_logits: + new_tokens.append(y[:, :1, ...].argmax().item()) + else: + new_tokens.append(y.argmax().item()) if new_tokens[-1] in stop_tokens: break @@ -607,6 +610,12 @@ def _run_once( freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len] if freqs_sin_override is None: freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len] + if not self.generate_full_logits: + extra_attn_options = { + "last_valid_token_pos": torch.tensor([n_tokens - 1], dtype=torch.long) + } + else: + extra_attn_options = {} y, attn_updates = model( tokens, { @@ -614,6 +623,7 @@ def _run_once( "freqs_cos_override": freqs_cos_override, "freqs_sin_override": freqs_sin_override, "in_cache_state": (self.k_caches, self.v_caches), + **extra_attn_options, }, ) non_padded_len = non_padded_len or n_tokens diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 0d407968c0e..460727449df 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -248,7 +248,9 @@ def test(style, attention_type): ) ys.append(y_i) - self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all()) + self.assertTrue( + torch.isclose(ys[-1].flatten(), expected.flatten(), rtol=1e-3).all() + ) for args in itertools.product( ["shift_pointer", "smart_mask"], ["static", "static_mha"]