Skip to content

Commit 189536b

Browse files
authored
Fix redundant data copying in unnest (#13441)
* Fix redundant data copying in unnest * Add test * fix typo
1 parent 6b0570b commit 189536b

File tree

2 files changed

+59
-9
lines changed

2 files changed

+59
-9
lines changed

datafusion/physical-plan/src/unnest.rs

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use arrow::compute::kernels::zip::zip;
3636
use arrow::compute::{cast, is_not_null, kernels, sum};
3737
use arrow::datatypes::{DataType, Int64Type, Schema, SchemaRef};
3838
use arrow::record_batch::RecordBatch;
39-
use arrow_array::{Int64Array, Scalar, StructArray};
39+
use arrow_array::{new_null_array, Int64Array, Scalar, StructArray};
4040
use arrow_ord::cmp::lt;
4141
use datafusion_common::{
4242
exec_datafusion_err, exec_err, internal_err, HashMap, HashSet, Result, UnnestOptions,
@@ -453,16 +453,36 @@ fn list_unnest_at_level(
453453

454454
// Create the take indices array for other columns
455455
let take_indices = create_take_indicies(unnested_length, total_length);
456-
457-
// Dimension of arrays in batch is untouched, but the values are repeated
458-
// as the side effect of unnesting
459-
let ret = repeat_arrs_from_indices(batch, &take_indices)?;
460456
unnested_temp_arrays
461457
.into_iter()
462458
.zip(list_unnest_specs.iter())
463459
.for_each(|(flatten_arr, unnesting)| {
464460
temp_unnested_arrs.insert(*unnesting, flatten_arr);
465461
});
462+
463+
let repeat_mask: Vec<bool> = batch
464+
.iter()
465+
.enumerate()
466+
.map(|(i, _)| {
467+
// Check if the column is needed in future levels (levels below the current one)
468+
let needed_in_future_levels = list_type_unnests.iter().any(|unnesting| {
469+
unnesting.index_in_input_schema == i && unnesting.depth < level_to_unnest
470+
});
471+
472+
// Check if the column is involved in unnesting at any level
473+
let is_involved_in_unnesting = list_type_unnests
474+
.iter()
475+
.any(|unnesting| unnesting.index_in_input_schema == i);
476+
477+
// Repeat columns needed in future levels or not unnested.
478+
needed_in_future_levels || !is_involved_in_unnesting
479+
})
480+
.collect();
481+
482+
// Dimension of arrays in batch is untouched, but the values are repeated
483+
// as the side effect of unnesting
484+
let ret = repeat_arrs_from_indices(batch, &take_indices, &repeat_mask)?;
485+
466486
Ok((ret, total_length))
467487
}
468488
struct UnnestingResult {
@@ -859,8 +879,11 @@ fn create_take_indicies(
859879
builder.finish()
860880
}
861881

862-
/// Create the batch given an arrays and a `indices` array
863-
/// that is used by the take kernel to copy values.
882+
/// Create a batch of arrays based on an input `batch` and a `indices` array.
883+
/// The `indices` array is used by the take kernel to repeat values in the arrays
884+
/// that are marked with `true` in the `repeat_mask`. Arrays marked with `false`
885+
/// in the `repeat_mask` will be replaced with arrays filled with nulls of the
886+
/// appropriate length.
864887
///
865888
/// For example if we have the following batch:
866889
///
@@ -890,14 +913,35 @@ fn create_take_indicies(
890913
/// c2: 'a', 'b', 'c', 'c', 'c', null, 'd', 'd'
891914
/// ```
892915
///
916+
/// The `repeat_mask` determines whether an array's values are repeated or replaced with nulls.
917+
/// For example, if the `repeat_mask` is:
918+
///
919+
/// ```ignore
920+
/// [true, false]
921+
/// ```
922+
///
923+
/// The final batch will look like:
924+
///
925+
/// ```ignore
926+
/// c1: 1, null, 2, 3, 4, null, 5, 6 // Repeated using `indices`
927+
/// c2: null, null, null, null, null, null, null, null // Replaced with nulls
928+
///
893929
fn repeat_arrs_from_indices(
894930
batch: &[ArrayRef],
895931
indices: &PrimitiveArray<Int64Type>,
932+
repeat_mask: &[bool],
896933
) -> Result<Vec<Arc<dyn Array>>> {
897934
batch
898935
.iter()
899-
.map(|arr| Ok(kernels::take::take(arr, indices, None)?))
900-
.collect::<Result<_>>()
936+
.zip(repeat_mask.iter())
937+
.map(|(arr, &repeat)| {
938+
if repeat {
939+
Ok(kernels::take::take(arr, indices, None)?)
940+
} else {
941+
Ok(new_null_array(arr.data_type(), arr.len()))
942+
}
943+
})
944+
.collect()
901945
}
902946

903947
#[cfg(test)]

datafusion/sqllogictest/test_files/unnest.slt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,3 +853,9 @@ select unnest(u.column5), j.* except(column2, column3) from unnest_table u join
853853
1 2 1
854854
3 4 2
855855
NULL NULL 4
856+
857+
## Issue: https://github.com/apache/datafusion/issues/13237
858+
query I
859+
select count(*) from (select unnest(range(0, 100000)) id) t inner join (select unnest(range(0, 100000)) id) t1 on t.id = t1.id;
860+
----
861+
100000

0 commit comments

Comments
 (0)