diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index ac13fc0c585..c23dc1c4e78 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -49,7 +49,7 @@ use lance_core::{Error, ROW_ID, ROW_ID_FIELD, Result}; use lance_select::{RowAddrMask, RowAddrTreeMap}; use roaring::RoaringBitmap; use std::sync::LazyLock; -use tokio::task::spawn_blocking; +use tokio::{sync::OnceCell, task::spawn_blocking}; use tracing::{info, instrument}; use super::encoding::{PositionBlockBuilder, decode_group_starts}; @@ -419,6 +419,7 @@ pub struct InvertedIndex { tokenizer: Box, token_set_format: TokenSetFormat, pub(crate) partitions: Vec>, + corpus_stats: Arc>, // Fragments which are contained in the index, but no longer in the dataset. // These should be pruned at search time since we don't prune them at update time. deleted_fragments: RoaringBitmap, @@ -666,21 +667,26 @@ impl InvertedIndex { /// `LazyDocSet`. Avoids materializing the full DocSet just to get /// these two scalars. async fn aggregate_corpus_stats(&self) -> Result<(u64, usize)> { - let io_parallelism = self.store.io_parallelism(); - let num_docs: usize = self.partitions.iter().map(|p| p.docs.len()).sum(); - let futures = self - .partitions - .iter() - .map(|p| { - let docs = p.docs.clone(); - async move { docs.total_tokens_num().await } + self.corpus_stats + .get_or_try_init(|| async { + let io_parallelism = self.store.io_parallelism(); + let num_docs: usize = self.partitions.iter().map(|p| p.docs.len()).sum(); + let futures = self + .partitions + .iter() + .map(|p| { + let docs = p.docs.clone(); + async move { docs.total_tokens_num().await } + }) + .collect::>(); + let totals: Vec = stream::iter(futures) + .buffer_unordered(io_parallelism) + .try_collect() + .await?; + Ok((totals.into_iter().sum(), num_docs)) }) - .collect::>(); - let totals: Vec = stream::iter(futures) - .buffer_unordered(io_parallelism) - .try_collect() - .await?; - Ok((totals.into_iter().sum(), num_docs)) + .await + .copied() } /// Sum the posting-list length for `term` across this index's partitions @@ -997,6 +1003,7 @@ impl InvertedIndex { docs: Arc::new(LazyDocSet::from_loaded(docs)), token_set_format: TokenSetFormat::Arrow, })], + corpus_stats: Arc::new(OnceCell::new()), deleted_fragments: RoaringBitmap::new(), })) } @@ -1084,6 +1091,7 @@ impl InvertedIndex { tokenizer, token_set_format, partitions, + corpus_stats: Arc::new(OnceCell::new()), deleted_fragments, })) } @@ -2193,7 +2201,7 @@ enum PostingMetadata { /// `ensure_metadata_loaded`, and the stats path can also fetch a single /// token via `posting_len_for_token` without forcing the bulk load. V2 { - metadata: tokio::sync::OnceCell, + metadata: OnceCell, }, } @@ -2272,7 +2280,7 @@ impl PostingListReader { } } else { PostingMetadata::V2 { - metadata: tokio::sync::OnceCell::new(), + metadata: OnceCell::new(), } }; @@ -2373,10 +2381,10 @@ impl PostingListReader { } /// Async access to a single token's posting list length. For v2 - /// indexes this reads a single row from `LENGTH_COL` if the bulk metadata - /// has not been loaded yet, and never triggers the bulk load itself. The - /// stats path uses this so a single-term `df` lookup costs O(1) bytes - /// rather than O(num_unique_tokens). + /// indexes this reads one row of posting metadata if the bulk metadata has + /// not been loaded yet, and never triggers the bulk load itself. The stats + /// path uses this so a single-term `df` lookup costs O(1) bytes rather + /// than O(num_unique_tokens). pub(crate) async fn posting_len_for_token(&self, token_id: u32) -> Result { match &self.metadata { PostingMetadata::LegacyV1 { .. } => Ok(self.posting_len(token_id)), @@ -2384,13 +2392,10 @@ impl PostingListReader { if let Some(metadata) = metadata.get() { return Ok(metadata.lengths[token_id as usize] as usize); } - let token_id = token_id as usize; - let batch = self - .reader - .read_range(token_id..token_id + 1, Some(&[LENGTH_COL])) - .await?; - let len = batch[LENGTH_COL].as_primitive::().value(0); - Ok(len as usize) + let (_, length) = self.posting_metadata_for_token(token_id).await?; + length + .map(|len| len as usize) + .ok_or_else(|| Error::index("posting length metadata missing".to_string())) } } } @@ -2416,17 +2421,20 @@ impl PostingListReader { Some(loaded.lengths[token_id as usize]), )); } - let token_id_usize = token_id as usize; - let batch = self - .reader - .read_range( - token_id_usize..token_id_usize + 1, - Some(&[MAX_SCORE_COL, LENGTH_COL]), - ) + let metadata = self + .index_cache + .get_or_insert_with_key(PostingMetadataKey { token_id }, || async move { + let token_id = token_id as usize; + let batch = self + .reader + .read_range(token_id..token_id + 1, Some(&[MAX_SCORE_COL, LENGTH_COL])) + .await?; + let max_score = batch[MAX_SCORE_COL].as_primitive::().value(0); + let length = batch[LENGTH_COL].as_primitive::().value(0); + Ok(PostingMetadataValue { max_score, length }) + }) .await?; - let max_score = batch[MAX_SCORE_COL].as_primitive::().value(0); - let length = batch[LENGTH_COL].as_primitive::().value(0); - Ok((Some(max_score), Some(length))) + Ok((Some(metadata.max_score), Some(metadata.length))) } } } @@ -3241,6 +3249,29 @@ impl CacheKey for PostingListGroupKey { } } +#[derive(Debug, Clone, DeepSizeOf)] +struct PostingMetadataValue { + max_score: f32, + length: u32, +} + +#[derive(Debug, Clone)] +struct PostingMetadataKey { + token_id: u32, +} + +impl CacheKey for PostingMetadataKey { + type ValueType = PostingMetadataValue; + + fn key(&self) -> std::borrow::Cow<'_, str> { + format!("posting-metadata-{}", self.token_id).into() + } + + fn type_name() -> &'static str { + "PostingMetadata" + } +} + #[derive(Debug, Clone)] pub struct PositionKey { pub token_id: u32, @@ -6845,6 +6876,7 @@ mod tests { // when the test exercises a scoring path. async fn load_counted_v2_index( num_tokens: usize, + cache: LanceCache, ) -> (Arc, Arc, TempObjDir) { let tmpdir = TempObjDir::default(); let inner_store = Arc::new(LanceIndexStore::new( @@ -6889,7 +6921,7 @@ mod tests { posting_file: posting_file_path(0), counter: counter.clone(), }); - let index = InvertedIndex::load(counting_store, None, &LanceCache::no_cache()) + let index = InvertedIndex::load(counting_store, None, &cache) .await .unwrap(); (index, counter, tmpdir) @@ -6902,9 +6934,9 @@ mod tests { /// /// * `InvertedIndex::load` does not touch the posting file at all /// (`InvertedPartition::load` only needs the token file and docs file). - /// * `bm25_stats_for_terms(["t0"])` reads exactly one row from the - /// posting file (the single LENGTH_COL entry for token 0) regardless - /// of how many unique tokens the partition has. + /// * `bm25_stats_for_terms(["t0"])` reads exactly one metadata row from + /// the posting file for token 0 regardless of how many unique tokens the + /// partition has. /// /// Before this refactor, `PostingListReader::try_new` did /// `read_range(0..num_rows, [MAX_SCORE_COL, LENGTH_COL])`, so the @@ -6917,7 +6949,8 @@ mod tests { #[case::tokens_1000(1000)] #[tokio::test] async fn test_bm25_stats_for_terms_is_lazy(#[case] num_tokens: usize) { - let (index, counter, _tmpdir) = load_counted_v2_index(num_tokens).await; + let (index, counter, _tmpdir) = + load_counted_v2_index(num_tokens, LanceCache::no_cache()).await; assert!( !index.partitions[0].inverted_list.is_legacy_layout(), "this test only proves the lazy path for v2 indexes", @@ -6957,6 +6990,38 @@ mod tests { ); } + #[tokio::test] + async fn test_bm25_stats_for_terms_reuses_posting_metadata_cache() { + let cache = LanceCache::with_capacity(1024 * 1024); + let (index, counter, _tmpdir) = load_counted_v2_index(100, cache.clone()).await; + + let terms = ["t0".to_string()]; + let first = index.bm25_stats_for_terms(&terms).await.unwrap(); + assert_eq!(first, (100, 100, vec![1])); + assert_eq!(counter.metadata_rows_read(), 1); + + let second = index.bm25_stats_for_terms(&terms).await.unwrap(); + assert_eq!(second, first); + assert_eq!( + counter.metadata_rows_read(), + 1, + "repeated stats for the same token should reuse cached posting metadata", + ); + } + + #[tokio::test] + async fn test_aggregate_corpus_stats_reuses_cached_value() { + let (index, _counter, _tmpdir) = load_counted_v2_index(100, LanceCache::no_cache()).await; + assert!(index.corpus_stats.get().is_none()); + + let first = index.aggregate_corpus_stats().await.unwrap(); + assert_eq!(first, (100, 100)); + assert_eq!(index.corpus_stats.get().copied(), Some(first)); + + let second = index.aggregate_corpus_stats().await.unwrap(); + assert_eq!(second, first); + } + #[tokio::test] async fn test_grouped_posting_lists_read_one_group_per_neighborhood() { // Cold-start scoring must not bulk-read the full `0..num_tokens` @@ -6966,7 +7031,8 @@ mod tests { // total token count. let num_tokens = 500; let queried_tokens: [u32; 4] = [0, 1, 2, 3]; - let (index, counter, _tmpdir) = load_counted_v2_index(num_tokens).await; + let (index, counter, _tmpdir) = + load_counted_v2_index(num_tokens, LanceCache::no_cache()).await; let inverted_list = index.partitions[0].inverted_list.clone(); assert!( !inverted_list.is_legacy_layout(),