diff --git a/docs/source/contributor-guide/adding_a_new_expression.md b/docs/source/contributor-guide/adding_a_new_expression.md index 9ce9af8004..74825f4301 100644 --- a/docs/source/contributor-guide/adding_a_new_expression.md +++ b/docs/source/contributor-guide/adding_a_new_expression.md @@ -271,7 +271,8 @@ How this works is somewhat dependent on the type of expression you're adding. Ex If you're adding a new expression that requires custom protobuf serialization, you may need to: 1. Add a new message to the protobuf definition in `native/proto/src/proto/expr.proto` -2. Update the Rust deserialization code to handle the new protobuf message type +2. Add a native expression handler in `expression_registry.rs` to deserialize the new protobuf message type and + create a native expression For most expressions, you can skip this step if you're using the existing scalar function infrastructure. diff --git a/native/core/src/execution/expressions/arithmetic.rs b/native/core/src/execution/expressions/arithmetic.rs new file mode 100644 index 0000000000..71fe85ef52 --- /dev/null +++ b/native/core/src/execution/expressions/arithmetic.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Arithmetic expression builders + +/// Macro to generate arithmetic expression builders that need eval_mode handling +#[macro_export] +macro_rules! arithmetic_expr_builder { + ($builder_name:ident, $expr_type:ident, $operator:expr) => { + pub struct $builder_name; + + impl $crate::execution::planner::traits::ExpressionBuilder for $builder_name { + fn build( + &self, + spark_expr: &datafusion_comet_proto::spark_expression::Expr, + input_schema: arrow::datatypes::SchemaRef, + planner: &$crate::execution::planner::PhysicalPlanner, + ) -> Result< + std::sync::Arc, + $crate::execution::operators::ExecutionError, + > { + let expr = $crate::extract_expr!(spark_expr, $expr_type); + let eval_mode = + $crate::execution::planner::from_protobuf_eval_mode(expr.eval_mode)?; + planner.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + $operator, + input_schema, + eval_mode, + ) + } + } + }; +} + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::logical_expr::Operator as DataFusionOperator; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::Expr; +use datafusion_comet_spark_expr::{create_modulo_expr, create_negate_expr, EvalMode}; + +use crate::execution::{ + expressions::extract_expr, + operators::ExecutionError, + planner::{ + from_protobuf_eval_mode, traits::ExpressionBuilder, BinaryExprOptions, PhysicalPlanner, + }, +}; + +/// Macro to define basic arithmetic builders that use eval_mode +macro_rules! define_basic_arithmetic_builders { + ($(($builder:ident, $expr_type:ident, $op:expr)),* $(,)?) => { + $( + arithmetic_expr_builder!($builder, $expr_type, $op); + )* + }; +} + +define_basic_arithmetic_builders![ + (AddBuilder, Add, DataFusionOperator::Plus), + (SubtractBuilder, Subtract, DataFusionOperator::Minus), + (MultiplyBuilder, Multiply, DataFusionOperator::Multiply), + (DivideBuilder, Divide, DataFusionOperator::Divide), +]; + +/// Builder for IntegralDivide expressions (requires special options) +pub struct IntegralDivideBuilder; + +impl ExpressionBuilder for IntegralDivideBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, IntegralDivide); + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + planner.create_binary_expr_with_options( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Divide, + input_schema, + BinaryExprOptions { + is_integral_div: true, + }, + eval_mode, + ) + } +} + +/// Builder for Remainder expressions (uses special modulo function) +pub struct RemainderBuilder; + +impl ExpressionBuilder for RemainderBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, Remainder); + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let left = planner.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let right = planner.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + + let result = create_modulo_expr( + left, + right, + expr.return_type + .as_ref() + .map(crate::execution::serde::to_arrow_datatype) + .unwrap(), + input_schema, + eval_mode == EvalMode::Ansi, + &planner.session_ctx().state(), + ); + result.map_err(|e| ExecutionError::GeneralError(e.to_string())) + } +} + +/// Builder for UnaryMinus expressions (uses special negate function) +pub struct UnaryMinusBuilder; + +impl ExpressionBuilder for UnaryMinusBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, UnaryMinus); + let child = planner.create_expr(expr.child.as_ref().unwrap(), input_schema)?; + let result = create_negate_expr(child, expr.fail_on_error); + result.map_err(|e| ExecutionError::GeneralError(e.to_string())) + } +} diff --git a/native/core/src/execution/expressions/bitwise.rs b/native/core/src/execution/expressions/bitwise.rs new file mode 100644 index 0000000000..2e39588b44 --- /dev/null +++ b/native/core/src/execution/expressions/bitwise.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Bitwise expression builders + +use datafusion::logical_expr::Operator as DataFusionOperator; + +use crate::binary_expr_builder; + +/// Macro to define all bitwise builders at once +macro_rules! define_bitwise_builders { + ($(($builder:ident, $expr_type:ident, $op:expr)),* $(,)?) => { + $( + binary_expr_builder!($builder, $expr_type, $op); + )* + }; +} + +define_bitwise_builders![ + ( + BitwiseAndBuilder, + BitwiseAnd, + DataFusionOperator::BitwiseAnd + ), + (BitwiseOrBuilder, BitwiseOr, DataFusionOperator::BitwiseOr), + ( + BitwiseXorBuilder, + BitwiseXor, + DataFusionOperator::BitwiseXor + ), + ( + BitwiseShiftLeftBuilder, + BitwiseShiftLeft, + DataFusionOperator::BitwiseShiftLeft + ), + ( + BitwiseShiftRightBuilder, + BitwiseShiftRight, + DataFusionOperator::BitwiseShiftRight + ), +]; diff --git a/native/core/src/execution/expressions/comparison.rs b/native/core/src/execution/expressions/comparison.rs new file mode 100644 index 0000000000..8312059e90 --- /dev/null +++ b/native/core/src/execution/expressions/comparison.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Comparison expression builders + +use datafusion::logical_expr::Operator as DataFusionOperator; + +use crate::binary_expr_builder; + +/// Macro to define all comparison builders at once +macro_rules! define_comparison_builders { + ($(($builder:ident, $expr_type:ident, $op:expr)),* $(,)?) => { + $( + binary_expr_builder!($builder, $expr_type, $op); + )* + }; +} + +define_comparison_builders![ + (EqBuilder, Eq, DataFusionOperator::Eq), + (NeqBuilder, Neq, DataFusionOperator::NotEq), + (LtBuilder, Lt, DataFusionOperator::Lt), + (LtEqBuilder, LtEq, DataFusionOperator::LtEq), + (GtBuilder, Gt, DataFusionOperator::Gt), + (GtEqBuilder, GtEq, DataFusionOperator::GtEq), + ( + EqNullSafeBuilder, + EqNullSafe, + DataFusionOperator::IsNotDistinctFrom + ), + ( + NeqNullSafeBuilder, + NeqNullSafe, + DataFusionOperator::IsDistinctFrom + ), +]; diff --git a/native/core/src/execution/expressions/logical.rs b/native/core/src/execution/expressions/logical.rs new file mode 100644 index 0000000000..04d09bd660 --- /dev/null +++ b/native/core/src/execution/expressions/logical.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logical expression builders + +use datafusion::logical_expr::Operator as DataFusionOperator; +use datafusion::physical_expr::expressions::NotExpr; + +use crate::{binary_expr_builder, unary_expr_builder}; + +/// Macro to define all logical builders at once +macro_rules! define_logical_builders { + () => { + binary_expr_builder!(AndBuilder, And, DataFusionOperator::And); + binary_expr_builder!(OrBuilder, Or, DataFusionOperator::Or); + unary_expr_builder!(NotBuilder, Not, NotExpr::new); + }; +} + +define_logical_builders!(); diff --git a/native/core/src/execution/expressions/mod.rs b/native/core/src/execution/expressions/mod.rs index 9bb8fad456..105afd5952 100644 --- a/native/core/src/execution/expressions/mod.rs +++ b/native/core/src/execution/expressions/mod.rs @@ -17,6 +17,14 @@ //! Native DataFusion expressions +pub mod arithmetic; +pub mod bitwise; +pub mod comparison; +pub mod logical; +pub mod nullcheck; pub mod subquery; pub use datafusion_comet_spark_expr::EvalMode; + +// Re-export the extract_expr macro for convenience in expression builders +pub use crate::extract_expr; diff --git a/native/core/src/execution/expressions/nullcheck.rs b/native/core/src/execution/expressions/nullcheck.rs new file mode 100644 index 0000000000..3981ab5504 --- /dev/null +++ b/native/core/src/execution/expressions/nullcheck.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Null check expression builders + +use datafusion::physical_expr::expressions::{IsNotNullExpr, IsNullExpr}; + +use crate::unary_expr_builder; + +/// Macro to define all null check builders at once +macro_rules! define_null_check_builders { + () => { + unary_expr_builder!(IsNullBuilder, IsNull, IsNullExpr::new); + unary_expr_builder!(IsNotNullBuilder, IsNotNull, IsNotNullExpr::new); + }; +} + +define_null_check_builders!(); diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index d09393fc90..269ded1e48 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -17,12 +17,16 @@ //! Converts Spark physical plan to DataFusion physical plan +pub mod expression_registry; +pub mod traits; + use crate::execution::operators::IcebergScanExec; use crate::{ errors::ExpressionError, execution::{ expressions::subquery::Subquery, operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + planner::expression_registry::ExpressionRegistry, serde::to_arrow_datatype, shuffle::ShuffleWriterExec, }, @@ -46,8 +50,8 @@ use datafusion::{ logical_expr::Operator as DataFusionOperator, physical_expr::{ expressions::{ - in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, - Literal as DataFusionLiteral, NotExpr, + in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNullExpr, LikeExpr, + Literal as DataFusionLiteral, }, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, }, @@ -62,9 +66,8 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_comet_spark_expr::{ - create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, create_modulo_expr, - create_negate_expr, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, EvalMode, - SparkHour, SparkMinute, SparkSecond, + create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, + BloomFilterAgg, BloomFilterMightContain, EvalMode, SparkHour, SparkMinute, SparkSecond, }; use iceberg::expr::Bind; @@ -144,7 +147,7 @@ struct JoinParameters { } #[derive(Default)] -struct BinaryExprOptions { +pub struct BinaryExprOptions { pub is_integral_div: bool, } @@ -244,126 +247,13 @@ impl PhysicalPlanner { spark_expr: &Expr, input_schema: SchemaRef, ) -> Result, ExecutionError> { - match spark_expr.expr_struct.as_ref().unwrap() { - ExprStruct::Add(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Plus, - input_schema, - eval_mode, - ) - } - ExprStruct::Subtract(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Minus, - input_schema, - eval_mode, - ) - } - ExprStruct::Multiply(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Multiply, - input_schema, - eval_mode, - ) - } - ExprStruct::Divide(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Divide, - input_schema, - eval_mode, - ) - } - ExprStruct::IntegralDivide(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - self.create_binary_expr_with_options( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Divide, - input_schema, - BinaryExprOptions { - is_integral_div: true, - }, - eval_mode, - ) - } - ExprStruct::Remainder(expr) => { - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; - // TODO add support for EvalMode::TRY - // https://github.com/apache/datafusion-comet/issues/2021 - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = - self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + // Try to use the modular registry first - this automatically handles any registered expression types + if ExpressionRegistry::global().can_handle(spark_expr) { + return ExpressionRegistry::global().create_expr(spark_expr, input_schema, self); + } - let result = create_modulo_expr( - left, - right, - expr.return_type.as_ref().map(to_arrow_datatype).unwrap(), - input_schema, - eval_mode == EvalMode::Ansi, - &self.session_ctx.state(), - ); - result.map_err(|e| GeneralError(e.to_string())) - } - ExprStruct::Eq(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::Eq; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::Neq(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::NotEq; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::Gt(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::Gt; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::GtEq(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::GtEq; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::Lt(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::Lt; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::LtEq(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::LtEq; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } + // Fall back to the original monolithic match for other expressions + match spark_expr.expr_struct.as_ref().unwrap() { ExprStruct::Bound(bound) => { let idx = bound.index as usize; if idx >= input_schema.fields().len() { @@ -381,28 +271,6 @@ impl PhysicalPlanner { data_type, ))) } - ExprStruct::IsNotNull(is_notnull) => { - let child = self.create_expr(is_notnull.child.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(IsNotNullExpr::new(child))) - } - ExprStruct::IsNull(is_null) => { - let child = self.create_expr(is_null.child.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(IsNullExpr::new(child))) - } - ExprStruct::And(and) => { - let left = - self.create_expr(and.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(and.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::And; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::Or(or) => { - let left = - self.create_expr(or.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(or.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::Or; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } ExprStruct::Literal(literal) => { let data_type = to_arrow_datatype(literal.datatype.as_ref().unwrap()); let scalar_value = if literal.is_null { @@ -629,55 +497,6 @@ impl PhysicalPlanner { _ => func, } } - ExprStruct::EqNullSafe(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::IsNotDistinctFrom; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::NeqNullSafe(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::IsDistinctFrom; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::BitwiseAnd(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::BitwiseAnd; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::BitwiseOr(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::BitwiseOr; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::BitwiseXor(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::BitwiseXor; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::BitwiseShiftRight(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::BitwiseShiftRight; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } - ExprStruct::BitwiseShiftLeft(expr) => { - let left = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let right = self.create_expr(expr.right.as_ref().unwrap(), input_schema)?; - let op = DataFusionOperator::BitwiseShiftLeft; - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } ExprStruct::CaseWhen(case_when) => { let when_then_pairs = case_when .when @@ -726,16 +545,6 @@ impl PhysicalPlanner { self.create_expr(expr.false_expr.as_ref().unwrap(), input_schema)?; Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr))) } - ExprStruct::Not(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; - Ok(Arc::new(NotExpr::new(child))) - } - ExprStruct::UnaryMinus(expr) => { - let child: Arc = - self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; - let result = create_negate_expr(child, expr.fail_on_error); - result.map_err(|e| GeneralError(e.to_string())) - } ExprStruct::NormalizeNanAndZero(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); @@ -894,7 +703,7 @@ impl PhysicalPlanner { } } - fn create_binary_expr( + pub fn create_binary_expr( &self, left: &Expr, right: &Expr, @@ -915,7 +724,7 @@ impl PhysicalPlanner { } #[allow(clippy::too_many_arguments)] - fn create_binary_expr_with_options( + pub fn create_binary_expr_with_options( &self, left: &Expr, right: &Expr, @@ -2825,7 +2634,7 @@ fn rewrite_physical_expr( Ok(expr.rewrite(&mut rewriter).data()?) } -fn from_protobuf_eval_mode(value: i32) -> Result { +pub fn from_protobuf_eval_mode(value: i32) -> Result { match spark_expression::EvalMode::try_from(value)? { spark_expression::EvalMode::Legacy => Ok(EvalMode::Legacy), spark_expression::EvalMode::Try => Ok(EvalMode::Try), diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs new file mode 100644 index 0000000000..f97cb984b1 --- /dev/null +++ b/native/core/src/execution/planner/expression_registry.rs @@ -0,0 +1,267 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression registry for dispatching expression creation + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::{expr::ExprStruct, Expr}; + +use crate::execution::operators::ExecutionError; +use crate::execution::planner::traits::{ExpressionBuilder, ExpressionType}; + +/// Registry for expression builders +pub struct ExpressionRegistry { + builders: HashMap>, +} + +impl ExpressionRegistry { + /// Create a new expression registry with all builders registered + fn new() -> Self { + let mut registry = Self { + builders: HashMap::new(), + }; + + registry.register_all_expressions(); + registry + } + + /// Get the global shared registry instance + pub fn global() -> &'static ExpressionRegistry { + static REGISTRY: std::sync::OnceLock = std::sync::OnceLock::new(); + REGISTRY.get_or_init(ExpressionRegistry::new) + } + + /// Check if the registry can handle a given expression type + pub fn can_handle(&self, spark_expr: &Expr) -> bool { + if let Ok(expr_type) = Self::get_expression_type(spark_expr) { + self.builders.contains_key(&expr_type) + } else { + false + } + } + + /// Create a physical expression from a Spark protobuf expression + pub fn create_expr( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &super::PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr_type = Self::get_expression_type(spark_expr)?; + + if let Some(builder) = self.builders.get(&expr_type) { + builder.build(spark_expr, input_schema, planner) + } else { + Err(ExecutionError::GeneralError(format!( + "No builder registered for expression type: {:?}", + expr_type + ))) + } + } + + /// Register all expression builders + fn register_all_expressions(&mut self) { + // Register arithmetic expressions + self.register_arithmetic_expressions(); + + // Register comparison expressions + self.register_comparison_expressions(); + + // Register bitwise expressions + self.register_bitwise_expressions(); + + // Register logical expressions + self.register_logical_expressions(); + + // Register null check expressions + self.register_null_check_expressions(); + + // TODO: Register other expression categories in future phases + // self.register_string_expressions(); + // self.register_temporal_expressions(); + // etc. + } + + /// Register arithmetic expression builders + fn register_arithmetic_expressions(&mut self) { + use crate::execution::expressions::arithmetic::*; + + self.builders + .insert(ExpressionType::Add, Box::new(AddBuilder)); + self.builders + .insert(ExpressionType::Subtract, Box::new(SubtractBuilder)); + self.builders + .insert(ExpressionType::Multiply, Box::new(MultiplyBuilder)); + self.builders + .insert(ExpressionType::Divide, Box::new(DivideBuilder)); + self.builders.insert( + ExpressionType::IntegralDivide, + Box::new(IntegralDivideBuilder), + ); + self.builders + .insert(ExpressionType::Remainder, Box::new(RemainderBuilder)); + self.builders + .insert(ExpressionType::UnaryMinus, Box::new(UnaryMinusBuilder)); + } + + /// Register comparison expression builders + fn register_comparison_expressions(&mut self) { + use crate::execution::expressions::comparison::*; + + self.builders + .insert(ExpressionType::Eq, Box::new(EqBuilder)); + self.builders + .insert(ExpressionType::Neq, Box::new(NeqBuilder)); + self.builders + .insert(ExpressionType::Lt, Box::new(LtBuilder)); + self.builders + .insert(ExpressionType::LtEq, Box::new(LtEqBuilder)); + self.builders + .insert(ExpressionType::Gt, Box::new(GtBuilder)); + self.builders + .insert(ExpressionType::GtEq, Box::new(GtEqBuilder)); + self.builders + .insert(ExpressionType::EqNullSafe, Box::new(EqNullSafeBuilder)); + self.builders + .insert(ExpressionType::NeqNullSafe, Box::new(NeqNullSafeBuilder)); + } + + /// Register bitwise expression builders + fn register_bitwise_expressions(&mut self) { + use crate::execution::expressions::bitwise::*; + + self.builders + .insert(ExpressionType::BitwiseAnd, Box::new(BitwiseAndBuilder)); + self.builders + .insert(ExpressionType::BitwiseOr, Box::new(BitwiseOrBuilder)); + self.builders + .insert(ExpressionType::BitwiseXor, Box::new(BitwiseXorBuilder)); + self.builders.insert( + ExpressionType::BitwiseShiftLeft, + Box::new(BitwiseShiftLeftBuilder), + ); + self.builders.insert( + ExpressionType::BitwiseShiftRight, + Box::new(BitwiseShiftRightBuilder), + ); + } + + /// Register logical expression builders + fn register_logical_expressions(&mut self) { + use crate::execution::expressions::logical::*; + + self.builders + .insert(ExpressionType::And, Box::new(AndBuilder)); + self.builders + .insert(ExpressionType::Or, Box::new(OrBuilder)); + self.builders + .insert(ExpressionType::Not, Box::new(NotBuilder)); + } + + /// Register null check expression builders + fn register_null_check_expressions(&mut self) { + use crate::execution::expressions::nullcheck::*; + + self.builders + .insert(ExpressionType::IsNull, Box::new(IsNullBuilder)); + self.builders + .insert(ExpressionType::IsNotNull, Box::new(IsNotNullBuilder)); + } + + /// Extract expression type from Spark protobuf expression + fn get_expression_type(spark_expr: &Expr) -> Result { + match spark_expr.expr_struct.as_ref() { + Some(ExprStruct::Add(_)) => Ok(ExpressionType::Add), + Some(ExprStruct::Subtract(_)) => Ok(ExpressionType::Subtract), + Some(ExprStruct::Multiply(_)) => Ok(ExpressionType::Multiply), + Some(ExprStruct::Divide(_)) => Ok(ExpressionType::Divide), + Some(ExprStruct::IntegralDivide(_)) => Ok(ExpressionType::IntegralDivide), + Some(ExprStruct::Remainder(_)) => Ok(ExpressionType::Remainder), + Some(ExprStruct::UnaryMinus(_)) => Ok(ExpressionType::UnaryMinus), + + Some(ExprStruct::Eq(_)) => Ok(ExpressionType::Eq), + Some(ExprStruct::Neq(_)) => Ok(ExpressionType::Neq), + Some(ExprStruct::Lt(_)) => Ok(ExpressionType::Lt), + Some(ExprStruct::LtEq(_)) => Ok(ExpressionType::LtEq), + Some(ExprStruct::Gt(_)) => Ok(ExpressionType::Gt), + Some(ExprStruct::GtEq(_)) => Ok(ExpressionType::GtEq), + Some(ExprStruct::EqNullSafe(_)) => Ok(ExpressionType::EqNullSafe), + Some(ExprStruct::NeqNullSafe(_)) => Ok(ExpressionType::NeqNullSafe), + + Some(ExprStruct::And(_)) => Ok(ExpressionType::And), + Some(ExprStruct::Or(_)) => Ok(ExpressionType::Or), + Some(ExprStruct::Not(_)) => Ok(ExpressionType::Not), + + Some(ExprStruct::IsNull(_)) => Ok(ExpressionType::IsNull), + Some(ExprStruct::IsNotNull(_)) => Ok(ExpressionType::IsNotNull), + + Some(ExprStruct::BitwiseAnd(_)) => Ok(ExpressionType::BitwiseAnd), + Some(ExprStruct::BitwiseOr(_)) => Ok(ExpressionType::BitwiseOr), + Some(ExprStruct::BitwiseXor(_)) => Ok(ExpressionType::BitwiseXor), + Some(ExprStruct::BitwiseShiftLeft(_)) => Ok(ExpressionType::BitwiseShiftLeft), + Some(ExprStruct::BitwiseShiftRight(_)) => Ok(ExpressionType::BitwiseShiftRight), + + Some(ExprStruct::Bound(_)) => Ok(ExpressionType::Bound), + Some(ExprStruct::Unbound(_)) => Ok(ExpressionType::Unbound), + Some(ExprStruct::Literal(_)) => Ok(ExpressionType::Literal), + Some(ExprStruct::Cast(_)) => Ok(ExpressionType::Cast), + Some(ExprStruct::CaseWhen(_)) => Ok(ExpressionType::CaseWhen), + Some(ExprStruct::In(_)) => Ok(ExpressionType::In), + Some(ExprStruct::If(_)) => Ok(ExpressionType::If), + Some(ExprStruct::Substring(_)) => Ok(ExpressionType::Substring), + Some(ExprStruct::Like(_)) => Ok(ExpressionType::Like), + Some(ExprStruct::Rlike(_)) => Ok(ExpressionType::Rlike), + Some(ExprStruct::CheckOverflow(_)) => Ok(ExpressionType::CheckOverflow), + Some(ExprStruct::ScalarFunc(_)) => Ok(ExpressionType::ScalarFunc), + Some(ExprStruct::NormalizeNanAndZero(_)) => Ok(ExpressionType::NormalizeNanAndZero), + Some(ExprStruct::Subquery(_)) => Ok(ExpressionType::Subquery), + Some(ExprStruct::BloomFilterMightContain(_)) => { + Ok(ExpressionType::BloomFilterMightContain) + } + Some(ExprStruct::CreateNamedStruct(_)) => Ok(ExpressionType::CreateNamedStruct), + Some(ExprStruct::GetStructField(_)) => Ok(ExpressionType::GetStructField), + Some(ExprStruct::ToJson(_)) => Ok(ExpressionType::ToJson), + Some(ExprStruct::ToPrettyString(_)) => Ok(ExpressionType::ToPrettyString), + Some(ExprStruct::ListExtract(_)) => Ok(ExpressionType::ListExtract), + Some(ExprStruct::GetArrayStructFields(_)) => Ok(ExpressionType::GetArrayStructFields), + Some(ExprStruct::ArrayInsert(_)) => Ok(ExpressionType::ArrayInsert), + Some(ExprStruct::Rand(_)) => Ok(ExpressionType::Rand), + Some(ExprStruct::Randn(_)) => Ok(ExpressionType::Randn), + Some(ExprStruct::SparkPartitionId(_)) => Ok(ExpressionType::SparkPartitionId), + Some(ExprStruct::MonotonicallyIncreasingId(_)) => { + Ok(ExpressionType::MonotonicallyIncreasingId) + } + + Some(ExprStruct::Hour(_)) => Ok(ExpressionType::Hour), + Some(ExprStruct::Minute(_)) => Ok(ExpressionType::Minute), + Some(ExprStruct::Second(_)) => Ok(ExpressionType::Second), + Some(ExprStruct::TruncTimestamp(_)) => Ok(ExpressionType::TruncTimestamp), + + Some(other) => Err(ExecutionError::GeneralError(format!( + "Unsupported expression type: {:?}", + other + ))), + None => Err(ExecutionError::GeneralError( + "Expression struct is None".to_string(), + )), + } + } +} diff --git a/native/core/src/execution/planner/traits.rs b/native/core/src/execution/planner/traits.rs new file mode 100644 index 0000000000..3f3467d0d0 --- /dev/null +++ b/native/core/src/execution/planner/traits.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Core traits for the modular planner framework + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::Expr; +use jni::objects::GlobalRef; + +use crate::execution::operators::ScanExec; +use crate::execution::{operators::ExecutionError, spark_plan::SparkPlan}; + +/// Macro to extract a specific expression variant, panicking if called with wrong type. +/// This should be used in expression builders where the registry guarantees the correct +/// expression type has been routed to the builder. +#[macro_export] +macro_rules! extract_expr { + ($spark_expr:expr, $variant:ident) => { + match $spark_expr + .expr_struct + .as_ref() + .expect("expression struct must be present") + { + datafusion_comet_proto::spark_expression::expr::ExprStruct::$variant(expr) => expr, + other => panic!( + "{} builder called with wrong expression type: {:?}", + stringify!($variant), + other + ), + } + }; +} + +/// Macro to generate binary expression builders with minimal boilerplate +#[macro_export] +macro_rules! binary_expr_builder { + ($builder_name:ident, $expr_type:ident, $operator:expr) => { + pub struct $builder_name; + + impl $crate::execution::planner::traits::ExpressionBuilder for $builder_name { + fn build( + &self, + spark_expr: &datafusion_comet_proto::spark_expression::Expr, + input_schema: arrow::datatypes::SchemaRef, + planner: &$crate::execution::planner::PhysicalPlanner, + ) -> Result< + std::sync::Arc, + $crate::execution::operators::ExecutionError, + > { + let expr = $crate::extract_expr!(spark_expr, $expr_type); + let left = planner.create_expr( + expr.left.as_ref().unwrap(), + std::sync::Arc::clone(&input_schema), + )?; + let right = planner.create_expr(expr.right.as_ref().unwrap(), input_schema)?; + Ok(std::sync::Arc::new( + datafusion::physical_expr::expressions::BinaryExpr::new(left, $operator, right), + )) + } + } + }; +} + +/// Macro to generate unary expression builders +#[macro_export] +macro_rules! unary_expr_builder { + ($builder_name:ident, $expr_type:ident, $expr_constructor:expr) => { + pub struct $builder_name; + + impl $crate::execution::planner::traits::ExpressionBuilder for $builder_name { + fn build( + &self, + spark_expr: &datafusion_comet_proto::spark_expression::Expr, + input_schema: arrow::datatypes::SchemaRef, + planner: &$crate::execution::planner::PhysicalPlanner, + ) -> Result< + std::sync::Arc, + $crate::execution::operators::ExecutionError, + > { + let expr = $crate::extract_expr!(spark_expr, $expr_type); + let child = planner.create_expr(expr.child.as_ref().unwrap(), input_schema)?; + Ok(std::sync::Arc::new($expr_constructor(child))) + } + } + }; +} + +/// Trait for building physical expressions from Spark protobuf expressions +pub trait ExpressionBuilder: Send + Sync { + /// Build a DataFusion physical expression from a Spark protobuf expression + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &super::PhysicalPlanner, + ) -> Result, ExecutionError>; +} + +/// Trait for building physical operators from Spark protobuf operators +#[allow(dead_code)] +pub trait OperatorBuilder: Send + Sync { + /// Build a Spark plan from a protobuf operator + fn build( + &self, + spark_plan: &datafusion_comet_proto::spark_operator::Operator, + inputs: &mut Vec>, + partition_count: usize, + planner: &super::PhysicalPlanner, + ) -> Result<(Vec, Arc), ExecutionError>; +} + +/// Enum to identify different expression types for registry dispatch +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExpressionType { + // Arithmetic expressions + Add, + Subtract, + Multiply, + Divide, + IntegralDivide, + Remainder, + UnaryMinus, + + // Comparison expressions + Eq, + Neq, + Lt, + LtEq, + Gt, + GtEq, + EqNullSafe, + NeqNullSafe, + + // Logical expressions + And, + Or, + Not, + + // Null checks + IsNull, + IsNotNull, + + // Bitwise operations + BitwiseAnd, + BitwiseOr, + BitwiseXor, + BitwiseShiftLeft, + BitwiseShiftRight, + + // Other expressions + Bound, + Unbound, + Literal, + Cast, + CaseWhen, + In, + If, + Substring, + Like, + Rlike, + CheckOverflow, + ScalarFunc, + NormalizeNanAndZero, + Subquery, + BloomFilterMightContain, + CreateNamedStruct, + GetStructField, + ToJson, + ToPrettyString, + ListExtract, + GetArrayStructFields, + ArrayInsert, + Rand, + Randn, + SparkPartitionId, + MonotonicallyIncreasingId, + + // Time functions + Hour, + Minute, + Second, + TruncTimestamp, +} + +/// Enum to identify different operator types for registry dispatch +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[allow(dead_code)] +pub enum OperatorType { + Scan, + NativeScan, + IcebergScan, + Projection, + Filter, + HashAgg, + Limit, + Sort, + ShuffleWriter, + ParquetWriter, + Expand, + SortMergeJoin, + HashJoin, + Window, +}