Skip to content

Commit adcecff

Browse files
committed
functions-aggregate: Support dictionary scalar coercion for min/max
1 parent 6e0dde0 commit adcecff

2 files changed

Lines changed: 115 additions & 6 deletions

File tree

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,31 @@ macro_rules! min_max {
413413
min_max_generic!(lhs, rhs, $OP)
414414
}
415415

416+
// Dictionary scalars: compare the inner values and re-wrap.
417+
(
418+
ScalarValue::Dictionary(key_type, lhs_inner),
419+
ScalarValue::Dictionary(_, rhs_inner),
420+
) => {
421+
let winner = min_max_generic!(lhs_inner.as_ref(), rhs_inner.as_ref(), $OP);
422+
ScalarValue::Dictionary(key_type.clone(), Box::new(winner))
423+
}
424+
425+
// Mixed Dictionary/non-Dictionary: unwrap the dict side and
426+
// compare, then return in whichever form won.
427+
(
428+
ScalarValue::Dictionary(_, lhs_inner),
429+
rhs,
430+
) => {
431+
min_max_generic!(lhs_inner.as_ref(), rhs, $OP)
432+
}
433+
434+
(
435+
lhs,
436+
ScalarValue::Dictionary(_, rhs_inner),
437+
) => {
438+
min_max_generic!(lhs, rhs_inner.as_ref(), $OP)
439+
}
440+
416441
e => {
417442
return internal_err!(
418443
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
@@ -766,9 +791,10 @@ pub fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
766791
DataType::FixedSizeList(_, _) => {
767792
min_max_batch_generic(values, Ordering::Greater)?
768793
}
769-
DataType::Dictionary(_, _) => {
770-
let values = values.as_any_dictionary().values();
771-
min_batch(values)?
794+
DataType::Dictionary(key_type, _) => {
795+
let dict_values = values.as_any_dictionary().values();
796+
let inner = min_batch(dict_values)?;
797+
ScalarValue::Dictionary(key_type.clone(), Box::new(inner))
772798
}
773799
_ => min_max_batch!(values, min),
774800
})
@@ -847,9 +873,10 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
847873
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
848874
DataType::LargeList(_) => min_max_batch_generic(values, Ordering::Less)?,
849875
DataType::FixedSizeList(_, _) => min_max_batch_generic(values, Ordering::Less)?,
850-
DataType::Dictionary(_, _) => {
851-
let values = values.as_any_dictionary().values();
852-
max_batch(values)?
876+
DataType::Dictionary(key_type, _) => {
877+
let dict_values = values.as_any_dictionary().values();
878+
let inner = max_batch(dict_values)?;
879+
ScalarValue::Dictionary(key_type.clone(), Box::new(inner))
853880
}
854881
_ => min_max_batch!(values, max),
855882
})

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,4 +1270,86 @@ mod tests {
12701270
assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
12711271
Ok(())
12721272
}
1273+
1274+
fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue {
1275+
ScalarValue::Dictionary(Box::new(key_type), Box::new(inner))
1276+
}
1277+
1278+
#[test]
1279+
fn test_min_max_dictionary_without_coercion() -> Result<()> {
1280+
let values = StringArray::from(vec!["b", "c", "a", "d"]);
1281+
let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), Some(3)]);
1282+
let dict_array =
1283+
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1284+
let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1285+
1286+
// Pass the raw Dictionary type — no get_min_max_result_type() unwrap.
1287+
let dict_type = dict_array_ref.data_type().clone();
1288+
1289+
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1290+
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1291+
let min_result = min_acc.evaluate()?;
1292+
assert_eq!(min_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))));
1293+
1294+
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1295+
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1296+
let max_result = max_acc.evaluate()?;
1297+
assert_eq!(max_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))));
1298+
Ok(())
1299+
}
1300+
1301+
#[test]
1302+
fn test_min_max_dictionary_with_nulls() -> Result<()> {
1303+
let values = StringArray::from(vec!["b", "c", "a"]);
1304+
let keys = Int32Array::from(vec![None, Some(0), None, Some(1), Some(2)]);
1305+
let dict_array =
1306+
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1307+
let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1308+
1309+
let dict_type = dict_array_ref.data_type().clone();
1310+
1311+
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1312+
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1313+
let min_result = min_acc.evaluate()?;
1314+
assert_eq!(min_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))));
1315+
1316+
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1317+
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1318+
let max_result = max_acc.evaluate()?;
1319+
assert_eq!(max_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("c".to_string()))));
1320+
Ok(())
1321+
}
1322+
1323+
#[test]
1324+
fn test_min_max_dictionary_multi_batch() -> Result<()> {
1325+
let dict_type =
1326+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1327+
1328+
// First batch.
1329+
let values1 = StringArray::from(vec!["b", "c"]);
1330+
let keys1 = Int32Array::from(vec![Some(0), Some(1)]);
1331+
let batch1 = Arc::new(
1332+
DictionaryArray::try_new(keys1, Arc::new(values1) as ArrayRef).unwrap(),
1333+
) as ArrayRef;
1334+
1335+
// Second batch with a new min and max.
1336+
let values2 = StringArray::from(vec!["a", "d"]);
1337+
let keys2 = Int32Array::from(vec![Some(0), Some(1)]);
1338+
let batch2 = Arc::new(
1339+
DictionaryArray::try_new(keys2, Arc::new(values2) as ArrayRef).unwrap(),
1340+
) as ArrayRef;
1341+
1342+
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1343+
min_acc.update_batch(&[Arc::clone(&batch1)])?;
1344+
min_acc.update_batch(&[Arc::clone(&batch2)])?;
1345+
let min_result = min_acc.evaluate()?;
1346+
assert_eq!(min_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("a".to_string()))));
1347+
1348+
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1349+
max_acc.update_batch(&[Arc::clone(&batch1)])?;
1350+
max_acc.update_batch(&[Arc::clone(&batch2)])?;
1351+
let max_result = max_acc.evaluate()?;
1352+
assert_eq!(max_result, dict_scalar(DataType::Int32, ScalarValue::Utf8(Some("d".to_string()))));
1353+
Ok(())
1354+
}
12731355
}

0 commit comments

Comments
 (0)