Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,6 @@ object VeloxBackendSettings extends BackendSettingsApi {
(conf.isUseGlutenShuffleManager || conf.shuffleManagerSupportsColumnarShuffle)
}

override def enableHashTableBuildOncePerExecutor(): Boolean = {
VeloxConfig.get.enableBroadcastBuildOncePerExecutor
}

override def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
t =>
if (super.supportHashBuildJoinTypeOnLeft(t)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -888,19 +888,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
}
}

override def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part still needed when buildHashTableOncePerExecutor is disabled?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doCanonicalizeForBroadcastMode() is not invoked anywhere, broadcast canonicalization goes through ColumnarBroadcastExchangeExec.doCanonicalize().

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is actually useful — when buildHashTableOncePerExecutor is disabled, it provides more opportunities to reuse broadcast exchanges. Moreover, I now think that even when buildHashTableOncePerExecutor is enabled, the comment in doCanonicalizeForBroadcastMode still holds true: we still broadcast byte arrays and build HashRelation at the executor side.
@JkSelf Can you explain why this was removed? This allows us to reuse broadcast exchanges for different build keys with the same data.
We should either restore the code before ColumnarBroadcastExchangeExec.doCanonicalize, or at least follow the original logic when buildHashTableOncePerExecutor is disabled.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JkSelf Thank you for the explanation, I understand now. @wecharyu According to the instructions here, you can restore the original behavior of doCanonicalizeForBroadcastMode when enableBroadcastBuildOncePerExecutor=false.

mode match {
case hash: HashedRelationBroadcastMode =>
// Node: It's different with vanilla Spark.
// Vanilla Spark build HashRelation at driver side, so it is build keys sensitive.
// But we broadcast byte array and build HashRelation at executor side,
// the build keys are actually meaningless for the broadcast value.
// This change allows us reuse broadcast exchange for different build keys with same table.
hash.copy(key = Seq.empty)
case _ => mode.canonicalized
}
}

/**
* * Expressions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,17 @@ case class BroadcastHashJoinExecTransformer(
isNullAwareAntiJoin) {

// Unique ID for built table
lazy val buildBroadcastTableId: String = buildPlan.id.toString
override lazy val buildHashTableId: String = {
// The unique ID is determined by the build keys, join type,
// and null-aware flag, which together define the hash table contents.
val buildKeys = buildKeyExprs.map(_.semanticHash()).mkString("[", ",", "]")
val key = Seq(
buildKeys,
substraitJoinType.name(),
isNullAwareAntiJoin.toString
).mkString("#")
s"BuiltHashTable-${buildPlan.id}-$key"
}

override protected lazy val substraitJoinType: JoinRel.JoinType = joinType match {
case _: InnerLike =>
Expand Down Expand Up @@ -136,10 +146,10 @@ case class BroadcastHashJoinExecTransformer(
val streamedRDD = getColumnarInputRDDs(streamedPlan)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
if (executionId != null) {
GlutenDriverEndpoint.collectResources(executionId, buildBroadcastTableId)
GlutenDriverEndpoint.collectResources(executionId, buildHashTableId)
} else {
logWarning(
s"Can not trace broadcast table data $buildBroadcastTableId" +
s"Can not trace broadcast table data $buildHashTableId" +
s" because execution id is null." +
s" Will clean up until expire time.")
}
Expand Down Expand Up @@ -174,7 +184,7 @@ case class BroadcastHashJoinExecTransformer(
buildPlan.output,
filterBuildColumns,
filterPropagatesNulls,
buildBroadcastTableId,
buildHashTableId,
isNullAwareAntiJoin,
bloomFilterPushdownSize,
metrics.get("buildHashTableTime")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.arrow.c.ArrowSchema

import scala.collection.JavaConverters._
import scala.collection.JavaConverters.asScalaIteratorConverter
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, Map}

object ColumnarBuildSideRelation {
// Keep constructor with BroadcastMode for compatibility
Expand Down Expand Up @@ -153,12 +153,12 @@ case class ColumnarBuildSideRelation(

override def asReadOnlyCopy(): ColumnarBuildSideRelation = this

private var hashTableData: Long = 0L
private val hashTableData = Map.empty[String, Long]

def buildHashTable(
broadcastContext: BroadcastHashJoinContext): (Long, ColumnarBuildSideRelation) =
synchronized {
if (hashTableData == 0) {
if (!hashTableData.contains(broadcastContext.buildHashTableId)) {
val startTime = System.nanoTime()
val runtime = Runtimes.contextInstance(
BackendsApiManager.getBackendName,
Expand Down Expand Up @@ -210,7 +210,7 @@ case class ColumnarBuildSideRelation(
val hashJoinBuilder = HashJoinBuilder.create(runtime)

// Build the hash table
hashTableData = hashJoinBuilder
hashTableData(broadcastContext.buildHashTableId) = hashJoinBuilder
.nativeBuild(
broadcastContext.buildHashTableId,
batchArray.toArray,
Expand All @@ -232,14 +232,14 @@ case class ColumnarBuildSideRelation(
val elapsedTime = System.nanoTime() - startTime
broadcastContext.buildHashTableTimeMetric.foreach(_ += elapsedTime / 1000000)

(hashTableData, this)
(hashTableData(broadcastContext.buildHashTableId), this)
} else {
(HashJoinBuilder.cloneHashTable(hashTableData), null)
(HashJoinBuilder.cloneHashTable(hashTableData(broadcastContext.buildHashTableId)), null)
}
}

def reset(): Unit = synchronized {
hashTableData = 0
hashTableData.clear()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}

import scala.collection.JavaConverters._
import scala.collection.JavaConverters.asScalaIteratorConverter
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, Map}

object UnsafeColumnarBuildSideRelation {
def apply(
Expand Down Expand Up @@ -123,11 +123,11 @@ class UnsafeColumnarBuildSideRelation(
batches
}

private var hashTableData: Long = 0L
private val hashTableData = Map.empty[String, Long]

def buildHashTable(broadcastContext: BroadcastHashJoinContext): (Long, BuildSideRelation) =
synchronized {
if (hashTableData == 0) {
if (!hashTableData.contains(broadcastContext.buildHashTableId)) {
val startTime = System.nanoTime()
val runtime = Runtimes.contextInstance(
BackendsApiManager.getBackendName,
Expand Down Expand Up @@ -180,7 +180,7 @@ class UnsafeColumnarBuildSideRelation(
val hashJoinBuilder = HashJoinBuilder.create(runtime)

// Build the hash table
hashTableData = hashJoinBuilder
hashTableData(broadcastContext.buildHashTableId) = hashJoinBuilder
.nativeBuild(
broadcastContext.buildHashTableId,
batchArray.toArray,
Expand All @@ -202,14 +202,14 @@ class UnsafeColumnarBuildSideRelation(
val elapsedTime = System.nanoTime() - startTime
broadcastContext.buildHashTableTimeMetric.foreach(_ += elapsedTime / 1000000)

(hashTableData, this)
(hashTableData(broadcastContext.buildHashTableId), this)
} else {
(HashJoinBuilder.cloneHashTable(hashTableData), null)
(HashJoinBuilder.cloneHashTable(hashTableData(broadcastContext.buildHashTableId)), null)
}
}

def reset(): Unit = synchronized {
hashTableData = 0
hashTableData.clear()
}

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.execution.{ColumnarSubqueryBroadcastExec, InputIteratorTransformer}
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSubqueryBroadcastExec, InputIteratorTransformer}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec

class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "/tpch-data-parquet"
Expand Down Expand Up @@ -322,4 +323,44 @@ class VeloxHashJoinSuite extends VeloxWholeStageTransformerSuite {
}
}
}

test("Reuse broadcast exchange with different hash table") {
withSQLConf(
("spark.sql.adaptive.enabled", "false")
) {
withTable("t1", "t2") {
spark
.range(100)
.selectExpr("id as key", "id as value")
.write
.saveAsTable("t1")

spark
.range(100)
.selectExpr("id % 7 as key", "id as value")
.write
.saveAsTable("t2")

val query = """
SELECT /*+ BROADCAST(t2) */ t1.key, t1.value
FROM t1
LEFT SEMI JOIN t2 ON t1.key = t2.key
UNION ALL
SELECT /*+ BROADCAST(t2) */ t1.key, t1.value
from t1
JOIN t2 on t1.key = t2.key
"""

runQueryAndCompare(query) {
df =>
// Check that columnar broadcast exchange is reused.
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case b: ColumnarBroadcastExchangeExec => b }.size == 1)
assert(collect(plan) {
case r @ ReusedExchangeExec(_, _: ColumnarBroadcastExchangeExec) => r
}.size == 1)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ trait BackendSettingsApi {

def enableJoinKeysRewrite(): Boolean = true

def enableHashTableBuildOncePerExecutor(): Boolean = true

def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
case _: InnerLike | RightOuter | FullOuter => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,6 @@ trait SparkPlanExecApi {
dataSize: SQLMetric,
buildThreads: SQLMetric = null): BuildSideRelation

def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = {
mode.canonicalized
}

/** Create ColumnarWriteFilesExec */
def createColumnarWriteFilesExec(
child: WriteFilesExecTransformer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport {
joinParams.isBHJ = true
}

val (_, _, hashTableId) = genJoinParametersInternal()
val joinRel = JoinUtils.createJoinRel(
streamedKeyExprs,
buildKeyExprs,
Expand All @@ -267,7 +268,7 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport {
inputBuildOutput,
context,
operatorId,
buildPlan.id.toString
hashTableId
)

context.registerJoinParam(operatorId, joinParams)
Expand Down
Loading