@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.SerializerBuildHelper
3636import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
3737import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders
3838import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
39- import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
4039import org.apache.spark.sql.catalyst.encoders.OuterScopes
4140import org.apache.spark.sql.catalyst.expressions.objects.Invoke
4241import org.apache.spark.sql.types.DataType
@@ -69,14 +68,15 @@ fun <T : Any> kotlinEncoderFor(
6968 arguments : List <KTypeProjection > = emptyList(),
7069 nullable : Boolean = false,
7170 annotations : List <Annotation > = emptyList()
72- ): Encoder <T > = ExpressionEncoder .apply (
73- KotlinTypeInference .encoderFor(
74- kClass = kClass,
75- arguments = arguments,
76- nullable = nullable,
77- annotations = annotations,
71+ ): Encoder <T > =
72+ applyEncoder(
73+ KotlinTypeInference .encoderFor(
74+ kClass = kClass,
75+ arguments = arguments,
76+ nullable = nullable,
77+ annotations = annotations,
78+ )
7879 )
79- )
8080
8181/* *
8282 * Main method of API, which gives you seamless integration with Spark:
@@ -88,15 +88,26 @@ fun <T : Any> kotlinEncoderFor(
8888 * @return generated encoder
8989 */
9090inline fun <reified T > kotlinEncoderFor (): Encoder <T > =
91- ExpressionEncoder . apply (
92- KotlinTypeInference .encoderFor <T >()
91+ kotlinEncoderFor (
92+ typeOf <T >()
9393 )
9494
9595fun <T > kotlinEncoderFor (kType : KType ): Encoder <T > =
96- ExpressionEncoder . apply (
96+ applyEncoder (
9797 KotlinTypeInference .encoderFor(kType)
9898 )
9999
100+ /* *
101+ * For spark-connect, no ExpressionEncoder is needed, so we can just return the AgnosticEncoder.
102+ */
103+ private fun <T > applyEncoder (agnosticEncoder : AgnosticEncoder <T >): Encoder <T > {
104+ // #if sparkConnect == false
105+ return org.apache.spark.sql.catalyst.encoders.ExpressionEncoder .apply (agnosticEncoder)
106+ // #else
107+ // $return agnosticEncoder
108+ // #endif
109+ }
110+
100111
101112@Deprecated(" Use kotlinEncoderFor instead" , ReplaceWith (" kotlinEncoderFor<T>()" ))
102113inline fun <reified T > encoder (): Encoder <T > = kotlinEncoderFor(typeOf<T >())
@@ -112,7 +123,7 @@ object KotlinTypeInference {
112123 // TODO this hack is a WIP and can give errors
113124 // TODO it's to make data classes get column names like "age" with functions like "getAge"
114125 // TODO instead of column names like "getAge"
115- var DO_NAME_HACK = true
126+ var DO_NAME_HACK = false
116127
117128 /* *
118129 * @param kClass the class for which to infer the encoder.
@@ -151,7 +162,6 @@ object KotlinTypeInference {
151162 currentType = kType,
152163 seenTypeSet = emptySet(),
153164 typeVariables = emptyMap(),
154- isTopLevel = true ,
155165 ) as AgnosticEncoder <T >
156166
157167
@@ -218,7 +228,6 @@ object KotlinTypeInference {
218228
219229 // how the generic types of the data class (like T, S) are filled in for this instance of the class
220230 typeVariables : Map <String , KType >,
221- isTopLevel : Boolean = false,
222231 ): AgnosticEncoder <* > {
223232 val kClass =
224233 currentType.classifier as ? KClass <* > ? : throw IllegalArgumentException (" Unsupported type $currentType " )
@@ -328,7 +337,7 @@ object KotlinTypeInference {
328337 AgnosticEncoders .UDTEncoder (udt, udt.javaClass)
329338 }
330339
331- currentType.isSubtypeOf< scala.Option <* >>() -> {
340+ currentType.isSubtypeOf< scala.Option <* >? > () -> {
332341 val elementEncoder = encoderFor(
333342 currentType = tArguments.first().type!! ,
334343 seenTypeSet = seenTypeSet,
@@ -506,7 +515,6 @@ object KotlinTypeInference {
506515
507516 DirtyProductEncoderField (
508517 doNameHack = DO_NAME_HACK ,
509- isTopLevel = isTopLevel,
510518 columnName = paramName,
511519 readMethodName = readMethodName,
512520 writeMethodName = writeMethodName,
@@ -525,7 +533,7 @@ object KotlinTypeInference {
525533 if (currentType in seenTypeSet) throw IllegalStateException (" Circular reference detected for type $currentType " )
526534 val constructorParams = currentType.getScalaConstructorParameters(typeVariables, kClass)
527535
528- val params: List < AgnosticEncoders . EncoderField > = constructorParams.map { (paramName, paramType) ->
536+ val params = constructorParams.map { (paramName, paramType) ->
529537 val encoder = encoderFor(
530538 currentType = paramType,
531539 seenTypeSet = seenTypeSet + currentType,
@@ -564,7 +572,6 @@ internal open class DirtyProductEncoderField(
564572 private val readMethodName : String , // the name of the method used to read the value
565573 private val writeMethodName : String? ,
566574 private val doNameHack : Boolean ,
567- private val isTopLevel : Boolean ,
568575 encoder : AgnosticEncoder <* >,
569576 nullable : Boolean ,
570577 metadata : Metadata = Metadata .empty(),
@@ -577,18 +584,18 @@ internal open class DirtyProductEncoderField(
577584 /* writeMethod = */ writeMethodName.toOption(),
578585), Serializable {
579586
580- private var isFirstNameCall = true
587+ private var noNameCalls = 0
581588
582589 /* *
583590 * This dirty trick only works because in [SerializerBuildHelper], [ProductEncoder]
584591 * creates an [Invoke] using [name] first and then calls [name] again to retrieve
585592 * the name of the column. This way, we can alternate between the two names.
586593 */
587594 override fun name (): String =
588- if (doNameHack && ! isFirstNameCall ) {
595+ if (doNameHack && noNameCalls > 0 ) {
589596 columnName
590597 } else {
591- isFirstNameCall = false
598+ noNameCalls ++
592599 readMethodName
593600 }
594601
0 commit comments