Skip to content

Commit b3c44c9

Browse files
authored
StaticAttetnion runtime support for generate_full_logits=False
Differential Revision: D88790445 Pull Request resolved: #16171
1 parent 756c709 commit b3c44c9

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ class StaticAttentionIOManager {
438438
RopeT* rope_freqs_cos;
439439
RopeT* rope_freqs_sin;
440440
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
441+
bool generate_full_logits = true;
442+
std::optional<size_t> last_valid_token_pos_index = 0;
441443
};
442444

443445
StaticAttentionIOManager(StaticAttentionIOConfig config)
@@ -607,9 +609,16 @@ class StaticAttentionIOManager {
607609
batch_len = std::min(input_len, tokens.size() - i);
608610
if (input_pos_ + batch_len > config_.max_context_len) {
609611
ET_LOG(Error, "Maximum context size reached, stopping prefill.");
610-
return input_len - 1;
612+
return config_.generate_full_logits ? input_len - 1 : 0;
611613
}
612614
std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin());
615+
if (!config_.generate_full_logits && config_.last_valid_token_pos_index) {
616+
last_valid_token_pos_ = batch_len - 1;
617+
set_input(
618+
method,
619+
*config_.last_valid_token_pos_index,
620+
&last_valid_token_pos_);
621+
}
613622
prepare(method);
614623
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
615624
update(
@@ -622,10 +631,12 @@ class StaticAttentionIOManager {
622631
auto* logits = logits_tensor.const_data_ptr<LogitT>();
623632
logits_callback(executorch::runtime::Span(
624633
logits,
625-
logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1)));
634+
logits +
635+
(config_.generate_full_logits ? batch_len : 1) *
636+
logits_tensor.size(logits_tensor.dim() - 1)));
626637
}
627638
}
628-
return batch_len - 1;
639+
return config_.generate_full_logits ? batch_len - 1 : 0;
629640
}
630641

631642
/**
@@ -648,6 +659,11 @@ class StaticAttentionIOManager {
648659
mask.set_causal_mask();
649660
set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get());
650661
}
662+
if (!config_.generate_full_logits && config_.last_valid_token_pos_index) {
663+
last_valid_token_pos_ = 0;
664+
set_input(
665+
method, *config_.last_valid_token_pos_index, &last_valid_token_pos_);
666+
}
651667

652668
while (true) {
653669
input_buffer[0] = prev_tok;
@@ -685,6 +701,7 @@ class StaticAttentionIOManager {
685701
size_t window_size,
686702
size_t n_verifications,
687703
std::unordered_map<TokenT, SuffixCache<TokenT>> suffix_caches) {
704+
ET_CHECK(config_.generate_full_logits);
688705
ET_LOG(
689706
Info,
690707
"Decoding with lookahead and verification at position %zu",
@@ -968,6 +985,7 @@ class StaticAttentionIOManager {
968985
std::unordered_map<size_t, PerCacheLenMasks> attentionMasks_;
969986
std::vector<RopeT> rope_freqs_cos_override_;
970987
std::vector<RopeT> rope_freqs_sin_override_;
988+
int64_t last_valid_token_pos_;
971989
};
972990

973991
} // namespace example

examples/models/llama/static_attention.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,10 @@ def decode(
456456
new_tokens = [init_token]
457457
for _ in range(n):
458458
y = self._run_once(model, new_tokens[-1:])[0]
459-
new_tokens.append(y[:, :1, ...].argmax().item())
459+
if self.generate_full_logits:
460+
new_tokens.append(y[:, :1, ...].argmax().item())
461+
else:
462+
new_tokens.append(y.argmax().item())
460463
if new_tokens[-1] in stop_tokens:
461464
break
462465

@@ -607,13 +610,20 @@ def _run_once(
607610
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
608611
if freqs_sin_override is None:
609612
freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len]
613+
if not self.generate_full_logits:
614+
extra_attn_options = {
615+
"last_valid_token_pos": torch.tensor([n_tokens - 1], dtype=torch.long)
616+
}
617+
else:
618+
extra_attn_options = {}
610619
y, attn_updates = model(
611620
tokens,
612621
{
613622
"masks": self.masks,
614623
"freqs_cos_override": freqs_cos_override,
615624
"freqs_sin_override": freqs_sin_override,
616625
"in_cache_state": (self.k_caches, self.v_caches),
626+
**extra_attn_options,
617627
},
618628
)
619629
non_padded_len = non_padded_len or n_tokens

examples/models/llama/tests/test_static_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def test(style, attention_type):
248248
)
249249
ys.append(y_i)
250250

251-
self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
251+
self.assertTrue(
252+
torch.isclose(ys[-1].flatten(), expected.flatten(), rtol=1e-3).all()
253+
)
252254

253255
for args in itertools.product(
254256
["shift_pointer", "smart_mask"], ["static", "static_mha"]

0 commit comments

Comments
 (0)