From 8a8b7c17df065e06e97aa06cf659ff588b305415 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 25 Dec 2025 18:09:36 -0500 Subject: [PATCH] leaky decimal fixes not implemented Signed-off-by: Connor Tsui --- vortex-array/src/arrays/decimal/array.rs | 18 ++- .../src/arrays/decimal/compute/fill_null.rs | 10 +- vortex-array/src/compute/fill_null.rs | 40 +++++ vortex-array/src/validity.rs | 137 ++++++++++++++---- 4 files changed, 172 insertions(+), 33 deletions(-) diff --git a/vortex-array/src/arrays/decimal/array.rs b/vortex-array/src/arrays/decimal/array.rs index 29671af0808..73182fe5635 100644 --- a/vortex-array/src/arrays/decimal/array.rs +++ b/vortex-array/src/arrays/decimal/array.rs @@ -120,7 +120,7 @@ impl DecimalArray { decimal_dtype: DecimalDType, validity: Validity, ) -> VortexResult { - Self::validate(&buffer, &validity)?; + Self::validate(&buffer, decimal_dtype, &validity)?; // SAFETY: validate ensures all invariants are met. Ok(unsafe { Self::new_unchecked(buffer, decimal_dtype, validity) }) @@ -136,8 +136,10 @@ impl DecimalArray { /// /// The caller must ensure all of the following invariants are satisfied: /// + /// - The storage type `T` must be compatible with the precision (i.e., able to represent all + /// values of the declared precision). /// - All non-null values in `buffer` must be representable within the specified precision. - /// - For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99]. + /// For example, with precision=5 and scale=2, all values must be in range [-999.99, 999.99]. /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`. pub unsafe fn new_unchecked( buffer: Buffer, @@ -145,7 +147,7 @@ impl DecimalArray { validity: Validity, ) -> Self { #[cfg(debug_assertions)] - Self::validate(&buffer, &validity) + Self::validate(&buffer, decimal_dtype, &validity) .vortex_expect("[Debug Assertion]: Invalid `DecimalArray` parameters"); Self { @@ -162,8 +164,18 @@ impl DecimalArray { /// This function checks all the invariants required by [`DecimalArray::new_unchecked`]. pub fn validate( buffer: &Buffer, + // TODO(connor): The decimal array storage type should be able to represent the entire + // domain of the decimal type. + _decimal_dtype: DecimalDType, validity: &Validity, ) -> VortexResult<()> { + // vortex_ensure!( + // T::DECIMAL_TYPE.is_compatible_decimal_value_type(decimal_dtype), + // "Storage type {:?} cannot represent all values of precision {}", + // T::DECIMAL_TYPE, + // decimal_dtype.precision() + // ); + if let Some(len) = validity.maybe_len() { vortex_ensure!( buffer.len() == len, diff --git a/vortex-array/src/arrays/decimal/compute/fill_null.rs b/vortex-array/src/arrays/decimal/compute/fill_null.rs index 0f17b9aa5c5..8d8016886b6 100644 --- a/vortex-array/src/arrays/decimal/compute/fill_null.rs +++ b/vortex-array/src/arrays/decimal/compute/fill_null.rs @@ -28,11 +28,13 @@ impl FillNullKernel for DecimalVTable { let is_invalid = is_valid.to_bool().bit_buffer().not(); match_each_decimal_value_type!(array.values_type(), |T| { let mut buffer = array.buffer::().into_mut(); - let fill_value = fill_value - .as_decimal() + let decimal_scalar = fill_value.as_decimal(); + let decimal_value = decimal_scalar .decimal_value() - .and_then(|v| v.cast::()) - .vortex_expect("top-level fill_null ensure non-null fill value"); + .vortex_expect("fill_null requires a non-null fill value"); + let fill_value = decimal_value + .cast::() + .vortex_expect("fill value does not fit in array's decimal storage type"); for invalid_index in is_invalid.set_indices() { buffer[invalid_index] = fill_value; } diff --git a/vortex-array/src/compute/fill_null.rs b/vortex-array/src/compute/fill_null.rs index 3c5379116d3..1e6f1ea95b4 100644 --- a/vortex-array/src/compute/fill_null.rs +++ b/vortex-array/src/compute/fill_null.rs @@ -5,7 +5,10 @@ use std::sync::LazyLock; use arcref::ArcRef; use vortex_dtype::DType; +use vortex_dtype::DecimalType; +use vortex_dtype::match_each_decimal_value_type; use vortex_error::VortexError; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; @@ -15,6 +18,7 @@ use crate::Array; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::ConstantArray; +use crate::arrays::DecimalVTable; use crate::compute::ComputeFn; use crate::compute::ComputeFnVTable; use crate::compute::InvocationArgs; @@ -59,6 +63,14 @@ pub fn fill_null(array: &dyn Array, fill_value: &Scalar) -> VortexResult VortexResult; } @@ -110,6 +122,34 @@ impl ComputeFnVTable for FillNull { vortex_bail!("Cannot fill_null with a null value") } + /* + // For decimal arrays, validate that the fill value fits in the storage type. + if let Some(decimal_dtype) = array.dtype().as_decimal_opt() { + // Try to get the actual storage type from a DecimalArray. Otherwise, use the smallest + // type that can represent the precision. + let storage_type = array + .as_opt::() + .map(|arr| arr.values_type()) + .unwrap_or_else(|| DecimalType::smallest_decimal_value_type(decimal_dtype)); + let decimal_value = fill_value + .as_decimal() + .decimal_value() + .vortex_expect("fill_null checked is_null above"); + + let fits = match_each_decimal_value_type!(storage_type, |T| { + decimal_value.cast::().is_some() + }); + + if !fits { + vortex_bail!( + "fill value {} does not fit in array's decimal storage type {:?}", + decimal_value, + storage_type + ) + } + } + */ + for kernel in kernels { if let Some(output) = kernel.invoke(args)? { return Ok(output); diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 5b224a7278e..f1058d8423e 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -272,37 +272,47 @@ impl Validity { indices: &dyn Array, patches: &Validity, ) -> Self { + use Validity::*; + match (&self, patches) { - (Validity::NonNullable, Validity::NonNullable) => return Validity::NonNullable, - (Validity::NonNullable, _) => { - vortex_panic!("Can't patch a non-nullable validity with nullable validity") + (NonNullable, NonNullable | AllValid) => { + return NonNullable; + } + (NonNullable, Array(_) | AllInvalid) => { + vortex_panic!("Can't patch a non-nullable validity with null values") } - (_, Validity::NonNullable) => { - vortex_panic!("Can't patch a nullable validity with non-nullable validity") + + (AllValid | Array(_) | AllInvalid, NonNullable) => { + vortex_panic!("Can't patch a nullable validity with a non-nullable validity") } - (Validity::AllValid, Validity::AllValid) => return Validity::AllValid, - (Validity::AllInvalid, Validity::AllInvalid) => return Validity::AllInvalid, - _ => {} + + (AllValid, AllValid) => return AllValid, + (AllValid, Array(_) | AllInvalid) => {} + + (AllInvalid, AllInvalid) => return AllInvalid, + (AllInvalid, AllValid | Array(_)) => {} + + (Array(_), _) => {} }; - let own_nullability = if self == Validity::NonNullable { + let own_nullability = if self == NonNullable { Nullability::NonNullable } else { Nullability::Nullable }; let source = match self { - Validity::NonNullable => BoolArray::from(BitBuffer::new_set(len)), - Validity::AllValid => BoolArray::from(BitBuffer::new_set(len)), - Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(len)), - Validity::Array(a) => a.to_bool(), + NonNullable => BoolArray::from(BitBuffer::new_set(len)), + AllValid => BoolArray::from(BitBuffer::new_set(len)), + AllInvalid => BoolArray::from(BitBuffer::new_unset(len)), + Array(a) => a.to_bool(), }; let patch_values = match patches { - Validity::NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())), - Validity::AllValid => BoolArray::from(BitBuffer::new_set(indices.len())), - Validity::AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())), - Validity::Array(a) => a.to_bool(), + NonNullable => BoolArray::from(BitBuffer::new_set(indices.len())), + AllValid => BoolArray::from(BitBuffer::new_set(indices.len())), + AllInvalid => BoolArray::from(BitBuffer::new_unset(indices.len())), + Array(a) => a.to_bool(), }; let patches = Patches::new( @@ -513,21 +523,96 @@ mod tests { use crate::validity::Validity; #[rstest] - #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)] - #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array()) + #[case( + Validity::AllValid, + 5, + &[2, 4], + Validity::AllValid, + Validity::AllValid + )] + #[case( + Validity::AllValid, + 5, + &[2, 4], + Validity::AllInvalid, + Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array()) + )] + #[case( + Validity::AllValid, + 5, + &[2, 4], + Validity::Array(BoolArray::from_iter([true, false]).into_array()), + Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array()) + )] + #[case( + Validity::AllInvalid, + 5, + &[2, 4], + Validity::AllValid, + Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array()) + )] + #[case( + Validity::AllInvalid, + 5, + &[2, 4], + Validity::AllInvalid, + Validity::AllInvalid + )] + #[case( + Validity::AllInvalid, + 5, + &[2, 4], + Validity::Array(BoolArray::from_iter([true, false]).into_array()), + Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array()) + )] + #[case( + Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), + 5, + &[2, 4], + Validity::AllValid, + Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array()) )] - #[case(Validity::AllValid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array()) + #[case( + Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), + 5, + &[2, 4], + Validity::AllInvalid, + Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()) )] - #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array()) + #[case( + Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), + 5, + &[2, 4], + Validity::Array(BoolArray::from_iter([true, false]).into_array()), + Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array()) )] - #[case(Validity::AllInvalid, 5, &[2, 4], Validity::AllInvalid, Validity::AllInvalid)] - #[case(Validity::AllInvalid, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, false, true, false, false]).into_array()) + #[case( + Validity::NonNullable, + 5, + &[2, 4], + Validity::AllValid, + Validity::NonNullable )] - #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllValid, Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array()) + #[case( + Validity::AllValid, + 5, + &[2, 4], + Validity::NonNullable, + Validity::AllValid )] - #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()) + #[case( + Validity::AllInvalid, + 5, + &[2, 4], + Validity::NonNullable, + Validity::Array(BoolArray::from_iter([false, false, true, false, true]).into_array()) )] - #[case(Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([false, true, true, true, false]).into_array()) + #[case( + Validity::Array(BoolArray::from_iter([false, true, false, true, false]).into_array()), + 5, + &[2, 4], + Validity::NonNullable, + Validity::Array(BoolArray::from_iter([false, true, true, true, true]).into_array()) )] fn patch_validity( #[case] validity: Validity,