Skip to content

Commit 2e10e88

Browse files
committed
Add table decorator and unit test
1 parent a55969c commit 2e10e88

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

examples/datafusion-ffi-example/python/tests/_test_table_function.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,8 @@ def __call__(
8888
return MyTableProvider(*args)
8989

9090

91-
def test_python_table_function():
92-
ctx = SessionContext()
93-
table_func = PythonTableFunction()
94-
table_udtf = udtf(table_func, "my_table_func")
95-
ctx.register_udtf(table_udtf)
96-
result = ctx.sql("select * from my_table_func(3,2,4)").collect()
91+
def common_table_function_test(test_ctx: SessionContext) -> None:
92+
result = test_ctx.sql("select * from my_table_func(3,2,4)").collect()
9793

9894
assert len(result) == 4
9995
assert result[0].num_columns == 3
@@ -108,3 +104,31 @@ def test_python_table_function():
108104
]
109105

110106
assert result == expected
107+
108+
109+
def test_python_table_function():
110+
ctx = SessionContext()
111+
table_func = PythonTableFunction()
112+
table_udtf = udtf(table_func, "my_table_func")
113+
ctx.register_udtf(table_udtf)
114+
115+
common_table_function_test(ctx)
116+
117+
118+
def test_python_table_function_decorator():
119+
ctx = SessionContext()
120+
121+
@udtf("my_table_func")
122+
def my_udtf(
123+
num_cols: Expr, num_rows: Expr, num_batches: Expr
124+
) -> TableProviderExportable:
125+
args = [
126+
num_cols.to_variant().value_i64(),
127+
num_rows.to_variant().value_i64(),
128+
num_batches.to_variant().value_i64(),
129+
]
130+
return MyTableProvider(*args)
131+
132+
ctx.register_udtf(my_udtf)
133+
134+
common_table_function_test(ctx)

python/datafusion/udf.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,9 @@ def udtf(*args: Any, **kwargs: Any):
804804
# Case 1: Used as a function, require the first parameter to be callable
805805
return TableFunction._create_table_udf(*args, **kwargs)
806806
if args and hasattr(args[0], "__datafusion_table_function__"):
807+
# Case 2: We have a datafusion FFI provided function
807808
return TableFunction(args[1], args[0])
808-
# Case 2: Used as a decorator with parameters
809+
# Case 3: Used as a decorator with parameters
809810
return TableFunction._create_table_udf_decorator(*args, **kwargs)
810811

811812
@staticmethod
@@ -820,6 +821,17 @@ def _create_table_udf(
820821

821822
return TableFunction(name, func)
822823

824+
@staticmethod
825+
def _create_table_udf_decorator(
826+
name: Optional[str] = None,
827+
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
828+
"""Create a decorator for a WindowUDF."""
829+
830+
def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:
831+
return TableFunction._create_table_udf(func, name)
832+
833+
return decorator
834+
823835
def __repr__(self) -> str:
824836
"""User printable representation."""
825837
return self._udtf.__repr__()

0 commit comments

Comments
 (0)