diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index e9a975239a481..83248779e40d4 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -684,8 +684,8 @@ fn test_simplify_concat_ws() { // the delimiter is an empty string { - let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]); - let expected = concat(vec![col("a"), lit("cb")]); + let expr = concat_ws(lit(""), vec![col("c1"), lit("c"), lit("b")]); + let expected = concat(vec![col("c1"), lit("cb")]); test_simplify(expr, expected); } @@ -695,7 +695,7 @@ fn test_simplify_concat_ws() { lit("-"), vec![ null.clone(), - col("c0"), + col("c1"), lit("hello"), null.clone(), lit("rust"), @@ -707,7 +707,7 @@ fn test_simplify_concat_ws() { ); let expected = concat_ws( lit("-"), - vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")], + vec![col("c1"), lit("hello-rust"), col("c1"), lit("-")], ); test_simplify(expr, expected) } @@ -738,8 +738,8 @@ fn test_simplify_concat_ws_with_null() { // null delimiter (nested) { - let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]); - let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]); + let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c1")]); + let expr = concat_ws(sub_expr, vec![col("c1"), col("c1")]); test_simplify(expr, null); } } @@ -754,16 +754,35 @@ fn test_simplify_concat() -> Result<()> { lit("hello "), null.clone(), lit("rust"), - lit(ScalarValue::Utf8View(Some("!".to_string()))), + lit("!"), col("c2"), lit(""), null, col("c5"), ]); let expr_datatype = expr.get_type(schema.as_ref())?; + let expected = concat(vec![col("c1"), lit("hello rust!"), col("c2"), col("c5")]); + let expected_datatype = expected.get_type(schema.as_ref())?; + assert_eq!(expr_datatype, expected_datatype); + test_simplify(expr, expected); + + let null = lit(ScalarValue::Binary(None)); + let expr = concat(vec![ + null.clone(), + col("c1"), + lit(vec![0xde_u8]), + null.clone(), + lit(vec![0xad_u8]), + lit(vec![0xbe_u8, 0xef]), + col("c2"), + lit(Vec::new()), + null, + col("c5"), + ]); + let expr_datatype = expr.get_type(schema.as_ref())?; let expected = concat(vec![ col("c1"), - lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))), + lit(vec![0xde_u8, 0xad, 0xbe, 0xef]), col("c2"), col("c5"), ]); diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index af51f66faa97c..d7e2ef667e6de 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -97,11 +97,8 @@ impl ScalarUDFImpl for ConcatFunc { } } - /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid - /// potential overflow on LargeUtf8 input. - /// For binaries, use the similar hierarchy fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(deduce_return_type(arg_types)) + Ok(arg_types[0].clone()) } /// Concatenates the text representations of all the arguments. NULL arguments are ignored. @@ -245,9 +242,97 @@ impl ScalarUDFImpl for ConcatFunc { fn simplify( &self, args: Vec, - _info: &SimplifyContext, + info: &SimplifyContext, ) -> Result { - simplify_concat(args) + let data_types = args + .iter() + .map(|expr| info.get_data_type(expr)) + .collect::>>()?; + let return_type = self.return_type(&data_types)?; + + let mut new_args = Vec::with_capacity(args.len()); + let mut contiguous_scalar: Vec = vec![]; + + fn form_scalar(scalar: &[u8], data_type: &DataType) -> Expr { + match data_type { + // Technically we're guaranteed UTF8 safety since all input types + // should be a common type, i.e. all strings or all binary. + // Using from_utf8_lossy here just for safety, as the performance + // impact is probably minimal on this simplification path. + DataType::Utf8 => lit(ScalarValue::Utf8(Some( + String::from_utf8_lossy(scalar).to_string(), + ))), + DataType::LargeUtf8 => lit(ScalarValue::LargeUtf8(Some( + String::from_utf8_lossy(scalar).to_string(), + ))), + DataType::Utf8View => lit(ScalarValue::Utf8View(Some( + String::from_utf8_lossy(scalar).to_string(), + ))), + DataType::Binary => lit(ScalarValue::Binary(Some(scalar.to_vec()))), + DataType::LargeBinary => { + lit(ScalarValue::LargeBinary(Some(scalar.to_vec()))) + } + DataType::BinaryView => { + lit(ScalarValue::BinaryView(Some(scalar.to_vec()))) + } + _ => unreachable!(), + } + } + + for arg in &args { + match arg { + Expr::Literal(sv, _) if sv.is_null() => {} + + Expr::Literal( + ScalarValue::Utf8(Some(v)) + | ScalarValue::LargeUtf8(Some(v)) + | ScalarValue::Utf8View(Some(v)), + _, + ) => { + contiguous_scalar.extend(v.as_bytes()); + } + Expr::Literal( + ScalarValue::Binary(Some(v)) + | ScalarValue::LargeBinary(Some(v)) + | ScalarValue::BinaryView(Some(v)), + _, + ) => { + contiguous_scalar.extend(v); + } + + Expr::Literal(x, _) => { + return internal_err!( + "Unexpected datatype during simplification, expected string or binary got {}", + x.data_type() + ); + } + + // Non-literal blocks further simplification, finish what we've + // done so far and reset + arg => { + if !contiguous_scalar.is_empty() { + new_args.push(form_scalar(&contiguous_scalar, &return_type)); + contiguous_scalar.clear(); + } + new_args.push(arg.clone()); + } + } + } + + if !contiguous_scalar.is_empty() { + new_args.push(form_scalar(&contiguous_scalar, &return_type)); + } + + if args.len() != new_args.len() { + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction { + func: concat(), + args: new_args, + }, + ))) + } else { + Ok(ExprSimplifyResult::Original(args)) + } } fn documentation(&self) -> Option<&Documentation> { @@ -259,24 +344,6 @@ impl ScalarUDFImpl for ConcatFunc { } } -pub(crate) fn deduce_return_type(arg_types: &[DataType]) -> DataType { - use DataType::*; - if arg_types.contains(&BinaryView) { - BinaryView - } else if arg_types.contains(&LargeBinary) { - // Serves LargeBinary and FixedSizeBinary inputs - LargeBinary - } else if arg_types.contains(&Binary) { - Binary - } else if arg_types.contains(&Utf8View) { - Utf8View - } else if arg_types.contains(&LargeUtf8) { - LargeUtf8 - } else { - Utf8 - } -} - /// Coerce all arguments to the widest type within the binary / string family pub(crate) fn coerce_arg_types(arg_types: &[DataType]) -> Result> { let has_binary = arg_types.iter().any(|dt| dt.is_binary()); @@ -311,113 +378,13 @@ fn build_concat( Ok(ColumnarValue::Array(array)) } -pub(crate) fn simplify_concat(args: Vec) -> Result { - // Skip simplification when binary literals are present, because it - // handles only strings - for arg in &args { - match arg { - Expr::Literal(dt, _) if dt.data_type().is_binary() => { - return Ok(ExprSimplifyResult::Original(args)); - } - _ => {} - } - } - - let mut new_args = Vec::with_capacity(args.len()); - let mut contiguous_scalar = "".to_string(); - - let return_type = { - let data_types: Vec<_> = args - .iter() - .filter_map(|expr| match expr { - Expr::Literal(l, _) => Some(l.data_type()), - _ => None, - }) - .collect(); - ConcatFunc::new().return_type(&data_types) - }?; - - for arg in args.clone() { - match arg { - Expr::Literal(ScalarValue::Utf8(None), _) => {} - Expr::Literal(ScalarValue::LargeUtf8(None), _) => {} - Expr::Literal(ScalarValue::Utf8View(None), _) => {} - - // filter out `null` args - // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. - // Concatenate it with the `contiguous_scalar`. - Expr::Literal(ScalarValue::Utf8(Some(v)), _) => { - contiguous_scalar += &v; - } - Expr::Literal(ScalarValue::LargeUtf8(Some(v)), _) => { - contiguous_scalar += &v; - } - Expr::Literal(ScalarValue::Utf8View(Some(v)), _) => { - contiguous_scalar += &v; - } - - Expr::Literal(x, _) => { - return internal_err!( - "The scalar {x} should be casted to string type during the type coercion." - ); - } - // If the arg is not a literal, we should first push the current `contiguous_scalar` - // to the `new_args` (if it is not empty) and reset it to empty string. - // Then pushing this arg to the `new_args`. - arg => { - if !contiguous_scalar.is_empty() { - match return_type { - DataType::Utf8 => new_args.push(lit(contiguous_scalar)), - DataType::LargeUtf8 => new_args - .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), - DataType::Utf8View => new_args - .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), - _ => unreachable!(), - } - contiguous_scalar = "".to_string(); - } - new_args.push(arg); - } - } - } - - if !contiguous_scalar.is_empty() { - match return_type { - DataType::Utf8 => new_args.push(lit(contiguous_scalar)), - DataType::LargeUtf8 => { - new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))) - } - DataType::Utf8View => { - new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))) - } - _ => unreachable!(), - } - } - - if !args.eq(&new_args) { - Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( - ScalarFunction { - func: concat(), - args: new_args, - }, - ))) - } else { - Ok(ExprSimplifyResult::Original(args)) - } -} - #[cfg(test)] mod tests { use super::*; use crate::utils::test::test_function; use DataType::*; - use arrow::array::{ - ArrayRef, BinaryArray, BinaryViewArray, LargeBinaryArray, StringArray, - }; + use arrow::array::{BinaryArray, BinaryViewArray, LargeBinaryArray, StringArray}; use arrow::array::{LargeStringArray, StringViewArray}; - use arrow::datatypes::Field; - use datafusion_common::config::ConfigOptions; - use std::sync::Arc; #[test] fn test_functions() -> Result<()> { @@ -456,10 +423,10 @@ mod tests { test_function!( ConcatFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8View(None)), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("cc".to_string()))), ], Ok(Some("aacc")), &str, @@ -469,9 +436,9 @@ mod tests { test_function!( ConcatFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("aa".to_string()))), ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("cc".to_string()))), ], Ok(Some("aacc")), &str, @@ -482,7 +449,7 @@ mod tests { ConcatFunc::new(), vec![ ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("cc".to_string()))), ], Ok(Some("aacc")), &str, @@ -510,7 +477,7 @@ mod tests { test_function!( ConcatFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::Binary(Some( + ColumnarValue::Scalar(ScalarValue::LargeBinary(Some( "Café".as_bytes().into() ))), ColumnarValue::Scalar(ScalarValue::LargeBinary(Some( @@ -525,7 +492,7 @@ mod tests { test_function!( ConcatFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::Binary(Some( + ColumnarValue::Scalar(ScalarValue::BinaryView(Some( "Café".as_bytes().into() ))), ColumnarValue::Scalar(ScalarValue::BinaryView(Some( @@ -575,103 +542,4 @@ mod tests { ); Ok(()) } - - #[test] - fn test_array_string() -> Result<()> { - let c0 = - ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); - let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ - Some("x"), - None, - Some("z"), - ]))); - let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); - let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ - Some("a"), - None, - Some("b"), - ]))); - let arg_fields = vec![ - Field::new("a", Utf8, true), - Field::new("a", Utf8, true), - Field::new("a", Utf8, true), - Field::new("a", Utf8View, true), - Field::new("a", Utf8View, true), - ] - .into_iter() - .map(Arc::new) - .collect::>(); - - let args = ScalarFunctionArgs { - args: vec![c0, c1, c2, c3, c4], - arg_fields, - number_rows: 3, - return_field: Field::new("f", Utf8View, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - - let result = ConcatFunc::new().invoke_with_args(args)?; - let expected = - Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"])) - as ArrayRef; - match &result { - ColumnarValue::Array(array) => { - assert_eq!(&expected, array); - } - _ => panic!(), - } - Ok(()) - } - - #[test] - fn test_array_binary() -> Result<()> { - let c0 = ColumnarValue::Array(Arc::new(BinaryArray::from_vec(vec![ - b"foo", b"bar", b"baz", - ]))); - let c1 = ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(b",".to_vec()))); - let c2 = ColumnarValue::Array(Arc::new(BinaryArray::from_opt_vec(vec![ - Some(b"x"), - None, - Some(b"z"), - ]))); - let c3 = ColumnarValue::Scalar(ScalarValue::BinaryView(Some(b",".to_vec()))); - let c4 = ColumnarValue::Array(Arc::new(BinaryViewArray::from_iter(vec![ - Some(b"a"), - None, - Some(b"b"), - ]))); - let arg_fields = vec![ - Field::new("a", Binary, true), - Field::new("a", LargeBinary, true), - Field::new("a", Binary, true), - Field::new("a", BinaryView, true), - Field::new("a", BinaryView, true), - ] - .into_iter() - .map(Arc::new) - .collect::>(); - - let args = ScalarFunctionArgs { - args: vec![c0, c1, c2, c3, c4], - arg_fields, - number_rows: 3, - return_field: Field::new("f", BinaryView, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - - let result = ConcatFunc::new().invoke_with_args(args)?; - let expected = Arc::new(BinaryViewArray::from_iter(vec![ - Some(b"foo,x,a".to_vec()), - Some(b"bar,,".to_vec()), - Some(b"baz,z,b".to_vec()), - ])) as ArrayRef; - match &result { - ColumnarValue::Array(array) => { - assert_eq!(&expected, array); - } - _ => panic!(), - } - Ok(()) - } } diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 8cb6869974813..f8b8185f786c3 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -22,7 +22,7 @@ use crate::binaries::{ ConcatBinaryBuilder, ConcatBinaryViewBuilder, ConcatLargeBinaryBuilder, }; use crate::string::concat; -use crate::string::concat::{coerce_arg_types, deduce_return_type, simplify_concat}; +use crate::string::concat::coerce_arg_types; use crate::string::concat_ws; use crate::strings::{ ColumnarValueRef, ConcatBuilder, ConcatLargeStringBuilder, ConcatStringBuilder, @@ -104,9 +104,8 @@ impl ScalarUDFImpl for ConcatWsFunc { } } - /// Match the return type to the input types. Delegates to `concat` implementation. fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(deduce_return_type(arg_types)) + Ok(arg_types[0].clone()) } /// Concatenates all but the first argument, with separators. The first @@ -145,7 +144,6 @@ impl ScalarUDFImpl for ConcatWsFunc { ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) | ScalarValue::BinaryView(Some(v)) => v.as_slice(), - ScalarValue::FixedSizeBinary(_, Some(v)) => v.as_slice(), scalar if scalar.is_null() => { return Ok(null_scalar(&return_datatype)); } @@ -163,9 +161,6 @@ impl ScalarUDFImpl for ConcatWsFunc { ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) | ScalarValue::BinaryView(Some(v)) => values.push(v.as_slice()), - ScalarValue::FixedSizeBinary(_, Some(v)) => { - values.push(v.as_slice()) - } // skip null scalar if scalar.is_null() => {} other => { @@ -311,11 +306,146 @@ impl ScalarUDFImpl for ConcatWsFunc { fn simplify( &self, args: Vec, - _info: &SimplifyContext, + info: &SimplifyContext, ) -> Result { - match &args[..] { - [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals), - _ => Ok(ExprSimplifyResult::Original(args)), + let (delimiter, other_args) = match &args[..] { + [delimiter, vals @ ..] => (delimiter, vals), + _ => return Ok(ExprSimplifyResult::Original(args)), + }; + + let data_types = other_args + .iter() + .map(|expr| info.get_data_type(expr)) + .collect::>>()?; + let return_type = self.return_type(&data_types)?; + + // No implementation for binary, only simple null filtering + if return_type.is_binary() { + let mut filtered_args = other_args + .iter() + .filter(|x| !is_null(x)) + .cloned() + .collect::>(); + filtered_args.insert(0, delimiter.clone()); + if filtered_args.len() != args.len() { + return Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction { + func: concat_ws(), + args: filtered_args, + }, + ))); + } else { + return Ok(ExprSimplifyResult::Original(args)); + } + } + + match delimiter { + // If the delimiter is null, then the value of the whole expression is null. + Expr::Literal(literal, _) if literal.is_null() => { + Ok(ExprSimplifyResult::Simplified(Expr::Literal( + ScalarValue::try_new_null(&return_type)?, + None, + ))) + } + // Behaves like a simple concat if empty delimiter + Expr::Literal( + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) + | ScalarValue::Utf8View(Some(delimiter)), + _, + ) if delimiter.is_empty() => Ok(ExprSimplifyResult::Simplified( + Expr::ScalarFunction(ScalarFunction { + func: concat(), + args: other_args.to_vec(), + }), + )), + Expr::Literal( + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) + | ScalarValue::Utf8View(Some(delimiter)), + _, + ) => { + let typed_lit = |s: String| -> Expr { + match return_type { + DataType::LargeUtf8 => lit(ScalarValue::LargeUtf8(Some(s))), + DataType::Utf8View => lit(ScalarValue::Utf8View(Some(s))), + _ => lit(s), + } + }; + + let mut new_args = Vec::with_capacity(other_args.len()); + new_args.push(typed_lit(delimiter.to_string())); + let mut contiguous_scalar = None; + for arg in other_args { + match arg { + // filter out null args + Expr::Literal(scalar, _) if scalar.is_null() => {} + Expr::Literal( + ScalarValue::Utf8(Some(v)) + | ScalarValue::LargeUtf8(Some(v)) + | ScalarValue::Utf8View(Some(v)), + _, + ) => match contiguous_scalar { + None => contiguous_scalar = Some(v.to_string()), + Some(mut pre) => { + pre += delimiter; + pre += v; + contiguous_scalar = Some(pre) + } + }, + Expr::Literal(s, _) => { + return internal_err!( + "The scalar {s} should be casted to string type during the type coercion." + ); + } + // If the arg is not a literal, we should first push the current `contiguous_scalar` + // to the `new_args` and reset it to None. + // Then pushing this arg to the `new_args`. + arg => { + if let Some(val) = contiguous_scalar { + new_args.push(typed_lit(val)); + } + new_args.push(arg.clone()); + contiguous_scalar = None; + } + } + } + if let Some(val) = contiguous_scalar { + new_args.push(typed_lit(val)); + } + + if args.len() != new_args.len() { + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction { + func: concat_ws(), + args: new_args, + }, + ))) + } else { + Ok(ExprSimplifyResult::Original(args)) + } + } + Expr::Literal(d, _) => internal_err!( + "The scalar {d} should be casted to string type during the type coercion." + ), + _ => { + let mut new_args = other_args + .iter() + .filter(|x| !is_null(x)) + .cloned() + .collect::>(); + new_args.insert(0, delimiter.clone()); + if new_args.len() != args.len() { + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction { + func: concat_ws(), + args: new_args, + }, + ))) + } else { + Ok(ExprSimplifyResult::Original(args)) + } + } } } @@ -359,139 +489,6 @@ fn null_scalar(dt: &DataType) -> ColumnarValue { ) } -fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { - // Preserve the delimiter's string type for any new literals produced - // during simplification. - let delimiter_type = match delimiter { - Expr::Literal(v, _) => v.data_type(), - _ => DataType::Utf8, - }; - - // Shortcut for binary delimiters - if delimiter_type.is_binary() { - let mut args = args - .iter() - .filter(|x| !is_null(x)) - .cloned() - .collect::>(); - args.insert(0, delimiter.clone()); - return Ok(ExprSimplifyResult::Original(args)); - } - - let typed_lit = |s: String| -> Expr { - match delimiter_type { - DataType::LargeUtf8 => lit(ScalarValue::LargeUtf8(Some(s))), - DataType::Utf8View => lit(ScalarValue::Utf8View(Some(s))), - _ => lit(s), - } - }; - - match delimiter { - Expr::Literal( - ScalarValue::Utf8(delimiter) - | ScalarValue::LargeUtf8(delimiter) - | ScalarValue::Utf8View(delimiter), - _, - ) => { - match delimiter { - // When the delimiter is the empty string, replace `concat_ws` - // with `concat` - Some(delimiter) if delimiter.is_empty() => { - match simplify_concat(args.to_vec())? { - ExprSimplifyResult::Original(_) => { - Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( - ScalarFunction { - func: concat(), - args: args.to_vec(), - }, - ))) - } - expr => Ok(expr), - } - } - Some(delimiter) => { - let mut new_args = Vec::with_capacity(args.len()); - new_args.push(typed_lit(delimiter.to_string())); - let mut contiguous_scalar = None; - for arg in args { - match arg { - // filter out null args - Expr::Literal( - ScalarValue::Utf8(None) - | ScalarValue::LargeUtf8(None) - | ScalarValue::Utf8View(None), - _, - ) => {} - Expr::Literal( - ScalarValue::Utf8(Some(v)) - | ScalarValue::LargeUtf8(Some(v)) - | ScalarValue::Utf8View(Some(v)), - _, - ) => match contiguous_scalar { - None => contiguous_scalar = Some(v.to_string()), - Some(mut pre) => { - pre += delimiter; - pre += v; - contiguous_scalar = Some(pre) - } - }, - Expr::Literal(s, _) => { - return internal_err!( - "The scalar {s} should be casted to string type during the type coercion." - ); - } - // If the arg is not a literal, we should first push the current `contiguous_scalar` - // to the `new_args` and reset it to None. - // Then pushing this arg to the `new_args`. - arg => { - if let Some(val) = contiguous_scalar { - new_args.push(typed_lit(val)); - } - new_args.push(arg.clone()); - contiguous_scalar = None; - } - } - } - if let Some(val) = contiguous_scalar { - new_args.push(typed_lit(val)); - } - - Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( - ScalarFunction { - func: concat_ws(), - args: new_args, - }, - ))) - } - // If the delimiter is null, then the value of the whole expression is null. - None => { - let null_scalar = match delimiter_type { - DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), - DataType::Utf8View => ScalarValue::Utf8View(None), - _ => ScalarValue::Utf8(None), - }; - Ok(ExprSimplifyResult::Simplified(Expr::Literal( - null_scalar, - None, - ))) - } - } - } - Expr::Literal(d, _) => internal_err!( - "The scalar {d} should be casted to string type during the type coercion." - ), - _ => { - let mut args = args - .iter() - .filter(|&x| !is_null(x)) - .cloned() - .collect::>(); - args.insert(0, delimiter.clone()); - Ok(ExprSimplifyResult::Original(args)) - } - } -} - fn is_null(expr: &Expr) -> bool { match expr { Expr::Literal(v, _) => v.is_null(), diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index be5ced2edfbf0..22ba7eb13e869 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -88,14 +88,18 @@ impl ScalarUDFImpl for SparkConcat { // Spark semantics: concat returns NULL if ANY input is NULL let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - let arg_types: Vec = args - .arg_fields - .iter() - .map(|f| f.data_type().clone()) - .collect(); - let dt = ConcatFunc::new().return_type(&arg_types)?; - - Ok(Arc::new(Field::new("concat", dt.clone(), nullable))) + let return_type = if args.arg_fields.is_empty() { + DataType::Utf8 + } else { + let arg_types: Vec = args + .arg_fields + .iter() + .map(|f| f.data_type().clone()) + .collect(); + ConcatFunc::new().return_type(&arg_types)? + }; + + Ok(Arc::new(Field::new("concat", return_type, nullable))) } } diff --git a/datafusion/sqllogictest/test_files/string/concat.slt b/datafusion/sqllogictest/test_files/string/concat.slt index 1749a57591bc6..1a5325a64653b 100644 --- a/datafusion/sqllogictest/test_files/string/concat.slt +++ b/datafusion/sqllogictest/test_files/string/concat.slt @@ -136,3 +136,17 @@ SELECT concat_ws(x'7c', column1, column2) from t; statement ok drop table t + +query TT +SELECT + concat(c0, ',', c1, ',', c2), + arrow_cast(concat(arrow_cast(c0, 'Binary'), arrow_cast(',', 'Binary'), arrow_cast(c1, 'Binary'), arrow_cast(',', 'Binary'), arrow_cast(c2, 'Binary')), 'Utf8') +FROM VALUES + ('foo', 'x', 'a'), + ('bar', null, null), + ('baz', 'z', 'b') +t(c0, c1, c2) +---- +foo,x,a foo,x,a +bar,, bar,, +baz,z,b baz,z,b