Skip to content

Commit 3a0064d

Browse files
skushagrarluvaton
andauthored
fix: custom nullability for length (#19175) (#19182)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #19175 ## What changes are included in this PR? - includes custom nullability for length, determined dynamically using input arguments. --------- Co-authored-by: Raz Luvaton <16746759+rluvaton@users.noreply.github.com>
1 parent 41f7137 commit 3a0064d

File tree

1 file changed

+49
-5
lines changed
  • datafusion/spark/src/function/string

1 file changed

+49
-5
lines changed

datafusion/spark/src/function/string/length.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
use arrow::array::{
1919
Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType,
2020
};
21-
use arrow::datatypes::{DataType, Int32Type};
21+
use arrow::datatypes::{DataType, Field, FieldRef, Int32Type};
2222
use datafusion_common::exec_err;
2323
use datafusion_expr::{
24-
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25+
Volatility,
2526
};
2627
use datafusion_functions::utils::make_scalar_function;
2728
use std::sync::Arc;
@@ -78,8 +79,9 @@ impl ScalarUDFImpl for SparkLengthFunc {
7879
}
7980

8081
fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result<DataType> {
81-
// spark length always returns Int32
82-
Ok(DataType::Int32)
82+
datafusion_common::internal_err!(
83+
"return_type should not be called, use return_field_from_args instead"
84+
)
8385
}
8486

8587
fn invoke_with_args(
@@ -92,6 +94,15 @@ impl ScalarUDFImpl for SparkLengthFunc {
9294
fn aliases(&self) -> &[String] {
9395
&self.aliases
9496
}
97+
98+
fn return_field_from_args(
99+
&self,
100+
args: ReturnFieldArgs,
101+
) -> datafusion_common::Result<FieldRef> {
102+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
103+
// spark length always returns Int32
104+
Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable)))
105+
}
95106
}
96107

97108
fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
@@ -193,8 +204,9 @@ mod tests {
193204
use crate::function::utils::test::test_scalar_function;
194205
use arrow::array::{Array, Int32Array};
195206
use arrow::datatypes::DataType::Int32;
207+
use arrow::datatypes::{Field, FieldRef};
196208
use datafusion_common::{Result, ScalarValue};
197-
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
209+
use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarUDFImpl};
198210

199211
macro_rules! test_spark_length_string {
200212
($INPUT:expr, $EXPECTED:expr) => {
@@ -279,4 +291,36 @@ mod tests {
279291

280292
Ok(())
281293
}
294+
295+
#[test]
296+
fn test_spark_length_nullability() -> Result<()> {
297+
let func = SparkLengthFunc::new();
298+
299+
let nullable_field: FieldRef = Arc::new(Field::new("col", DataType::Utf8, true));
300+
301+
let out_nullable = func.return_field_from_args(ReturnFieldArgs {
302+
arg_fields: &[nullable_field],
303+
scalar_arguments: &[None],
304+
})?;
305+
306+
assert!(
307+
out_nullable.is_nullable(),
308+
"length(col) should be nullable when child is nullable"
309+
);
310+
311+
let non_nullable_field: FieldRef =
312+
Arc::new(Field::new("col", DataType::Utf8, false));
313+
314+
let out_non_nullable = func.return_field_from_args(ReturnFieldArgs {
315+
arg_fields: &[non_nullable_field],
316+
scalar_arguments: &[None],
317+
})?;
318+
319+
assert!(
320+
!out_non_nullable.is_nullable(),
321+
"length(col) should NOT be nullable when child is NOT nullable"
322+
);
323+
324+
Ok(())
325+
}
282326
}

0 commit comments

Comments
 (0)