Skip to content

Commit 5b5c8d8

Browse files
committed
chore: leave tests in core module
1 parent d093582 commit 5b5c8d8

15 files changed

+7305
-3854
lines changed
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#[cfg(test)]
19+
pub(crate) mod tests {
20+
21+
use crate::error::Result;
22+
use std::sync::Arc;
23+
24+
use crate::prelude::SessionContext;
25+
use datafusion_common::config::ConfigOptions;
26+
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
27+
use datafusion_expr::Operator;
28+
use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics;
29+
use datafusion_physical_optimizer::PhysicalOptimizerRule;
30+
use datafusion_physical_plan::aggregates::{AggregateExec, PhysicalGroupBy};
31+
use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
32+
use datafusion_physical_plan::filter::FilterExec;
33+
use datafusion_physical_plan::memory::MemoryExec;
34+
use datafusion_physical_plan::projection::ProjectionExec;
35+
use datafusion_physical_plan::{common, ExecutionPlan};
36+
37+
use datafusion_common::arrow::array::Int32Array;
38+
use datafusion_common::arrow::datatypes::{DataType, Field, Schema};
39+
use datafusion_common::arrow::record_batch::RecordBatch;
40+
use datafusion_common::cast::as_int64_array;
41+
use datafusion_functions_aggregate::count::count_udaf;
42+
use datafusion_physical_expr::expressions::{self, cast};
43+
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
44+
use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
45+
use datafusion_physical_plan::aggregates::AggregateMode;
46+
47+
/// Mock data using a MemoryExec which has an exact count statistic
48+
fn mock_data() -> Result<Arc<MemoryExec>> {
49+
let schema = Arc::new(Schema::new(vec![
50+
Field::new("a", DataType::Int32, true),
51+
Field::new("b", DataType::Int32, true),
52+
]));
53+
54+
let batch = RecordBatch::try_new(
55+
Arc::clone(&schema),
56+
vec![
57+
Arc::new(Int32Array::from(vec![Some(1), Some(2), None])),
58+
Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])),
59+
],
60+
)?;
61+
62+
Ok(Arc::new(MemoryExec::try_new(
63+
&[vec![batch]],
64+
Arc::clone(&schema),
65+
None,
66+
)?))
67+
}
68+
69+
/// Checks that the count optimization was applied and we still get the right result
70+
async fn assert_count_optim_success(
71+
plan: AggregateExec,
72+
agg: TestAggregate,
73+
) -> Result<()> {
74+
let session_ctx = SessionContext::new();
75+
let state = session_ctx.state();
76+
let plan: Arc<dyn ExecutionPlan> = Arc::new(plan);
77+
78+
let optimized = AggregateStatistics::new()
79+
.optimize(Arc::clone(&plan), state.config_options())?;
80+
81+
// A ProjectionExec is a sign that the count optimization was applied
82+
assert!(optimized.as_any().is::<ProjectionExec>());
83+
84+
// run both the optimized and nonoptimized plan
85+
let optimized_result =
86+
common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?;
87+
let nonoptimized_result =
88+
common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
89+
assert_eq!(optimized_result.len(), nonoptimized_result.len());
90+
91+
// and validate the results are the same and expected
92+
assert_eq!(optimized_result.len(), 1);
93+
check_batch(optimized_result.into_iter().next().unwrap(), &agg);
94+
// check the non optimized one too to ensure types and names remain the same
95+
assert_eq!(nonoptimized_result.len(), 1);
96+
check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
97+
98+
Ok(())
99+
}
100+
101+
fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
102+
let schema = batch.schema();
103+
let fields = schema.fields();
104+
assert_eq!(fields.len(), 1);
105+
106+
let field = &fields[0];
107+
assert_eq!(field.name(), agg.column_name());
108+
assert_eq!(field.data_type(), &DataType::Int64);
109+
// note that nullabiolity differs
110+
111+
assert_eq!(
112+
as_int64_array(batch.column(0)).unwrap().values(),
113+
&[agg.expected_count()]
114+
);
115+
}
116+
117+
/// Describe the type of aggregate being tested
118+
pub(crate) enum TestAggregate {
119+
/// Testing COUNT(*) type aggregates
120+
CountStar,
121+
122+
/// Testing for COUNT(column) aggregate
123+
ColumnA(Arc<Schema>),
124+
}
125+
126+
impl TestAggregate {
127+
pub(crate) fn new_count_star() -> Self {
128+
Self::CountStar
129+
}
130+
131+
fn new_count_column(schema: &Arc<Schema>) -> Self {
132+
Self::ColumnA(schema.clone())
133+
}
134+
135+
// Return appropriate expr depending if COUNT is for col or table (*)
136+
pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn AggregateExpr> {
137+
AggregateExprBuilder::new(count_udaf(), vec![self.column()])
138+
.schema(Arc::new(schema.clone()))
139+
.name(self.column_name())
140+
.build()
141+
.unwrap()
142+
}
143+
144+
/// what argument would this aggregate need in the plan?
145+
fn column(&self) -> Arc<dyn PhysicalExpr> {
146+
match self {
147+
Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION),
148+
Self::ColumnA(s) => expressions::col("a", s).unwrap(),
149+
}
150+
}
151+
152+
/// What name would this aggregate produce in a plan?
153+
fn column_name(&self) -> &'static str {
154+
match self {
155+
Self::CountStar => "COUNT(*)",
156+
Self::ColumnA(_) => "COUNT(a)",
157+
}
158+
}
159+
160+
/// What is the expected count?
161+
fn expected_count(&self) -> i64 {
162+
match self {
163+
TestAggregate::CountStar => 3,
164+
TestAggregate::ColumnA(_) => 2,
165+
}
166+
}
167+
}
168+
169+
#[tokio::test]
170+
async fn test_count_partial_direct_child() -> Result<()> {
171+
// basic test case with the aggregation applied on a source with exact statistics
172+
let source = mock_data()?;
173+
let schema = source.schema();
174+
let agg = TestAggregate::new_count_star();
175+
176+
let partial_agg = AggregateExec::try_new(
177+
AggregateMode::Partial,
178+
PhysicalGroupBy::default(),
179+
vec![agg.count_expr(&schema)],
180+
vec![None],
181+
source,
182+
Arc::clone(&schema),
183+
)?;
184+
185+
let final_agg = AggregateExec::try_new(
186+
AggregateMode::Final,
187+
PhysicalGroupBy::default(),
188+
vec![agg.count_expr(&schema)],
189+
vec![None],
190+
Arc::new(partial_agg),
191+
Arc::clone(&schema),
192+
)?;
193+
194+
assert_count_optim_success(final_agg, agg).await?;
195+
196+
Ok(())
197+
}
198+
199+
#[tokio::test]
200+
async fn test_count_partial_with_nulls_direct_child() -> Result<()> {
201+
// basic test case with the aggregation applied on a source with exact statistics
202+
let source = mock_data()?;
203+
let schema = source.schema();
204+
let agg = TestAggregate::new_count_column(&schema);
205+
206+
let partial_agg = AggregateExec::try_new(
207+
AggregateMode::Partial,
208+
PhysicalGroupBy::default(),
209+
vec![agg.count_expr(&schema)],
210+
vec![None],
211+
source,
212+
Arc::clone(&schema),
213+
)?;
214+
215+
let final_agg = AggregateExec::try_new(
216+
AggregateMode::Final,
217+
PhysicalGroupBy::default(),
218+
vec![agg.count_expr(&schema)],
219+
vec![None],
220+
Arc::new(partial_agg),
221+
Arc::clone(&schema),
222+
)?;
223+
224+
assert_count_optim_success(final_agg, agg).await?;
225+
226+
Ok(())
227+
}
228+
229+
#[tokio::test]
230+
async fn test_count_partial_indirect_child() -> Result<()> {
231+
let source = mock_data()?;
232+
let schema = source.schema();
233+
let agg = TestAggregate::new_count_star();
234+
235+
let partial_agg = AggregateExec::try_new(
236+
AggregateMode::Partial,
237+
PhysicalGroupBy::default(),
238+
vec![agg.count_expr(&schema)],
239+
vec![None],
240+
source,
241+
Arc::clone(&schema),
242+
)?;
243+
244+
// We introduce an intermediate optimization step between the partial and final aggregtator
245+
let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
246+
247+
let final_agg = AggregateExec::try_new(
248+
AggregateMode::Final,
249+
PhysicalGroupBy::default(),
250+
vec![agg.count_expr(&schema)],
251+
vec![None],
252+
Arc::new(coalesce),
253+
Arc::clone(&schema),
254+
)?;
255+
256+
assert_count_optim_success(final_agg, agg).await?;
257+
258+
Ok(())
259+
}
260+
261+
#[tokio::test]
262+
async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
263+
let source = mock_data()?;
264+
let schema = source.schema();
265+
let agg = TestAggregate::new_count_column(&schema);
266+
267+
let partial_agg = AggregateExec::try_new(
268+
AggregateMode::Partial,
269+
PhysicalGroupBy::default(),
270+
vec![agg.count_expr(&schema)],
271+
vec![None],
272+
source,
273+
Arc::clone(&schema),
274+
)?;
275+
276+
// We introduce an intermediate optimization step between the partial and final aggregtator
277+
let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
278+
279+
let final_agg = AggregateExec::try_new(
280+
AggregateMode::Final,
281+
PhysicalGroupBy::default(),
282+
vec![agg.count_expr(&schema)],
283+
vec![None],
284+
Arc::new(coalesce),
285+
Arc::clone(&schema),
286+
)?;
287+
288+
assert_count_optim_success(final_agg, agg).await?;
289+
290+
Ok(())
291+
}
292+
293+
#[tokio::test]
294+
async fn test_count_inexact_stat() -> Result<()> {
295+
let source = mock_data()?;
296+
let schema = source.schema();
297+
let agg = TestAggregate::new_count_star();
298+
299+
// adding a filter makes the statistics inexact
300+
let filter = Arc::new(FilterExec::try_new(
301+
expressions::binary(
302+
expressions::col("a", &schema)?,
303+
Operator::Gt,
304+
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
305+
&schema,
306+
)?,
307+
source,
308+
)?);
309+
310+
let partial_agg = AggregateExec::try_new(
311+
AggregateMode::Partial,
312+
PhysicalGroupBy::default(),
313+
vec![agg.count_expr(&schema)],
314+
vec![None],
315+
filter,
316+
Arc::clone(&schema),
317+
)?;
318+
319+
let final_agg = AggregateExec::try_new(
320+
AggregateMode::Final,
321+
PhysicalGroupBy::default(),
322+
vec![agg.count_expr(&schema)],
323+
vec![None],
324+
Arc::new(partial_agg),
325+
Arc::clone(&schema),
326+
)?;
327+
328+
let conf = ConfigOptions::new();
329+
let optimized =
330+
AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
331+
332+
// check that the original ExecutionPlan was not replaced
333+
assert!(optimized.as_any().is::<AggregateExec>());
334+
335+
Ok(())
336+
}
337+
338+
#[tokio::test]
339+
async fn test_count_with_nulls_inexact_stat() -> Result<()> {
340+
let source = mock_data()?;
341+
let schema = source.schema();
342+
let agg = TestAggregate::new_count_column(&schema);
343+
344+
// adding a filter makes the statistics inexact
345+
let filter = Arc::new(FilterExec::try_new(
346+
expressions::binary(
347+
expressions::col("a", &schema)?,
348+
Operator::Gt,
349+
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
350+
&schema,
351+
)?,
352+
source,
353+
)?);
354+
355+
let partial_agg = AggregateExec::try_new(
356+
AggregateMode::Partial,
357+
PhysicalGroupBy::default(),
358+
vec![agg.count_expr(&schema)],
359+
vec![None],
360+
filter,
361+
Arc::clone(&schema),
362+
)?;
363+
364+
let final_agg = AggregateExec::try_new(
365+
AggregateMode::Final,
366+
PhysicalGroupBy::default(),
367+
vec![agg.count_expr(&schema)],
368+
vec![None],
369+
Arc::new(partial_agg),
370+
Arc::clone(&schema),
371+
)?;
372+
373+
let conf = ConfigOptions::new();
374+
let optimized =
375+
AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
376+
377+
// check that the original ExecutionPlan was not replaced
378+
assert!(optimized.as_any().is::<AggregateExec>());
379+
380+
Ok(())
381+
}
382+
}

0 commit comments

Comments
 (0)