Skip to content

Commit b22f82f

Browse files
authored
Add missing array functions (#551)
* Add array_append, array_concat and array_cat * Add tests for array functions array_append, array_concat and array_cat * Add array_dims and list_dims * Add tests for array_dims and list_dims * Add array_element, array_extract, list_element and list_extract * Add tests for array_element, array_extract, list_element and list_extract * Add array_length and list_length
1 parent 76d7fcf commit b22f82f

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed

datafusion/tests/test_functions.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from datafusion import functions as f
2626
from datafusion import literal
2727

28+
np.seterr(invalid="ignore")
29+
2830

2931
@pytest.fixture
3032
def df():
@@ -197,6 +199,68 @@ def test_math_functions():
197199
)
198200

199201

202+
def test_array_functions():
203+
data = [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]]
204+
ctx = SessionContext()
205+
batch = pa.RecordBatch.from_arrays(
206+
[np.array(data, dtype=object)], names=["arr"]
207+
)
208+
df = ctx.create_dataframe([[batch]])
209+
210+
col = column("arr")
211+
test_items = [
212+
[
213+
f.array_append(col, literal(99.0)),
214+
lambda: [np.append(arr, 99.0) for arr in data],
215+
],
216+
[
217+
f.array_concat(col, col),
218+
lambda: [np.concatenate([arr, arr]) for arr in data],
219+
],
220+
[
221+
f.array_cat(col, col),
222+
lambda: [np.concatenate([arr, arr]) for arr in data],
223+
],
224+
[
225+
f.array_dims(col),
226+
lambda: [[len(r)] for r in data],
227+
],
228+
[
229+
f.list_dims(col),
230+
lambda: [[len(r)] for r in data],
231+
],
232+
[
233+
f.array_element(col, literal(1)),
234+
lambda: [r[0] for r in data],
235+
],
236+
[
237+
f.array_extract(col, literal(1)),
238+
lambda: [r[0] for r in data],
239+
],
240+
[
241+
f.list_element(col, literal(1)),
242+
lambda: [r[0] for r in data],
243+
],
244+
[
245+
f.list_extract(col, literal(1)),
246+
lambda: [r[0] for r in data],
247+
],
248+
[
249+
f.array_length(col),
250+
lambda: [len(r) for r in data],
251+
],
252+
[
253+
f.list_length(col),
254+
lambda: [len(r) for r in data],
255+
],
256+
]
257+
258+
for stmt, py_expr in test_items:
259+
query_result = df.select(stmt).collect()[0].column(0).tolist()
260+
for a, b in zip(query_result, py_expr()):
261+
np.testing.assert_array_almost_equal(a, b)
262+
263+
200264
def test_string_functions(df):
201265
df = df.select(
202266
f.ascii(column("a")),

src/functions.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,19 @@ scalar_function!(random, Random);
357357
scalar_function!(encode, Encode);
358358
scalar_function!(decode, Decode);
359359

360+
// Array Functions
361+
scalar_function!(array_append, ArrayAppend);
362+
scalar_function!(array_concat, ArrayConcat);
363+
scalar_function!(array_cat, ArrayConcat);
364+
scalar_function!(array_dims, ArrayDims);
365+
scalar_function!(list_dims, ArrayDims);
366+
scalar_function!(array_element, ArrayElement);
367+
scalar_function!(array_extract, ArrayElement);
368+
scalar_function!(list_element, ArrayElement);
369+
scalar_function!(list_extract, ArrayElement);
370+
scalar_function!(array_length, ArrayLength);
371+
scalar_function!(list_length, ArrayLength);
372+
360373
aggregate_function!(approx_distinct, ApproxDistinct);
361374
aggregate_function!(approx_median, ApproxMedian);
362375
aggregate_function!(approx_percentile_cont, ApproxPercentileCont);
@@ -546,5 +559,19 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
546559
//Binary String Functions
547560
m.add_wrapped(wrap_pyfunction!(encode))?;
548561
m.add_wrapped(wrap_pyfunction!(decode))?;
562+
563+
// Array Functions
564+
m.add_wrapped(wrap_pyfunction!(array_append))?;
565+
m.add_wrapped(wrap_pyfunction!(array_concat))?;
566+
m.add_wrapped(wrap_pyfunction!(array_cat))?;
567+
m.add_wrapped(wrap_pyfunction!(array_dims))?;
568+
m.add_wrapped(wrap_pyfunction!(list_dims))?;
569+
m.add_wrapped(wrap_pyfunction!(array_element))?;
570+
m.add_wrapped(wrap_pyfunction!(array_extract))?;
571+
m.add_wrapped(wrap_pyfunction!(list_element))?;
572+
m.add_wrapped(wrap_pyfunction!(list_extract))?;
573+
m.add_wrapped(wrap_pyfunction!(array_length))?;
574+
m.add_wrapped(wrap_pyfunction!(list_length))?;
575+
549576
Ok(())
550577
}

0 commit comments

Comments
 (0)