Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 155 additions & 13 deletions rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1438,14 +1438,35 @@ impl MergeInsertJob {
self.execute_uncommitted_impl(stream).await
}

/// Join type for the `create_plan` fast path, which builds the join as
/// `source.join(target)` — i.e. the SOURCE is the left input and the TARGET
/// is the right input.
///
/// At scale the optimizer plans this as a `Partitioned` hash join whose
/// build (left) side is hashed and held in memory per partition. Putting the
/// (typically small) source there keeps the per-partition hash tables
/// bounded by the source size, while the (potentially huge) target streams
/// through as the probe side. Neither input carries row statistics the
/// optimizer can compare (the source is a one-shot stream), so
/// `should_swap_join_order` is `false` and the operands are kept as written
/// rather than swapped to a target build side — which would materialize the
/// entire target per partition and can exhaust memory when the target is
/// large and several partitions run concurrently.
///
/// Because the source is the left input, the operands are ordered
/// `(keep_unmatched_source_rows, keep_unmatched_target_rows)`: keeping
/// unmatched left (source) rows is a `Left` join, keeping unmatched right
/// (target) rows is a `Right` join. Every column is referenced downstream by
/// qualified name, so the join output is semantically identical to a
/// target-left orientation.
fn create_plan_join_type(&self) -> JoinType {
let keep_unmatched_source_rows = self.params.insert_not_matched;
let keep_unmatched_target_rows = !matches!(
self.params.delete_not_matched_by_source,
WhenNotMatchedBySource::Keep
);

match (keep_unmatched_target_rows, keep_unmatched_source_rows) {
match (keep_unmatched_source_rows, keep_unmatched_target_rows) {
(false, false) => JoinType::Inner,
(false, true) => JoinType::Right,
(true, false) => JoinType::Left,
Expand Down Expand Up @@ -1492,16 +1513,15 @@ impl MergeInsertJob {
.map_err(crate::Error::from)?;
let source_df_aliased = source_df.alias("source")?;
let scan_aliased = scan.alias("target")?;
// Build the join as source.join(target) so the (typically small) source
// is the hash join's build side and the (potentially huge) target is
// streamed as the probe side. See `create_plan_join_type` for why this
// orientation is kept by the optimizer and avoids materializing the
// whole target per partition.
let join_type = self.create_plan_join_type();
let dataset_schema: Schema = self.dataset.schema().into();
let mut df = scan_aliased
.join(
source_df_aliased,
join_type,
&on_cols_refs,
&on_cols_refs,
None,
)?
let mut df = source_df_aliased
.join(scan_aliased, join_type, &on_cols_refs, &on_cols_refs, None)?
.with_column(
MERGE_ACTION_COLUMN,
merge_insert_action(&self.params, Some(&dataset_schema))?,
Expand Down Expand Up @@ -5436,7 +5456,7 @@ mod tests {
plan,
"MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=InsertAll, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NULL THEN 2 WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, \
row_id=true, row_addr=true, full_filter=--, refine_filter=--
Expand Down Expand Up @@ -5484,7 +5504,7 @@ mod tests {
plan,
"MergeInsert: on=[key], when_matched=UpdateAll, when_not_matched=DoNothing, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NOT NULL THEN 1 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec...
Expand Down Expand Up @@ -5531,7 +5551,7 @@ mod tests {
plan,
"MergeInsert: on=[key], when_matched=UpdateIf(source.value > 20), when_not_matched=DoNothing, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action]
ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NOT NULL AND value@2 > 20 THEN 1 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec...
Expand Down Expand Up @@ -5585,7 +5605,7 @@ mod tests {
plan,
"MergeInsert: on=[key], when_matched=DoNothing, when_not_matched=InsertAll, when_not_matched_by_source=Keep
CoalescePartitionsExec
ProjectionExec: expr=[_rowid@0 as _rowid, _rowaddr@1 as _rowaddr, value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action]
ProjectionExec: expr=[value@2 as value, key@3 as key, __merge_source_sentinel@4 as __merge_source_sentinel, _rowid@0 as _rowid, _rowaddr@1 as _rowaddr, CASE WHEN _rowaddr@1 IS NULL THEN 2 ELSE 0 END as __action]
HashJoinExec: mode=CollectLeft, join_type=Right, on=[(key@0, key@1)], projection=[_rowid@1, _rowaddr@2, value@3, key@4, __merge_source_sentinel@5]
LanceRead: uri=..., projection=[key], num_fragments=1, range_before=None, range_after=None, row_id=true, row_addr=true, full_filter=--, refine_filter=--
RepartitionExec...
Expand All @@ -5596,6 +5616,128 @@ mod tests {
.unwrap();
}

/// Regression test for the join build-side orientation with a large target.
///
/// `create_plan` must keep the small source on the hash join's build side
/// and stream the large target as the probe side; otherwise the entire
/// target is materialized in the hash table, which can exhaust memory on
/// large tables.
///
/// Covers all four merge shapes, which map to the four hash-join types:
///
/// | `insert_not_matched` | `delete_not_matched_by_source` | join type |
/// |---|---|---|
/// | false | Keep | Inner (update-only) |
/// | false | Delete | Right (delete unmatched target)|
/// | true | Keep | Left (upsert) |
/// | true | Delete | Full (upsert + delete) |
///
/// The toy-sized plan-snapshot tests above cannot catch a regression here:
/// with a tiny target the optimizer freely swaps the build/probe sides, so
/// they would still pass with the operands reversed. This test uses a target
/// larger than DataFusion's collect threshold
/// (`hash_join_single_partition_threshold_rows`, 128K) so the swap logic
/// does not pull the target onto the build side. It asserts the orientation
/// (`LanceRead` on the probe/right input), not the partition mode, so it is
/// independent of the host core count (which determines `CollectLeft` vs
/// `Partitioned`).
#[rstest::rstest]
#[case::inner(false, WhenNotMatchedBySource::Keep, JoinType::Inner)]
#[case::right(false, WhenNotMatchedBySource::Delete, JoinType::Right)]
#[case::left(true, WhenNotMatchedBySource::Keep, JoinType::Left)]
#[case::full(true, WhenNotMatchedBySource::Delete, JoinType::Full)]
#[tokio::test]
async fn test_plan_keeps_target_on_probe_side_at_scale(
#[case] insert_not_matched: bool,
#[case] delete_by_source: WhenNotMatchedBySource,
#[case] expected_join_type: JoinType,
) {
use datafusion::physical_plan::{displayable, joins::HashJoinExec};

// Target with > 128K rows so the optimizer's collect/swap logic does not
// pull the target onto the build side.
let data = lance_datagen::gen_batch()
.with_seed(Seed::from(1))
.col("value", array::step::<UInt32Type>())
.col("key", array::step::<UInt64Type>());
let data = data.into_reader_rows(RowCount::from(50_000), BatchCount::from(8)); // 400K rows
let ds = Dataset::write(data, "memory://", None).await.unwrap();

let mut builder =
crate::dataset::MergeInsertBuilder::try_new(Arc::new(ds), vec!["key".to_string()])
.unwrap();
builder.when_matched(crate::dataset::WhenMatched::UpdateAll);
builder.when_not_matched(if insert_not_matched {
crate::dataset::WhenNotMatched::InsertAll
} else {
crate::dataset::WhenNotMatched::DoNothing
});
if matches!(delete_by_source, WhenNotMatchedBySource::Delete) {
builder.when_not_matched_by_source(WhenNotMatchedBySource::Delete);
}
let job = builder.try_build().unwrap();

// A small source — the side that should be hashed/built.
let new_data = lance_datagen::gen_batch()
.with_seed(Seed::from(2))
.col("value", array::step::<UInt32Type>())
.col("key", array::step::<UInt64Type>());
let new_data = new_data.into_reader_rows(RowCount::from(1000), BatchCount::from(1));
let stream = reader_to_stream(Box::new(new_data));
let plan = job.create_plan(stream).await.unwrap();

// Locate the HashJoinExec in the physical plan.
fn find_hash_join(plan: &Arc<dyn ExecutionPlan>) -> Option<Arc<dyn ExecutionPlan>> {
if plan.as_any().is::<HashJoinExec>() {
return Some(plan.clone());
}
for child in plan.children() {
if let Some(found) = find_hash_join(child) {
return Some(found);
}
}
None
}

let rendered = format!("{}", displayable(plan.as_ref()).indent(true));
let hash_join = find_hash_join(&plan)
.unwrap_or_else(|| panic!("expected a HashJoinExec in the plan:\n{rendered}"));
let hash_join = hash_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect("HashJoinExec");

// Sanity-check that the shape produced the join type we intended to
// exercise, so the source-left operand-order mapping stays in sync with
// `create_plan_join_type`.
assert_eq!(
hash_join.join_type(),
&expected_join_type,
"unexpected join type for this merge shape; plan was:\n{rendered}"
);

// The target scan must be the right (probe) input, not the left (build)
// input — that is the whole point of building source.join(target).
//
// This holds regardless of partition mode: with a large target neither
// `CollectLeft` (when DataFusion's target_partitions is 1) nor
// `Partitioned` (the multi-core default) swaps the target onto the build
// side, because the source stream reports no statistics for
// `should_swap_join_order` to act on. We assert the orientation rather
// than the mode so the test does not depend on the host core count.
let right_has_lance_scan =
format!("{}", displayable(hash_join.right().as_ref()).indent(true))
.contains("LanceRead");
let left_has_lance_scan =
format!("{}", displayable(hash_join.left().as_ref()).indent(true))
.contains("LanceRead");
assert!(
right_has_lance_scan && !left_has_lance_scan,
"target (LanceRead) must be the probe (right) side of the hash join so it \
is streamed rather than materialized; plan was:\n{rendered}"
);
}

#[tokio::test]
async fn test_skip_auto_cleanup() {
let tmpdir = TempStrDir::default();
Expand Down
Loading