diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index 27620221cf23c..15a1725a28182 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -413,6 +413,31 @@ macro_rules! min_max { min_max_generic!(lhs, rhs, $OP) } + // Dictionary scalars: compare the inner values and re-wrap. + ( + ScalarValue::Dictionary(key_type, lhs_inner), + ScalarValue::Dictionary(_, rhs_inner), + ) => { + let winner = min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP); + ScalarValue::Dictionary(key_type.clone(), Box::new(winner)) + } + + // Mixed Dictionary/non-Dictionary: unwrap the dict side and + // compare, then return in whichever form won. + ( + ScalarValue::Dictionary(_, lhs_inner), + rhs, + ) => { + min_max_generic!(lhs_inner.as_ref(), rhs, $OP) + } + + ( + lhs, + ScalarValue::Dictionary(_, rhs_inner), + ) => { + min_max_generic!(lhs, rhs_inner.as_ref(), $OP) + } + e => { return internal_err!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", @@ -766,9 +791,10 @@ pub fn min_batch(values: &ArrayRef) -> Result { DataType::FixedSizeList(_, _) => { min_max_batch_generic(values, Ordering::Greater)? } - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - min_batch(values)? + DataType::Dictionary(key_type, _) => { + let dict_values = values.as_any_dictionary().values(); + let inner = min_batch(dict_values)?; + ScalarValue::Dictionary(key_type.clone(), Box::new(inner)) } _ => min_max_batch!(values, min), }) @@ -847,9 +873,10 @@ pub fn max_batch(values: &ArrayRef) -> Result { DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?, DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?, - DataType::Dictionary(_, _) => { - let values = values.as_any_dictionary().values(); - max_batch(values)? + DataType::Dictionary(key_type, _) => { + let dict_values = values.as_any_dictionary().values(); + let inner = max_batch(dict_values)?; + ScalarValue::Dictionary(key_type.clone(), Box::new(inner)) } _ => min_max_batch!(values, max), }) diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index ee8f7ac753318..c5fb9f666bea2 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -1270,4 +1270,104 @@ mod tests { assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); Ok(()) } + + fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(key_type), Box::new(inner)) + } + + #[test] + fn test_min_max_dictionary_without_coercion() -> Result<()> { + let values = StringArray::from(vec!["b", "c", "a", "d"]); + let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), Some(3)]); + let dict_array = + DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); + let dict_array_ref = Arc::new(dict_array) as ArrayRef; + + // Pass the raw Dictionary type — no get_min_max_result_type() unwrap. + let dict_type = dict_array_ref.data_type().clone(); + + let mut min_acc = MinAccumulator::try_new(&dict_type)?; + min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let min_result = min_acc.evaluate()?; + assert_eq!( + min_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) + ); + + let mut max_acc = MaxAccumulator::try_new(&dict_type)?; + max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let max_result = max_acc.evaluate()?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))) + ); + Ok(()) + } + + #[test] + fn test_min_max_dictionary_with_nulls() -> Result<()> { + let values = StringArray::from(vec!["b", "c", "a"]); + let keys = Int32Array::from(vec![None, Some(0), None, Some(1), Some(2)]); + let dict_array = + DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap(); + let dict_array_ref = Arc::new(dict_array) as ArrayRef; + + let dict_type = dict_array_ref.data_type().clone(); + + let mut min_acc = MinAccumulator::try_new(&dict_type)?; + min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let min_result = min_acc.evaluate()?; + assert_eq!( + min_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) + ); + + let mut max_acc = MaxAccumulator::try_new(&dict_type)?; + max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; + let max_result = max_acc.evaluate()?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("c".to_string()))) + ); + Ok(()) + } + + #[test] + fn test_min_max_dictionary_multi_batch() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + + // First batch. + let values1 = StringArray::from(vec!["b", "c"]); + let keys1 = Int32Array::from(vec![Some(0), Some(1)]); + let batch1 = Arc::new( + DictionaryArray::try_new(keys1, Arc::new(values1) as ArrayRef).unwrap(), + ) as ArrayRef; + + // Second batch with a new min and max. + let values2 = StringArray::from(vec!["a", "d"]); + let keys2 = Int32Array::from(vec![Some(0), Some(1)]); + let batch2 = Arc::new( + DictionaryArray::try_new(keys2, Arc::new(values2) as ArrayRef).unwrap(), + ) as ArrayRef; + + let mut min_acc = MinAccumulator::try_new(&dict_type)?; + min_acc.update_batch(&[Arc::clone(&batch1)])?; + min_acc.update_batch(&[Arc::clone(&batch2)])?; + let min_result = min_acc.evaluate()?; + assert_eq!( + min_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))) + ); + + let mut max_acc = MaxAccumulator::try_new(&dict_type)?; + max_acc.update_batch(&[Arc::clone(&batch1)])?; + max_acc.update_batch(&[Arc::clone(&batch2)])?; + let max_result = max_acc.evaluate()?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))) + ); + Ok(()) + } }