Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
72e8922
add benchmark
LuciferYang Mar 10, 2026
bc49e32
fix
LuciferYang Mar 10, 2026
c46c933
add times
LuciferYang Mar 11, 2026
f965a0c
revert collectionOperations.scala
LuciferYang Mar 11, 2026
4a61bc3
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 11, 2026
bcaba7c
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 11, 2026
8e5b795
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 11, 2026
b5d0d38
add more tests
LuciferYang Mar 11, 2026
8b33dc7
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 11, 2026
f7d29cb
use putIfAbsent
LuciferYang Mar 12, 2026
e3ac77e
add config
LuciferYang Mar 12, 2026
dcad5e5
refactor test
LuciferYang Mar 12, 2026
4ac347b
refactor benchmark
LuciferYang Mar 12, 2026
6757907
fix doc
LuciferYang Mar 12, 2026
0917346
add more memory
LuciferYang Mar 12, 2026
782d7a5
init
LuciferYang Mar 12, 2026
9d21fae
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 12, 2026
c721967
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 12, 2026
3bc2c2a
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 12, 2026
fa38f0b
try add 1M back
LuciferYang Mar 13, 2026
7c01c08
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 13, 2026
2e7bf82
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 13, 2026
756ccbc
Benchmark results for org.apache.spark.sql.execution.benchmark.MapLoo…
LuciferYang Mar 13, 2026
85495d3
Merge branch 'upmaster' into issue-54646
LuciferYang Mar 13, 2026
a6eb0f3
add ConfigBindingPolicy
LuciferYang Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,14 +441,60 @@ trait GetArrayItemUtil {
*/
trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {

// todo: current search is O(n), improve it.
@transient private var lastMap: MapData = _
@transient private var lastIndex: java.util.HashMap[Any, Int] = _

/**
* The threshold to determine whether to use hash lookup for map lookup expressions.
* If the map size is small, the cost of building hash map exceeds the cost of a linear scan.
* This is configured by `spark.sql.mapLookupHashThreshold`.
*/
@transient private lazy val hashLookupThreshold =
SQLConf.get.getConf(SQLConf.MAP_LOOKUP_HASH_THRESHOLD)

private def getOrBuildIndex(map: MapData, keyType: DataType): java.util.HashMap[Any, Int] = {
if (lastMap ne map) {
val keys = map.keyArray()
val len = keys.numElements()
val hm = new java.util.HashMap[Any, Int]((len * 1.5).toInt)
var i = 0
while (i < len) {
val k = keys.get(i, keyType)
hm.putIfAbsent(k, i)
i += 1
}
lastIndex = hm
lastMap = map
}
lastIndex
}

def getValueEval(
value: Any,
ordinal: Any,
keyType: DataType,
ordering: Ordering[Any]): Any = {
val map = value.asInstanceOf[MapData]
val length = map.numElements()

if (length < hashLookupThreshold || !TypeUtils.typeWithProperEquals(keyType)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use the approach I have in #53468 to support all types for hashing (and help me get that merged in 😬 ). Though it doesn't do codegen yet, would need to think about how to do that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit tired today. Let me take a look tomorrow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can submit a separate pr later and then make use of the data structures in #53468.

getValueEvalLinear(map, ordinal, keyType, ordering)
} else {
val idx = getOrBuildIndex(map, keyType).getOrDefault(ordinal, -1)
if (idx == -1 || map.valueArray().isNullAt(idx)) {
null
} else {
map.valueArray().get(idx, dataType)
}
}
}

private def getValueEvalLinear(
map: MapData,
ordinal: Any,
keyType: DataType,
ordering: Ordering[Any]): Any = {
val length = map.numElements()
val keys = map.keyArray()
val values = map.valueArray()

Expand All @@ -473,38 +519,178 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
ctx: CodegenContext,
ev: ExprCode,
mapType: MapType): ExprCode = {
val keyType = mapType.keyType
if (supportsHashLookup(keyType)) {
doGetValueGenCodeWithHashOpt(ctx, ev, mapType)
} else {
doGetValueGenCodeLinear(ctx, ev, mapType)
}
}

private def supportsHashLookup(keyType: DataType): Boolean = keyType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | DateType | TimestampType |
TimestampNTZType | _: YearMonthIntervalType |
_: DayTimeIntervalType => true
case st: StringType if st.supportsBinaryEquality => true
case _ => false
}

private def doGetValueGenCodeLinear(
ctx: CodegenContext,
ev: ExprCode,
mapType: MapType): ExprCode = {
val index = ctx.freshName("index")
val length = ctx.freshName("length")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val keyType = mapType.keyType

val keyJavaType = CodeGenerator.javaType(keyType)
val loopKey = ctx.freshName("loopKey")
val i = ctx.freshName("i")

val nullValueCheck = if (mapType.valueContainsNull) {
s"""
|else if ($values.isNullAt($index)) {
| ${ev.isNull} = true;
|}
""".stripMargin
} else {
""
}

nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
|final int $length = $eval1.numElements();
|final ArrayData $keys = $eval1.keyArray();
|final ArrayData $values = $eval1.valueArray();
|int $index = -1;
|
|for (int $i = 0; $i < $length; $i++) {
| $keyJavaType $loopKey = ${CodeGenerator.getValue(keys, keyType, i)};
| if (${ctx.genEqual(keyType, loopKey, eval2)}) {
| $index = $i;
| break;
| }
|}
|
|if ($index < 0) {
| ${ev.isNull} = true;
|} $nullValueCheck else {
| ${ev.value} = ${CodeGenerator.getValue(values, dataType, index)};
|}
""".stripMargin
})
}

/**
* Generates code for map lookups.
* If the map size is small (less than HASH_LOOKUP_THRESHOLD), it uses a linear scan.
* If the map size is large, it builds a hash index for O(1) lookup.
*/
private def doGetValueGenCodeWithHashOpt(
ctx: CodegenContext,
ev: ExprCode,
mapType: MapType): ExprCode = {
val index = ctx.freshName("index")
val length = ctx.freshName("length")
val keys = ctx.freshName("keys")
val key = ctx.freshName("key")
val values = ctx.freshName("values")
val keyType = mapType.keyType
val nullCheck = if (mapType.valueContainsNull) {
s" || $values.isNullAt($index)"

val nullValueCheck = if (mapType.valueContainsNull) {
s"""
|else if ($values.isNullAt($index)) {
| ${ev.isNull} = true;
|}
""".stripMargin
} else {
""
}

val keyJavaType = CodeGenerator.javaType(keyType)
val lastKeyArray = ctx.addMutableState("ArrayData", "lastKeyArray", v => s"$v = null;")
val hashBuckets = ctx.addMutableState("int[]", "hashBuckets", v => s"$v = null;")
val hashMask = ctx.addMutableState("int", "hashMask", v => s"$v = 0;")

def genHash(v: String): String = keyType match {
case BooleanType => s"($v ? 1 : 0)"
case ByteType | ShortType | IntegerType | DateType | _: YearMonthIntervalType => s"$v"
case LongType | TimestampType | TimestampNTZType | _: DayTimeIntervalType =>
s"(int)($v ^ ($v >>> 32))"
case FloatType => s"Float.floatToIntBits($v)"
case DoubleType =>
s"(int)(Double.doubleToLongBits($v) ^ (Double.doubleToLongBits($v) >>> 32))"
case _ => s"$v.hashCode()"
}

nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val i = ctx.freshName("i")
val h = ctx.freshName("h")
val cap = ctx.freshName("cap")
val idx = ctx.freshName("idx")
val candidate = ctx.freshName("candidate")
val loopKey = ctx.freshName("loopKey")

val buildIndex =
s"""
|int $cap = Math.max(Integer.highestOneBit(Math.max($length * 2 - 1, 1)) << 1, 4);
|if ($hashBuckets == null || $hashBuckets.length < $cap) {
| $hashBuckets = new int[$cap];
|}
|java.util.Arrays.fill($hashBuckets, 0, $cap, -1);
|$hashMask = $cap - 1;
|for (int $i = 0; $i < $length; $i++) {
| $keyJavaType $loopKey = ${CodeGenerator.getValue(keys, keyType, i)};
| int $h = (${genHash(loopKey)}) & $hashMask;
| while ($hashBuckets[$h] != -1) {
| $h = ($h + 1) & $hashMask;
| }
| $hashBuckets[$h] = $i;
|}
|$lastKeyArray = $keys;
""".stripMargin

val lookup =
s"""
|int $h = (${genHash(eval2)}) & $hashMask;
|$index = -1;
|while ($hashBuckets[$h] != -1) {
| int $idx = $hashBuckets[$h];
| $keyJavaType $candidate = ${CodeGenerator.getValue(keys, keyType, idx)};
| if (${ctx.genEqual(keyType, candidate, eval2)}) {
| $index = $idx;
| break;
| }
| $h = ($h + 1) & $hashMask;
|}
""".stripMargin

s"""
final int $length = $eval1.numElements();
final ArrayData $keys = $eval1.keyArray();
final ArrayData $values = $eval1.valueArray();
int $index = -1;

int $index = 0;
while ($index < $length) {
final $keyJavaType $key = ${CodeGenerator.getValue(keys, keyType, index)};
if (${ctx.genEqual(keyType, key, eval2)}) {
break;
} else {
$index++;
if ($length >= $hashLookupThreshold) {
if ($keys != $lastKeyArray) {
$buildIndex
}
$lookup
} else {
for (int $i = 0; $i < $length; $i++) {
$keyJavaType $loopKey = ${CodeGenerator.getValue(keys, keyType, i)};
if (${ctx.genEqual(keyType, loopKey, eval2)}) {
$index = $i;
break;
}
}
}

if ($index == $length$nullCheck) {
if ($index < 0) {
${ev.isNull} = true;
} else {
} $nullValueCheck else {
${ev.value} = ${CodeGenerator.getValue(values, dataType, index)};
}
"""
Expand Down Expand Up @@ -547,15 +733,10 @@ case class GetMapValue(child: Expression, key: Expression)

/**
* `Null` is returned for invalid ordinals.
*
* TODO: We could make nullability more precise in foldable cases (e.g., literal input).
* But, since the key search is O(n), it takes much time to compute nullability.
* If we find efficient key searches, revisit this.
*/
override def nullable: Boolean = true
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType

// todo: current search is O(n), improve it.
override def nullSafeEval(value: Any, ordinal: Any): Any = {
getValueEval(value, ordinal, keyType, ordering)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2501,6 +2501,18 @@ object SQLConf {
.intConf
.createWithDefault(-1)

val MAP_LOOKUP_HASH_THRESHOLD =
buildConf("spark.sql.optimizer.mapLookupHashThreshold")
.internal()
.doc("The minimum number of map entries to attempt hash-based lookup in `element_at` and " +
"the `[]` operator. Below this threshold, linear scan is used. For key types that do not " +
"support hashing (e.g. arrays, structs), linear scan is always used regardless of map size.")
.version("4.2.0")
.withBindingPolicy(ConfigBindingPolicy.SESSION)
.intConf
.checkValue(_ >= 0, "The threshold must be non-negative.")
.createWithDefault(1000)

val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files. " +
"This configuration is effective only when using file-based sources such as Parquet, JSON " +
Expand Down
Loading