Skip to content

Commit 9b9f5a2

Browse files
committed
Update udtf to handle python based table functions
1 parent dd430eb commit 9b9f5a2

File tree

3 files changed

+54
-15
lines changed

3 files changed

+54
-15
lines changed

examples/datafusion-ffi-example/python/tests/_test_table_function.py

Whitespace-only changes.

python/datafusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
WindowUDF,
5858
udaf,
5959
udf,
60+
udtf,
6061
udwf,
6162
)
6263

@@ -100,6 +101,7 @@
100101
"substrait",
101102
"udaf",
102103
"udf",
104+
"udtf",
103105
"udwf",
104106
"unparser",
105107
]

src/udtf.rs

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,20 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
19-
2018
use pyo3::prelude::*;
19+
use std::sync::Arc;
2120

2221
use crate::dataframe::PyTableProvider;
23-
use crate::errors::py_datafusion_err;
22+
use crate::errors::{py_datafusion_err, to_datafusion_err};
2423
use crate::expr::PyExpr;
2524
use crate::utils::validate_pycapsule;
2625
use datafusion::catalog::{TableFunctionImpl, TableProvider};
27-
use datafusion::common::exec_err;
26+
use datafusion::error::Result as DataFusionResult;
2827
use datafusion::logical_expr::Expr;
28+
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2929
use datafusion_ffi::udtf::{FFI_TableFunction, ForeignTableFunction};
30-
use pyo3::types::PyCapsule;
30+
use pyo3::exceptions::PyNotImplementedError;
31+
use pyo3::types::{PyCapsule, PyTuple};
3132

3233
/// Represents a user defined table function
3334
#[pyclass(name = "TableFunction", module = "datafusion")]
@@ -40,7 +41,7 @@ pub struct PyTableFunction {
4041
// TODO: Implement pure python based user defined table functions
4142
#[derive(Debug, Clone)]
4243
pub(crate) enum PyTableFunctionInner {
43-
// PythonFunction(Arc<PyObject>),
44+
PythonFunction(Arc<PyObject>),
4445
FFIFunction(Arc<dyn TableFunctionImpl>),
4546
}
4647

@@ -49,22 +50,24 @@ impl PyTableFunction {
4950
#[new]
5051
#[pyo3(signature=(name, func))]
5152
pub fn new(name: &str, func: Bound<'_, PyAny>) -> PyResult<Self> {
52-
if func.hasattr("__datafusion_table_function__")? {
53+
let inner = if func.hasattr("__datafusion_table_function__")? {
5354
let capsule = func.getattr("__datafusion_table_function__")?.call0()?;
5455
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
5556
validate_pycapsule(capsule, "datafusion_table_function")?;
5657

5758
let ffi_func = unsafe { capsule.reference::<FFI_TableFunction>() };
5859
let foreign_func: ForeignTableFunction = ffi_func.to_owned().into();
5960

60-
Ok(Self {
61-
name: name.to_string(),
62-
inner: PyTableFunctionInner::FFIFunction(Arc::new(foreign_func)),
63-
})
61+
PyTableFunctionInner::FFIFunction(Arc::new(foreign_func))
6462
} else {
65-
exec_err!("Python based Table Functions are not yet implemented")
66-
.map_err(py_datafusion_err)
67-
}
63+
let py_obj = Arc::new(func.unbind());
64+
PyTableFunctionInner::PythonFunction(py_obj)
65+
};
66+
67+
Ok(Self {
68+
name: name.to_string(),
69+
inner,
70+
})
6871
}
6972

7073
#[pyo3(signature = (*args))]
@@ -80,10 +83,44 @@ impl PyTableFunction {
8083
}
8184
}
8285

86+
fn call_python_table_function(
87+
func: &Arc<PyObject>,
88+
args: &[Expr],
89+
) -> DataFusionResult<Arc<dyn TableProvider>> {
90+
let args = args
91+
.iter()
92+
.map(|arg| PyExpr::from(arg.clone()))
93+
.collect::<Vec<_>>();
94+
95+
// move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
96+
Python::with_gil(|py| {
97+
let py_args = PyTuple::new(py, args)?;
98+
let provider_obj = func.call1(py, py_args)?;
99+
let provider = provider_obj.bind(py);
100+
101+
if provider.hasattr("__datafusion_table_provider__")? {
102+
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
103+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
104+
validate_pycapsule(capsule, "datafusion_table_provider")?;
105+
106+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
107+
let provider: ForeignTableProvider = provider.into();
108+
109+
Ok(Arc::new(provider) as Arc<dyn TableProvider>)
110+
} else {
111+
Err(PyNotImplementedError::new_err(
112+
"__datafusion_table_provider__ does not exist on Table Provider object.",
113+
))
114+
}
115+
})
116+
.map_err(to_datafusion_err)
117+
}
118+
83119
impl TableFunctionImpl for PyTableFunction {
84-
fn call(&self, args: &[Expr]) -> datafusion::common::Result<Arc<dyn TableProvider>> {
120+
fn call(&self, args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
85121
match &self.inner {
86122
PyTableFunctionInner::FFIFunction(func) => func.call(args),
123+
PyTableFunctionInner::PythonFunction(obj) => call_python_table_function(obj, args),
87124
}
88125
}
89126
}

0 commit comments

Comments
 (0)