Skip to content
Open
243 changes: 198 additions & 45 deletions datafusion/datasource-parquet/src/opener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::{
ParquetAccessPlan, ParquetFileMetrics, ParquetFileReaderFactory,
};
use arrow::array::{RecordBatch, RecordBatchOptions};
use arrow::datatypes::DataType;
use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener};
use datafusion_physical_expr::projection::ProjectionExprs;
use datafusion_physical_expr::utils::reassign_expr_columns;
Expand All @@ -35,8 +36,10 @@ use std::task::{Context, Poll};

use arrow::datatypes::{SchemaRef, TimeUnit};
use datafusion_common::encryption::FileDecryptionProperties;

use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_common::stats::Precision;
use datafusion_common::{
exec_err, ColumnStatistics, DataFusionError, Result, ScalarValue, Statistics,
};
use datafusion_datasource::{PartitionedFile, TableSchema};
use datafusion_physical_expr::simplifier::PhysicalExprSimplifier;
use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory;
Expand Down Expand Up @@ -137,59 +140,60 @@ impl FileOpener for ParquetOpener {

let batch_size = self.batch_size;

// Build partition values map for replacing partition column references
// with their literal values from this file's partition values.
// Calculate the output schema from the original projection (before literal replacement)
// so we get correct field names from column references
let logical_file_schema = Arc::clone(self.table_schema.file_schema());
let output_schema = Arc::new(
self.projection
.project_schema(self.table_schema.table_schema())?,
);

// Build a combined map for replacing column references with literal values.
// This includes:
// 1. Partition column values from the file path (e.g., region=us-west-2)
// 2. Constant columns detected from file statistics (where min == max)
//
// For example, given
// 1. `region` is a partition column,
// 2. predicate `host IN ('us-east-1', 'eu-central-1')`:
// 3. The file path is `/data/region=us-west-2/...`
// (that is the partition column value is `us-west-2`)
// Although partition columns *are* constant columns, we don't want to rely on
// statistics for them being populated if we can use the partition values
// (which are guaranteed to be present).
//
// The predicate would be rewritten to
// ```sql
// 'us-west-2` IN ('us-east-1', 'eu-central-1')
// ```
// which can be further simplified to `FALSE`, meaning
// the file can be skipped entirely.
// For example, given a partition column `region` and predicate
// `region IN ('us-east-1', 'eu-central-1')` with file path
// `/data/region=us-west-2/...`, the predicate is rewritten to
// `'us-west-2' IN ('us-east-1', 'eu-central-1')` which simplifies to FALSE.
//
// While this particular optimization is done during logical planning,
// there are other cases where partition columns may appear in more
// complex predicates that cannot be simplified until we are about to
// open the file (such as dynamic predicates)
let partition_values: HashMap<&str, &ScalarValue> = self
// While partition column optimization is done during logical planning,
// there are cases where partition columns may appear in more complex
// predicates that cannot be simplified until we open the file (such as
// dynamic predicates).
let mut literal_columns: HashMap<String, ScalarValue> = self
.table_schema
.table_partition_cols()
.iter()
.zip(partitioned_file.partition_values.iter())
.map(|(field, value)| (field.name().as_str(), value))
.map(|(field, value)| (field.name().clone(), value.clone()))
.collect();
// Add constant columns from file statistics.
// Note that if there are statistics for partition columns there will be overlap,
// but since we use a HashMap, we'll just overwrite the partition values with the
// constant values from statistics (which should be the same).
literal_columns.extend(constant_columns_from_stats(
partitioned_file.statistics.as_deref(),
&logical_file_schema,
));

// Calculate the output schema from the original projection (before literal replacement)
// so we get correct field names from column references
let logical_file_schema = Arc::clone(self.table_schema.file_schema());
let output_schema = Arc::new(
self.projection
.project_schema(self.table_schema.table_schema())?,
);

// Apply partition column replacement to projection expressions
// Apply literal replacements to projection and predicate
let mut projection = self.projection.clone();
if !partition_values.is_empty() {
let mut predicate = self.predicate.clone();
if !literal_columns.is_empty() {
projection = projection.try_map_exprs(|expr| {
replace_columns_with_literals(Arc::clone(&expr), &partition_values)
replace_columns_with_literals(Arc::clone(&expr), &literal_columns)
})?;
predicate = predicate
.map(|p| replace_columns_with_literals(p, &literal_columns))
.transpose()?;
}

// Apply partition column replacement to predicate
let mut predicate = if partition_values.is_empty() {
self.predicate.clone()
} else {
self.predicate
.clone()
.map(|p| replace_columns_with_literals(p, &partition_values))
.transpose()?
};
let reorder_predicates = self.reorder_filters;
let pushdown_filters = self.pushdown_filters;
let force_filter_selections = self.force_filter_selections;
Expand Down Expand Up @@ -580,6 +584,67 @@ fn copy_arrow_reader_metrics(
}
}

type ConstantColumns = HashMap<String, ScalarValue>;

/// Extract constant column values from statistics, keyed by column name in the logical file schema.
fn constant_columns_from_stats(
statistics: Option<&Statistics>,
file_schema: &SchemaRef,
) -> ConstantColumns {
let mut constants = HashMap::new();
let Some(statistics) = statistics else {
return constants;
};

let num_rows = match statistics.num_rows {
Precision::Exact(num_rows) => Some(num_rows),
_ => None,
};

for (idx, column_stats) in statistics
.column_statistics
.iter()
.take(file_schema.fields().len())
.enumerate()
{
let field = file_schema.field(idx);
if let Some(value) =
constant_value_from_stats(column_stats, num_rows, field.data_type())
{
constants.insert(field.name().clone(), value);
}
}

constants
}

fn constant_value_from_stats(
column_stats: &ColumnStatistics,
num_rows: Option<usize>,
data_type: &DataType,
) -> Option<ScalarValue> {
if let (Precision::Exact(min), Precision::Exact(max)) =
(&column_stats.min_value, &column_stats.max_value)
{
if min == max
&& !min.is_null()
&& matches!(column_stats.null_count, Precision::Exact(0))
{
return Some(min.clone());
}
}

if let (Some(num_rows), Precision::Exact(nulls)) =
(num_rows, &column_stats.null_count)
{
if *nulls == num_rows {
return ScalarValue::try_new_null(data_type).ok();
}
}

None
}

/// Wraps an inner RecordBatchStream and a [`FilePruner`]
///
/// This can terminate the scan early when some dynamic filters is updated after
Expand Down Expand Up @@ -840,7 +905,8 @@ fn should_enable_page_index(
mod test {
use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema};
use super::{constant_columns_from_stats, ConstantColumns};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use bytes::{BufMut, BytesMut};
use datafusion_common::{
record_batch, stats::Precision, ColumnStatistics, DataFusionError, ScalarValue,
Expand All @@ -849,17 +915,104 @@ mod test {
use datafusion_datasource::{file_stream::FileOpener, PartitionedFile, TableSchema};
use datafusion_expr::{col, lit};
use datafusion_physical_expr::{
expressions::DynamicFilterPhysicalExpr, planner::logical2physical,
projection::ProjectionExprs, PhysicalExpr,
expressions::{Column, DynamicFilterPhysicalExpr, Literal},
planner::logical2physical,
projection::ProjectionExprs,
PhysicalExpr,
};
use datafusion_physical_expr_adapter::{
replace_columns_with_literals, DefaultPhysicalExprAdapterFactory,
};
use datafusion_physical_expr_adapter::DefaultPhysicalExprAdapterFactory;
use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet;
use futures::{Stream, StreamExt};
use object_store::{memory::InMemory, path::Path, ObjectStore};
use parquet::arrow::ArrowWriter;

use crate::{opener::ParquetOpener, DefaultParquetFileReaderFactory};

fn constant_int_stats() -> (Statistics, SchemaRef) {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]));
let statistics = Statistics {
num_rows: Precision::Exact(3),
total_byte_size: Precision::Absent,
column_statistics: vec![
ColumnStatistics {
null_count: Precision::Exact(0),
max_value: Precision::Exact(ScalarValue::from(5i32)),
min_value: Precision::Exact(ScalarValue::from(5i32)),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Absent,
},
ColumnStatistics::new_unknown(),
],
};
(statistics, schema)
}

#[test]
fn extract_constant_columns_non_null() {
let (statistics, schema) = constant_int_stats();
let constants = constant_columns_from_stats(Some(&statistics), &schema);
assert_eq!(constants.len(), 1);
assert_eq!(constants.get("a"), Some(&ScalarValue::from(5i32)));
assert!(!constants.contains_key("b"));
}

#[test]
fn extract_constant_columns_all_null() {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
let statistics = Statistics {
num_rows: Precision::Exact(2),
total_byte_size: Precision::Absent,
column_statistics: vec![ColumnStatistics {
null_count: Precision::Exact(2),
max_value: Precision::Absent,
min_value: Precision::Absent,
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
byte_size: Precision::Absent,
}],
};

let constants = constant_columns_from_stats(Some(&statistics), &schema);
assert_eq!(
constants.get("a"),
Some(&ScalarValue::Utf8(None)),
"all-null column should be treated as constant null"
);
}

#[test]
fn rewrite_projection_to_literals() {
let (statistics, schema) = constant_int_stats();
let constants = constant_columns_from_stats(Some(&statistics), &schema);
let projection = ProjectionExprs::from_indices(&[0, 1], &schema);

let rewritten = projection
.try_map_exprs(|expr| replace_columns_with_literals(expr, &constants))
.unwrap();
let exprs = rewritten.as_ref();
assert!(exprs[0].expr.as_any().downcast_ref::<Literal>().is_some());
assert!(exprs[1].expr.as_any().downcast_ref::<Column>().is_some());

// Only column `b` should remain in the projection mask
assert_eq!(rewritten.column_indices(), vec![1]);
}

#[test]
fn rewrite_physical_expr_literal() {
let mut constants = ConstantColumns::new();
constants.insert("a".to_string(), ScalarValue::from(7i32));
let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("a", 0));

let rewritten = replace_columns_with_literals(expr, &constants).unwrap();
assert!(rewritten.as_any().downcast_ref::<Literal>().is_some());
}

async fn count_batches_and_rows(
mut stream: std::pin::Pin<
Box<
Expand Down
18 changes: 13 additions & 5 deletions datafusion/physical-expr-adapter/src/schema_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
//! [`PhysicalExprAdapterFactory`], default implementations,
//! and [`replace_columns_with_literals`].

use std::borrow::Borrow;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::Arc;

use arrow::compute::can_cast_types;
Expand Down Expand Up @@ -50,19 +52,25 @@ use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
/// # Arguments
/// - `expr`: The physical expression in which to replace column references.
/// - `replacements`: A mapping from column names to their corresponding literal `ScalarValue`s.
/// Accepts various HashMap types including `HashMap<&str, &ScalarValue>`,
/// `HashMap<String, ScalarValue>`, `HashMap<String, &ScalarValue>`, etc.
///
/// # Returns
/// - `Result<Arc<dyn PhysicalExpr>>`: The rewritten physical expression with columns replaced by literals.
pub fn replace_columns_with_literals(
pub fn replace_columns_with_literals<K, V>(
expr: Arc<dyn PhysicalExpr>,
replacements: &HashMap<&str, &ScalarValue>,
) -> Result<Arc<dyn PhysicalExpr>> {
expr.transform(|expr| {
replacements: &HashMap<K, V>,
) -> Result<Arc<dyn PhysicalExpr>>
where
K: Borrow<str> + Eq + Hash,
V: Borrow<ScalarValue>,
{
expr.transform_down(|expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>()
&& let Some(replacement_value) = replacements.get(column.name())
{
return Ok(Transformed::yes(expressions::lit(
(*replacement_value).clone(),
replacement_value.borrow().clone(),
)));
}
Ok(Transformed::no(expr))
Expand Down
Loading