diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index dc3e6dcbd388c..8cd5627b84a8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -4096,8 +4096,14 @@ trait ArraySetLike { case _ => false } - @transient protected lazy val ordering: Ordering[Any] = - TypeUtils.getInterpretedOrdering(et) + // If the element type supports proper equals, we use the values directly for comparison, + // otherwise we use the generic comparable wrapper so all types support hash-based operations + @transient protected lazy val keyGenerator: (Any => Any) = + if (TypeUtils.typeWithProperEquals(et)) { + identity + } else { + GenericComparableWrapper.getGenericComparableWrapperFactory(et) + } protected def resultArrayElementNullable = dt.asInstanceOf[ArrayType].containsNull @@ -4203,62 +4209,32 @@ case class ArrayDistinct(child: Expression) } } - override def nullSafeEval(array: Any): Any = { - val data = array.asInstanceOf[ArrayData] - doEvaluation(data) - } - - @transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) { - (array: ArrayData) => - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val hs = new SQLOpenHashSet[Any]() - val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => - if (!hs.contains(value)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( - prettyName, arrayBuffer.size) - } - arrayBuffer += value - hs.add(value) - }, - (valueNaN: Any) => arrayBuffer += valueNaN) - val withNullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, - (value: Any) => withNaNCheckFunc(value), - () => arrayBuffer += null) - var i = 0 - while (i < array.numElements()) { - withNullCheckFunc(array, i) - i += 1 - } - new GenericArrayData(arrayBuffer) - } else { - (data: ArrayData) => { - val array = data.toArray[AnyRef](elementType) - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[AnyRef] - var alreadyStoredNull = false - for (i <- array.indices) { - if (array(i) != null) { - var found = false - var j = 0 - while (!found && j < arrayBuffer.size) { - val va = arrayBuffer(j) - found = (va != null) && ordering.equiv(va, array(i)) - j += 1 - } - if (!found) { - arrayBuffer += array(i) - } - } else { - // De-duplicate the null values. - if (!alreadyStoredNull) { - arrayBuffer += array(i) - alreadyStoredNull = true + override def nullSafeEval(input: Any): Any = { + val array = input.asInstanceOf[ArrayData] + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new SQLOpenHashSet[Any]() + val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => { + val key = keyGenerator(value) + if (!hs.contains(key)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, arrayBuffer.size) } + arrayBuffer += value + hs.add(key) } - } - new GenericArrayData(arrayBuffer) + }, + (valueNaN: Any) => arrayBuffer += valueNaN) + val withNullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, + (value: Any) => withNaNCheckFunc(value), + () => arrayBuffer += null) + var i = 0 + while (i < array.numElements()) { + withNullCheckFunc(array, i) + i += 1 } + new GenericArrayData(arrayBuffer) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -4381,74 +4357,37 @@ trait ArrayBinaryLike case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLike with ComplexTypeMergingExpression { - @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { - if (TypeUtils.typeWithProperEquals(elementType)) { - (array1, array2) => - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val hs = new SQLOpenHashSet[Any]() - val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => - if (!hs.contains(value)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( - prettyName, arrayBuffer.size) - } - arrayBuffer += value - hs.add(value) - }, - (valueNaN: Any) => arrayBuffer += valueNaN) - val withNullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, - (value: Any) => withNaNCheckFunc(value), - () => arrayBuffer += null - ) - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - withNullCheckFunc(array, i) - i += 1 - } - } - new GenericArrayData(arrayBuffer) - } else { - (array1, array2) => - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var alreadyIncludeNull = false - Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { - var found = false - if (elem == null) { - if (alreadyIncludeNull) { - found = true - } else { - alreadyIncludeNull = true - } - } else { - // check elem is already stored in arrayBuffer or not? - var j = 0 - while (!found && j < arrayBuffer.size) { - val va = arrayBuffer(j) - if (va != null && ordering.equiv(va, elem)) { - found = true - } - j = j + 1 - } - } - if (!found) { - if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( - prettyName, arrayBuffer.length) - } - arrayBuffer += elem - } - })) - new GenericArrayData(arrayBuffer) - } - } - override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - evalUnion(array1, array2) + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new SQLOpenHashSet[Any]() + val withNaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => { + val key = keyGenerator(value) + if (!hs.contains(key)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw QueryExecutionErrors.arrayFunctionWithElementsExceedLimitError( + prettyName, arrayBuffer.size) + } + arrayBuffer += value + hs.add(key) + } + }, + (valueNaN: Any) => arrayBuffer += valueNaN) + val withNullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, + (value: Any) => withNaNCheckFunc(value), + () => arrayBuffer += null + ) + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + withNullCheckFunc(array, i) + i += 1 + } + } + new GenericArrayData(arrayBuffer) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -4565,110 +4504,57 @@ case class ArrayIntersect(left: Expression, right: Expression) extends ArrayBina override def dataType: DataType = internalDataType - @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { - if (TypeUtils.typeWithProperEquals(elementType)) { - (array1, array2) => - if (array1.numElements() != 0 && array2.numElements() != 0) { - val hs = new SQLOpenHashSet[Any] - val hsResult = new SQLOpenHashSet[Any] - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => hs.add(value), - (valueNaN: Any) => {} ) - val withArray2NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, - (value: Any) => withArray2NaNCheckFunc(value), - () => {} - ) - val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hsResult, - (value: Any) => - if (hs.contains(value) && !hsResult.contains(value)) { - arrayBuffer += value - hsResult.add(value) - }, - (valueNaN: Any) => - if (hs.containsNaN()) { - arrayBuffer += valueNaN - }) - val withArray1NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hsResult, - (value: Any) => withArray1NaNCheckFunc(value), - () => - if (hs.containsNull()) { - arrayBuffer += null - } - ) + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] - var i = 0 - while (i < array2.numElements()) { - withArray2NullCheckFunc(array2, i) - i += 1 + if (array1.numElements() != 0 && array2.numElements() != 0) { + val hs = new SQLOpenHashSet[Any] + val hsResult = new SQLOpenHashSet[Any] + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => hs.add(keyGenerator(value)), + (valueNaN: Any) => {} ) + val withArray2NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, + (value: Any) => withArray2NaNCheckFunc(value), + () => {} + ) + val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hsResult, + (value: Any) => { + val key = keyGenerator(value) + if (hs.contains(key) && !hsResult.contains(key)) { + arrayBuffer += value + hsResult.add(key) } - i = 0 - while (i < array1.numElements()) { - withArray1NullCheckFunc(array1, i) - i += 1 + }, + (valueNaN: Any) => + if (hs.containsNaN()) { + arrayBuffer += valueNaN + }) + val withArray1NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hsResult, + (value: Any) => withArray1NaNCheckFunc(value), + () => + if (hs.containsNull()) { + arrayBuffer += null } - new GenericArrayData(arrayBuffer) - } else { - new GenericArrayData(Array.emptyObjectArray) - } + ) + + var i = 0 + while (i < array2.numElements()) { + withArray2NullCheckFunc(array2, i) + i += 1 + } + i = 0 + while (i < array1.numElements()) { + withArray1NullCheckFunc(array1, i) + i += 1 + } + new GenericArrayData(arrayBuffer) } else { - (array1, array2) => - if (array1.numElements() != 0 && array2.numElements() != 0) { - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var alreadySeenNull = false - var i = 0 - while (i < array1.numElements()) { - var found = false - val elem1 = array1.get(i, elementType) - if (array1.isNullAt(i)) { - if (!alreadySeenNull) { - var j = 0 - while (!found && j < array2.numElements()) { - found = array2.isNullAt(j) - j += 1 - } - // array2 is scanned only once for null element - alreadySeenNull = true - } - } else { - var j = 0 - while (!found && j < array2.numElements()) { - if (!array2.isNullAt(j)) { - val elem2 = array2.get(j, elementType) - if (ordering.equiv(elem1, elem2)) { - // check whether elem1 is already stored in arrayBuffer - var foundArrayBuffer = false - var k = 0 - while (!foundArrayBuffer && k < arrayBuffer.size) { - val va = arrayBuffer(k) - foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) - k += 1 - } - found = !foundArrayBuffer - } - } - j += 1 - } - } - if (found) { - arrayBuffer += elem1 - } - i += 1 - } - new GenericArrayData(arrayBuffer) - } else { - new GenericArrayData(Array.emptyObjectArray) - } + new GenericArrayData(Array.emptyObjectArray) } } - override def nullSafeEval(input1: Any, input2: Any): Any = { - val array1 = input1.asInstanceOf[ArrayData] - val array2 = input2.asInstanceOf[ArrayData] - - evalIntersect(array1, array2) - } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val i = ctx.freshName("i") val value = ctx.freshName("value") @@ -4797,93 +4683,43 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL override def dataType: DataType = internalDataType - @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { - if (TypeUtils.typeWithProperEquals(elementType)) { - (array1, array2) => - val hs = new SQLOpenHashSet[Any] - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => hs.add(value), - (valueNaN: Any) => {}) - val withArray2NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, - (value: Any) => withArray2NaNCheckFunc(value), - () => {} - ) - val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, - (value: Any) => - if (!hs.contains(value)) { - arrayBuffer += value - hs.add(value) - }, - (valueNaN: Any) => arrayBuffer += valueNaN) - val withArray1NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, - (value: Any) => withArray1NaNCheckFunc(value), - () => arrayBuffer += null - ) - var i = 0 - while (i < array2.numElements()) { - withArray2NullCheckFunc(array2, i) - i += 1 - } - i = 0 - while (i < array1.numElements()) { - withArray1NullCheckFunc(array1, i) - i += 1 - } - new GenericArrayData(arrayBuffer) - } else { - (array1, array2) => - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - var scannedNullElements = false - var i = 0 - while (i < array1.numElements()) { - var found = false - val elem1 = array1.get(i, elementType) - if (elem1 == null) { - if (!scannedNullElements) { - var j = 0 - while (!found && j < array2.numElements()) { - found = array2.isNullAt(j) - j += 1 - } - // array2 is scanned only once for null element - scannedNullElements = true - } else { - found = true - } - } else { - var j = 0 - while (!found && j < array2.numElements()) { - val elem2 = array2.get(j, elementType) - if (elem2 != null) { - found = ordering.equiv(elem1, elem2) - } - j += 1 - } - if (!found) { - // check whether elem1 is already stored in arrayBuffer - var k = 0 - while (!found && k < arrayBuffer.size) { - val va = arrayBuffer(k) - found = (va != null) && ordering.equiv(va, elem1) - k += 1 - } - } - } - if (!found) { - arrayBuffer += elem1 - } - i += 1 - } - new GenericArrayData(arrayBuffer) - } - } - override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - evalExcept(array1, array2) + val hs = new SQLOpenHashSet[Any] + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => hs.add(keyGenerator(value)), + (valueNaN: Any) => {}) + val withArray2NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, + (value: Any) => withArray2NaNCheckFunc(value), + () => {} + ) + val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs, + (value: Any) => { + val key = keyGenerator(value) + if (!hs.contains(key)) { + arrayBuffer += value + hs.add(key) + } + }, + (valueNaN: Any) => arrayBuffer += valueNaN) + val withArray1NullCheckFunc = SQLOpenHashSet.withNullCheckFunc(elementType, hs, + (value: Any) => withArray1NaNCheckFunc(value), + () => arrayBuffer += null + ) + var i = 0 + while (i < array2.numElements()) { + withArray2NullCheckFunc(array2, i) + i += 1 + } + i = 0 + while (i < array1.numElements()) { + withArray1NullCheckFunc(array1, i) + i += 1 + } + new GenericArrayData(arrayBuffer) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericComparableWrapper.scala new file mode 100644 index 0000000000000..c22f8c9fb20ff --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericComparableWrapper.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions.Murmur3HashFunction +import org.apache.spark.sql.types.DataType + +/** + * Wraps any internal Spark type with the corresponding [[DataType]] to make it comparable with + * other values. + * It uses Spark's internal murmur hash to compute hash code, and uses PhysicalDataType ordering + * to perform equality checks. + * + * @param dataType the data type for the value + */ +class GenericComparableWrapper private ( + val value: Any, + val dataType: DataType, + val ordering: Ordering[Any]) { + + override def hashCode(): Int = Murmur3HashFunction.hash( + value, + dataType, + 42L, + isCollationAware = true, + // legacyCollationAwareHashing only matters when isCollationAware is false. + legacyCollationAwareHashing = false).toInt + + override def equals(other: Any): Boolean = { + if (!other.isInstanceOf[GenericComparableWrapper]) { + return false + } + val otherWrapper = other.asInstanceOf[GenericComparableWrapper] + if (!otherWrapper.dataType.equals(this.dataType)) { + return false + } + ordering.equiv(value, otherWrapper.value) + } +} + +object GenericComparableWrapper { + /** Creates a shared factory method for a given data type */ + def getGenericComparableWrapperFactory( + dataType: DataType): Any => GenericComparableWrapper = { + val ordering = TypeUtils.getInterpretedOrdering(dataType) + value: Any => new GenericComparableWrapper(value, dataType, ordering) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 1907ec7c23aa6..5e018438bb735 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2322,6 +2322,29 @@ class CollectionExpressionsSuite Seq[Int](4, 5))) checkEvaluation(ArrayDistinct(c4), Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](3, 4), Seq[Int](4, 5))) + + val structType = StructType(Seq(StructField("a", IntegerType, nullable = true))) + val d0 = Literal.create(Seq(create_row(1), create_row(2), create_row(1), create_row(2)), + ArrayType(structType)) + val d1 = Literal.create(Seq(null, create_row(2), null, create_row(2)), + ArrayType(structType)) + checkEvaluation(ArrayDistinct(d0), Seq(create_row(1), create_row(2))) + checkEvaluation(ArrayDistinct(d1), Seq(null, create_row(2))) + + val nestedStructType = ArrayType(ArrayType(structType)) + val e0 = Literal.create(Seq( + Seq(create_row(1), create_row(2)), + Seq(create_row(3)), + Seq(create_row(1), create_row(2))), + nestedStructType) + val e1 = Literal.create(Seq( + Seq(create_row(1), create_row(2)), + Seq(create_row(4))), + nestedStructType) + checkEvaluation(ArrayDistinct(e0), + Seq(Seq(create_row(1), create_row(2)), Seq(create_row(3)))) + checkEvaluation(ArrayDistinct(e1), + Seq(Seq(create_row(1), create_row(2)), Seq(create_row(4)))) } test("Array Union") { @@ -2412,6 +2435,29 @@ class CollectionExpressionsSuite assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull) assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull) + + val structType = StructType(Seq(StructField("a", IntegerType, nullable = true))) + val structArray0 = Literal.create(Seq(create_row(1), create_row(2), null), + ArrayType(structType)) + val structArray1 = Literal.create(Seq(create_row(2), create_row(3), null), + ArrayType(structType)) + checkEvaluation(ArrayUnion(structArray0, structArray1), + Seq(create_row(1), create_row(2), null, create_row(3))) + + val nestedStructType = ArrayType(ArrayType(structType)) + val nestedStructArray0 = Literal.create(Seq( + Seq(create_row(1), create_row(2)), + Seq(create_row(3))), + nestedStructType) + val nestedStructArray1 = Literal.create(Seq( + Seq(create_row(1), create_row(2)), + Seq(create_row(3), create_row(4))), + nestedStructType) + checkEvaluation(ArrayUnion(nestedStructArray0, nestedStructArray1), Seq( + Seq(create_row(1), create_row(2)), + Seq(create_row(3)), + Seq(create_row(3), create_row(4)) + )) } test("Shuffle") { @@ -2596,6 +2642,29 @@ class CollectionExpressionsSuite assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull) assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull) + + val structType = StructType(Seq(StructField("a", IntegerType, nullable = true))) + val structArray0 = Literal.create(Seq(create_row(1), create_row(2), null, create_row(1)), + ArrayType(structType)) + val structArray1 = Literal.create(Seq(create_row(2), create_row(3), null), + ArrayType(structType)) + checkEvaluation(ArrayExcept(structArray0, structArray1), Seq(create_row(1))) + checkEvaluation(ArrayExcept(structArray1, structArray0), Seq(create_row(3))) + + val nestedStructType = ArrayType(ArrayType(structType)) + val nestedStructArray0 = Literal.create(Seq( + Seq(create_row(1), create_row(2)), + Seq(create_row(3)), + Seq(create_row(1), create_row(2))), + nestedStructType) + val nestedStructArray1 = Literal.create(Seq( + Seq(create_row(3)), + Seq(create_row(4))), + nestedStructType) + checkEvaluation(ArrayExcept(nestedStructArray0, nestedStructArray1), + Seq(Seq(create_row(1), create_row(2)))) + checkEvaluation(ArrayExcept(nestedStructArray1, nestedStructArray0), + Seq(Seq(create_row(4)))) } test("Array Except - null handling") { @@ -2864,6 +2933,29 @@ class CollectionExpressionsSuite assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull) assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false) assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull) + + val structType = StructType(Seq(StructField("a", IntegerType, nullable = true))) + val structArray0 = Literal.create(Seq(create_row(1), create_row(2), null, create_row(1)), + ArrayType(structType)) + val structArray1 = Literal.create(Seq(create_row(2), create_row(3), null), + ArrayType(structType)) + checkEvaluation(ArrayIntersect(structArray0, structArray1), Seq(create_row(2), null)) + checkEvaluation(ArrayIntersect(structArray1, structArray0), Seq(create_row(2), null)) + + val nestedStructType = ArrayType(ArrayType(structType)) + val nestedStructArray0 = Literal.create(Seq( + Seq(create_row(1), create_row(2)), + Seq(create_row(3)), + Seq(create_row(1), create_row(2))), + nestedStructType) + val nestedStructArray1 = Literal.create(Seq( + Seq(create_row(3)), + Seq(create_row(4))), + nestedStructType) + checkEvaluation(ArrayIntersect(nestedStructArray0, nestedStructArray1), + Seq(Seq(create_row(3)))) + checkEvaluation(ArrayIntersect(nestedStructArray1, nestedStructArray0), + Seq(Seq(create_row(3)))) } test("Array Intersect - null handling") { diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out index af5b4f9b129e8..1986dd78cef40 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/array.sql.out @@ -936,3 +936,39 @@ select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) -- !query analysis Project [array_distinct(array(cast(0.0 as double), cast(0.0 as double), cast(0.0 as double), cast(NaN as double), cast(NaN as double))) AS array_distinct(array(0.0, 0.0, 0.0, NaN, NaN))#x] +- OneRowRelation + + +-- !query +select array_union( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D))) +-- !query analysis +Project [array_union(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double))), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0))) AS array_union(array(named_struct(a, -0.0), named_struct(a, NaN)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 1.0)))#x] ++- OneRowRelation + + +-- !query +select array_distinct( + array(named_struct('a', -0.0D), named_struct('a', 0.0D), + named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN')))) +-- !query analysis +Project [array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, cast(NaN as double)))) AS array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, NaN)))#x] ++- OneRowRelation + + +-- !query +select array_except( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')))) +-- !query analysis +Project [array_except(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)))) AS array_except(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN)))#x] ++- OneRowRelation + + +-- !query +select array_intersect( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D))) +-- !query analysis +Project [array_intersect(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, 2.0))) AS array_intersect(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 2.0)))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/array.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/array.sql.out index c60e2c3737b4b..23863656d7b24 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/array.sql.out @@ -936,3 +936,39 @@ select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) -- !query analysis Project [array_distinct(array(cast(0.0 as double), cast(0.0 as double), cast(0.0 as double), cast(NaN as double), cast(NaN as double))) AS array_distinct(array(0.0, 0.0, 0.0, NaN, NaN))#x] +- OneRowRelation + + +-- !query +select array_union( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D))) +-- !query analysis +Project [array_union(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double))), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0))) AS array_union(array(named_struct(a, -0.0), named_struct(a, NaN)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 1.0)))#x] ++- OneRowRelation + + +-- !query +select array_distinct( + array(named_struct('a', -0.0D), named_struct('a', 0.0D), + named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN')))) +-- !query analysis +Project [array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, cast(NaN as double)))) AS array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, NaN)))#x] ++- OneRowRelation + + +-- !query +select array_except( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')))) +-- !query analysis +Project [array_except(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)))) AS array_except(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN)))#x] ++- OneRowRelation + + +-- !query +select array_intersect( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D))) +-- !query analysis +Project [array_intersect(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, 2.0))) AS array_intersect(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 2.0)))#x] ++- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql index e923c3cdc3600..67649f1238d00 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/array.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -189,3 +189,17 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String)); -- SPARK-45599: Confirm 0.0, -0.0, and NaN are handled appropriately. select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN"))); select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))); + +-- SPARK-54698: Confirm 0.0, -0.0, and NaN are handled appropriately for complex types. +select array_union( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D))); +select array_distinct( + array(named_struct('a', -0.0D), named_struct('a', 0.0D), + named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN')))); +select array_except( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')))); +select array_intersect( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D))); diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out index 90a605734c1a4..26d78f3f0d24b 100644 --- a/sql/core/src/test/resources/sql-tests/results/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -1106,3 +1106,43 @@ select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) struct> -- !query output [0.0,NaN] + + +-- !query +select array_union( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D))) +-- !query schema +struct>> +-- !query output +[{"a":-0.0},{"a":NaN},{"a":1.0}] + + +-- !query +select array_distinct( + array(named_struct('a', -0.0D), named_struct('a', 0.0D), + named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN')))) +-- !query schema +struct>> +-- !query output +[{"a":-0.0},{"a":NaN}] + + +-- !query +select array_except( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')))) +-- !query schema +struct>> +-- !query output +[{"a":1.0}] + + +-- !query +select array_intersect( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D))) +-- !query schema +struct>> +-- !query output +[{"a":-0.0},{"a":NaN}] diff --git a/sql/core/src/test/resources/sql-tests/results/nonansi/array.sql.out b/sql/core/src/test/resources/sql-tests/results/nonansi/array.sql.out index 460cb89113abe..e496912e7577f 100644 --- a/sql/core/src/test/resources/sql-tests/results/nonansi/array.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/nonansi/array.sql.out @@ -994,3 +994,43 @@ select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN"))) struct> -- !query output [0.0,NaN] + + +-- !query +select array_union( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D))) +-- !query schema +struct>> +-- !query output +[{"a":-0.0},{"a":NaN},{"a":1.0}] + + +-- !query +select array_distinct( + array(named_struct('a', -0.0D), named_struct('a', 0.0D), + named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN')))) +-- !query schema +struct>> +-- !query output +[{"a":-0.0},{"a":NaN}] + + +-- !query +select array_except( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')))) +-- !query schema +struct>> +-- !query output +[{"a":1.0}] + + +-- !query +select array_intersect( + array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)), + array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D))) +-- !query schema +struct>> +-- !query output +[{"a":-0.0},{"a":NaN}]