Skip to content

Commit f26b1a1

Browse files
resolve comments
1 parent dab39fb commit f26b1a1

File tree

4 files changed

+302
-171
lines changed

4 files changed

+302
-171
lines changed

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/join/OracleJoinPushdownIntegrationSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Locale
2222

2323
import org.apache.spark.sql.jdbc.{DockerJDBCIntegrationSuite, JdbcDialect, OracleDatabaseOnDocker, OracleDialect}
2424
import org.apache.spark.sql.jdbc.v2.JDBCV2JoinPushdownIntegrationSuiteBase
25+
import org.apache.spark.sql.types.DataTypes
2526
import org.apache.spark.tags.DockerTest
2627

2728
/**
@@ -55,15 +56,22 @@ import org.apache.spark.tags.DockerTest
5556
class OracleJoinPushdownIntegrationSuite
5657
extends DockerJDBCIntegrationSuite
5758
with JDBCV2JoinPushdownIntegrationSuiteBase {
58-
override val namespaceOpt: Option[String] = Some("SYSTEM")
59+
override val namespace: String = "SYSTEM"
5960

6061
override val db = new OracleDatabaseOnDocker
6162

6263
override val url = db.getJdbcUrl(dockerIp, externalPort)
6364

6465
override val jdbcDialect: JdbcDialect = OracleDialect()
6566

67+
override val integerType = DataTypes.createDecimalType(10, 0)
68+
6669
override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT)
6770

68-
override def schemaPreparation(connection: Connection): Unit = {}
71+
override def schemaPreparation(): Unit = {}
72+
73+
// This method comes from DockerJDBCIntegrationSuite
74+
override def dataPreparation(connection: Connection): Unit = {
75+
super.dataPreparation()
76+
}
6977
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourcePushdownTestUtils.scala

Lines changed: 149 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -22,127 +22,173 @@ import org.apache.spark.sql.catalyst.plans.logical._
2222
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc
2323
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
2424
import org.apache.spark.sql.internal.SQLConf
25+
import org.apache.spark.sql.types.StructType
2526

2627
trait DataSourcePushdownTestUtils extends ExplainSuiteHelper {
28+
protected val supportsSamplePushdown: Boolean = true
29+
30+
protected val supportsFilterPushdown: Boolean = true
31+
32+
protected val supportsLimitPushdown: Boolean = true
33+
34+
protected val supportsAggregatePushdown: Boolean = true
35+
36+
protected val supportsSortPushdown: Boolean = true
37+
38+
protected val supportsOffsetPushdown: Boolean = true
39+
40+
protected val supportsColumnPruning: Boolean = true
41+
42+
protected val supportsJoinPushdown: Boolean = true
43+
44+
2745
protected def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = {
28-
val sample = df.queryExecution.optimizedPlan.collect {
29-
case s: Sample => s
30-
}
31-
if (pushed) {
32-
assert(sample.isEmpty)
33-
} else {
34-
assert(sample.nonEmpty)
46+
if (supportsSamplePushdown) {
47+
val sample = df.queryExecution.optimizedPlan.collect {
48+
case s: Sample => s
49+
}
50+
if (pushed) {
51+
assert(sample.isEmpty)
52+
} else {
53+
assert(sample.nonEmpty)
54+
}
3555
}
3656
}
3757

3858
protected def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = {
39-
val filter = df.queryExecution.optimizedPlan.collect {
40-
case f: Filter => f
41-
}
42-
if (pushed) {
43-
assert(filter.isEmpty)
44-
} else {
45-
assert(filter.nonEmpty)
59+
if (supportsFilterPushdown) {
60+
val filter = df.queryExecution.optimizedPlan.collect {
61+
case f: Filter => f
62+
}
63+
if (pushed) {
64+
assert(filter.isEmpty)
65+
} else {
66+
assert(filter.nonEmpty)
67+
}
4668
}
4769
}
4870

4971
protected def checkLimitRemoved(df: DataFrame, pushed: Boolean = true): Unit = {
50-
val limit = df.queryExecution.optimizedPlan.collect {
51-
case l: LocalLimit => l
52-
case g: GlobalLimit => g
53-
}
54-
if (pushed) {
55-
assert(limit.isEmpty)
56-
} else {
57-
assert(limit.nonEmpty)
72+
if (supportsLimitPushdown) {
73+
val limit = df.queryExecution.optimizedPlan.collect {
74+
case l: LocalLimit => l
75+
case g: GlobalLimit => g
76+
}
77+
if (pushed) {
78+
assert(limit.isEmpty)
79+
} else {
80+
assert(limit.nonEmpty)
81+
}
5882
}
5983
}
6084

6185
protected def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = {
62-
df.queryExecution.optimizedPlan.collect {
63-
case relation: DataSourceV2ScanRelation => relation.scan match {
64-
case v1: V1ScanWrapper =>
65-
assert(v1.pushedDownOperators.limit == limit)
86+
if (supportsLimitPushdown) {
87+
df.queryExecution.optimizedPlan.collect {
88+
case relation: DataSourceV2ScanRelation => relation.scan match {
89+
case v1: V1ScanWrapper =>
90+
assert(v1.pushedDownOperators.limit == limit)
91+
}
6692
}
6793
}
6894
}
6995

7096
protected def checkColumnPruned(df: DataFrame, col: String): Unit = {
71-
val scan = df.queryExecution.optimizedPlan.collectFirst {
72-
case s: DataSourceV2ScanRelation => s
73-
}.get
74-
assert(scan.schema.names.sameElements(Seq(col)))
97+
if (supportsColumnPruning) {
98+
val scan = df.queryExecution.optimizedPlan.collectFirst {
99+
case s: DataSourceV2ScanRelation => s
100+
}.get
101+
assert(scan.schema.names.sameElements(Seq(col)))
102+
}
75103
}
76104

77105
protected def checkAggregateRemoved(df: DataFrame): Unit = {
78-
val aggregates = df.queryExecution.optimizedPlan.collect {
79-
case agg: Aggregate => agg
106+
if (supportsAggregatePushdown) {
107+
val aggregates = df.queryExecution.optimizedPlan.collect {
108+
case agg: Aggregate => agg
109+
}
110+
assert(aggregates.isEmpty)
80111
}
81-
assert(aggregates.isEmpty)
82112
}
83113

84114
protected def checkAggregatePushed(df: DataFrame, funcName: String): Unit = {
85-
df.queryExecution.optimizedPlan.collect {
86-
case DataSourceV2ScanRelation(_, scan, _, _, _) =>
87-
assert(scan.isInstanceOf[V1ScanWrapper])
88-
val wrapper = scan.asInstanceOf[V1ScanWrapper]
89-
assert(wrapper.pushedDownOperators.aggregation.isDefined)
90-
val aggregationExpressions =
91-
wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
92-
assert(aggregationExpressions.exists { expr =>
93-
expr.isInstanceOf[GeneralAggregateFunc] &&
94-
expr.asInstanceOf[GeneralAggregateFunc].name() == funcName
95-
})
115+
if (supportsAggregatePushdown) {
116+
df.queryExecution.optimizedPlan.collect {
117+
case DataSourceV2ScanRelation(_, scan, _, _, _) =>
118+
assert(scan.isInstanceOf[V1ScanWrapper])
119+
val wrapper = scan.asInstanceOf[V1ScanWrapper]
120+
assert(wrapper.pushedDownOperators.aggregation.isDefined)
121+
val aggregationExpressions =
122+
wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
123+
assert(aggregationExpressions.exists { expr =>
124+
expr.isInstanceOf[GeneralAggregateFunc] &&
125+
expr.asInstanceOf[GeneralAggregateFunc].name() == funcName
126+
})
127+
}
96128
}
97129
}
98130

99-
protected def checkSortRemoved(df: DataFrame, pushed: Boolean = true): Unit = {
100-
val sorts = df.queryExecution.optimizedPlan.collect {
101-
case s: Sort => s
102-
}
131+
protected def checkSortRemoved(
132+
df: DataFrame,
133+
pushed: Boolean = true): Unit = {
134+
if (supportsSortPushdown) {
135+
val sorts = df.queryExecution.optimizedPlan.collect {
136+
case s: Sort => s
137+
}
103138

104-
if (pushed) {
105-
assert(sorts.isEmpty)
106-
} else {
107-
assert(sorts.nonEmpty)
139+
if (pushed) {
140+
assert(sorts.isEmpty)
141+
} else {
142+
assert(sorts.nonEmpty)
143+
}
108144
}
109145
}
110146

111-
protected def checkOffsetRemoved(df: DataFrame, pushed: Boolean = true): Unit = {
112-
val offsets = df.queryExecution.optimizedPlan.collect {
113-
case o: Offset => o
114-
}
147+
protected def checkOffsetRemoved(
148+
df: DataFrame,
149+
pushed: Boolean = true): Unit = {
150+
if (supportsOffsetPushdown) {
151+
val offsets = df.queryExecution.optimizedPlan.collect {
152+
case o: Offset => o
153+
}
115154

116-
if (pushed) {
117-
assert(offsets.isEmpty)
118-
} else {
119-
assert(offsets.nonEmpty)
155+
if (pushed) {
156+
assert(offsets.isEmpty)
157+
} else {
158+
assert(offsets.nonEmpty)
159+
}
120160
}
121161
}
122162

123163
protected def checkOffsetPushed(df: DataFrame, offset: Option[Int]): Unit = {
124-
df.queryExecution.optimizedPlan.collect {
125-
case relation: DataSourceV2ScanRelation => relation.scan match {
126-
case v1: V1ScanWrapper =>
127-
assert(v1.pushedDownOperators.offset == offset)
164+
if (supportsOffsetPushdown) {
165+
df.queryExecution.optimizedPlan.collect {
166+
case relation: DataSourceV2ScanRelation => relation.scan match {
167+
case v1: V1ScanWrapper =>
168+
assert(v1.pushedDownOperators.offset == offset)
169+
}
128170
}
129171
}
130172
}
131173

132174
protected def checkJoinNotPushed(df: DataFrame): Unit = {
133-
val joinNodes = df.queryExecution.optimizedPlan.collect {
134-
case j: Join => j
175+
if (supportsJoinPushdown) {
176+
val joinNodes = df.queryExecution.optimizedPlan.collect {
177+
case j: Join => j
178+
}
179+
assert(joinNodes.nonEmpty, "Join should not be pushed down")
135180
}
136-
assert(joinNodes.nonEmpty, "Join should not be pushed down")
137181
}
138182

139183
protected def checkJoinPushed(df: DataFrame, expectedTables: String*): Unit = {
140-
val joinNodes = df.queryExecution.optimizedPlan.collect {
141-
case j: Join => j
142-
}
143-
assert(joinNodes.isEmpty, "Join should be pushed down")
144-
if (expectedTables.nonEmpty) {
145-
checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]")
184+
if (supportsJoinPushdown) {
185+
val joinNodes = df.queryExecution.optimizedPlan.collect {
186+
case j: Join => j
187+
}
188+
assert(joinNodes.isEmpty, "Join should be pushed down")
189+
if (expectedTables.nonEmpty) {
190+
checkPushedInfo(df, s"PushedJoins: [${expectedTables.mkString(", ")}]")
191+
}
146192
}
147193
}
148194

@@ -154,4 +200,34 @@ trait DataSourcePushdownTestUtils extends ExplainSuiteHelper {
154200
}
155201
}
156202
}
203+
204+
/**
205+
* Check if the output schema of dataframe {@code df} is same as {@code schema}. There is one
206+
* limitation: if expected schema name is empty, assertion on same names will be skipped.
207+
* <br>
208+
* For example, it is not really possible to use {@code checkPrunedColumns} for join pushdown,
209+
* because in case of duplicate names, columns will have random UUID suffixes. For this reason,
210+
* the best we can do is test that the size is same, and other fields beside names do match.
211+
*/
212+
protected def checkPrunedColumnsDataTypeAndNullability(
213+
df: DataFrame,
214+
schema: StructType): Unit = {
215+
if (supportsColumnPruning) {
216+
df.queryExecution.optimizedPlan.collect {
217+
case relation: DataSourceV2ScanRelation => relation.scan match {
218+
case v1: V1ScanWrapper =>
219+
val dfSchema = v1.readSchema()
220+
221+
assert(dfSchema.length == schema.length)
222+
dfSchema.fields.zip(schema.fields).foreach { case (f1, f2) =>
223+
if (f2.name.nonEmpty) {
224+
assert(f1.name == f2.name)
225+
}
226+
assert(f1.dataType == f2.dataType)
227+
assert(f1.nullable == f2.nullable)
228+
}
229+
}
230+
}
231+
}
232+
}
157233
}

0 commit comments

Comments
 (0)