@@ -36,7 +36,7 @@ use arrow::compute::kernels::zip::zip;
36
36
use arrow:: compute:: { cast, is_not_null, kernels, sum} ;
37
37
use arrow:: datatypes:: { DataType , Int64Type , Schema , SchemaRef } ;
38
38
use arrow:: record_batch:: RecordBatch ;
39
- use arrow_array:: { Int64Array , Scalar , StructArray } ;
39
+ use arrow_array:: { new_null_array , Int64Array , Scalar , StructArray } ;
40
40
use arrow_ord:: cmp:: lt;
41
41
use datafusion_common:: {
42
42
exec_datafusion_err, exec_err, internal_err, HashMap , HashSet , Result , UnnestOptions ,
@@ -453,16 +453,36 @@ fn list_unnest_at_level(
453
453
454
454
// Create the take indices array for other columns
455
455
let take_indices = create_take_indicies ( unnested_length, total_length) ;
456
-
457
- // Dimension of arrays in batch is untouched, but the values are repeated
458
- // as the side effect of unnesting
459
- let ret = repeat_arrs_from_indices ( batch, & take_indices) ?;
460
456
unnested_temp_arrays
461
457
. into_iter ( )
462
458
. zip ( list_unnest_specs. iter ( ) )
463
459
. for_each ( |( flatten_arr, unnesting) | {
464
460
temp_unnested_arrs. insert ( * unnesting, flatten_arr) ;
465
461
} ) ;
462
+
463
+ let repeat_mask: Vec < bool > = batch
464
+ . iter ( )
465
+ . enumerate ( )
466
+ . map ( |( i, _) | {
467
+ // Check if the column is needed in future levels (levels below the current one)
468
+ let needed_in_future_levels = list_type_unnests. iter ( ) . any ( |unnesting| {
469
+ unnesting. index_in_input_schema == i && unnesting. depth < level_to_unnest
470
+ } ) ;
471
+
472
+ // Check if the column is involved in unnesting at any level
473
+ let is_involved_in_unnesting = list_type_unnests
474
+ . iter ( )
475
+ . any ( |unnesting| unnesting. index_in_input_schema == i) ;
476
+
477
+ // Repeat columns needed in future levels or not unnested.
478
+ needed_in_future_levels || !is_involved_in_unnesting
479
+ } )
480
+ . collect ( ) ;
481
+
482
+ // Dimension of arrays in batch is untouched, but the values are repeated
483
+ // as the side effect of unnesting
484
+ let ret = repeat_arrs_from_indices ( batch, & take_indices, & repeat_mask) ?;
485
+
466
486
Ok ( ( ret, total_length) )
467
487
}
468
488
struct UnnestingResult {
@@ -859,8 +879,11 @@ fn create_take_indicies(
859
879
builder. finish ( )
860
880
}
861
881
862
- /// Create the batch given an arrays and a `indices` array
863
- /// that is used by the take kernel to copy values.
882
+ /// Create a batch of arrays based on an input `batch` and a `indices` array.
883
+ /// The `indices` array is used by the take kernel to repeat values in the arrays
884
+ /// that are marked with `true` in the `repeat_mask`. Arrays marked with `false`
885
+ /// in the `repeat_mask` will be replaced with arrays filled with nulls of the
886
+ /// appropriate length.
864
887
///
865
888
/// For example if we have the following batch:
866
889
///
@@ -890,14 +913,35 @@ fn create_take_indicies(
890
913
/// c2: 'a', 'b', 'c', 'c', 'c', null, 'd', 'd'
891
914
/// ```
892
915
///
916
+ /// The `repeat_mask` determines whether an array's values are repeated or replaced with nulls.
917
+ /// For example, if the `repeat_mask` is:
918
+ ///
919
+ /// ```ignore
920
+ /// [true, false]
921
+ /// ```
922
+ ///
923
+ /// The final batch will look like:
924
+ ///
925
+ /// ```ignore
926
+ /// c1: 1, null, 2, 3, 4, null, 5, 6 // Repeated using `indices`
927
+ /// c2: null, null, null, null, null, null, null, null // Replaced with nulls
928
+ ///
893
929
fn repeat_arrs_from_indices (
894
930
batch : & [ ArrayRef ] ,
895
931
indices : & PrimitiveArray < Int64Type > ,
932
+ repeat_mask : & [ bool ] ,
896
933
) -> Result < Vec < Arc < dyn Array > > > {
897
934
batch
898
935
. iter ( )
899
- . map ( |arr| Ok ( kernels:: take:: take ( arr, indices, None ) ?) )
900
- . collect :: < Result < _ > > ( )
936
+ . zip ( repeat_mask. iter ( ) )
937
+ . map ( |( arr, & repeat) | {
938
+ if repeat {
939
+ Ok ( kernels:: take:: take ( arr, indices, None ) ?)
940
+ } else {
941
+ Ok ( new_null_array ( arr. data_type ( ) , arr. len ( ) ) )
942
+ }
943
+ } )
944
+ . collect ( )
901
945
}
902
946
903
947
#[ cfg( test) ]
0 commit comments