Skip to content

Commit 2f52688

Browse files
authored
Add decorator for udwf (#1061)
* feat: Introduce create_udwf method for User-Defined Window Functions - Added `create_udwf` static method to `WindowUDF` class, allowing users to create User-Defined Window Functions (UDWF) as both a function and a decorator. - Updated type hinting for `_R` using `TypeAlias` for better clarity. - Enhanced documentation with usage examples for both function and decorator styles, improving usability and understanding. * refactor: Simplify UDWF test suite and introduce SimpleWindowCount evaluator - Removed multiple exponential smoothing classes to streamline the code. - Introduced SimpleWindowCount class for basic row counting functionality. - Updated test cases to validate the new SimpleWindowCount evaluator. - Refactored fixture and test functions for clarity and consistency. - Enhanced error handling in UDWF creation tests. * fix: Update type alias import to use typing_extensions for compatibility * Add udwf tests for multiple input types and decorator syntax * replace old def udwf * refactor: Simplify df fixture by passing ctx as an argument * refactor: Rename DataFrame fixtures and update test functions - Renamed `df` fixture to `complex_window_df` for clarity. - Renamed `simple_df` fixture to `count_window_df` to better reflect its purpose. - Updated test functions to use the new fixture names, enhancing readability and maintainability. * refactor: Update udwf calls in WindowUDF to use BiasedNumbers directly - Changed udwf1 to use BiasedNumbers instead of bias_10. - Added udwf2 to call udwf with bias_10. - Introduced udwf3 to demonstrate a lambda function returning BiasedNumbers(20). * feat: Add overloads for udwf function to support multiple input types and decorator syntax * refactor: Simplify udwf method signature by removing redundant type hints * refactor: Remove state_type from udwf method signature and update return type handling - Eliminated the state_type parameter from the udwf method to simplify the function signature. - Updated return type handling in the _function and _decorator methods to use a generic type _R for better type flexibility. - Enhanced the decorator to wrap the original function, allowing for improved argument handling and expression return. * refactor: Update volatility parameter type in udwf method signature to support Volatility enum * Fix ruff errors * fix C901 for def udwf * refactor: Update udwf method signature and simplify input handling - Changed the type hint for the return type in the _create_window_udf_decorator method to use pa.DataType directly instead of a TypeVar. - Simplified the handling of input types by removing redundant checks and directly using the input types list. - Removed unnecessary comments and cleaned up the code for better readability. - Updated the test for udwf to use parameterized tests for better coverage and maintainability. * refactor: Rename input_type to input_types in udwf method signature for clarity * refactor: Enhance typing in udf.py by introducing Protocol for WindowEvaluator and improving import organization * Revert "refactor: Enhance typing in udf.py by introducing Protocol for WindowEvaluator and improving import organization" This reverts commit 16dbe5f.
1 parent 4f45703 commit 2f52688

File tree

2 files changed

+264
-29
lines changed

2 files changed

+264
-29
lines changed

python/datafusion/udf.py

+99-24
Original file line numberDiff line numberDiff line change
@@ -621,31 +621,48 @@ def __call__(self, *args: Expr) -> Expr:
621621
args_raw = [arg.expr for arg in args]
622622
return Expr(self._udwf.__call__(*args_raw))
623623

624+
@overload
625+
@staticmethod
626+
def udwf(
627+
input_types: pa.DataType | list[pa.DataType],
628+
return_type: pa.DataType,
629+
volatility: Volatility | str,
630+
name: Optional[str] = None,
631+
) -> Callable[..., WindowUDF]: ...
632+
633+
@overload
624634
@staticmethod
625635
def udwf(
626636
func: Callable[[], WindowEvaluator],
627637
input_types: pa.DataType | list[pa.DataType],
628638
return_type: pa.DataType,
629639
volatility: Volatility | str,
630640
name: Optional[str] = None,
631-
) -> WindowUDF:
632-
"""Create a new User-Defined Window Function.
641+
) -> WindowUDF: ...
633642

634-
If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
635-
can simply pass it's type as ``func``. If you need to pass additional arguments
636-
to it's constructor, you can define a lambda or a factory method. During runtime
637-
the :py:class:`WindowEvaluator` will be constructed for every instance in
638-
which this UDWF is used. The following examples are all valid.
643+
@staticmethod
644+
def udwf(*args: Any, **kwargs: Any): # noqa: D417
645+
"""Create a new User-Defined Window Function (UDWF).
639646
640-
.. code-block:: python
647+
This class can be used both as a **function** and as a **decorator**.
648+
649+
Usage:
650+
- **As a function**: Call `udwf(func, input_types, return_type, volatility,
651+
name)`.
652+
- **As a decorator**: Use `@udwf(input_types, return_type, volatility,
653+
name)`. When using `udwf` as a decorator, **do not pass `func`
654+
explicitly**.
641655
656+
**Function example:**
657+
```
642658
import pyarrow as pa
643659
644660
class BiasedNumbers(WindowEvaluator):
645661
def __init__(self, start: int = 0) -> None:
646662
self.start = start
647663
648-
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
664+
def evaluate_all(self, values: list[pa.Array],
665+
num_rows: int) -> pa.Array:
649666
return pa.array([self.start + i for i in range(num_rows)])
650667
651668
def bias_10() -> BiasedNumbers:
@@ -655,35 +672,93 @@ def bias_10() -> BiasedNumbers:
655672
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
656673
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
657674
675+
```
676+
677+
**Decorator example:**
678+
```
679+
@udwf(pa.int64(), pa.int64(), "immutable")
680+
def biased_numbers() -> BiasedNumbers:
681+
return BiasedNumbers(10)
682+
```
683+
658684
Args:
659-
func: A callable to create the window function.
660-
input_types: The data types of the arguments to ``func``.
685+
func: **Only needed when calling as a function. Skip this argument when
686+
using `udwf` as a decorator.**
687+
input_types: The data types of the arguments.
661688
return_type: The data type of the return value.
662689
volatility: See :py:class:`Volatility` for allowed values.
663-
arguments: A list of arguments to pass in to the __init__ method for accum.
664690
name: A descriptive name for the function.
665691
666692
Returns:
667-
A user-defined window function.
668-
""" # noqa: W505, E501
693+
A user-defined window function that can be used in window function calls.
694+
"""
695+
if args and callable(args[0]):
696+
# Case 1: Used as a function, require the first parameter to be callable
697+
return WindowUDF._create_window_udf(*args, **kwargs)
698+
# Case 2: Used as a decorator with parameters
699+
return WindowUDF._create_window_udf_decorator(*args, **kwargs)
700+
701+
@staticmethod
702+
def _create_window_udf(
703+
func: Callable[[], WindowEvaluator],
704+
input_types: pa.DataType | list[pa.DataType],
705+
return_type: pa.DataType,
706+
volatility: Volatility | str,
707+
name: Optional[str] = None,
708+
) -> WindowUDF:
709+
"""Create a WindowUDF instance from function arguments."""
669710
if not callable(func):
670711
msg = "`func` must be callable."
671712
raise TypeError(msg)
672713
if not isinstance(func(), WindowEvaluator):
673714
msg = "`func` must implement the abstract base class WindowEvaluator"
674715
raise TypeError(msg)
675-
if name is None:
676-
name = func().__class__.__qualname__.lower()
677-
if isinstance(input_types, pa.DataType):
678-
input_types = [input_types]
679-
return WindowUDF(
680-
name=name,
681-
func=func,
682-
input_types=input_types,
683-
return_type=return_type,
684-
volatility=volatility,
716+
717+
name = name or func.__qualname__.lower()
718+
input_types = (
719+
[input_types] if isinstance(input_types, pa.DataType) else input_types
685720
)
686721

722+
return WindowUDF(name, func, input_types, return_type, volatility)
723+
724+
@staticmethod
725+
def _get_default_name(func: Callable) -> str:
726+
"""Get the default name for a function based on its attributes."""
727+
if hasattr(func, "__qualname__"):
728+
return func.__qualname__.lower()
729+
return func.__class__.__name__.lower()
730+
731+
@staticmethod
732+
def _normalize_input_types(
733+
input_types: pa.DataType | list[pa.DataType],
734+
) -> list[pa.DataType]:
735+
"""Convert a single DataType to a list if needed."""
736+
if isinstance(input_types, pa.DataType):
737+
return [input_types]
738+
return input_types
739+
740+
@staticmethod
741+
def _create_window_udf_decorator(
742+
input_types: pa.DataType | list[pa.DataType],
743+
return_type: pa.DataType,
744+
volatility: Volatility | str,
745+
name: Optional[str] = None,
746+
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
747+
"""Create a decorator for a WindowUDF."""
748+
749+
def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:
750+
udwf_caller = WindowUDF._create_window_udf(
751+
func, input_types, return_type, volatility, name
752+
)
753+
754+
@functools.wraps(func)
755+
def wrapper(*args: Any, **kwargs: Any) -> Expr:
756+
return udwf_caller(*args, **kwargs)
757+
758+
return wrapper
759+
760+
return decorator
761+
687762

688763
# Convenience exports so we can import instead of treating as
689764
# variables at the package root

python/tests/test_udwf.py

+165-5
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,27 @@ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
162162
return pa.array(results)
163163

164164

165+
class SimpleWindowCount(WindowEvaluator):
166+
"""A simple window evaluator that counts rows."""
167+
168+
def __init__(self, base: int = 0) -> None:
169+
self.base = base
170+
171+
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
172+
return pa.array([self.base + i for i in range(num_rows)])
173+
174+
165175
class NotSubclassOfWindowEvaluator:
166176
pass
167177

168178

169179
@pytest.fixture
170-
def df():
171-
ctx = SessionContext()
180+
def ctx():
181+
return SessionContext()
182+
172183

184+
@pytest.fixture
185+
def complex_window_df(ctx):
173186
# create a RecordBatch and a new DataFrame from it
174187
batch = pa.RecordBatch.from_arrays(
175188
[
@@ -182,7 +195,17 @@ def df():
182195
return ctx.create_dataframe([[batch]])
183196

184197

185-
def test_udwf_errors(df):
198+
@pytest.fixture
199+
def count_window_df(ctx):
200+
# create a RecordBatch and a new DataFrame from it
201+
batch = pa.RecordBatch.from_arrays(
202+
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
203+
names=["a", "b"],
204+
)
205+
return ctx.create_dataframe([[batch]], name="test_table")
206+
207+
208+
def test_udwf_errors(complex_window_df):
186209
with pytest.raises(TypeError):
187210
udwf(
188211
NotSubclassOfWindowEvaluator,
@@ -192,6 +215,103 @@ def test_udwf_errors(df):
192215
)
193216

194217

218+
def test_udwf_errors_with_message():
219+
"""Test error cases for UDWF creation."""
220+
with pytest.raises(
221+
TypeError, match="`func` must implement the abstract base class WindowEvaluator"
222+
):
223+
udwf(
224+
NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable"
225+
)
226+
227+
228+
def test_udwf_basic_usage(count_window_df):
229+
"""Test basic UDWF usage with a simple counting window function."""
230+
simple_count = udwf(
231+
SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
232+
)
233+
234+
df = count_window_df.select(
235+
simple_count(column("a"))
236+
.window_frame(WindowFrame("rows", None, None))
237+
.build()
238+
.alias("count")
239+
)
240+
result = df.collect()[0]
241+
assert result.column(0) == pa.array([0, 1, 2])
242+
243+
244+
def test_udwf_with_args(count_window_df):
245+
"""Test UDWF with constructor arguments."""
246+
count_base10 = udwf(
247+
lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable"
248+
)
249+
250+
df = count_window_df.select(
251+
count_base10(column("a"))
252+
.window_frame(WindowFrame("rows", None, None))
253+
.build()
254+
.alias("count")
255+
)
256+
result = df.collect()[0]
257+
assert result.column(0) == pa.array([10, 11, 12])
258+
259+
260+
def test_udwf_decorator_basic(count_window_df):
261+
"""Test UDWF used as a decorator."""
262+
263+
@udwf([pa.int64()], pa.int64(), "immutable")
264+
def window_count() -> WindowEvaluator:
265+
return SimpleWindowCount()
266+
267+
df = count_window_df.select(
268+
window_count(column("a"))
269+
.window_frame(WindowFrame("rows", None, None))
270+
.build()
271+
.alias("count")
272+
)
273+
result = df.collect()[0]
274+
assert result.column(0) == pa.array([0, 1, 2])
275+
276+
277+
def test_udwf_decorator_with_args(count_window_df):
278+
"""Test UDWF decorator with constructor arguments."""
279+
280+
@udwf([pa.int64()], pa.int64(), "immutable")
281+
def window_count_base10() -> WindowEvaluator:
282+
return SimpleWindowCount(10)
283+
284+
df = count_window_df.select(
285+
window_count_base10(column("a"))
286+
.window_frame(WindowFrame("rows", None, None))
287+
.build()
288+
.alias("count")
289+
)
290+
result = df.collect()[0]
291+
assert result.column(0) == pa.array([10, 11, 12])
292+
293+
294+
def test_register_udwf(ctx, count_window_df):
295+
"""Test registering and using UDWF in SQL context."""
296+
window_count = udwf(
297+
SimpleWindowCount,
298+
[pa.int64()],
299+
pa.int64(),
300+
volatility="immutable",
301+
name="window_count",
302+
)
303+
304+
ctx.register_udwf(window_count)
305+
result = ctx.sql(
306+
"""
307+
SELECT window_count(a)
308+
OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
309+
FOLLOWING) FROM test_table
310+
"""
311+
).collect()[0]
312+
assert result.column(0) == pa.array([0, 1, 2])
313+
314+
195315
smooth_default = udwf(
196316
ExponentialSmoothDefault,
197317
pa.float64(),
@@ -299,10 +419,50 @@ def test_udwf_errors(df):
299419

300420

301421
@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions)
302-
def test_udwf_functions(df, name, expr, expected):
303-
df = df.select("a", "b", f.round(expr, lit(3)).alias(name))
422+
def test_udwf_functions(complex_window_df, name, expr, expected):
423+
df = complex_window_df.select("a", "b", f.round(expr, lit(3)).alias(name))
304424

305425
# execute and collect the first (and only) batch
306426
result = df.sort(column("a")).select(column(name)).collect()[0]
307427

308428
assert result.column(0) == pa.array(expected)
429+
430+
431+
@pytest.mark.parametrize(
432+
"udwf_func",
433+
[
434+
udwf(SimpleWindowCount, pa.int64(), pa.int64(), "immutable"),
435+
udwf(SimpleWindowCount, [pa.int64()], pa.int64(), "immutable"),
436+
udwf([pa.int64()], pa.int64(), "immutable")(lambda: SimpleWindowCount()),
437+
udwf(pa.int64(), pa.int64(), "immutable")(lambda: SimpleWindowCount()),
438+
],
439+
)
440+
def test_udwf_overloads(udwf_func, count_window_df):
441+
df = count_window_df.select(
442+
udwf_func(column("a"))
443+
.window_frame(WindowFrame("rows", None, None))
444+
.build()
445+
.alias("count")
446+
)
447+
result = df.collect()[0]
448+
assert result.column(0) == pa.array([0, 1, 2])
449+
450+
451+
def test_udwf_named_function(ctx, count_window_df):
452+
"""Test UDWF with explicit name parameter."""
453+
window_count = udwf(
454+
SimpleWindowCount,
455+
pa.int64(),
456+
pa.int64(),
457+
volatility="immutable",
458+
name="my_custom_counter",
459+
)
460+
461+
ctx.register_udwf(window_count)
462+
result = ctx.sql(
463+
"""
464+
SELECT my_custom_counter(a)
465+
OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
466+
FOLLOWING) FROM test_table"""
467+
).collect()[0]
468+
assert result.column(0) == pa.array([0, 1, 2])

0 commit comments

Comments
 (0)