Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions datafusion/functions-aggregate-common/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,31 @@ macro_rules! min_max {
min_max_generic!(lhs, rhs, $OP)
}

// Dictionary scalars: compare the inner values and re-wrap.
(
ScalarValue::Dictionary(key_type, lhs_inner),
ScalarValue::Dictionary(_, rhs_inner),
) => {
let winner = min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP);
ScalarValue::Dictionary(key_type.clone(), Box::new(winner))
}

// Mixed Dictionary/non-Dictionary: unwrap the dict side and
// compare, then return in whichever form won.
(
ScalarValue::Dictionary(_, lhs_inner),
rhs,
) => {
min_max_generic!(lhs_inner.as_ref(), rhs, $OP)
}

(
lhs,
ScalarValue::Dictionary(_, rhs_inner),
) => {
min_max_generic!(lhs, rhs_inner.as_ref(), $OP)
}

e => {
return internal_err!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
Expand Down Expand Up @@ -766,9 +791,10 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
DataType::FixedSizeList(_, _) => {
min_max_batch_generic(values, Ordering::Greater)?
}
DataType::Dictionary(_, _) => {
let values = values.as_any_dictionary().values();
min_batch(values)?
DataType::Dictionary(key_type, _) => {
let dict_values = values.as_any_dictionary().values();
let inner = min_batch(dict_values)?;
ScalarValue::Dictionary(key_type.clone(), Box::new(inner))
}
_ => min_max_batch!(values, min),
})
Expand Down Expand Up @@ -847,9 +873,10 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
DataType::Dictionary(_, _) => {
let values = values.as_any_dictionary().values();
max_batch(values)?
DataType::Dictionary(key_type, _) => {
let dict_values = values.as_any_dictionary().values();
let inner = max_batch(dict_values)?;
ScalarValue::Dictionary(key_type.clone(), Box::new(inner))
}
_ => min_max_batch!(values, max),
})
Expand Down
82 changes: 82 additions & 0 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1270,4 +1270,86 @@ mod tests {
assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
Ok(())
}

fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue {
ScalarValue::Dictionary(Box::new(key_type), Box::new(inner))
}

#[test]
fn test_min_max_dictionary_without_coercion() -> Result<()> {
let values = StringArray::from(vec!["b", "c", "a", "d"]);
let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), Some(3)]);
let dict_array =
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
let dict_array_ref = Arc::new(dict_array) as ArrayRef;

// Pass the raw Dictionary type — no get_min_max_result_type() unwrap.
let dict_type = dict_array_ref.data_type().clone();

let mut min_acc = MinAccumulator::try_new(&dict_type)?;
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
let min_result = min_acc.evaluate()?;
assert_eq!(min_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))));

let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
let max_result = max_acc.evaluate()?;
assert_eq!(max_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))));
Ok(())
}

#[test]
fn test_min_max_dictionary_with_nulls() -> Result<()> {
let values = StringArray::from(vec!["b", "c", "a"]);
let keys = Int32Array::from(vec![None, Some(0), None, Some(1), Some(2)]);
let dict_array =
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
let dict_array_ref = Arc::new(dict_array) as ArrayRef;

let dict_type = dict_array_ref.data_type().clone();

let mut min_acc = MinAccumulator::try_new(&dict_type)?;
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
let min_result = min_acc.evaluate()?;
assert_eq!(min_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))));

let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
let max_result = max_acc.evaluate()?;
assert_eq!(max_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("c".to_string()))));
Ok(())
}

#[test]
fn test_min_max_dictionary_multi_batch() -> Result<()> {
let dict_type =
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));

// First batch.
let values1 = StringArray::from(vec!["b", "c"]);
let keys1 = Int32Array::from(vec![Some(0), Some(1)]);
let batch1 = Arc::new(
DictionaryArray::try_new(keys1, Arc::new(values1) as ArrayRef).unwrap(),
) as ArrayRef;

// Second batch with a new min and max.
let values2 = StringArray::from(vec!["a", "d"]);
let keys2 = Int32Array::from(vec![Some(0), Some(1)]);
let batch2 = Arc::new(
DictionaryArray::try_new(keys2, Arc::new(values2) as ArrayRef).unwrap(),
) as ArrayRef;

let mut min_acc = MinAccumulator::try_new(&dict_type)?;
min_acc.update_batch(&[Arc::clone(&batch1)])?;
min_acc.update_batch(&[Arc::clone(&batch2)])?;
let min_result = min_acc.evaluate()?;
assert_eq!(min_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))));

let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
max_acc.update_batch(&[Arc::clone(&batch1)])?;
max_acc.update_batch(&[Arc::clone(&batch2)])?;
let max_result = max_acc.evaluate()?;
assert_eq!(max_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))));
Ok(())
}
}
Loading