@@ -51,7 +51,7 @@ mod roundtrip_tests {
51
51
logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
52
52
} ;
53
53
use crate :: logical_plan:: LogicalExtensionCodec ;
54
- use arrow:: datatypes:: Schema ;
54
+ use arrow:: datatypes:: { Schema , SchemaRef } ;
55
55
use arrow:: {
56
56
array:: ArrayRef ,
57
57
datatypes:: {
@@ -65,7 +65,7 @@ mod roundtrip_tests {
65
65
use datafusion:: prelude:: {
66
66
create_udf, CsvReadOptions , SessionConfig , SessionContext ,
67
67
} ;
68
- use datafusion:: test_util:: TestTableFactory ;
68
+ use datafusion:: test_util:: { TestTableFactory , TestTableProvider } ;
69
69
use datafusion_common:: { DFSchemaRef , DataFusionError , ScalarValue } ;
70
70
use datafusion_expr:: create_udaf;
71
71
use datafusion_expr:: expr:: { Between , BinaryExpr , Case , GroupingSet , Like } ;
@@ -81,6 +81,7 @@ mod roundtrip_tests {
81
81
use std:: fmt:: Debug ;
82
82
use std:: fmt:: Formatter ;
83
83
use std:: sync:: Arc ;
84
+ use datafusion:: datasource:: TableProvider ;
84
85
85
86
#[ cfg( feature = "json" ) ]
86
87
fn roundtrip_json_test ( proto : & protobuf:: LogicalExprNode ) {
@@ -135,22 +136,73 @@ mod roundtrip_tests {
135
136
Ok ( ( ) )
136
137
}
137
138
139
+ #[ derive( Clone , PartialEq , :: prost:: Message ) ]
140
+ pub struct TestTableProto {
141
+ /// URL of the table root
142
+ #[ prost( string, tag = "1" ) ]
143
+ pub url : String ,
144
+ }
145
+
146
+ #[ derive( Debug ) ]
147
+ pub struct TestTableProviderCodec { }
148
+
149
+ impl LogicalExtensionCodec for TestTableProviderCodec {
150
+ fn try_decode ( & self , buf : & [ u8 ] , inputs : & [ LogicalPlan ] , ctx : & SessionContext ) -> Result < Extension , DataFusionError > {
151
+ Err ( DataFusionError :: NotImplemented (
152
+ "No extension codec provided" . to_string ( ) ,
153
+ ) )
154
+ }
155
+
156
+ fn try_encode ( & self , node : & Extension , buf : & mut Vec < u8 > ) -> Result < ( ) , DataFusionError > {
157
+ Err ( DataFusionError :: NotImplemented (
158
+ "No extension codec provided" . to_string ( ) ,
159
+ ) )
160
+ }
161
+
162
+ fn try_decode_table_provider ( & self , buf : & [ u8 ] , schema : SchemaRef , ctx : & SessionContext ) -> Result < Arc < dyn TableProvider > , DataFusionError > {
163
+ let msg = TestTableProto :: decode ( buf)
164
+ . map_err ( |_| DataFusionError :: Internal ( "Error encoding test table" . to_string ( ) ) ) ?;
165
+ let state = ctx. state . read ( ) ;
166
+ let factory = state
167
+ . runtime_env
168
+ . table_factories
169
+ . get ( "testtable" )
170
+ . ok_or_else ( || {
171
+ DataFusionError :: Plan ( format ! (
172
+ "Unable to find testtable factory" ,
173
+ ) )
174
+ } ) ?;
175
+ let provider = ( * factory) . with_schema ( schema, msg. url . as_str ( ) ) ?;
176
+ Ok ( provider)
177
+ }
178
+
179
+ fn try_encode_table_provider ( & self , node : Arc < dyn TableProvider > , buf : & mut Vec < u8 > ) -> Result < ( ) , DataFusionError > {
180
+ let table = node. as_ref ( ) . as_any ( ) . downcast_ref :: < TestTableProvider > ( )
181
+ . ok_or ( DataFusionError :: Internal ( "Can't encode non-test tables" . to_string ( ) ) ) ?;
182
+ let msg = TestTableProto {
183
+ url : table. url . clone ( )
184
+ } ;
185
+ msg. encode ( buf) . map_err ( |_| DataFusionError :: Internal ( "Error encoding test table" . to_string ( ) ) )
186
+ }
187
+ }
188
+
138
189
#[ tokio:: test]
139
190
async fn roundtrip_custom_tables ( ) -> Result < ( ) , DataFusionError > {
140
191
let mut table_factories: HashMap < String , Arc < dyn TableProviderFactory > > =
141
192
HashMap :: new ( ) ;
142
- table_factories. insert ( "deltatable " . to_string ( ) , Arc :: new ( TestTableFactory { } ) ) ;
193
+ table_factories. insert ( "testtable " . to_string ( ) , Arc :: new ( TestTableFactory { } ) ) ;
143
194
let cfg = RuntimeConfig :: new ( ) . with_table_factories ( table_factories) ;
144
195
let env = RuntimeEnv :: new ( cfg) . unwrap ( ) ;
145
196
let ses = SessionConfig :: new ( ) ;
146
197
let ctx = SessionContext :: with_config_rt ( ses, Arc :: new ( env) ) ;
147
198
148
- let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';" ;
199
+ let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';" ;
149
200
ctx. sql ( sql) . await . unwrap ( ) ;
150
201
151
- let scan = ctx. table ( "dt" ) ?. to_logical_plan ( ) ?;
152
- let bytes = logical_plan_to_bytes ( & scan) ?;
153
- let logical_round_trip = logical_plan_from_bytes ( & bytes, & ctx) ?;
202
+ let codec = TestTableProviderCodec { } ;
203
+ let scan = ctx. table ( "t" ) ?. to_logical_plan ( ) ?;
204
+ let bytes = logical_plan_to_bytes_with_extension_codec ( & scan, & codec) ?;
205
+ let logical_round_trip = logical_plan_from_bytes_with_extension_codec ( & bytes, & ctx, & codec) ?;
154
206
assert_eq ! ( format!( "{:?}" , scan) , format!( "{:?}" , logical_round_trip) ) ;
155
207
Ok ( ( ) )
156
208
}
@@ -350,6 +402,18 @@ mod roundtrip_tests {
350
402
) )
351
403
}
352
404
}
405
+
406
+ fn try_decode_table_provider ( & self , buf : & [ u8 ] , schema : SchemaRef , ctx : & SessionContext ) -> Result < Arc < dyn TableProvider > , DataFusionError > {
407
+ Err ( DataFusionError :: Internal (
408
+ "unsupported plan type" . to_string ( ) ,
409
+ ) )
410
+ }
411
+
412
+ fn try_encode_table_provider ( & self , node : Arc < dyn TableProvider > , buf : & mut Vec < u8 > ) -> Result < ( ) , DataFusionError > {
413
+ Err ( DataFusionError :: Internal (
414
+ "unsupported plan type" . to_string ( ) ,
415
+ ) )
416
+ }
353
417
}
354
418
355
419
#[ test]
0 commit comments