Skip to content

Commit 927c8d5

Browse files
sxufacebook-github-bot
authored andcommitted
StaticAttetnion runtime support for generate_full_logits=False (#16171)
Summary: Prefill will produce the logits at position 0 when not generating full logits. Lookahead decoding requires full logits. Reviewed By: billmguo Differential Revision: D88790445
1 parent 71ebc50 commit 927c8d5

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ class StaticAttentionIOManager {
438438
RopeT* rope_freqs_cos;
439439
RopeT* rope_freqs_sin;
440440
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
441+
bool generate_full_logits = true;
442+
size_t last_valid_token_pos_index = 0;
441443
};
442444

443445
StaticAttentionIOManager(StaticAttentionIOConfig config)
@@ -607,9 +609,14 @@ class StaticAttentionIOManager {
607609
batch_len = std::min(input_len, tokens.size() - i);
608610
if (input_pos_ + batch_len > config_.max_context_len) {
609611
ET_LOG(Error, "Maximum context size reached, stopping prefill.");
610-
return input_len - 1;
612+
return config_.generate_full_logits ? input_len - 1 : 0;
611613
}
612614
std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin());
615+
if (!config_.generate_full_logits) {
616+
last_valid_token_pos_ = batch_len - 1;
617+
set_input(
618+
method, config_.last_valid_token_pos_index, &last_valid_token_pos_);
619+
}
613620
prepare(method);
614621
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
615622
update(
@@ -622,10 +629,12 @@ class StaticAttentionIOManager {
622629
auto* logits = logits_tensor.const_data_ptr<LogitT>();
623630
logits_callback(executorch::runtime::Span(
624631
logits,
625-
logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1)));
632+
logits +
633+
(config_.generate_full_logits ? batch_len : 1) *
634+
logits_tensor.size(logits_tensor.dim() - 1)));
626635
}
627636
}
628-
return batch_len - 1;
637+
return config_.generate_full_logits ? batch_len - 1 : 0;
629638
}
630639

631640
/**
@@ -648,6 +657,11 @@ class StaticAttentionIOManager {
648657
mask.set_causal_mask();
649658
set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get());
650659
}
660+
if (!config_.generate_full_logits) {
661+
last_valid_token_pos_ = 0;
662+
set_input(
663+
method, config_.last_valid_token_pos_index, &last_valid_token_pos_);
664+
}
651665

652666
while (true) {
653667
input_buffer[0] = prev_tok;
@@ -685,6 +699,7 @@ class StaticAttentionIOManager {
685699
size_t window_size,
686700
size_t n_verifications,
687701
std::unordered_map<TokenT, SuffixCache<TokenT>> suffix_caches) {
702+
ET_CHECK(config_.generate_full_logits);
688703
ET_LOG(
689704
Info,
690705
"Decoding with lookahead and verification at position %zu",
@@ -968,6 +983,7 @@ class StaticAttentionIOManager {
968983
std::unordered_map<size_t, PerCacheLenMasks> attentionMasks_;
969984
std::vector<RopeT> rope_freqs_cos_override_;
970985
std::vector<RopeT> rope_freqs_sin_override_;
986+
int64_t last_valid_token_pos_;
971987
};
972988

973989
} // namespace example

examples/models/llama/static_attention.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,10 @@ def decode(
456456
new_tokens = [init_token]
457457
for _ in range(n):
458458
y = self._run_once(model, new_tokens[-1:])[0]
459-
new_tokens.append(y[:, :1, ...].argmax().item())
459+
if self.generate_full_logits:
460+
new_tokens.append(y[:, :1, ...].argmax().item())
461+
else:
462+
new_tokens.append(y.argmax().item())
460463
if new_tokens[-1] in stop_tokens:
461464
break
462465

@@ -607,13 +610,20 @@ def _run_once(
607610
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
608611
if freqs_sin_override is None:
609612
freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len]
613+
if not self.generate_full_logits:
614+
extra_attn_options = {
615+
"last_valid_token_pos": torch.tensor([n_tokens - 1], dtype=torch.long)
616+
}
617+
else:
618+
extra_attn_options = {}
610619
y, attn_updates = model(
611620
tokens,
612621
{
613622
"masks": self.masks,
614623
"freqs_cos_override": freqs_cos_override,
615624
"freqs_sin_override": freqs_sin_override,
616625
"in_cache_state": (self.k_caches, self.v_caches),
626+
**extra_attn_options,
617627
},
618628
)
619629
non_padded_len = non_padded_len or n_tokens

examples/models/llama/tests/test_static_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def test(style, attention_type):
248248
)
249249
ys.append(y_i)
250250

251-
self.assertTrue(torch.isclose(ys[-1], expected, rtol=1e-3).all())
251+
self.assertTrue(
252+
torch.isclose(ys[-1].flatten(), expected.flatten(), rtol=1e-3).all()
253+
)
252254

253255
for args in itertools.product(
254256
["shift_pointer", "smart_mask"], ["static", "static_mha"]

0 commit comments

Comments
 (0)