Skip to content

Commit 76d7fcf

Browse files
authored
feat: udaf: enable multiple column input (#546)
1 parent 2889de0 commit 76d7fcf

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

datafusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None):
213213
)
214214
if name is None:
215215
name = accum.__qualname__.lower()
216+
if isinstance(input_type, pa.lib.DataType):
217+
input_type = [input_type]
216218
return AggregateUDF(
217219
name=name,
218220
accumulator=accum,

src/udaf.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,14 @@ impl PyAggregateUDF {
148148
fn new(
149149
name: &str,
150150
accumulator: PyObject,
151-
input_type: PyArrowType<DataType>,
151+
input_type: PyArrowType<Vec<DataType>>,
152152
return_type: PyArrowType<DataType>,
153153
state_type: PyArrowType<Vec<DataType>>,
154154
volatility: &str,
155155
) -> PyResult<Self> {
156156
let function = create_udaf(
157157
name,
158-
vec![input_type.0],
158+
input_type.0,
159159
Arc::new(return_type.0),
160160
parse_volatility(volatility)?,
161161
to_rust_accumulator(accumulator),

0 commit comments

Comments
 (0)