@@ -438,6 +438,8 @@ class StaticAttentionIOManager {
438438 RopeT* rope_freqs_cos;
439439 RopeT* rope_freqs_sin;
440440 StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
441+ bool generate_full_logits = true ;
442+ std::optional<size_t > last_valid_token_pos_index = 0 ;
441443 };
442444
443445 StaticAttentionIOManager (StaticAttentionIOConfig config)
@@ -607,9 +609,16 @@ class StaticAttentionIOManager {
607609 batch_len = std::min (input_len, tokens.size () - i);
608610 if (input_pos_ + batch_len > config_.max_context_len ) {
609611 ET_LOG (Error, " Maximum context size reached, stopping prefill." );
610- return input_len - 1 ;
612+ return config_. generate_full_logits ? input_len - 1 : 0 ;
611613 }
612614 std::copy (&tokens[i], &tokens[i + batch_len], input_buffer.begin ());
615+ if (!config_.generate_full_logits && config_.last_valid_token_pos_index ) {
616+ last_valid_token_pos_ = batch_len - 1 ;
617+ set_input (
618+ method,
619+ *config_.last_valid_token_pos_index ,
620+ &last_valid_token_pos_);
621+ }
613622 prepare (method);
614623 ET_CHECK (method.execute () == executorch::runtime::Error::Ok);
615624 update (
@@ -622,10 +631,12 @@ class StaticAttentionIOManager {
622631 auto * logits = logits_tensor.const_data_ptr <LogitT>();
623632 logits_callback (executorch::runtime::Span (
624633 logits,
625- logits + batch_len * logits_tensor.size (logits_tensor.dim () - 1 )));
634+ logits +
635+ (config_.generate_full_logits ? batch_len : 1 ) *
636+ logits_tensor.size (logits_tensor.dim () - 1 )));
626637 }
627638 }
628- return batch_len - 1 ;
639+ return config_. generate_full_logits ? batch_len - 1 : 0 ;
629640 }
630641
631642 /* *
@@ -648,6 +659,11 @@ class StaticAttentionIOManager {
648659 mask.set_causal_mask ();
649660 set_input (method, config_.cache_len_to_mask_idx [pair.first ], mask.get ());
650661 }
662+ if (!config_.generate_full_logits && config_.last_valid_token_pos_index ) {
663+ last_valid_token_pos_ = 0 ;
664+ set_input (
665+ method, *config_.last_valid_token_pos_index , &last_valid_token_pos_);
666+ }
651667
652668 while (true ) {
653669 input_buffer[0 ] = prev_tok;
@@ -685,6 +701,7 @@ class StaticAttentionIOManager {
685701 size_t window_size,
686702 size_t n_verifications,
687703 std::unordered_map<TokenT, SuffixCache<TokenT>> suffix_caches) {
704+ ET_CHECK (config_.generate_full_logits );
688705 ET_LOG (
689706 Info,
690707 " Decoding with lookahead and verification at position %zu" ,
@@ -968,6 +985,7 @@ class StaticAttentionIOManager {
968985 std::unordered_map<size_t , PerCacheLenMasks> attentionMasks_;
969986 std::vector<RopeT> rope_freqs_cos_override_;
970987 std::vector<RopeT> rope_freqs_sin_override_;
988+ int64_t last_valid_token_pos_;
971989};
972990
973991} // namespace example
0 commit comments