Skip to content

Commit 9c3ae40

Browse files
sxufacebook-github-bot
authored andcommitted
StaticAttetnion runtime support for generate_full_logits=False
Summary: Prefill will produce the logits at position 0 when not generating full logits. Lookahead decoding requires full logits. Differential Revision: D88790445
1 parent d39d64b commit 9c3ae40

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ class StaticAttentionIOManager {
438438
RopeT* rope_freqs_cos;
439439
RopeT* rope_freqs_sin;
440440
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
441+
bool generate_full_logits = true;
441442
};
442443

443444
StaticAttentionIOManager(StaticAttentionIOConfig config)
@@ -607,7 +608,7 @@ class StaticAttentionIOManager {
607608
batch_len = std::min(input_len, tokens.size() - i);
608609
if (input_pos_ + batch_len > config_.max_context_len) {
609610
ET_LOG(Error, "Maximum context size reached, stopping prefill.");
610-
return input_len - 1;
611+
return config_.generate_full_logits ? input_len - 1 : 0;
611612
}
612613
std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin());
613614
prepare(method);
@@ -622,10 +623,12 @@ class StaticAttentionIOManager {
622623
auto* logits = logits_tensor.const_data_ptr<LogitT>();
623624
logits_callback(executorch::runtime::Span(
624625
logits,
625-
logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1)));
626+
logits +
627+
(config_.generate_full_logits ? batch_len : 1) *
628+
logits_tensor.size(logits_tensor.dim() - 1)));
626629
}
627630
}
628-
return batch_len - 1;
631+
return config_.generate_full_logits ? batch_len - 1 : 0;
629632
}
630633

631634
/**
@@ -685,6 +688,7 @@ class StaticAttentionIOManager {
685688
size_t window_size,
686689
size_t n_verifications,
687690
std::unordered_map<TokenT, SuffixCache<TokenT>> suffix_caches) {
691+
ET_CHECK(config_.generate_full_logits);
688692
ET_LOG(
689693
Info,
690694
"Decoding with lookahead and verification at position %zu",

0 commit comments

Comments
 (0)