Skip to content

Change flatten so it does only a level, not recursively #15160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 91 additions & 62 deletions datafusion/functions-nested/src/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@
//! [`ScalarUDFImpl`] definitions for flatten function.

use crate::utils::make_scalar_function;
use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait};
use arrow::array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::{
DataType,
DataType::{FixedSizeList, LargeList, List, Null},
};
use datafusion_common::cast::{
as_generic_list_array, as_large_list_array, as_list_array,
};
use datafusion_common::cast::{as_large_list_array, as_list_array};
use datafusion_common::utils::ListCoercion;
use datafusion_common::{exec_err, utils::take_function_args, Result};
use datafusion_expr::{
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
Expand Down Expand Up @@ -77,9 +76,11 @@ impl Flatten {
pub fn new() -> Self {
Self {
signature: Signature {
// TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::RecursiveArray,
ArrayFunctionSignature::Array {
arguments: vec![ArrayFunctionArgument::Array],
array_coercion: Some(ListCoercion::FixedSizedListToList),
},
),
volatility: Volatility::Immutable,
},
Expand All @@ -102,25 +103,23 @@ impl ScalarUDFImpl for Flatten {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
List(field) | FixedSizeList(field, _)
if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) =>
{
get_base_type(field.data_type())
}
LargeList(field) if matches!(field.data_type(), LargeList(_)) => {
get_base_type(field.data_type())
let data_type = match &arg_types[0] {
List(field) | FixedSizeList(field, _) => match field.data_type() {
List(field) | FixedSizeList(field, _) => List(Arc::clone(field)),
_ => arg_types[0].clone(),
},
LargeList(field) => match field.data_type() {
List(field) | LargeList(field) | FixedSizeList(field, _) => {
LargeList(Arc::clone(field))
}
Null | List(_) | LargeList(_) => Ok(data_type.to_owned()),
FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
_ => exec_err!(
"Not reachable, data_type should be List, LargeList or FixedSizeList"
),
}
}
_ => arg_types[0].clone(),
},
Null => Null,
_ => exec_err!(
"Not reachable, data_type should be List, LargeList or FixedSizeList"
)?,
};

let data_type = get_base_type(&arg_types[0])?;
Ok(data_type)
}

Expand All @@ -146,14 +145,62 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result<ArrayRef> {

match array.data_type() {
List(_) => {
let list_arr = as_list_array(&array)?;
let flattened_array = flatten_internal::<i32>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
let (field, offsets, values, nulls) =
as_list_array(&array)?.clone().into_parts();

match field.data_type() {
List(_) => {
let (inner_field, inner_offsets, inner_values, _) =
as_list_array(&values)?.clone().into_parts();
let offsets = get_offsets_for_flatten::<i32>(inner_offsets, offsets);
let flattened_array = GenericListArray::<i32>::new(
inner_field,
offsets,
inner_values,
nulls,
);

Ok(Arc::new(flattened_array) as ArrayRef)
}
LargeList(_) => {
exec_err!("flatten does not support type '{:?}'", array.data_type())?
}
_ => Ok(Arc::clone(array) as ArrayRef),
}
}
LargeList(_) => {
let list_arr = as_large_list_array(&array)?;
let flattened_array = flatten_internal::<i64>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
let (field, offsets, values, nulls) =
as_large_list_array(&array)?.clone().into_parts();

match field.data_type() {
List(_) => {
let (inner_field, inner_offsets, inner_values, _) =
as_list_array(&values)?.clone().into_parts();
let offsets = get_large_offsets_for_flatten(inner_offsets, offsets);
let flattened_array = GenericListArray::<i64>::new(
inner_field,
offsets,
inner_values,
nulls,
);

Ok(Arc::new(flattened_array) as ArrayRef)
}
LargeList(_) => {
let (inner_field, inner_offsets, inner_values, nulls) =
as_large_list_array(&values)?.clone().into_parts();
let offsets = get_offsets_for_flatten::<i64>(inner_offsets, offsets);
let flattened_array = GenericListArray::<i64>::new(
inner_field,
offsets,
inner_values,
nulls,
);

Ok(Arc::new(flattened_array) as ArrayRef)
}
_ => Ok(Arc::clone(array) as ArrayRef),
}
}
Null => Ok(Arc::clone(array)),
_ => {
Expand All @@ -162,37 +209,6 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

fn flatten_internal<O: OffsetSizeTrait>(
list_arr: GenericListArray<O>,
indexes: Option<OffsetBuffer<O>>,
) -> Result<GenericListArray<O>> {
let (field, offsets, values, _) = list_arr.clone().into_parts();
let data_type = field.data_type();

match data_type {
// Recursively get the base offsets for flattened array
List(_) | LargeList(_) => {
let sub_list = as_generic_list_array::<O>(&values)?;
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
flatten_internal::<O>(sub_list.clone(), Some(offsets))
} else {
flatten_internal::<O>(sub_list.clone(), Some(offsets))
}
}
// Reach the base level, create a new list array
_ => {
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
let list_arr = GenericListArray::<O>::new(field, offsets, values, None);
Ok(list_arr)
} else {
Ok(list_arr)
}
}
}
}

// Create new offsets that are equivalent to `flatten` the array.
fn get_offsets_for_flatten<O: OffsetSizeTrait>(
offsets: OffsetBuffer<O>,
Expand All @@ -205,3 +221,16 @@ fn get_offsets_for_flatten<O: OffsetSizeTrait>(
.collect();
OffsetBuffer::new(offsets.into())
}

// Create new large offsets that are equivalent to `flatten` the array.
fn get_large_offsets_for_flatten<O: OffsetSizeTrait, P: OffsetSizeTrait>(
offsets: OffsetBuffer<O>,
indexes: OffsetBuffer<P>,
) -> OffsetBuffer<i64> {
let buffer = offsets.into_inner();
let offsets: Vec<i64> = indexes
.iter()
.map(|i| buffer[i.to_usize().unwrap()].to_i64().unwrap())
.collect();
OffsetBuffer::new(offsets.into())
}
39 changes: 24 additions & 15 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -7090,34 +7090,32 @@ select array_concat(column1, [7]) from arrays_values_v2;

# flatten

#TODO: https://github.com/apache/datafusion/issues/7142
# follow DuckDB
#query ?
#select flatten(NULL);
#----
#NULL
query ?
select flatten(NULL);
----
NULL

# flatten with scalar values #1
query ???
select flatten(make_array(1, 2, 1, 3, 2)),
flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]));
----
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]

query ???
select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')),
flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'LargeList(LargeList(Int64))')),
flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 'LargeList(LargeList(LargeList(Float64)))'));
----
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]

query ???
select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'FixedSizeList(5, Int64)')),
flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'FixedSizeList(4, List(Int64))')),
flatten(arrow_cast(make_array([[1.1], [2.2]], [[3.3], [4.4]]), 'FixedSizeList(2, List(List(Float64)))'));
----
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [1.1, 2.2, 3.3, 4.4]
[1, 2, 1, 3, 2] [1, 2, 3, NULL, 4, NULL, 5] [[1.1], [2.2], [3.3], [4.4]]

# flatten with column values
query ????
Expand All @@ -7127,8 +7125,8 @@ select flatten(column1),
flatten(column4)
from flatten_table;
----
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

query ????
select flatten(column1),
Expand All @@ -7137,8 +7135,8 @@ select flatten(column1),
flatten(column4)
from large_flatten_table;
----
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [[8]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

query ????
select flatten(column1),
Expand All @@ -7147,8 +7145,19 @@ select flatten(column1),
flatten(column4)
from fixed_size_flatten_table;
----
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8, 9, 10, 11, 12, 13] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
[1, 2, 3] [[1, 2, 3], [4, 5], [6]] [[[1]], [[2, 3]]] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [[8], [9, 10], [11, 12, 13]] [[[1, 2]], [[3]]] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

# flatten with different inner list type
query ??????
select flatten(arrow_cast(make_array([1, 2], [3, 4]), 'List(FixedSizeList(2, Int64))')),
flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'List(FixedSizeList(1, List(Int64)))')),
flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(List(Int64))')),
flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(List(List(Int64)))')),
flatten(arrow_cast(make_array([1, 2], [3, 4]), 'LargeList(FixedSizeList(2, Int64))')),
flatten(arrow_cast(make_array([[1, 2]], [[3, 4]]), 'LargeList(FixedSizeList(1, List(Int64)))'))
----
[1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]] [1, 2, 3, 4] [[1, 2], [3, 4]]

## empty (aliases: `array_empty`, `list_empty`)
# empty scalar function #1
Expand Down