|
16 | 16 | // under the License.
|
17 | 17 |
|
18 | 18 | use std::sync::Arc;
|
19 |
| -use std::{collections::HashSet, str::FromStr}; |
20 | 19 |
|
21 | 20 | use arrow::array::RecordBatch;
|
22 | 21 | use arrow::util::pretty::pretty_format_batches;
|
23 | 22 | use datafusion_common::{DataFusionError, Result};
|
24 | 23 | use datafusion_common_runtime::JoinSet;
|
25 |
| -use rand::seq::SliceRandom; |
26 | 24 | use rand::{thread_rng, Rng};
|
27 | 25 |
|
| 26 | +use crate::fuzz_cases::aggregation_fuzzer::query_builder::QueryBuilder; |
28 | 27 | use crate::fuzz_cases::aggregation_fuzzer::{
|
29 | 28 | check_equality_of_batches,
|
30 | 29 | context_generator::{SessionContextGenerator, SessionContextWithParams},
|
@@ -69,30 +68,16 @@ impl AggregationFuzzerBuilder {
|
69 | 68 | /// - 3 random queries
|
70 | 69 | /// - 3 random queries for each group by selected from the sort keys
|
71 | 70 | /// - 1 random query with no grouping
|
72 |
| - pub fn add_query_builder(mut self, mut query_builder: QueryBuilder) -> Self { |
73 |
| - const NUM_QUERIES: usize = 3; |
74 |
| - for _ in 0..NUM_QUERIES { |
75 |
| - let sql = query_builder.generate_query(); |
76 |
| - self.candidate_sqls.push(Arc::from(sql)); |
77 |
| - } |
78 |
| - // also add several queries limited to grouping on the group by columns only, if any |
79 |
| - // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b` |
80 |
| - if let Some(data_gen_config) = &self.data_gen_config { |
81 |
| - for sort_keys in &data_gen_config.sort_keys_set { |
82 |
| - let group_by_columns = sort_keys.iter().map(|s| s.as_str()); |
83 |
| - query_builder = query_builder.set_group_by_columns(group_by_columns); |
84 |
| - for _ in 0..NUM_QUERIES { |
85 |
| - let sql = query_builder.generate_query(); |
86 |
| - self.candidate_sqls.push(Arc::from(sql)); |
87 |
| - } |
88 |
| - } |
89 |
| - } |
90 |
| - // also add a query with no grouping |
91 |
| - query_builder = query_builder.set_group_by_columns(vec![]); |
92 |
| - let sql = query_builder.generate_query(); |
93 |
| - self.candidate_sqls.push(Arc::from(sql)); |
| 71 | + pub fn add_query_builder(mut self, query_builder: QueryBuilder) -> Self { |
| 72 | + self = self.table_name(query_builder.table_name()); |
94 | 73 |
|
95 |
| - self.table_name(query_builder.table_name()) |
| 74 | + let sqls = query_builder |
| 75 | + .generate_queries() |
| 76 | + .into_iter() |
| 77 | + .map(|sql| Arc::from(sql.as_str())); |
| 78 | + self.candidate_sqls.extend(sqls); |
| 79 | + |
| 80 | + self |
96 | 81 | }
|
97 | 82 |
|
98 | 83 | pub fn table_name(mut self, table_name: &str) -> Self {
|
@@ -371,217 +356,3 @@ fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display
|
371 | 356 |
|
372 | 357 | pretty_format_batches(&to_print).unwrap()
|
373 | 358 | }
|
374 |
| - |
375 |
| -/// Random aggregate query builder |
376 |
| -/// |
377 |
| -/// Creates queries like |
378 |
| -/// ```sql |
379 |
| -/// SELECT AGG(..) FROM table_name GROUP BY <group_by_columns> |
380 |
| -///``` |
381 |
| -#[derive(Debug, Default, Clone)] |
382 |
| -pub struct QueryBuilder { |
383 |
| - /// The name of the table to query |
384 |
| - table_name: String, |
385 |
| - /// Aggregate functions to be used in the query |
386 |
| - /// (function_name, is_distinct) |
387 |
| - aggregate_functions: Vec<(String, bool)>, |
388 |
| - /// Columns to be used in group by |
389 |
| - group_by_columns: Vec<String>, |
390 |
| - /// Possible columns for arguments in the aggregate functions |
391 |
| - /// |
392 |
| - /// Assumes each |
393 |
| - arguments: Vec<String>, |
394 |
| -} |
395 |
| -impl QueryBuilder { |
396 |
| - pub fn new() -> Self { |
397 |
| - Default::default() |
398 |
| - } |
399 |
| - |
400 |
| - /// return the table name if any |
401 |
| - pub fn table_name(&self) -> &str { |
402 |
| - &self.table_name |
403 |
| - } |
404 |
| - |
405 |
| - /// Set the table name for the query builder |
406 |
| - pub fn with_table_name(mut self, table_name: impl Into<String>) -> Self { |
407 |
| - self.table_name = table_name.into(); |
408 |
| - self |
409 |
| - } |
410 |
| - |
411 |
| - /// Add a new possible aggregate function to the query builder |
412 |
| - pub fn with_aggregate_function( |
413 |
| - mut self, |
414 |
| - aggregate_function: impl Into<String>, |
415 |
| - ) -> Self { |
416 |
| - self.aggregate_functions |
417 |
| - .push((aggregate_function.into(), false)); |
418 |
| - self |
419 |
| - } |
420 |
| - |
421 |
| - /// Add a new possible `DISTINCT` aggregate function to the query |
422 |
| - /// |
423 |
| - /// This is different than `with_aggregate_function` because only certain |
424 |
| - /// aggregates support `DISTINCT` |
425 |
| - pub fn with_distinct_aggregate_function( |
426 |
| - mut self, |
427 |
| - aggregate_function: impl Into<String>, |
428 |
| - ) -> Self { |
429 |
| - self.aggregate_functions |
430 |
| - .push((aggregate_function.into(), true)); |
431 |
| - self |
432 |
| - } |
433 |
| - |
434 |
| - /// Set the columns to be used in the group bys clauses |
435 |
| - pub fn set_group_by_columns<'a>( |
436 |
| - mut self, |
437 |
| - group_by: impl IntoIterator<Item = &'a str>, |
438 |
| - ) -> Self { |
439 |
| - self.group_by_columns = group_by.into_iter().map(String::from).collect(); |
440 |
| - self |
441 |
| - } |
442 |
| - |
443 |
| - /// Add one or more columns to be used as an argument in the aggregate functions |
444 |
| - pub fn with_aggregate_arguments<'a>( |
445 |
| - mut self, |
446 |
| - arguments: impl IntoIterator<Item = &'a str>, |
447 |
| - ) -> Self { |
448 |
| - let arguments = arguments.into_iter().map(String::from); |
449 |
| - self.arguments.extend(arguments); |
450 |
| - self |
451 |
| - } |
452 |
| - |
453 |
| - pub fn generate_query(&self) -> String { |
454 |
| - let group_by = self.random_group_by(); |
455 |
| - let mut query = String::from("SELECT "); |
456 |
| - query.push_str(&group_by.join(", ")); |
457 |
| - if !group_by.is_empty() { |
458 |
| - query.push_str(", "); |
459 |
| - } |
460 |
| - query.push_str(&self.random_aggregate_functions(&group_by).join(", ")); |
461 |
| - query.push_str(" FROM "); |
462 |
| - query.push_str(&self.table_name); |
463 |
| - if !group_by.is_empty() { |
464 |
| - query.push_str(" GROUP BY "); |
465 |
| - query.push_str(&group_by.join(", ")); |
466 |
| - } |
467 |
| - query |
468 |
| - } |
469 |
| - |
470 |
| - /// Generate a some random aggregate function invocations (potentially repeating). |
471 |
| - /// |
472 |
| - /// Each aggregate function invocation is of the form |
473 |
| - /// |
474 |
| - /// ```sql |
475 |
| - /// function_name(<DISTINCT> argument) as alias |
476 |
| - /// ``` |
477 |
| - /// |
478 |
| - /// where |
479 |
| - /// * `function_names` are randomly selected from [`Self::aggregate_functions`] |
480 |
| - /// * `<DISTINCT> argument` is randomly selected from [`Self::arguments`] |
481 |
| - /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) |
482 |
| - fn random_aggregate_functions(&self, group_by_cols: &[String]) -> Vec<String> { |
483 |
| - const MAX_NUM_FUNCTIONS: usize = 5; |
484 |
| - let mut rng = thread_rng(); |
485 |
| - let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); |
486 |
| - |
487 |
| - let mut alias_gen = 1; |
488 |
| - |
489 |
| - let mut aggregate_functions = vec![]; |
490 |
| - |
491 |
| - let mut order_by_black_list: HashSet<String> = |
492 |
| - group_by_cols.iter().cloned().collect(); |
493 |
| - // remove one random col |
494 |
| - if let Some(first) = order_by_black_list.iter().next().cloned() { |
495 |
| - order_by_black_list.remove(&first); |
496 |
| - } |
497 |
| - |
498 |
| - while aggregate_functions.len() < num_aggregate_functions { |
499 |
| - let idx = rng.gen_range(0..self.aggregate_functions.len()); |
500 |
| - let (function_name, is_distinct) = &self.aggregate_functions[idx]; |
501 |
| - let argument = self.random_argument(); |
502 |
| - let alias = format!("col{}", alias_gen); |
503 |
| - let distinct = if *is_distinct { "DISTINCT " } else { "" }; |
504 |
| - alias_gen += 1; |
505 |
| - |
506 |
| - let (order_by, null_opt) = if function_name.eq("first_value") |
507 |
| - || function_name.eq("last_value") |
508 |
| - { |
509 |
| - ( |
510 |
| - self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */ |
511 |
| - self.null_opt(), |
512 |
| - ) |
513 |
| - } else { |
514 |
| - ("".to_string(), "".to_string()) |
515 |
| - }; |
516 |
| - |
517 |
| - let function = format!( |
518 |
| - "{function_name}({distinct}{argument}{order_by}) {null_opt} as {alias}" |
519 |
| - ); |
520 |
| - aggregate_functions.push(function); |
521 |
| - } |
522 |
| - aggregate_functions |
523 |
| - } |
524 |
| - |
525 |
| - /// Pick a random aggregate function argument |
526 |
| - fn random_argument(&self) -> String { |
527 |
| - let mut rng = thread_rng(); |
528 |
| - let idx = rng.gen_range(0..self.arguments.len()); |
529 |
| - self.arguments[idx].clone() |
530 |
| - } |
531 |
| - |
532 |
| - fn order_by(&self, black_list: &HashSet<String>) -> String { |
533 |
| - let mut available_columns: Vec<String> = self |
534 |
| - .arguments |
535 |
| - .iter() |
536 |
| - .filter(|col| !black_list.contains(*col)) |
537 |
| - .cloned() |
538 |
| - .collect(); |
539 |
| - |
540 |
| - available_columns.shuffle(&mut thread_rng()); |
541 |
| - |
542 |
| - let num_of_order_by_col = 12; |
543 |
| - let column_count = std::cmp::min(num_of_order_by_col, available_columns.len()); |
544 |
| - |
545 |
| - let selected_columns = &available_columns[0..column_count]; |
546 |
| - |
547 |
| - let mut rng = thread_rng(); |
548 |
| - let mut result = String::from_str(" order by ").unwrap(); |
549 |
| - for col in selected_columns { |
550 |
| - let order = if rng.gen_bool(0.5) { "ASC" } else { "DESC" }; |
551 |
| - result.push_str(&format!("{} {},", col, order)); |
552 |
| - } |
553 |
| - |
554 |
| - result.strip_suffix(",").unwrap().to_string() |
555 |
| - } |
556 |
| - |
557 |
| - fn null_opt(&self) -> String { |
558 |
| - if thread_rng().gen_bool(0.5) { |
559 |
| - "RESPECT NULLS".to_string() |
560 |
| - } else { |
561 |
| - "IGNORE NULLS".to_string() |
562 |
| - } |
563 |
| - } |
564 |
| - |
565 |
| - /// Pick a random number of fields to group by (non-repeating) |
566 |
| - /// |
567 |
| - /// Limited to 3 group by columns to ensure coverage for large groups. With |
568 |
| - /// larger numbers of columns, each group has many fewer values. |
569 |
| - fn random_group_by(&self) -> Vec<String> { |
570 |
| - let mut rng = thread_rng(); |
571 |
| - const MAX_GROUPS: usize = 3; |
572 |
| - let max_groups = self.group_by_columns.len().max(MAX_GROUPS); |
573 |
| - let num_group_by = rng.gen_range(1..max_groups); |
574 |
| - |
575 |
| - let mut already_used = HashSet::new(); |
576 |
| - let mut group_by = vec![]; |
577 |
| - while group_by.len() < num_group_by |
578 |
| - && already_used.len() != self.group_by_columns.len() |
579 |
| - { |
580 |
| - let idx = rng.gen_range(0..self.group_by_columns.len()); |
581 |
| - if already_used.insert(idx) { |
582 |
| - group_by.push(self.group_by_columns[idx].clone()); |
583 |
| - } |
584 |
| - } |
585 |
| - group_by |
586 |
| - } |
587 |
| -} |
0 commit comments