@@ -24,6 +24,7 @@ use pyo3::exceptions::{PyKeyError, PyValueError};
24
24
use pyo3:: prelude:: * ;
25
25
26
26
use datafusion:: arrow:: datatypes:: Schema ;
27
+ use datafusion:: arrow:: pyarrow:: PyArrowType ;
27
28
use datafusion:: arrow:: record_batch:: RecordBatch ;
28
29
use datafusion:: datasource:: datasource:: TableProvider ;
29
30
use datafusion:: datasource:: MemTable ;
@@ -99,9 +100,12 @@ impl PySessionContext {
99
100
Ok ( PyDataFrame :: new ( df) )
100
101
}
101
102
102
- fn create_dataframe ( & mut self , partitions : Vec < Vec < RecordBatch > > ) -> PyResult < PyDataFrame > {
103
- let table = MemTable :: try_new ( partitions[ 0 ] [ 0 ] . schema ( ) , partitions)
104
- . map_err ( DataFusionError :: from) ?;
103
+ fn create_dataframe (
104
+ & mut self ,
105
+ partitions : PyArrowType < Vec < Vec < RecordBatch > > > ,
106
+ ) -> PyResult < PyDataFrame > {
107
+ let schema = partitions. 0 [ 0 ] [ 0 ] . schema ( ) ;
108
+ let table = MemTable :: try_new ( schema, partitions. 0 ) . map_err ( DataFusionError :: from) ?;
105
109
106
110
// generate a random (unique) name for this table
107
111
// table name cannot start with numeric digit
@@ -136,10 +140,10 @@ impl PySessionContext {
136
140
fn register_record_batches (
137
141
& mut self ,
138
142
name : & str ,
139
- partitions : Vec < Vec < RecordBatch > > ,
143
+ partitions : PyArrowType < Vec < Vec < RecordBatch > > > ,
140
144
) -> PyResult < ( ) > {
141
- let schema = partitions[ 0 ] [ 0 ] . schema ( ) ;
142
- let table = MemTable :: try_new ( schema, partitions) ?;
145
+ let schema = partitions. 0 [ 0 ] [ 0 ] . schema ( ) ;
146
+ let table = MemTable :: try_new ( schema, partitions. 0 ) ?;
143
147
self . ctx
144
148
. register_table ( name, Arc :: new ( table) )
145
149
. map_err ( DataFusionError :: from) ?;
@@ -182,7 +186,7 @@ impl PySessionContext {
182
186
& mut self ,
183
187
name : & str ,
184
188
path : PathBuf ,
185
- schema : Option < Schema > ,
189
+ schema : Option < PyArrowType < Schema > > ,
186
190
has_header : bool ,
187
191
delimiter : & str ,
188
192
schema_infer_max_records : usize ,
@@ -204,7 +208,7 @@ impl PySessionContext {
204
208
. delimiter ( delimiter[ 0 ] )
205
209
. schema_infer_max_records ( schema_infer_max_records)
206
210
. file_extension ( file_extension) ;
207
- options. schema = schema. as_ref ( ) ;
211
+ options. schema = schema. as_ref ( ) . map ( |x| & x . 0 ) ;
208
212
209
213
let result = self . ctx . register_csv ( name, path, options) ;
210
214
wait_for_future ( py, result) . map_err ( DataFusionError :: from) ?;
@@ -277,7 +281,7 @@ impl PySessionContext {
277
281
fn read_csv (
278
282
& self ,
279
283
path : PathBuf ,
280
- schema : Option < Schema > ,
284
+ schema : Option < PyArrowType < Schema > > ,
281
285
has_header : bool ,
282
286
delimiter : & str ,
283
287
schema_infer_max_records : usize ,
@@ -302,12 +306,17 @@ impl PySessionContext {
302
306
. schema_infer_max_records ( schema_infer_max_records)
303
307
. file_extension ( file_extension)
304
308
. table_partition_cols ( table_partition_cols) ;
305
- options. schema = schema. as_ref ( ) ;
306
309
307
- let result = self . ctx . read_csv ( path, options) ;
308
- let df = PyDataFrame :: new ( wait_for_future ( py, result) . map_err ( DataFusionError :: from) ?) ;
309
-
310
- Ok ( df)
310
+ if let Some ( py_schema) = schema {
311
+ options. schema = Some ( & py_schema. 0 ) ;
312
+ let result = self . ctx . read_csv ( path, options) ;
313
+ let df = PyDataFrame :: new ( wait_for_future ( py, result) . map_err ( DataFusionError :: from) ?) ;
314
+ Ok ( df)
315
+ } else {
316
+ let result = self . ctx . read_csv ( path, options) ;
317
+ let df = PyDataFrame :: new ( wait_for_future ( py, result) . map_err ( DataFusionError :: from) ?) ;
318
+ Ok ( df)
319
+ }
311
320
}
312
321
313
322
#[ allow( clippy:: too_many_arguments) ]
@@ -346,14 +355,14 @@ impl PySessionContext {
346
355
fn read_avro (
347
356
& self ,
348
357
path : & str ,
349
- schema : Option < Schema > ,
358
+ schema : Option < PyArrowType < Schema > > ,
350
359
table_partition_cols : Vec < String > ,
351
360
file_extension : & str ,
352
361
py : Python ,
353
362
) -> PyResult < PyDataFrame > {
354
363
let mut options = AvroReadOptions :: default ( ) . table_partition_cols ( table_partition_cols) ;
355
364
options. file_extension = file_extension;
356
- options. schema = schema. map ( Arc :: new) ;
365
+ options. schema = schema. map ( |s| Arc :: new ( s . 0 ) ) ;
357
366
358
367
let result = self . ctx . read_avro ( path, options) ;
359
368
let df = PyDataFrame :: new ( wait_for_future ( py, result) . map_err ( DataFusionError :: from) ?) ;
0 commit comments