Skip to content

Commit dbf2fb7

Browse files
authored
feat: support array_repeat (#1680)
* feat: support `array_repeat`
1 parent 2dd887b commit dbf2fb7

File tree

8 files changed

+399
-9
lines changed

8 files changed

+399
-9
lines changed

native/core/src/execution/operators/scan.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,14 +497,14 @@ pub enum InputBatch {
497497
/// The end of input batches.
498498
EOF,
499499

500-
/// A normal batch with columns and number of rows.
501-
/// It is possible to have zero-column batch with non-zero number of rows,
500+
/// A normal batch with columns and a number of rows.
501+
/// It is possible to have a zero-column batch with a non-zero number of rows,
502502
/// i.e. reading empty schema from scan.
503503
Batch(Vec<ArrayRef>, usize),
504504
}
505505

506506
impl InputBatch {
507-
/// Constructs a `InputBatch` from columns and optional number of rows.
507+
/// Constructs an ` InputBatch ` from columns and an optional number of rows.
508508
/// If `num_rows` is none, this function will calculate it from given
509509
/// columns.
510510
pub fn new(columns: Vec<ArrayRef>, num_rows: Option<usize>) -> Self {

native/core/src/execution/planner.rs

Lines changed: 130 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,7 +2505,7 @@ mod tests {
25052505

25062506
use futures::{poll, StreamExt};
25072507

2508-
use arrow::array::{DictionaryArray, Int32Array, StringArray};
2508+
use arrow::array::{Array, DictionaryArray, Int32Array, StringArray};
25092509
use arrow::datatypes::DataType;
25102510
use datafusion::logical_expr::ScalarUDF;
25112511
use datafusion::{assert_batches_eq, physical_plan::common::collect, prelude::SessionContext};
@@ -2912,7 +2912,6 @@ mod tests {
29122912

29132913
// Separate thread to send the EOF signal once we've processed the only input batch
29142914
runtime.spawn(async move {
2915-
// Create a dictionary array with 100 values, and use it as input to the execution.
29162915
let a = Int32Array::from(vec![0, 3]);
29172916
let b = Int32Array::from(vec![1, 4]);
29182917
let c = Int32Array::from(vec![2, 5]);
@@ -2953,4 +2952,133 @@ mod tests {
29532952
}
29542953
});
29552954
}
2955+
2956+
#[test]
2957+
fn test_array_repeat() {
2958+
let session_ctx = SessionContext::new();
2959+
let task_ctx = session_ctx.task_ctx();
2960+
let planner = PhysicalPlanner::new(Arc::from(session_ctx));
2961+
2962+
// Mock scan operator with 3 INT32 columns
2963+
let op_scan = Operator {
2964+
plan_id: 0,
2965+
children: vec![],
2966+
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
2967+
fields: vec![
2968+
spark_expression::DataType {
2969+
type_id: 3, // Int32
2970+
type_info: None,
2971+
},
2972+
spark_expression::DataType {
2973+
type_id: 3, // Int32
2974+
type_info: None,
2975+
},
2976+
spark_expression::DataType {
2977+
type_id: 3, // Int32
2978+
type_info: None,
2979+
},
2980+
],
2981+
source: "".to_string(),
2982+
})),
2983+
};
2984+
2985+
// Mock expression to read a INT32 column with position 0
2986+
let array_col = spark_expression::Expr {
2987+
expr_struct: Some(Bound(spark_expression::BoundReference {
2988+
index: 0,
2989+
datatype: Some(spark_expression::DataType {
2990+
type_id: 3,
2991+
type_info: None,
2992+
}),
2993+
})),
2994+
};
2995+
2996+
// Mock expression to read a INT32 column with position 1
2997+
let array_col_1 = spark_expression::Expr {
2998+
expr_struct: Some(Bound(spark_expression::BoundReference {
2999+
index: 1,
3000+
datatype: Some(spark_expression::DataType {
3001+
type_id: 3,
3002+
type_info: None,
3003+
}),
3004+
})),
3005+
};
3006+
3007+
// Make a projection operator with array_repeat(array_col, array_col_1)
3008+
let projection = Operator {
3009+
children: vec![op_scan],
3010+
plan_id: 0,
3011+
op_struct: Some(OpStruct::Projection(spark_operator::Projection {
3012+
project_list: vec![spark_expression::Expr {
3013+
expr_struct: Some(ExprStruct::ScalarFunc(spark_expression::ScalarFunc {
3014+
func: "array_repeat".to_string(),
3015+
args: vec![array_col, array_col_1],
3016+
return_type: None,
3017+
})),
3018+
}],
3019+
})),
3020+
};
3021+
3022+
// Create a physical plan
3023+
let (mut scans, datafusion_plan) =
3024+
planner.create_plan(&projection, &mut vec![], 1).unwrap();
3025+
3026+
// Feed the data into plan
3027+
//scans[0].set_input_batch(input_batch);
3028+
3029+
// Start executing the plan in a separate thread
3030+
// The plan waits for incoming batches and emitting result as input comes
3031+
let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();
3032+
3033+
let runtime = tokio::runtime::Runtime::new().unwrap();
3034+
// create async channel
3035+
let (tx, mut rx) = mpsc::channel(1);
3036+
3037+
// Send data as input to the plan being executed in a separate thread
3038+
runtime.spawn(async move {
3039+
// create data batch
3040+
// 0, 1, 2
3041+
// 3, 4, 5
3042+
// 6, null, null
3043+
let a = Int32Array::from(vec![Some(0), Some(3), Some(6)]);
3044+
let b = Int32Array::from(vec![Some(1), Some(4), None]);
3045+
let c = Int32Array::from(vec![Some(2), Some(5), None]);
3046+
let input_batch1 = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 3);
3047+
let input_batch2 = InputBatch::EOF;
3048+
3049+
let batches = vec![input_batch1, input_batch2];
3050+
3051+
for batch in batches.into_iter() {
3052+
tx.send(batch).await.unwrap();
3053+
}
3054+
});
3055+
3056+
// Wait for the plan to finish executing and assert the result
3057+
runtime.block_on(async move {
3058+
loop {
3059+
let batch = rx.recv().await.unwrap();
3060+
scans[0].set_input_batch(batch);
3061+
match poll!(stream.next()) {
3062+
Poll::Ready(Some(batch)) => {
3063+
assert!(batch.is_ok(), "got error {}", batch.unwrap_err());
3064+
let batch = batch.unwrap();
3065+
let expected = [
3066+
"+--------------+",
3067+
"| col_0 |",
3068+
"+--------------+",
3069+
"| [0] |",
3070+
"| [3, 3, 3, 3] |",
3071+
"| |",
3072+
"+--------------+",
3073+
];
3074+
assert_batches_eq!(expected, &[batch]);
3075+
}
3076+
Poll::Ready(None) => {
3077+
break;
3078+
}
3079+
_ => {}
3080+
}
3081+
}
3082+
});
3083+
}
29563084
}
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData,
20+
NullBufferBuilder, OffsetSizeTrait, UInt64Array,
21+
};
22+
use arrow::buffer::OffsetBuffer;
23+
use arrow::compute;
24+
use arrow::compute::cast;
25+
use arrow::datatypes::DataType::{LargeList, List};
26+
use arrow::datatypes::{DataType, Field};
27+
use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array};
28+
use datafusion::common::{exec_err, DataFusionError, ScalarValue};
29+
use datafusion::logical_expr::ColumnarValue;
30+
use std::sync::Arc;
31+
32+
pub fn make_scalar_function<F>(
33+
inner: F,
34+
) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue, DataFusionError>
35+
where
36+
F: Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError>,
37+
{
38+
move |args: &[ColumnarValue]| {
39+
// first, identify if any of the arguments is an Array. If yes, store its `len`,
40+
// as any scalar will need to be converted to an array of len `len`.
41+
let len = args
42+
.iter()
43+
.fold(Option::<usize>::None, |acc, arg| match arg {
44+
ColumnarValue::Scalar(_) => acc,
45+
ColumnarValue::Array(a) => Some(a.len()),
46+
});
47+
48+
let is_scalar = len.is_none();
49+
50+
let args = ColumnarValue::values_to_arrays(args)?;
51+
52+
let result = (inner)(&args);
53+
54+
if is_scalar {
55+
// If all inputs are scalar, keeps output as scalar
56+
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
57+
result.map(ColumnarValue::Scalar)
58+
} else {
59+
result.map(ColumnarValue::Array)
60+
}
61+
}
62+
}
63+
64+
pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
65+
make_scalar_function(spark_array_repeat_inner)(args)
66+
}
67+
68+
/// Array_repeat SQL function
69+
fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result<ArrayRef> {
70+
let element = &args[0];
71+
let count_array = &args[1];
72+
73+
let count_array = match count_array.data_type() {
74+
DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
75+
DataType::UInt64 => count_array,
76+
_ => return exec_err!("count must be an integer type"),
77+
};
78+
79+
let count_array = as_uint64_array(count_array)?;
80+
81+
match element.data_type() {
82+
List(_) => {
83+
let list_array = as_list_array(element)?;
84+
general_list_repeat::<i32>(list_array, count_array)
85+
}
86+
LargeList(_) => {
87+
let list_array = as_large_list_array(element)?;
88+
general_list_repeat::<i64>(list_array, count_array)
89+
}
90+
_ => general_repeat::<i32>(element, count_array),
91+
}
92+
}
93+
94+
/// For each element of `array[i]` repeat `count_array[i]` times.
95+
///
96+
/// Assumption for the input:
97+
/// 1. `count[i] >= 0`
98+
/// 2. `array.len() == count_array.len()`
99+
///
100+
/// For example,
101+
/// ```text
102+
/// array_repeat(
103+
/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
104+
/// )
105+
/// ```
106+
fn general_repeat<O: OffsetSizeTrait>(
107+
array: &ArrayRef,
108+
count_array: &UInt64Array,
109+
) -> datafusion::common::Result<ArrayRef> {
110+
let data_type = array.data_type();
111+
let mut new_values = vec![];
112+
113+
let count_vec = count_array
114+
.values()
115+
.to_vec()
116+
.iter()
117+
.map(|x| *x as usize)
118+
.collect::<Vec<_>>();
119+
120+
let mut nulls = NullBufferBuilder::new(count_array.len());
121+
122+
for (row_index, &count) in count_vec.iter().enumerate() {
123+
nulls.append(!count_array.is_null(row_index));
124+
let repeated_array = if array.is_null(row_index) {
125+
new_null_array(data_type, count)
126+
} else {
127+
let original_data = array.to_data();
128+
let capacity = Capacities::Array(count);
129+
let mut mutable =
130+
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
131+
132+
for _ in 0..count {
133+
mutable.extend(0, row_index, row_index + 1);
134+
}
135+
136+
let data = mutable.freeze();
137+
arrow::array::make_array(data)
138+
};
139+
new_values.push(repeated_array);
140+
}
141+
142+
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
143+
let values = compute::concat(&new_values)?;
144+
145+
Ok(Arc::new(GenericListArray::<O>::try_new(
146+
Arc::new(Field::new_list_field(data_type.to_owned(), true)),
147+
OffsetBuffer::from_lengths(count_vec),
148+
values,
149+
nulls.finish(),
150+
)?))
151+
}
152+
153+
/// Handle List version of `general_repeat`
154+
///
155+
/// For each element of `list_array[i]` repeat `count_array[i]` times.
156+
///
157+
/// For example,
158+
/// ```text
159+
/// array_repeat(
160+
/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]]
161+
/// )
162+
/// ```
163+
fn general_list_repeat<O: OffsetSizeTrait>(
164+
list_array: &GenericListArray<O>,
165+
count_array: &UInt64Array,
166+
) -> datafusion::common::Result<ArrayRef> {
167+
let data_type = list_array.data_type();
168+
let value_type = list_array.value_type();
169+
let mut new_values = vec![];
170+
171+
let count_vec = count_array
172+
.values()
173+
.to_vec()
174+
.iter()
175+
.map(|x| *x as usize)
176+
.collect::<Vec<_>>();
177+
178+
for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) {
179+
let list_arr = match list_array_row {
180+
Some(list_array_row) => {
181+
let original_data = list_array_row.to_data();
182+
let capacity = Capacities::Array(original_data.len() * count);
183+
let mut mutable =
184+
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
185+
186+
for _ in 0..count {
187+
mutable.extend(0, 0, original_data.len());
188+
}
189+
190+
let data = mutable.freeze();
191+
let repeated_array = arrow::array::make_array(data);
192+
193+
let list_arr = GenericListArray::<O>::try_new(
194+
Arc::new(Field::new_list_field(value_type.clone(), true)),
195+
OffsetBuffer::<O>::from_lengths(vec![original_data.len(); count]),
196+
repeated_array,
197+
None,
198+
)?;
199+
Arc::new(list_arr) as ArrayRef
200+
}
201+
None => new_null_array(data_type, count),
202+
};
203+
new_values.push(list_arr);
204+
}
205+
206+
let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
207+
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
208+
let values = compute::concat(&new_values)?;
209+
210+
Ok(Arc::new(ListArray::try_new(
211+
Arc::new(Field::new_list_field(data_type.to_owned(), true)),
212+
OffsetBuffer::<i32>::from_lengths(lengths),
213+
values,
214+
None,
215+
)?))
216+
}

native/spark-expr/src/array_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
// under the License.
1717

1818
mod array_insert;
19+
mod array_repeat;
1920
mod get_array_struct_fields;
2021
mod list_extract;
2122

2223
pub use array_insert::ArrayInsert;
24+
pub use array_repeat::spark_array_repeat;
2325
pub use get_array_struct_fields::GetArrayStructFields;
2426
pub use list_extract::ListExtract;

0 commit comments

Comments
 (0)