@@ -22,127 +22,173 @@ import org.apache.spark.sql.catalyst.plans.logical._
22
22
import org .apache .spark .sql .connector .expressions .aggregate .GeneralAggregateFunc
23
23
import org .apache .spark .sql .execution .datasources .v2 .{DataSourceV2ScanRelation , V1ScanWrapper }
24
24
import org .apache .spark .sql .internal .SQLConf
25
+ import org .apache .spark .sql .types .StructType
25
26
26
27
trait DataSourcePushdownTestUtils extends ExplainSuiteHelper {
28
+ protected val supportsSamplePushdown : Boolean = true
29
+
30
+ protected val supportsFilterPushdown : Boolean = true
31
+
32
+ protected val supportsLimitPushdown : Boolean = true
33
+
34
+ protected val supportsAggregatePushdown : Boolean = true
35
+
36
+ protected val supportsSortPushdown : Boolean = true
37
+
38
+ protected val supportsOffsetPushdown : Boolean = true
39
+
40
+ protected val supportsColumnPruning : Boolean = true
41
+
42
+ protected val supportsJoinPushdown : Boolean = true
43
+
44
+
27
45
protected def checkSamplePushed (df : DataFrame , pushed : Boolean = true ): Unit = {
28
- val sample = df.queryExecution.optimizedPlan.collect {
29
- case s : Sample => s
30
- }
31
- if (pushed) {
32
- assert(sample.isEmpty)
33
- } else {
34
- assert(sample.nonEmpty)
46
+ if (supportsSamplePushdown) {
47
+ val sample = df.queryExecution.optimizedPlan.collect {
48
+ case s : Sample => s
49
+ }
50
+ if (pushed) {
51
+ assert(sample.isEmpty)
52
+ } else {
53
+ assert(sample.nonEmpty)
54
+ }
35
55
}
36
56
}
37
57
38
58
protected def checkFilterPushed (df : DataFrame , pushed : Boolean = true ): Unit = {
39
- val filter = df.queryExecution.optimizedPlan.collect {
40
- case f : Filter => f
41
- }
42
- if (pushed) {
43
- assert(filter.isEmpty)
44
- } else {
45
- assert(filter.nonEmpty)
59
+ if (supportsFilterPushdown) {
60
+ val filter = df.queryExecution.optimizedPlan.collect {
61
+ case f : Filter => f
62
+ }
63
+ if (pushed) {
64
+ assert(filter.isEmpty)
65
+ } else {
66
+ assert(filter.nonEmpty)
67
+ }
46
68
}
47
69
}
48
70
49
71
protected def checkLimitRemoved (df : DataFrame , pushed : Boolean = true ): Unit = {
50
- val limit = df.queryExecution.optimizedPlan.collect {
51
- case l : LocalLimit => l
52
- case g : GlobalLimit => g
53
- }
54
- if (pushed) {
55
- assert(limit.isEmpty)
56
- } else {
57
- assert(limit.nonEmpty)
72
+ if (supportsLimitPushdown) {
73
+ val limit = df.queryExecution.optimizedPlan.collect {
74
+ case l : LocalLimit => l
75
+ case g : GlobalLimit => g
76
+ }
77
+ if (pushed) {
78
+ assert(limit.isEmpty)
79
+ } else {
80
+ assert(limit.nonEmpty)
81
+ }
58
82
}
59
83
}
60
84
61
85
protected def checkLimitPushed (df : DataFrame , limit : Option [Int ]): Unit = {
62
- df.queryExecution.optimizedPlan.collect {
63
- case relation : DataSourceV2ScanRelation => relation.scan match {
64
- case v1 : V1ScanWrapper =>
65
- assert(v1.pushedDownOperators.limit == limit)
86
+ if (supportsLimitPushdown) {
87
+ df.queryExecution.optimizedPlan.collect {
88
+ case relation : DataSourceV2ScanRelation => relation.scan match {
89
+ case v1 : V1ScanWrapper =>
90
+ assert(v1.pushedDownOperators.limit == limit)
91
+ }
66
92
}
67
93
}
68
94
}
69
95
70
96
protected def checkColumnPruned (df : DataFrame , col : String ): Unit = {
71
- val scan = df.queryExecution.optimizedPlan.collectFirst {
72
- case s : DataSourceV2ScanRelation => s
73
- }.get
74
- assert(scan.schema.names.sameElements(Seq (col)))
97
+ if (supportsColumnPruning) {
98
+ val scan = df.queryExecution.optimizedPlan.collectFirst {
99
+ case s : DataSourceV2ScanRelation => s
100
+ }.get
101
+ assert(scan.schema.names.sameElements(Seq (col)))
102
+ }
75
103
}
76
104
77
105
protected def checkAggregateRemoved (df : DataFrame ): Unit = {
78
- val aggregates = df.queryExecution.optimizedPlan.collect {
79
- case agg : Aggregate => agg
106
+ if (supportsAggregatePushdown) {
107
+ val aggregates = df.queryExecution.optimizedPlan.collect {
108
+ case agg : Aggregate => agg
109
+ }
110
+ assert(aggregates.isEmpty)
80
111
}
81
- assert(aggregates.isEmpty)
82
112
}
83
113
84
114
protected def checkAggregatePushed (df : DataFrame , funcName : String ): Unit = {
85
- df.queryExecution.optimizedPlan.collect {
86
- case DataSourceV2ScanRelation (_, scan, _, _, _) =>
87
- assert(scan.isInstanceOf [V1ScanWrapper ])
88
- val wrapper = scan.asInstanceOf [V1ScanWrapper ]
89
- assert(wrapper.pushedDownOperators.aggregation.isDefined)
90
- val aggregationExpressions =
91
- wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
92
- assert(aggregationExpressions.exists { expr =>
93
- expr.isInstanceOf [GeneralAggregateFunc ] &&
94
- expr.asInstanceOf [GeneralAggregateFunc ].name() == funcName
95
- })
115
+ if (supportsAggregatePushdown) {
116
+ df.queryExecution.optimizedPlan.collect {
117
+ case DataSourceV2ScanRelation (_, scan, _, _, _) =>
118
+ assert(scan.isInstanceOf [V1ScanWrapper ])
119
+ val wrapper = scan.asInstanceOf [V1ScanWrapper ]
120
+ assert(wrapper.pushedDownOperators.aggregation.isDefined)
121
+ val aggregationExpressions =
122
+ wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
123
+ assert(aggregationExpressions.exists { expr =>
124
+ expr.isInstanceOf [GeneralAggregateFunc ] &&
125
+ expr.asInstanceOf [GeneralAggregateFunc ].name() == funcName
126
+ })
127
+ }
96
128
}
97
129
}
98
130
99
- protected def checkSortRemoved (df : DataFrame , pushed : Boolean = true ): Unit = {
100
- val sorts = df.queryExecution.optimizedPlan.collect {
101
- case s : Sort => s
102
- }
131
+ protected def checkSortRemoved (
132
+ df : DataFrame ,
133
+ pushed : Boolean = true ): Unit = {
134
+ if (supportsSortPushdown) {
135
+ val sorts = df.queryExecution.optimizedPlan.collect {
136
+ case s : Sort => s
137
+ }
103
138
104
- if (pushed) {
105
- assert(sorts.isEmpty)
106
- } else {
107
- assert(sorts.nonEmpty)
139
+ if (pushed) {
140
+ assert(sorts.isEmpty)
141
+ } else {
142
+ assert(sorts.nonEmpty)
143
+ }
108
144
}
109
145
}
110
146
111
- protected def checkOffsetRemoved (df : DataFrame , pushed : Boolean = true ): Unit = {
112
- val offsets = df.queryExecution.optimizedPlan.collect {
113
- case o : Offset => o
114
- }
147
+ protected def checkOffsetRemoved (
148
+ df : DataFrame ,
149
+ pushed : Boolean = true ): Unit = {
150
+ if (supportsOffsetPushdown) {
151
+ val offsets = df.queryExecution.optimizedPlan.collect {
152
+ case o : Offset => o
153
+ }
115
154
116
- if (pushed) {
117
- assert(offsets.isEmpty)
118
- } else {
119
- assert(offsets.nonEmpty)
155
+ if (pushed) {
156
+ assert(offsets.isEmpty)
157
+ } else {
158
+ assert(offsets.nonEmpty)
159
+ }
120
160
}
121
161
}
122
162
123
163
protected def checkOffsetPushed (df : DataFrame , offset : Option [Int ]): Unit = {
124
- df.queryExecution.optimizedPlan.collect {
125
- case relation : DataSourceV2ScanRelation => relation.scan match {
126
- case v1 : V1ScanWrapper =>
127
- assert(v1.pushedDownOperators.offset == offset)
164
+ if (supportsOffsetPushdown) {
165
+ df.queryExecution.optimizedPlan.collect {
166
+ case relation : DataSourceV2ScanRelation => relation.scan match {
167
+ case v1 : V1ScanWrapper =>
168
+ assert(v1.pushedDownOperators.offset == offset)
169
+ }
128
170
}
129
171
}
130
172
}
131
173
132
174
protected def checkJoinNotPushed (df : DataFrame ): Unit = {
133
- val joinNodes = df.queryExecution.optimizedPlan.collect {
134
- case j : Join => j
175
+ if (supportsJoinPushdown) {
176
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
177
+ case j : Join => j
178
+ }
179
+ assert(joinNodes.nonEmpty, " Join should not be pushed down" )
135
180
}
136
- assert(joinNodes.nonEmpty, " Join should not be pushed down" )
137
181
}
138
182
139
183
protected def checkJoinPushed (df : DataFrame , expectedTables : String * ): Unit = {
140
- val joinNodes = df.queryExecution.optimizedPlan.collect {
141
- case j : Join => j
142
- }
143
- assert(joinNodes.isEmpty, " Join should be pushed down" )
144
- if (expectedTables.nonEmpty) {
145
- checkPushedInfo(df, s " PushedJoins: [ ${expectedTables.mkString(" , " )}] " )
184
+ if (supportsJoinPushdown) {
185
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
186
+ case j : Join => j
187
+ }
188
+ assert(joinNodes.isEmpty, " Join should be pushed down" )
189
+ if (expectedTables.nonEmpty) {
190
+ checkPushedInfo(df, s " PushedJoins: [ ${expectedTables.mkString(" , " )}] " )
191
+ }
146
192
}
147
193
}
148
194
@@ -154,4 +200,34 @@ trait DataSourcePushdownTestUtils extends ExplainSuiteHelper {
154
200
}
155
201
}
156
202
}
203
+
204
+ /**
205
+ * Check if the output schema of dataframe {@code df} is same as {@code schema}. There is one
206
+ * limitation: if expected schema name is empty, assertion on same names will be skipped.
207
+ * <br>
208
+ * For example, it is not really possible to use {@code checkPrunedColumns} for join pushdown,
209
+ * because in case of duplicate names, columns will have random UUID suffixes. For this reason,
210
+ * the best we can do is test that the size is same, and other fields beside names do match.
211
+ */
212
+ protected def checkPrunedColumnsDataTypeAndNullability (
213
+ df : DataFrame ,
214
+ schema : StructType ): Unit = {
215
+ if (supportsColumnPruning) {
216
+ df.queryExecution.optimizedPlan.collect {
217
+ case relation : DataSourceV2ScanRelation => relation.scan match {
218
+ case v1 : V1ScanWrapper =>
219
+ val dfSchema = v1.readSchema()
220
+
221
+ assert(dfSchema.length == schema.length)
222
+ dfSchema.fields.zip(schema.fields).foreach { case (f1, f2) =>
223
+ if (f2.name.nonEmpty) {
224
+ assert(f1.name == f2.name)
225
+ }
226
+ assert(f1.dataType == f2.dataType)
227
+ assert(f1.nullable == f2.nullable)
228
+ }
229
+ }
230
+ }
231
+ }
232
+ }
157
233
}
0 commit comments