diff --git a/Cargo.lock b/Cargo.lock index 45421a862ad77..50d6f8588b8e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2543,10 +2543,12 @@ name = "datafusion-pruning" version = "52.3.0" dependencies = [ "arrow", + "async-trait", "datafusion-common", "datafusion-datasource", "datafusion-expr", "datafusion-expr-common", + "datafusion-functions-aggregate", "datafusion-functions-nested", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -2554,6 +2556,7 @@ dependencies = [ "insta", "itertools 0.14.0", "log", + "tokio", ] [[package]] diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 4cf5cc366158b..620aba294a4d8 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -27,8 +27,8 @@ use arrow::datatypes::{DataType, Schema}; /// Represents a value with a degree of certainty. `Precision` is used to /// propagate information the precision of statistical values. -#[derive(Clone, PartialEq, Eq, Default, Copy)] -pub enum Precision { +#[derive(Clone, Default, Copy)] +pub enum Precision { /// The exact value is known. Used for guaranteeing correctness. /// /// Comes from definitive sources such as: @@ -60,7 +60,7 @@ pub enum Precision { Absent, } -impl Precision { +impl Precision { /// If we have some value (exact or inexact), it returns that value. /// Otherwise, it returns `None`. pub fn get_value(&self) -> Option<&T> { @@ -75,7 +75,7 @@ impl Precision { pub fn map(self, f: F) -> Precision where F: Fn(T) -> U, - U: Debug + Clone + PartialEq + Eq + PartialOrd, + U: Debug + Clone, { match self { Precision::Exact(val) => Precision::Exact(f(val)), @@ -94,6 +94,16 @@ impl Precision { } } + /// Demotes the precision state from exact to inexact (if present). + pub fn to_inexact(self) -> Self { + match self { + Precision::Exact(value) => Precision::Inexact(value), + _ => self, + } + } +} + +impl Precision { /// Returns the maximum of two (possibly inexact) values, conservatively /// propagating exactness information. If one of the input values is /// [`Precision::Absent`], the result is `Absent` too. @@ -127,14 +137,6 @@ impl Precision { (_, _) => Precision::Absent, } } - - /// Demotes the precision state from exact to inexact (if present). - pub fn to_inexact(self) -> Self { - match self { - Precision::Exact(value) => Precision::Inexact(value), - _ => self, - } - } } impl Precision { @@ -318,7 +320,23 @@ impl Precision { } } -impl Debug for Precision { +impl PartialEq for Precision +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Precision::Exact(a), Precision::Exact(b)) => a == b, + (Precision::Inexact(a), Precision::Inexact(b)) => a == b, + (Precision::Absent, Precision::Absent) => true, + _ => false, + } + } +} + +impl Eq for Precision {} + +impl Debug for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({inner:?})"), @@ -328,7 +346,7 @@ impl Debug for Precision { } } -impl Display for Precision { +impl Display for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({inner:?})"), diff --git a/datafusion/pruning/Cargo.toml b/datafusion/pruning/Cargo.toml index e6f4bb6f273c9..9bf2cc8321320 100644 --- a/datafusion/pruning/Cargo.toml +++ b/datafusion/pruning/Cargo.toml @@ -17,9 +17,13 @@ workspace = true [dependencies] arrow = { workspace = true } +async-trait = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-datasource = { workspace = true } +datafusion-expr = { workspace = true, default-features = true } datafusion-expr-common = { workspace = true, default-features = true } +datafusion-functions-aggregate = { workspace = true, default-features = true } +datafusion-functions-nested = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } @@ -30,3 +34,4 @@ datafusion-expr = { workspace = true } datafusion-functions-nested = { workspace = true } insta = { workspace = true } itertools = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/pruning/src/lib.rs b/datafusion/pruning/src/lib.rs index be17f29eaafa0..4fa95cc211a97 100644 --- a/datafusion/pruning/src/lib.rs +++ b/datafusion/pruning/src/lib.rs @@ -19,9 +19,11 @@ mod file_pruner; mod pruning_predicate; +mod statistics; pub use file_pruner::FilePruner; pub use pruning_predicate::{ PredicateRewriter, PruningPredicate, PruningStatistics, RequiredColumns, UnhandledPredicateHook, build_pruning_predicate, }; +pub use statistics::{ResolvedStatistics, StatisticsSource}; diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 6f6b00e80abc2..5c06b88a68aac 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use arrow::array::AsArray; use arrow::{ - array::{ArrayRef, BooleanArray, new_null_array}, + array::{ArrayRef, BooleanArray}, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; @@ -41,6 +41,7 @@ use datafusion_common::{ ScalarValue, internal_datafusion_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, }; +use datafusion_expr::ExprFunctionExt; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::expressions::CastColumnExpr; use datafusion_physical_expr::utils::{Guarantee, LiteralGuarantee}; @@ -521,54 +522,99 @@ impl PruningPredicate { &self, statistics: &S, ) -> Result> { - let mut builder = BoolVecBuilder::new(statistics.num_containers()); + let resolved = crate::statistics::resolve_all_sync( + statistics, + &self.all_required_expressions(), + ); + self.evaluate(&resolved) + } - // Try to prove the predicate can't be true for the containers based on - // literal guarantees + /// Evaluates the pruning predicate against pre-resolved statistics. + /// + /// This is the sync evaluation phase of the two-phase + /// resolve-then-evaluate pattern. Statistics are resolved ahead of + /// time (possibly asynchronously via [`StatisticsSource`]) into a + /// [`ResolvedStatistics`] cache, then this method evaluates the + /// predicate against that cache synchronously. + /// + /// Returns the same `true`/`false` semantics as [`Self::prune`]: + /// - `true`: There MAY be rows that match the predicate + /// - `false`: There are no rows that could possibly match + /// + /// Missing entries in `resolved` are treated as unknown (null arrays), + /// which is conservative — the container will not be pruned. + /// + /// [`StatisticsSource`]: crate::StatisticsSource + /// [`ResolvedStatistics`]: crate::ResolvedStatistics + pub fn evaluate( + &self, + resolved: &crate::statistics::ResolvedStatistics, + ) -> Result> { + let mut builder = BoolVecBuilder::new(resolved.num_containers()); + + // Phase 1: Literal guarantees (InList lookups) for literal_guarantee in &self.literal_guarantees { let LiteralGuarantee { column, guarantee, literals, } = literal_guarantee; - if let Some(results) = statistics.contained(column, literals) { - match guarantee { - // `In` means the values in the column must be one of the - // values in the set for the predicate to evaluate to true. - // If `contained` returns false, that means the column is - // not any of the values so we can prune the container - Guarantee::In => builder.combine_array(&results), - // `NotIn` means the values in the column must not be - // any of the values in the set for the predicate to - // evaluate to true. If `contained` returns true, it means the - // column is only in the set of values so we can prune the - // container - Guarantee::NotIn => { - builder.combine_array(&arrow::compute::not(&results)?) - } - } - // if all containers are pruned (has rows that DEFINITELY DO NOT pass the predicate) - // can return early without evaluating the rest of predicates. + + // Build the InList Expr that corresponds to this guarantee + let in_list_expr = literal_guarantee_to_in_list( + column, + literals, + matches!(guarantee, Guarantee::NotIn), + ); + + if let Some(array) = resolved.get(&in_list_expr) + && let Some(bool_arr) = array.as_any().downcast_ref::() + { + builder.combine_array(bool_arr); if builder.check_all_pruned() { return Ok(builder.build()); } } } - // Next, try to prove the predicate can't be true for the containers based - // on min/max values - - // build a RecordBatch that contains the min/max values in the - // appropriate statistics columns for the min/max predicate - let statistics_batch = - build_statistics_record_batch(statistics, &self.required_columns)?; - - // Evaluate the pruning predicate on that record batch and append any results to the builder + // Phase 2: Min/max/null_count/row_count predicate + let statistics_batch = build_statistics_record_batch_from_resolved( + resolved, + &self.required_columns, + )?; builder.combine_value(self.predicate_expr.evaluate(&statistics_batch)?); Ok(builder.build()) } + /// Returns all expressions needed to fully evaluate this predicate, + /// including both aggregate statistics and literal guarantee InLists. + /// + /// Pass these to [`StatisticsSource::expression_statistics`] or + /// [`ResolvedStatistics::resolve`] to pre-fetch the needed data. + /// + /// [`StatisticsSource::expression_statistics`]: crate::StatisticsSource::expression_statistics + /// [`ResolvedStatistics::resolve`]: crate::ResolvedStatistics::resolve + pub fn all_required_expressions(&self) -> Vec { + let mut exprs = Vec::new(); + + // Aggregate stats from RequiredColumns + for (column, statistics_type, _field) in self.required_columns.iter() { + exprs.push(stat_type_to_expr(column, *statistics_type)); + } + + // Literal guarantee InList expressions + for lg in &self.literal_guarantees { + exprs.push(literal_guarantee_to_in_list( + &lg.column, + &lg.literals, + matches!(lg.guarantee, Guarantee::NotIn), + )); + } + + exprs + } + /// Return a reference to the input schema pub fn schema(&self) -> &SchemaRef { &self.schema @@ -913,10 +959,12 @@ impl From> for RequiredColumns { /// -------+-------- /// 5 | 1000 /// ``` +#[cfg(test)] fn build_statistics_record_batch( statistics: &S, required_columns: &RequiredColumns, ) -> Result { + use arrow::array::new_null_array; let mut arrays = Vec::::new(); // For each needed statistics column: for (column, statistics_type, stat_field) in required_columns.iter() { @@ -1943,6 +1991,96 @@ fn wrap_null_count_check_expr( ))) } +/// Convert a [`StatisticsType`] + column into the corresponding logical expression. +fn stat_type_to_expr( + column: &phys_expr::Column, + stat_type: StatisticsType, +) -> datafusion_expr::Expr { + use datafusion_expr::Expr as LExpr; + let col_expr = LExpr::Column(Column::new_unqualified(column.name())); + match stat_type { + StatisticsType::Min => { + datafusion_functions_aggregate::min_max::min_udaf().call(vec![col_expr]) + } + StatisticsType::Max => { + datafusion_functions_aggregate::min_max::max_udaf().call(vec![col_expr]) + } + StatisticsType::NullCount => { + let count_expr = datafusion_functions_aggregate::count::count_udaf() + .call(vec![LExpr::Literal(ScalarValue::Boolean(Some(true)), None)]); + count_expr + .filter(LExpr::IsNull(Box::new(col_expr))) + .build() + .expect("building count filter expr") + } + StatisticsType::RowCount => { + let count_expr = datafusion_functions_aggregate::count::count_udaf() + .call(vec![LExpr::Literal(ScalarValue::Boolean(Some(true)), None)]); + count_expr + .filter(LExpr::IsNotNull(Box::new(col_expr))) + .build() + .expect("building count filter expr") + } + } +} + +/// Convert a [`LiteralGuarantee`] into an `Expr::InList`. +fn literal_guarantee_to_in_list( + column: &Column, + literals: &HashSet, + negated: bool, +) -> datafusion_expr::Expr { + datafusion_expr::Expr::InList(datafusion_expr::expr::InList::new( + Box::new(datafusion_expr::Expr::Column(column.clone())), + literals + .iter() + .map(|s| datafusion_expr::Expr::Literal(s.clone(), None)) + .collect(), + negated, + )) +} + +/// Build a statistics [`RecordBatch`] from a [`crate::ResolvedStatistics`] cache, +/// looking up each required column's expression and falling back to null +/// arrays for missing entries. +fn build_statistics_record_batch_from_resolved( + resolved: &crate::statistics::ResolvedStatistics, + required_columns: &RequiredColumns, +) -> Result { + let mut arrays = Vec::::new(); + let num_containers = resolved.num_containers(); + + for (column, statistics_type, stat_field) in required_columns.iter() { + let data_type = stat_field.data_type(); + let stat_expr = stat_type_to_expr(column, *statistics_type); + + let array = resolved.get_or_null(&stat_expr, data_type); + + assert_eq_or_internal_err!( + num_containers, + array.len(), + "mismatched statistics length. Expected {}, got {}", + num_containers, + array.len() + ); + + let array = arrow::compute::cast(&array, data_type)?; + arrays.push(array); + } + + let schema = Arc::new(required_columns.schema()); + let mut options = RecordBatchOptions::default(); + options.row_count = Some(num_containers); + + trace!( + "Creating statistics batch from resolved for {required_columns:#?} with {arrays:#?}" + ); + + RecordBatch::try_new_with_options(schema, arrays, &options).map_err(|err| { + plan_datafusion_err!("Can not create statistics record batch: {err}") + }) +} + #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(crate) enum StatisticsType { Min, @@ -5441,4 +5579,113 @@ mod tests { "c1_null_count@2 != row_count@3 AND c1_min@0 <= a AND a <= c1_max@1"; assert_eq!(res.to_string(), expected); } + + /// Test that evaluate() produces the same results as prune() for basic predicates + #[test] + fn test_evaluate_matches_prune() { + // i > 5 with 3 containers + let schema = Schema::new(vec![Field::new("i", DataType::Int32, true)]); + let statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32( + vec![Some(1), Some(6), Some(3)], // min + vec![Some(4), Some(10), Some(8)], // max + ), + ); + + let expr = col("i").gt(lit(5i32)); + let p = + PruningPredicate::try_new(logical2physical(&expr, &schema), Arc::new(schema)) + .unwrap(); + + let prune_result = p.prune(&statistics).unwrap(); + let resolved = crate::statistics::resolve_all_sync( + &statistics, + &p.all_required_expressions(), + ); + let evaluate_result = p.evaluate(&resolved).unwrap(); + + assert_eq!(prune_result, evaluate_result); + // Container 0: max=4, 4 > 5 is false → prune + // Container 1: max=10, 10 > 5 is true → keep + // Container 2: max=8, 8 > 5 is true → keep + assert_eq!(evaluate_result, vec![false, true, true]); + } + + /// Test evaluate with null counts and row counts + #[test] + fn test_evaluate_with_null_counts() { + let schema = Schema::new(vec![Field::new("i", DataType::Int32, true)]); + let statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32(vec![Some(0), Some(0)], vec![Some(0), Some(0)]) + .with_null_counts(vec![Some(10), Some(0)]) + .with_row_counts(vec![Some(10), Some(10)]), + ); + + // i = 0: first container is all nulls, should be pruned + let expr = col("i").eq(lit(0i32)); + let p = + PruningPredicate::try_new(logical2physical(&expr, &schema), Arc::new(schema)) + .unwrap(); + + let prune_result = p.prune(&statistics).unwrap(); + let resolved = crate::statistics::resolve_all_sync( + &statistics, + &p.all_required_expressions(), + ); + let evaluate_result = p.evaluate(&resolved).unwrap(); + + assert_eq!(prune_result, evaluate_result); + } + + /// Test evaluate with missing cache entries (should produce null → conservative keep) + #[test] + fn test_evaluate_missing_cache_entries() { + let schema = Schema::new(vec![Field::new("i", DataType::Int32, true)]); + let _statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32(vec![Some(1), Some(6)], vec![Some(4), Some(10)]), + ); + + let expr = col("i").gt(lit(5i32)); + let p = + PruningPredicate::try_new(logical2physical(&expr, &schema), Arc::new(schema)) + .unwrap(); + + // Empty resolved stats — everything should be kept (conservative) + let resolved = crate::statistics::ResolvedStatistics::new_empty(2); + let evaluate_result = p.evaluate(&resolved).unwrap(); + assert_eq!(evaluate_result, vec![true, true]); + } + + /// Test that all_required_expressions() generates the right Expr types + #[test] + fn test_all_required_expressions() { + let schema = Schema::new(vec![Field::new("i", DataType::Int32, true)]); + let expr = col("i").eq(lit(5i32)); + let p = + PruningPredicate::try_new(logical2physical(&expr, &schema), Arc::new(schema)) + .unwrap(); + + let exprs = p.all_required_expressions(); + // i = 5 requires: min(i), max(i), count(*) filter (where i is null), + // count(*) filter (where i is not null) + assert!( + exprs.len() >= 2, + "Expected at least min and max, got {}", + exprs.len() + ); + + // Check that we have min and max expressions + let expr_strings: Vec = exprs.iter().map(|e| e.to_string()).collect(); + assert!( + expr_strings.iter().any(|s| s.contains("min")), + "Expected min expr in {expr_strings:?}" + ); + assert!( + expr_strings.iter().any(|s| s.contains("max")), + "Expected max expr in {expr_strings:?}" + ); + } } diff --git a/datafusion/pruning/src/statistics.rs b/datafusion/pruning/src/statistics.rs new file mode 100644 index 0000000000000..f965c23cef60b --- /dev/null +++ b/datafusion/pruning/src/statistics.rs @@ -0,0 +1,538 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, new_null_array}; +use datafusion_common::pruning::PruningStatistics; +use datafusion_expr::Expr; +use std::collections::{HashMap, HashSet}; + +use datafusion_common::error::DataFusionError; + +/// A source of runtime statistical information for pruning. +/// +/// This trait accepts a set of [`Expr`] expressions and returns +/// statistics for those expressions that can be used for pruning. +/// +/// It is up to implementors to determine how to collect these statistics. +/// Some example use cases include: +/// 1. Matching on basic expressions like `min(column)` or `max(column)` +/// and returning statistics from file metadata. +/// 2. Sampling data at runtime to get more accurate statistics. +/// 3. Querying an external metastore for statistics. +/// +/// # Supported expression types +/// +/// The following expression types are meaningful for pruning: +/// +/// - **Aggregate functions**: `min(column)`, `max(column)`, +/// `count(*) FILTER (WHERE column IS NULL)`, +/// `count(*) FILTER (WHERE column IS NOT NULL)` +/// - **InList**: `column IN (v1, v2, ...)` — see [InList semantics] below. +/// +/// Implementors return `None` for any expression they cannot answer. +/// +/// # InList semantics +/// +/// For `column IN (v1, v2, ..., vN)`, the returned `BooleanArray` has one +/// entry per container with three-valued logic: +/// +/// - `true` — the column in this container ONLY contains values in +/// `{v1, ..., vN}`. Every row in the container satisfies the `IN` +/// predicate (assuming non-null values; see below). +/// - `false` — the column in this container contains NONE of the values +/// in `{v1, ..., vN}`. No row can satisfy the `IN` predicate, so the +/// container can be pruned. +/// - `null` — it is not known whether the column contains any of the +/// values. The container cannot be pruned. +/// +/// ## Null handling +/// +/// - **Null values in the column**: SQL `IN` returns `NULL` when the +/// column value is `NULL`, regardless of the list contents. Containers +/// where the column has null values should return `null` (unknown) +/// unless the implementation can determine that all non-null values +/// still satisfy or violate the predicate. +/// - **Null values in the list** (`column IN (1, NULL, 3)`): Per SQL +/// semantics, `x IN (1, NULL, 3)` returns `TRUE` if `x` is 1 or 3, +/// `NULL` if `x` is any other non-null value (because `x = NULL` is +/// unknown), and `NULL` if `x` is `NULL`. Null literals in the list +/// therefore weaken pruning — a container can no longer return `false` +/// unless it can prove the column has no values at all. +/// - **`NOT IN` with nulls** (`column NOT IN (1, NULL, 3)`): This can +/// never return `TRUE` for non-null column values because `x != NULL` +/// is always unknown. A container can only be pruned if it is known +/// to contain exclusively values in the list. +#[async_trait::async_trait] +pub trait StatisticsSource: Send + Sync { + /// Returns the number of containers (row groups, files, etc.) that + /// statistics are provided for. All returned arrays must have this length. + fn num_containers(&self) -> usize; + + /// Returns statistics for each expression, or `None` for expressions + /// that cannot be answered. + async fn expression_statistics( + &self, + expressions: &[Expr], + ) -> Result>, DataFusionError>; +} + +/// Blanket implementation of [`StatisticsSource`] for types that implement +/// [`PruningStatistics`]. +/// +/// This allows any type that implements [`PruningStatistics`] to be used as +/// a [`StatisticsSource`] without needing to implement the trait directly. +/// +/// The implementation matches on expressions that can be directly answered +/// by the underlying [`PruningStatistics`]: +/// - `min(column)` → [`PruningStatistics::min_values`] +/// - `max(column)` → [`PruningStatistics::max_values`] +/// - `count(*) FILTER (WHERE column IS NOT NULL)` → [`PruningStatistics::row_counts`] +/// - `count(*) FILTER (WHERE column IS NULL)` → [`PruningStatistics::null_counts`] +/// - `column IN (lit1, lit2, ...)` → [`PruningStatistics::contained`] +/// +/// Any other expressions return `None`. +#[async_trait::async_trait] +impl StatisticsSource for T { + fn num_containers(&self) -> usize { + PruningStatistics::num_containers(self) + } + + async fn expression_statistics( + &self, + expressions: &[Expr], + ) -> Result>, DataFusionError> { + Ok(expressions + .iter() + .map(|expr| resolve_expression_sync(self, expr)) + .collect()) + } +} + +/// Pre-resolved statistics cache. Created asynchronously via +/// [`StatisticsSource`], evaluated synchronously by +/// [`PruningPredicate::evaluate`]. +/// +/// Keyed by [`Expr`] so that a single cache can serve multiple +/// [`PruningPredicate`](crate::PruningPredicate) instances (e.g., after dynamic filter changes +/// rebuild the predicate but reuse the same resolved stats). +/// Missing entries are treated as unknown — safe for pruning +/// (the predicate will conservatively keep the container). +/// +/// [`PruningPredicate::evaluate`]: crate::PruningPredicate::evaluate +pub struct ResolvedStatistics { + num_containers: usize, + cache: HashMap, +} + +impl ResolvedStatistics { + /// Create an empty cache with no resolved statistics. + /// All lookups will return `None`, causing `evaluate()` to use + /// null arrays (conservative — no pruning). + pub fn new_empty(num_containers: usize) -> Self { + Self { + num_containers, + cache: HashMap::new(), + } + } + + /// Resolve statistics for the given expressions from an async source. + pub async fn resolve( + source: &(impl StatisticsSource + ?Sized), + expressions: &[Expr], + ) -> Result { + let num_containers = source.num_containers(); + let arrays = source.expression_statistics(expressions).await?; + let cache = expressions + .iter() + .zip(arrays) + .filter_map(|(expr, arr)| arr.map(|a| (expr.clone(), a))) + .collect(); + Ok(Self { + num_containers, + cache, + }) + } + + /// Look up a resolved expression. Returns `None` if not in cache. + pub fn get(&self, expr: &Expr) -> Option<&ArrayRef> { + self.cache.get(expr) + } + + /// Look up a resolved expression, returning a null array of the given + /// type if the expression is not in the cache. + pub fn get_or_null( + &self, + expr: &Expr, + data_type: &arrow::datatypes::DataType, + ) -> ArrayRef { + self.cache + .get(expr) + .cloned() + .unwrap_or_else(|| new_null_array(data_type, self.num_containers)) + } + + /// Returns the number of containers these statistics cover. + pub fn num_containers(&self) -> usize { + self.num_containers + } +} + +/// Resolve a single expression synchronously against a [`PruningStatistics`] impl. +pub(crate) fn resolve_expression_sync( + stats: &(impl PruningStatistics + ?Sized), + expr: &Expr, +) -> Option { + match expr { + Expr::AggregateFunction(func) => resolve_aggregate_function(stats, func), + Expr::InList(in_list) => resolve_in_list(stats, in_list), + _ => None, + } +} + +/// Resolve all expressions synchronously against a [`PruningStatistics`] impl, +/// returning a [`ResolvedStatistics`] cache. +pub(crate) fn resolve_all_sync( + stats: &(impl PruningStatistics + ?Sized), + expressions: &[Expr], +) -> ResolvedStatistics { + let num_containers = stats.num_containers(); + let cache = expressions + .iter() + .filter_map(|expr| { + resolve_expression_sync(stats, expr).map(|arr| (expr.clone(), arr)) + }) + .collect(); + ResolvedStatistics { + num_containers, + cache, + } +} + +/// Resolve an aggregate function expression against [`PruningStatistics`]. +fn resolve_aggregate_function( + stats: &(impl PruningStatistics + ?Sized), + func: &datafusion_expr::expr::AggregateFunction, +) -> Option { + use datafusion_functions_aggregate::count::Count; + use datafusion_functions_aggregate::min_max::{Max, Min}; + + let udaf = func.func.inner(); + + if udaf.as_any().downcast_ref::().is_some() { + // min(column) — reject if there's a filter + if func.params.filter.is_some() { + return None; + } + if let Some(Expr::Column(col)) = func.params.args.first() { + return stats.min_values(col); + } + } else if udaf.as_any().downcast_ref::().is_some() { + // max(column) — reject if there's a filter + if func.params.filter.is_some() { + return None; + } + if let Some(Expr::Column(col)) = func.params.args.first() { + return stats.max_values(col); + } + } else if udaf.as_any().downcast_ref::().is_some() + && let Some(filter) = &func.params.filter + { + match filter.as_ref() { + // count(*) FILTER (WHERE col IS NOT NULL) → row_counts + Expr::IsNotNull(inner) => { + if let Expr::Column(col) = inner.as_ref() { + return stats.row_counts(col); + } + } + // count(*) FILTER (WHERE col IS NULL) → null_counts + Expr::IsNull(inner) => { + if let Expr::Column(col) = inner.as_ref() { + return stats.null_counts(col); + } + } + _ => {} + } + } + + None +} + +/// Resolve an `IN` list expression against [`PruningStatistics::contained`]. +/// +/// Only supports `column IN (literal, literal, ...)`. Returns `None` for +/// expressions with non-column left-hand sides or non-literal list items. +/// +/// For `NOT IN`, the result is inverted: `true` becomes `false` and vice +/// versa, while `null` stays `null`. See the [InList semantics] section +/// on [`StatisticsSource`] for details on null handling. +fn resolve_in_list( + stats: &(impl PruningStatistics + ?Sized), + in_list: &datafusion_expr::expr::InList, +) -> Option { + // Only support `column IN (literal, literal, ...)` + let Expr::Column(col) = in_list.expr.as_ref() else { + return None; + }; + + let mut values = HashSet::with_capacity(in_list.list.len()); + for item in &in_list.list { + match item { + Expr::Literal(scalar, _) => { + values.insert(scalar.clone()); + } + _ => return None, // non-literal in list, can't resolve + } + } + + let result = stats.contained(col, &values)?; + if in_list.negated { + // NOT IN: invert the contained result + let inverted = arrow::compute::not(&result).ok()?; + Some(std::sync::Arc::new(inverted) as ArrayRef) + } else { + Some(std::sync::Arc::new(result) as ArrayRef) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{BooleanArray, Int64Array, UInt64Array}; + use arrow::datatypes::DataType; + use datafusion_common::pruning::PruningStatistics; + use datafusion_common::{Column, ScalarValue}; + use datafusion_expr::ExprFunctionExt; + use std::sync::Arc; + + /// A simple mock PruningStatistics for testing. + #[derive(Debug)] + struct MockStats { + min: ArrayRef, + max: ArrayRef, + null_counts: ArrayRef, + row_counts: ArrayRef, + contained_result: Option, + } + + impl MockStats { + fn new() -> Self { + Self { + min: Arc::new(Int64Array::from(vec![Some(1), Some(10)])), + max: Arc::new(Int64Array::from(vec![Some(5), Some(20)])), + null_counts: Arc::new(UInt64Array::from(vec![0, 2])), + row_counts: Arc::new(UInt64Array::from(vec![100, 100])), + contained_result: None, + } + } + + fn with_contained(mut self, result: BooleanArray) -> Self { + self.contained_result = Some(result); + self + } + } + + impl PruningStatistics for MockStats { + fn min_values(&self, _column: &Column) -> Option { + Some(Arc::clone(&self.min)) + } + fn max_values(&self, _column: &Column) -> Option { + Some(Arc::clone(&self.max)) + } + fn num_containers(&self) -> usize { + self.min.len() + } + fn null_counts(&self, _column: &Column) -> Option { + Some(Arc::clone(&self.null_counts)) + } + fn row_counts(&self, _column: &Column) -> Option { + Some(Arc::clone(&self.row_counts)) + } + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + self.contained_result.clone() + } + } + + fn col_expr(name: &str) -> Expr { + Expr::Column(Column::new_unqualified(name)) + } + + #[test] + fn test_resolve_min() { + let stats = MockStats::new(); + let expr = + datafusion_functions_aggregate::min_max::min_udaf().call(vec![col_expr("a")]); + let result = resolve_expression_sync(&stats, &expr); + assert!(result.is_some()); + let arr = result.unwrap(); + assert_eq!(arr.len(), 2); + } + + #[test] + fn test_resolve_max() { + let stats = MockStats::new(); + let expr = + datafusion_functions_aggregate::min_max::max_udaf().call(vec![col_expr("a")]); + let result = resolve_expression_sync(&stats, &expr); + assert!(result.is_some()); + let arr = result.unwrap(); + assert_eq!(arr.len(), 2); + } + + #[test] + fn test_resolve_count_null() { + let stats = MockStats::new(); + let expr = datafusion_functions_aggregate::count::count_udaf() + .call(vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)]) + .filter(Expr::IsNull(Box::new(col_expr("a")))) + .build() + .unwrap(); + let result = resolve_expression_sync(&stats, &expr); + assert!(result.is_some()); + } + + #[test] + fn test_resolve_count_not_null() { + let stats = MockStats::new(); + let expr = datafusion_functions_aggregate::count::count_udaf() + .call(vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)]) + .filter(Expr::IsNotNull(Box::new(col_expr("a")))) + .build() + .unwrap(); + let result = resolve_expression_sync(&stats, &expr); + assert!(result.is_some()); + } + + #[test] + fn test_resolve_unsupported_returns_none() { + let stats = MockStats::new(); + // A plain column is not a supported expression for stats + let result = resolve_expression_sync(&stats, &col_expr("a")); + assert!(result.is_none()); + } + + #[test] + fn test_resolve_min_with_filter_returns_none() { + let stats = MockStats::new(); + // min(a) FILTER (WHERE a > 0) — not supported + let expr = datafusion_functions_aggregate::min_max::min_udaf() + .call(vec![col_expr("a")]) + .filter(col_expr("a").gt(Expr::Literal(ScalarValue::Int64(Some(0)), None))) + .build() + .unwrap(); + let result = resolve_expression_sync(&stats, &expr); + assert!(result.is_none()); + } + + #[test] + fn test_resolve_in_list() { + let stats = MockStats::new() + .with_contained(BooleanArray::from(vec![Some(true), Some(false)])); + let expr = Expr::InList(datafusion_expr::expr::InList::new( + Box::new(col_expr("a")), + vec![ + Expr::Literal(ScalarValue::Int64(Some(1)), None), + Expr::Literal(ScalarValue::Int64(Some(2)), None), + ], + false, + )); + let result = resolve_expression_sync(&stats, &expr); + assert!(result.is_some()); + let arr = result.unwrap(); + let bool_arr = arr.as_any().downcast_ref::().unwrap(); + assert!(bool_arr.value(0)); + assert!(!bool_arr.value(1)); + } + + #[test] + fn test_resolve_not_in_list() { + let stats = MockStats::new() + .with_contained(BooleanArray::from(vec![Some(true), Some(false)])); + let expr = Expr::InList(datafusion_expr::expr::InList::new( + Box::new(col_expr("a")), + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)], + true, // negated + )); + let result = resolve_expression_sync(&stats, &expr); + assert!(result.is_some()); + let arr = result.unwrap(); + let bool_arr = arr.as_any().downcast_ref::().unwrap(); + // Inverted: true→false, false→true + assert!(!bool_arr.value(0)); + assert!(bool_arr.value(1)); + } + + #[test] + fn test_resolve_all_sync_builds_cache() { + let stats = MockStats::new(); + let exprs = vec![ + datafusion_functions_aggregate::min_max::min_udaf().call(vec![col_expr("a")]), + datafusion_functions_aggregate::min_max::max_udaf().call(vec![col_expr("a")]), + col_expr("unsupported"), // should be missing from cache + ]; + + let resolved = resolve_all_sync(&stats, &exprs); + assert_eq!(resolved.num_containers(), 2); + assert!(resolved.get(&exprs[0]).is_some()); // min + assert!(resolved.get(&exprs[1]).is_some()); // max + assert!(resolved.get(&exprs[2]).is_none()); // unsupported + } + + #[test] + fn test_resolved_statistics_get_or_null() { + let stats = MockStats::new(); + let min_expr = + datafusion_functions_aggregate::min_max::min_udaf().call(vec![col_expr("a")]); + let resolved = resolve_all_sync(&stats, std::slice::from_ref(&min_expr)); + + // Existing entry + let arr = resolved.get_or_null(&min_expr, &DataType::Int64); + assert_eq!(arr.len(), 2); + assert_eq!(arr.null_count(), 0); + + // Missing entry → null array + let missing = col_expr("missing"); + let arr = resolved.get_or_null(&missing, &DataType::Int32); + assert_eq!(arr.len(), 2); + assert_eq!(arr.null_count(), 2); + } + + #[tokio::test] + async fn test_resolved_statistics_resolve_async() { + let stats = MockStats::new(); + let exprs = vec![ + datafusion_functions_aggregate::min_max::min_udaf().call(vec![col_expr("a")]), + datafusion_functions_aggregate::min_max::max_udaf().call(vec![col_expr("a")]), + ]; + + let resolved = ResolvedStatistics::resolve(&stats, &exprs).await.unwrap(); + assert_eq!(resolved.num_containers(), 2); + assert!(resolved.get(&exprs[0]).is_some()); + assert!(resolved.get(&exprs[1]).is_some()); + } + + #[test] + fn test_new_empty_resolved_statistics() { + let resolved = ResolvedStatistics::new_empty(5); + assert_eq!(resolved.num_containers(), 5); + let expr = col_expr("any"); + assert!(resolved.get(&expr).is_none()); + } +}