Skip to content

feat: support array_repeat #1680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,14 +497,14 @@ pub enum InputBatch {
/// The end of input batches.
EOF,

/// A normal batch with columns and number of rows.
/// It is possible to have zero-column batch with non-zero number of rows,
/// A normal batch with columns and a number of rows.
/// It is possible to have a zero-column batch with a non-zero number of rows,
/// i.e. reading empty schema from scan.
Batch(Vec<ArrayRef>, usize),
}

impl InputBatch {
/// Constructs a `InputBatch` from columns and optional number of rows.
/// Constructs an ` InputBatch ` from columns and an optional number of rows.
/// If `num_rows` is none, this function will calculate it from given
/// columns.
pub fn new(columns: Vec<ArrayRef>, num_rows: Option<usize>) -> Self {
Expand Down
132 changes: 130 additions & 2 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2505,7 +2505,7 @@ mod tests {

use futures::{poll, StreamExt};

use arrow::array::{DictionaryArray, Int32Array, StringArray};
use arrow::array::{Array, DictionaryArray, Int32Array, StringArray};
use arrow::datatypes::DataType;
use datafusion::logical_expr::ScalarUDF;
use datafusion::{assert_batches_eq, physical_plan::common::collect, prelude::SessionContext};
Expand Down Expand Up @@ -2912,7 +2912,6 @@ mod tests {

// Separate thread to send the EOF signal once we've processed the only input batch
runtime.spawn(async move {
// Create a dictionary array with 100 values, and use it as input to the execution.
let a = Int32Array::from(vec![0, 3]);
let b = Int32Array::from(vec![1, 4]);
let c = Int32Array::from(vec![2, 5]);
Expand Down Expand Up @@ -2953,4 +2952,133 @@ mod tests {
}
});
}

#[test]
fn test_array_repeat() {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let planner = PhysicalPlanner::new(Arc::from(session_ctx));

// Mock scan operator with 3 INT32 columns
let op_scan = Operator {
plan_id: 0,
children: vec![],
op_struct: Some(OpStruct::Scan(spark_operator::Scan {
fields: vec![
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
spark_expression::DataType {
type_id: 3, // Int32
type_info: None,
},
],
source: "".to_string(),
})),
};

// Mock expression to read a INT32 column with position 0
let array_col = spark_expression::Expr {
expr_struct: Some(Bound(spark_expression::BoundReference {
index: 0,
datatype: Some(spark_expression::DataType {
type_id: 3,
type_info: None,
}),
})),
};

// Mock expression to read a INT32 column with position 1
let array_col_1 = spark_expression::Expr {
expr_struct: Some(Bound(spark_expression::BoundReference {
index: 1,
datatype: Some(spark_expression::DataType {
type_id: 3,
type_info: None,
}),
})),
};

// Make a projection operator with array_repeat(array_col, array_col_1)
let projection = Operator {
children: vec![op_scan],
plan_id: 0,
op_struct: Some(OpStruct::Projection(spark_operator::Projection {
project_list: vec![spark_expression::Expr {
expr_struct: Some(ExprStruct::ScalarFunc(spark_expression::ScalarFunc {
func: "array_repeat".to_string(),
args: vec![array_col, array_col_1],
return_type: None,
})),
}],
})),
};

// Create a physical plan
let (mut scans, datafusion_plan) =
planner.create_plan(&projection, &mut vec![], 1).unwrap();

// Feed the data into plan
//scans[0].set_input_batch(input_batch);

// Start executing the plan in a separate thread
// The plan waits for incoming batches and emitting result as input comes
let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap();

let runtime = tokio::runtime::Runtime::new().unwrap();
// create async channel
let (tx, mut rx) = mpsc::channel(1);

// Send data as input to the plan being executed in a separate thread
runtime.spawn(async move {
// create data batch
// 0, 1, 2
// 3, 4, 5
// 6, null, null
let a = Int32Array::from(vec![Some(0), Some(3), Some(6)]);
let b = Int32Array::from(vec![Some(1), Some(4), None]);
let c = Int32Array::from(vec![Some(2), Some(5), None]);
let input_batch1 = InputBatch::Batch(vec![Arc::new(a), Arc::new(b), Arc::new(c)], 3);
let input_batch2 = InputBatch::EOF;

let batches = vec![input_batch1, input_batch2];

for batch in batches.into_iter() {
tx.send(batch).await.unwrap();
}
});

// Wait for the plan to finish executing and assert the result
runtime.block_on(async move {
loop {
let batch = rx.recv().await.unwrap();
scans[0].set_input_batch(batch);
match poll!(stream.next()) {
Poll::Ready(Some(batch)) => {
assert!(batch.is_ok(), "got error {}", batch.unwrap_err());
let batch = batch.unwrap();
let expected = [
"+--------------+",
"| col_0 |",
"+--------------+",
"| [0] |",
"| [3, 3, 3, 3] |",
"| |",
"+--------------+",
];
assert_batches_eq!(expected, &[batch]);
}
Poll::Ready(None) => {
break;
}
_ => {}
}
}
});
}
}
216 changes: 216 additions & 0 deletions native/spark-expr/src/array_funcs/array_repeat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{
new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData,
NullBufferBuilder, OffsetSizeTrait, UInt64Array,
};
use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::compute::cast;
use arrow::datatypes::DataType::{LargeList, List};
use arrow::datatypes::{DataType, Field};
use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array};
use datafusion::common::{exec_err, DataFusionError, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use std::sync::Arc;

pub fn make_scalar_function<F>(
inner: F,
) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue, DataFusionError>
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError>,
{
move |args: &[ColumnarValue]| {
// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();

let args = ColumnarValue::values_to_arrays(args)?;

let result = (inner)(&args);

if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
}
}

pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
make_scalar_function(spark_array_repeat_inner)(args)
}

/// Array_repeat SQL function
fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result<ArrayRef> {
let element = &args[0];
let count_array = &args[1];

let count_array = match count_array.data_type() {
DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
DataType::UInt64 => count_array,
_ => return exec_err!("count must be an integer type"),
};

let count_array = as_uint64_array(count_array)?;

match element.data_type() {
List(_) => {
let list_array = as_list_array(element)?;
general_list_repeat::<i32>(list_array, count_array)
}
LargeList(_) => {
let list_array = as_large_list_array(element)?;
general_list_repeat::<i64>(list_array, count_array)
}
_ => general_repeat::<i32>(element, count_array),
}
}

/// For each element of `array[i]` repeat `count_array[i]` times.
///
/// Assumption for the input:
/// 1. `count[i] >= 0`
/// 2. `array.len() == count_array.len()`
///
/// For example,
/// ```text
/// array_repeat(
/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
/// )
/// ```
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &UInt64Array,
) -> datafusion::common::Result<ArrayRef> {
let data_type = array.data_type();
let mut new_values = vec![];

let count_vec = count_array
.values()
.to_vec()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();

let mut nulls = NullBufferBuilder::new(count_array.len());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual fix to have nulls buffer and have response as null if count is null


for (row_index, &count) in count_vec.iter().enumerate() {
nulls.append(!count_array.is_null(row_index));
let repeated_array = if array.is_null(row_index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the result be a null array if the count is zero ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be empty array, added this test case as well

new_null_array(data_type, count)
} else {
let original_data = array.to_data();
let capacity = Capacities::Array(count);
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);

for _ in 0..count {
mutable.extend(0, row_index, row_index + 1);
}

let data = mutable.freeze();
arrow::array::make_array(data)
};
new_values.push(repeated_array);
}

let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = compute::concat(&new_values)?;

Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::new(Field::new_list_field(data_type.to_owned(), true)),
OffsetBuffer::from_lengths(count_vec),
values,
nulls.finish(),
)?))
}

/// Handle List version of `general_repeat`
///
/// For each element of `list_array[i]` repeat `count_array[i]` times.
///
/// For example,
/// ```text
/// array_repeat(
/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]]
/// )
/// ```
fn general_list_repeat<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
count_array: &UInt64Array,
) -> datafusion::common::Result<ArrayRef> {
let data_type = list_array.data_type();
let value_type = list_array.value_type();
let mut new_values = vec![];

let count_vec = count_array
.values()
.to_vec()
.iter()
.map(|x| *x as usize)
.collect::<Vec<_>>();

for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) {
let list_arr = match list_array_row {
Some(list_array_row) => {
let original_data = list_array_row.to_data();
let capacity = Capacities::Array(original_data.len() * count);
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);

for _ in 0..count {
mutable.extend(0, 0, original_data.len());
}

let data = mutable.freeze();
let repeated_array = arrow::array::make_array(data);

let list_arr = GenericListArray::<O>::try_new(
Arc::new(Field::new_list_field(value_type.clone(), true)),
OffsetBuffer::<O>::from_lengths(vec![original_data.len(); count]),
repeated_array,
None,
)?;
Arc::new(list_arr) as ArrayRef
}
None => new_null_array(data_type, count),
};
new_values.push(list_arr);
}

let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
let values = compute::concat(&new_values)?;

Ok(Arc::new(ListArray::try_new(
Arc::new(Field::new_list_field(data_type.to_owned(), true)),
OffsetBuffer::<i32>::from_lengths(lengths),
values,
None,
)?))
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/array_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
// under the License.

mod array_insert;
mod array_repeat;
mod get_array_struct_fields;
mod list_extract;

pub use array_insert::ArrayInsert;
pub use array_repeat::spark_array_repeat;
pub use get_array_struct_fields::GetArrayStructFields;
pub use list_extract::ListExtract;
Loading
Loading