Skip to content

Commit 7ce574f

Browse files
committed
Add array_resize scalar function
1 parent 3e79905 commit 7ce574f

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

datafusion/tests/test_functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,15 @@ def py_arr_replace(arr, from_, to, n=None):
220220

221221
return new_arr
222222

223+
def py_arr_resize(arr, size, value):
224+
arr = np.asarray(arr)
225+
return np.pad(
226+
arr,
227+
[(0, size - arr.shape[0])],
228+
"constant",
229+
constant_values=value,
230+
)
231+
223232
def py_flatten(arr):
224233
result = []
225234
for elem in arr:
@@ -447,6 +456,14 @@ def py_flatten(arr):
447456
f.list_except(col, literal([3.0])),
448457
lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
449458
],
459+
[
460+
f.array_resize(col, literal(10), literal(0.0)),
461+
lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
462+
],
463+
[
464+
f.list_resize(col, literal(10), literal(0.0)),
465+
lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
466+
],
450467
[f.flatten(literal(data)), lambda: [py_flatten(data)]],
451468
[
452469
f.range(literal(1), literal(5), literal(2)),

src/functions.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ scalar_function!(array_union, ArrayUnion);
453453
scalar_function!(list_union, ArrayUnion);
454454
scalar_function!(array_except, ArrayExcept);
455455
scalar_function!(list_except, ArrayExcept);
456+
scalar_function!(array_resize, ArrayResize);
457+
scalar_function!(list_resize, ArrayResize);
456458
scalar_function!(flatten, Flatten);
457459

458460
aggregate_function!(approx_distinct, ApproxDistinct);
@@ -679,6 +681,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
679681
m.add_wrapped(wrap_pyfunction!(list_union))?;
680682
m.add_wrapped(wrap_pyfunction!(array_except))?;
681683
m.add_wrapped(wrap_pyfunction!(list_except))?;
684+
m.add_wrapped(wrap_pyfunction!(array_resize))?;
685+
m.add_wrapped(wrap_pyfunction!(list_resize))?;
682686
m.add_wrapped(wrap_pyfunction!(array_join))?;
683687
m.add_wrapped(wrap_pyfunction!(list_to_string))?;
684688
m.add_wrapped(wrap_pyfunction!(list_join))?;

0 commit comments

Comments
 (0)