diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 11cf54c904ac8..050732a380e06 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -195,6 +195,68 @@ impl TopKThreshold { } } +#[derive(Clone, Copy)] +struct TopKHeapBoundaryRow<'a> { + row: &'a TopKRow, +} + +impl<'a> TopKHeapBoundaryRow<'a> { + fn new(row: &'a TopKRow) -> Self { + Self { row } + } + + fn full_sort_key_row(&self) -> &[u8] { + self.row.row() + } + + fn is_more_selective_than(&self, current: Option<&TopKThreshold>) -> bool { + current + .map(|current| self.full_sort_key_row() < current.full_sort_key_row()) + .unwrap_or(true) + } +} + +#[derive(Clone, Copy)] +struct TopKHeapBoundary<'a> { + row: &'a TopKRow, + batch: &'a RecordBatch, +} + +impl<'a> TopKHeapBoundary<'a> { + fn new(row: &'a TopKRow, batch: &'a RecordBatch) -> Self { + Self { row, batch } + } + + fn threshold_values( + &self, + sort_exprs: &[PhysicalSortExpr], + ) -> Result> { + let mut scalar_values = Vec::with_capacity(sort_exprs.len()); + for sort_expr in sort_exprs { + let value = sort_expr + .expr + .evaluate(&self.batch.slice(self.row.index, 1))?; + + let scalar = match value { + ColumnarValue::Scalar(scalar) => scalar, + ColumnarValue::Array(array) if array.len() == 1 => { + ScalarValue::try_from_array(&array, 0)? + } + array => { + return internal_err!("Expected a scalar value, got {:?}", array); + } + }; + scalar_values.push(scalar); + } + + Ok(scalar_values) + } + + fn threshold(&self, common_prefix_row: Option>) -> TopKThreshold { + TopKThreshold::new(self.row.row().to_vec(), common_prefix_row) + } +} + impl TopKDynamicFilters { /// Create a new `TopKDynamicFilters` with the given expression pub fn new(expr: Arc) -> Self { @@ -424,6 +486,28 @@ impl TopK { replacements } + fn current_heap_boundary_row(&self) -> Option> { + self.heap.max().map(TopKHeapBoundaryRow::new) + } + + fn current_heap_boundary(&self) -> Result>> { + let Some(row) = self.heap.max() else { + return Ok(None); + }; + + self.heap_boundary(row).map(Some) + } + + fn heap_boundary<'a>(&'a self, row: &'a TopKRow) -> Result> { + let batch_entry = self + .heap + .store + .get(row.batch_id) + .ok_or_else(|| internal_datafusion_err!("Invalid batch ID in TopKRow"))?; + + Ok(TopKHeapBoundary::new(row, &batch_entry.batch)) + } + /// Update the filter representation of our TopK heap. /// For example, given the sort expression `ORDER BY a DESC, b ASC LIMIT 3`, /// and the current heap values `[(1, 5), (1, 4), (2, 3)]`, @@ -436,42 +520,31 @@ impl TopK { /// ``` fn update_filter(&mut self) -> Result<()> { // If the heap doesn't have k elements yet, we can't create thresholds - let Some(max_row) = self.heap.max() else { + let Some(boundary_row) = self.current_heap_boundary_row() else { return Ok(()); }; - let new_threshold_row = max_row.row(); - // Fast path: check if the current value in topk is better than what is // currently set in the filter with a read only lock - let needs_update = self - .filter - .read() - .shared_threshold - .as_ref() - .map(|current_threshold| { - // new < current means new threshold is more selective - new_threshold_row < current_threshold.full_sort_key_row() - }) - .unwrap_or(true); // No current threshold, so we need to set one + let needs_update = { + let filter = self.filter.read(); + boundary_row.is_more_selective_than(filter.shared_threshold.as_ref()) + }; // exit early if the current values are better if !needs_update { return Ok(()); } + let boundary = self.heap_boundary(boundary_row.row)?; + // Extract scalar values BEFORE acquiring lock to reduce critical section - let thresholds = match self.heap.get_threshold_values(&self.expr)? { - Some(t) => t, - None => return Ok(()), - }; + let thresholds = boundary.threshold_values(&self.expr)?; // Build the filter expression OUTSIDE any synchronization let predicate = Self::build_filter_expression(&self.expr, &thresholds)?; - let new_threshold = TopKThreshold::new( - new_threshold_row.to_vec(), - self.encode_topk_common_prefix_row(max_row)?, - ); + let new_threshold = + boundary.threshold(self.encode_topk_common_prefix_row(boundary)?); // update the threshold. Since there was a lock gap, we must check if it is still the best // may have changed while we were building the expression without the lock @@ -629,44 +702,45 @@ impl TopK { return Ok(()); } - // Early exit if the heap is not full (`heap.max()` only returns `Some` if the heap is full). - let Some(max_topk_row) = self.heap.max() else { - return Ok(()); - }; - - // Encode the local heap max row's common-prefix projection. - let Some(heap_common_prefix_row) = - self.encode_topk_common_prefix_row(max_topk_row)? - else { + // Early exit only from the local heap once it has a full boundary row. + let Some(boundary) = self.current_heap_boundary()? else { return Ok(()); }; - // If the last row's prefix is strictly greater than the max prefix, mark as finished. - if batch_common_prefix > heap_common_prefix_row.as_slice() { + if self.batch_prefix_exceeds_heap_boundary(batch_common_prefix, boundary)? { self.finished = true; } Ok(()) } + fn batch_prefix_exceeds_heap_boundary( + &self, + batch_common_prefix: &[u8], + boundary: TopKHeapBoundary<'_>, + ) -> Result { + let Some(heap_common_prefix_row) = + self.encode_topk_common_prefix_row(boundary)? + else { + return Ok(false); + }; + + Ok(batch_common_prefix > heap_common_prefix_row.as_slice()) + } + fn encode_topk_common_prefix_row( &self, - topk_row: &TopKRow, + boundary: TopKHeapBoundary<'_>, ) -> Result>> { let Some(prefix_converter) = &self.common_sort_prefix_converter else { return Ok(None); }; - let store_entry = self - .heap - .store - .get(topk_row.batch_id) - .ok_or(internal_datafusion_err!("Invalid batch id in topK heap"))?; let mut scratch = prefix_converter.empty_rows(1, ESTIMATED_BYTES_PER_ROW); self.append_common_prefix_row( prefix_converter, - &store_entry.batch, - topk_row.index, + boundary.batch, + boundary.row.index, &mut scratch, )?; Ok(Some(scratch.row(0).as_ref().to_vec())) @@ -951,47 +1025,6 @@ impl TopKHeap { + self.store.size() + self.owned_bytes } - - fn get_threshold_values( - &self, - sort_exprs: &[PhysicalSortExpr], - ) -> Result>> { - // If the heap doesn't have k elements yet, we can't create thresholds - let max_row = match self.max() { - Some(row) => row, - None => return Ok(None), - }; - - // Get the batch that contains the max row - let batch_entry = match self.store.get(max_row.batch_id) { - Some(entry) => entry, - None => return internal_err!("Invalid batch ID in TopKRow"), - }; - - // Extract threshold values for each sort expression - let mut scalar_values = Vec::with_capacity(sort_exprs.len()); - for sort_expr in sort_exprs { - // Extract the value for this column from the max row - let expr = Arc::clone(&sort_expr.expr); - let value = expr.evaluate(&batch_entry.batch.slice(max_row.index, 1))?; - - // Convert to scalar value - should be a single value since we're evaluating on a single row batch - let scalar = match value { - ColumnarValue::Scalar(scalar) => scalar, - ColumnarValue::Array(array) if array.len() == 1 => { - // Extract the first (and only) value from the array - ScalarValue::try_from_array(&array, 0)? - } - array => { - return internal_err!("Expected a scalar value, got {:?}", array); - } - }; - - scalar_values.push(scalar); - } - - Ok(Some(scalar_values)) - } } /// Represents one of the top K rows held in this heap. Orders