|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use std::{ffi::CString, sync::Arc}; |
| 18 | +use crate::table_function::MyTableFunction; |
| 19 | +use crate::table_provider::MyTableProvider; |
| 20 | +use pyo3::prelude::*; |
19 | 21 |
|
20 | | -use arrow_array::ArrayRef; |
21 | | -use datafusion::catalog::{TableFunctionImpl, TableProvider}; |
22 | | -use datafusion::logical_expr::Expr; |
23 | | -use datafusion::{ |
24 | | - arrow::{ |
25 | | - array::RecordBatch, |
26 | | - datatypes::{DataType, Field, Schema}, |
27 | | - }, |
28 | | - datasource::MemTable, |
29 | | - error::{DataFusionError, Result}, |
30 | | -}; |
31 | | -use datafusion_ffi::table_provider::FFI_TableProvider; |
32 | | -use datafusion_ffi::udtf::FFI_TableFunction; |
33 | | -use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyCapsule}; |
34 | | - |
35 | | -/// In order to provide a test that demonstrates different sized record batches, |
36 | | -/// the first batch will have num_rows, the second batch num_rows+1, and so on. |
37 | | -#[pyclass(name = "MyTableProvider", module = "ffi_table_provider", subclass)] |
38 | | -#[derive(Clone)] |
39 | | -struct MyTableProvider { |
40 | | - num_cols: usize, |
41 | | - num_rows: usize, |
42 | | - num_batches: usize, |
43 | | -} |
44 | | - |
45 | | -fn create_record_batch( |
46 | | - schema: &Arc<Schema>, |
47 | | - num_cols: usize, |
48 | | - start_value: i32, |
49 | | - num_values: usize, |
50 | | -) -> Result<RecordBatch> { |
51 | | - let end_value = start_value + num_values as i32; |
52 | | - let row_values: Vec<i32> = (start_value..end_value).collect(); |
53 | | - |
54 | | - let columns: Vec<_> = (0..num_cols) |
55 | | - .map(|_| { |
56 | | - std::sync::Arc::new(arrow::array::Int32Array::from(row_values.clone())) as ArrayRef |
57 | | - }) |
58 | | - .collect(); |
59 | | - |
60 | | - RecordBatch::try_new(Arc::clone(schema), columns).map_err(DataFusionError::from) |
61 | | -} |
62 | | - |
63 | | -impl MyTableProvider { |
64 | | - fn create_table(&self) -> Result<MemTable> { |
65 | | - let fields: Vec<_> = (0..self.num_cols) |
66 | | - .map(|idx| (b'A' + idx as u8) as char) |
67 | | - .map(|col_name| Field::new(col_name, DataType::Int32, true)) |
68 | | - .collect(); |
69 | | - |
70 | | - let schema = Arc::new(Schema::new(fields)); |
71 | | - |
72 | | - let batches: Result<Vec<_>> = (0..self.num_batches) |
73 | | - .map(|batch_idx| { |
74 | | - let start_value = batch_idx * self.num_rows; |
75 | | - create_record_batch( |
76 | | - &schema, |
77 | | - self.num_cols, |
78 | | - start_value as i32, |
79 | | - self.num_rows + batch_idx, |
80 | | - ) |
81 | | - }) |
82 | | - .collect(); |
83 | | - |
84 | | - MemTable::try_new(schema, vec![batches?]) |
85 | | - } |
86 | | -} |
87 | | - |
88 | | -#[pymethods] |
89 | | -impl MyTableProvider { |
90 | | - #[new] |
91 | | - fn new(num_cols: usize, num_rows: usize, num_batches: usize) -> Self { |
92 | | - Self { |
93 | | - num_cols, |
94 | | - num_rows, |
95 | | - num_batches, |
96 | | - } |
97 | | - } |
98 | | - |
99 | | - fn __datafusion_table_provider__<'py>( |
100 | | - &self, |
101 | | - py: Python<'py>, |
102 | | - ) -> PyResult<Bound<'py, PyCapsule>> { |
103 | | - let name = CString::new("datafusion_table_provider").unwrap(); |
104 | | - |
105 | | - let provider = self |
106 | | - .create_table() |
107 | | - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; |
108 | | - let provider = FFI_TableProvider::new(Arc::new(provider), false, None); |
109 | | - |
110 | | - PyCapsule::new_bound(py, provider, Some(name.clone())) |
111 | | - } |
112 | | -} |
113 | | - |
114 | | -#[pyclass(name = "MyTableFunction", module = "ffi_table_provider", subclass)] |
115 | | -#[derive(Debug, Clone)] |
116 | | -struct MyTableFunction {} |
117 | | - |
118 | | -#[pymethods] |
119 | | -impl MyTableFunction { |
120 | | - #[new] |
121 | | - fn new() -> Self { |
122 | | - Self {} |
123 | | - } |
124 | | - |
125 | | - fn __datafusion_table_function__<'py>( |
126 | | - &self, |
127 | | - py: Python<'py>, |
128 | | - ) -> PyResult<Bound<'py, PyCapsule>> { |
129 | | - let name = cr"datafusion_table_function".into(); |
130 | | - |
131 | | - let func = self.clone(); |
132 | | - let provider = FFI_TableFunction::new(Arc::new(func), None); |
133 | | - |
134 | | - PyCapsule::new(py, provider, Some(name)) |
135 | | - } |
136 | | -} |
137 | | - |
138 | | -impl TableFunctionImpl for MyTableFunction { |
139 | | - fn call(&self, args: &[Expr]) -> Result<Arc<dyn TableProvider>> { |
140 | | - let provider = MyTableProvider::new(10, 3, 2).create_table()?; |
141 | | - Ok(Arc::new(provider)) |
142 | | - } |
143 | | -} |
| 22 | +pub(crate) mod table_function; |
| 23 | +pub(crate) mod table_provider; |
144 | 24 |
|
145 | 25 | #[pymodule] |
146 | | -fn ffi_table_provider(m: &Bound<'_, PyModule>) -> PyResult<()> { |
| 26 | +fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { |
147 | 27 | m.add_class::<MyTableProvider>()?; |
148 | 28 | m.add_class::<MyTableFunction>()?; |
149 | 29 | Ok(()) |
|
0 commit comments