Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 110 additions & 44 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use lance_core::{Error, ROW_ID, ROW_ID_FIELD, Result};
use lance_select::{RowAddrMask, RowAddrTreeMap};
use roaring::RoaringBitmap;
use std::sync::LazyLock;
use tokio::task::spawn_blocking;
use tokio::{sync::OnceCell, task::spawn_blocking};
use tracing::{info, instrument};

use super::encoding::{PositionBlockBuilder, decode_group_starts};
Expand Down Expand Up @@ -419,6 +419,7 @@ pub struct InvertedIndex {
tokenizer: Box<dyn LanceTokenizer>,
token_set_format: TokenSetFormat,
pub(crate) partitions: Vec<Arc<InvertedPartition>>,
corpus_stats: Arc<OnceCell<(u64, usize)>>,
// Fragments which are contained in the index, but no longer in the dataset.
// These should be pruned at search time since we don't prune them at update time.
deleted_fragments: RoaringBitmap,
Expand Down Expand Up @@ -666,21 +667,26 @@ impl InvertedIndex {
/// `LazyDocSet`. Avoids materializing the full DocSet just to get
/// these two scalars.
async fn aggregate_corpus_stats(&self) -> Result<(u64, usize)> {
let io_parallelism = self.store.io_parallelism();
let num_docs: usize = self.partitions.iter().map(|p| p.docs.len()).sum();
let futures = self
.partitions
.iter()
.map(|p| {
let docs = p.docs.clone();
async move { docs.total_tokens_num().await }
self.corpus_stats
.get_or_try_init(|| async {
let io_parallelism = self.store.io_parallelism();
let num_docs: usize = self.partitions.iter().map(|p| p.docs.len()).sum();
let futures = self
.partitions
.iter()
.map(|p| {
let docs = p.docs.clone();
async move { docs.total_tokens_num().await }
})
.collect::<Vec<_>>();
let totals: Vec<u64> = stream::iter(futures)
.buffer_unordered(io_parallelism)
.try_collect()
.await?;
Ok((totals.into_iter().sum(), num_docs))
})
.collect::<Vec<_>>();
let totals: Vec<u64> = stream::iter(futures)
.buffer_unordered(io_parallelism)
.try_collect()
.await?;
Ok((totals.into_iter().sum(), num_docs))
.await
.copied()
}

/// Sum the posting-list length for `term` across this index's partitions
Expand Down Expand Up @@ -997,6 +1003,7 @@ impl InvertedIndex {
docs: Arc::new(LazyDocSet::from_loaded(docs)),
token_set_format: TokenSetFormat::Arrow,
})],
corpus_stats: Arc::new(OnceCell::new()),
deleted_fragments: RoaringBitmap::new(),
}))
}
Expand Down Expand Up @@ -1084,6 +1091,7 @@ impl InvertedIndex {
tokenizer,
token_set_format,
partitions,
corpus_stats: Arc::new(OnceCell::new()),
deleted_fragments,
}))
}
Expand Down Expand Up @@ -2193,7 +2201,7 @@ enum PostingMetadata {
/// `ensure_metadata_loaded`, and the stats path can also fetch a single
/// token via `posting_len_for_token` without forcing the bulk load.
V2 {
metadata: tokio::sync::OnceCell<LoadedPostingMetadata>,
metadata: OnceCell<LoadedPostingMetadata>,
},
}

Expand Down Expand Up @@ -2272,7 +2280,7 @@ impl PostingListReader {
}
} else {
PostingMetadata::V2 {
metadata: tokio::sync::OnceCell::new(),
metadata: OnceCell::new(),
}
};

Expand Down Expand Up @@ -2373,24 +2381,21 @@ impl PostingListReader {
}

/// Async access to a single token's posting list length. For v2
/// indexes this reads a single row from `LENGTH_COL` if the bulk metadata
/// has not been loaded yet, and never triggers the bulk load itself. The
/// stats path uses this so a single-term `df` lookup costs O(1) bytes
/// rather than O(num_unique_tokens).
/// indexes this reads one row of posting metadata if the bulk metadata has
/// not been loaded yet, and never triggers the bulk load itself. The stats
/// path uses this so a single-term `df` lookup costs O(1) bytes rather
/// than O(num_unique_tokens).
pub(crate) async fn posting_len_for_token(&self, token_id: u32) -> Result<usize> {
match &self.metadata {
PostingMetadata::LegacyV1 { .. } => Ok(self.posting_len(token_id)),
PostingMetadata::V2 { metadata } => {
if let Some(metadata) = metadata.get() {
return Ok(metadata.lengths[token_id as usize] as usize);
}
let token_id = token_id as usize;
let batch = self
.reader
.read_range(token_id..token_id + 1, Some(&[LENGTH_COL]))
.await?;
let len = batch[LENGTH_COL].as_primitive::<UInt32Type>().value(0);
Ok(len as usize)
let (_, length) = self.posting_metadata_for_token(token_id).await?;
length
.map(|len| len as usize)
.ok_or_else(|| Error::index("posting length metadata missing".to_string()))
}
}
}
Expand All @@ -2416,17 +2421,20 @@ impl PostingListReader {
Some(loaded.lengths[token_id as usize]),
));
}
let token_id_usize = token_id as usize;
let batch = self
.reader
.read_range(
token_id_usize..token_id_usize + 1,
Some(&[MAX_SCORE_COL, LENGTH_COL]),
)
let metadata = self
.index_cache
.get_or_insert_with_key(PostingMetadataKey { token_id }, || async move {
let token_id = token_id as usize;
let batch = self
.reader
.read_range(token_id..token_id + 1, Some(&[MAX_SCORE_COL, LENGTH_COL]))
.await?;
let max_score = batch[MAX_SCORE_COL].as_primitive::<Float32Type>().value(0);
let length = batch[LENGTH_COL].as_primitive::<UInt32Type>().value(0);
Ok(PostingMetadataValue { max_score, length })
})
.await?;
let max_score = batch[MAX_SCORE_COL].as_primitive::<Float32Type>().value(0);
let length = batch[LENGTH_COL].as_primitive::<UInt32Type>().value(0);
Ok((Some(max_score), Some(length)))
Ok((Some(metadata.max_score), Some(metadata.length)))
}
}
}
Expand Down Expand Up @@ -3241,6 +3249,29 @@ impl CacheKey for PostingListGroupKey {
}
}

#[derive(Debug, Clone, DeepSizeOf)]
struct PostingMetadataValue {
max_score: f32,
length: u32,
}

#[derive(Debug, Clone)]
struct PostingMetadataKey {
token_id: u32,
}

impl CacheKey for PostingMetadataKey {
type ValueType = PostingMetadataValue;

fn key(&self) -> std::borrow::Cow<'_, str> {
format!("posting-metadata-{}", self.token_id).into()
}

fn type_name() -> &'static str {
"PostingMetadata"
}
}

#[derive(Debug, Clone)]
pub struct PositionKey {
pub token_id: u32,
Expand Down Expand Up @@ -6845,6 +6876,7 @@ mod tests {
// when the test exercises a scoring path.
async fn load_counted_v2_index(
num_tokens: usize,
cache: LanceCache,
) -> (Arc<InvertedIndex>, Arc<PostingMetadataCounter>, TempObjDir) {
let tmpdir = TempObjDir::default();
let inner_store = Arc::new(LanceIndexStore::new(
Expand Down Expand Up @@ -6889,7 +6921,7 @@ mod tests {
posting_file: posting_file_path(0),
counter: counter.clone(),
});
let index = InvertedIndex::load(counting_store, None, &LanceCache::no_cache())
let index = InvertedIndex::load(counting_store, None, &cache)
.await
.unwrap();
(index, counter, tmpdir)
Expand All @@ -6902,9 +6934,9 @@ mod tests {
///
/// * `InvertedIndex::load` does not touch the posting file at all
/// (`InvertedPartition::load` only needs the token file and docs file).
/// * `bm25_stats_for_terms(["t0"])` reads exactly one row from the
/// posting file (the single LENGTH_COL entry for token 0) regardless
/// of how many unique tokens the partition has.
/// * `bm25_stats_for_terms(["t0"])` reads exactly one metadata row from
/// the posting file for token 0 regardless of how many unique tokens the
/// partition has.
///
/// Before this refactor, `PostingListReader::try_new` did
/// `read_range(0..num_rows, [MAX_SCORE_COL, LENGTH_COL])`, so the
Expand All @@ -6917,7 +6949,8 @@ mod tests {
#[case::tokens_1000(1000)]
#[tokio::test]
async fn test_bm25_stats_for_terms_is_lazy(#[case] num_tokens: usize) {
let (index, counter, _tmpdir) = load_counted_v2_index(num_tokens).await;
let (index, counter, _tmpdir) =
load_counted_v2_index(num_tokens, LanceCache::no_cache()).await;
assert!(
!index.partitions[0].inverted_list.is_legacy_layout(),
"this test only proves the lazy path for v2 indexes",
Expand Down Expand Up @@ -6957,6 +6990,38 @@ mod tests {
);
}

#[tokio::test]
async fn test_bm25_stats_for_terms_reuses_posting_metadata_cache() {
let cache = LanceCache::with_capacity(1024 * 1024);
let (index, counter, _tmpdir) = load_counted_v2_index(100, cache.clone()).await;

let terms = ["t0".to_string()];
let first = index.bm25_stats_for_terms(&terms).await.unwrap();
assert_eq!(first, (100, 100, vec![1]));
assert_eq!(counter.metadata_rows_read(), 1);

let second = index.bm25_stats_for_terms(&terms).await.unwrap();
assert_eq!(second, first);
assert_eq!(
counter.metadata_rows_read(),
1,
"repeated stats for the same token should reuse cached posting metadata",
);
}

#[tokio::test]
async fn test_aggregate_corpus_stats_reuses_cached_value() {
let (index, _counter, _tmpdir) = load_counted_v2_index(100, LanceCache::no_cache()).await;
assert!(index.corpus_stats.get().is_none());

let first = index.aggregate_corpus_stats().await.unwrap();
assert_eq!(first, (100, 100));
assert_eq!(index.corpus_stats.get().copied(), Some(first));

let second = index.aggregate_corpus_stats().await.unwrap();
assert_eq!(second, first);
}

#[tokio::test]
async fn test_grouped_posting_lists_read_one_group_per_neighborhood() {
// Cold-start scoring must not bulk-read the full `0..num_tokens`
Expand All @@ -6966,7 +7031,8 @@ mod tests {
// total token count.
let num_tokens = 500;
let queried_tokens: [u32; 4] = [0, 1, 2, 3];
let (index, counter, _tmpdir) = load_counted_v2_index(num_tokens).await;
let (index, counter, _tmpdir) =
load_counted_v2_index(num_tokens, LanceCache::no_cache()).await;
let inverted_list = index.partitions[0].inverted_list.clone();
assert!(
!inverted_list.is_legacy_layout(),
Expand Down
Loading