diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4219c24bfc9c9..b5193031a1782 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1388,6 +1388,82 @@ impl LogicalPlan { } } + /// Returns the skip (offset) of this plan node, if it has one. + /// + /// Only [`LogicalPlan::Limit`] carries a skip value; all other variants + /// return `Ok(None)`. Returns `Ok(None)` for a zero skip. + pub fn skip(&self) -> Result> { + match self { + LogicalPlan::Limit(limit) => match limit.get_skip_type()? { + SkipType::Literal(0) => Ok(None), + SkipType::Literal(n) => Ok(Some(n)), + SkipType::UnsupportedExpr => Ok(None), + }, + LogicalPlan::Sort(_) => Ok(None), + LogicalPlan::TableScan(_) => Ok(None), + LogicalPlan::Projection(_) => Ok(None), + LogicalPlan::Filter(_) => Ok(None), + LogicalPlan::Window(_) => Ok(None), + LogicalPlan::Aggregate(_) => Ok(None), + LogicalPlan::Join(_) => Ok(None), + LogicalPlan::Repartition(_) => Ok(None), + LogicalPlan::Union(_) => Ok(None), + LogicalPlan::EmptyRelation(_) => Ok(None), + LogicalPlan::Subquery(_) => Ok(None), + LogicalPlan::SubqueryAlias(_) => Ok(None), + LogicalPlan::Statement(_) => Ok(None), + LogicalPlan::Values(_) => Ok(None), + LogicalPlan::Explain(_) => Ok(None), + LogicalPlan::Analyze(_) => Ok(None), + LogicalPlan::Extension(_) => Ok(None), + LogicalPlan::Distinct(_) => Ok(None), + LogicalPlan::Dml(_) => Ok(None), + LogicalPlan::Ddl(_) => Ok(None), + LogicalPlan::Copy(_) => Ok(None), + LogicalPlan::DescribeTable(_) => Ok(None), + LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::RecursiveQuery(_) => Ok(None), + } + } + + /// Returns the fetch (limit) of this plan node, if it has one. + /// + /// [`LogicalPlan::Sort`], [`LogicalPlan::TableScan`], and + /// [`LogicalPlan::Limit`] may carry a fetch value; all other variants + /// return `Ok(None)`. + pub fn fetch(&self) -> Result> { + match self { + LogicalPlan::Sort(Sort { fetch, .. }) => Ok(*fetch), + LogicalPlan::TableScan(TableScan { fetch, .. }) => Ok(*fetch), + LogicalPlan::Limit(limit) => match limit.get_fetch_type()? { + FetchType::Literal(s) => Ok(s), + FetchType::UnsupportedExpr => Ok(None), + }, + LogicalPlan::Projection(_) => Ok(None), + LogicalPlan::Filter(_) => Ok(None), + LogicalPlan::Window(_) => Ok(None), + LogicalPlan::Aggregate(_) => Ok(None), + LogicalPlan::Join(_) => Ok(None), + LogicalPlan::Repartition(_) => Ok(None), + LogicalPlan::Union(_) => Ok(None), + LogicalPlan::EmptyRelation(_) => Ok(None), + LogicalPlan::Subquery(_) => Ok(None), + LogicalPlan::SubqueryAlias(_) => Ok(None), + LogicalPlan::Statement(_) => Ok(None), + LogicalPlan::Values(_) => Ok(None), + LogicalPlan::Explain(_) => Ok(None), + LogicalPlan::Analyze(_) => Ok(None), + LogicalPlan::Extension(_) => Ok(None), + LogicalPlan::Distinct(_) => Ok(None), + LogicalPlan::Dml(_) => Ok(None), + LogicalPlan::Ddl(_) => Ok(None), + LogicalPlan::Copy(_) => Ok(None), + LogicalPlan::DescribeTable(_) => Ok(None), + LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::RecursiveQuery(_) => Ok(None), + } + } + /// If this node's expressions contains any references to an outer subquery pub fn contains_outer_reference(&self) -> bool { let mut contains = false; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 755ffdbafc869..da4d457913ed2 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -791,6 +791,13 @@ impl OptimizerRule for PushDownFilter { filter.predicate = new_predicate; } + // If the child has a fetch (limit) or skip (offset), pushing a filter + // below it would change semantics: the limit/offset should apply before + // the filter, not after. + if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Filter(child_filter) => { let parents_predicates = split_conjunction_owned(filter.predicate); @@ -4222,4 +4229,63 @@ mod tests { " ) } + + #[test] + fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> { + let scan = test_table_scan()?; + let scan_with_fetch = match scan { + LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan { + fetch: Some(10), + ..scan + }), + _ => unreachable!(), + }; + let plan = LogicalPlanBuilder::from(scan_with_fetch) + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter must NOT be pushed into the table scan when it has a fetch (limit) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a > Int64(10) + TableScan: test, fetch=10 + " + ) + } + + #[test] + fn filter_push_down_through_sort_without_fetch() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![col("a").sort(true, true)])? + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter should be pushed below the sort + assert_optimized_plan_equal!( + plan, + @r" + Sort: test.a ASC NULLS FIRST + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) + } + + #[test] + fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort_with_limit(vec![col("a").sort(true, true)], Some(5))? + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter must NOT be pushed below the sort when it has a fetch (limit), + // because the limit should apply before the filter. + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a > Int64(10) + Sort: test.a ASC NULLS FIRST, fetch=5 + TableScan: test + " + ) + } } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index 1274e954eaeb3..ea0069d3b75be 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -40,7 +40,6 @@ use std::sync::Arc; use datafusion_common::Result; use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use itertools::Itertools; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FilterPushdownPhase { @@ -359,6 +358,17 @@ impl ChildFilterDescription { }) } + /// Mark all parent filters as unsupported for this child. + pub fn all_unsupported(parent_filters: &[Arc]) -> Self { + Self { + parent_filters: parent_filters + .iter() + .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) + .collect(), + self_filters: vec![], + } + } + /// Add a self filter (from the current node) to be pushed down to this child. pub fn with_self_filter(mut self, filter: Arc) -> Self { self.self_filters.push(filter); @@ -434,15 +444,9 @@ impl FilterDescription { children: &[&Arc], ) -> Self { let mut desc = Self::new(); - let child_filters = parent_filters - .iter() - .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) - .collect_vec(); for _ in 0..children.len() { - desc = desc.with_child(ChildFilterDescription { - parent_filters: child_filters.clone(), - self_filters: vec![], - }); + desc = + desc.with_child(ChildFilterDescription::all_unsupported(parent_filters)); } desc } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 475738cca3f05..30a0e301823f7 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -1404,12 +1404,23 @@ impl ExecutionPlan for SortExec { parent_filters: Vec>, config: &datafusion_common::config::ConfigOptions, ) -> Result { - if !matches!(phase, FilterPushdownPhase::Post) { + if phase != FilterPushdownPhase::Post { + if self.fetch.is_some() { + return Ok(FilterDescription::all_unsupported( + &parent_filters, + &self.children(), + )); + } return FilterDescription::from_children(parent_filters, &self.children()); } - let mut child = - ChildFilterDescription::from_child(&parent_filters, self.input())?; + // In Post phase: block parent filters when fetch is set, + // but still push the TopK dynamic filter (self-filter). + let mut child = if self.fetch.is_some() { + ChildFilterDescription::all_unsupported(&parent_filters) + } else { + ChildFilterDescription::from_child(&parent_filters, self.input())? + }; if let Some(filter) = &self.filter && config.optimizer.enable_topk_dynamic_filter_pushdown @@ -1430,8 +1441,10 @@ mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::collect; + use crate::empty::EmptyExec; use crate::execution_plan::Boundedness; use crate::expressions::col; + use crate::filter_pushdown::{FilterPushdownPhase, PushedDown}; use crate::test; use crate::test::TestMemoryExec; use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; @@ -1441,6 +1454,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::*; use datafusion_common::cast::as_primitive_array; + use datafusion_common::config::ConfigOptions; use datafusion_common::test_util::batches_to_string; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_execution::RecordBatchStream; @@ -2705,4 +2719,68 @@ mod tests { Ok(()) } + + fn make_sort_exec_with_fetch(fetch: Option) -> SortExec { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let input = Arc::new(EmptyExec::new(schema)); + SortExec::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), + input, + ) + .with_fetch(fetch) + } + + #[test] + fn test_sort_with_fetch_blocks_filter_pushdown() -> Result<()> { + let sort = make_sort_exec_with_fetch(Some(10)); + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Pre, + vec![Arc::new(Column::new("a", 0))], + &ConfigOptions::new(), + )?; + // Sort with fetch (TopK) must not allow filters to be pushed below it. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::No + )); + Ok(()) + } + + #[test] + fn test_sort_without_fetch_allows_filter_pushdown() -> Result<()> { + let sort = make_sort_exec_with_fetch(None); + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Pre, + vec![Arc::new(Column::new("a", 0))], + &ConfigOptions::new(), + )?; + // Plain sort (no fetch) is filter-commutative. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::Yes + )); + Ok(()) + } + + #[test] + fn test_sort_with_fetch_allows_topk_self_filter_in_post_phase() -> Result<()> { + let sort = make_sort_exec_with_fetch(Some(10)); + assert!(sort.filter.is_some(), "TopK filter should be created"); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_topk_dynamic_filter_pushdown = true; + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![Arc::new(Column::new("a", 0))], + &config, + )?; + // Parent filters are still blocked in the Post phase. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::No + )); + // But the TopK self-filter should be pushed down. + assert_eq!(desc.self_filters()[0].len(), 1); + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 96471411e0f95..8fb0edde79e25 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -869,6 +869,45 @@ limit 1000; statement ok DROP TABLE test_limit_with_partitions; +# Tests for filter pushdown behavior with Sort + LIMIT (fetch). + +statement ok +CREATE TABLE t(id INT, value INT) AS VALUES +(1, 100), +(2, 200), +(3, 300), +(4, 400), +(5, 500); + +# Take the 3 smallest values (100, 200, 300), then filter value > 200. +query II +SELECT * FROM (SELECT * FROM t ORDER BY value LIMIT 3) sub WHERE sub.value > 200; +---- +3 300 + +# Take the 3 largest values (500, 400, 300), then filter value < 400. +query II +SELECT * FROM (SELECT * FROM t ORDER BY value DESC LIMIT 3) sub WHERE sub.value < 400; +---- +3 300 + +# The filter stays above the sort+fetch in the plan. +query TT +EXPLAIN SELECT * FROM (SELECT * FROM t ORDER BY value LIMIT 3) sub WHERE sub.value > 200; +---- +logical_plan +01)SubqueryAlias: sub +02)--Filter: t.value > Int32(200) +03)----Sort: t.value ASC NULLS LAST, fetch=3 +04)------TableScan: t projection=[id, value] +physical_plan +01)FilterExec: value@1 > 200 +02)--SortExec: TopK(fetch=3), expr=[value@1 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +statement ok +DROP TABLE t; + # Tear down src_table table: statement ok DROP TABLE src_table; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 8ac8724683a8a..7feeb210f1608 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -3198,16 +3198,17 @@ EXPLAIN SELECT * FROM (SELECT *, ROW_NUMBER() OVER(ORDER BY a ASC) as rn1 ---- logical_plan 01)Sort: rn1 ASC NULLS LAST -02)--Sort: rn1 ASC NULLS LAST, fetch=5 -03)----Projection: annotated_data_infinite2.a0, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -04)------Filter: row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW < UInt64(50) +02)--Filter: rn1 < UInt64(50) +03)----Sort: rn1 ASC NULLS LAST, fetch=5 +04)------Projection: annotated_data_infinite2.a0, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 05)--------WindowAggr: windowExpr=[[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 06)----------TableScan: annotated_data_infinite2 projection=[a0, a, b, c, d] physical_plan -01)ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] -02)--FilterExec: row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 < 50, fetch=5 -03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { "row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW": UInt64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] -04)------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] +01)FilterExec: rn1@5 < 50 +02)--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] +03)----GlobalLimitExec: skip=0, fetch=5 +04)------BoundedWindowAggExec: wdw=[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Field { "row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW": UInt64 }, frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] +05)--------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] # Top level sort is pushed down through BoundedWindowAggExec as its SUM result does already satisfy the required # global order. The existing sort is for the second-term lexicographical ordering requirement, which is being