@@ -438,6 +438,8 @@ class StaticAttentionIOManager {
438438 RopeT* rope_freqs_cos;
439439 RopeT* rope_freqs_sin;
440440 StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
441+ bool generate_full_logits = true ;
442+ size_t last_valid_token_pos_index = 0 ;
441443 };
442444
443445 StaticAttentionIOManager (StaticAttentionIOConfig config)
@@ -607,9 +609,14 @@ class StaticAttentionIOManager {
607609 batch_len = std::min (input_len, tokens.size () - i);
608610 if (input_pos_ + batch_len > config_.max_context_len ) {
609611 ET_LOG (Error, " Maximum context size reached, stopping prefill." );
610- return input_len - 1 ;
612+ return config_. generate_full_logits ? input_len - 1 : 0 ;
611613 }
612614 std::copy (&tokens[i], &tokens[i + batch_len], input_buffer.begin ());
615+ if (!config_.generate_full_logits ) {
616+ last_valid_token_pos_ = batch_len - 1 ;
617+ set_input (
618+ method, config_.last_valid_token_pos_index , &last_valid_token_pos_);
619+ }
613620 prepare (method);
614621 ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
615622 update (
@@ -622,10 +629,12 @@ class StaticAttentionIOManager {
622629 auto * logits = logits_tensor.const_data_ptr <LogitT>();
623630 logits_callback (executorch::runtime::Span (
624631 logits,
625- logits + batch_len * logits_tensor.size (logits_tensor.dim () - 1 )));
632+ logits +
633+ (config_.generate_full_logits ? batch_len : 1 ) *
634+ logits_tensor.size (logits_tensor.dim () - 1 )));
626635 }
627636 }
628- return batch_len - 1 ;
637+ return config_. generate_full_logits ? batch_len - 1 : 0 ;
629638 }
630639
631640 /* *
@@ -648,6 +657,11 @@ class StaticAttentionIOManager {
648657 mask.set_causal_mask ();
649658 set_input (method, config_.cache_len_to_mask_idx [pair.first ], mask.get ());
650659 }
660+ if (!config_.generate_full_logits ) {
661+ last_valid_token_pos_ = 0 ;
662+ set_input (
663+ method, config_.last_valid_token_pos_index , &last_valid_token_pos_);
664+ }
651665
652666 while (true ) {
653667 input_buffer[0 ] = prev_tok;
@@ -685,6 +699,7 @@ class StaticAttentionIOManager {
685699 size_t window_size,
686700 size_t n_verifications,
687701 std::unordered_map<TokenT, SuffixCache<TokenT>> suffix_caches) {
702+ ET_CHECK (config_.generate_full_logits );
688703 ET_LOG (
689704 Info,
690705 " Decoding with lookahead and verification at position %zu" ,
@@ -968,6 +983,7 @@ class StaticAttentionIOManager {
968983 std::unordered_map<size_t , PerCacheLenMasks> attentionMasks_;
969984 std::vector<RopeT> rope_freqs_cos_override_;
970985 std::vector<RopeT> rope_freqs_sin_override_;
986+ int64_t last_valid_token_pos_;
971987};
972988
973989} // namespace example
0 commit comments