diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index 485aa99dced..0fe0bb0349b 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -4,7 +4,10 @@ use std::ops::Deref; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, LazyLock}; -use std::{cell::UnsafeCell, collections::BinaryHeap}; +use std::{ + cell::UnsafeCell, + collections::{BinaryHeap, VecDeque}, +}; use std::{cmp::Reverse, fmt::Debug}; use arrow::array::AsArray; @@ -35,7 +38,6 @@ use super::{ use super::{DocInfo, builder::BLOCK_SIZE}; const TERMINATED_DOC_ID: u64 = u64::MAX; - pub static FLAT_SEARCH_PERCENT_THRESHOLD: LazyLock = LazyLock::new(|| { std::env::var("LANCE_FLAT_SEARCH_PERCENT_THRESHOLD") .unwrap_or_else(|_| "10".to_string()) @@ -68,6 +70,7 @@ struct CompressedState { position_block_idx: Option, position_values: Vec, position_offsets: Vec, + block_max_window: BlockMaxWindow, } impl CompressedState { @@ -80,6 +83,7 @@ impl CompressedState { position_block_idx: None, position_values: Vec::new(), position_offsets: Vec::new(), + block_max_window: BlockMaxWindow::new(), } } @@ -114,6 +118,91 @@ impl CompressedState { } } +#[derive(Clone)] +struct BlockMaxWindow { + // Sliding block range used for Lucene-style getMaxScore(upTo). The deque is + // monotonic by score and covers blocks in [start_block_idx, next_block_idx). + start_block_idx: usize, + next_block_idx: usize, + max_scores: VecDeque<(usize, f32)>, +} + +struct BlockMaxScore { + score: f32, + blocks_scanned: usize, +} + +impl BlockMaxWindow { + fn new() -> Self { + Self { + start_block_idx: 0, + next_block_idx: 0, + max_scores: VecDeque::new(), + } + } + + fn reset(&mut self, start_block_idx: usize) { + self.start_block_idx = start_block_idx; + self.next_block_idx = start_block_idx; + self.max_scores.clear(); + } + + fn max_score_up_to( + &mut self, + list: &CompressedPostingList, + start_block_idx: usize, + up_to: u64, + ) -> BlockMaxScore { + if start_block_idx >= list.blocks.len() { + self.reset(start_block_idx); + return BlockMaxScore { + score: 0.0, + blocks_scanned: 0, + }; + } + if start_block_idx < self.start_block_idx || start_block_idx > self.next_block_idx { + self.reset(start_block_idx); + } + self.start_block_idx = start_block_idx; + while matches!(self.max_scores.front(), Some((block_idx, _)) if *block_idx < start_block_idx) + { + self.max_scores.pop_front(); + } + + if list.block_least_doc_id(start_block_idx) as u64 > up_to { + self.reset(start_block_idx); + return BlockMaxScore { + score: 0.0, + blocks_scanned: 0, + }; + } + + self.next_block_idx = self.next_block_idx.max(start_block_idx); + let mut blocks_scanned = 0; + while self.next_block_idx < list.blocks.len() + && list.block_least_doc_id(self.next_block_idx) as u64 <= up_to + { + let score = list.block_max_score(self.next_block_idx); + while matches!(self.max_scores.back(), Some((_, old_score)) if *old_score <= score) { + self.max_scores.pop_back(); + } + self.max_scores.push_back((self.next_block_idx, score)); + self.next_block_idx += 1; + blocks_scanned += 1; + } + + let score = self + .max_scores + .front() + .map(|(_, score)| *score) + .unwrap_or(0.0); + BlockMaxScore { + score, + blocks_scanned, + } + } +} + impl Debug for PostingIterator { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PostingIterator") @@ -412,6 +501,27 @@ impl PostingIterator { } } + #[inline] + fn block_max_score_up_to_with_stats(&mut self, up_to: u64) -> BlockMaxScore { + match self.list { + PostingList::Compressed(ref list) => { + let compressed = unsafe { &mut *self.compressed_state_ptr() }; + compressed + .block_max_window + .max_score_up_to(list, self.block_idx, up_to) + } + PostingList::Plain(_) => BlockMaxScore { + score: self.approximate_upper_bound, + blocks_scanned: 0, + }, + } + } + + #[inline] + fn is_compressed(&self) -> bool { + matches!(self.list, PostingList::Compressed(_)) + } + fn block_first_doc(&self) -> Option { match self.list { PostingList::Compressed(ref list) => { @@ -540,6 +650,15 @@ impl PartialEq for TailPosting { } } +#[derive(Default)] +struct AndWindowStats { + windows_wide: usize, + windows_narrow: usize, + windows_skipped: usize, + range_blocks_scanned: usize, + candidates_returned: usize, +} + impl Eq for TailPosting {} impl PartialOrd for TailPosting { @@ -588,6 +707,7 @@ pub struct Wand<'a, S: Scorer> { // Last conjunction doc returned to the caller. The next conjunction search // resumes strictly after this doc, like Lucene's `nextDoc()/advance()`. and_last_doc: Option, + and_window_stats: AndWindowStats, and_candidates_pruned_before_return: usize, docs: &'a DocSet, scorer: S, @@ -651,6 +771,7 @@ impl<'a, S: Scorer> Wand<'a, S> { up_to: None, and_max_score: f32::INFINITY, and_last_doc: None, + and_window_stats: AndWindowStats::default(), and_candidates_pruned_before_return: 0, docs, scorer, @@ -824,6 +945,16 @@ impl<'a, S: Scorer> Wand<'a, S> { self.push_back_leads(doc.doc_id() + 1); } } + if self.operator == Operator::And { + tracing::debug!( + and_windows_wide = self.and_window_stats.windows_wide, + and_windows_narrow = self.and_window_stats.windows_narrow, + and_windows_skipped = self.and_window_stats.windows_skipped, + and_range_blocks_scanned = self.and_window_stats.range_blocks_scanned, + and_candidates_returned = self.and_window_stats.candidates_returned, + "fts conjunction block-max window stats" + ); + } metrics.record_comparisons(num_comparisons); let and_candidates_pruned_before_return = self .and_candidates_pruned_before_return @@ -1035,7 +1166,11 @@ impl<'a, S: Scorer> Wand<'a, S> { // from `tail` iterators that are advanced to the same doc later. fn next(&mut self) -> Result> { if self.operator == Operator::And { - return Ok(self.next_and_candidate().map(|doc| (doc, 0.0))); + let candidate = self.next_and_candidate(); + if candidate.is_some() { + self.and_window_stats.candidates_returned += 1; + } + return Ok(candidate.map(|doc| (doc, 0.0))); } while let Some(target) = self.head_doc() { @@ -1167,6 +1302,14 @@ impl<'a, S: Scorer> Wand<'a, S> { } } + fn posting_block_up_to(posting: &PostingIterator, target: u64) -> u64 { + posting + .next_block_first_doc() + .map(|doc| doc.saturating_sub(1)) + .unwrap_or(TERMINATED_DOC_ID) + .max(target) + } + fn and_move_to_next_block(&mut self, target: u64) { if self.threshold <= 0.0 { self.up_to = Some(target); @@ -1174,19 +1317,65 @@ impl<'a, S: Scorer> Wand<'a, S> { return; } - let mut up_to = TERMINATED_DOC_ID; - let mut max_score = 0.0; + if self.lead.is_empty() { + self.up_to = Some(TERMINATED_DOC_ID); + self.and_max_score = 0.0; + return; + } + for posting in &mut self.lead { posting.shallow_next(target); - let block_end = posting - .next_block_first_doc() - .map(|doc| doc.saturating_sub(1)) - .unwrap_or(TERMINATED_DOC_ID); - up_to = up_to.min(block_end.max(target)); - max_score += posting.block_max_score(); } - self.up_to = Some(up_to); - self.and_max_score = max_score; + + let narrow_up_to = self + .lead + .iter() + .map(|posting| Self::posting_block_up_to(posting, target)) + .min() + .unwrap_or(TERMINATED_DOC_ID); + let narrow_max_score = self + .lead + .iter() + .map(|posting| posting.block_max_score()) + .sum::(); + + if narrow_max_score >= self.threshold { + self.up_to = Some(narrow_up_to); + self.and_max_score = narrow_max_score; + self.and_window_stats.windows_narrow += 1; + return; + } + + let lead_up_to = self + .lead + .first() + .map(|posting| Self::posting_block_up_to(posting, target)) + .unwrap_or(TERMINATED_DOC_ID); + let can_try_wide = lead_up_to > narrow_up_to + && lead_up_to != TERMINATED_DOC_ID + && self.lead.iter().all(|posting| posting.is_compressed()); + + if can_try_wide { + let mut wide_max_score = 0.0; + let mut range_blocks_scanned = 0; + for posting in &mut self.lead { + let block_max = posting.block_max_score_up_to_with_stats(lead_up_to); + wide_max_score += block_max.score; + range_blocks_scanned += block_max.blocks_scanned; + } + self.and_window_stats.range_blocks_scanned += range_blocks_scanned; + + if wide_max_score < self.threshold { + self.up_to = Some(lead_up_to); + self.and_max_score = wide_max_score; + self.and_window_stats.windows_wide += 1; + return; + } + } + + self.up_to = Some(narrow_up_to); + self.and_max_score = narrow_max_score; + self.and_window_stats.windows_narrow += 1; } fn and_advance_target(&mut self, mut target: u64) -> u64 { @@ -1201,6 +1390,7 @@ impl<'a, S: Scorer> Wand<'a, S> { if self.and_max_score >= self.threshold { return target; } + self.and_window_stats.windows_skipped += 1; if up_to == TERMINATED_DOC_ID { return TERMINATED_DOC_ID; } @@ -1900,6 +2090,18 @@ mod tests { } } + fn sorted_candidate_row_ids(candidates: Vec) -> Vec { + let mut row_ids = candidates + .into_iter() + .map(|candidate| match candidate.addr { + CandidateAddr::RowId(row_id) => row_id, + CandidateAddr::Pending(doc_id) => doc_id as u64, + }) + .collect::>(); + row_ids.sort_unstable(); + row_ids + } + #[rstest] #[tokio::test] async fn test_wand(#[values(false, true)] is_compressed: bool) { @@ -2269,6 +2471,229 @@ mod tests { assert_eq!(candidate.0.doc_id(), BLOCK_SIZE as u64); } + #[test] + fn test_and_advance_falls_back_to_narrow_when_range_max_loosens_bound() { + let total = 4 * BLOCK_SIZE as u32; + let mut docs = DocSet::default(); + for i in 0..total { + docs.append(i as u64, 1); + } + + let lead_docs = (0..total).step_by(2).collect::>(); + let follower_docs = (0..total).collect::>(); + let postings = vec![ + PostingIterator::with_query_weight( + String::from("lead"), + 0, + 0, + 1.0, + generate_posting_list(lead_docs, 1.0, Some(vec![1.0, 1.0]), true), + docs.len(), + ), + PostingIterator::with_query_weight( + String::from("follower"), + 1, + 1, + 1.0, + generate_posting_list(follower_docs, 10.0, Some(vec![0.1, 10.0, 0.1, 0.1]), true), + docs.len(), + ), + ]; + + let mut wand = Wand::new(Operator::And, postings.into_iter(), &docs, UnitScorer); + wand.threshold = 5.0; + + let target = wand.and_advance_target(0); + + assert_eq!(target, BLOCK_SIZE as u64); + assert_eq!(wand.up_to, Some((2 * BLOCK_SIZE - 1) as u64)); + assert!( + (wand.and_max_score - 11.0).abs() < 1e-6, + "expected the second narrow window to include the high follower block, got {}", + wand.and_max_score + ); + assert_eq!(wand.and_window_stats.windows_wide, 0); + assert_eq!(wand.and_window_stats.windows_narrow, 2); + assert_eq!(wand.and_window_stats.windows_skipped, 1); + } + + #[test] + fn test_and_advance_uses_narrow_window_for_candidate_ranges() { + let total = 4 * BLOCK_SIZE as u32; + let mut docs = DocSet::default(); + for i in 0..total { + docs.append(i as u64, 1); + } + + let lead_docs = (0..total).step_by(2).collect::>(); + let follower_docs = (0..total).collect::>(); + let postings = vec![ + PostingIterator::with_query_weight( + String::from("lead"), + 0, + 0, + 1.0, + generate_posting_list(lead_docs, 1.0, Some(vec![1.0, 1.0]), true), + docs.len(), + ), + PostingIterator::with_query_weight( + String::from("follower"), + 1, + 1, + 1.0, + generate_posting_list(follower_docs, 1.0, Some(vec![1.0, 1.0, 1.0, 1.0]), true), + docs.len(), + ), + ]; + + let mut wand = Wand::new(Operator::And, postings.into_iter(), &docs, UnitScorer); + wand.threshold = 1.5; + + let target = wand.and_advance_target(0); + + assert_eq!(target, 0); + assert_eq!(wand.up_to, Some((BLOCK_SIZE - 1) as u64)); + assert!((wand.and_max_score - 2.0).abs() < 1e-6); + assert_eq!(wand.and_window_stats.windows_wide, 0); + assert_eq!(wand.and_window_stats.windows_narrow, 1); + assert_eq!(wand.and_window_stats.range_blocks_scanned, 0); + } + + #[test] + fn test_and_wide_window_only_skips_and_does_not_return_candidates() { + let total = 4 * BLOCK_SIZE as u32; + let mut docs = DocSet::default(); + for i in 0..total { + docs.append(i as u64, 1); + } + + let lead_docs = (0..total).step_by(2).collect::>(); + let follower_docs = (0..total).collect::>(); + let postings = vec![ + PostingIterator::with_query_weight( + String::from("lead"), + 0, + 0, + 1.0, + generate_posting_list(lead_docs, 3.0, Some(vec![1.0, 3.0]), true), + docs.len(), + ), + PostingIterator::with_query_weight( + String::from("follower"), + 1, + 1, + 1.0, + generate_posting_list(follower_docs, 3.0, Some(vec![0.1, 0.1, 3.0, 3.0]), true), + docs.len(), + ), + ]; + + let mut wand = Wand::new(Operator::And, postings.into_iter(), &docs, UnitScorer); + wand.threshold = 2.0; + + let candidate = wand.next().unwrap().unwrap(); + + assert_eq!(candidate.0.doc_id(), (2 * BLOCK_SIZE) as u64); + assert_eq!(wand.up_to, Some((3 * BLOCK_SIZE - 1) as u64)); + assert_eq!(wand.and_window_stats.windows_wide, 1); + assert_eq!(wand.and_window_stats.windows_skipped, 1); + assert_eq!(wand.and_window_stats.windows_narrow, 1); + assert_eq!(wand.and_window_stats.candidates_returned, 1); + } + + #[test] + fn test_and_range_max_preserves_exact_top_k() { + let total = 4 * BLOCK_SIZE as u32; + let hot = BLOCK_SIZE as u32..BLOCK_SIZE as u32 + 16; + let mut docs = DocSet::default(); + for doc_id in 0..total { + let doc_tokens = if hot.contains(&doc_id) { 1 } else { 1000 }; + docs.append(doc_id as u64, doc_tokens); + } + + let params = FtsSearchParams::new().with_limit(Some(8)); + let run = |is_compressed: bool| { + let lead_docs = (0..total).step_by(2).collect::>(); + let follower_docs = (0..total).collect::>(); + let lead_scores = is_compressed.then_some(vec![1.0, 0.001]); + let follower_scores = is_compressed.then_some(vec![0.001, 1.0, 0.001, 0.001]); + let postings = vec![ + PostingIterator::with_query_weight( + String::from("lead"), + 0, + 0, + 1.0, + generate_posting_list(lead_docs, 1.0, lead_scores, is_compressed), + docs.len(), + ), + PostingIterator::with_query_weight( + String::from("follower"), + 1, + 1, + 1.0, + generate_posting_list(follower_docs, 1.0, follower_scores, is_compressed), + docs.len(), + ), + ]; + let mut wand = Wand::new( + Operator::And, + postings.into_iter(), + &docs, + InverseDocLengthScorer, + ); + sorted_candidate_row_ids( + wand.search( + ¶ms, + Arc::new(RowAddrMask::default()), + &NoOpMetricsCollector, + ) + .unwrap(), + ) + }; + + let compressed = run(true); + let plain = run(false); + let expected = hot.step_by(2).map(u64::from).collect::>(); + assert_eq!(compressed, expected); + assert_eq!(compressed, plain); + } + + #[test] + fn test_block_max_score_up_to_slides_and_expires_old_max() { + let total = 5 * BLOCK_SIZE as u32; + let posting = generate_posting_list( + (0..total).collect(), + 5.0, + Some(vec![1.0, 4.0, 2.0, 5.0, 3.0]), + true, + ); + let mut posting = PostingIterator::new(String::from("term"), 0, 0, posting, total as usize); + + posting.shallow_next(0); + assert_eq!( + posting + .block_max_score_up_to_with_stats((3 * BLOCK_SIZE - 1) as u64) + .score, + 4.0 + ); + + posting.shallow_next((2 * BLOCK_SIZE) as u64); + assert_eq!( + posting + .block_max_score_up_to_with_stats((4 * BLOCK_SIZE - 1) as u64) + .score, + 5.0 + ); + + posting.shallow_next((4 * BLOCK_SIZE) as u64); + assert_eq!( + posting + .block_max_score_up_to_with_stats((5 * BLOCK_SIZE - 1) as u64) + .score, + 3.0 + ); + } + #[test] fn test_and_candidate_prune_scores_first_term_before_full_score() { let total_docs = 2 * BLOCK_SIZE as u32 + 1;