Skip to content

Commit 92b093c

Browse files
authored
Add make_list and tests for make_list, make_array (#949)
1 parent 5e32ada commit 92b093c

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

python/datafusion/functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@
184184
"lpad",
185185
"ltrim",
186186
"make_array",
187+
"make_list",
187188
"make_date",
188189
"max",
189190
"md5",
@@ -1044,6 +1045,14 @@ def make_array(*args: Expr) -> Expr:
10441045
return Expr(f.make_array(args))
10451046

10461047

1048+
def make_list(*args: Expr) -> Expr:
1049+
"""Returns an array using the specified input expressions.
1050+
1051+
This is an alias for :py:func:`make_array`.
1052+
"""
1053+
return make_array(*args)
1054+
1055+
10471056
def array(*args: Expr) -> Expr:
10481057
"""Returns an array using the specified input expressions.
10491058

python/tests/test_functions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,37 @@ def test_array_function_cardinality():
576576
)
577577

578578

579+
@pytest.mark.parametrize("make_func", [f.make_array, f.make_list])
580+
def test_make_array_functions(make_func):
581+
ctx = SessionContext()
582+
batch = pa.RecordBatch.from_arrays(
583+
[
584+
pa.array(["Hello", "World", "!"], type=pa.string()),
585+
pa.array([4, 5, 6]),
586+
pa.array(["hello ", " world ", " !"], type=pa.string()),
587+
],
588+
names=["a", "b", "c"],
589+
)
590+
df = ctx.create_dataframe([[batch]])
591+
592+
stmt = make_func(
593+
column("a").cast(pa.string()),
594+
column("b").cast(pa.string()),
595+
column("c").cast(pa.string()),
596+
)
597+
py_expr = [
598+
["Hello", "4", "hello "],
599+
["World", "5", " world "],
600+
["!", "6", " !"],
601+
]
602+
603+
query_result = df.select(stmt).collect()[0].column(0)
604+
for a, b in zip(query_result, py_expr):
605+
np.testing.assert_array_equal(
606+
np.array(a.as_py(), dtype=str), np.array(b, dtype=str)
607+
)
608+
609+
579610
@pytest.mark.parametrize(
580611
("stmt", "py_expr"),
581612
[

0 commit comments

Comments
 (0)