diff --git a/java/src/main/java/org/lance/index/scalar/InvertedIndexParams.java b/java/src/main/java/org/lance/index/scalar/InvertedIndexParams.java index ca0a7a46c70..31333cfd76b 100755 --- a/java/src/main/java/org/lance/index/scalar/InvertedIndexParams.java +++ b/java/src/main/java/org/lance/index/scalar/InvertedIndexParams.java @@ -53,6 +53,7 @@ public static final class Builder { private Integer minNgramLength; private Integer maxNgramLength; private Boolean prefixOnly; + private Boolean disableCrossArrayUnnest; private Boolean skipMerge; /** @@ -223,6 +224,22 @@ public Builder prefixOnly(boolean prefixOnly) { return this; } + /** + * Configure whether flattened JSON tokenization avoids cross-array unnesting. + * + *

When true, sibling arrays are indexed independently instead of producing their Cartesian + * product. This can reduce index build memory for JSON records with multiple arrays but can + * sacrifice result accuracy for queries that constrain values across those arrays. The default + * is false. + * + * @param disableCrossArrayUnnest whether to avoid cross-array unnesting + * @return this builder + */ + public Builder disableCrossArrayUnnest(boolean disableCrossArrayUnnest) { + this.disableCrossArrayUnnest = disableCrossArrayUnnest; + return this; + } + /** * Configure whether to skip the partition merge stage after indexing. If true, skip the * partition merge stage after indexing. This can be useful for distributed indexing where merge @@ -280,6 +297,9 @@ public ScalarIndexParams build() { if (prefixOnly != null) { params.put("prefix_only", prefixOnly); } + if (disableCrossArrayUnnest != null) { + params.put("disable_cross_array_unnest", disableCrossArrayUnnest); + } if (skipMerge != null) { params.put("skip_merge", skipMerge); } diff --git a/protos/index_old.proto b/protos/index_old.proto index 601aa2681da..97b24f9b463 100644 --- a/protos/index_old.proto +++ b/protos/index_old.proto @@ -39,4 +39,10 @@ message InvertedIndexDetails { uint32 min_ngram_length = 9; uint32 max_ngram_length = 10; bool prefix_only = 11; + // JSON document tokenization mode. Absent means SingleDocument JSON tokenization, + // which is how indexes written before flattened JSON sub-docs are interpreted. + optional string json_tokenizer_mode = 12; + // If true, avoid cross-array unnesting during flattened JSON tokenization. + // The default false value preserves exact Cartesian-product semantics. + bool disable_cross_array_unnest = 13; } diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index e96d9305ce5..9e1cb31efb4 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3255,6 +3255,12 @@ def create_scalar_index( ``[1, num_compute_cpus]``. If unset, Lance uses ``num_compute_cpus`` workers unless ``LANCE_FTS_NUM_SHARDS`` is set. This parameter is only used for the current build and is not persisted with the index. + disable_cross_array_unnest: bool, default False + This is for the ``INVERTED`` index on JSON columns. If True, flattened + JSON tokenization indexes sibling arrays independently instead of + producing their Cartesian product. This reduces index build memory for + records with multiple arrays but can sacrifice result accuracy for + queries that constrain values across those arrays. base_tokenizer: str, default "simple" This is for the ``INVERTED`` index. The base tokenizer to use. The value can be: diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 7ddfbbc0dc8..f7de546b6db 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -4727,7 +4727,10 @@ def test_json_inverted_match_query(tmp_path): stem=True, lower_case=True, remove_stop_words=True, + disable_cross_array_unnest=True, ) + details = dataset.describe_indices()[0].details + assert details["disable_cross_array_unnest"] is True # Test match query with token exceeding max_token_length results = dataset.to_table( @@ -4743,7 +4746,7 @@ def test_json_inverted_match_query(tmp_path): # Test language match results = dataset.to_table( - full_text_query=MatchQuery("Language,str,english", "json_col") + full_text_query=MatchQuery("Language[*],str,english", "json_col") ) assert results.num_rows == 3 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 8bfa81aeae4..f8cd3094de8 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -2292,6 +2292,12 @@ impl Dataset { if let Some(prefix_only) = kwargs.get_item("prefix_only")? { params = params.ngram_prefix_only(prefix_only.extract()?); } + if let Some(disable_cross_array_unnest) = + kwargs.get_item("disable_cross_array_unnest")? + { + params = params + .disable_cross_array_unnest(disable_cross_array_unnest.extract()?); + } if let Some(memory_limit) = kwargs.get_item("memory_limit")? { params = params.memory_limit_mb(memory_limit.extract()?); } diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index d0bb0e40d3a..46ba6b0056f 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -147,11 +147,11 @@ impl InvertedIndexPlugin { } }); - let details = pbold::InvertedIndexDetails::try_from(¶ms)?; let mut inverted_index = InvertedIndexBuilder::new_with_fragment_mask(params, fragment_mask) .with_progress(progress); let files = inverted_index.update(data, index_store, None).await?; + let details = pbold::InvertedIndexDetails::try_from(inverted_index.params())?; Ok(CreatedIndex { index_details: prost_types::Any::from_msg(&details).unwrap(), index_version: current_fts_format_version().index_version(), diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index 24b1eb50203..d97fb017c25 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -3,7 +3,7 @@ use super::encoding::encode_group_starts; use super::{InvertedIndexParams, index::*}; -use crate::scalar::inverted::document_tokenizer::DocType; +use crate::scalar::inverted::document_tokenizer::{DocType, JsonTokenizerMode}; use crate::scalar::inverted::json::JsonTextStream; use crate::scalar::inverted::tokenizer::document_tokenizer::LanceTokenizer; #[cfg(test)] @@ -239,13 +239,19 @@ impl InvertedIndexBuilder { /// Constructed as `(fragment_id as u64) << 32`. /// When provided, ensures that generated IDs belong to the specified fragment. pub fn from_existing_index( - params: InvertedIndexParams, + mut params: InvertedIndexParams, store: Option>, partitions: Vec, token_set_format: TokenSetFormat, fragment_mask: Option, deleted_fragments: RoaringBitmap, ) -> Self { + if (store.is_some() || !partitions.is_empty()) + && params.lance_tokenizer.as_deref() == Some("json") + && params.json_tokenizer_mode.is_none() + { + params.json_tokenizer_mode = Some(JsonTokenizerMode::SingleDocument); + } Self { params, partitions, @@ -260,6 +266,17 @@ impl InvertedIndexBuilder { } } + fn configure_json_tokenizer_mode_for_new_data(&mut self, doc_type: DocType) { + if self.params.lance_tokenizer.is_none() { + self.params.lance_tokenizer = Some(doc_type.as_ref().to_string()); + } + if self.params.lance_tokenizer.as_deref() == Some("json") + && self.params.json_tokenizer_mode.is_none() + { + self.params.json_tokenizer_mode = Some(JsonTokenizerMode::FlattenedSubDocs); + } + } + pub fn with_posting_tail_codec(mut self, posting_tail_codec: PostingTailCodec) -> Self { self.format_version = InvertedListFormatVersion::from_posting_tail_codec(posting_tail_codec); @@ -283,6 +300,10 @@ impl InvertedIndexBuilder { self } + pub(crate) fn params(&self) -> &InvertedIndexParams { + &self.params + } + pub async fn update( &mut self, new_data: SendableRecordBatchStream, @@ -293,12 +314,9 @@ impl InvertedIndexBuilder { let doc_col = schema.field(0).name(); // infer lance_tokenizer based on document type - if self.params.lance_tokenizer.is_none() { - let schema = new_data.schema(); - let field = schema.column_with_name(doc_col).expect_ok()?.1; - let doc_type = DocType::try_from(field)?; - self.params.lance_tokenizer = Some(doc_type.as_ref().to_string()); - } + let field = schema.column_with_name(doc_col).expect_ok()?.1; + let doc_type = DocType::try_from(field)?; + self.configure_json_tokenizer_mode_for_new_data(doc_type); let new_data = document_input(new_data, doc_col)?; @@ -326,11 +344,9 @@ impl InvertedIndexBuilder { let schema = new_data.schema(); let doc_col = schema.field(0).name(); - if self.params.lance_tokenizer.is_none() { - let field = schema.column_with_name(doc_col).expect_ok()?.1; - let doc_type = DocType::try_from(field)?; - self.params.lance_tokenizer = Some(doc_type.as_ref().to_string()); - } + let field = schema.column_with_name(doc_col).expect_ok()?.1; + let doc_type = DocType::try_from(field)?; + self.configure_json_tokenizer_mode_for_new_data(doc_type); let mut files = self .merge_existing_segments(dest_store, old_segments, old_data_filter.as_ref()) @@ -1233,142 +1249,141 @@ impl IndexWorker { let with_position = self.has_position(); for (doc, row_id) in docs { - let builder_was_empty = self.builder.docs.is_empty(); - let old_temporary_memory_size = self.temporary_memory_size(); - let old_token_memory_size = self.builder.tokens.memory_size() as u64; - let doc_id = self.builder.docs.len() as u32; - let mut token_num: u32 = 0; - let mut posting_memory_delta = 0i64; - if with_position { - if self.token_ids.capacity() < self.last_token_count { - self.token_ids - .reserve(self.last_token_count - self.token_ids.capacity()); - } - self.token_ids.clear(); - let builder = &mut self.builder; - let token_ids = &mut self.token_ids; - let memory_size = &mut self.memory_size; - let posting_tail_codec = builder.posting_tail_codec; - - let mut token_stream = self.tokenizer.token_stream_for_doc(doc); - while token_stream.advance() { - let token = token_stream.token_mut(); - let token_text = std::mem::take(&mut token.text); - let token_id = builder.tokens.add(token_text); - if token_id as usize == builder.posting_lists.len() { - let old_posting_lists_overhead_size = (builder.posting_lists.capacity() - * std::mem::size_of::()) - as u64; - builder.posting_lists.push( - PostingListBuilder::new_with_posting_tail_codec( - true, - posting_tail_codec, - ), - ); - let new_posting_lists_overhead_size = (builder.posting_lists.capacity() - * std::mem::size_of::()) - as u64; - Self::adjust_tracked_value( - memory_size, - old_posting_lists_overhead_size, - new_posting_lists_overhead_size, - ); - } - let posting_list = &mut builder.posting_lists[token_id as usize]; - let old_posting_memory_size = posting_list.size(); - if posting_list.add_occurrence(doc_id, token.position as u32)? { - token_ids.push(token_id); - } - let new_posting_memory_size = posting_list.size(); - posting_memory_delta += - new_posting_memory_size as i64 - old_posting_memory_size as i64; - token_num += 1; - } - } else { - if self.token_ids.capacity() < self.last_token_count { - self.token_ids - .reserve(self.last_token_count - self.token_ids.capacity()); + self.total_doc_length += doc.len(); + let sub_docs = self.tokenizer.token_streams_for_doc(doc)?; + for tokens in sub_docs { + self.process_tokenized_doc(row_id, tokens, with_position) + .await?; + } + } + + Ok(()) + } + + async fn process_tokenized_doc( + &mut self, + row_id: u64, + mut tokens: Vec, + with_position: bool, + ) -> Result<()> { + let builder_was_empty = self.builder.docs.is_empty(); + let old_temporary_memory_size = self.temporary_memory_size(); + let old_token_memory_size = self.builder.tokens.memory_size() as u64; + let doc_id = self.builder.docs.len() as u32; + let mut token_num: u32 = 0; + let mut posting_memory_delta = 0i64; + if with_position { + if self.token_ids.capacity() < self.last_token_count { + self.token_ids + .reserve(self.last_token_count - self.token_ids.capacity()); + } + self.token_ids.clear(); + let builder = &mut self.builder; + let token_ids = &mut self.token_ids; + let memory_size = &mut self.memory_size; + let posting_tail_codec = builder.posting_tail_codec; + + for token in &mut tokens { + let token_text = std::mem::take(&mut token.text); + let token_id = builder.tokens.add(token_text); + if token_id as usize == builder.posting_lists.len() { + let old_posting_lists_overhead_size = (builder.posting_lists.capacity() + * std::mem::size_of::()) + as u64; + builder + .posting_lists + .push(PostingListBuilder::new_with_posting_tail_codec( + true, + posting_tail_codec, + )); + let new_posting_lists_overhead_size = (builder.posting_lists.capacity() + * std::mem::size_of::()) + as u64; + Self::adjust_tracked_value( + memory_size, + old_posting_lists_overhead_size, + new_posting_lists_overhead_size, + ); } - self.token_ids.clear(); - - let mut token_stream = self.tokenizer.token_stream_for_doc(doc); - while token_stream.advance() { - let token = token_stream.token_mut(); - let token_text = std::mem::take(&mut token.text); - let token_id = self.builder.tokens.add(token_text); - self.token_ids.push(token_id); - token_num += 1; + let posting_list = &mut builder.posting_lists[token_id as usize]; + let old_posting_memory_size = posting_list.size(); + if posting_list.add_occurrence(doc_id, token.position as u32)? { + token_ids.push(token_id); } + let new_posting_memory_size = posting_list.size(); + posting_memory_delta += + new_posting_memory_size as i64 - old_posting_memory_size as i64; + token_num += 1; } - self.adjust_tracked_memory_size( - old_token_memory_size, - self.builder.tokens.memory_size() as u64, - ); + } else { + if self.token_ids.capacity() < self.last_token_count { + self.token_ids + .reserve(self.last_token_count - self.token_ids.capacity()); + } + self.token_ids.clear(); - if !with_position { - let old_posting_lists_overhead_size = self.posting_lists_overhead_size(); - self.builder - .posting_lists - .resize_with(self.builder.tokens.len(), || { - PostingListBuilder::new_with_posting_tail_codec( - false, - self.builder.posting_tail_codec, - ) - }); - let new_posting_lists_overhead_size = self.posting_lists_overhead_size(); - Self::adjust_tracked_value( - &mut self.memory_size, - old_posting_lists_overhead_size, - new_posting_lists_overhead_size, - ); + for mut token in tokens { + let token_text = std::mem::take(&mut token.text); + let token_id = self.builder.tokens.add(token_text); + self.token_ids.push(token_id); + token_num += 1; } + } + self.adjust_tracked_memory_size( + old_token_memory_size, + self.builder.tokens.memory_size() as u64, + ); - let old_doc_memory_size = self.builder.docs.memory_size() as u64; - let appended_doc_id = self.builder.docs.append(row_id, token_num); - debug_assert_eq!(appended_doc_id, doc_id); - self.adjust_tracked_memory_size( - old_doc_memory_size, - self.builder.docs.memory_size() as u64, + if !with_position { + let old_posting_lists_overhead_size = self.posting_lists_overhead_size(); + self.builder + .posting_lists + .resize_with(self.builder.tokens.len(), || { + PostingListBuilder::new_with_posting_tail_codec( + false, + self.builder.posting_tail_codec, + ) + }); + let new_posting_lists_overhead_size = self.posting_lists_overhead_size(); + Self::adjust_tracked_value( + &mut self.memory_size, + old_posting_lists_overhead_size, + new_posting_lists_overhead_size, ); - self.total_doc_length += doc.len(); + } - if with_position { - for &token_id in &self.token_ids { - let (old_posting_memory_size, new_posting_memory_size) = { - let posting_list = &mut self.builder.posting_lists[token_id as usize]; - let old_posting_memory_size = posting_list.size(); - posting_list.finish_open_doc(doc_id)?; - let new_posting_memory_size = posting_list.size(); - (old_posting_memory_size, new_posting_memory_size) - }; - posting_memory_delta += - new_posting_memory_size as i64 - old_posting_memory_size as i64; - } - Self::apply_delta(&mut self.memory_size, posting_memory_delta); - } else if token_num > 0 { - self.token_ids.sort_unstable(); - let mut iter = self.token_ids.iter(); - let mut current = *iter.next().unwrap(); - let mut count = 1u32; - for &token_id in iter { - if token_id == current { - count += 1; - continue; - } + let old_doc_memory_size = self.builder.docs.memory_size() as u64; + let appended_doc_id = self.builder.docs.append(row_id, token_num); + debug_assert_eq!(appended_doc_id, doc_id); + self.adjust_tracked_memory_size( + old_doc_memory_size, + self.builder.docs.memory_size() as u64, + ); - let (old_posting_memory_size, new_posting_memory_size) = { - let posting_list = &mut self.builder.posting_lists[current as usize]; - let old_posting_memory_size = posting_list.size(); - posting_list.add(doc_id, PositionRecorder::Count(count)); - let new_posting_memory_size = posting_list.size(); - (old_posting_memory_size, new_posting_memory_size) - }; - posting_memory_delta += - new_posting_memory_size as i64 - old_posting_memory_size as i64; - - current = token_id; - count = 1; + if with_position { + for &token_id in &self.token_ids { + let (old_posting_memory_size, new_posting_memory_size) = { + let posting_list = &mut self.builder.posting_lists[token_id as usize]; + let old_posting_memory_size = posting_list.size(); + posting_list.finish_open_doc(doc_id)?; + let new_posting_memory_size = posting_list.size(); + (old_posting_memory_size, new_posting_memory_size) + }; + posting_memory_delta += + new_posting_memory_size as i64 - old_posting_memory_size as i64; + } + Self::apply_delta(&mut self.memory_size, posting_memory_delta); + } else if token_num > 0 { + self.token_ids.sort_unstable(); + let mut iter = self.token_ids.iter(); + let mut current = *iter.next().unwrap(); + let mut count = 1u32; + for &token_id in iter { + if token_id == current { + count += 1; + continue; } + let (old_posting_memory_size, new_posting_memory_size) = { let posting_list = &mut self.builder.posting_lists[current as usize]; let old_posting_memory_size = posting_list.size(); @@ -1378,29 +1393,36 @@ impl IndexWorker { }; posting_memory_delta += new_posting_memory_size as i64 - old_posting_memory_size as i64; - Self::apply_delta(&mut self.memory_size, posting_memory_delta); - } - self.last_token_count = self.token_ids.len(); - self.trim_temporary_buffers(); - self.adjust_tracked_memory_size( - old_temporary_memory_size, - self.temporary_memory_size(), - ); - if self.builder.docs.len() == 1 && self.memory_size > self.worker_memory_limit_bytes { - return Err(Error::invalid_input(format!( - "single document row_id={} exceeds worker memory limit: {} > {} bytes", - row_id, self.memory_size, self.worker_memory_limit_bytes - ))); + current = token_id; + count = 1; } + let (old_posting_memory_size, new_posting_memory_size) = { + let posting_list = &mut self.builder.posting_lists[current as usize]; + let old_posting_memory_size = posting_list.size(); + posting_list.add(doc_id, PositionRecorder::Count(count)); + let new_posting_memory_size = posting_list.size(); + (old_posting_memory_size, new_posting_memory_size) + }; + posting_memory_delta += new_posting_memory_size as i64 - old_posting_memory_size as i64; + Self::apply_delta(&mut self.memory_size, posting_memory_delta); + } + self.last_token_count = self.token_ids.len(); + self.trim_temporary_buffers(); + self.adjust_tracked_memory_size(old_temporary_memory_size, self.temporary_memory_size()); - if self.builder.docs.len() as u32 == u32::MAX - || (!builder_was_empty && self.memory_size >= self.worker_memory_limit_bytes) - { - self.flush().await?; - } + if self.builder.docs.len() == 1 && self.memory_size > self.worker_memory_limit_bytes { + return Err(Error::invalid_input(format!( + "single document row_id={} exceeds worker memory limit: {} > {} bytes", + row_id, self.memory_size, self.worker_memory_limit_bytes + ))); } + if self.builder.docs.len() as u32 == u32::MAX + || (!builder_was_empty && self.memory_size >= self.worker_memory_limit_bytes) + { + self.flush().await?; + } Ok(()) } diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 56547c6510b..02581bb5418 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -54,6 +54,7 @@ use tracing::{info, instrument}; use super::encoding::{PositionBlockBuilder, decode_group_starts}; use super::iter::PostingListIterator; use super::lazy_docset::LazyDocSet; +use super::tokenizer::document_tokenizer::JsonTokenizerMode; use super::{InvertedIndexBuilder, InvertedIndexParams, wand::*}; use super::{ builder::{ @@ -693,10 +694,24 @@ impl InvertedIndex { &local_scorer }; - let limit = params.limit.unwrap_or(usize::MAX); - if limit == 0 { + let requested_limit = params.limit.unwrap_or(usize::MAX); + if requested_limit == 0 { return Ok((Vec::new(), Vec::new())); } + let should_deduplicate_rows = + self.params.json_tokenizer_mode == Some(JsonTokenizerMode::FlattenedSubDocs); + let candidate_limit = if should_deduplicate_rows { + usize::MAX + } else { + requested_limit + }; + let search_params = if should_deduplicate_rows { + let mut params = params.as_ref().clone(); + params.limit = None; + Arc::new(params) + } else { + params.clone() + }; let mask = prefilter.mask(); let mut candidates = BinaryHeap::new(); @@ -711,7 +726,7 @@ impl InvertedIndex { .map(|part| { let part = part.clone(); let tokens = tokens.clone(); - let params = params.clone(); + let params = search_params.clone(); let mask = mask.clone(); let metrics = metrics.clone(); let shared_threshold = shared_threshold.clone(); @@ -802,7 +817,7 @@ impl InvertedIndex { score += idf_by_position[term_index as usize] * scorer.doc_weight(freq, doc_length); } - if candidates.len() < limit { + if candidates.len() < candidate_limit { candidates.push(Reverse(ScoredDoc::new(row_id, score))); } else if candidates.peek().unwrap().0.score.0 < score { candidates.pop(); @@ -811,11 +826,16 @@ impl InvertedIndex { } } - Ok(candidates + let (row_ids, scores): (Vec<_>, Vec<_>) = candidates .into_sorted_vec() .into_iter() .map(|Reverse(doc)| (doc.row_id, doc.score.0)) - .unzip()) + .unzip(); + if should_deduplicate_rows { + Ok(deduplicate_scored_rows(row_ids, scores, requested_limit)) + } else { + Ok((row_ids, scores)) + } } async fn load_legacy_index( @@ -4967,6 +4987,7 @@ async fn tokenize_and_count( // thread is invisible to the caller's poll timer otherwise). let start = std::time::Instant::now(); let batch = batch?; + let mut row_ids_builder = UInt64Builder::with_capacity(batch.num_rows()); let mut all_token_counts = UInt64Builder::with_capacity(batch.num_rows()); let mut query_token_counts = FixedSizeListBuilder::with_capacity( UInt64Builder::with_capacity(batch.num_rows() * query_tokens.len()), @@ -4975,8 +4996,10 @@ async fn tokenize_and_count( ); let mut temp_query_token_counts = Vec::with_capacity(query_tokens.len()); let doc_iter = iter_str_array(batch.column(doc_col_idx)); - for doc in doc_iter { + let row_id_array = batch[ROW_ID].as_primitive::(); + for (doc, row_id) in doc_iter.zip(row_id_array.values().iter().copied()) { let Some(doc) = doc else { + row_ids_builder.append_value(row_id); all_token_counts.append_value(0); query_token_counts .values() @@ -4985,31 +5008,37 @@ async fn tokenize_and_count( continue; }; - temp_query_token_counts.clear(); - temp_query_token_counts.extend(std::iter::repeat_n(0, query_tokens.len())); - - let mut stream = tokenizer.token_stream_for_doc(doc); - let mut all_tokens = 0; - while let Some(token) = stream.next() { - all_tokens += 1; - if let Some(token_index) = query_tokens.token_index(&token.text) { - temp_query_token_counts[token_index] += 1; + let sub_docs = tokenizer.token_streams_for_doc(doc).map_err(|err| { + datafusion::error::DataFusionError::Execution(err.to_string()) + })?; + for tokens in sub_docs { + row_ids_builder.append_value(row_id); + temp_query_token_counts.clear(); + temp_query_token_counts + .extend(std::iter::repeat_n(0, query_tokens.len())); + + let mut all_tokens = 0; + for token in tokens { + all_tokens += 1; + if let Some(token_index) = query_tokens.token_index(&token.text) { + temp_query_token_counts[token_index] += 1; + } } + all_token_counts.append_value(all_tokens); + for count in temp_query_token_counts.iter().copied() { + query_token_counts.values().append_value(count); + } + query_token_counts.append(true); } - all_token_counts.append_value(all_tokens); - for count in temp_query_token_counts.iter().copied() { - query_token_counts.values().append_value(count); - } - query_token_counts.append(true); } - let row_ids = batch[ROW_ID].clone(); + let row_ids = row_ids_builder.finish(); let all_token_counts = all_token_counts.finish(); let query_token_counts = query_token_counts.finish(); let result_batch = RecordBatch::try_new( output_schema, vec![ - row_ids, + Arc::new(row_ids) as ArrayRef, Arc::new(all_token_counts) as ArrayRef, Arc::new(query_token_counts) as ArrayRef, ], @@ -5092,10 +5121,54 @@ fn initialize_scorer( MemBM25Scorer::new(total_tokens, num_docs, token_counts_map) } +fn deduplicate_scored_rows( + row_ids: Vec, + scores: Vec, + limit: usize, +) -> (Vec, Vec) { + let mut scores_by_row_id: HashMap = HashMap::with_capacity(row_ids.len()); + for (row_id, score) in row_ids.into_iter().zip(scores) { + scores_by_row_id + .entry(row_id) + .and_modify(|existing| { + if score > *existing { + *existing = score; + } + }) + .or_insert(score); + } + + let mut scored_rows = scores_by_row_id.into_iter().collect::>(); + scored_rows.sort_unstable_by(|(left_row_id, left_score), (right_row_id, right_score)| { + right_score + .total_cmp(left_score) + .then_with(|| left_row_id.cmp(right_row_id)) + }); + scored_rows.truncate(scored_rows.len().min(limit)); + scored_rows.into_iter().unzip() +} + +fn deduplicate_fts_batch(batch: RecordBatch, limit: usize) -> Result { + let row_ids = batch[ROW_ID].as_primitive::().values().to_vec(); + let scores = batch[SCORE_COL] + .as_primitive::() + .values() + .to_vec(); + let (row_ids, scores) = deduplicate_scored_rows(row_ids, scores, limit); + Ok(RecordBatch::try_new( + FTS_SCHEMA.clone(), + vec![ + Arc::new(UInt64Array::from(row_ids)) as ArrayRef, + Arc::new(Float32Array::from(scores)) as ArrayRef, + ], + )?) +} + fn flat_bm25_score( query_tokens: &Tokens, counted_input: &RecordBatch, scorer: &MemBM25Scorer, + require_all_query_tokens: bool, ) -> Result { let mut row_ids_builder = UInt64Builder::with_capacity(counted_input.num_rows()); let mut scores_builder = Float32Builder::with_capacity(counted_input.num_rows()); @@ -5131,12 +5204,16 @@ fn flat_bm25_score( } let doc_norm = K1 * (1.0 - B + B * num_tokens_in_doc as f32 / scorer.avg_doc_length()); let mut score = 0.0; + let mut has_all_query_tokens = true; for token in query_tokens { let freq = query_token_counts_iter.next().expect_ok()? as f32; + if freq == 0.0 { + has_all_query_tokens = false; + } let idf = idf(scorer.num_docs_containing_token(token), scorer.num_docs()); score += idf * (freq * (K1 + 1.0) / (freq + doc_norm)); } - if score > 0.0 { + if score > 0.0 && (!require_all_query_tokens || has_all_query_tokens) { row_ids_builder.append_value(row_id); scores_builder.append_value(score); } @@ -5194,6 +5271,13 @@ pub async fn flat_bm25_search_stream_with_metrics( // Pre-await synchronous work: query tokenization + chunk-stream setup. let pre_await_start = std::time::Instant::now(); let query_tokens = Arc::new(collect_query_tokens(&query, &mut tokenizer)); + let should_deduplicate_rows = + tokenizer.json_tokenizer_mode() == Some(JsonTokenizerMode::FlattenedSubDocs); + let require_all_query_tokens = should_deduplicate_rows + && query_tokens + .as_ref() + .into_iter() + .any(|token| token.contains("$idx,number,")); // A query that tokenizes to no terms (e.g. only stop words) has no // searchable content and matches nothing. Return early rather than @@ -5243,7 +5327,15 @@ pub async fn flat_bm25_search_stream_with_metrics( // All post-await work is synchronous; time the scorer + score + slicing loop together. let post_await_start = std::time::Instant::now(); let scorer = initialize_scorer(base_scorer.as_ref(), query_tokens.as_ref(), &counted_input); - let scores = flat_bm25_score(query_tokens.as_ref(), &counted_input, &scorer)?; + let mut scores = flat_bm25_score( + query_tokens.as_ref(), + &counted_input, + &scorer, + require_all_query_tokens, + )?; + if should_deduplicate_rows { + scores = deduplicate_fts_batch(scores, usize::MAX)?; + } // Finally we emit batches according to the target batch size let num_out_batches = scores.num_rows().div_ceil(target_batch_size); diff --git a/rust/lance-index/src/scalar/inverted/query.rs b/rust/lance-index/src/scalar/inverted/query.rs index 7388bc0401f..c8dc40b02bb 100644 --- a/rust/lance-index/src/scalar/inverted/query.rs +++ b/rust/lance-index/src/scalar/inverted/query.rs @@ -811,13 +811,19 @@ pub fn has_query_token( tokenizer: &mut Box, query_tokens: &Tokens, ) -> bool { - let mut stream = tokenizer.token_stream_for_doc(text); - while let Some(token) = stream.next() { - if query_tokens.contains(&token.text) { - return true; + match tokenizer.token_streams_for_doc(text) { + Ok(sub_docs) => { + for tokens in sub_docs { + for token in tokens { + if query_tokens.contains(&token.text) { + return true; + } + } + } + false } + Err(_) => false, } - false } pub fn fill_fts_query_column( diff --git a/rust/lance-index/src/scalar/inverted/tokenizer.rs b/rust/lance-index/src/scalar/inverted/tokenizer.rs index 6024747025b..1d93c0fa47e 100644 --- a/rust/lance-index/src/scalar/inverted/tokenizer.rs +++ b/rust/lance-index/src/scalar/inverted/tokenizer.rs @@ -3,7 +3,7 @@ use lance_core::{Error, Result}; use serde::{Deserialize, Serialize}; -use std::{env, path::PathBuf}; +use std::{env, path::PathBuf, str::FromStr}; #[cfg(feature = "tokenizer-jieba")] mod jieba; @@ -20,7 +20,7 @@ use lindera::LinderaTokenizerBuilder; use crate::pbold; use crate::scalar::inverted::tokenizer::document_tokenizer::{ - JsonTokenizer, LanceTokenizer, TextTokenizer, + JsonTokenizer, JsonTokenizerMode, LanceTokenizer, TextTokenizer, }; pub use lance_tokenizer::Language; use lance_tokenizer::{ @@ -97,6 +97,19 @@ pub struct InvertedIndexParams { #[serde(default)] pub(crate) prefix_only: bool, + /// JSON tokenization mode. `None` means the caller did not provide a mode. + /// Existing JSON indexes without this field are interpreted as `SingleDocument`; + /// new JSON indexes default this to `FlattenedSubDocs` during index build. + #[serde(default)] + pub(crate) json_tokenizer_mode: Option, + + /// If true, flattened JSON tokenization avoids cross-array unnesting. + /// This reduces sub-doc explosion for JSON records with multiple sibling + /// arrays by indexing each array independently instead of producing their + /// Cartesian product. Default is false for exact query semantics. + #[serde(default)] + pub(crate) disable_cross_array_unnest: bool, + /// Total memory limit in MiB for the build stage. /// /// This is split evenly across FTS workers at build time. By default Lance @@ -137,6 +150,11 @@ impl TryFrom<&InvertedIndexParams> for pbold::InvertedIndexDetails { min_ngram_length: params.min_ngram_length, max_ngram_length: params.max_ngram_length, prefix_only: params.prefix_only, + json_tokenizer_mode: params + .json_tokenizer_mode + .filter(|mode| *mode == JsonTokenizerMode::FlattenedSubDocs) + .map(|mode| mode.as_ref().to_string()), + disable_cross_array_unnest: params.disable_cross_array_unnest, }) } } @@ -164,6 +182,12 @@ impl TryFrom<&pbold::InvertedIndexDetails> for InvertedIndexParams { min_ngram_length: details.min_ngram_length, max_ngram_length: details.max_ngram_length, prefix_only: details.prefix_only, + json_tokenizer_mode: details + .json_tokenizer_mode + .as_deref() + .map(JsonTokenizerMode::from_str) + .transpose()?, + disable_cross_array_unnest: details.disable_cross_array_unnest, memory_limit_mb: defaults.memory_limit_mb, num_workers: defaults.num_workers, }) @@ -218,6 +242,8 @@ impl InvertedIndexParams { min_ngram_length: default_min_ngram_length(), max_ngram_length: default_max_ngram_length(), prefix_only: false, + json_tokenizer_mode: None, + disable_cross_array_unnest: false, memory_limit_mb: None, num_workers: None, } @@ -228,6 +254,18 @@ impl InvertedIndexParams { self } + /// Set how JSON documents are tokenized by the Lance JSON tokenizer. + pub fn json_tokenizer_mode(mut self, mode: JsonTokenizerMode) -> Self { + self.json_tokenizer_mode = Some(mode); + self + } + + /// Set whether flattened JSON tokenization avoids cross-array unnesting. + pub fn disable_cross_array_unnest(mut self, disable_cross_array_unnest: bool) -> Self { + self.disable_cross_array_unnest = disable_cross_array_unnest; + self + } + pub fn base_tokenizer(mut self, base_tokenizer: String) -> Self { self.base_tokenizer = base_tokenizer; self @@ -373,7 +411,12 @@ impl InvertedIndexParams { match self.lance_tokenizer { Some(ref t) if t == "text" => Ok(Box::new(TextTokenizer::new(tokenizer))), - Some(ref t) if t == "json" => Ok(Box::new(JsonTokenizer::new(tokenizer))), + Some(ref t) if t == "json" => Ok(Box::new(JsonTokenizer::new( + tokenizer, + self.json_tokenizer_mode + .unwrap_or(JsonTokenizerMode::SingleDocument), + self.disable_cross_array_unnest, + ))), None => Ok(Box::new(TextTokenizer::new(tokenizer))), _ => Err(Error::invalid_input(format!( "unknown lance tokenizer {}", @@ -440,6 +483,8 @@ pub fn language_model_home() -> Option { #[cfg(test)] mod tests { use super::InvertedIndexParams; + use crate::pbold; + use crate::scalar::inverted::tokenizer::document_tokenizer::JsonTokenizerMode; use lance_tokenizer::TokenStream; #[test] @@ -490,6 +535,34 @@ mod tests { assert_eq!(json.get("num_workers"), Some(&serde_json::Value::from(3))); } + #[test] + fn test_json_tokenizer_mode_details_round_trip() { + let params = InvertedIndexParams::default() + .lance_tokenizer("json".to_string()) + .json_tokenizer_mode(JsonTokenizerMode::FlattenedSubDocs) + .disable_cross_array_unnest(true); + let details = pbold::InvertedIndexDetails::try_from(¶ms).unwrap(); + assert_eq!( + details.json_tokenizer_mode.as_deref(), + Some("flattened_sub_docs") + ); + assert!(details.disable_cross_array_unnest); + + let decoded_params = InvertedIndexParams::try_from(&details).unwrap(); + assert_eq!( + decoded_params.json_tokenizer_mode, + Some(JsonTokenizerMode::FlattenedSubDocs) + ); + assert!(decoded_params.disable_cross_array_unnest); + + let single_document_params = InvertedIndexParams::default() + .lance_tokenizer("json".to_string()) + .json_tokenizer_mode(JsonTokenizerMode::SingleDocument); + let details = pbold::InvertedIndexDetails::try_from(&single_document_params).unwrap(); + assert_eq!(details.json_tokenizer_mode, None); + assert!(!details.disable_cross_array_unnest); + } + #[test] fn test_build_icu_tokenizer() { let mut tokenizer = InvertedIndexParams::default() diff --git a/rust/lance-index/src/scalar/inverted/tokenizer/document_tokenizer.rs b/rust/lance-index/src/scalar/inverted/tokenizer/document_tokenizer.rs index 62dd7b1aa3b..8f40ac02e66 100644 --- a/rust/lance-index/src/scalar/inverted/tokenizer/document_tokenizer.rs +++ b/rust/lance-index/src/scalar/inverted/tokenizer/document_tokenizer.rs @@ -4,8 +4,11 @@ use arrow_schema::{DataType, Field}; use lance_arrow::ARROW_EXT_NAME_KEY; use lance_arrow::json::JSON_EXT_NAME; +use lance_core::Error; use lance_tokenizer::{BoxTokenStream, TextAnalyzer, Token, TokenStream}; +use serde::{Deserialize, Serialize}; use serde_json::Value; +use std::str::FromStr; /// Document type for full text search. #[derive(Debug, Clone)] @@ -14,6 +17,39 @@ pub enum DocType { Json, } +/// Controls how JSON documents are represented inside the inverted index. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum JsonTokenizerMode { + /// Emit one token stream for each source JSON document. + SingleDocument, + /// Flatten arrays into multiple sub-doc token streams for each source JSON document. + FlattenedSubDocs, +} + +impl AsRef for JsonTokenizerMode { + fn as_ref(&self) -> &str { + match self { + Self::SingleDocument => "single_document", + Self::FlattenedSubDocs => "flattened_sub_docs", + } + } +} + +impl FromStr for JsonTokenizerMode { + type Err = Error; + + fn from_str(value: &str) -> std::result::Result { + match value { + "single_document" => Ok(Self::SingleDocument), + "flattened_sub_docs" => Ok(Self::FlattenedSubDocs), + _ => Err(Error::invalid_input(format!( + "unknown JSON tokenizer mode {value:?}; expected 'single_document' or 'flattened_sub_docs'" + ))), + } + } +} + impl AsRef for DocType { fn as_ref(&self) -> &str { match self { @@ -78,10 +114,27 @@ pub trait LanceTokenizer: Send + Sync + std::fmt::Debug { fn token_stream_for_search<'a>(&'a mut self, query_text: &'a str) -> BoxTokenStream<'a>; /// Tokenize document text for index. fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a>; + /// Tokenize document text into one or more internal inverted-index documents. + fn token_streams_for_doc(&mut self, text: &str) -> lance_core::Result>> { + let mut stream = self.token_stream_for_doc(text); + let mut tokens = Vec::new(); + while let Some(token) = stream.next() { + tokens.push(token.clone()); + } + Ok(vec![tokens]) + } /// Clone the tokenizer. fn box_clone(&self) -> Box; /// Get document type. fn doc_type(&self) -> DocType; + /// Get the JSON tokenization mode, if this tokenizer handles JSON documents. + fn json_tokenizer_mode(&self) -> Option { + None + } + /// Whether flattened JSON tokenization avoids cross-array unnesting. + fn disable_cross_array_unnest(&self) -> bool { + false + } } impl Clone for Box { @@ -128,39 +181,75 @@ impl LanceTokenizer for TextTokenizer { #[derive(Clone)] pub struct JsonTokenizer { tokenizer: TextAnalyzer, + mode: JsonTokenizerMode, + disable_cross_array_unnest: bool, } impl JsonTokenizer { - pub fn new(tokenizer: TextAnalyzer) -> Self { - Self { tokenizer } + pub fn new( + tokenizer: TextAnalyzer, + mode: JsonTokenizerMode, + disable_cross_array_unnest: bool, + ) -> Self { + Self { + tokenizer, + mode, + disable_cross_array_unnest, + } } } impl std::fmt::Debug for JsonTokenizer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "JsonTokenizer") + f.debug_struct("JsonTokenizer") + .field("mode", &self.mode) + .field( + "disable_cross_array_unnest", + &self.disable_cross_array_unnest, + ) + .finish() } } impl LanceTokenizer for JsonTokenizer { fn token_stream_for_search<'a>(&'a mut self, query_text: &'a str) -> BoxTokenStream<'a> { - let tokens = flatten_triplet(query_text, &mut self.tokenizer).unwrap(); + let tokens = flatten_triplet(query_text, self.mode, &mut self.tokenizer).unwrap(); BoxTokenStream::new(TTStream { tokens, index: 0 }) } fn token_stream_for_doc<'a>(&'a mut self, text: &'a str) -> BoxTokenStream<'a> { - let value: Value = match serde_json::from_slice(text.as_bytes()) { - Ok(v) => v, - Err(e) => { - panic!("JSON parse error: {:?}", e); - } - }; - let mut tokens = vec![]; - let mut position = 0; - flatten_json(&value, "", &mut tokens, &mut position, &mut self.tokenizer); + let tokens = self + .token_streams_for_doc(text) + .unwrap() + .into_iter() + .next() + .unwrap_or_default(); BoxTokenStream::new(TTStream { tokens, index: 0 }) } + fn token_streams_for_doc(&mut self, text: &str) -> lance_core::Result>> { + let value: Value = serde_json::from_slice(text.as_bytes()).map_err(|err| { + Error::invalid_input(format!( + "failed to parse JSON document for FTS indexing: {err}" + )) + })?; + + match self.mode { + JsonTokenizerMode::SingleDocument => { + let mut tokens = Vec::new(); + let mut position = 0; + flatten_json(&value, "", &mut tokens, &mut position, &mut self.tokenizer); + Ok(vec![tokens]) + } + JsonTokenizerMode::FlattenedSubDocs => Ok(flatten_json_sub_docs( + &value, + "", + &mut self.tokenizer, + self.disable_cross_array_unnest, + )), + } + } + fn box_clone(&self) -> Box { Box::new(self.clone()) } @@ -168,9 +257,21 @@ impl LanceTokenizer for JsonTokenizer { fn doc_type(&self) -> DocType { DocType::Json } + + fn json_tokenizer_mode(&self) -> Option { + Some(self.mode) + } + + fn disable_cross_array_unnest(&self) -> bool { + self.disable_cross_array_unnest + } } -fn flatten_triplet(text: &str, tokenizer: &mut TextAnalyzer) -> lance_core::Result> { +fn flatten_triplet( + text: &str, + mode: JsonTokenizerMode, + tokenizer: &mut TextAnalyzer, +) -> lance_core::Result> { let mut token_vec = Vec::new(); let mut idx = 0; @@ -184,6 +285,21 @@ fn flatten_triplet(text: &str, tokenizer: &mut TextAnalyzer) -> lance_core::Resu let field = parts[0]; let v_type = parts[1]; let value = parts[2]; + let (field, mut index_tokens) = match mode { + JsonTokenizerMode::SingleDocument => (field.to_string(), Vec::new()), + JsonTokenizerMode::FlattenedSubDocs => normalize_flattened_json_path(field)?, + }; + + for index_token in index_tokens.drain(..) { + token_vec.push(Token { + offset_from: 0, + offset_to: 0, + position: idx, + text: index_token, + position_length: 1, + }); + idx += 1; + } match v_type { "number" | "bool" | "null" => { @@ -220,6 +336,46 @@ fn flatten_triplet(text: &str, tokenizer: &mut TextAnalyzer) -> lance_core::Resu Ok(token_vec) } +fn normalize_flattened_json_path(path: &str) -> lance_core::Result<(String, Vec)> { + let mut normalized = String::with_capacity(path.len()); + let mut index_tokens = Vec::new(); + let mut chars = path.char_indices().peekable(); + + while let Some((_, ch)) = chars.next() { + if ch != '[' { + normalized.push(ch); + continue; + } + + let index_path = normalized.clone(); + let mut array_index = String::new(); + let mut found_right_bracket = false; + for (_, bracket_ch) in chars.by_ref() { + if bracket_ch == ']' { + found_right_bracket = true; + break; + } + array_index.push(bracket_ch); + } + if !found_right_bracket { + return Err(Error::invalid_input(format!( + "missing right bracket in JSON path {path:?}" + ))); + } + if array_index.is_empty() { + return Err(Error::invalid_input(format!( + "empty array index in JSON path {path:?}" + ))); + } + if array_index != "*" { + index_tokens.push(format!("{index_path}$idx,number,{array_index}")); + } + normalized.push('.'); + } + + Ok((normalized, index_tokens)) +} + fn flatten_json( value: &Value, prefix: &str, @@ -277,6 +433,168 @@ fn flatten_json( } } +fn flatten_json_sub_docs( + value: &Value, + prefix: &str, + tokenizer: &mut TextAnalyzer, + disable_cross_array_unnest: bool, +) -> Vec> { + let token_texts = + flatten_json_sub_doc_terms(value, prefix, tokenizer, disable_cross_array_unnest); + token_texts + .into_iter() + .map(|sub_doc| { + sub_doc + .into_iter() + .enumerate() + .map(|(position, text)| Token { + offset_from: 0, + offset_to: 0, + position, + text, + position_length: 1, + }) + .collect() + }) + .collect() +} + +fn flatten_json_sub_doc_terms( + value: &Value, + prefix: &str, + tokenizer: &mut TextAnalyzer, + disable_cross_array_unnest: bool, +) -> Vec> { + match value { + Value::Object(map) => { + let mut non_nested = Vec::new(); + let mut nested: Vec>> = Vec::new(); + + for (key, child) in map { + let child_prefix = if prefix.is_empty() { + key.clone() + } else { + format!("{prefix}.{key}") + }; + let child_terms = flatten_json_sub_doc_terms( + child, + &child_prefix, + tokenizer, + disable_cross_array_unnest, + ); + match child_terms.len() { + 0 => {} + 1 => non_nested.extend(child_terms.into_iter().next().unwrap()), + _ => nested.push(child_terms), + } + } + + match nested.len() { + 0 if non_nested.is_empty() => Vec::new(), + 0 => vec![non_nested], + 1 => nested + .pop() + .unwrap() + .into_iter() + .map(|mut sub_doc| { + sub_doc.extend(non_nested.iter().cloned()); + sub_doc + }) + .collect(), + _ if disable_cross_array_unnest => unnest_json_sub_docs(&nested, &non_nested), + _ => cross_join_json_sub_docs(&nested, &non_nested), + } + } + Value::Array(arr) => { + let mut sub_docs = Vec::new(); + let child_prefix = format!("{prefix}."); + for (array_index, child) in arr.iter().enumerate() { + let mut child_terms = flatten_json_sub_doc_terms( + child, + &child_prefix, + tokenizer, + disable_cross_array_unnest, + ); + for sub_doc in &mut child_terms { + sub_doc.push(format!("{prefix}$idx,number,{array_index}")); + } + sub_docs.extend(child_terms); + } + sub_docs + } + Value::String(text) => { + let mut token_texts = Vec::new(); + let mut tokens = tokenizer.token_stream(text); + while let Some(token) = tokens.next() { + token_texts.push(format!("{prefix},str,{}", token.text)); + } + if token_texts.is_empty() { + Vec::new() + } else { + vec![token_texts] + } + } + _ => { + let value_type = match value { + Value::Null => "null", + Value::Bool(_) => "bool", + Value::Number(_) => "number", + _ => unreachable!(), + }; + vec![vec![format!("{prefix},{value_type},{value}")]] + } + } +} + +fn cross_join_json_sub_docs( + nested: &[Vec>], + non_nested: &[String], +) -> Vec> { + let capacity = nested + .iter() + .map(|sub_docs| sub_docs.len()) + .product::(); + let mut results = Vec::with_capacity(capacity); + let mut current = Vec::new(); + cross_join_json_sub_docs_inner(nested, 0, non_nested, &mut current, &mut results); + results +} + +fn unnest_json_sub_docs(nested: &[Vec>], non_nested: &[String]) -> Vec> { + let capacity = nested.iter().map(|sub_docs| sub_docs.len()).sum::(); + let mut results = Vec::with_capacity(capacity); + for sub_docs in nested { + for child in sub_docs { + let mut sub_doc = child.clone(); + sub_doc.extend(non_nested.iter().cloned()); + results.push(sub_doc); + } + } + results +} + +fn cross_join_json_sub_docs_inner( + nested: &[Vec>], + nested_index: usize, + non_nested: &[String], + current: &mut Vec, + results: &mut Vec>, +) { + if nested_index == nested.len() { + let mut sub_doc = current.clone(); + sub_doc.extend(non_nested.iter().cloned()); + results.push(sub_doc); + return; + } + + for child in &nested[nested_index] { + let old_len = current.len(); + current.extend(child.iter().cloned()); + cross_join_json_sub_docs_inner(nested, nested_index + 1, non_nested, current, results); + current.truncate(old_len); + } +} + struct TTStream { tokens: Vec, index: usize, @@ -304,7 +622,8 @@ impl TokenStream for TTStream { #[cfg(test)] mod tests { use crate::scalar::inverted::tokenizer::document_tokenizer::{ - JsonTokenizer, LanceTokenizer, flatten_json, flatten_triplet, + JsonTokenizer, JsonTokenizerMode, LanceTokenizer, flatten_json, flatten_json_sub_docs, + flatten_triplet, }; use lance_tokenizer::{SimpleTokenizer, TextAnalyzer, Token}; use serde_json::Value; @@ -318,8 +637,11 @@ mod tests { {"c": "e"} ] }"#; - let mut tokenizer = - JsonTokenizer::new(TextAnalyzer::builder(SimpleTokenizer::default()).build()); + let mut tokenizer = JsonTokenizer::new( + TextAnalyzer::builder(SimpleTokenizer::default()).build(), + JsonTokenizerMode::SingleDocument, + false, + ); let mut stream = tokenizer.token_stream_for_doc(text); let mut tokens: Vec = vec![]; @@ -368,7 +690,8 @@ mod tests { fn test_flatten_triplet() { let text = r#"a,number,1;b.c,str,d;b.c,str,e;d,str,hello world;e,number,1.0"#; let mut tokenizer = TextAnalyzer::builder(SimpleTokenizer::default()).build(); - let tokens = flatten_triplet(text, &mut tokenizer).unwrap(); + let tokens = + flatten_triplet(text, JsonTokenizerMode::SingleDocument, &mut tokenizer).unwrap(); assert_eq!(tokens.len(), 6); assert_token(&tokens[0], 0, "a,number,1"); @@ -379,6 +702,175 @@ mod tests { assert_token(&tokens[5], 5, "e,number,1.0"); } + #[test] + fn test_flattened_sub_docs_design_example() { + let doc0 = flattened_sub_doc_texts(r#"{"foo":[{"bar":["x","y"]}]}"#); + let doc1 = flattened_sub_doc_texts(r#"{"foo":[{"bar":["y"]},{"bar":"z"}]}"#); + + assert_eq!( + doc0, + vec![ + sorted_tokens([ + "foo$idx,number,0", + "foo..bar$idx,number,0", + "foo..bar.,str,x", + ]), + sorted_tokens([ + "foo$idx,number,0", + "foo..bar$idx,number,1", + "foo..bar.,str,y", + ]), + ] + ); + assert_eq!( + doc1, + vec![ + sorted_tokens([ + "foo$idx,number,0", + "foo..bar$idx,number,0", + "foo..bar.,str,y", + ]), + sorted_tokens(["foo$idx,number,1", "foo..bar,str,z"]), + ] + ); + + let mut tokenizer = TextAnalyzer::builder(SimpleTokenizer::default()).build(); + let exact_tokens = flatten_triplet( + "foo[0].bar[0],str,y", + JsonTokenizerMode::FlattenedSubDocs, + &mut tokenizer, + ) + .unwrap(); + assert_token_texts( + &exact_tokens, + &[ + "foo$idx,number,0", + "foo..bar$idx,number,0", + "foo..bar.,str,y", + ], + ); + + let wildcard_tokens = flatten_triplet( + "foo[0].bar[*],str,y", + JsonTokenizerMode::FlattenedSubDocs, + &mut tokenizer, + ) + .unwrap(); + assert_token_texts(&wildcard_tokens, &["foo$idx,number,0", "foo..bar.,str,y"]); + } + + #[test] + fn test_flattened_sub_docs_sibling_array_example() { + let doc0 = flattened_sub_doc_texts( + r#"{"foo":[{"bar":["x","y"]},{"bar":["a","b"]}],"foo2":["u"]}"#, + ); + let doc1 = flattened_sub_doc_texts(r#"{"foo":[{"bar":["y","z"]}],"foo2":["u"]}"#); + + assert_eq!( + doc0, + vec![ + sorted_tokens([ + "foo$idx,number,0", + "foo..bar$idx,number,0", + "foo..bar.,str,x", + "foo2$idx,number,0", + "foo2.,str,u", + ]), + sorted_tokens([ + "foo$idx,number,0", + "foo..bar$idx,number,1", + "foo..bar.,str,y", + "foo2$idx,number,0", + "foo2.,str,u", + ]), + sorted_tokens([ + "foo$idx,number,1", + "foo..bar$idx,number,0", + "foo..bar.,str,a", + "foo2$idx,number,0", + "foo2.,str,u", + ]), + sorted_tokens([ + "foo$idx,number,1", + "foo..bar$idx,number,1", + "foo..bar.,str,b", + "foo2$idx,number,0", + "foo2.,str,u", + ]), + ] + ); + assert_eq!( + doc1, + vec![ + sorted_tokens([ + "foo$idx,number,0", + "foo..bar$idx,number,0", + "foo..bar.,str,y", + "foo2$idx,number,0", + "foo2.,str,u", + ]), + sorted_tokens([ + "foo$idx,number,0", + "foo..bar$idx,number,1", + "foo..bar.,str,z", + "foo2$idx,number,0", + "foo2.,str,u", + ]), + ] + ); + } + + #[test] + fn test_disable_cross_array_unnest_indexes_arrays_independently() { + let cross_joined = flattened_sub_doc_texts(r#"{"a":["x","y"],"b":["u","v"],"c":1}"#); + let disabled = flattened_sub_doc_texts_with_disable_cross_array_unnest( + r#"{"a":["x","y"],"b":["u","v"],"c":1}"#, + ); + + assert_eq!( + sorted_sub_docs(cross_joined), + sorted_sub_docs(vec![ + sorted_tokens([ + "a$idx,number,0", + "a.,str,x", + "b$idx,number,0", + "b.,str,u", + "c,number,1" + ]), + sorted_tokens([ + "a$idx,number,0", + "a.,str,x", + "b$idx,number,1", + "b.,str,v", + "c,number,1" + ]), + sorted_tokens([ + "a$idx,number,1", + "a.,str,y", + "b$idx,number,0", + "b.,str,u", + "c,number,1" + ]), + sorted_tokens([ + "a$idx,number,1", + "a.,str,y", + "b$idx,number,1", + "b.,str,v", + "c,number,1" + ]), + ]) + ); + assert_eq!( + sorted_sub_docs(disabled), + sorted_sub_docs(vec![ + sorted_tokens(["a$idx,number,0", "a.,str,x", "c,number,1"]), + sorted_tokens(["a$idx,number,1", "a.,str,y", "c,number,1"]), + sorted_tokens(["b$idx,number,0", "b.,str,u", "c,number,1"]), + sorted_tokens(["b$idx,number,1", "b.,str,v", "c,number,1"]), + ]) + ); + } + fn assert_token(token: &Token, position: usize, text: &str) { assert_eq!( token.position, position, @@ -390,4 +882,43 @@ mod tests { "expected text {text} but {token:?}" ); } + + fn flattened_sub_doc_texts(json: &str) -> Vec> { + flattened_sub_doc_texts_with_mode(json, false) + } + + fn flattened_sub_doc_texts_with_disable_cross_array_unnest(json: &str) -> Vec> { + flattened_sub_doc_texts_with_mode(json, true) + } + + fn flattened_sub_doc_texts_with_mode( + json: &str, + disable_cross_array_unnest: bool, + ) -> Vec> { + let value: Value = serde_json::from_str(json).unwrap(); + let mut tokenizer = TextAnalyzer::builder(SimpleTokenizer::default()).build(); + flatten_json_sub_docs(&value, "", &mut tokenizer, disable_cross_array_unnest) + .into_iter() + .map(|tokens| sorted_tokens(tokens.into_iter().map(|token| token.text))) + .collect() + } + + fn sorted_tokens(tokens: impl IntoIterator>) -> Vec { + let mut tokens = tokens.into_iter().map(Into::into).collect::>(); + tokens.sort(); + tokens + } + + fn sorted_sub_docs(mut sub_docs: Vec>) -> Vec> { + sub_docs.sort(); + sub_docs + } + + fn assert_token_texts(tokens: &[Token], expected: &[&str]) { + let actual = tokens + .iter() + .map(|token| token.text.as_str()) + .collect::>(); + assert_eq!(actual, expected); + } } diff --git a/rust/lance/src/dataset/tests/dataset_index.rs b/rust/lance/src/dataset/tests/dataset_index.rs index beb6e2b99fd..1c6b68ba16a 100644 --- a/rust/lance/src/dataset/tests/dataset_index.rs +++ b/rust/lance/src/dataset/tests/dataset_index.rs @@ -2728,6 +2728,155 @@ async fn prepare_json_dataset() -> (Dataset, String) { (dataset, json_col) } +#[tokio::test] +async fn test_json_inverted_flattened_sub_doc_array_paths() { + let ids = Arc::new(UInt64Array::from(vec![0, 1])); + let json_col = "json_field".to_string(); + let json_values = Arc::new(StringArray::from(vec![ + r#"{"foo":[{"bar":["x","y"]},{"bar":["a","b"]}],"foo2":["u"]}"#, + r#"{"foo":[{"bar":["y","z"]}],"foo2":["u"]}"#, + ])); + + let mut metadata = HashMap::new(); + metadata.insert( + ARROW_EXT_NAME_KEY.to_string(), + ARROW_JSON_EXT_NAME.to_string(), + ); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + Field::new("id", DataType::UInt64, false), + Field::new(&json_col, DataType::Utf8, false).with_metadata(metadata), + ]) + .into(), + vec![ids as ArrayRef, json_values as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let stream = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write(stream, "memory://test/flattened_json_array_paths", None) + .await + .unwrap(); + + dataset + .create_index( + &[&json_col], + IndexType::Inverted, + None, + &InvertedIndexParams::default() + .lance_tokenizer("json".to_string()) + .stem(false) + .remove_stop_words(false), + true, + ) + .await + .unwrap(); + + let exact_query = FullTextSearchQuery { + query: FtsQuery::Match( + MatchQuery::new("foo[0].bar[0],str,y".to_string()).with_column(Some(json_col.clone())), + ), + limit: None, + wand_factor: None, + }; + let exact_batch = dataset + .scan() + .full_text_search(exact_query) + .unwrap() + .try_into_batch() + .await + .unwrap(); + let exact_ids = exact_batch["id"].as_primitive::().values(); + assert_eq!(exact_ids, &[1]); + + let wildcard_query = FullTextSearchQuery { + query: FtsQuery::Match( + MatchQuery::new("foo[0].bar[*],str,y".to_string()).with_column(Some(json_col.clone())), + ), + limit: None, + wand_factor: None, + }; + let wildcard_batch = dataset + .scan() + .full_text_search(wildcard_query) + .unwrap() + .try_into_batch() + .await + .unwrap(); + let wildcard_ids = wildcard_batch["id"].as_primitive::().values(); + assert_eq!(wildcard_batch.num_rows(), 2, "ids={wildcard_ids:?}"); + assert!(wildcard_ids.contains(&0), "ids={wildcard_ids:?}"); + assert!(wildcard_ids.contains(&1), "ids={wildcard_ids:?}"); +} + +#[tokio::test] +async fn test_json_inverted_flattened_sub_doc_prevents_cross_object_match() { + let ids = Arc::new(UInt64Array::from(vec![0, 1])); + let json_col = "json_field".to_string(); + let json_values = Arc::new(StringArray::from(vec![ + r#"{"cart_id":3234234,"cart":[{"product_type":"sneakers","attributes":{"color":"white"}},{"product_type":"t-shirt","attributes":{"color":"red"}}]}"#, + r#"{"cart_id":3234235,"cart":[{"product_type":"sneakers","attributes":{"color":"red"}}]}"#, + ])); + + let mut metadata = HashMap::new(); + metadata.insert( + ARROW_EXT_NAME_KEY.to_string(), + ARROW_JSON_EXT_NAME.to_string(), + ); + let batch = RecordBatch::try_new( + arrow_schema::Schema::new(vec![ + Field::new("id", DataType::UInt64, false), + Field::new(&json_col, DataType::Utf8, false).with_metadata(metadata), + ]) + .into(), + vec![ids as ArrayRef, json_values as ArrayRef], + ) + .unwrap(); + let schema = batch.schema(); + let stream = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema); + let mut dataset = Dataset::write( + stream, + "memory://test/flattened_json_cross_object_match", + None, + ) + .await + .unwrap(); + + dataset + .create_index( + &[&json_col], + IndexType::Inverted, + None, + &InvertedIndexParams::default() + .lance_tokenizer("json".to_string()) + .stem(false) + .remove_stop_words(false), + true, + ) + .await + .unwrap(); + + let query = FullTextSearchQuery { + query: FtsQuery::Match( + MatchQuery::new( + "cart[*].product_type,str,sneakers;cart[*].attributes.color,str,red".to_string(), + ) + .with_column(Some(json_col.clone())) + .with_operator(Operator::And), + ), + limit: None, + wand_factor: None, + }; + let batch = dataset + .scan() + .full_text_search(query) + .unwrap() + .try_into_batch() + .await + .unwrap(); + let ids = batch["id"].as_primitive::().values(); + assert_eq!(ids, &[1]); +} + #[tokio::test] async fn test_json_inverted_fuzziness_query() { let (mut dataset, json_col) = prepare_json_dataset().await; @@ -3088,7 +3237,7 @@ async fn test_json_inverted_multimatch_query() { match_queries: vec![ MatchQuery::new("Title,str,harrypotter".to_string()) .with_column(Some(json_col.clone())), - MatchQuery::new("Language,str,english".to_string()) + MatchQuery::new("Language[*],str,english".to_string()) .with_column(Some(json_col.clone())), ], }), @@ -3130,7 +3279,7 @@ async fn test_json_inverted_boolean_query() { should: vec![], must: vec![ FtsQuery::Match( - MatchQuery::new("Language,str,english".to_string()) + MatchQuery::new("Language[*],str,english".to_string()) .with_column(Some(json_col.clone())), ), FtsQuery::Match( diff --git a/rust/lance/src/io/exec/fts.rs b/rust/lance/src/io/exec/fts.rs index 76db9300d0f..a6f25f174d2 100644 --- a/rust/lance/src/io/exec/fts.rs +++ b/rust/lance/src/io/exec/fts.rs @@ -40,9 +40,11 @@ use crate::{Dataset, index::DatasetIndexInternalExt}; use lance_index::metrics::MetricsCollector; use lance_index::scalar::inverted::builder::ScoredDoc; use lance_index::scalar::inverted::builder::document_input; -use lance_index::scalar::inverted::document_tokenizer::{DocType, JsonTokenizer, LanceTokenizer}; +use lance_index::scalar::inverted::document_tokenizer::{ + DocType, JsonTokenizer, JsonTokenizerMode, LanceTokenizer, +}; use lance_index::scalar::inverted::query::{ - BoostQuery, FtsSearchParams, MatchQuery, PhraseQuery, Tokens, collect_query_tokens, + BoostQuery, FtsSearchParams, MatchQuery, Operator, PhraseQuery, Tokens, collect_query_tokens, has_query_token, }; use lance_index::scalar::inverted::tokenizer::document_tokenizer::TextTokenizer; @@ -156,6 +158,28 @@ fn default_text_tokenizer() -> Box { )) } +fn query_has_concrete_json_array_index(query: &str) -> bool { + query.split(';').any(|triple| { + let path = triple + .split_once(',') + .map(|(path, _)| path) + .unwrap_or(triple); + let mut remaining = path; + while let Some(left_bracket) = remaining.find('[') { + let after_left = &remaining[left_bracket + 1..]; + let Some(right_bracket) = after_left.find(']') else { + return false; + }; + let array_index = &after_left[..right_bracket]; + if !array_index.is_empty() && array_index != "*" { + return true; + } + remaining = &after_left[right_bracket + 1..]; + } + false + }) +} + pub struct FtsIndexMetrics { index_metrics: IndexMetrics, partitions_searched: Count, @@ -489,11 +513,27 @@ impl ExecutionPlan for MatchQueryExec { Box::new(TextTokenizer::new(tokenizer)) as Box } DocType::Json => { - Box::new(JsonTokenizer::new(tokenizer)) as Box + let index_tokenizer = first_index.tokenizer(); + let mode = index_tokenizer + .json_tokenizer_mode() + .unwrap_or(JsonTokenizerMode::SingleDocument); + Box::new(JsonTokenizer::new( + tokenizer, + mode, + index_tokenizer.disable_cross_array_unnest(), + )) as Box } } } }; + let force_and_operator = tokenizer.json_tokenizer_mode() + == Some(JsonTokenizerMode::FlattenedSubDocs) + && query_has_concrete_json_array_index(&query.terms); + let operator = if force_and_operator { + Operator::And + } else { + query.operator + }; let tokens = collect_query_tokens(&query.terms, &mut tokenizer); let base_scorer = match preset_base_scorer { Some(scorer) => scorer, @@ -511,7 +551,7 @@ impl ExecutionPlan for MatchQueryExec { &indices, tokens, params, - query.operator, + operator, pre_filter, metrics.clone(), base_scorer,