Skip to content

Commit fc6d0a4

Browse files
Fix shuffle function to report nullability correctly (#19184)
# fix: shuffle should report nullability correctly - Replace return_type with return_field_from_args to preserve input nullability - Add test to verify nullability is correctly reported - Addresses issue #19145 ## Which issue does this PR close? Closes #19145 ## Rationale for this change The `shuffle` UDF was using the default `is_nullable` implementation which always returns `true`, regardless of the input array's nullability. This causes: 1. Incorrect schema inference - non-nullable inputs are incorrectly marked as nullable 2. Missed optimization opportunities - the query optimizer cannot apply certain optimizations when nullability information is incorrect 3. Potential runtime errors - incorrect metadata can lead to unexpected behavior in downstream operations The shuffle function simply reorders elements within an array without changing the array's structure or nullability, so the output should have the same nullability as the input. ## What changes are included in this PR? 1. **Implemented `return_field_from_args`**: Returns the input field directly, preserving both data type and nullability 2. **Updated `return_type`**: Now returns an error directing users to use `return_field_from_args` instead (following DataFusion best practices) 3. **Added comprehensive tests**: Verifies that both nullable and non-nullable inputs are handled correctly ## Are these changes tested? Yes, this PR includes a new test `test_shuffle_nullability` that verifies: - Non-nullable array input produces non-nullable output - Nullable array input produces nullable output - Data types are preserved correctly in both cases Test results:
1 parent cb2f3d2 commit fc6d0a4

File tree

1 file changed

+61
-3
lines changed

1 file changed

+61
-3
lines changed

datafusion/spark/src/function/array/shuffle.rs

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ use arrow::datatypes::FieldRef;
2626
use datafusion_common::cast::{
2727
as_fixed_size_list_array, as_large_list_array, as_list_array,
2828
};
29-
use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue};
29+
use datafusion_common::{
30+
exec_err, internal_err, utils::take_function_args, Result, ScalarValue,
31+
};
3032
use datafusion_expr::{
3133
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl,
3234
Signature, TypeSignature, Volatility,
@@ -87,8 +89,16 @@ impl ScalarUDFImpl for SparkShuffle {
8789
&self.signature
8890
}
8991

90-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
91-
Ok(arg_types[0].clone())
92+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
93+
internal_err!("return_field_from_args should be used instead")
94+
}
95+
96+
fn return_field_from_args(
97+
&self,
98+
args: datafusion_expr::ReturnFieldArgs,
99+
) -> Result<FieldRef> {
100+
// Shuffle returns an array with the same type and nullability as the input
101+
Ok(Arc::clone(&args.arg_fields[0]))
92102
}
93103

94104
fn invoke_with_args(
@@ -263,3 +273,51 @@ fn fixed_size_array_shuffle(
263273
Some(nulls.into()),
264274
)?))
265275
}
276+
277+
#[cfg(test)]
278+
mod tests {
279+
use super::*;
280+
use arrow::datatypes::Field;
281+
use datafusion_expr::ReturnFieldArgs;
282+
283+
#[test]
284+
fn test_shuffle_nullability() {
285+
let shuffle = SparkShuffle::new();
286+
287+
// Test with non-nullable array
288+
let non_nullable_field = Arc::new(Field::new(
289+
"arr",
290+
List(Arc::new(Field::new("item", DataType::Int32, true))),
291+
false, // not nullable
292+
));
293+
294+
let result = shuffle
295+
.return_field_from_args(ReturnFieldArgs {
296+
arg_fields: &[Arc::clone(&non_nullable_field)],
297+
scalar_arguments: &[None],
298+
})
299+
.unwrap();
300+
301+
// The result should not be nullable (same as input)
302+
assert!(!result.is_nullable());
303+
assert_eq!(result.data_type(), non_nullable_field.data_type());
304+
305+
// Test with nullable array
306+
let nullable_field = Arc::new(Field::new(
307+
"arr",
308+
List(Arc::new(Field::new("item", DataType::Int32, true))),
309+
true, // nullable
310+
));
311+
312+
let result = shuffle
313+
.return_field_from_args(ReturnFieldArgs {
314+
arg_fields: &[Arc::clone(&nullable_field)],
315+
scalar_arguments: &[None],
316+
})
317+
.unwrap();
318+
319+
// The result should be nullable (same as input)
320+
assert!(result.is_nullable());
321+
assert_eq!(result.data_type(), nullable_field.data_type());
322+
}
323+
}

0 commit comments

Comments
 (0)