Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 113 additions & 81 deletions datafusion/physical-plan/src/topk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,67 @@ impl TopKThreshold {
}
}

#[derive(Clone, Copy)]
struct TopKHeapBoundaryRow<'a> {
row: &'a TopKRow,
}

impl<'a> TopKHeapBoundaryRow<'a> {
fn new(row: &'a TopKRow) -> Self {
Self { row }
}

fn full_sort_key_row(&self) -> &[u8] {
self.row.row()
}

fn is_more_selective_than(&self, current: Option<&TopKThreshold>) -> bool {
current
.map(|current| self.full_sort_key_row() < current.full_sort_key_row())
.unwrap_or(true)
}
}

#[derive(Clone, Copy)]
struct TopKHeapBoundary<'a> {
row: &'a TopKRow,
batch: &'a RecordBatch,
}

impl<'a> TopKHeapBoundary<'a> {
fn new(row: &'a TopKRow, batch: &'a RecordBatch) -> Self {
Self { row, batch }
}

fn threshold_values(
&self,
sort_exprs: &[PhysicalSortExpr],
) -> Result<Vec<ScalarValue>> {
let mut scalar_values = Vec::with_capacity(sort_exprs.len());
for sort_expr in sort_exprs {
let expr = Arc::clone(&sort_expr.expr);
let value = expr.evaluate(&self.batch.slice(self.row.index, 1))?;

let scalar = match value {
ColumnarValue::Scalar(scalar) => scalar,
ColumnarValue::Array(array) if array.len() == 1 => {
ScalarValue::try_from_array(&array, 0)?
}
array => {
return internal_err!("Expected a scalar value, got {:?}", array);
}
};
scalar_values.push(scalar);
}

Ok(scalar_values)
}

fn threshold(&self, common_prefix_row: Option<Vec<u8>>) -> TopKThreshold {
TopKThreshold::new(self.row.row().to_vec(), common_prefix_row)
}
}

impl TopKDynamicFilters {
/// Create a new `TopKDynamicFilters` with the given expression
pub fn new(expr: Arc<DynamicFilterPhysicalExpr>) -> Self {
Expand Down Expand Up @@ -424,6 +485,28 @@ impl TopK {
replacements
}

fn current_heap_boundary_row(&self) -> Option<TopKHeapBoundaryRow<'_>> {
self.heap.max().map(TopKHeapBoundaryRow::new)
}

fn current_heap_boundary(&self) -> Result<Option<TopKHeapBoundary<'_>>> {
let Some(row) = self.heap.max() else {
return Ok(None);
};

self.heap_boundary(row).map(Some)
}

fn heap_boundary<'a>(&'a self, row: &'a TopKRow) -> Result<TopKHeapBoundary<'a>> {
let batch_entry = self
.heap
.store
.get(row.batch_id)
.ok_or_else(|| internal_datafusion_err!("Invalid batch ID in TopKRow"))?;

Ok(TopKHeapBoundary::new(row, &batch_entry.batch))
}

/// Update the filter representation of our TopK heap.
/// For example, given the sort expression `ORDER BY a DESC, b ASC LIMIT 3`,
/// and the current heap values `[(1, 5), (1, 4), (2, 3)]`,
Expand All @@ -436,42 +519,31 @@ impl TopK {
/// ```
fn update_filter(&mut self) -> Result<()> {
// If the heap doesn't have k elements yet, we can't create thresholds
let Some(max_row) = self.heap.max() else {
let Some(boundary_row) = self.current_heap_boundary_row() else {
return Ok(());
};

let new_threshold_row = max_row.row();

// Fast path: check if the current value in topk is better than what is
// currently set in the filter with a read only lock
let needs_update = self
.filter
.read()
.shared_threshold
.as_ref()
.map(|current_threshold| {
// new < current means new threshold is more selective
new_threshold_row < current_threshold.full_sort_key_row()
})
.unwrap_or(true); // No current threshold, so we need to set one
let needs_update = {
let filter = self.filter.read();
boundary_row.is_more_selective_than(filter.shared_threshold.as_ref())
};

// exit early if the current values are better
if !needs_update {
return Ok(());
}

let boundary = self.heap_boundary(boundary_row.row)?;

// Extract scalar values BEFORE acquiring lock to reduce critical section
let thresholds = match self.heap.get_threshold_values(&self.expr)? {
Some(t) => t,
None => return Ok(()),
};
let thresholds = boundary.threshold_values(&self.expr)?;

// Build the filter expression OUTSIDE any synchronization
let predicate = Self::build_filter_expression(&self.expr, &thresholds)?;
let new_threshold = TopKThreshold::new(
new_threshold_row.to_vec(),
self.encode_topk_common_prefix_row(max_row)?,
);
let new_threshold =
boundary.threshold(self.encode_topk_common_prefix_row(boundary)?);

// update the threshold. Since there was a lock gap, we must check if it is still the best
// may have changed while we were building the expression without the lock
Expand Down Expand Up @@ -629,44 +701,45 @@ impl TopK {
return Ok(());
}

// Early exit if the heap is not full (`heap.max()` only returns `Some` if the heap is full).
let Some(max_topk_row) = self.heap.max() else {
// Early exit only from the local heap once it has a full boundary row.
let Some(boundary) = self.current_heap_boundary()? else {
return Ok(());
};

// Encode the local heap max row's common-prefix projection.
let Some(heap_common_prefix_row) =
self.encode_topk_common_prefix_row(max_topk_row)?
else {
return Ok(());
};

// If the last row's prefix is strictly greater than the max prefix, mark as finished.
if batch_common_prefix > heap_common_prefix_row.as_slice() {
if self.batch_prefix_exceeds_heap_boundary(batch_common_prefix, boundary)? {
self.finished = true;
}

Ok(())
}

fn batch_prefix_exceeds_heap_boundary(
&self,
batch_common_prefix: &[u8],
boundary: TopKHeapBoundary<'_>,
) -> Result<bool> {
let Some(heap_common_prefix_row) =
self.encode_topk_common_prefix_row(boundary)?
else {
return Ok(false);
};

Ok(batch_common_prefix > heap_common_prefix_row.as_slice())
}

fn encode_topk_common_prefix_row(
&self,
topk_row: &TopKRow,
boundary: TopKHeapBoundary<'_>,
) -> Result<Option<Vec<u8>>> {
let Some(prefix_converter) = &self.common_sort_prefix_converter else {
return Ok(None);
};

let store_entry = self
.heap
.store
.get(topk_row.batch_id)
.ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?;
let mut scratch = prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW);
self.append_common_prefix_row(
prefix_converter,
&store_entry.batch,
topk_row.index,
boundary.batch,
boundary.row.index,
&mut scratch,
)?;
Ok(Some(scratch.row(0).as_ref().to_vec()))
Expand Down Expand Up @@ -951,47 +1024,6 @@ impl TopKHeap {
+ self.store.size()
+ self.owned_bytes
}

fn get_threshold_values(
&self,
sort_exprs: &[PhysicalSortExpr],
) -> Result<Option<Vec<ScalarValue>>> {
// If the heap doesn't have k elements yet, we can't create thresholds
let max_row = match self.max() {
Some(row) => row,
None => return Ok(None),
};

// Get the batch that contains the max row
let batch_entry = match self.store.get(max_row.batch_id) {
Some(entry) => entry,
None => return internal_err!("Invalid batch ID in TopKRow"),
};

// Extract threshold values for each sort expression
let mut scalar_values = Vec::with_capacity(sort_exprs.len());
for sort_expr in sort_exprs {
// Extract the value for this column from the max row
let expr = Arc::clone(&sort_expr.expr);
let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?;

// Convert to scalar value - should be a single value since we're evaluating on a single row batch
let scalar = match value {
ColumnarValue::Scalar(scalar) => scalar,
ColumnarValue::Array(array) if array.len() == 1 => {
// Extract the first (and only) value from the array
ScalarValue::try_from_array(&array, 0)?
}
array => {
return internal_err!("Expected a scalar value, got {:?}", array);
}
};

scalar_values.push(scalar);
}

Ok(Some(scalar_values))
}
}

/// Represents one of the top K rows held in this heap. Orders
Expand Down