Skip to content

Replace multiple calls to withColumn with single select to simplify query plans #888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions spark/src/main/scala/ai/chronon/spark/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@

package ai.chronon.spark

import org.slf4j.LoggerFactory
import ai.chronon.api
import ai.chronon.api.Constants
import ai.chronon.api.Extensions.JoinPartOps
import ai.chronon.api.{Constants, JoinPart}
import ai.chronon.online.{AvroCodec, AvroConversions, SparkConversions}
import org.apache.avro.Schema
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, LongType, StructType}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.util.sketch.BloomFilter
import org.slf4j.LoggerFactory

import java.util
import scala.collection.Seq
Expand Down Expand Up @@ -154,10 +156,36 @@ object Extensions {
TableUtils(df.sparkSession).insertUnPartitioned(df, tableName, tableProperties)
}

def prefixColumnNames(prefix: String, columns: Seq[String]): DataFrame = {
columns.foldLeft(df) { (renamedDf, key) =>
renamedDf.withColumnRenamed(key, s"${prefix}_$key")
/**
* Pads fields in a dataframe according to a schema.
* Fields in the schema that are not present in the dataframe
* are filled with null values.
*/
def padFields(structType: sql.types.StructType): DataFrame = {
val existingColumns = df.columns.toSet
val paddedColumns = structType
.filterNot(field => existingColumns.contains(field.name))
.map(field => lit(null).cast(field.dataType).as(field.name))
val columnsWithPadding = df.columns.map(col) ++ paddedColumns
df.select(columnsWithPadding: _*)
}

def renameRightColumnsForJoin(joinPart: JoinPart, timeColumns: Set[String]): DataFrame = {
val nonValueColumns = joinPart.rightToLeft.keys.toSet ++ timeColumns
val valueColumns = df.schema.names.toSet.diff(nonValueColumns)

val renamedColumns = df.columns.map { columnName =>
val column = col(columnName)
if (joinPart.rightToLeft.contains(columnName)) {
column.as(joinPart.rightToLeft(columnName))
} else if (valueColumns.contains(columnName)) {
column.as(s"${joinPart.fullPrefix}_$columnName")
} else {
column
}
}

df.select(renamedColumns: _*)
}

def validateJoinKeys(right: DataFrame, keys: Seq[String]): Unit = {
Expand Down Expand Up @@ -228,23 +256,6 @@ object Extensions {
df.filter(cols.map(_ + " IS NOT NULL").mkString(" AND "))
}

def nullSafeJoin(right: DataFrame, keys: Seq[String], joinType: String): DataFrame = {
validateJoinKeys(right, keys)
val prefixedLeft = df.prefixColumnNames("left", keys)
val prefixedRight = right.prefixColumnNames("right", keys)
val joinExpr = keys
.map(key => prefixedLeft(s"left_$key") <=> prefixedRight(s"right_$key"))
.reduce((col1, col2) => col1.and(col2))
val joined = prefixedLeft.join(
prefixedRight,
joinExpr,
joinType = joinType
)
keys.foldLeft(joined) { (renamedJoin, key) =>
renamedJoin.withColumnRenamed(s"left_$key", key).drop(s"right_$key")
}
}

// convert a millisecond timestamp to string with the specified format
def withTimeBasedColumn(columnName: String,
timeColumn: String = Constants.TimeColumn,
Expand Down
15 changes: 2 additions & 13 deletions spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,6 @@ class Join(joinConf: api.Join,
unsetSemanticHash: Boolean = false)
extends JoinBase(joinConf, endPartition, tableUtils, skipFirstHole, showDf, selectedJoinParts, unsetSemanticHash) {

private def padFields(df: DataFrame, structType: sql.types.StructType): DataFrame = {
structType.foldLeft(df) {
case (df, field) =>
if (df.columns.contains(field.name)) {
df
} else {
df.withColumn(field.name, lit(null).cast(field.dataType))
}
}
}

private def toSparkSchema(fields: Seq[StructField]): sql.types.StructType =
SparkConversions.fromChrononSchema(StructType("", fields.toArray))

Expand All @@ -99,7 +88,7 @@ class Join(joinConf: api.Join,
val contextualFields = toSparkSchema(
bootstrapInfo.externalParts.filter(_.externalPart.isContextual).flatMap(_.keySchema))

def withNonContextualFields(df: DataFrame): DataFrame = padFields(df, nonContextualFields)
def withNonContextualFields(df: DataFrame): DataFrame = df.padFields(nonContextualFields)

// Ensure keys and values for contextual fields are consistent even if only one of them is explicitly bootstrapped
def withContextualFields(df: DataFrame): DataFrame =
Expand Down Expand Up @@ -129,7 +118,7 @@ class Join(joinConf: api.Join,
*/
private def padGroupByFields(baseJoinDf: DataFrame, bootstrapInfo: BootstrapInfo): DataFrame = {
val groupByFields = toSparkSchema(bootstrapInfo.joinParts.flatMap(_.valueSchema))
padFields(baseJoinDf, groupByFields)
baseJoinDf.padFields(groupByFields)
}

private def findBootstrapSetCoverings(bootstrapDf: DataFrame,
Expand Down
23 changes: 5 additions & 18 deletions spark/src/main/scala/ai/chronon/spark/JoinBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,36 +83,23 @@ abstract class JoinBase(joinConf: api.Join,
}
val keys = partLeftKeys ++ additionalKeys

// apply prefix to value columns
val nonValueColumns = joinPart.rightToLeft.keys.toArray ++ Array(Constants.TimeColumn,
tableUtils.partitionColumn,
Constants.TimePartitionColumn)
val valueColumns = rightDf.schema.names.filterNot(nonValueColumns.contains)
val prefixedRightDf = rightDf.prefixColumnNames(joinPart.fullPrefix, valueColumns)

// apply key-renaming to key columns
val newColumns = prefixedRightDf.columns.map { column =>
if (joinPart.rightToLeft.contains(column)) {
col(column).as(joinPart.rightToLeft(column))
} else {
col(column)
}
}
val keyRenamedRightDf = prefixedRightDf.select(newColumns: _*)
val renamedRightDf = rightDf.renameRightColumnsForJoin(
joinPart,
Set(Constants.TimeColumn, tableUtils.partitionColumn, Constants.TimePartitionColumn))

// adjust join keys
val joinableRightDf = if (additionalKeys.contains(Constants.TimePartitionColumn)) {
// increment one day to align with left side ts_ds
// because one day was decremented from the partition range for snapshot accuracy
keyRenamedRightDf
renamedRightDf
.withColumn(
Constants.TimePartitionColumn,
date_format(date_add(to_date(col(tableUtils.partitionColumn), tableUtils.partitionSpec.format), 1),
tableUtils.partitionSpec.format)
)
.drop(tableUtils.partitionColumn)
} else {
keyRenamedRightDf
renamedRightDf
}

logger.info(s"""
Expand Down
26 changes: 10 additions & 16 deletions spark/src/main/scala/ai/chronon/spark/LabelJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

package ai.chronon.spark

import org.slf4j.LoggerFactory
import ai.chronon.api
import ai.chronon.api.DataModel.{Entities, Events}
import ai.chronon.api.Extensions._
import ai.chronon.api.{Constants, JoinPart, TimeUnit, Window}
import ai.chronon.spark.Extensions._
import ai.chronon.online.Metrics
import ai.chronon.spark.Extensions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.lit
import org.apache.spark.util.sketch.BloomFilter
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._
import scala.collection.Seq
Expand Down Expand Up @@ -255,17 +255,11 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
leftDf
}

// apply key-renaming to key columns
val keyRenamedRight = joinPart.rightToLeft.foldLeft(rightDf) {
case (rightDf, (rightKey, leftKey)) => rightDf.withColumnRenamed(rightKey, leftKey)
}

val nonValueColumns = joinPart.rightToLeft.keys.toArray ++ Array(Constants.TimeColumn,
tableUtils.partitionColumn,
Constants.TimePartitionColumn,
Constants.LabelPartitionColumn)
val valueColumns = rightDf.schema.names.filterNot(nonValueColumns.contains)
val prefixedRight = keyRenamedRight.prefixColumnNames(joinPart.fullPrefix, valueColumns)
val renamedRightDf = rightDf.renameRightColumnsForJoin(joinPart,
Set(Constants.TimeColumn,
tableUtils.partitionColumn,
Constants.TimePartitionColumn,
Constants.LabelPartitionColumn))

val partName = joinPart.groupBy.metaData.name

Expand All @@ -274,11 +268,11 @@ class LabelJoin(joinConf: api.Join, tableUtils: TableUtils, labelDS: String) {
|${updatedLeftDf.schema.pretty}
|
|Right Schema:
|${prefixedRight.schema.pretty}
|${renamedRightDf.schema.pretty}
|
|""".stripMargin)

updatedLeftDf.validateJoinKeys(prefixedRight, partLeftKeys)
updatedLeftDf.join(prefixedRight, partLeftKeys, "left_outer")
updatedLeftDf.validateJoinKeys(renamedRightDf, partLeftKeys)
updatedLeftDf.join(renamedRightDf, partLeftKeys, "left_outer")
}
}
71 changes: 68 additions & 3 deletions spark/src/test/scala/ai/chronon/spark/test/ExtensionsTest.scala
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
package ai.chronon.spark.test

import ai.chronon.api.Builders
import ai.chronon.spark.Extensions._
import ai.chronon.spark.test.TestUtils.compareDfSchemas
import ai.chronon.spark.{Comparison, PartitionRange, SparkSessionBuilder, TableUtils}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import org.junit.Assert.assertEquals
import org.junit.Test

class ExtensionsTest {

lazy val spark: SparkSession = SparkSessionBuilder.build("ExtensionsTest", local = true)

import spark.implicits._

private implicit val tableUtils = TableUtils(spark)
private implicit val tableUtils: TableUtils = TableUtils(spark)

@Test
def testPadFields(): Unit = {
val df = Seq(
(1, "foo", null, 1.1),
(2, "bar", null, 2.2),
(3, "baz", null, 3.3)
).toDF("field1", "field2", "field3", "field4")
val paddedSchema = new StructType(Array(
StructField("field2", IntegerType),
StructField("field3", IntegerType),
StructField("field4", IntegerType),
StructField("field5", IntegerType),
StructField("field6", IntegerType),
))

val paddedDf = df.padFields(paddedSchema)

val expectedDf = Seq(
(1, "foo", null, 1.1, null, null),
(2, "bar", null, 2.2, null, null),
(3, "baz", null, 3.3, null, null)
).toDF("field1", "field2", "field3", "field4", "field5", "field6")
val diff = Comparison.sideBySide(expectedDf, paddedDf, List("field1"))
if (diff.count() > 0) {
diff.show()
}
assertEquals(0, diff.count())
}

@Test
def testPrunePartitionTest(): Unit = {
Expand Down Expand Up @@ -43,4 +74,38 @@ class ExtensionsTest {
}
assertEquals(0, diff.count())
}

@Test
def testRenameRightColumnsForJoin(): Unit = {
val schema = new StructType(Array(
StructField("key1", IntegerType),
StructField("key2", IntegerType),
StructField("right_key1", IntegerType),
StructField("right_key2", IntegerType),
StructField("value1", StringType),
StructField("value2", StringType),
StructField("ds", StringType),
StructField("ts", TimestampType),
))
val joinPart = Builders.JoinPart(
groupBy = Builders.GroupBy(metaData = Builders.MetaData(name = "test_gb"), keyColumns = Seq("key1", "key2")),
keyMapping = Map("left_key1" -> "right_key1", "left_key2" -> "right_key2"),
prefix = "test_prefix"
)
val df = spark.createDataFrame(spark.sparkContext.emptyRDD[Row], schema)

val renamedDf = df.renameRightColumnsForJoin(joinPart, Set("ts", "ds"))

val expectedSchema = new StructType(Array(
StructField("key1", IntegerType),
StructField("key2", IntegerType),
StructField("left_key1", IntegerType),
StructField("left_key2", IntegerType),
StructField("test_prefix_test_gb_value1", StringType),
StructField("test_prefix_test_gb_value2", StringType),
StructField("ds", StringType),
StructField("ts", TimestampType),
))
assertEquals(compareDfSchemas(renamedDf.schema, expectedSchema), Seq())
}
}
26 changes: 26 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ package ai.chronon.spark.test

import ai.chronon.aggregator.test.Column
import ai.chronon.api
import ai.chronon.api.Constants.ChrononMetadataKey
import ai.chronon.api.Extensions.MetadataOps
import ai.chronon.api._
import ai.chronon.online.SparkConversions
import ai.chronon.spark.Extensions._
import ai.chronon.spark.TableUtils
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{StructType => SparkStructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

import scala.collection.JavaConverters._
import scala.util.ScalaJavaConversions.JListOps

object TestUtils {
Expand Down Expand Up @@ -342,6 +346,28 @@ object TestUtils {
groupBy
}

/** Compares two schemas, returns a list of differences if they are not equal. */
def compareDfSchemas(expectedSchema: SparkStructType, actualSchema: SparkStructType): Seq[String] = {
val expectedSchemaMap = expectedSchema.fields.map(field => field.name -> (field.dataType, field.nullable)).toMap
val expectedSchemaKeys = expectedSchemaMap.keys.toSet
val actualSchemaMap = actualSchema.fields.map(field => field.name -> (field.dataType, field.nullable)).toMap
val actualSchemaKeys = actualSchemaMap.keys.toSet
val added = actualSchemaKeys.diff(expectedSchemaKeys)
.map(field => s"unexpected field added $field: ${actualSchemaMap(field)}").toSeq
val removed = expectedSchemaKeys.diff(actualSchemaKeys)
.map(field => s"expected field not found $field: ${expectedSchemaMap(field)}").toSeq
val diffs = expectedSchemaKeys.intersect(actualSchemaKeys).map(key => {
val expectedField = expectedSchemaMap(key)
val actualField = actualSchemaMap(key)
if (expectedField != actualField) {
Some(s"unexpected field definition $key: expected: $expectedField found: $actualField")
} else {
None
}
}).filter(_.nonEmpty).map(_.get).toSeq
added ++ removed ++ diffs
}

def createSampleLabelTableDf(spark: SparkSession, tableName: String = "listing_labels"): DataFrame = {
val schema = StructType(
tableName,
Expand Down