Skip to content

Commit e9f9a23

Browse files
authored
Minor: Add routine to debug join fuzz tests (#10970)
* Minor: Add routine to debug join fuzz tests
1 parent a2c9d1a commit e9f9a23

File tree

1 file changed

+162
-40
lines changed

1 file changed

+162
-40
lines changed

datafusion/core/tests/fuzz_cases/join_fuzz.rs

Lines changed: 162 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ use datafusion::physical_plan::memory::MemoryExec;
4343
use datafusion::prelude::{SessionConfig, SessionContext};
4444
use test_utils::stagger_batch_with_seed;
4545

46+
// Determines what Fuzz tests needs to run
47+
// Ideally all tests should match, but in reality some tests
48+
// passes only partial cases
49+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
50+
enum JoinTestType {
51+
// compare NestedLoopJoin and HashJoin
52+
NljHj,
53+
// compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin and NestedLoopJoin
54+
// because if existing variants both passed that means SortMergeJoin and NestedLoopJoin also passes
55+
HjSmj,
56+
}
4657
#[tokio::test]
4758
async fn test_inner_join_1k() {
4859
JoinFuzzTestCase::new(
@@ -51,7 +62,7 @@ async fn test_inner_join_1k() {
5162
JoinType::Inner,
5263
None,
5364
)
54-
.run_test()
65+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
5566
.await
5667
}
5768

@@ -71,6 +82,30 @@ fn less_than_100_join_filter(schema1: Arc<Schema>, _schema2: Arc<Schema>) -> Joi
7182
JoinFilter::new(less_than_100, column_indices, intermediate_schema)
7283
}
7384

85+
fn col_lt_col_filter(schema1: Arc<Schema>, schema2: Arc<Schema>) -> JoinFilter {
86+
let less_than_100 = Arc::new(BinaryExpr::new(
87+
Arc::new(Column::new("x", 1)),
88+
Operator::Lt,
89+
Arc::new(Column::new("x", 0)),
90+
)) as _;
91+
let column_indices = vec![
92+
ColumnIndex {
93+
index: 2,
94+
side: JoinSide::Left,
95+
},
96+
ColumnIndex {
97+
index: 2,
98+
side: JoinSide::Right,
99+
},
100+
];
101+
let intermediate_schema = Schema::new(vec![
102+
schema1.field_with_name("x").unwrap().to_owned(),
103+
schema2.field_with_name("x").unwrap().to_owned(),
104+
]);
105+
106+
JoinFilter::new(less_than_100, column_indices, intermediate_schema)
107+
}
108+
74109
#[tokio::test]
75110
async fn test_inner_join_1k_filtered() {
76111
JoinFuzzTestCase::new(
@@ -79,7 +114,7 @@ async fn test_inner_join_1k_filtered() {
79114
JoinType::Inner,
80115
Some(Box::new(less_than_100_join_filter)),
81116
)
82-
.run_test()
117+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
83118
.await
84119
}
85120

@@ -91,7 +126,7 @@ async fn test_inner_join_1k_smjoin() {
91126
JoinType::Inner,
92127
None,
93128
)
94-
.run_test()
129+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
95130
.await
96131
}
97132

@@ -103,7 +138,7 @@ async fn test_left_join_1k() {
103138
JoinType::Left,
104139
None,
105140
)
106-
.run_test()
141+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
107142
.await
108143
}
109144

@@ -115,7 +150,7 @@ async fn test_left_join_1k_filtered() {
115150
JoinType::Left,
116151
Some(Box::new(less_than_100_join_filter)),
117152
)
118-
.run_test()
153+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
119154
.await
120155
}
121156

@@ -127,7 +162,7 @@ async fn test_right_join_1k() {
127162
JoinType::Right,
128163
None,
129164
)
130-
.run_test()
165+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
131166
.await
132167
}
133168
// Add support for Right filtered joins
@@ -140,7 +175,7 @@ async fn test_right_join_1k_filtered() {
140175
JoinType::Right,
141176
Some(Box::new(less_than_100_join_filter)),
142177
)
143-
.run_test()
178+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
144179
.await
145180
}
146181

@@ -152,7 +187,7 @@ async fn test_full_join_1k() {
152187
JoinType::Full,
153188
None,
154189
)
155-
.run_test()
190+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
156191
.await
157192
}
158193

@@ -164,7 +199,7 @@ async fn test_full_join_1k_filtered() {
164199
JoinType::Full,
165200
Some(Box::new(less_than_100_join_filter)),
166201
)
167-
.run_test()
202+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
168203
.await
169204
}
170205

@@ -176,22 +211,23 @@ async fn test_semi_join_1k() {
176211
JoinType::LeftSemi,
177212
None,
178213
)
179-
.run_test()
214+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
180215
.await
181216
}
182217

183218
// The test is flaky
184219
// https://github.com/apache/datafusion/issues/10886
220+
// SMJ produces 1 more row in the output
185221
#[ignore]
186222
#[tokio::test]
187223
async fn test_semi_join_1k_filtered() {
188224
JoinFuzzTestCase::new(
189225
make_staggered_batches(1000),
190226
make_staggered_batches(1000),
191227
JoinType::LeftSemi,
192-
Some(Box::new(less_than_100_join_filter)),
228+
Some(Box::new(col_lt_col_filter)),
193229
)
194-
.run_test()
230+
.run_test(&[JoinTestType::HjSmj], false)
195231
.await
196232
}
197233

@@ -203,7 +239,7 @@ async fn test_anti_join_1k() {
203239
JoinType::LeftAnti,
204240
None,
205241
)
206-
.run_test()
242+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
207243
.await
208244
}
209245

@@ -217,7 +253,7 @@ async fn test_anti_join_1k_filtered() {
217253
JoinType::LeftAnti,
218254
Some(Box::new(less_than_100_join_filter)),
219255
)
220-
.run_test()
256+
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
221257
.await
222258
}
223259

@@ -331,7 +367,7 @@ impl JoinFuzzTestCase {
331367
self.on_columns().clone(),
332368
self.join_filter(),
333369
self.join_type,
334-
vec![SortOptions::default(), SortOptions::default()],
370+
vec![SortOptions::default(); self.on_columns().len()],
335371
false,
336372
)
337373
.unwrap(),
@@ -381,9 +417,11 @@ impl JoinFuzzTestCase {
381417
)
382418
}
383419

384-
/// Perform sort-merge join and hash join on same input
385-
/// and verify two outputs are equal
386-
async fn run_test(&self) {
420+
/// Perform joins tests on same inputs and verify outputs are equal
421+
/// `join_tests` - identifies what join types to test
422+
/// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders,
423+
/// so it is easy to debug a test on top of the failed data
424+
async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) {
387425
for batch_size in self.batch_sizes {
388426
let session_config = SessionConfig::new().with_batch_size(*batch_size);
389427
let ctx = SessionContext::new_with_config(session_config);
@@ -394,17 +432,30 @@ impl JoinFuzzTestCase {
394432
let hj = self.hash_join();
395433
let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();
396434

435+
let nlj = self.nested_loop_join();
436+
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
437+
397438
// Get actual row counts(without formatting overhead) for HJ and SMJ
398439
let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
399440
let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
441+
let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
400442

401-
assert_eq!(
402-
hj_rows, smj_rows,
403-
"SortMergeJoinExec and HashJoinExec produced different row counts"
404-
);
443+
if debug {
444+
println!("The debug is ON. Input data will be saved");
445+
let out_dir_name = &format!("fuzz_test_debug_batch_size_{batch_size}");
446+
Self::save_as_parquet(&self.input1, out_dir_name, "input1");
447+
Self::save_as_parquet(&self.input2, out_dir_name, "input2");
405448

406-
let nlj = self.nested_loop_join();
407-
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
449+
if join_tests.contains(&JoinTestType::NljHj) {
450+
Self::save_as_parquet(&nlj_collected, out_dir_name, "nlj");
451+
Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
452+
}
453+
454+
if join_tests.contains(&JoinTestType::HjSmj) {
455+
Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
456+
Self::save_as_parquet(&smj_collected, out_dir_name, "smj");
457+
}
458+
}
408459

409460
// compare
410461
let smj_formatted =
@@ -425,35 +476,106 @@ impl JoinFuzzTestCase {
425476
nlj_formatted.trim().lines().collect();
426477
nlj_formatted_sorted.sort_unstable();
427478

428-
// row level compare if any of joins returns the result
429-
// the reason is different formatting when there is no rows
430-
if smj_rows > 0 || hj_rows > 0 {
431-
for (i, (smj_line, hj_line)) in smj_formatted_sorted
479+
if join_tests.contains(&JoinTestType::NljHj) {
480+
let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size);
481+
assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str());
482+
483+
let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size);
484+
// row level compare if any of joins returns the result
485+
// the reason is different formatting when there is no rows
486+
for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
432487
.iter()
433488
.zip(&hj_formatted_sorted)
434489
.enumerate()
435490
{
436491
assert_eq!(
437-
(i, smj_line),
492+
(i, nlj_line),
438493
(i, hj_line),
439-
"SortMergeJoinExec and HashJoinExec produced different results"
494+
"{}",
495+
err_msg_contents.as_str()
440496
);
441497
}
442498
}
443499

444-
for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
445-
.iter()
446-
.zip(&hj_formatted_sorted)
447-
.enumerate()
448-
{
449-
assert_eq!(
450-
(i, nlj_line),
451-
(i, hj_line),
452-
"NestedLoopJoinExec and HashJoinExec produced different results"
453-
);
500+
if join_tests.contains(&JoinTestType::HjSmj) {
501+
let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size);
502+
assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str());
503+
504+
let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size);
505+
// row level compare if any of joins returns the result
506+
// the reason is different formatting when there is no rows
507+
if smj_rows > 0 || hj_rows > 0 {
508+
for (i, (smj_line, hj_line)) in smj_formatted_sorted
509+
.iter()
510+
.zip(&hj_formatted_sorted)
511+
.enumerate()
512+
{
513+
assert_eq!(
514+
(i, smj_line),
515+
(i, hj_line),
516+
"{}",
517+
err_msg_contents.as_str()
518+
);
519+
}
520+
}
454521
}
455522
}
456523
}
524+
525+
/// This method useful for debugging fuzz tests
526+
/// It helps to save randomly generated input test data for both join inputs into the user folder
527+
/// as a parquet files preserving partitioning.
528+
/// Once the data is saved it is possible to run a custom test on top of the saved data and debug
529+
///
530+
/// let ctx: SessionContext = SessionContext::new();
531+
/// let df = ctx
532+
/// .read_parquet(
533+
/// "/tmp/input1/*.parquet",
534+
/// ParquetReadOptions::default(),
535+
/// )
536+
/// .await
537+
/// .unwrap();
538+
/// let left = df.collect().await.unwrap();
539+
///
540+
/// let df = ctx
541+
/// .read_parquet(
542+
/// "/tmp/input2/*.parquet",
543+
/// ParquetReadOptions::default(),
544+
/// )
545+
/// .await
546+
/// .unwrap();
547+
///
548+
/// let right = df.collect().await.unwrap();
549+
/// JoinFuzzTestCase::new(
550+
/// left,
551+
/// right,
552+
/// JoinType::LeftSemi,
553+
/// Some(Box::new(less_than_100_join_filter)),
554+
/// )
555+
/// .run_test()
556+
/// .await
557+
/// }
558+
fn save_as_parquet(input: &[RecordBatch], output_dir: &str, out_name: &str) {
559+
let out_path = &format!("{output_dir}/{out_name}");
560+
std::fs::remove_dir_all(out_path).unwrap_or(());
561+
std::fs::create_dir_all(out_path).unwrap();
562+
563+
input.iter().enumerate().for_each(|(idx, batch)| {
564+
let mut file =
565+
std::fs::File::create(format!("{out_path}/file_{}.parquet", idx))
566+
.unwrap();
567+
let mut writer = parquet::arrow::ArrowWriter::try_new(
568+
&mut file,
569+
input.first().unwrap().schema(),
570+
None,
571+
)
572+
.expect("creating writer");
573+
writer.write(batch).unwrap();
574+
writer.close().unwrap();
575+
});
576+
577+
println!("The data {out_name} saved as parquet into {out_path}");
578+
}
457579
}
458580

459581
/// Return randomly sized record batches with:

0 commit comments

Comments
 (0)