Skip to content

Commit 6a895c6

Browse files
authored
More missing array funcs (#605)
* Add `array_distinct` function * Add `range` function * Add `list_distinct` alias * Add `array_intersect` scalar function * Add `array_union` scalar function * Add `array_except` scalar function * Add `array_resize` scalar function
1 parent 18ac182 commit 6a895c6

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

datafusion/tests/test_functions.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,15 @@ def py_arr_replace(arr, from_, to, n=None):
220220

221221
return new_arr
222222

223+
def py_arr_resize(arr, size, value):
224+
arr = np.asarray(arr)
225+
return np.pad(
226+
arr,
227+
[(0, size - arr.shape[0])],
228+
"constant",
229+
constant_values=value,
230+
)
231+
223232
def py_flatten(arr):
224233
result = []
225234
for elem in arr:
@@ -259,6 +268,14 @@ def py_flatten(arr):
259268
f.array_dims(col),
260269
lambda: [[len(r)] for r in data],
261270
],
271+
[
272+
f.array_distinct(col),
273+
lambda: [list(set(r)) for r in data],
274+
],
275+
[
276+
f.list_distinct(col),
277+
lambda: [list(set(r)) for r in data],
278+
],
262279
[
263280
f.list_dims(col),
264281
lambda: [[len(r)] for r in data],
@@ -415,7 +432,43 @@ def py_flatten(arr):
415432
f.list_slice(col, literal(-1), literal(2)),
416433
lambda: [arr[-1:2] for arr in data],
417434
],
435+
[
436+
f.array_intersect(col, literal([3.0, 4.0])),
437+
lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
438+
],
439+
[
440+
f.list_intersect(col, literal([3.0, 4.0])),
441+
lambda: [np.intersect1d(arr, [3.0, 4.0]) for arr in data],
442+
],
443+
[
444+
f.array_union(col, literal([12.0, 999.0])),
445+
lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data],
446+
],
447+
[
448+
f.list_union(col, literal([12.0, 999.0])),
449+
lambda: [np.union1d(arr, [12.0, 999.0]) for arr in data],
450+
],
451+
[
452+
f.array_except(col, literal([3.0])),
453+
lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
454+
],
455+
[
456+
f.list_except(col, literal([3.0])),
457+
lambda: [np.setdiff1d(arr, [3.0]) for arr in data],
458+
],
459+
[
460+
f.array_resize(col, literal(10), literal(0.0)),
461+
lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
462+
],
463+
[
464+
f.list_resize(col, literal(10), literal(0.0)),
465+
lambda: [py_arr_resize(arr, 10, 0.0) for arr in data],
466+
],
418467
[f.flatten(literal(data)), lambda: [py_flatten(data)]],
468+
[
469+
f.range(literal(1), literal(5), literal(2)),
470+
lambda: [np.arange(1, 5, 2)],
471+
],
419472
]
420473

421474
for stmt, py_expr in test_items:

src/functions.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ scalar_function!(trunc, Trunc);
391391
scalar_function!(upper, Upper, "Converts the string to all upper case.");
392392
scalar_function!(make_array, MakeArray);
393393
scalar_function!(array, MakeArray);
394+
scalar_function!(range, Range);
394395
scalar_function!(uuid, Uuid);
395396
scalar_function!(r#struct, Struct); // Use raw identifier since struct is a keyword
396397
scalar_function!(from_unixtime, FromUnixtime);
@@ -405,6 +406,8 @@ scalar_function!(list_push_back, ArrayAppend);
405406
scalar_function!(array_concat, ArrayConcat);
406407
scalar_function!(array_cat, ArrayConcat);
407408
scalar_function!(array_dims, ArrayDims);
409+
scalar_function!(array_distinct, ArrayDistinct);
410+
scalar_function!(list_distinct, ArrayDistinct);
408411
scalar_function!(list_dims, ArrayDims);
409412
scalar_function!(array_element, ArrayElement);
410413
scalar_function!(array_extract, ArrayElement);
@@ -444,6 +447,14 @@ scalar_function!(array_replace_all, ArrayReplaceAll);
444447
scalar_function!(list_replace_all, ArrayReplaceAll);
445448
scalar_function!(array_slice, ArraySlice);
446449
scalar_function!(list_slice, ArraySlice);
450+
scalar_function!(array_intersect, ArrayIntersect);
451+
scalar_function!(list_intersect, ArrayIntersect);
452+
scalar_function!(array_union, ArrayUnion);
453+
scalar_function!(list_union, ArrayUnion);
454+
scalar_function!(array_except, ArrayExcept);
455+
scalar_function!(list_except, ArrayExcept);
456+
scalar_function!(array_resize, ArrayResize);
457+
scalar_function!(list_resize, ArrayResize);
447458
scalar_function!(flatten, Flatten);
448459

449460
aggregate_function!(approx_distinct, ApproxDistinct);
@@ -499,6 +510,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
499510
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
500511
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
501512
m.add_wrapped(wrap_pyfunction!(array))?;
513+
m.add_wrapped(wrap_pyfunction!(range))?;
502514
m.add_wrapped(wrap_pyfunction!(array_agg))?;
503515
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
504516
m.add_wrapped(wrap_pyfunction!(ascii))?;
@@ -644,6 +656,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
644656
m.add_wrapped(wrap_pyfunction!(array_concat))?;
645657
m.add_wrapped(wrap_pyfunction!(array_cat))?;
646658
m.add_wrapped(wrap_pyfunction!(array_dims))?;
659+
m.add_wrapped(wrap_pyfunction!(array_distinct))?;
660+
m.add_wrapped(wrap_pyfunction!(list_distinct))?;
647661
m.add_wrapped(wrap_pyfunction!(list_dims))?;
648662
m.add_wrapped(wrap_pyfunction!(array_element))?;
649663
m.add_wrapped(wrap_pyfunction!(array_extract))?;
@@ -661,6 +675,14 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
661675
m.add_wrapped(wrap_pyfunction!(array_positions))?;
662676
m.add_wrapped(wrap_pyfunction!(list_positions))?;
663677
m.add_wrapped(wrap_pyfunction!(array_to_string))?;
678+
m.add_wrapped(wrap_pyfunction!(array_intersect))?;
679+
m.add_wrapped(wrap_pyfunction!(list_intersect))?;
680+
m.add_wrapped(wrap_pyfunction!(array_union))?;
681+
m.add_wrapped(wrap_pyfunction!(list_union))?;
682+
m.add_wrapped(wrap_pyfunction!(array_except))?;
683+
m.add_wrapped(wrap_pyfunction!(list_except))?;
684+
m.add_wrapped(wrap_pyfunction!(array_resize))?;
685+
m.add_wrapped(wrap_pyfunction!(list_resize))?;
664686
m.add_wrapped(wrap_pyfunction!(array_join))?;
665687
m.add_wrapped(wrap_pyfunction!(list_to_string))?;
666688
m.add_wrapped(wrap_pyfunction!(list_join))?;

0 commit comments

Comments
 (0)