|
| 1 | +// Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +// or more contributor license agreements. See the NOTICE file |
| 3 | +// distributed with this work for additional information |
| 4 | +// regarding copyright ownership. The ASF licenses this file |
| 5 | +// to you under the Apache License, Version 2.0 (the |
| 6 | +// "License"); you may not use this file except in compliance |
| 7 | +// with the License. You may obtain a copy of the License at |
| 8 | +// |
| 9 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +// |
| 11 | +// Unless required by applicable law or agreed to in writing, |
| 12 | +// software distributed under the License is distributed on an |
| 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +// KIND, either express or implied. See the License for the |
| 15 | +// specific language governing permissions and limitations |
| 16 | +// under the License. |
| 17 | + |
| 18 | +use arrow::array::{ |
| 19 | + new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData, |
| 20 | + NullBufferBuilder, OffsetSizeTrait, UInt64Array, |
| 21 | +}; |
| 22 | +use arrow::buffer::OffsetBuffer; |
| 23 | +use arrow::compute; |
| 24 | +use arrow::compute::cast; |
| 25 | +use arrow::datatypes::DataType::{LargeList, List}; |
| 26 | +use arrow::datatypes::{DataType, Field}; |
| 27 | +use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array}; |
| 28 | +use datafusion::common::{exec_err, DataFusionError, ScalarValue}; |
| 29 | +use datafusion::logical_expr::ColumnarValue; |
| 30 | +use std::sync::Arc; |
| 31 | + |
| 32 | +pub fn make_scalar_function<F>( |
| 33 | + inner: F, |
| 34 | +) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> |
| 35 | +where |
| 36 | + F: Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError>, |
| 37 | +{ |
| 38 | + move |args: &[ColumnarValue]| { |
| 39 | + // first, identify if any of the arguments is an Array. If yes, store its `len`, |
| 40 | + // as any scalar will need to be converted to an array of len `len`. |
| 41 | + let len = args |
| 42 | + .iter() |
| 43 | + .fold(Option::<usize>::None, |acc, arg| match arg { |
| 44 | + ColumnarValue::Scalar(_) => acc, |
| 45 | + ColumnarValue::Array(a) => Some(a.len()), |
| 46 | + }); |
| 47 | + |
| 48 | + let is_scalar = len.is_none(); |
| 49 | + |
| 50 | + let args = ColumnarValue::values_to_arrays(args)?; |
| 51 | + |
| 52 | + let result = (inner)(&args); |
| 53 | + |
| 54 | + if is_scalar { |
| 55 | + // If all inputs are scalar, keeps output as scalar |
| 56 | + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); |
| 57 | + result.map(ColumnarValue::Scalar) |
| 58 | + } else { |
| 59 | + result.map(ColumnarValue::Array) |
| 60 | + } |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> { |
| 65 | + make_scalar_function(spark_array_repeat_inner)(args) |
| 66 | +} |
| 67 | + |
| 68 | +/// Array_repeat SQL function |
| 69 | +fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result<ArrayRef> { |
| 70 | + let element = &args[0]; |
| 71 | + let count_array = &args[1]; |
| 72 | + |
| 73 | + let count_array = match count_array.data_type() { |
| 74 | + DataType::Int64 => &cast(count_array, &DataType::UInt64)?, |
| 75 | + DataType::UInt64 => count_array, |
| 76 | + _ => return exec_err!("count must be an integer type"), |
| 77 | + }; |
| 78 | + |
| 79 | + let count_array = as_uint64_array(count_array)?; |
| 80 | + |
| 81 | + match element.data_type() { |
| 82 | + List(_) => { |
| 83 | + let list_array = as_list_array(element)?; |
| 84 | + general_list_repeat::<i32>(list_array, count_array) |
| 85 | + } |
| 86 | + LargeList(_) => { |
| 87 | + let list_array = as_large_list_array(element)?; |
| 88 | + general_list_repeat::<i64>(list_array, count_array) |
| 89 | + } |
| 90 | + _ => general_repeat::<i32>(element, count_array), |
| 91 | + } |
| 92 | +} |
| 93 | + |
| 94 | +/// For each element of `array[i]` repeat `count_array[i]` times. |
| 95 | +/// |
| 96 | +/// Assumption for the input: |
| 97 | +/// 1. `count[i] >= 0` |
| 98 | +/// 2. `array.len() == count_array.len()` |
| 99 | +/// |
| 100 | +/// For example, |
| 101 | +/// ```text |
| 102 | +/// array_repeat( |
| 103 | +/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] |
| 104 | +/// ) |
| 105 | +/// ``` |
| 106 | +fn general_repeat<O: OffsetSizeTrait>( |
| 107 | + array: &ArrayRef, |
| 108 | + count_array: &UInt64Array, |
| 109 | +) -> datafusion::common::Result<ArrayRef> { |
| 110 | + let data_type = array.data_type(); |
| 111 | + let mut new_values = vec![]; |
| 112 | + |
| 113 | + let count_vec = count_array |
| 114 | + .values() |
| 115 | + .to_vec() |
| 116 | + .iter() |
| 117 | + .map(|x| *x as usize) |
| 118 | + .collect::<Vec<_>>(); |
| 119 | + |
| 120 | + let mut nulls = NullBufferBuilder::new(count_array.len()); |
| 121 | + |
| 122 | + for (row_index, &count) in count_vec.iter().enumerate() { |
| 123 | + nulls.append(!count_array.is_null(row_index)); |
| 124 | + let repeated_array = if array.is_null(row_index) { |
| 125 | + new_null_array(data_type, count) |
| 126 | + } else { |
| 127 | + let original_data = array.to_data(); |
| 128 | + let capacity = Capacities::Array(count); |
| 129 | + let mut mutable = |
| 130 | + MutableArrayData::with_capacities(vec![&original_data], false, capacity); |
| 131 | + |
| 132 | + for _ in 0..count { |
| 133 | + mutable.extend(0, row_index, row_index + 1); |
| 134 | + } |
| 135 | + |
| 136 | + let data = mutable.freeze(); |
| 137 | + arrow::array::make_array(data) |
| 138 | + }; |
| 139 | + new_values.push(repeated_array); |
| 140 | + } |
| 141 | + |
| 142 | + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); |
| 143 | + let values = compute::concat(&new_values)?; |
| 144 | + |
| 145 | + Ok(Arc::new(GenericListArray::<O>::try_new( |
| 146 | + Arc::new(Field::new_list_field(data_type.to_owned(), true)), |
| 147 | + OffsetBuffer::from_lengths(count_vec), |
| 148 | + values, |
| 149 | + nulls.finish(), |
| 150 | + )?)) |
| 151 | +} |
| 152 | + |
| 153 | +/// Handle List version of `general_repeat` |
| 154 | +/// |
| 155 | +/// For each element of `list_array[i]` repeat `count_array[i]` times. |
| 156 | +/// |
| 157 | +/// For example, |
| 158 | +/// ```text |
| 159 | +/// array_repeat( |
| 160 | +/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] |
| 161 | +/// ) |
| 162 | +/// ``` |
| 163 | +fn general_list_repeat<O: OffsetSizeTrait>( |
| 164 | + list_array: &GenericListArray<O>, |
| 165 | + count_array: &UInt64Array, |
| 166 | +) -> datafusion::common::Result<ArrayRef> { |
| 167 | + let data_type = list_array.data_type(); |
| 168 | + let value_type = list_array.value_type(); |
| 169 | + let mut new_values = vec![]; |
| 170 | + |
| 171 | + let count_vec = count_array |
| 172 | + .values() |
| 173 | + .to_vec() |
| 174 | + .iter() |
| 175 | + .map(|x| *x as usize) |
| 176 | + .collect::<Vec<_>>(); |
| 177 | + |
| 178 | + for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { |
| 179 | + let list_arr = match list_array_row { |
| 180 | + Some(list_array_row) => { |
| 181 | + let original_data = list_array_row.to_data(); |
| 182 | + let capacity = Capacities::Array(original_data.len() * count); |
| 183 | + let mut mutable = |
| 184 | + MutableArrayData::with_capacities(vec![&original_data], false, capacity); |
| 185 | + |
| 186 | + for _ in 0..count { |
| 187 | + mutable.extend(0, 0, original_data.len()); |
| 188 | + } |
| 189 | + |
| 190 | + let data = mutable.freeze(); |
| 191 | + let repeated_array = arrow::array::make_array(data); |
| 192 | + |
| 193 | + let list_arr = GenericListArray::<O>::try_new( |
| 194 | + Arc::new(Field::new_list_field(value_type.clone(), true)), |
| 195 | + OffsetBuffer::<O>::from_lengths(vec![original_data.len(); count]), |
| 196 | + repeated_array, |
| 197 | + None, |
| 198 | + )?; |
| 199 | + Arc::new(list_arr) as ArrayRef |
| 200 | + } |
| 201 | + None => new_null_array(data_type, count), |
| 202 | + }; |
| 203 | + new_values.push(list_arr); |
| 204 | + } |
| 205 | + |
| 206 | + let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>(); |
| 207 | + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); |
| 208 | + let values = compute::concat(&new_values)?; |
| 209 | + |
| 210 | + Ok(Arc::new(ListArray::try_new( |
| 211 | + Arc::new(Field::new_list_field(data_type.to_owned(), true)), |
| 212 | + OffsetBuffer::<i32>::from_lengths(lengths), |
| 213 | + values, |
| 214 | + None, |
| 215 | + )?)) |
| 216 | +} |
0 commit comments