diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index b14421c963f..ad74873ad17 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -1438,6 +1438,27 @@ impl MergeInsertJob { self.execute_uncommitted_impl(stream).await } + /// Join type for the `create_plan` fast path, which builds the join as + /// `source.join(target)` — i.e. the SOURCE is the left input and the TARGET + /// is the right input. + /// + /// At scale the optimizer plans this as a `Partitioned` hash join whose + /// build (left) side is hashed and held in memory per partition. Putting the + /// (typically small) source there keeps the per-partition hash tables + /// bounded by the source size, while the (potentially huge) target streams + /// through as the probe side. Neither input carries row statistics the + /// optimizer can compare (the source is a one-shot stream), so + /// `should_swap_join_order` is `false` and the operands are kept as written + /// rather than swapped to a target build side — which would materialize the + /// entire target per partition and can exhaust memory when the target is + /// large and several partitions run concurrently. + /// + /// Because the source is the left input, the operands are ordered + /// `(keep_unmatched_source_rows, keep_unmatched_target_rows)`: keeping + /// unmatched left (source) rows is a `Left` join, keeping unmatched right + /// (target) rows is a `Right` join. Every column is referenced downstream by + /// qualified name, so the join output is semantically identical to a + /// target-left orientation. fn create_plan_join_type(&self) -> JoinType { let keep_unmatched_source_rows = self.params.insert_not_matched; let keep_unmatched_target_rows = !matches!( @@ -1445,7 +1466,7 @@ impl MergeInsertJob { WhenNotMatchedBySource::Keep ); - match (keep_unmatched_target_rows, keep_unmatched_source_rows) { + match (keep_unmatched_source_rows, keep_unmatched_target_rows) { (false, false) => JoinType::Inner, (false, true) => JoinType::Right, (true, false) => JoinType::Left, @@ -1492,16 +1513,15 @@ impl MergeInsertJob { .map_err(crate::Error::from)?; let source_df_aliased = source_df.alias("source")?; let scan_aliased = scan.alias("target")?; + // Build the join as source.join(target) so the (typically small) source + // is the hash join's build side and the (potentially huge) target is + // streamed as the probe side. See `create_plan_join_type` for why this + // orientation is kept by the optimizer and avoids materializing the + // whole target per partition. let join_type = self.create_plan_join_type(); let dataset_schema: Schema = self.dataset.schema().into(); - let mut df = scan_aliased - .join( - source_df_aliased, - join_type, - &on_cols_refs, - &on_cols_refs, - None, - )? + let mut df = source_df_aliased + .join(scan_aliased, join_type, &on_cols_refs, &on_cols_refs, None)? .with_column( MERGE_ACTION_COLUMN, merge_insert_action(&self.params, Some(&dataset_schema))?, @@ -5436,7 +5456,7 @@ mod tests { plan, "MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep CoalescePartitionsExec - ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action] + ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action] HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5] LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, \ row_id=true, row_addr=true, full_filter=--, refine_filter=-- @@ -5484,7 +5504,7 @@ mod tests { plan, "MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=DoNothing, when_not_matched_by_source=Keep CoalescePartitionsExec - ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action] + ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action] HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5] LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=-- RepartitionExec... @@ -5531,7 +5551,7 @@ mod tests { plan, "MergeInsert: on=[key], when_matched=UpdateIf(source.value > 20), when_not_matched=DoNothing, when_not_matched_by_source=Keep CoalescePartitionsExec - ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action] + ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action] HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5] LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=-- RepartitionExec... @@ -5585,7 +5605,7 @@ mod tests { plan, "MergeInsert: on=[key], when_matched=DoNothing, when_not_matched=InsertAll, when_not_matched_by_source=Keep CoalescePartitionsExec - ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action] + ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action] HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5] LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=-- RepartitionExec... @@ -5596,6 +5616,128 @@ mod tests { .unwrap(); } + /// Regression test for the join build-side orientation with a large target. + /// + /// `create_plan` must keep the small source on the hash join's build side + /// and stream the large target as the probe side; otherwise the entire + /// target is materialized in the hash table, which can exhaust memory on + /// large tables. + /// + /// Covers all four merge shapes, which map to the four hash-join types: + /// + /// | `insert_not_matched` | `delete_not_matched_by_source` | join type | + /// |---|---|---| + /// | false | Keep | Inner (update-only) | + /// | false | Delete | Right (delete unmatched target)| + /// | true | Keep | Left (upsert) | + /// | true | Delete | Full (upsert + delete) | + /// + /// The toy-sized plan-snapshot tests above cannot catch a regression here: + /// with a tiny target the optimizer freely swaps the build/probe sides, so + /// they would still pass with the operands reversed. This test uses a target + /// larger than DataFusion's collect threshold + /// (`hash_join_single_partition_threshold_rows`, 128K) so the swap logic + /// does not pull the target onto the build side. It asserts the orientation + /// (`LanceRead` on the probe/right input), not the partition mode, so it is + /// independent of the host core count (which determines `CollectLeft` vs + /// `Partitioned`). + #[rstest::rstest] + #[case::inner(false, WhenNotMatchedBySource::Keep, JoinType::Inner)] + #[case::right(false, WhenNotMatchedBySource::Delete, JoinType::Right)] + #[case::left(true, WhenNotMatchedBySource::Keep, JoinType::Left)] + #[case::full(true, WhenNotMatchedBySource::Delete, JoinType::Full)] + #[tokio::test] + async fn test_plan_keeps_target_on_probe_side_at_scale( + #[case] insert_not_matched: bool, + #[case] delete_by_source: WhenNotMatchedBySource, + #[case] expected_join_type: JoinType, + ) { + use datafusion::physical_plan::{displayable, joins::HashJoinExec}; + + // Target with > 128K rows so the optimizer's collect/swap logic does not + // pull the target onto the build side. + let data = lance_datagen::gen_batch() + .with_seed(Seed::from(1)) + .col("value", array::step::()) + .col("key", array::step::()); + let data = data.into_reader_rows(RowCount::from(50_000), BatchCount::from(8)); // 400K rows + let ds = Dataset::write(data, "memory://", None).await.unwrap(); + + let mut builder = + crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()]) + .unwrap(); + builder.when_matched(crate::dataset::WhenMatched::UpdateAll); + builder.when_not_matched(if insert_not_matched { + crate::dataset::WhenNotMatched::InsertAll + } else { + crate::dataset::WhenNotMatched::DoNothing + }); + if matches!(delete_by_source, WhenNotMatchedBySource::Delete) { + builder.when_not_matched_by_source(WhenNotMatchedBySource::Delete); + } + let job = builder.try_build().unwrap(); + + // A small source — the side that should be hashed/built. + let new_data = lance_datagen::gen_batch() + .with_seed(Seed::from(2)) + .col("value", array::step::()) + .col("key", array::step::()); + let new_data = new_data.into_reader_rows(RowCount::from(1000), BatchCount::from(1)); + let stream = reader_to_stream(Box::new(new_data)); + let plan = job.create_plan(stream).await.unwrap(); + + // Locate the HashJoinExec in the physical plan. + fn find_hash_join(plan: &Arc) -> Option> { + if plan.as_any().is::() { + return Some(plan.clone()); + } + for child in plan.children() { + if let Some(found) = find_hash_join(child) { + return Some(found); + } + } + None + } + + let rendered = format!("{}", displayable(plan.as_ref()).indent(true)); + let hash_join = find_hash_join(&plan) + .unwrap_or_else(|| panic!("expected a HashJoinExec in the plan:\n{rendered}")); + let hash_join = hash_join + .as_any() + .downcast_ref::() + .expect("HashJoinExec"); + + // Sanity-check that the shape produced the join type we intended to + // exercise, so the source-left operand-order mapping stays in sync with + // `create_plan_join_type`. + assert_eq!( + hash_join.join_type(), + &expected_join_type, + "unexpected join type for this merge shape; plan was:\n{rendered}" + ); + + // The target scan must be the right (probe) input, not the left (build) + // input — that is the whole point of building source.join(target). + // + // This holds regardless of partition mode: with a large target neither + // `CollectLeft` (when DataFusion's target_partitions is 1) nor + // `Partitioned` (the multi-core default) swaps the target onto the build + // side, because the source stream reports no statistics for + // `should_swap_join_order` to act on. We assert the orientation rather + // than the mode so the test does not depend on the host core count. + let right_has_lance_scan = + format!("{}", displayable(hash_join.right().as_ref()).indent(true)) + .contains("LanceRead"); + let left_has_lance_scan = + format!("{}", displayable(hash_join.left().as_ref()).indent(true)) + .contains("LanceRead"); + assert!( + right_has_lance_scan && !left_has_lance_scan, + "target (LanceRead) must be the probe (right) side of the hash join so it \ + is streamed rather than materialized; plan was:\n{rendered}" + ); + } + #[tokio::test] async fn test_skip_auto_cleanup() { let tmpdir = TempStrDir::default();