Skip to content

Commit 697ca2c

Browse files
authored
Add array functions (#560)
* Add array_has, array_has_all and array_has_any * Add array_position, array_indexof, list_position and list_indexof * Add array_to_string, array_join, list_to_string and list_join * Add array_ndims and list_ndims * Add array_push_back, list_append and list_push_back * Add array_prepend, array_push_front, list_prepend and list_push_front * Add array_pop_back and array_pop_front * Add array_positions and list_positions * Add array_remove, list_remove, array_remove_n, list_remove_n, array_remove_all and list_remove_all * Add array_repeat * Add array_replace, list_replace, array_replace_n, list_replace_n, array_replace_all, list_replace_all * Add array_slice and list_slice
1 parent 476ca22 commit 697ca2c

File tree

2 files changed

+286
-3
lines changed

2 files changed

+286
-3
lines changed

datafusion/tests/test_functions.py

Lines changed: 208 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,19 +200,62 @@ def test_math_functions():
200200

201201

202202
def test_array_functions():
203-
data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]
203+
data = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 3.0], [6.0]]
204204
ctx = SessionContext()
205205
batch = pa.RecordBatch.from_arrays(
206206
[np.array(data, dtype=object)], names=["arr"]
207207
)
208208
df = ctx.create_dataframe([[batch]])
209209

210+
def py_indexof(arr, v):
211+
try:
212+
return arr.index(v) + 1
213+
except ValueError:
214+
return np.nan
215+
216+
def py_arr_remove(arr, v, n=None):
217+
new_arr = arr[:]
218+
found = 0
219+
while found != n:
220+
try:
221+
new_arr.remove(v)
222+
found += 1
223+
except ValueError:
224+
break
225+
226+
return new_arr
227+
228+
def py_arr_replace(arr, from_, to, n=None):
229+
new_arr = arr[:]
230+
found = 0
231+
while found != n:
232+
try:
233+
idx = new_arr.index(from_)
234+
new_arr[idx] = to
235+
found += 1
236+
except ValueError:
237+
break
238+
239+
return new_arr
240+
210241
col = column("arr")
211242
test_items = [
212243
[
213244
f.array_append(col, literal(99.0)),
214245
lambda: [np.append(arr, 99.0) for arr in data],
215246
],
247+
[
248+
f.array_push_back(col, literal(99.0)),
249+
lambda: [np.append(arr, 99.0) for arr in data],
250+
],
251+
[
252+
f.list_append(col, literal(99.0)),
253+
lambda: [np.append(arr, 99.0) for arr in data],
254+
],
255+
[
256+
f.list_push_back(col, literal(99.0)),
257+
lambda: [np.append(arr, 99.0) for arr in data],
258+
],
216259
[
217260
f.array_concat(col, col),
218261
lambda: [np.concatenate([arr, arr]) for arr in data],
@@ -253,12 +296,174 @@ def test_array_functions():
253296
f.list_length(col),
254297
lambda: [len(r) for r in data],
255298
],
299+
[
300+
f.array_has(col, literal(1.0)),
301+
lambda: [1.0 in r for r in data],
302+
],
303+
[
304+
f.array_has_all(
305+
col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
306+
),
307+
lambda: [np.all([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
308+
],
309+
[
310+
f.array_has_any(
311+
col, f.make_array(*[literal(v) for v in [1.0, 3.0, 5.0]])
312+
),
313+
lambda: [np.any([v in r for v in [1.0, 3.0, 5.0]]) for r in data],
314+
],
315+
[
316+
f.array_position(col, literal(1.0)),
317+
lambda: [py_indexof(r, 1.0) for r in data],
318+
],
319+
[
320+
f.array_indexof(col, literal(1.0)),
321+
lambda: [py_indexof(r, 1.0) for r in data],
322+
],
323+
[
324+
f.list_position(col, literal(1.0)),
325+
lambda: [py_indexof(r, 1.0) for r in data],
326+
],
327+
[
328+
f.list_indexof(col, literal(1.0)),
329+
lambda: [py_indexof(r, 1.0) for r in data],
330+
],
331+
[
332+
f.array_positions(col, literal(1.0)),
333+
lambda: [
334+
[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data
335+
],
336+
],
337+
[
338+
f.list_positions(col, literal(1.0)),
339+
lambda: [
340+
[i + 1 for i, _v in enumerate(r) if _v == 1.0] for r in data
341+
],
342+
],
343+
[
344+
f.array_ndims(col),
345+
lambda: [np.array(r).ndim for r in data],
346+
],
347+
[
348+
f.list_ndims(col),
349+
lambda: [np.array(r).ndim for r in data],
350+
],
351+
[
352+
f.array_prepend(literal(99.0), col),
353+
lambda: [np.insert(arr, 0, 99.0) for arr in data],
354+
],
355+
[
356+
f.array_push_front(literal(99.0), col),
357+
lambda: [np.insert(arr, 0, 99.0) for arr in data],
358+
],
359+
[
360+
f.list_prepend(literal(99.0), col),
361+
lambda: [np.insert(arr, 0, 99.0) for arr in data],
362+
],
363+
[
364+
f.list_push_front(literal(99.0), col),
365+
lambda: [np.insert(arr, 0, 99.0) for arr in data],
366+
],
367+
[
368+
f.array_pop_back(col),
369+
lambda: [arr[:-1] for arr in data],
370+
],
371+
[
372+
f.array_pop_front(col),
373+
lambda: [arr[1:] for arr in data],
374+
],
375+
[
376+
f.array_remove(col, literal(3.0)),
377+
lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
378+
],
379+
[
380+
f.list_remove(col, literal(3.0)),
381+
lambda: [py_arr_remove(arr, 3.0, 1) for arr in data],
382+
],
383+
[
384+
f.array_remove_n(col, literal(3.0), literal(2)),
385+
lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
386+
],
387+
[
388+
f.list_remove_n(col, literal(3.0), literal(2)),
389+
lambda: [py_arr_remove(arr, 3.0, 2) for arr in data],
390+
],
391+
[
392+
f.array_remove_all(col, literal(3.0)),
393+
lambda: [py_arr_remove(arr, 3.0) for arr in data],
394+
],
395+
[
396+
f.list_remove_all(col, literal(3.0)),
397+
lambda: [py_arr_remove(arr, 3.0) for arr in data],
398+
],
399+
[
400+
f.array_repeat(col, literal(2)),
401+
lambda: [[arr] * 2 for arr in data],
402+
],
403+
[
404+
f.array_replace(col, literal(3.0), literal(4.0)),
405+
lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
406+
],
407+
[
408+
f.list_replace(col, literal(3.0), literal(4.0)),
409+
lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
410+
],
411+
[
412+
f.array_replace_n(col, literal(3.0), literal(4.0), literal(1)),
413+
lambda: [py_arr_replace(arr, 3.0, 4.0, 1) for arr in data],
414+
],
415+
[
416+
f.list_replace_n(col, literal(3.0), literal(4.0), literal(2)),
417+
lambda: [py_arr_replace(arr, 3.0, 4.0, 2) for arr in data],
418+
],
419+
[
420+
f.array_replace_all(col, literal(3.0), literal(4.0)),
421+
lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
422+
],
423+
[
424+
f.list_replace_all(col, literal(3.0), literal(4.0)),
425+
lambda: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
426+
],
427+
[
428+
f.array_slice(col, literal(2), literal(4)),
429+
lambda: [arr[1:4] for arr in data],
430+
],
431+
[
432+
f.list_slice(col, literal(-1), literal(2)),
433+
lambda: [arr[-1:2] for arr in data],
434+
],
256435
]
257436

258437
for stmt, py_expr in test_items:
259-
query_result = df.select(stmt).collect()[0].column(0).tolist()
438+
query_result = df.select(stmt).collect()[0].column(0)
439+
for a, b in zip(query_result, py_expr()):
440+
np.testing.assert_array_almost_equal(
441+
np.array(a.as_py(), dtype=float), np.array(b, dtype=float)
442+
)
443+
444+
obj_test_items = [
445+
[
446+
f.array_to_string(col, literal(",")),
447+
lambda: [",".join([str(int(v)) for v in r]) for r in data],
448+
],
449+
[
450+
f.array_join(col, literal(",")),
451+
lambda: [",".join([str(int(v)) for v in r]) for r in data],
452+
],
453+
[
454+
f.list_to_string(col, literal(",")),
455+
lambda: [",".join([str(int(v)) for v in r]) for r in data],
456+
],
457+
[
458+
f.list_join(col, literal(",")),
459+
lambda: [",".join([str(int(v)) for v in r]) for r in data],
460+
],
461+
]
462+
463+
for stmt, py_expr in obj_test_items:
464+
query_result = np.array(df.select(stmt).collect()[0].column(0))
260465
for a, b in zip(query_result, py_expr()):
261-
np.testing.assert_array_almost_equal(a, b)
466+
assert a == b
262467

263468

264469
def test_string_functions(df):

src/functions.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,9 @@ scalar_function!(decode, Decode);
360360

361361
// Array Functions
362362
scalar_function!(array_append, ArrayAppend);
363+
scalar_function!(array_push_back, ArrayAppend);
364+
scalar_function!(list_append, ArrayAppend);
365+
scalar_function!(list_push_back, ArrayAppend);
363366
scalar_function!(array_concat, ArrayConcat);
364367
scalar_function!(array_cat, ArrayConcat);
365368
scalar_function!(array_dims, ArrayDims);
@@ -370,6 +373,42 @@ scalar_function!(list_element, ArrayElement);
370373
scalar_function!(list_extract, ArrayElement);
371374
scalar_function!(array_length, ArrayLength);
372375
scalar_function!(list_length, ArrayLength);
376+
scalar_function!(array_has, ArrayHas);
377+
scalar_function!(array_has_all, ArrayHasAll);
378+
scalar_function!(array_has_any, ArrayHasAny);
379+
scalar_function!(array_position, ArrayPosition);
380+
scalar_function!(array_indexof, ArrayPosition);
381+
scalar_function!(list_position, ArrayPosition);
382+
scalar_function!(list_indexof, ArrayPosition);
383+
scalar_function!(array_positions, ArrayPositions);
384+
scalar_function!(list_positions, ArrayPositions);
385+
scalar_function!(array_to_string, ArrayToString);
386+
scalar_function!(array_join, ArrayToString);
387+
scalar_function!(list_to_string, ArrayToString);
388+
scalar_function!(list_join, ArrayToString);
389+
scalar_function!(array_ndims, ArrayNdims);
390+
scalar_function!(list_ndims, ArrayNdims);
391+
scalar_function!(array_prepend, ArrayPrepend);
392+
scalar_function!(array_push_front, ArrayPrepend);
393+
scalar_function!(list_prepend, ArrayPrepend);
394+
scalar_function!(list_push_front, ArrayPrepend);
395+
scalar_function!(array_pop_back, ArrayPopBack);
396+
scalar_function!(array_pop_front, ArrayPopFront);
397+
scalar_function!(array_remove, ArrayRemove);
398+
scalar_function!(list_remove, ArrayRemove);
399+
scalar_function!(array_remove_n, ArrayRemoveN);
400+
scalar_function!(list_remove_n, ArrayRemoveN);
401+
scalar_function!(array_remove_all, ArrayRemoveAll);
402+
scalar_function!(list_remove_all, ArrayRemoveAll);
403+
scalar_function!(array_repeat, ArrayRepeat);
404+
scalar_function!(array_replace, ArrayReplace);
405+
scalar_function!(list_replace, ArrayReplace);
406+
scalar_function!(array_replace_n, ArrayReplaceN);
407+
scalar_function!(list_replace_n, ArrayReplaceN);
408+
scalar_function!(array_replace_all, ArrayReplaceAll);
409+
scalar_function!(list_replace_all, ArrayReplaceAll);
410+
scalar_function!(array_slice, ArraySlice);
411+
scalar_function!(list_slice, ArraySlice);
373412

374413
aggregate_function!(approx_distinct, ApproxDistinct);
375414
aggregate_function!(approx_median, ApproxMedian);
@@ -563,6 +602,9 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
563602

564603
// Array Functions
565604
m.add_wrapped(wrap_pyfunction!(array_append))?;
605+
m.add_wrapped(wrap_pyfunction!(array_push_back))?;
606+
m.add_wrapped(wrap_pyfunction!(list_append))?;
607+
m.add_wrapped(wrap_pyfunction!(list_push_back))?;
566608
m.add_wrapped(wrap_pyfunction!(array_concat))?;
567609
m.add_wrapped(wrap_pyfunction!(array_cat))?;
568610
m.add_wrapped(wrap_pyfunction!(array_dims))?;
@@ -573,6 +615,42 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
573615
m.add_wrapped(wrap_pyfunction!(list_extract))?;
574616
m.add_wrapped(wrap_pyfunction!(array_length))?;
575617
m.add_wrapped(wrap_pyfunction!(list_length))?;
618+
m.add_wrapped(wrap_pyfunction!(array_has))?;
619+
m.add_wrapped(wrap_pyfunction!(array_has_all))?;
620+
m.add_wrapped(wrap_pyfunction!(array_has_any))?;
621+
m.add_wrapped(wrap_pyfunction!(array_position))?;
622+
m.add_wrapped(wrap_pyfunction!(array_indexof))?;
623+
m.add_wrapped(wrap_pyfunction!(list_position))?;
624+
m.add_wrapped(wrap_pyfunction!(list_indexof))?;
625+
m.add_wrapped(wrap_pyfunction!(array_positions))?;
626+
m.add_wrapped(wrap_pyfunction!(list_positions))?;
627+
m.add_wrapped(wrap_pyfunction!(array_to_string))?;
628+
m.add_wrapped(wrap_pyfunction!(array_join))?;
629+
m.add_wrapped(wrap_pyfunction!(list_to_string))?;
630+
m.add_wrapped(wrap_pyfunction!(list_join))?;
631+
m.add_wrapped(wrap_pyfunction!(array_ndims))?;
632+
m.add_wrapped(wrap_pyfunction!(list_ndims))?;
633+
m.add_wrapped(wrap_pyfunction!(array_prepend))?;
634+
m.add_wrapped(wrap_pyfunction!(array_push_front))?;
635+
m.add_wrapped(wrap_pyfunction!(list_prepend))?;
636+
m.add_wrapped(wrap_pyfunction!(list_push_front))?;
637+
m.add_wrapped(wrap_pyfunction!(array_pop_back))?;
638+
m.add_wrapped(wrap_pyfunction!(array_pop_front))?;
639+
m.add_wrapped(wrap_pyfunction!(array_remove))?;
640+
m.add_wrapped(wrap_pyfunction!(list_remove))?;
641+
m.add_wrapped(wrap_pyfunction!(array_remove_n))?;
642+
m.add_wrapped(wrap_pyfunction!(list_remove_n))?;
643+
m.add_wrapped(wrap_pyfunction!(array_remove_all))?;
644+
m.add_wrapped(wrap_pyfunction!(list_remove_all))?;
645+
m.add_wrapped(wrap_pyfunction!(array_repeat))?;
646+
m.add_wrapped(wrap_pyfunction!(array_replace))?;
647+
m.add_wrapped(wrap_pyfunction!(list_replace))?;
648+
m.add_wrapped(wrap_pyfunction!(array_replace_n))?;
649+
m.add_wrapped(wrap_pyfunction!(list_replace_n))?;
650+
m.add_wrapped(wrap_pyfunction!(array_replace_all))?;
651+
m.add_wrapped(wrap_pyfunction!(list_replace_all))?;
652+
m.add_wrapped(wrap_pyfunction!(array_slice))?;
653+
m.add_wrapped(wrap_pyfunction!(list_slice))?;
576654

577655
Ok(())
578656
}

0 commit comments

Comments
 (0)