Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,20 @@ class StatePartitionReaderFactory(
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider])
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String])
extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val stateStoreInputPartition = partition.asInstanceOf[StateStoreInputPartition]
if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
stateStoreColFamilySchemaOpt, stateSchemaProviderOpt)
stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt)
} else {
new StatePartitionReader(storeConf, hadoopConf,
stateStoreInputPartition, schema, keyStateEncoderSpec, stateVariableInfoOpt,
stateStoreColFamilySchemaOpt, stateSchemaProviderOpt)
stateStoreColFamilySchemaOpt, stateSchemaProviderOpt, joinColFamilyOpt)
}
}
}
Expand All @@ -71,7 +72,8 @@ abstract class StatePartitionReaderBase(
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider])
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String])
extends PartitionReader[InternalRow] with Logging {
// Used primarily as a placeholder for the value schema in the context of
// state variables used within the transformWithState operator.
Expand All @@ -98,11 +100,7 @@ abstract class StatePartitionReaderBase(
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)

val useColFamilies = if (stateVariableInfoOpt.isDefined) {
true
} else {
false
}
val useColFamilies = stateVariableInfoOpt.isDefined || joinColFamilyOpt.isDefined

val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt,
StateVariableType.ListState)
Expand Down Expand Up @@ -164,10 +162,11 @@ class StatePartitionReader(
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider])
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String])
extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
stateSchemaProviderOpt) {
stateSchemaProviderOpt, joinColFamilyOpt) {

private lazy val store: ReadStateStore = {
partition.sourceOptions.fromSnapshotOptions match {
Expand All @@ -186,17 +185,18 @@ class StatePartitionReader(
}

override lazy val iter: Iterator[InternalRow] = {
val stateVarName = stateVariableInfoOpt
.map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
val colFamilyName = stateStoreColFamilySchemaOpt
.map(_.colFamilyName).getOrElse(
joinColFamilyOpt.getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME))

if (stateVariableInfoOpt.isDefined) {
val stateVariableInfo = stateVariableInfoOpt.get
val stateVarType = stateVariableInfo.stateVariableType
SchemaUtil.processStateEntries(stateVarType, stateVarName, store,
SchemaUtil.processStateEntries(stateVarType, colFamilyName, store,
keySchema, partition.partition, partition.sourceOptions)
} else {
store
.iterator(stateVarName)
.iterator(colFamilyName)
.map { pair =>
SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition)
}
Expand All @@ -221,10 +221,11 @@ class StateStoreChangeDataPartitionReader(
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider])
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String])
extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
stateSchemaProviderOpt) {
stateSchemaProviderOpt, joinColFamilyOpt) {

private lazy val changeDataReader:
NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] = {
Expand All @@ -235,6 +236,8 @@ class StateStoreChangeDataPartitionReader(

val colFamilyNameOpt = if (stateVariableInfoOpt.isDefined) {
Some(stateVariableInfoOpt.get.stateName)
} else if (joinColFamilyOpt.isDefined) {
Some(joinColFamilyOpt.get)
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ class StateScanBuilder(
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider]) extends ScanBuilder {
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String]) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt)
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt)
}

/** An implementation of [[InputPartition]] for State Store data source. */
Expand All @@ -65,7 +67,8 @@ class StateScan(
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider])
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String])
extends Scan with Batch {

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
Expand Down Expand Up @@ -120,24 +123,28 @@ class StateScan(
override def createReaderFactory(): PartitionReaderFactory = sourceOptions.joinSide match {
case JoinSideValues.left =>
val userFacingSchema = schema
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions,
hadoopConfBroadcast.value.value)
val stateSchema = StreamStreamJoinStateHelper.readSchema(session,
sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, LeftSide,
excludeAuxColumns = false)
oldSchemaFilePaths, excludeAuxColumns = false)
new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf,
hadoopConfBroadcast.value, userFacingSchema, stateSchema)

case JoinSideValues.right =>
val userFacingSchema = schema
val oldSchemaFilePaths = StateDataSource.getOldSchemaFilePaths(sourceOptions,
hadoopConfBroadcast.value.value)
val stateSchema = StreamStreamJoinStateHelper.readSchema(session,
sourceOptions.stateCheckpointLocation.toString, sourceOptions.operatorId, RightSide,
excludeAuxColumns = false)
oldSchemaFilePaths, excludeAuxColumns = false)
new StreamStreamJoinStatePartitionReaderFactory(stateStoreConf,
hadoopConfBroadcast.value, userFacingSchema, stateSchema)

case JoinSideValues.none =>
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
stateSchemaProviderOpt)
stateSchemaProviderOpt, joinColFamilyOpt)
}

override def toBatch: Batch = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class StateTable(
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
stateSchemaProviderOpt: Option[StateSchemaProvider])
stateSchemaProviderOpt: Option[StateSchemaProvider],
joinColFamilyOpt: Option[String])
extends Table with SupportsRead with SupportsMetadataColumns {

import StateTable._
Expand Down Expand Up @@ -85,7 +86,8 @@ class StateTable(

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new StateScanBuilder(session, schema, sourceOptions, stateConf, keyStateEncoderSpec,
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt)
stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt)

override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ package org.apache.spark.sql.execution.datasources.v2.state

import java.util.UUID

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinSide
import org.apache.spark.sql.execution.streaming.CheckpointFileManager
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{JoinSide, LeftSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreId, StateStoreProviderId, SymmetricHashJoinStateManager}
import org.apache.spark.sql.types.{BooleanType, StructType}

Expand All @@ -35,52 +39,92 @@ object StreamStreamJoinStateHelper {
stateCheckpointLocation: String,
operatorId: Int,
side: JoinSide,
oldSchemaFilePaths: List[Path],
excludeAuxColumns: Boolean = true): StructType = {
val (keySchema, valueSchema) = readKeyValueSchema(session, stateCheckpointLocation,
operatorId, side, excludeAuxColumns)
operatorId, side, oldSchemaFilePaths, excludeAuxColumns)

new StructType()
.add("key", keySchema)
.add("value", valueSchema)
}

// Returns whether the checkpoint uses stateFormatVersion 3 which uses VCF for the join.
def usesVirtualColumnFamilies(
hadoopConf: Configuration,
stateCheckpointLocation: String,
operatorId: Int): Boolean = {
// If the schema exists for operatorId/partitionId/left-keyToNumValues, it is not
// stateFormatVersion 3.
val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, SymmetricHashJoinStateManager.allStateStoreNames(LeftSide).toList.head)
val schemaFilePath = StateSchemaCompatibilityChecker.schemaFile(
storeId.storeCheckpointLocation())
val fm = CheckpointFileManager.create(schemaFilePath, hadoopConf)
!fm.exists(schemaFilePath)
}

def readKeyValueSchema(
session: SparkSession,
stateCheckpointLocation: String,
operatorId: Int,
side: JoinSide,
oldSchemaFilePaths: List[Path],
excludeAuxColumns: Boolean = true): (StructType, StructType) = {

val newHadoopConf = session.sessionState.newHadoopConf()
val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
// KeyToNumValuesType, KeyWithIndexToValueType
val storeNames = SymmetricHashJoinStateManager.allStateStoreNames(side).toList

val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
val storeIdForKeyToNumValues = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, storeNames(0))
val providerIdForKeyToNumValues = new StateStoreProviderId(storeIdForKeyToNumValues,
UUID.randomUUID())
val (keySchema, valueSchema) =
if (!usesVirtualColumnFamilies(
newHadoopConf, stateCheckpointLocation, operatorId)) {
val storeIdForKeyToNumValues = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, storeNames(0))
val providerIdForKeyToNumValues = new StateStoreProviderId(storeIdForKeyToNumValues,
UUID.randomUUID())

val storeIdForKeyWithIndexToValue = new StateStoreId(stateCheckpointLocation,
operatorId, partitionId, storeNames(1))
val providerIdForKeyWithIndexToValue = new StateStoreProviderId(storeIdForKeyWithIndexToValue,
UUID.randomUUID())
val storeIdForKeyWithIndexToValue = new StateStoreId(stateCheckpointLocation,
operatorId, partitionId, storeNames(1))
val providerIdForKeyWithIndexToValue = new StateStoreProviderId(
storeIdForKeyWithIndexToValue, UUID.randomUUID())

val newHadoopConf = session.sessionState.newHadoopConf()
// read the key schema from the keyToNumValues store for the join keys
val manager = new StateSchemaCompatibilityChecker(
providerIdForKeyToNumValues, newHadoopConf, oldSchemaFilePaths)
val kSchema = manager.readSchemaFile().head.keySchema

// read the value schema from the keyWithIndexToValue store for the values
val manager2 = new StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue,
newHadoopConf, oldSchemaFilePaths)
val vSchema = manager2.readSchemaFile().head.valueSchema

(kSchema, vSchema)
} else {
val storeId = new StateStoreId(stateCheckpointLocation, operatorId,
partitionId, StateStoreId.DEFAULT_STORE_NAME)
val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())

val manager = new StateSchemaCompatibilityChecker(
providerId, newHadoopConf, oldSchemaFilePaths)
val kSchema = manager.readSchemaFile().find { schema =>
schema.colFamilyName == storeNames(0)
}.map(_.keySchema).get

// read the key schema from the keyToNumValues store for the join keys
val manager = new StateSchemaCompatibilityChecker(providerIdForKeyToNumValues, newHadoopConf)
val keySchema = manager.readSchemaFile().head.keySchema
val vSchema = manager.readSchemaFile().find { schema =>
schema.colFamilyName == storeNames(1)
}.map(_.valueSchema).get

// read the value schema from the keyWithIndexToValue store for the values
val manager2 = new StateSchemaCompatibilityChecker(providerIdForKeyWithIndexToValue,
newHadoopConf)
val valueSchema = manager2.readSchemaFile().head.valueSchema
(kSchema, vSchema)
}

val maybeMatchedColumn = valueSchema.last

if (excludeAuxColumns
&& maybeMatchedColumn.name == "matched"
&& maybeMatchedColumn.dataType == BooleanType) {
&& maybeMatchedColumn.name == "matched"
&& maybeMatchedColumn.dataType == BooleanType) {
// remove internal column `matched` for format version 2
(keySchema, StructType(valueSchema.dropRight(1)))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,18 @@ class StreamStreamJoinStatePartitionReader(
private val (inputAttributes, formatVersion) = {
val maybeMatchedColumn = valueSchema.last
val (fields, version) = {
// If there is a matched column, version is either 2 or 3. We need to drop the matched
// column from the value schema to get the actual fields.
if (maybeMatchedColumn.name == "matched" && maybeMatchedColumn.dataType == BooleanType) {
(valueSchema.dropRight(1), 2)
// If checkpoint is using one store and virtual column families, version is 3
if (StreamStreamJoinStateHelper.usesVirtualColumnFamilies(
hadoopConf.value,
partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId)) {
(valueSchema.dropRight(1), 3)
} else {
(valueSchema.dropRight(1), 2)
}
} else {
(valueSchema, 1)
}
Expand Down Expand Up @@ -137,7 +147,7 @@ class StreamStreamJoinStatePartitionReader(
inputAttributes)

joinStateManager.iterator.map { pair =>
if (formatVersion == 2) {
if (formatVersion >= 2) {
val row = valueWithMatchedRowGenerator(pair.value)
row.setBoolean(indexOrdinalInValueWithMatchedRow, pair.matched)
unifyStateRowPair(pair.key, row)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowPythonRunner
import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, groupAndProject, resolveArgOffsets}
import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulOperatorStateInfo, StatefulProcessorHandleImpl, TransformWithStateExecBase, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulOperatorStateInfo, StatefulOperatorsUtils, StatefulProcessorHandleImpl, TransformWithStateExecBase, TransformWithStateVariableInfo}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, StateStoreProviderId}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -95,9 +95,9 @@ case class TransformWithStateInPySparkExec(
override def shortName: String = if (
userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS
) {
"transformWithStateInPandasExec"
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME
} else {
"transformWithStateInPySparkExec"
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME
}

private val pythonUDF = functionExpr.asInstanceOf[PythonUDF]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ case class StreamingSymmetricHashJoinExec(
case _ => throwBadJoinTypeException()
}

override def shortName: String = "symmetricHashJoin"
override def shortName: String = StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME

override val stateStoreNames: Seq[String] = _stateStoreNames

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,9 +737,9 @@ abstract class SymmetricHashJoinStateManager(
if (useVirtualColumnFamilies) {
stateStore.createColFamilyIfAbsent(
colFamilyName,
keySchema,
keyWithIndexSchema,
valueRowConverter.valueAttributes.toStructType,
NoPrefixKeyStateEncoderSpec(keySchema)
NoPrefixKeyStateEncoderSpec(keyWithIndexSchema)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1546,3 +1546,16 @@ trait SchemaValidationUtils extends Logging {
schemaEvolutionEnabled = usingAvro && schemaEvolutionEnabledForOperator))
}
}

object StatefulOperatorsUtils {
val TRANSFORM_WITH_STATE_EXEC_OP_NAME = "transformWithStateExec"
val TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME = "transformWithStateInPandasExec"
val TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME = "transformWithStateInPySparkExec"
// Seq of operator names who uses state schema v3 and TWS related options.
val TRANSFORM_WITH_STATE_OP_NAMES: Seq[String] = Seq(
TRANSFORM_WITH_STATE_EXEC_OP_NAME,
TRANSFORM_WITH_STATE_IN_PANDAS_EXEC_OP_NAME,
TRANSFORM_WITH_STATE_IN_PYSPARK_EXEC_OP_NAME
)
val SYMMETRIC_HASH_JOIN_EXEC_OP_NAME = "symmetricHashJoin"
}
Loading