1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: sync:: Arc ;
19-
2018use pyo3:: prelude:: * ;
19+ use std:: sync:: Arc ;
2120
2221use crate :: dataframe:: PyTableProvider ;
23- use crate :: errors:: py_datafusion_err;
22+ use crate :: errors:: { py_datafusion_err, to_datafusion_err } ;
2423use crate :: expr:: PyExpr ;
2524use crate :: utils:: validate_pycapsule;
2625use datafusion:: catalog:: { TableFunctionImpl , TableProvider } ;
27- use datafusion:: common :: exec_err ;
26+ use datafusion:: error :: Result as DataFusionResult ;
2827use datafusion:: logical_expr:: Expr ;
28+ use datafusion_ffi:: table_provider:: { FFI_TableProvider , ForeignTableProvider } ;
2929use datafusion_ffi:: udtf:: { FFI_TableFunction , ForeignTableFunction } ;
30- use pyo3:: types:: PyCapsule ;
30+ use pyo3:: exceptions:: PyNotImplementedError ;
31+ use pyo3:: types:: { PyCapsule , PyTuple } ;
3132
3233/// Represents a user defined table function
3334#[ pyclass( name = "TableFunction" , module = "datafusion" ) ]
@@ -40,7 +41,7 @@ pub struct PyTableFunction {
4041// TODO: Implement pure python based user defined table functions
4142#[ derive( Debug , Clone ) ]
4243pub ( crate ) enum PyTableFunctionInner {
43- // PythonFunction(Arc<PyObject>),
44+ PythonFunction ( Arc < PyObject > ) ,
4445 FFIFunction ( Arc < dyn TableFunctionImpl > ) ,
4546}
4647
@@ -49,22 +50,24 @@ impl PyTableFunction {
4950 #[ new]
5051 #[ pyo3( signature=( name, func) ) ]
5152 pub fn new ( name : & str , func : Bound < ' _ , PyAny > ) -> PyResult < Self > {
52- if func. hasattr ( "__datafusion_table_function__" ) ? {
53+ let inner = if func. hasattr ( "__datafusion_table_function__" ) ? {
5354 let capsule = func. getattr ( "__datafusion_table_function__" ) ?. call0 ( ) ?;
5455 let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
5556 validate_pycapsule ( capsule, "datafusion_table_function" ) ?;
5657
5758 let ffi_func = unsafe { capsule. reference :: < FFI_TableFunction > ( ) } ;
5859 let foreign_func: ForeignTableFunction = ffi_func. to_owned ( ) . into ( ) ;
5960
60- Ok ( Self {
61- name : name. to_string ( ) ,
62- inner : PyTableFunctionInner :: FFIFunction ( Arc :: new ( foreign_func) ) ,
63- } )
61+ PyTableFunctionInner :: FFIFunction ( Arc :: new ( foreign_func) )
6462 } else {
65- exec_err ! ( "Python based Table Functions are not yet implemented" )
66- . map_err ( py_datafusion_err)
67- }
63+ let py_obj = Arc :: new ( func. unbind ( ) ) ;
64+ PyTableFunctionInner :: PythonFunction ( py_obj)
65+ } ;
66+
67+ Ok ( Self {
68+ name : name. to_string ( ) ,
69+ inner,
70+ } )
6871 }
6972
7073 #[ pyo3( signature = ( * args) ) ]
@@ -80,10 +83,44 @@ impl PyTableFunction {
8083 }
8184}
8285
86+ fn call_python_table_function (
87+ func : & Arc < PyObject > ,
88+ args : & [ Expr ] ,
89+ ) -> DataFusionResult < Arc < dyn TableProvider > > {
90+ let args = args
91+ . iter ( )
92+ . map ( |arg| PyExpr :: from ( arg. clone ( ) ) )
93+ . collect :: < Vec < _ > > ( ) ;
94+
95+ // move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
96+ Python :: with_gil ( |py| {
97+ let py_args = PyTuple :: new ( py, args) ?;
98+ let provider_obj = func. call1 ( py, py_args) ?;
99+ let provider = provider_obj. bind ( py) ;
100+
101+ if provider. hasattr ( "__datafusion_table_provider__" ) ? {
102+ let capsule = provider. getattr ( "__datafusion_table_provider__" ) ?. call0 ( ) ?;
103+ let capsule = capsule. downcast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) ?;
104+ validate_pycapsule ( capsule, "datafusion_table_provider" ) ?;
105+
106+ let provider = unsafe { capsule. reference :: < FFI_TableProvider > ( ) } ;
107+ let provider: ForeignTableProvider = provider. into ( ) ;
108+
109+ Ok ( Arc :: new ( provider) as Arc < dyn TableProvider > )
110+ } else {
111+ Err ( PyNotImplementedError :: new_err (
112+ "__datafusion_table_provider__ does not exist on Table Provider object." ,
113+ ) )
114+ }
115+ } )
116+ . map_err ( to_datafusion_err)
117+ }
118+
83119impl TableFunctionImpl for PyTableFunction {
84- fn call ( & self , args : & [ Expr ] ) -> datafusion :: common :: Result < Arc < dyn TableProvider > > {
120+ fn call ( & self , args : & [ Expr ] ) -> DataFusionResult < Arc < dyn TableProvider > > {
85121 match & self . inner {
86122 PyTableFunctionInner :: FFIFunction ( func) => func. call ( args) ,
123+ PyTableFunctionInner :: PythonFunction ( obj) => call_python_table_function ( obj, args) ,
87124 }
88125 }
89126}
0 commit comments