@@ -43,6 +43,17 @@ use datafusion::physical_plan::memory::MemoryExec;
43
43
use datafusion:: prelude:: { SessionConfig , SessionContext } ;
44
44
use test_utils:: stagger_batch_with_seed;
45
45
46
+ // Determines what Fuzz tests needs to run
47
+ // Ideally all tests should match, but in reality some tests
48
+ // passes only partial cases
49
+ #[ derive( Debug , Clone , Copy , PartialEq , Eq , Hash ) ]
50
+ enum JoinTestType {
51
+ // compare NestedLoopJoin and HashJoin
52
+ NljHj ,
53
+ // compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin and NestedLoopJoin
54
+ // because if existing variants both passed that means SortMergeJoin and NestedLoopJoin also passes
55
+ HjSmj ,
56
+ }
46
57
#[ tokio:: test]
47
58
async fn test_inner_join_1k ( ) {
48
59
JoinFuzzTestCase :: new (
@@ -51,7 +62,7 @@ async fn test_inner_join_1k() {
51
62
JoinType :: Inner ,
52
63
None ,
53
64
)
54
- . run_test ( )
65
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
55
66
. await
56
67
}
57
68
@@ -71,6 +82,30 @@ fn less_than_100_join_filter(schema1: Arc<Schema>, _schema2: Arc<Schema>) -> Joi
71
82
JoinFilter :: new ( less_than_100, column_indices, intermediate_schema)
72
83
}
73
84
85
+ fn col_lt_col_filter ( schema1 : Arc < Schema > , schema2 : Arc < Schema > ) -> JoinFilter {
86
+ let less_than_100 = Arc :: new ( BinaryExpr :: new (
87
+ Arc :: new ( Column :: new ( "x" , 1 ) ) ,
88
+ Operator :: Lt ,
89
+ Arc :: new ( Column :: new ( "x" , 0 ) ) ,
90
+ ) ) as _ ;
91
+ let column_indices = vec ! [
92
+ ColumnIndex {
93
+ index: 2 ,
94
+ side: JoinSide :: Left ,
95
+ } ,
96
+ ColumnIndex {
97
+ index: 2 ,
98
+ side: JoinSide :: Right ,
99
+ } ,
100
+ ] ;
101
+ let intermediate_schema = Schema :: new ( vec ! [
102
+ schema1. field_with_name( "x" ) . unwrap( ) . to_owned( ) ,
103
+ schema2. field_with_name( "x" ) . unwrap( ) . to_owned( ) ,
104
+ ] ) ;
105
+
106
+ JoinFilter :: new ( less_than_100, column_indices, intermediate_schema)
107
+ }
108
+
74
109
#[ tokio:: test]
75
110
async fn test_inner_join_1k_filtered ( ) {
76
111
JoinFuzzTestCase :: new (
@@ -79,7 +114,7 @@ async fn test_inner_join_1k_filtered() {
79
114
JoinType :: Inner ,
80
115
Some ( Box :: new ( less_than_100_join_filter) ) ,
81
116
)
82
- . run_test ( )
117
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
83
118
. await
84
119
}
85
120
@@ -91,7 +126,7 @@ async fn test_inner_join_1k_smjoin() {
91
126
JoinType :: Inner ,
92
127
None ,
93
128
)
94
- . run_test ( )
129
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
95
130
. await
96
131
}
97
132
@@ -103,7 +138,7 @@ async fn test_left_join_1k() {
103
138
JoinType :: Left ,
104
139
None ,
105
140
)
106
- . run_test ( )
141
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
107
142
. await
108
143
}
109
144
@@ -115,7 +150,7 @@ async fn test_left_join_1k_filtered() {
115
150
JoinType :: Left ,
116
151
Some ( Box :: new ( less_than_100_join_filter) ) ,
117
152
)
118
- . run_test ( )
153
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
119
154
. await
120
155
}
121
156
@@ -127,7 +162,7 @@ async fn test_right_join_1k() {
127
162
JoinType :: Right ,
128
163
None ,
129
164
)
130
- . run_test ( )
165
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
131
166
. await
132
167
}
133
168
// Add support for Right filtered joins
@@ -140,7 +175,7 @@ async fn test_right_join_1k_filtered() {
140
175
JoinType :: Right ,
141
176
Some ( Box :: new ( less_than_100_join_filter) ) ,
142
177
)
143
- . run_test ( )
178
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
144
179
. await
145
180
}
146
181
@@ -152,7 +187,7 @@ async fn test_full_join_1k() {
152
187
JoinType :: Full ,
153
188
None ,
154
189
)
155
- . run_test ( )
190
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
156
191
. await
157
192
}
158
193
@@ -164,7 +199,7 @@ async fn test_full_join_1k_filtered() {
164
199
JoinType :: Full ,
165
200
Some ( Box :: new ( less_than_100_join_filter) ) ,
166
201
)
167
- . run_test ( )
202
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
168
203
. await
169
204
}
170
205
@@ -176,22 +211,23 @@ async fn test_semi_join_1k() {
176
211
JoinType :: LeftSemi ,
177
212
None ,
178
213
)
179
- . run_test ( )
214
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
180
215
. await
181
216
}
182
217
183
218
// The test is flaky
184
219
// https://github.com/apache/datafusion/issues/10886
220
+ // SMJ produces 1 more row in the output
185
221
#[ ignore]
186
222
#[ tokio:: test]
187
223
async fn test_semi_join_1k_filtered ( ) {
188
224
JoinFuzzTestCase :: new (
189
225
make_staggered_batches ( 1000 ) ,
190
226
make_staggered_batches ( 1000 ) ,
191
227
JoinType :: LeftSemi ,
192
- Some ( Box :: new ( less_than_100_join_filter ) ) ,
228
+ Some ( Box :: new ( col_lt_col_filter ) ) ,
193
229
)
194
- . run_test ( )
230
+ . run_test ( & [ JoinTestType :: HjSmj ] , false )
195
231
. await
196
232
}
197
233
@@ -203,7 +239,7 @@ async fn test_anti_join_1k() {
203
239
JoinType :: LeftAnti ,
204
240
None ,
205
241
)
206
- . run_test ( )
242
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
207
243
. await
208
244
}
209
245
@@ -217,7 +253,7 @@ async fn test_anti_join_1k_filtered() {
217
253
JoinType :: LeftAnti ,
218
254
Some ( Box :: new ( less_than_100_join_filter) ) ,
219
255
)
220
- . run_test ( )
256
+ . run_test ( & [ JoinTestType :: HjSmj , JoinTestType :: NljHj ] , false )
221
257
. await
222
258
}
223
259
@@ -331,7 +367,7 @@ impl JoinFuzzTestCase {
331
367
self . on_columns ( ) . clone ( ) ,
332
368
self . join_filter ( ) ,
333
369
self . join_type ,
334
- vec ! [ SortOptions :: default ( ) , SortOptions :: default ( ) ] ,
370
+ vec ! [ SortOptions :: default ( ) ; self . on_columns ( ) . len ( ) ] ,
335
371
false ,
336
372
)
337
373
. unwrap ( ) ,
@@ -381,9 +417,11 @@ impl JoinFuzzTestCase {
381
417
)
382
418
}
383
419
384
- /// Perform sort-merge join and hash join on same input
385
- /// and verify two outputs are equal
386
- async fn run_test ( & self ) {
420
+ /// Perform joins tests on same inputs and verify outputs are equal
421
+ /// `join_tests` - identifies what join types to test
422
+ /// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders,
423
+ /// so it is easy to debug a test on top of the failed data
424
+ async fn run_test ( & self , join_tests : & [ JoinTestType ] , debug : bool ) {
387
425
for batch_size in self . batch_sizes {
388
426
let session_config = SessionConfig :: new ( ) . with_batch_size ( * batch_size) ;
389
427
let ctx = SessionContext :: new_with_config ( session_config) ;
@@ -394,17 +432,30 @@ impl JoinFuzzTestCase {
394
432
let hj = self . hash_join ( ) ;
395
433
let hj_collected = collect ( hj, task_ctx. clone ( ) ) . await . unwrap ( ) ;
396
434
435
+ let nlj = self . nested_loop_join ( ) ;
436
+ let nlj_collected = collect ( nlj, task_ctx. clone ( ) ) . await . unwrap ( ) ;
437
+
397
438
// Get actual row counts(without formatting overhead) for HJ and SMJ
398
439
let hj_rows = hj_collected. iter ( ) . fold ( 0 , |acc, b| acc + b. num_rows ( ) ) ;
399
440
let smj_rows = smj_collected. iter ( ) . fold ( 0 , |acc, b| acc + b. num_rows ( ) ) ;
441
+ let nlj_rows = nlj_collected. iter ( ) . fold ( 0 , |acc, b| acc + b. num_rows ( ) ) ;
400
442
401
- assert_eq ! (
402
- hj_rows, smj_rows,
403
- "SortMergeJoinExec and HashJoinExec produced different row counts"
404
- ) ;
443
+ if debug {
444
+ println ! ( "The debug is ON. Input data will be saved" ) ;
445
+ let out_dir_name = & format ! ( "fuzz_test_debug_batch_size_{batch_size}" ) ;
446
+ Self :: save_as_parquet ( & self . input1 , out_dir_name, "input1" ) ;
447
+ Self :: save_as_parquet ( & self . input2 , out_dir_name, "input2" ) ;
405
448
406
- let nlj = self . nested_loop_join ( ) ;
407
- let nlj_collected = collect ( nlj, task_ctx. clone ( ) ) . await . unwrap ( ) ;
449
+ if join_tests. contains ( & JoinTestType :: NljHj ) {
450
+ Self :: save_as_parquet ( & nlj_collected, out_dir_name, "nlj" ) ;
451
+ Self :: save_as_parquet ( & hj_collected, out_dir_name, "hj" ) ;
452
+ }
453
+
454
+ if join_tests. contains ( & JoinTestType :: HjSmj ) {
455
+ Self :: save_as_parquet ( & hj_collected, out_dir_name, "hj" ) ;
456
+ Self :: save_as_parquet ( & smj_collected, out_dir_name, "smj" ) ;
457
+ }
458
+ }
408
459
409
460
// compare
410
461
let smj_formatted =
@@ -425,35 +476,106 @@ impl JoinFuzzTestCase {
425
476
nlj_formatted. trim ( ) . lines ( ) . collect ( ) ;
426
477
nlj_formatted_sorted. sort_unstable ( ) ;
427
478
428
- // row level compare if any of joins returns the result
429
- // the reason is different formatting when there is no rows
430
- if smj_rows > 0 || hj_rows > 0 {
431
- for ( i, ( smj_line, hj_line) ) in smj_formatted_sorted
479
+ if join_tests. contains ( & JoinTestType :: NljHj ) {
480
+ let err_msg_rowcnt = format ! ( "NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}" , batch_size) ;
481
+ assert_eq ! ( nlj_rows, hj_rows, "{}" , err_msg_rowcnt. as_str( ) ) ;
482
+
483
+ let err_msg_contents = format ! ( "NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}" , batch_size) ;
484
+ // row level compare if any of joins returns the result
485
+ // the reason is different formatting when there is no rows
486
+ for ( i, ( nlj_line, hj_line) ) in nlj_formatted_sorted
432
487
. iter ( )
433
488
. zip ( & hj_formatted_sorted)
434
489
. enumerate ( )
435
490
{
436
491
assert_eq ! (
437
- ( i, smj_line ) ,
492
+ ( i, nlj_line ) ,
438
493
( i, hj_line) ,
439
- "SortMergeJoinExec and HashJoinExec produced different results"
494
+ "{}" ,
495
+ err_msg_contents. as_str( )
440
496
) ;
441
497
}
442
498
}
443
499
444
- for ( i, ( nlj_line, hj_line) ) in nlj_formatted_sorted
445
- . iter ( )
446
- . zip ( & hj_formatted_sorted)
447
- . enumerate ( )
448
- {
449
- assert_eq ! (
450
- ( i, nlj_line) ,
451
- ( i, hj_line) ,
452
- "NestedLoopJoinExec and HashJoinExec produced different results"
453
- ) ;
500
+ if join_tests. contains ( & JoinTestType :: HjSmj ) {
501
+ let err_msg_row_cnt = format ! ( "HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}" , & batch_size) ;
502
+ assert_eq ! ( hj_rows, smj_rows, "{}" , err_msg_row_cnt. as_str( ) ) ;
503
+
504
+ let err_msg_contents = format ! ( "SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}" , & batch_size) ;
505
+ // row level compare if any of joins returns the result
506
+ // the reason is different formatting when there is no rows
507
+ if smj_rows > 0 || hj_rows > 0 {
508
+ for ( i, ( smj_line, hj_line) ) in smj_formatted_sorted
509
+ . iter ( )
510
+ . zip ( & hj_formatted_sorted)
511
+ . enumerate ( )
512
+ {
513
+ assert_eq ! (
514
+ ( i, smj_line) ,
515
+ ( i, hj_line) ,
516
+ "{}" ,
517
+ err_msg_contents. as_str( )
518
+ ) ;
519
+ }
520
+ }
454
521
}
455
522
}
456
523
}
524
+
525
+ /// This method useful for debugging fuzz tests
526
+ /// It helps to save randomly generated input test data for both join inputs into the user folder
527
+ /// as a parquet files preserving partitioning.
528
+ /// Once the data is saved it is possible to run a custom test on top of the saved data and debug
529
+ ///
530
+ /// let ctx: SessionContext = SessionContext::new();
531
+ /// let df = ctx
532
+ /// .read_parquet(
533
+ /// "/tmp/input1/*.parquet",
534
+ /// ParquetReadOptions::default(),
535
+ /// )
536
+ /// .await
537
+ /// .unwrap();
538
+ /// let left = df.collect().await.unwrap();
539
+ ///
540
+ /// let df = ctx
541
+ /// .read_parquet(
542
+ /// "/tmp/input2/*.parquet",
543
+ /// ParquetReadOptions::default(),
544
+ /// )
545
+ /// .await
546
+ /// .unwrap();
547
+ ///
548
+ /// let right = df.collect().await.unwrap();
549
+ /// JoinFuzzTestCase::new(
550
+ /// left,
551
+ /// right,
552
+ /// JoinType::LeftSemi,
553
+ /// Some(Box::new(less_than_100_join_filter)),
554
+ /// )
555
+ /// .run_test()
556
+ /// .await
557
+ /// }
558
+ fn save_as_parquet ( input : & [ RecordBatch ] , output_dir : & str , out_name : & str ) {
559
+ let out_path = & format ! ( "{output_dir}/{out_name}" ) ;
560
+ std:: fs:: remove_dir_all ( out_path) . unwrap_or ( ( ) ) ;
561
+ std:: fs:: create_dir_all ( out_path) . unwrap ( ) ;
562
+
563
+ input. iter ( ) . enumerate ( ) . for_each ( |( idx, batch) | {
564
+ let mut file =
565
+ std:: fs:: File :: create ( format ! ( "{out_path}/file_{}.parquet" , idx) )
566
+ . unwrap ( ) ;
567
+ let mut writer = parquet:: arrow:: ArrowWriter :: try_new (
568
+ & mut file,
569
+ input. first ( ) . unwrap ( ) . schema ( ) ,
570
+ None ,
571
+ )
572
+ . expect ( "creating writer" ) ;
573
+ writer. write ( batch) . unwrap ( ) ;
574
+ writer. close ( ) . unwrap ( ) ;
575
+ } ) ;
576
+
577
+ println ! ( "The data {out_name} saved as parquet into {out_path}" ) ;
578
+ }
457
579
}
458
580
459
581
/// Return randomly sized record batches with:
0 commit comments