Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3133,14 +3133,23 @@ object SQLConf {
"State between versions are tend to be incompatible, so state format version shouldn't " +
"be modified after running. Version 3 uses a single state store with virtual column " +
"families instead of four stores and is only supported with RocksDB. NOTE: version " +
"1 is DEPRECATED and should not be explicitly set by users.")
"1 is DEPRECATED and should not be explicitly set by users. " +
"Version 4 is under development and only available for testing.")
.version("3.0.0")
.intConf
// TODO: [SPARK-55628] Add version 4 once we integrate the state format version 4 into
// stream-stream join operator.
.checkValue(v => Set(1, 2, 3).contains(v), "Valid versions are 1, 2, and 3")
.checkValue(v => Set(1, 2, 3, 4).contains(v), "Valid versions are 1, 2, 3, and 4")
.createWithDefault(2)

val STREAMING_JOIN_STATE_FORMAT_V4_ENABLED =
buildConf("spark.sql.streaming.join.stateFormatV4.enabled")
.internal()
.doc("When true, enables state format version 4 for stream-stream joins. " +
"This config will be removed once V4 is complete.")
.version("4.2.0")
.withBindingPolicy(ConfigBindingPolicy.SESSION)
.booleanConf
.createWithDefaultFunction(() => Utils.isTesting)

val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION =
buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition")
.doc("When true, streaming session window sorts and merge sessions in local partition " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ case class StreamingSymmetricHashJoinExec(
private val allowMultipleStatefulOperators =
conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)

private val useVirtualColumnFamilies = stateFormatVersion == 3
private val useVirtualColumnFamilies = stateFormatVersion >= 3

// Determine the store names and metadata version based on format version
private val (numStoresPerPartition, _stateStoreNames, _operatorStateMetadataVersion) =
Expand Down Expand Up @@ -292,8 +292,12 @@ case class StreamingSymmetricHashJoinExec(
val info = getStateInfo
val stateSchemaDir = stateSchemaDirPath()

// V4 uses VCF like V3, which requires schema version 3. The stateSchemaVersion
// parameter may carry the stateFormatVersion (e.g. 4) from IncrementalExecution,
// so we hardcode 3 here for the VCF path.
val effectiveSchemaVersion = 3
validateAndWriteStateSchema(
hadoopConf, batchId, stateSchemaVersion, info, stateSchemaDir, session
hadoopConf, batchId, effectiveSchemaVersion, info, stateSchemaDir, session
)
} else {
var result: Map[String, (StructType, StructType)] = Map.empty
Expand Down Expand Up @@ -437,7 +441,7 @@ case class StreamingSymmetricHashJoinExec(
removedRowIter.filterNot { kv =>
stateFormatVersion match {
case 1 => matchesWithRightSideState(new UnsafeRowPair(kv.key, kv.value))
case 2 | 3 => kv.matched
case 2 | 3 | 4 => kv.matched
case _ => throwBadStateFormatVersionException()
}
}.map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
Expand All @@ -463,7 +467,7 @@ case class StreamingSymmetricHashJoinExec(
removedRowIter.filterNot { kv =>
stateFormatVersion match {
case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value))
case 2 | 3 => kv.matched
case 2 | 3 | 4 => kv.matched
case _ => throwBadStateFormatVersionException()
}
}.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
Expand All @@ -479,7 +483,7 @@ case class StreamingSymmetricHashJoinExec(
case FullOuter =>
lazy val isKeyToValuePairMatched = (kv: KeyToValuePair) =>
stateFormatVersion match {
case 2 | 3 => kv.matched
case 2 | 3 | 4 => kv.matched
case _ => throwBadStateFormatVersionException()
}

Expand Down Expand Up @@ -801,15 +805,15 @@ case class StreamingSymmetricHashJoinExec(
s.evictByKeyCondition(stateKeyWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictByTimestamp(stateWatermark)
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
}
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictByValueCondition(stateValueWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictByTimestamp(stateWatermark)
s.evictByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
}
case _ => 0L
}
Expand All @@ -833,20 +837,27 @@ case class StreamingSymmetricHashJoinExec(
s.evictAndReturnByKeyCondition(stateKeyWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictAndReturnByTimestamp(stateWatermark)
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
}
case Some(JoinStateValueWatermarkPredicate(_, stateWatermark)) =>
joinStateManager match {
case s: SupportsEvictByCondition =>
s.evictAndReturnByValueCondition(stateValueWatermarkPredicateFunc)

case s: SupportsEvictByTimestamp =>
s.evictAndReturnByTimestamp(stateWatermark)
s.evictAndReturnByTimestamp(watermarkMsToStateTimestamp(stateWatermark))
}
case _ => Iterator.empty
}
}

/**
* V4 stores timestamps in microseconds (TimestampType) while the watermark
* is tracked in milliseconds. Convert ms to microseconds for eviction calls.
*/
private def watermarkMsToStateTimestamp(watermarkMs: Long): Long =
watermarkMs * 1000

/** Commit changes to the buffer state */
def commitState(): Unit = {
joinStateManager.commit()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, WatermarkSupport}
import org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper._
import org.apache.spark.sql.execution.streaming.state.{DropLastNFieldsStatePartitionKeyExtractor, KeyStateEncoderSpec, NoopStatePartitionKeyExtractor, NoPrefixKeyStateEncoderSpec, StatePartitionKeyExtractor, StateSchemaBroadcast, StateStore, StateStoreCheckpointInfo, StateStoreColFamilySchema, StateStoreConf, StateStoreErrors, StateStoreId, StateStoreMetrics, StateStoreProvider, StateStoreProviderId, SupportsFineGrainedReplay, TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec, TimestampKeyStateEncoder}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType, LongType, NullType, StructField, StructType}
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -252,9 +253,9 @@ class SymmetricHashJoinStateManagerV4(
Seq(StructField("dummy", NullType, nullable = true))
)

// TODO: [SPARK-55628] Below two fields need to be handled properly during integration with
// the operator.
private val stateStoreCkptId: Option[String] = None
// V4 uses a single store with VCFs (not separate keyToNumValues/keyWithIndexToValue stores).
// Use the keyToNumValues checkpoint ID for loading the correct committed version.
private val stateStoreCkptId: Option[String] = keyToNumValuesStateStoreCkptId
private val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None

private var stateStoreProvider: StateStoreProvider = _
Expand Down Expand Up @@ -496,7 +497,7 @@ class SymmetricHashJoinStateManagerV4(
private val attachTimestampProjection: UnsafeProjection =
TimestampKeyStateEncoder.getAttachTimestampProjection(keySchema)

// Create the specific column family in the store for this join side's KeyWithIndexToValueStore
// Create the specific column family in the store for this join side's KeyWithTsToValuesStore.
stateStore.createColFamilyIfAbsent(
colFamilyName,
keySchema,
Expand Down Expand Up @@ -648,13 +649,15 @@ class SymmetricHashJoinStateManagerV4(
private val attachTimestampProjection: UnsafeProjection =
TimestampKeyStateEncoder.getAttachTimestampProjection(keySchema)

// Create the specific column family in the store for this join side's KeyWithIndexToValueStore
// Create the specific column family in the store for this join side's TsWithKeyStore.
// Mark as internal so that numKeys counts only primary data, not the secondary index.
stateStore.createColFamilyIfAbsent(
colFamilyName,
keySchema,
valueStructType,
TimestampAsPrefixKeyStateEncoderSpec(keySchemaWithTimestamp),
useMultipleValuesPerKey = true
useMultipleValuesPerKey = true,
isInternal = true
)

private def createKeyRow(key: UnsafeRow, timestamp: Long): UnsafeRow = {
Expand Down Expand Up @@ -1311,8 +1314,8 @@ abstract class SymmetricHashJoinStateManagerBase(
val handlerSnapshotOptions: Option[HandlerSnapshotOptions] = None)
extends StateStoreHandler(
KeyToNumValuesType, keyToNumValuesStateStoreCkptId, handlerSnapshotOptions) {
SnapshotOptions
private val useVirtualColumnFamilies = stateFormatVersion == 3

private val useVirtualColumnFamilies = stateFormatVersion >= 3
private val longValueSchema = new StructType().add("value", "long")
private val longToUnsafeRow = UnsafeProjection.create(longValueSchema)
private val valueRow = longToUnsafeRow(new SpecificInternalRow(longValueSchema))
Expand Down Expand Up @@ -1411,7 +1414,7 @@ SnapshotOptions
extends StateStoreHandler(
KeyWithIndexToValueType, keyWithIndexToValueStateStoreCkptId, handlerSnapshotOptions) {

private val useVirtualColumnFamilies = stateFormatVersion == 3
private val useVirtualColumnFamilies = stateFormatVersion >= 3
private val keyWithIndexExprs = keyAttributes :+ Literal(1L)
private val keyWithIndexSchema = keySchema.add("index", LongType)
private val indexOrdinalInKeyWithIndexRow = keyAttributes.size
Expand Down Expand Up @@ -1744,6 +1747,8 @@ object SymmetricHashJoinStateManager {
snapshotOptions: Option[SnapshotOptions] = None,
joinStoreGenerator: JoinStateManagerStoreGenerator): SymmetricHashJoinStateManager = {
if (stateFormatVersion == 4) {
require(SQLConf.get.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_V4_ENABLED),
"State format version 4 is under development.")
new SymmetricHashJoinStateManagerV4(
joinSide, inputValueAttributes, joinKeys, stateInfo, storeConf, hadoopConf,
partitionId, keyToNumValuesStateStoreCkptId, keyWithIndexToValueStateStoreCkptId,
Expand Down Expand Up @@ -1780,28 +1785,44 @@ object SymmetricHashJoinStateManager {
inputValueAttributes: Seq[Attribute],
joinKeys: Seq[Expression],
stateFormatVersion: Int): Map[String, (StructType, StructType)] = {
var result: Map[String, (StructType, StructType)] = Map.empty

// get the key and value schema for the KeyToNumValues state store
val keySchema = StructType(
joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) })
val longValueSchema = new StructType().add("value", "long")
result += (getStateStoreName(joinSide, KeyToNumValuesType) -> (keySchema, longValueSchema))

// get the key and value schema for the KeyWithIndexToValue state store
val keyWithIndexSchema = keySchema.add("index", LongType)
val valueSchema = if (stateFormatVersion == 1) {
inputValueAttributes
} else if (stateFormatVersion == 2 || stateFormatVersion == 3) {
inputValueAttributes :+ AttributeReference("matched", BooleanType)()

if (stateFormatVersion == 4) {
// V4 uses two column families: KeyWithTsToValues and TsWithKey
val keySchemaWithTimestamp =
TimestampKeyStateEncoder.keySchemaWithTimestamp(keySchema)
val valueWithMatchedSchema =
(inputValueAttributes :+ AttributeReference("matched", BooleanType)()).toStructType
val dummyValueSchema = StructType(Array(StructField("__dummy__", NullType)))

Map(
getStateStoreName(joinSide, KeyWithTsToValuesType) ->
(keySchemaWithTimestamp, valueWithMatchedSchema),
getStateStoreName(joinSide, TsWithKeyType) ->
(keySchemaWithTimestamp, dummyValueSchema))
} else {
throw new IllegalArgumentException("Incorrect state format version! " +
s"version=$stateFormatVersion")
}
result += (getStateStoreName(joinSide, KeyWithIndexToValueType) ->
(keyWithIndexSchema, valueSchema.toStructType))
var result: Map[String, (StructType, StructType)] = Map.empty

// get the key and value schema for the KeyToNumValues state store
val longValueSchema = new StructType().add("value", "long")
result += (getStateStoreName(joinSide, KeyToNumValuesType) -> (keySchema, longValueSchema))

// get the key and value schema for the KeyWithIndexToValue state store
val keyWithIndexSchema = keySchema.add("index", LongType)
val valueSchema = if (stateFormatVersion == 1) {
inputValueAttributes
} else if (stateFormatVersion == 2 || stateFormatVersion == 3) {
inputValueAttributes :+ AttributeReference("matched", BooleanType)()
} else {
throw new IllegalArgumentException("Incorrect state format version! " +
s"version=$stateFormatVersion")
}
result += (getStateStoreName(joinSide, KeyWithIndexToValueType) ->
(keyWithIndexSchema, valueSchema.toStructType))

result
result
}
}

/** Retrieves the schemas used for join operator state stores that use column families */
Expand All @@ -1816,9 +1837,18 @@ object SymmetricHashJoinStateManager {

schemas.map {
case (colFamilyName, (keySchema, valueSchema)) =>
val keyStateEncoderSpec = if (stateFormatVersion == 4) {
if (colFamilyName == getStateStoreName(joinSide, KeyWithTsToValuesType)) {
TimestampAsPostfixKeyStateEncoderSpec(keySchema)
} else {
TimestampAsPrefixKeyStateEncoderSpec(keySchema)
}
} else {
NoPrefixKeyStateEncoderSpec(keySchema)
}
colFamilyName -> StateStoreColFamilySchema(
colFamilyName, 0, keySchema, 0, valueSchema,
Some(NoPrefixKeyStateEncoderSpec(keySchema))
Some(keyStateEncoderSpec)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ object KeyStateEncoderSpec {
case "PrefixKeyScanStateEncoderSpec" =>
val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[BigInt].toInt
PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey)
case "TimestampAsPostfixKeyStateEncoderSpec" =>
TimestampAsPostfixKeyStateEncoderSpec(keySchema)
case "TimestampAsPrefixKeyStateEncoderSpec" =>
TimestampAsPrefixKeyStateEncoderSpec(keySchema)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
.select($"key", $"window.end".cast("long"), $"leftValue", $"rightValue")

val useVirtualColumnFamilies =
spark.sessionState.conf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION) == 3
spark.sessionState.conf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION) >= 3
// Number of shuffle partitions being used is 3
val numStateStoreInstances = if (useVirtualColumnFamilies) {
// Only one state store is created per partition if we're using virtual column families
Expand Down
Loading