Skip to content

[SPARK-52709][SQL] Fix parsing of STRUCT<> #51480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ ZONE: 'ZONE';

EQ : '=' | '==';
NSEQ: '<=>';
NEQ : '<>';
NEQ : '<>' {complex_type_level_counter == 0}?;
NEQJ: '!=';
LT : '<';
LTE : '<=' | '!>';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ primitiveType
dataType
: complex=ARRAY (LT dataType GT)? #complexDataType
| complex=MAP (LT dataType COMMA dataType GT)? #complexDataType
| complex=STRUCT ((LT complexColTypeList? GT) | NEQ)? #complexDataType
| complex=STRUCT (LT complexColTypeList? GT)? #complexDataType
| primitiveType #primitiveDataType
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
* Create a complex DataType. Arrays, Maps and Structures are supported.
*/
override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
if (ctx.LT() == null && ctx.NEQ() == null) {
if (ctx.LT() == null) {
throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.getText, ctx)
}
ctx.complex.getType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.{SparkConf, SparkThrowable}
import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedHaving, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Cast, Concat, GreaterThan, Literal, NamedExpression, NullsFirst, ShiftRight, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference}
import org.apache.spark.sql.catalyst.parser.{AbstractParser, ParseException}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern._
Expand All @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType}
import org.apache.spark.util.ArrayImplicits._

/**
Expand Down Expand Up @@ -1164,4 +1164,75 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession {
}
}
}

test("SPARK-52709: Parsing STRUCT (empty,nested,within complex types) followed by shiftRight") {

// Test valid complex data types, and their combinations.
val typeStringsToTest = Seq(
"STRUCT<>", // Empty struct
"STRUCT<a: STRUCT<b: INT>>", // Nested struct
"STRUCT<c: ARRAY<INT>>", // Struct containing an array
"MAP<STRING, STRUCT<x: STRING, y: INT>>", // Map containing a struct
"ARRAY<STRUCT<>>", // Array containing empty structs
"ARRAY<STRUCT<id: INT, name: STRING>>" // Array containing non-empty structs
)

/**
* Helper function to generate a SQL CAST fragment and its corresponding
* expected expression for a given type string.
*/
def createCastNullAsTypeExpression(typeString: String): (String, NamedExpression) = {
// Use the suite's 'parser' instance to parse the DataType
val dataType: DataType = parser.parseDataType(typeString)
val castExpr = Cast(Literal(null, NullType), dataType)
val expectedExpr = UnresolvedAlias(castExpr) // SparkSqlParserSuite expects UnresolvedAlias
val sqlFragment = s"CAST(null AS $typeString)"
(sqlFragment, expectedExpr)
}

// Generate the SQL fragments and their corresponding expected expressions for all CASTs
val castExpressionsData = typeStringsToTest.map(createCastNullAsTypeExpression)

// Extract just the SQL fragments for the SELECT statement
val selectClauses = castExpressionsData.map(_._1)

val sql =
s"""
|SELECT
| ${selectClauses.mkString(",\n ")},
| 4 >> 1
""".stripMargin

// Construct the list of ALL expected expressions for the Project node.
// This includes all the CAST expressions generated above, plus the ShiftRight expression.
val allExpectedExprs = castExpressionsData.map(_._2) :+
UnresolvedAlias(ShiftRight(Literal(4, IntegerType), Literal(1, IntegerType)))

// Define the expected logical plan
val expectedPlan = Project(
allExpectedExprs,
OneRowRelation()
)

assertEqual(sql, expectedPlan)
}

test("SPARK-52709-Invalid: Parsing should fail for empty ARRAY<> type") {
val sql = "SELECT CAST(null AS ARRAY<>)"
checkError(
exception = parseException(sql),
condition = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "'<'", "hint" -> ": missing ')'")
)
}

test("SPARK-52709-Invalid: Parsing should fail for empty MAP<> type") {
val sql = "SELECT CAST(null AS MAP<>)"
checkError(
exception = parseException(sql),
condition = "PARSE_SYNTAX_ERROR",
parameters = Map("error" -> "'<'", "hint" -> ": missing ')'")
)
}
}