Skip to content

add cardinality function to calculate total distinct elements in an array #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/source/user-guide/common-operations/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,20 @@ This function returns a boolean indicating whether the array is empty.

In this example, the `is_empty` column will contain `True` for the first row and `False` for the second row.

To get the total number of elements in an array, you can use the function :py:func:`datafusion.functions.cardinality`.
This function returns an integer indicating the total number of elements in the array.

.. ipython:: python

from datafusion import SessionContext, col
from datafusion.functions import cardinality

ctx = SessionContext()
df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5, 6]]})
df.select(cardinality(col("a")).alias("num_elements"))

In this example, the `num_elements` column will contain `3` for both rows.

Structs
-------

Expand Down
6 changes: 6 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
"find_in_set",
"first_value",
"flatten",
"cardinality",
"floor",
"from_unixtime",
"gcd",
Expand Down Expand Up @@ -1516,6 +1517,11 @@ def flatten(array: Expr) -> Expr:
return Expr(f.flatten(array.expr))


def cardinality(array: Expr) -> Expr:
"""Returns the total number of elements in the array."""
return Expr(f.cardinality(array.expr))


# aggregate functions
def approx_distinct(
expression: Expr,
Expand Down
18 changes: 18 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,24 @@ def test_array_function_flatten():
)


def test_array_function_cardinality():
data = [[1, 2, 3], [4, 4, 5, 6]]
ctx = SessionContext()
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
df = ctx.create_dataframe([[batch]])

stmt = f.cardinality(column("arr"))
py_expr = [len(arr) for arr in data] # Expected lengths: [3, 3]
# assert py_expr lengths

query_result = df.select(stmt).collect()[0].column(0)

for a, b in zip(query_result, py_expr):
np.testing.assert_array_equal(
np.array([a.as_py()], dtype=int), np.array([b], dtype=int)
)


@pytest.mark.parametrize(
("stmt", "py_expr"),
[
Expand Down
2 changes: 2 additions & 0 deletions src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ array_fn!(array_intersect, first_array second_array);
array_fn!(array_union, array1 array2);
array_fn!(array_except, first_array second_array);
array_fn!(array_resize, array size value);
array_fn!(cardinality, array);
array_fn!(flatten, array);
array_fn!(range, start stop step);

Expand Down Expand Up @@ -1030,6 +1031,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_sort))?;
m.add_wrapped(wrap_pyfunction!(array_slice))?;
m.add_wrapped(wrap_pyfunction!(flatten))?;
m.add_wrapped(wrap_pyfunction!(cardinality))?;

// Window Functions
m.add_wrapped(wrap_pyfunction!(lead))?;
Expand Down
Loading