Skip to content

Commit a0b0221

Browse files
authored
Change flatten so it does only a level, not recursively (#15160)
* flatten array in a single step instead of recursive * clippy * update flatten type signature to Array * add fixed list to list coercion to flatten signature * support LargeList(List) and LargeList(FixedSizeList) in flatten * add test for LargeList(FixedSizeList) * handle nulls * uncomment flatten(NULL) test - it already works
1 parent 112cde8 commit a0b0221

File tree

2 files changed

+115
-77
lines changed

2 files changed

+115
-77
lines changed

datafusion/functions-nested/src/flatten.rs

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,18 @@
1818
//! [`ScalarUDFImpl`] definitions for flatten function.
1919
2020
use crate::utils::make_scalar_function;
21-
use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait};
21+
use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait};
2222
use arrow::buffer::OffsetBuffer;
2323
use arrow::datatypes::{
2424
DataType,
2525
DataType::{FixedSizeList, LargeList, List, Null},
2626
};
27-
use datafusion_common::cast::{
28-
as_generic_list_array, as_large_list_array, as_list_array,
29-
};
27+
use datafusion_common::cast::{as_large_list_array, as_list_array};
28+
use datafusion_common::utils::ListCoercion;
3029
use datafusion_common::{exec_err, utils::take_function_args, Result};
3130
use datafusion_expr::{
32-
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
33-
TypeSignature, Volatility,
31+
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
32+
ScalarUDFImpl, Signature, TypeSignature, Volatility,
3433
};
3534
use datafusion_macros::user_doc;
3635
use std::any::Any;
@@ -77,9 +76,11 @@ impl Flatten {
7776
pub fn new() -> Self {
7877
Self {
7978
signature: Signature {
80-
// TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
8179
type_signature: TypeSignature::ArraySignature(
82-
ArrayFunctionSignature::RecursiveArray,
80+
ArrayFunctionSignature::Array {
81+
arguments: vec![ArrayFunctionArgument::Array],
82+
array_coercion: Some(ListCoercion::FixedSizedListToList),
83+
},
8384
),
8485
volatility: Volatility::Immutable,
8586
},
@@ -102,25 +103,23 @@ impl ScalarUDFImpl for Flatten {
102103
}
103104

104105
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
105-
fn get_base_type(data_type: &DataType) -> Result<DataType> {
106-
match data_type {
107-
List(field) | FixedSizeList(field, _)
108-
if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) =>
109-
{
110-
get_base_type(field.data_type())
111-
}
112-
LargeList(field) if matches!(field.data_type(), LargeList(_)) => {
113-
get_base_type(field.data_type())
106+
let data_type = match &arg_types[0] {
107+
List(field) | FixedSizeList(field, _) => match field.data_type() {
108+
List(field) | FixedSizeList(field, _) => List(Arc::clone(field)),
109+
_ => arg_types[0].clone(),
110+
},
111+
LargeList(field) => match field.data_type() {
112+
List(field) | LargeList(field) | FixedSizeList(field, _) => {
113+
LargeList(Arc::clone(field))
114114
}
115-
Null | List(_) | LargeList(_) => Ok(data_type.to_owned()),
116-
FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
117-
_ => exec_err!(
118-
"Not reachable, data_type should be List, LargeList or FixedSizeList"
119-
),
120-
}
121-
}
115+
_ => arg_types[0].clone(),
116+
},
117+
Null => Null,
118+
_ => exec_err!(
119+
"Not reachable, data_type should be List, LargeList or FixedSizeList"
120+
)?,
121+
};
122122

123-
let data_type = get_base_type(&arg_types[0])?;
124123
Ok(data_type)
125124
}
126125

@@ -146,14 +145,62 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
146145

147146
match array.data_type() {
148147
List(_) => {
149-
let list_arr = as_list_array(&array)?;
150-
let flattened_array = flatten_internal::<i32>(list_arr.clone(), None)?;
151-
Ok(Arc::new(flattened_array) as ArrayRef)
148+
let (field, offsets, values, nulls) =
149+
as_list_array(&array)?.clone().into_parts();
150+
151+
match field.data_type() {
152+
List(_) => {
153+
let (inner_field, inner_offsets, inner_values, _) =
154+
as_list_array(&values)?.clone().into_parts();
155+
let offsets = get_offsets_for_flatten::<i32>(inner_offsets, offsets);
156+
let flattened_array = GenericListArray::<i32>::new(
157+
inner_field,
158+
offsets,
159+
inner_values,
160+
nulls,
161+
);
162+
163+
Ok(Arc::new(flattened_array) as ArrayRef)
164+
}
165+
LargeList(_) => {
166+
exec_err!("flatten does not support type '{:?}'", array.data_type())?
167+
}
168+
_ => Ok(Arc::clone(array) as ArrayRef),
169+
}
152170
}
153171
LargeList(_) => {
154-
let list_arr = as_large_list_array(&array)?;
155-
let flattened_array = flatten_internal::<i64>(list_arr.clone(), None)?;
156-
Ok(Arc::new(flattened_array) as ArrayRef)
172+
let (field, offsets, values, nulls) =
173+
as_large_list_array(&array)?.clone().into_parts();
174+
175+
match field.data_type() {
176+
List(_) => {
177+
let (inner_field, inner_offsets, inner_values, _) =
178+
as_list_array(&values)?.clone().into_parts();
179+
let offsets = get_large_offsets_for_flatten(inner_offsets, offsets);
180+
let flattened_array = GenericListArray::<i64>::new(
181+
inner_field,
182+
offsets,
183+
inner_values,
184+
nulls,
185+
);
186+
187+
Ok(Arc::new(flattened_array) as ArrayRef)
188+
}
189+
LargeList(_) => {
190+
let (inner_field, inner_offsets, inner_values, nulls) =
191+
as_large_list_array(&values)?.clone().into_parts();
192+
let offsets = get_offsets_for_flatten::<i64>(inner_offsets, offsets);
193+
let flattened_array = GenericListArray::<i64>::new(
194+
inner_field,
195+
offsets,
196+
inner_values,
197+
nulls,
198+
);
199+
200+
Ok(Arc::new(flattened_array) as ArrayRef)
201+
}
202+
_ => Ok(Arc::clone(array) as ArrayRef),
203+
}
157204
}
158205
Null => Ok(Arc::clone(array)),
159206
_ => {
@@ -162,37 +209,6 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
162209
}
163210
}
164211

165-
fn flatten_internal<O: OffsetSizeTrait>(
166-
list_arr: GenericListArray<O>,
167-
indexes: Option<OffsetBuffer<O>>,
168-
) -> Result<GenericListArray<O>> {
169-
let (field, offsets, values, _) = list_arr.clone().into_parts();
170-
let data_type = field.data_type();
171-
172-
match data_type {
173-
// Recursively get the base offsets for flattened array
174-
List(_) | LargeList(_) => {
175-
let sub_list = as_generic_list_array::<O>(&values)?;
176-
if let Some(indexes) = indexes {
177-
let offsets = get_offsets_for_flatten(offsets, indexes);
178-
flatten_internal::<O>(sub_list.clone(), Some(offsets))
179-
} else {
180-
flatten_internal::<O>(sub_list.clone(), Some(offsets))
181-
}
182-
}
183-
// Reach the base level, create a new list array
184-
_ => {
185-
if let Some(indexes) = indexes {
186-
let offsets = get_offsets_for_flatten(offsets, indexes);
187-
let list_arr = GenericListArray::<O>::new(field, offsets, values, None);
188-
Ok(list_arr)
189-
} else {
190-
Ok(list_arr)
191-
}
192-
}
193-
}
194-
}
195-
196212
// Create new offsets that are equivalent to `flatten` the array.
197213
fn get_offsets_for_flatten<O: OffsetSizeTrait>(
198214
offsets: OffsetBuffer<O>,
@@ -205,3 +221,16 @@ fn get_offsets_for_flatten<O: OffsetSizeTrait>(
205221
.collect();
206222
OffsetBuffer::new(offsets.into())
207223
}
224+
225+
// Create new large offsets that are equivalent to `flatten` the array.
226+
fn get_large_offsets_for_flatten<O: OffsetSizeTrait, P: OffsetSizeTrait>(
227+
offsets: OffsetBuffer<O>,
228+
indexes: OffsetBuffer<P>,
229+
) -> OffsetBuffer<i64> {
230+
let buffer = offsets.into_inner();
231+
let offsets: Vec<i64> = indexes
232+
.iter()
233+
.map(|i| buffer[i.to_usize().unwrap()].to_i64().unwrap())
234+
.collect();
235+
OffsetBuffer::new(offsets.into())
236+
}

datafusion/sqllogictest/test_files/array.slt

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7285,34 +7285,32 @@ select array_concat(column1, [7]) from arrays_values_v2;
72857285

72867286
# flatten
72877287

7288-
#TODO: https://github.com/apache/datafusion/issues/7142
7289-
# follow DuckDB
7290-
#query ?
7291-
#select flatten(NULL);
7292-
#----
7293-
#NULL
7288+
query ?
7289+
select flatten(NULL);
7290+
----
7291+
NULL
72947292

72957293
# flatten with scalar values #1
72967294
query ???
72977295
select flatten(make_array(1, 2, 1, 3, 2)),
72987296
flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
72997297
flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]));
73007298
----
7301-
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
7299+
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]
73027300

73037301
query ???
73047302
select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')),
73057303
flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'LargeList(LargeList(Int64))')),
73067304
flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 'LargeList(LargeList(LargeList(Float64)))'));
73077305
----
7308-
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
7306+
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]
73097307

73107308
query ???
73117309
select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'FixedSizeList(5, Int64)')),
73127310
flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'FixedSizeList(4, List(Int64))')),
73137311
flatten(arrow_cast(make_array([[1.1], [2.2]], [[3.3], [4.4]]), 'FixedSizeList(2, List(List(Float64)))'));
73147312
----
7315-
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
7313+
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]
73167314

73177315
# flatten with column values
73187316
query ????
@@ -7322,8 +7320,8 @@ select flatten(column1),
73227320
flatten(column4)
73237321
from flatten_table;
73247322
----
7325-
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
7326-
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
7323+
[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
7324+
[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
73277325

73287326
query ????
73297327
select flatten(column1),
@@ -7332,8 +7330,8 @@ select flatten(column1),
73327330
flatten(column4)
73337331
from large_flatten_table;
73347332
----
7335-
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
7336-
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
7333+
[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
7334+
[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
73377335

73387336
query ????
73397337
select flatten(column1),
@@ -7342,8 +7340,19 @@ select flatten(column1),
73427340
flatten(column4)
73437341
from fixed_size_flatten_table;
73447342
----
7345-
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
7346-
[1, 2, 3, 4, 5, 6] [8, 9, 10, 11, 12, 13] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
7343+
[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
7344+
[1, 2, 3, 4, 5, 6] [[8], [9, 10], [11, 12, 13]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
7345+
7346+
# flatten with different inner list type
7347+
query ??????
7348+
select flatten(arrow_cast(make_array([1, 2], [3, 4]), 'List(FixedSizeList(2, Int64))')),
7349+
flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'List(FixedSizeList(1, List(Int64)))')),
7350+
flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')),
7351+
flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(List(List(Int64)))')),
7352+
flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(FixedSizeList(2, Int64))')),
7353+
flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(FixedSizeList(1, List(Int64)))'))
7354+
----
7355+
[1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]]
73477356

73487357
## empty (aliases: `array_empty`, `list_empty`)
73497358
# empty scalar function #1

0 commit comments

Comments
 (0)