Skip to content

Commit 166d439

Browse files
sxufacebook-github-bot
authored andcommitted
StaticAttetnion runtime support for generate_full_logits=False (pytorch#16171)
Summary: Prefill will produce the logits at position 0 when not generating full logits. Lookahead decoding requires full logits. Reviewed By: billmguo Differential Revision: D88790445
1 parent ee236cb commit 166d439

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ class StaticAttentionIOManager {
438438
RopeT* rope_freqs_cos;
439439
RopeT* rope_freqs_sin;
440440
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
441+
bool generate_full_logits = true;
442+
size_t last_valid_token_pos_index = 0;
441443
};
442444

443445
StaticAttentionIOManager(StaticAttentionIOConfig config)
@@ -607,9 +609,13 @@ class StaticAttentionIOManager {
607609
batch_len = std::min(input_len, tokens.size() - i);
608610
if (input_pos_ + batch_len > config_.max_context_len) {
609611
ET_LOG(Error, "Maximum context size reached, stopping prefill.");
610-
return input_len - 1;
612+
return config_.generate_full_logits ? input_len - 1 : 0;
611613
}
612614
std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin());
615+
if (!config_.generate_full_logits) {
616+
last_valid_token_pos_ = batch_len - 1;
617+
set_input(method, config_.last_valid_token_pos_index, &last_valid_token_pos_);
618+
}
613619
prepare(method);
614620
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
615621
update(
@@ -622,10 +628,12 @@ class StaticAttentionIOManager {
622628
auto* logits = logits_tensor.const_data_ptr<LogitT>();
623629
logits_callback(executorch::runtime::Span(
624630
logits,
625-
logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1)));
631+
logits +
632+
(config_.generate_full_logits ? batch_len : 1) *
633+
logits_tensor.size(logits_tensor.dim() - 1)));
626634
}
627635
}
628-
return batch_len - 1;
636+
return config_.generate_full_logits ? batch_len - 1 : 0;
629637
}
630638

631639
/**
@@ -648,6 +656,10 @@ class StaticAttentionIOManager {
648656
mask.set_causal_mask();
649657
set_input(method, config_.cache_len_to_mask_idx[pair.first], mask.get());
650658
}
659+
if (!config_.generate_full_logits) {
660+
last_valid_token_pos_ = 0;
661+
set_input(method, config_.last_valid_token_pos_index, &last_valid_token_pos_);
662+
}
651663

652664
while (true) {
653665
input_buffer[0] = prev_tok;
@@ -685,6 +697,7 @@ class StaticAttentionIOManager {
685697
size_t window_size,
686698
size_t n_verifications,
687699
std::unordered_map<TokenT, SuffixCache<TokenT>> suffix_caches) {
700+
ET_CHECK(config_.generate_full_logits);
688701
ET_LOG(
689702
Info,
690703
"Decoding with lookahead and verification at position %zu",
@@ -968,6 +981,7 @@ class StaticAttentionIOManager {
968981
std::unordered_map<size_t, PerCacheLenMasks> attentionMasks_;
969982
std::vector<RopeT> rope_freqs_cos_override_;
970983
std::vector<RopeT> rope_freqs_sin_override_;
984+
int64_t last_valid_token_pos_;
971985
};
972986

973987
} // namespace example

examples/models/llama/static_attention.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,12 @@ def decode(
456456
new_tokens = [init_token]
457457
for _ in range(n):
458458
y = self._run_once(model, new_tokens[-1:])[0]
459-
new_tokens.append(y[:, :1, ...].argmax().item())
459+
print(y.shape, flush=True)
460+
if self.generate_full_logits:
461+
print(y[:, :1, ...].shape, flush=True)
462+
new_tokens.append(y[:, :1, ...].argmax().item())
463+
else:
464+
new_tokens.append(y.argmax().item())
460465
if new_tokens[-1] in stop_tokens:
461466
break
462467

@@ -607,13 +612,18 @@ def _run_once(
607612
freqs_cos_override = self.freqs_cos[self.pos : self.pos + self.input_len]
608613
if freqs_sin_override is None:
609614
freqs_sin_override = self.freqs_sin[self.pos : self.pos + self.input_len]
615+
if not self.generate_full_logits:
616+
extra_attn_options = {"last_valid_token_pos": torch.tensor([n_tokens - 1], dtype=torch.long)}
617+
else:
618+
extra_attn_options = {}
610619
y, attn_updates = model(
611620
tokens,
612621
{
613622
"masks": self.masks,
614623
"freqs_cos_override": freqs_cos_override,
615624
"freqs_sin_override": freqs_sin_override,
616625
"in_cache_state": (self.k_caches, self.v_caches),
626+
**extra_attn_options,
617627
},
618628
)
619629
non_padded_len = non_padded_len or n_tokens

0 commit comments

Comments
 (0)