Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.catalyst.expressions.Murmur3HashFunction
import org.apache.spark.sql.types.DataType

/**
* Wraps any internal Spark type with the corresponding [[DataType]] to make it comparable with
* other values.
* It uses Spark's internal murmur hash to compute hash code, and uses PhysicalDataType ordering
* to perform equality checks.
*
* @param dataType the data type for the value
*/
class GenericComparableWrapper private (
val value: Any,
val dataType: DataType,
val ordering: Ordering[Any]) {

override def hashCode(): Int = Murmur3HashFunction.hash(
value,
dataType,
42L,
isCollationAware = true,
// legacyCollationAwareHashing only matters when isCollationAware is false.
legacyCollationAwareHashing = false).toInt

override def equals(other: Any): Boolean = {
if (!other.isInstanceOf[GenericComparableWrapper]) {
return false
}
val otherWrapper = other.asInstanceOf[GenericComparableWrapper]
if (!otherWrapper.dataType.equals(this.dataType)) {
return false
}
ordering.equiv(value, otherWrapper.value)
}
}

object GenericComparableWrapper {
/** Creates a shared factory method for a given data type */
def getGenericComparableWrapperFactory(
dataType: DataType): Any => GenericComparableWrapper = {
val ordering = TypeUtils.getInterpretedOrdering(dataType)
value: Any => new GenericComparableWrapper(value, dataType, ordering)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2322,6 +2322,29 @@ class CollectionExpressionsSuite
Seq[Int](4, 5)))
checkEvaluation(ArrayDistinct(c4), Seq[Seq[Int]](null, Seq[Int](1, 2), Seq[Int](3, 4),
Seq[Int](4, 5)))

val structType = StructType(Seq(StructField("a", IntegerType, nullable = true)))
val d0 = Literal.create(Seq(create_row(1), create_row(2), create_row(1), create_row(2)),
ArrayType(structType))
val d1 = Literal.create(Seq(null, create_row(2), null, create_row(2)),
ArrayType(structType))
checkEvaluation(ArrayDistinct(d0), Seq(create_row(1), create_row(2)))
checkEvaluation(ArrayDistinct(d1), Seq(null, create_row(2)))

val nestedStructType = ArrayType(ArrayType(structType))
val e0 = Literal.create(Seq(
Seq(create_row(1), create_row(2)),
Seq(create_row(3)),
Seq(create_row(1), create_row(2))),
nestedStructType)
val e1 = Literal.create(Seq(
Seq(create_row(1), create_row(2)),
Seq(create_row(4))),
nestedStructType)
checkEvaluation(ArrayDistinct(e0),
Seq(Seq(create_row(1), create_row(2)), Seq(create_row(3))))
checkEvaluation(ArrayDistinct(e1),
Seq(Seq(create_row(1), create_row(2)), Seq(create_row(4))))
}

test("Array Union") {
Expand Down Expand Up @@ -2412,6 +2435,29 @@ class CollectionExpressionsSuite
assert(ArrayUnion(a00, a02).dataType.asInstanceOf[ArrayType].containsNull)
assert(ArrayUnion(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
assert(ArrayUnion(a20, a22).dataType.asInstanceOf[ArrayType].containsNull)

val structType = StructType(Seq(StructField("a", IntegerType, nullable = true)))
val structArray0 = Literal.create(Seq(create_row(1), create_row(2), null),
ArrayType(structType))
val structArray1 = Literal.create(Seq(create_row(2), create_row(3), null),
ArrayType(structType))
checkEvaluation(ArrayUnion(structArray0, structArray1),
Seq(create_row(1), create_row(2), null, create_row(3)))

val nestedStructType = ArrayType(ArrayType(structType))
val nestedStructArray0 = Literal.create(Seq(
Seq(create_row(1), create_row(2)),
Seq(create_row(3))),
nestedStructType)
val nestedStructArray1 = Literal.create(Seq(
Seq(create_row(1), create_row(2)),
Seq(create_row(3), create_row(4))),
nestedStructType)
checkEvaluation(ArrayUnion(nestedStructArray0, nestedStructArray1), Seq(
Seq(create_row(1), create_row(2)),
Seq(create_row(3)),
Seq(create_row(3), create_row(4))
))
}

test("Shuffle") {
Expand Down Expand Up @@ -2596,6 +2642,29 @@ class CollectionExpressionsSuite
assert(ArrayExcept(a04, a05).dataType.asInstanceOf[ArrayType].containsNull)
assert(ArrayExcept(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
assert(ArrayExcept(a24, a22).dataType.asInstanceOf[ArrayType].containsNull)

val structType = StructType(Seq(StructField("a", IntegerType, nullable = true)))
val structArray0 = Literal.create(Seq(create_row(1), create_row(2), null, create_row(1)),
ArrayType(structType))
val structArray1 = Literal.create(Seq(create_row(2), create_row(3), null),
ArrayType(structType))
checkEvaluation(ArrayExcept(structArray0, structArray1), Seq(create_row(1)))
checkEvaluation(ArrayExcept(structArray1, structArray0), Seq(create_row(3)))

val nestedStructType = ArrayType(ArrayType(structType))
val nestedStructArray0 = Literal.create(Seq(
Seq(create_row(1), create_row(2)),
Seq(create_row(3)),
Seq(create_row(1), create_row(2))),
nestedStructType)
val nestedStructArray1 = Literal.create(Seq(
Seq(create_row(3)),
Seq(create_row(4))),
nestedStructType)
checkEvaluation(ArrayExcept(nestedStructArray0, nestedStructArray1),
Seq(Seq(create_row(1), create_row(2))))
checkEvaluation(ArrayExcept(nestedStructArray1, nestedStructArray0),
Seq(Seq(create_row(4))))
}

test("Array Except - null handling") {
Expand Down Expand Up @@ -2864,6 +2933,29 @@ class CollectionExpressionsSuite
assert(ArrayIntersect(a04, a05).dataType.asInstanceOf[ArrayType].containsNull)
assert(ArrayIntersect(a20, a21).dataType.asInstanceOf[ArrayType].containsNull === false)
assert(ArrayIntersect(a23, a24).dataType.asInstanceOf[ArrayType].containsNull)

val structType = StructType(Seq(StructField("a", IntegerType, nullable = true)))
val structArray0 = Literal.create(Seq(create_row(1), create_row(2), null, create_row(1)),
ArrayType(structType))
val structArray1 = Literal.create(Seq(create_row(2), create_row(3), null),
ArrayType(structType))
checkEvaluation(ArrayIntersect(structArray0, structArray1), Seq(create_row(2), null))
checkEvaluation(ArrayIntersect(structArray1, structArray0), Seq(create_row(2), null))

val nestedStructType = ArrayType(ArrayType(structType))
val nestedStructArray0 = Literal.create(Seq(
Seq(create_row(1), create_row(2)),
Seq(create_row(3)),
Seq(create_row(1), create_row(2))),
nestedStructType)
val nestedStructArray1 = Literal.create(Seq(
Seq(create_row(3)),
Seq(create_row(4))),
nestedStructType)
checkEvaluation(ArrayIntersect(nestedStructArray0, nestedStructArray1),
Seq(Seq(create_row(3))))
checkEvaluation(ArrayIntersect(nestedStructArray1, nestedStructArray0),
Seq(Seq(create_row(3))))
}

test("Array Intersect - null handling") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,3 +936,39 @@ select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN")))
-- !query analysis
Project [array_distinct(array(cast(0.0 as double), cast(0.0 as double), cast(0.0 as double), cast(NaN as double), cast(NaN as double))) AS array_distinct(array(0.0, 0.0, 0.0, NaN, NaN))#x]
+- OneRowRelation


-- !query
select array_union(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)))
-- !query analysis
Project [array_union(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double))), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0))) AS array_union(array(named_struct(a, -0.0), named_struct(a, NaN)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 1.0)))#x]
+- OneRowRelation


-- !query
select array_distinct(
array(named_struct('a', -0.0D), named_struct('a', 0.0D),
named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN'))))
-- !query analysis
Project [array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, cast(NaN as double)))) AS array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, NaN)))#x]
+- OneRowRelation


-- !query
select array_except(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN'))))
-- !query analysis
Project [array_except(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)))) AS array_except(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN)))#x]
+- OneRowRelation


-- !query
select array_intersect(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D)))
-- !query analysis
Project [array_intersect(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, 2.0))) AS array_intersect(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 2.0)))#x]
+- OneRowRelation
Original file line number Diff line number Diff line change
Expand Up @@ -936,3 +936,39 @@ select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN")))
-- !query analysis
Project [array_distinct(array(cast(0.0 as double), cast(0.0 as double), cast(0.0 as double), cast(NaN as double), cast(NaN as double))) AS array_distinct(array(0.0, 0.0, 0.0, NaN, NaN))#x]
+- OneRowRelation


-- !query
select array_union(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)))
-- !query analysis
Project [array_union(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double))), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0))) AS array_union(array(named_struct(a, -0.0), named_struct(a, NaN)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 1.0)))#x]
+- OneRowRelation


-- !query
select array_distinct(
array(named_struct('a', -0.0D), named_struct('a', 0.0D),
named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN'))))
-- !query analysis
Project [array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, cast(NaN as double)))) AS array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, NaN)))#x]
+- OneRowRelation


-- !query
select array_except(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN'))))
-- !query analysis
Project [array_except(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)))) AS array_except(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN)))#x]
+- OneRowRelation


-- !query
select array_intersect(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D)))
-- !query analysis
Project [array_intersect(array(named_struct(a, -0.0), named_struct(a, cast(NaN as double)), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, cast(NaN as double)), named_struct(a, 2.0))) AS array_intersect(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 2.0)))#x]
+- OneRowRelation
14 changes: 14 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/array.sql
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,17 @@ select array_prepend(array(CAST(NULL AS String)), CAST(NULL AS String));
-- SPARK-45599: Confirm 0.0, -0.0, and NaN are handled appropriately.
select array_union(array(0.0, -0.0, DOUBLE("NaN")), array(0.0, -0.0, DOUBLE("NaN")));
select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN")));

-- SPARK-54698: Confirm 0.0, -0.0, and NaN are handled appropriately for complex types.
select array_union(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)));
select array_distinct(
array(named_struct('a', -0.0D), named_struct('a', 0.0D),
named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN'))));
select array_except(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN'))));
select array_intersect(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D)));
40 changes: 40 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/array.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -1106,3 +1106,43 @@ select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN")))
struct<array_distinct(array(0.0, 0.0, 0.0, NaN, NaN)):array<double>>
-- !query output
[0.0,NaN]


-- !query
select array_union(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)))
-- !query schema
struct<array_union(array(named_struct(a, -0.0), named_struct(a, NaN)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 1.0))):array<struct<a:double>>>
-- !query output
[{"a":-0.0},{"a":NaN},{"a":1.0}]


-- !query
select array_distinct(
array(named_struct('a', -0.0D), named_struct('a', 0.0D),
named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN'))))
-- !query schema
struct<array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, NaN))):array<struct<a:double>>>
-- !query output
[{"a":-0.0},{"a":NaN}]


-- !query
select array_except(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN'))))
-- !query schema
struct<array_except(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN))):array<struct<a:double>>>
-- !query output
[{"a":1.0}]


-- !query
select array_intersect(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D)))
-- !query schema
struct<array_intersect(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 2.0))):array<struct<a:double>>>
-- !query output
[{"a":-0.0},{"a":NaN}]
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,43 @@ select array_distinct(array(0.0, -0.0, -0.0, DOUBLE("NaN"), DOUBLE("NaN")))
struct<array_distinct(array(0.0, 0.0, 0.0, NaN, NaN)):array<double>>
-- !query output
[0.0,NaN]


-- !query
select array_union(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN'))),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)))
-- !query schema
struct<array_union(array(named_struct(a, -0.0), named_struct(a, NaN)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 1.0))):array<struct<a:double>>>
-- !query output
[{"a":-0.0},{"a":NaN},{"a":1.0}]


-- !query
select array_distinct(
array(named_struct('a', -0.0D), named_struct('a', 0.0D),
named_struct('a', DOUBLE('NaN')), named_struct('a', DOUBLE('NaN'))))
-- !query schema
struct<array_distinct(array(named_struct(a, -0.0), named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, NaN))):array<struct<a:double>>>
-- !query output
[{"a":-0.0},{"a":NaN}]


-- !query
select array_except(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN'))))
-- !query schema
struct<array_except(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN))):array<struct<a:double>>>
-- !query output
[{"a":1.0}]


-- !query
select array_intersect(
array(named_struct('a', -0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 1.0D)),
array(named_struct('a', 0.0D), named_struct('a', DOUBLE('NaN')), named_struct('a', 2.0D)))
-- !query schema
struct<array_intersect(array(named_struct(a, -0.0), named_struct(a, NaN), named_struct(a, 1.0)), array(named_struct(a, 0.0), named_struct(a, NaN), named_struct(a, 2.0))):array<struct<a:double>>>
-- !query output
[{"a":-0.0},{"a":NaN}]