Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,35 @@ class CodegenContext extends Logging {
*/
var currentVars: Seq[ExprCode] = null

/**
* A mapping from [[ExprId]] to [[ExprCode]] for lambda variables that are currently in scope.
* This is used by [[NamedLambdaVariable]] to look up pre-computed variable bindings set by
* enclosing higher-order functions during code generation.
*
* The enclosing higher-order function registers entries before generating the lambda body code,
* and restores the previous state after. This follows the same save/restore pattern as
* `currentVars`/`INPUT_ROW`.
*
* Note: Like other mutable state in CodegenContext (e.g., `currentVars`, `INPUT_ROW`),
* this is not thread-safe. Callers must ensure single-threaded access during code generation.
*/
var lambdaVariableMap: Map[ExprId, ExprCode] = Map.empty

/**
* Registers lambda variable bindings, executes the given block,
* then restores the previous bindings. This ensures lambda variable scoping is correct
* for nested higher-order functions.
*
* Note: bindings from inner HOFs take precedence over outer ones via `Map.++`.
* This is safe because [[ExprId]]s are globally unique; inner and outer lambda
* variables will never share the same ExprId.
*/
def withLambdaVariableBindings[T](bindings: Map[ExprId, ExprCode])(f: => T): T = {
val oldBindings = lambdaVariableMap
lambdaVariableMap = lambdaVariableMap ++ bindings
try f finally { lambdaVariableMap = oldBindings }
}

/**
* Holding expressions' inlined mutable states like `MonotonicallyIncreasingID.count` as a
* 2-tuple: java type, variable name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import scala.collection.mutable
import scala.jdk.CollectionConverters.MapHasAsScala

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern._
Expand Down Expand Up @@ -81,8 +84,9 @@ case class NamedLambdaVariable(
exprId: ExprId = NamedExpression.newExprId,
value: AtomicReference[Any] = new AtomicReference())
extends LeafExpression
with NamedExpression
with CodegenFallback {
with NamedExpression {

final override val nodePatterns: Seq[TreePattern] = Seq(LAMBDA_VARIABLE)

override def qualifier: Seq[String] = Seq.empty

Expand All @@ -98,13 +102,59 @@ case class NamedLambdaVariable(

override def eval(input: InternalRow): Any = value.get

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ctx.lambdaVariableMap.get(exprId) match {
case Some(binding) =>
// Lambda variable has been bound by an enclosing higher-order function.
// Return the binding directly -- it already contains the correct code,
// isNull, and value referencing the mutable state fields.
binding
case None =>
// No binding found -- fall back to interpreted eval via references array.
// This is unexpected in normal operation (the enclosing HOF should have registered
// bindings), but we degrade gracefully rather than failing the query.
NamedLambdaVariable.warnNoCodegenBinding(name, exprId)
val idx = ctx.references.length
ctx.references += this
val objectTerm = ctx.freshName("lambdaValue")
val javaType = CodeGenerator.javaType(dataType)
// Pass null as the input row because NamedLambdaVariable.eval() ignores
// the input row entirely -- it reads its value from the AtomicReference
// set by the enclosing HOF's eval loop.
if (nullable) {
ev.copy(code = code"""
Object $objectTerm = ((Expression) references[$idx]).eval(null);
boolean ${ev.isNull} = $objectTerm == null;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = (${CodeGenerator.boxedType(dataType)}) $objectTerm;
}""")
} else {
ev.copy(code = code"""
Object $objectTerm = ((Expression) references[$idx]).eval(null);
$javaType ${ev.value} = (${CodeGenerator.boxedType(dataType)}) $objectTerm;
""", isNull = FalseLiteral)
}
}
}

override def toString: String = s"lambda $name#${exprId.id}$typeSuffix"

override def simpleString(maxFields: Int): String = {
s"lambda $name#${exprId.id}: ${dataType.simpleString(maxFields)}"
}
}

object NamedLambdaVariable extends Logging {
private[expressions] def warnNoCodegenBinding(name: String, exprId: ExprId): Unit = {
logWarning(
s"NamedLambdaVariable '$name#${exprId.id}' has no codegen binding, " +
"falling back to interpreted eval. This warning is emitted during code generation " +
"(not per row at runtime). " +
"Possible cause: missing binding in an enclosing higher-order function's doGenCode.")
}
}

/**
* A lambda function and its arguments. A lambda function can be hidden when a user wants to
* process an completely independent expression in a [[HigherOrderFunction]], the lambda function
Expand All @@ -114,7 +164,7 @@ case class LambdaFunction(
function: Expression,
arguments: Seq[NamedExpression],
hidden: Boolean = false)
extends Expression with CodegenFallback {
extends Expression {

override def children: Seq[Expression] = function +: arguments
override def dataType: DataType = function.dataType
Expand All @@ -132,6 +182,26 @@ case class LambdaFunction(

override def eval(input: InternalRow): Any = function.eval(input)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// LambdaFunction is a thin wrapper. The enclosing HOF is responsible for
// registering lambda variable bindings before this is called.
arguments.foreach {
case nlv: NamedLambdaVariable =>
if (!ctx.lambdaVariableMap.contains(nlv.exprId)) {
throw SparkException.internalError(
s"Lambda variable '${nlv.name}#${nlv.exprId.id}' has no codegen binding. " +
s"Bound ids: [${ctx.lambdaVariableMap.keys.map(_.id).mkString(", ")}]")
}
case other =>
// arguments should always be NamedLambdaVariable instances (bound by
// HigherOrderFunction.bind). When hidden=true, arguments is empty and
// this branch is unreachable.
throw SparkException.internalError(
s"Expected NamedLambdaVariable but got ${other.getClass.getName}")
}
function.genCode(ctx)
}

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): LambdaFunction =
copy(
Expand Down Expand Up @@ -312,7 +382,7 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
case class ArrayTransform(
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
extends ArrayBasedSimpleHigherOrderFunction {

override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)

Expand Down Expand Up @@ -354,6 +424,174 @@ case class ArrayTransform(
result
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val argumentGen = argument.genCode(ctx)
val resultArray = ctx.freshName("resultArray")
val numElements = ctx.freshName("numElements")
val loopIndex = ctx.freshName("i")
val arrData = ctx.freshName("arrData")

val elementType = elementVar.dataType
val javaElementType = CodeGenerator.javaType(elementType)
val elementDefault = CodeGenerator.defaultValue(elementType)

// Use mutable state (class fields) for lambda variable bindings instead of local
// variables. This is critical because Expression.reduceCodeSize() may extract
// lambda body code into a separate private method, and local loop variables would
// not be accessible from such extracted methods.
// Concurrency note: each Spark task runs in its own thread with a separate generated
// class instance, so these fields are not shared across tasks.
val elementIsNull = ctx.addMutableState(
CodeGenerator.JAVA_BOOLEAN, "elementIsNull")
val elementValue = ctx.addMutableState(javaElementType, "elementValue")

val elementExtract = if (elementVar.nullable) {
// For primitives, elementDefault provides a valid zero value (e.g. 0 for int) to avoid
// uninitialized reads. For non-primitives, it returns "null" -- logically redundant
// but harmless, and keeps the generated code pattern uniform across all types.
// The isNull flag guards all downstream reads (setElemAtomicRef checks isNull before
// boxing, and lambdaBodyGen propagates isNull through the lambda variable binding).
s"""
|$elementIsNull = $arrData.isNullAt($loopIndex);
|$elementValue = $elementIsNull ?
| $elementDefault : (${CodeGenerator.getValue(arrData, elementType, loopIndex)});
""".stripMargin
} else {
s"""
|$elementIsNull = false;
|$elementValue =
| ${CodeGenerator.getValue(arrData, elementType, loopIndex)};
""".stripMargin
}

// Recursively check whether any sub-expression in the lambda body is a CodegenFallback.
// LambdaFunction and NamedLambdaVariable themselves no longer extend CodegenFallback,
// so this targets genuinely un-codegen'd sub-expressions (e.g., ArrayFilter).
// If none found, we can skip AtomicReference writes entirely, avoiding per-element
// boxing overhead.
val lambdaBodyHasFallback = function.exists(_.isInstanceOf[CodegenFallback])

// Also set the AtomicReference on the lambda variable so that any CodegenFallback
// expressions nested inside the lambda body (e.g., ArrayExists, ArrayFilter that
// haven't been given codegen yet) can read the correct value via
// NamedLambdaVariable.eval(). This is NOT redundant with the mutable state bindings
// above -- the mutable state is for the codegen path, while AtomicReference is for
// CodegenFallback sub-expressions that call eval() at runtime.
val setElemAtomicRef = if (lambdaBodyHasFallback) {
val elemAtomicRefTerm = ctx.addReferenceObj(
"elementVarRef", elementVar.value,
"java.util.concurrent.atomic.AtomicReference")
// Explicitly box primitive values to ensure the AtomicReference contains the
// correct boxed type (e.g., Byte for ByteType, Short for ShortType), matching
// what ArrayData.get() returns in the interpreted path.
val boxedElementType = CodeGenerator.boxedType(elementType)
if (elementVar.nullable) {
s"$elemAtomicRefTerm.set($elementIsNull ? null : ($boxedElementType) $elementValue);"
} else {
s"$elemAtomicRefTerm.set(($boxedElementType) $elementValue);"
}
} else {
""
}

// Build lambda variable bindings using the mutable state variables.
val elementCode = ExprCode(
code = EmptyBlock,
isNull = if (elementVar.nullable) JavaCode.isNullVariable(elementIsNull)
else FalseLiteral,
value = JavaCode.variable(elementValue, elementType))

val (indexExtract, indexBinding) = indexVar match {
case Some(iv) =>
val indexValue = ctx.addMutableState(CodeGenerator.JAVA_INT, "indexValue")
val indexCode = ExprCode(
code = EmptyBlock,
isNull = FalseLiteral,
value = JavaCode.variable(indexValue, IntegerType))
val idxAtomicRefUpdate = if (lambdaBodyHasFallback) {
val idxAtomicRefTerm = ctx.addReferenceObj(
"indexVarRef", iv.value,
"java.util.concurrent.atomic.AtomicReference")
val boxedIndexType = CodeGenerator.boxedType(iv.dataType)
s"\n$idxAtomicRefTerm.set(($boxedIndexType) $loopIndex);"
} else {
""
}
val extract =
s"""
|$indexValue = $loopIndex;$idxAtomicRefUpdate
""".stripMargin
(extract, Some(iv.exprId -> indexCode))
case None =>
("", None)
}

val bindings = Map(elementVar.exprId -> elementCode) ++ indexBinding

// Generate code for the lambda body with bindings registered.
// Call function.genCode (not lf.function.genCode) so that LambdaFunction.doGenCode
// is exercised, including its binding validation.
val lambdaBodyGen = ctx.withLambdaVariableBindings(bindings) {
function.genCode(ctx)
}

// Determine the output element type and write strategy.
val outputElementType = function.dataType
val isPrimitive = CodeGenerator.isPrimitiveType(outputElementType)
val isNullOpt = if (function.nullable) Some(lambdaBodyGen.isNull.toString) else None

// For primitives, setArrayElement handles null check internally.
// For non-primitives, we must copy to avoid memory aliasing with mutable types
// (e.g., UnsafeRow, GenericArrayData). copyValue is a no-op for immutable types
// (e.g., UTF8String, Decimal), so the overhead is negligible.
val setResultElement = if (isPrimitive) {
CodeGenerator.setArrayElement(
resultArray, outputElementType, loopIndex, lambdaBodyGen.value.toString,
isNullOpt)
} else if (function.nullable) {
s"""
|if (${lambdaBodyGen.isNull}) {
| $resultArray.setNullAt($loopIndex);
|} else {
| $resultArray.update($loopIndex,
| InternalRow.copyValue(${lambdaBodyGen.value}));
|}
""".stripMargin
} else {
s"$resultArray.update($loopIndex, InternalRow.copyValue(${lambdaBodyGen.value}));"
}

val allocation = CodeGenerator.createArrayData(
resultArray, outputElementType, numElements,
" ArrayTransform failed.")

// argumentGen.value is guaranteed to be ArrayData for ArrayType expressions.
val loopCode =
s"""
|ArrayData $arrData = (ArrayData) ${argumentGen.value};
|int $numElements = $arrData.numElements();
|$allocation
|for (int $loopIndex = 0; $loopIndex < $numElements; $loopIndex++) {
| $elementExtract
| $setElemAtomicRef
| $indexExtract
| ${lambdaBodyGen.code}
| $setResultElement
|}
""".stripMargin

// Null safety: if argument is null, output is null.
ev.copy(code = code"""
${argumentGen.code}
boolean ${ev.isNull} = ${argumentGen.isNull};
ArrayData ${ev.value} = null;
if (!${ev.isNull}) {
$loopCode
${ev.value} = $resultArray;
}
""")
}

override def nodeName: String = "transform"

override protected def withNewChildrenInternal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -885,4 +885,19 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
"actualType" -> toSQLType(StringType)
)))
}

test("LambdaFunction.doGenCode requires bindings for all lambda variables") {
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext

val lv = NamedLambdaVariable("x", IntegerType, nullable = false)
val lf = LambdaFunction(lv + Literal(1), Seq(lv))
val ctx = new CodegenContext()

// genCode without registering bindings should fail with SparkException
val e = intercept[SparkException] {
lf.genCode(ctx)
}
assert(e.getMessage.contains("has no codegen binding"))
assert(e.getMessage.contains("x#"), "Error message should include the variable name")
}
}
Loading