Skip to content

Commit 74dc419

Browse files
authored
Make aggr fuzzer query builder more configurable (#15851)
* refactor and make `QueryBuilder` more configurable. * fix tests. * fix clippy. * extract `QueryBuilder` to a dedicated module. * add `min_group_by_columns`, and fix some bugs.
1 parent 96a2086 commit 74dc419

File tree

4 files changed

+404
-240
lines changed

4 files changed

+404
-240
lines changed

datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
use std::sync::Arc;
1919

20+
use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder;
2021
use crate::fuzz_cases::aggregation_fuzzer::{
21-
AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder,
22+
AggregationFuzzerBuilder, DatasetGeneratorConfig,
2223
};
2324

2425
use arrow::array::{
@@ -85,6 +86,7 @@ async fn test_min() {
8586
.with_aggregate_function("min")
8687
// min works on all column types
8788
.with_aggregate_arguments(data_gen_config.all_columns())
89+
.with_dataset_sort_keys(data_gen_config.sort_keys_set.clone())
8890
.set_group_by_columns(data_gen_config.all_columns());
8991

9092
AggregationFuzzerBuilder::from(data_gen_config)
@@ -111,6 +113,7 @@ async fn test_first_val() {
111113
.with_table_name("fuzz_table")
112114
.with_aggregate_function("first_value")
113115
.with_aggregate_arguments(data_gen_config.all_columns())
116+
.with_dataset_sort_keys(data_gen_config.sort_keys_set.clone())
114117
.set_group_by_columns(data_gen_config.all_columns());
115118

116119
AggregationFuzzerBuilder::from(data_gen_config)
@@ -137,6 +140,7 @@ async fn test_last_val() {
137140
.with_table_name("fuzz_table")
138141
.with_aggregate_function("last_value")
139142
.with_aggregate_arguments(data_gen_config.all_columns())
143+
.with_dataset_sort_keys(data_gen_config.sort_keys_set.clone())
140144
.set_group_by_columns(data_gen_config.all_columns());
141145

142146
AggregationFuzzerBuilder::from(data_gen_config)
@@ -156,6 +160,7 @@ async fn test_max() {
156160
.with_aggregate_function("max")
157161
// max works on all column types
158162
.with_aggregate_arguments(data_gen_config.all_columns())
163+
.with_dataset_sort_keys(data_gen_config.sort_keys_set.clone())
159164
.set_group_by_columns(data_gen_config.all_columns());
160165

161166
AggregationFuzzerBuilder::from(data_gen_config)
@@ -176,6 +181,7 @@ async fn test_sum() {
176181
.with_distinct_aggregate_function("sum")
177182
// sum only works on numeric columns
178183
.with_aggregate_arguments(data_gen_config.numeric_columns())
184+
.with_dataset_sort_keys(data_gen_config.sort_keys_set.clone())
179185
.set_group_by_columns(data_gen_config.all_columns());
180186

181187
AggregationFuzzerBuilder::from(data_gen_config)
@@ -196,6 +202,7 @@ async fn test_count() {
196202
.with_distinct_aggregate_function("count")
197203
// count work for all arguments
198204
.with_aggregate_arguments(data_gen_config.all_columns())
205+
.with_dataset_sort_keys(data_gen_config.sort_keys_set.clone())
199206
.set_group_by_columns(data_gen_config.all_columns());
200207

201208
AggregationFuzzerBuilder::from(data_gen_config)
@@ -216,6 +223,7 @@ async fn test_median() {
216223
.with_distinct_aggregate_function("median")
217224
// median only works on numeric columns
218225
.with_aggregate_arguments(data_gen_config.numeric_columns())
226+
.with_dataset_sort_keys(data_gen_config.sort_keys_set.clone())
219227
.set_group_by_columns(data_gen_config.all_columns());
220228

221229
AggregationFuzzerBuilder::from(data_gen_config)

datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs

Lines changed: 10 additions & 239 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616
// under the License.
1717

1818
use std::sync::Arc;
19-
use std::{collections::HashSet, str::FromStr};
2019

2120
use arrow::array::RecordBatch;
2221
use arrow::util::pretty::pretty_format_batches;
2322
use datafusion_common::{DataFusionError, Result};
2423
use datafusion_common_runtime::JoinSet;
25-
use rand::seq::SliceRandom;
2624
use rand::{thread_rng, Rng};
2725

26+
use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder;
2827
use crate::fuzz_cases::aggregation_fuzzer::{
2928
check_equality_of_batches,
3029
context_generator::{SessionContextGenerator, SessionContextWithParams},
@@ -69,30 +68,16 @@ impl AggregationFuzzerBuilder {
6968
/// - 3 random queries
7069
/// - 3 random queries for each group by selected from the sort keys
7170
/// - 1 random query with no grouping
72-
pub fn add_query_builder(mut self, mut query_builder: QueryBuilder) -> Self {
73-
const NUM_QUERIES: usize = 3;
74-
for _ in 0..NUM_QUERIES {
75-
let sql = query_builder.generate_query();
76-
self.candidate_sqls.push(Arc::from(sql));
77-
}
78-
// also add several queries limited to grouping on the group by columns only, if any
79-
// So if the data is sorted on `a,b` only group by `a,b` or`a` or `b`
80-
if let Some(data_gen_config) = &self.data_gen_config {
81-
for sort_keys in &data_gen_config.sort_keys_set {
82-
let group_by_columns = sort_keys.iter().map(|s| s.as_str());
83-
query_builder = query_builder.set_group_by_columns(group_by_columns);
84-
for _ in 0..NUM_QUERIES {
85-
let sql = query_builder.generate_query();
86-
self.candidate_sqls.push(Arc::from(sql));
87-
}
88-
}
89-
}
90-
// also add a query with no grouping
91-
query_builder = query_builder.set_group_by_columns(vec![]);
92-
let sql = query_builder.generate_query();
93-
self.candidate_sqls.push(Arc::from(sql));
71+
pub fn add_query_builder(mut self, query_builder: QueryBuilder) -> Self {
72+
self = self.table_name(query_builder.table_name());
9473

95-
self.table_name(query_builder.table_name())
74+
let sqls = query_builder
75+
.generate_queries()
76+
.into_iter()
77+
.map(|sql| Arc::from(sql.as_str()));
78+
self.candidate_sqls.extend(sqls);
79+
80+
self
9681
}
9782

9883
pub fn table_name(mut self, table_name: &str) -> Self {
@@ -371,217 +356,3 @@ fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display
371356

372357
pretty_format_batches(&to_print).unwrap()
373358
}
374-
375-
/// Random aggregate query builder
376-
///
377-
/// Creates queries like
378-
/// ```sql
379-
/// SELECT AGG(..) FROM table_name GROUP BY <group_by_columns>
380-
///```
381-
#[derive(Debug, Default, Clone)]
382-
pub struct QueryBuilder {
383-
/// The name of the table to query
384-
table_name: String,
385-
/// Aggregate functions to be used in the query
386-
/// (function_name, is_distinct)
387-
aggregate_functions: Vec<(String, bool)>,
388-
/// Columns to be used in group by
389-
group_by_columns: Vec<String>,
390-
/// Possible columns for arguments in the aggregate functions
391-
///
392-
/// Assumes each
393-
arguments: Vec<String>,
394-
}
395-
impl QueryBuilder {
396-
pub fn new() -> Self {
397-
Default::default()
398-
}
399-
400-
/// return the table name if any
401-
pub fn table_name(&self) -> &str {
402-
&self.table_name
403-
}
404-
405-
/// Set the table name for the query builder
406-
pub fn with_table_name(mut self, table_name: impl Into<String>) -> Self {
407-
self.table_name = table_name.into();
408-
self
409-
}
410-
411-
/// Add a new possible aggregate function to the query builder
412-
pub fn with_aggregate_function(
413-
mut self,
414-
aggregate_function: impl Into<String>,
415-
) -> Self {
416-
self.aggregate_functions
417-
.push((aggregate_function.into(), false));
418-
self
419-
}
420-
421-
/// Add a new possible `DISTINCT` aggregate function to the query
422-
///
423-
/// This is different than `with_aggregate_function` because only certain
424-
/// aggregates support `DISTINCT`
425-
pub fn with_distinct_aggregate_function(
426-
mut self,
427-
aggregate_function: impl Into<String>,
428-
) -> Self {
429-
self.aggregate_functions
430-
.push((aggregate_function.into(), true));
431-
self
432-
}
433-
434-
/// Set the columns to be used in the group bys clauses
435-
pub fn set_group_by_columns<'a>(
436-
mut self,
437-
group_by: impl IntoIterator<Item = &'a str>,
438-
) -> Self {
439-
self.group_by_columns = group_by.into_iter().map(String::from).collect();
440-
self
441-
}
442-
443-
/// Add one or more columns to be used as an argument in the aggregate functions
444-
pub fn with_aggregate_arguments<'a>(
445-
mut self,
446-
arguments: impl IntoIterator<Item = &'a str>,
447-
) -> Self {
448-
let arguments = arguments.into_iter().map(String::from);
449-
self.arguments.extend(arguments);
450-
self
451-
}
452-
453-
pub fn generate_query(&self) -> String {
454-
let group_by = self.random_group_by();
455-
let mut query = String::from("SELECT ");
456-
query.push_str(&group_by.join(", "));
457-
if !group_by.is_empty() {
458-
query.push_str(", ");
459-
}
460-
query.push_str(&self.random_aggregate_functions(&group_by).join(", "));
461-
query.push_str(" FROM ");
462-
query.push_str(&self.table_name);
463-
if !group_by.is_empty() {
464-
query.push_str(" GROUP BY ");
465-
query.push_str(&group_by.join(", "));
466-
}
467-
query
468-
}
469-
470-
/// Generate a some random aggregate function invocations (potentially repeating).
471-
///
472-
/// Each aggregate function invocation is of the form
473-
///
474-
/// ```sql
475-
/// function_name(<DISTINCT> argument) as alias
476-
/// ```
477-
///
478-
/// where
479-
/// * `function_names` are randomly selected from [`Self::aggregate_functions`]
480-
/// * `<DISTINCT> argument` is randomly selected from [`Self::arguments`]
481-
/// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names)
482-
fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec<String> {
483-
const MAX_NUM_FUNCTIONS: usize = 5;
484-
let mut rng = thread_rng();
485-
let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS);
486-
487-
let mut alias_gen = 1;
488-
489-
let mut aggregate_functions = vec![];
490-
491-
let mut order_by_black_list: HashSet<String> =
492-
group_by_cols.iter().cloned().collect();
493-
// remove one random col
494-
if let Some(first) = order_by_black_list.iter().next().cloned() {
495-
order_by_black_list.remove(&first);
496-
}
497-
498-
while aggregate_functions.len() < num_aggregate_functions {
499-
let idx = rng.gen_range(0..self.aggregate_functions.len());
500-
let (function_name, is_distinct) = &self.aggregate_functions[idx];
501-
let argument = self.random_argument();
502-
let alias = format!("col{}", alias_gen);
503-
let distinct = if *is_distinct { "DISTINCT " } else { "" };
504-
alias_gen += 1;
505-
506-
let (order_by, null_opt) = if function_name.eq("first_value")
507-
|| function_name.eq("last_value")
508-
{
509-
(
510-
self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */
511-
self.null_opt(),
512-
)
513-
} else {
514-
("".to_string(), "".to_string())
515-
};
516-
517-
let function = format!(
518-
"{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}"
519-
);
520-
aggregate_functions.push(function);
521-
}
522-
aggregate_functions
523-
}
524-
525-
/// Pick a random aggregate function argument
526-
fn random_argument(&self) -> String {
527-
let mut rng = thread_rng();
528-
let idx = rng.gen_range(0..self.arguments.len());
529-
self.arguments[idx].clone()
530-
}
531-
532-
fn order_by(&self, black_list: &HashSet<String>) -> String {
533-
let mut available_columns: Vec<String> = self
534-
.arguments
535-
.iter()
536-
.filter(|col| !black_list.contains(*col))
537-
.cloned()
538-
.collect();
539-
540-
available_columns.shuffle(&mut thread_rng());
541-
542-
let num_of_order_by_col = 12;
543-
let column_count = std::cmp::min(num_of_order_by_col, available_columns.len());
544-
545-
let selected_columns = &available_columns[0..column_count];
546-
547-
let mut rng = thread_rng();
548-
let mut result = String::from_str(" order by ").unwrap();
549-
for col in selected_columns {
550-
let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" };
551-
result.push_str(&format!("{} {},", col, order));
552-
}
553-
554-
result.strip_suffix(",").unwrap().to_string()
555-
}
556-
557-
fn null_opt(&self) -> String {
558-
if thread_rng().gen_bool(0.5) {
559-
"RESPECT NULLS".to_string()
560-
} else {
561-
"IGNORE NULLS".to_string()
562-
}
563-
}
564-
565-
/// Pick a random number of fields to group by (non-repeating)
566-
///
567-
/// Limited to 3 group by columns to ensure coverage for large groups. With
568-
/// larger numbers of columns, each group has many fewer values.
569-
fn random_group_by(&self) -> Vec<String> {
570-
let mut rng = thread_rng();
571-
const MAX_GROUPS: usize = 3;
572-
let max_groups = self.group_by_columns.len().max(MAX_GROUPS);
573-
let num_group_by = rng.gen_range(1..max_groups);
574-
575-
let mut already_used = HashSet::new();
576-
let mut group_by = vec![];
577-
while group_by.len() < num_group_by
578-
&& already_used.len() != self.group_by_columns.len()
579-
{
580-
let idx = rng.gen_range(0..self.group_by_columns.len());
581-
if already_used.insert(idx) {
582-
group_by.push(self.group_by_columns[idx].clone());
583-
}
584-
}
585-
group_by
586-
}
587-
}

datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use datafusion_common::error::Result;
4343
mod context_generator;
4444
mod data_generator;
4545
mod fuzzer;
46+
pub mod query_builder;
4647

4748
pub use crate::fuzz_cases::record_batch_generator::ColumnDescr;
4849
pub use data_generator::DatasetGeneratorConfig;

0 commit comments

Comments
 (0)