Skip to content

Commit 9cacefa

Browse files
committed
functions-aggregate: Support dictionary scalar coercion for min/max
1 parent 6e0dde0 commit 9cacefa

2 files changed

Lines changed: 104 additions & 4 deletions

File tree

datafusion/functions-aggregate-common/src/min_max.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,15 @@ pub struct MaxAccumulator {
432432
impl MaxAccumulator {
433433
/// new max accumulator
434434
pub fn try_new(datatype: &DataType) -> Result<Self> {
435+
// Unwrap dictionary types to match what max_batch() returns: it
436+
// recurses into the dictionary's values array, so the resulting
437+
// ScalarValue is always the value type, not the dictionary type.
438+
let dt = match datatype {
439+
DataType::Dictionary(_, v) => v.as_ref(),
440+
other => other,
441+
};
435442
Ok(Self {
436-
max: ScalarValue::try_from(datatype)?,
443+
max: ScalarValue::try_from(dt)?,
437444
})
438445
}
439446
}
@@ -473,8 +480,15 @@ pub struct MinAccumulator {
473480
impl MinAccumulator {
474481
/// new min accumulator
475482
pub fn try_new(datatype: &DataType) -> Result<Self> {
483+
// Unwrap dictionary types to match what min_batch() returns: it
484+
// recurses into the dictionary's values array, so the resulting
485+
// ScalarValue is always the value type, not the dictionary type.
486+
let dt = match datatype {
487+
DataType::Dictionary(_, v) => v.as_ref(),
488+
other => other,
489+
};
476490
Ok(Self {
477-
min: ScalarValue::try_from(datatype)?,
491+
min: ScalarValue::try_from(dt)?,
478492
})
479493
}
480494
}

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,12 @@ pub struct SlidingMaxAccumulator {
392392
impl SlidingMaxAccumulator {
393393
/// new max accumulator
394394
pub fn try_new(datatype: &DataType) -> Result<Self> {
395+
let dt = match datatype {
396+
DataType::Dictionary(_, v) => v.as_ref(),
397+
other => other,
398+
};
395399
Ok(Self {
396-
max: ScalarValue::try_from(datatype)?,
400+
max: ScalarValue::try_from(dt)?,
397401
moving_max: MovingMax::<ScalarValue>::new(),
398402
})
399403
}
@@ -679,8 +683,12 @@ pub struct SlidingMinAccumulator {
679683

680684
impl SlidingMinAccumulator {
681685
pub fn try_new(datatype: &DataType) -> Result<Self> {
686+
let dt = match datatype {
687+
DataType::Dictionary(_, v) => v.as_ref(),
688+
other => other,
689+
};
682690
Ok(Self {
683-
min: ScalarValue::try_from(datatype)?,
691+
min: ScalarValue::try_from(dt)?,
684692
moving_min: MovingMin::<ScalarValue>::new(),
685693
})
686694
}
@@ -1270,4 +1278,82 @@ mod tests {
12701278
assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string())));
12711279
Ok(())
12721280
}
1281+
1282+
#[test]
1283+
fn test_min_max_dictionary_without_coercion() -> Result<()> {
1284+
let values = StringArray::from(vec!["b", "c", "a", "d"]);
1285+
let keys = Int32Array::from(vec![Some(0), Some(1), Some(2), Some(3)]);
1286+
let dict_array =
1287+
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1288+
let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1289+
1290+
// Pass the raw Dictionary type — no get_min_max_result_type() unwrap.
1291+
let dict_type = dict_array_ref.data_type().clone();
1292+
1293+
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1294+
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1295+
let min_result = min_acc.evaluate()?;
1296+
assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
1297+
1298+
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1299+
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1300+
let max_result = max_acc.evaluate()?;
1301+
assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string())));
1302+
Ok(())
1303+
}
1304+
1305+
#[test]
1306+
fn test_min_max_dictionary_with_nulls() -> Result<()> {
1307+
let values = StringArray::from(vec!["b", "c", "a"]);
1308+
let keys = Int32Array::from(vec![None, Some(0), None, Some(1), Some(2)]);
1309+
let dict_array =
1310+
DictionaryArray::try_new(keys, Arc::new(values) as ArrayRef).unwrap();
1311+
let dict_array_ref = Arc::new(dict_array) as ArrayRef;
1312+
1313+
let dict_type = dict_array_ref.data_type().clone();
1314+
1315+
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1316+
min_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1317+
let min_result = min_acc.evaluate()?;
1318+
assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
1319+
1320+
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1321+
max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?;
1322+
let max_result = max_acc.evaluate()?;
1323+
assert_eq!(max_result, ScalarValue::Utf8(Some("c".to_string())));
1324+
Ok(())
1325+
}
1326+
1327+
#[test]
1328+
fn test_min_max_dictionary_multi_batch() -> Result<()> {
1329+
let dict_type =
1330+
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
1331+
1332+
// First batch.
1333+
let values1 = StringArray::from(vec!["b", "c"]);
1334+
let keys1 = Int32Array::from(vec![Some(0), Some(1)]);
1335+
let batch1 = Arc::new(
1336+
DictionaryArray::try_new(keys1, Arc::new(values1) as ArrayRef).unwrap(),
1337+
) as ArrayRef;
1338+
1339+
// Second batch with a new min and max.
1340+
let values2 = StringArray::from(vec!["a", "d"]);
1341+
let keys2 = Int32Array::from(vec![Some(0), Some(1)]);
1342+
let batch2 = Arc::new(
1343+
DictionaryArray::try_new(keys2, Arc::new(values2) as ArrayRef).unwrap(),
1344+
) as ArrayRef;
1345+
1346+
let mut min_acc = MinAccumulator::try_new(&dict_type)?;
1347+
min_acc.update_batch(&[Arc::clone(&batch1)])?;
1348+
min_acc.update_batch(&[Arc::clone(&batch2)])?;
1349+
let min_result = min_acc.evaluate()?;
1350+
assert_eq!(min_result, ScalarValue::Utf8(Some("a".to_string())));
1351+
1352+
let mut max_acc = MaxAccumulator::try_new(&dict_type)?;
1353+
max_acc.update_batch(&[Arc::clone(&batch1)])?;
1354+
max_acc.update_batch(&[Arc::clone(&batch2)])?;
1355+
let max_result = max_acc.evaluate()?;
1356+
assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string())));
1357+
Ok(())
1358+
}
12731359
}

0 commit comments

Comments
 (0)