Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,62 @@ case class FilterExec(condition: Expression, child: SparkPlan)
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")

val predicateCode = generatePredicateCode(
ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes)
// Subexpression elimination for filter predicates in whole-stage codegen.
// Only collect otherPreds for CSE -- notNullPreds are simple IsNotNull checks
// (guaranteed by the partition logic in FilterExec's constructor) with no CSE value,
// including them would interfere with equivalence analysis.
//
// Note: CSE evaluation code is placed BEFORE predicate short-circuit checks.
// This means common subexpressions are evaluated unconditionally even if an earlier
// notNull check would have short-circuited. This is an intentional tradeoff:
// for expensive shared expressions (e.g., from_json appearing in 500 predicates),
// the benefit of evaluating once vs N times far outweighs the cost of losing
// short-circuit on the CSE portion. When there are no common subexpressions,
// subExprsCode is empty and this path has zero overhead.
// This is safe because Spark SQL expressions handle null inputs gracefully (returning
// null rather than throwing), so evaluating them before notNull guards does not
// introduce new exceptions.
val (inputVarsCode, subExprsCode, predicateCode) =
if (conf.subexpressionEliminationEnabled && otherPreds.nonEmpty) {
val boundOtherPreds = otherPreds.map(
BindReferences.bindReference(_, output))
// Pre-evaluate input variables referenced by otherPreds before CSE analysis.
// FilterExec sets usedInputs = AttributeSet.empty to defer input evaluation
// for short-circuit optimization. However, subexpressionEliminationForWholeStageCodegen's
// internal getLocalInputVariableValues has a side effect: it clears
// ctx.currentVars[i].code for input variables referenced by common subexpressions.
// In the non-split path, the cleared codes are baked into the subexpression eval code,
// and exprCodesNeedEvaluate is not returned. If notNullPreds reference the same input
// variables, generatePredicateCode's evaluateRequiredVariables would find empty code
// and skip their declarations, causing "is not an rvalue" compilation errors.
// By pre-evaluating here, we ensure input variable codes are already EmptyBlock before
// CSE analysis runs, avoiding the conflict.
val otherPredInputAttrs = AttributeSet(otherPreds.flatMap(_.references))
val inputVarsEvalCode = evaluateRequiredVariables(
child.output, input, otherPredInputAttrs)

val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundOtherPreds)
// withSubExprEliminationExprs requires Seq[ExprCode] return type, but we need
// the String result from generatePredicateCode. Use var + side-effect capture
// as a workaround for this API constraint.
val predCode: String = {
var code = ""
ctx.withSubExprEliminationExprs(subExprs.states) {
code = generatePredicateCode(
ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes)
Seq.empty
}
code
}
// evaluateSubExprEliminationState must be called after predicate code generation;
// it emits the pre-computation code and marks states as consumed.
(inputVarsEvalCode, ctx.evaluateSubExprEliminationState(subExprs.states.values), predCode)
} else {
// CSE disabled or no other predicates: fall back to original codegen path
// with no overhead.
("", "", generatePredicateCode(
ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes))
}

// Reset the isNull to false for the not-null columns, then the followed operators could
// generate better code (remove dead branches).
Expand All @@ -268,6 +322,8 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// Note: wrap in "do { } while (false);", so the generated checks can jump out with "continue;"
s"""
|do {
| $inputVarsCode
| $subExprsCode
| $predicateCode
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAnd
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.debug.codegenString
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

// Disable AQE because the WholeStageCodegenExec is added when running QueryStageExec
class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
Expand Down Expand Up @@ -944,4 +945,92 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
}
}
}

test("SPARK-56032: subexpression elimination in FilterExec codegen") {
// Verify that common subexpressions in filter predicates are evaluated only once
// in whole-stage codegen. This was the root cause of the from_json codegen revert:
// without CSE in FilterExec.doConsume, expensive shared expressions (like from_json)
// were inlined N times, causing code bloat and performance regression.
def testFilterCSE(cseEnabled: Boolean): (Seq[Row], String) = {
withSQLConf(
SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> cseEnabled.toString,
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
// Use a low split threshold to exercise the split code path in CSE, where common
// subexpressions are extracted into separate helper functions.
SQLConf.CODEGEN_METHOD_SPLIT_THRESHOLD.key -> "1") {
val df = spark.range(10).selectExpr("id", "id as a", "id as b")
// (a + b) is the common subexpression shared across three predicates
val filtered = df.where("(a + b) > 3 AND (a + b) < 17 AND (a + b) != 10")
val plan = filtered.queryExecution.executedPlan
assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]),
"Filter should be in whole-stage codegen")
(filtered.collect().toSeq, codegenString(plan))
}
}

val (cseResult, cseCode) = testFilterCSE(cseEnabled = true)
val (noCseResult, noCseCode) = testFilterCSE(cseEnabled = false)

// Functional correctness: both modes must produce the same result
val expected = (2L to 8L).filter(_ * 2 != 10).map(i => Row(i, i, i))
assert(cseResult === expected)
assert(noCseResult === expected)

// CSE semantic check: count how many times the addition (a + b) is computed.
// The generated code uses MathUtils.addExact for long addition with overflow check.
// With CSE enabled, (a + b) should be computed once and reused across all three
// predicates; without CSE, it is inlined separately in each predicate.
val addExactPattern = "addExact".r
val cseAddCount = addExactPattern.findAllIn(cseCode).length
val noCseAddCount = addExactPattern.findAllIn(noCseCode).length
assert(cseAddCount < noCseAddCount,
s"CSE should reduce repeated evaluation: addExact appears $cseAddCount times with CSE " +
s"vs $noCseAddCount times without")
}

test("SPARK-56032: FilterExec CSE with notNullPreds sharing input variables") {
// Regression test for a bug in CodeGenerator.subexpressionEliminationForWholeStageCodegen:
// In the non-split path, getLocalInputVariableValues clears ctx.currentVars[i].code for
// input variables referenced by common subexpressions (side effect), but the saved
// exprCodesNeedEvaluate was discarded (returned Seq.empty). When FilterExec's
// generatePredicateCode later processed notNullPreds referencing the same input variables,
// evaluateRequiredVariables found empty code and skipped variable declarations, causing
// "is not an rvalue" compilation errors (e.g., TPC-DS q85).
withSQLConf(
SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> "true",
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
// Create nullable columns so that IsNotNull predicates are meaningful and
// not optimized away.
val schema = StructType(Seq(
StructField("a", IntegerType, nullable = true),
StructField("b", IntegerType, nullable = true)))
val data = spark.sparkContext.parallelize(Seq(
Row(1, 5), Row(null, 3), Row(4, null), Row(5, 6), Row(7, 8), Row(2, 3)))
val df = spark.createDataFrame(data, schema)

// Filter condition produces:
// - notNullPreds: IsNotNull(a), IsNotNull(b) (inferred by optimizer)
// - otherPreds: (a + b) > 3, (a + b) < 15 (common subexpression: a + b)
// The common subexpression (a + b) references input variables a and b, which are
// also referenced by the notNullPreds -- this is the exact scenario that triggers
// the bug.
val result = df.where("a IS NOT NULL AND (a + b) > 3 AND (a + b) < 15")

val plan = result.queryExecution.executedPlan
assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]),
"Filter should be in whole-stage codegen")

// Verify generated code compiles successfully. Before the fix, this would fail
// with "Expression ... is not an rvalue" because input variable declarations
// were lost when CSE cleared ctx.currentVars but discarded exprCodesNeedEvaluate.
val codeGenStr = codegenString(plan)
assert(codeGenStr.nonEmpty, "Should generate valid code")

// Row(1,5): a+b=6, passes | Row(null,3): excluded | Row(4,null): excluded
// Row(5,6): a+b=11, passes | Row(7,8): a+b=15, excluded | Row(2,3): a+b=5, passes
checkAnswer(result, Seq(Row(1, 5), Row(5, 6), Row(2, 3)))
}
}
}