Skip to content

Commit 0f6277a

Browse files
committed
Add python table function example and test
1 parent 9b9f5a2 commit 0f6277a

File tree

3 files changed

+110
-2
lines changed

3 files changed

+110
-2
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from typing import TYPE_CHECKING
21+
22+
import pyarrow as pa
23+
from datafusion import SessionContext, udtf
24+
from datafusion_ffi_example import MyTableFunction, MyTableProvider
25+
26+
if TYPE_CHECKING:
27+
from datafusion.context import TableProviderExportable
28+
29+
30+
def test_ffi_table_function_register():
31+
ctx = SessionContext()
32+
table_func = MyTableFunction()
33+
table_udtf = udtf(table_func, "my_table_func")
34+
ctx.register_udtf(table_udtf)
35+
result = ctx.sql("select * from my_table_func()").collect()
36+
37+
assert len(result) == 2
38+
assert result[0].num_columns == 4
39+
print(result)
40+
41+
result = [r.column(0) for r in result]
42+
expected = [
43+
pa.array([0, 1, 2], type=pa.int32()),
44+
pa.array([3, 4, 5, 6], type=pa.int32()),
45+
]
46+
47+
assert result == expected
48+
49+
50+
def test_ffi_table_function_call_directly():
51+
ctx = SessionContext()
52+
table_func = MyTableFunction()
53+
table_udtf = udtf(table_func, "my_table_func")
54+
55+
my_table = table_udtf()
56+
ctx.register_table_provider("t", my_table)
57+
result = ctx.table("t").collect()
58+
59+
assert len(result) == 2
60+
assert result[0].num_columns == 4
61+
print(result)
62+
63+
result = [r.column(0) for r in result]
64+
expected = [
65+
pa.array([0, 1, 2], type=pa.int32()),
66+
pa.array([3, 4, 5, 6], type=pa.int32()),
67+
]
68+
69+
assert result == expected
70+
71+
72+
class PythonTableFunction:
73+
"""Python based table function.
74+
75+
This class is used as a Python implementation of a table function.
76+
We use the existing TableProvider to create the underlying
77+
provider, and this function takes no arguments
78+
"""
79+
80+
def __init__(self) -> None:
81+
self.table_provider = MyTableProvider(3, 2, 4)
82+
83+
def __call__(self) -> TableProviderExportable:
84+
return self.table_provider
85+
86+
87+
def test_python_table_function():
88+
ctx = SessionContext()
89+
table_func = PythonTableFunction()
90+
table_udtf = udtf(table_func, "my_table_func")
91+
ctx.register_udtf(table_udtf)
92+
result = ctx.sql("select * from my_table_func()").collect()
93+
94+
assert len(result) == 4
95+
assert result[0].num_columns == 3
96+
print(result)
97+
98+
result = [r.column(0) for r in result]
99+
expected = [
100+
pa.array([0, 1], type=pa.int32()),
101+
pa.array([2, 3, 4], type=pa.int32()),
102+
pa.array([4, 5, 6, 7], type=pa.int32()),
103+
pa.array([6, 7, 8, 9, 10], type=pa.int32()),
104+
]
105+
106+
assert result == expected

examples/datafusion-ffi-example/python/tests/_test_table_provider.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from __future__ import annotations
19+
1820
import pyarrow as pa
1921
from datafusion import SessionContext
20-
from ffi_table_provider import MyTableProvider
22+
from datafusion_ffi_example import MyTableProvider
2123

2224

2325
def test_table_loading():

examples/datafusion-ffi-example/src/table_function.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ impl MyTableFunction {
5050

5151
impl TableFunctionImpl for MyTableFunction {
5252
fn call(&self, _args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
53-
let provider = MyTableProvider::new(10, 3, 2).create_table()?;
53+
let provider = MyTableProvider::new(4, 3, 2).create_table()?;
5454
Ok(Arc::new(provider))
5555
}
5656
}

0 commit comments

Comments
 (0)