19
19
//! queried by DataFusion. This allows data to be pre-loaded into memory and then
20
20
//! repeatedly queried without incurring additional file I/O overhead.
21
21
22
- use futures:: StreamExt ;
22
+ use futures:: { StreamExt , TryStreamExt } ;
23
23
use std:: any:: Any ;
24
24
use std:: sync:: Arc ;
25
25
26
26
use arrow:: datatypes:: SchemaRef ;
27
27
use arrow:: record_batch:: RecordBatch ;
28
28
use async_trait:: async_trait;
29
+ use datafusion_expr:: LogicalPlan ;
30
+ use tokio:: sync:: RwLock ;
31
+ use tokio:: task;
29
32
30
33
use crate :: datasource:: { TableProvider , TableType } ;
31
34
use crate :: error:: { DataFusionError , Result } ;
32
35
use crate :: execution:: context:: SessionState ;
33
36
use crate :: logical_expr:: Expr ;
37
+ use crate :: physical_plan:: coalesce_partitions:: CoalescePartitionsExec ;
34
38
use crate :: physical_plan:: common;
35
39
use crate :: physical_plan:: common:: AbortOnDropSingle ;
36
40
use crate :: physical_plan:: memory:: MemoryExec ;
@@ -41,7 +45,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
41
45
#[ derive( Debug ) ]
42
46
pub struct MemTable {
43
47
schema : SchemaRef ,
44
- batches : Vec < Vec < RecordBatch > > ,
48
+ batches : Arc < RwLock < Vec < Vec < RecordBatch > > > > ,
45
49
}
46
50
47
51
impl MemTable {
@@ -54,7 +58,7 @@ impl MemTable {
54
58
{
55
59
Ok ( Self {
56
60
schema,
57
- batches : partitions,
61
+ batches : Arc :: new ( RwLock :: new ( partitions) ) ,
58
62
} )
59
63
} else {
60
64
Err ( DataFusionError :: Plan (
@@ -143,22 +147,102 @@ impl TableProvider for MemTable {
143
147
_filters : & [ Expr ] ,
144
148
_limit : Option < usize > ,
145
149
) -> Result < Arc < dyn ExecutionPlan > > {
150
+ let batches = & self . batches . read ( ) . await ;
146
151
Ok ( Arc :: new ( MemoryExec :: try_new (
147
- & self . batches . clone ( ) ,
152
+ batches,
148
153
self . schema ( ) ,
149
154
projection. cloned ( ) ,
150
155
) ?) )
151
156
}
157
+
158
+ /// Inserts the execution results of a given [LogicalPlan] into this [MemTable].
159
+ /// The `LogicalPlan` must have the same schema as this `MemTable`.
160
+ ///
161
+ /// # Arguments
162
+ ///
163
+ /// * `state` - The [SessionState] containing the context for executing the plan.
164
+ /// * `input` - The [LogicalPlan] to execute and insert.
165
+ ///
166
+ /// # Returns
167
+ ///
168
+ /// * A `Result` indicating success or failure.
169
+ async fn insert_into ( & self , state : & SessionState , input : & LogicalPlan ) -> Result < ( ) > {
170
+ // Create a physical plan from the logical plan.
171
+ let plan = state. create_physical_plan ( input) . await ?;
172
+
173
+ // Check that the schema of the plan matches the schema of this table.
174
+ if !plan. schema ( ) . eq ( & self . schema ) {
175
+ return Err ( DataFusionError :: Plan (
176
+ "Inserting query must have the same schema with the table." . to_string ( ) ,
177
+ ) ) ;
178
+ }
179
+
180
+ // Get the number of partitions in the plan and the table.
181
+ let plan_partition_count = plan. output_partitioning ( ) . partition_count ( ) ;
182
+ let table_partition_count = self . batches . read ( ) . await . len ( ) ;
183
+
184
+ // Adjust the plan as necessary to match the number of partitions in the table.
185
+ let plan: Arc < dyn ExecutionPlan > = if plan_partition_count
186
+ == table_partition_count
187
+ || table_partition_count == 0
188
+ {
189
+ plan
190
+ } else if table_partition_count == 1 {
191
+ // If the table has only one partition, coalesce the partitions in the plan.
192
+ Arc :: new ( CoalescePartitionsExec :: new ( plan) )
193
+ } else {
194
+ // Otherwise, repartition the plan using a round-robin partitioning scheme.
195
+ Arc :: new ( RepartitionExec :: try_new (
196
+ plan,
197
+ Partitioning :: RoundRobinBatch ( table_partition_count) ,
198
+ ) ?)
199
+ } ;
200
+
201
+ // Get the task context from the session state.
202
+ let task_ctx = state. task_ctx ( ) ;
203
+
204
+ // Execute the plan and collect the results into batches.
205
+ let mut tasks = vec ! [ ] ;
206
+ for idx in 0 ..plan. output_partitioning ( ) . partition_count ( ) {
207
+ let stream = plan. execute ( idx, task_ctx. clone ( ) ) ?;
208
+ let handle = task:: spawn ( async move {
209
+ stream. try_collect ( ) . await . map_err ( DataFusionError :: from)
210
+ } ) ;
211
+ tasks. push ( AbortOnDropSingle :: new ( handle) ) ;
212
+ }
213
+ let results = futures:: future:: join_all ( tasks)
214
+ . await
215
+ . into_iter ( )
216
+ . map ( |result| {
217
+ result. map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) ) ?
218
+ } )
219
+ . collect :: < Result < Vec < Vec < RecordBatch > > > > ( ) ?;
220
+
221
+ // Write the results into the table.
222
+ let mut all_batches = self . batches . write ( ) . await ;
223
+
224
+ if all_batches. is_empty ( ) {
225
+ * all_batches = results
226
+ } else {
227
+ for ( batches, result) in all_batches. iter_mut ( ) . zip ( results. into_iter ( ) ) {
228
+ batches. extend ( result) ;
229
+ }
230
+ }
231
+
232
+ Ok ( ( ) )
233
+ }
152
234
}
153
235
154
236
#[ cfg( test) ]
155
237
mod tests {
156
238
use super :: * ;
239
+ use crate :: datasource:: provider_as_source;
157
240
use crate :: from_slice:: FromSlice ;
158
241
use crate :: prelude:: SessionContext ;
159
242
use arrow:: array:: Int32Array ;
160
243
use arrow:: datatypes:: { DataType , Field , Schema } ;
161
244
use arrow:: error:: ArrowError ;
245
+ use datafusion_expr:: LogicalPlanBuilder ;
162
246
use futures:: StreamExt ;
163
247
use std:: collections:: HashMap ;
164
248
@@ -388,4 +472,135 @@ mod tests {
388
472
389
473
Ok ( ( ) )
390
474
}
475
+
476
+ fn create_mem_table_scan (
477
+ schema : SchemaRef ,
478
+ data : Vec < Vec < RecordBatch > > ,
479
+ ) -> Result < Arc < LogicalPlan > > {
480
+ // Convert the table into a provider so that it can be used in a query
481
+ let provider = provider_as_source ( Arc :: new ( MemTable :: try_new ( schema, data) ?) ) ;
482
+ // Create a table scan logical plan to read from the table
483
+ Ok ( Arc :: new (
484
+ LogicalPlanBuilder :: scan ( "source" , provider, None ) ?. build ( ) ?,
485
+ ) )
486
+ }
487
+
488
+ fn create_initial_ctx ( ) -> Result < ( SessionContext , SchemaRef , RecordBatch ) > {
489
+ // Create a new session context
490
+ let session_ctx = SessionContext :: new ( ) ;
491
+ // Create a new schema with one field called "a" of type Int32
492
+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ) ;
493
+
494
+ // Create a new batch of data to insert into the table
495
+ let batch = RecordBatch :: try_new (
496
+ schema. clone ( ) ,
497
+ vec ! [ Arc :: new( Int32Array :: from_slice( [ 1 , 2 , 3 ] ) ) ] ,
498
+ ) ?;
499
+ Ok ( ( session_ctx, schema, batch) )
500
+ }
501
+
502
+ #[ tokio:: test]
503
+ async fn test_insert_into_single_partition ( ) -> Result < ( ) > {
504
+ let ( session_ctx, schema, batch) = create_initial_ctx ( ) ?;
505
+ let initial_table = Arc :: new ( MemTable :: try_new (
506
+ schema. clone ( ) ,
507
+ vec ! [ vec![ batch. clone( ) ] ] ,
508
+ ) ?) ;
509
+ // Create a table scan logical plan to read from the table
510
+ let single_partition_table_scan =
511
+ create_mem_table_scan ( schema. clone ( ) , vec ! [ vec![ batch. clone( ) ] ] ) ?;
512
+ // Insert the data from the provider into the table
513
+ initial_table
514
+ . insert_into ( & session_ctx. state ( ) , & single_partition_table_scan)
515
+ . await ?;
516
+ // Ensure that the table now contains two batches of data in the same partition
517
+ assert_eq ! ( initial_table. batches. read( ) . await . get( 0 ) . unwrap( ) . len( ) , 2 ) ;
518
+
519
+ // Create a new provider with 2 partitions
520
+ let multi_partition_table_scan = create_mem_table_scan (
521
+ schema. clone ( ) ,
522
+ vec ! [ vec![ batch. clone( ) ] , vec![ batch] ] ,
523
+ ) ?;
524
+
525
+ // Insert the data from the provider into the table. We expect coalescing partitions.
526
+ initial_table
527
+ . insert_into ( & session_ctx. state ( ) , & multi_partition_table_scan)
528
+ . await ?;
529
+ // Ensure that the table now contains 4 batches of data with only 1 partition
530
+ assert_eq ! ( initial_table. batches. read( ) . await . get( 0 ) . unwrap( ) . len( ) , 4 ) ;
531
+ assert_eq ! ( initial_table. batches. read( ) . await . len( ) , 1 ) ;
532
+ Ok ( ( ) )
533
+ }
534
+
535
+ #[ tokio:: test]
536
+ async fn test_insert_into_multiple_partition ( ) -> Result < ( ) > {
537
+ let ( session_ctx, schema, batch) = create_initial_ctx ( ) ?;
538
+ // create a memory table with two partitions, each having one batch with the same data
539
+ let initial_table = Arc :: new ( MemTable :: try_new (
540
+ schema. clone ( ) ,
541
+ vec ! [ vec![ batch. clone( ) ] , vec![ batch. clone( ) ] ] ,
542
+ ) ?) ;
543
+
544
+ // scan a data source provider from a memory table with a single partition
545
+ let single_partition_table_scan = create_mem_table_scan (
546
+ schema. clone ( ) ,
547
+ vec ! [ vec![ batch. clone( ) , batch. clone( ) ] ] ,
548
+ ) ?;
549
+
550
+ // insert the data from the 1 partition data source provider into the initial table
551
+ initial_table
552
+ . insert_into ( & session_ctx. state ( ) , & single_partition_table_scan)
553
+ . await ?;
554
+
555
+ // We expect round robin repartition here, each partition gets 1 batch.
556
+ assert_eq ! ( initial_table. batches. read( ) . await . get( 0 ) . unwrap( ) . len( ) , 2 ) ;
557
+ assert_eq ! ( initial_table. batches. read( ) . await . get( 1 ) . unwrap( ) . len( ) , 2 ) ;
558
+
559
+ // scan a data source provider from a memory table with 2 partition
560
+ let multi_partition_table_scan = create_mem_table_scan (
561
+ schema. clone ( ) ,
562
+ vec ! [ vec![ batch. clone( ) ] , vec![ batch] ] ,
563
+ ) ?;
564
+ // We expect one-to-one partition mapping.
565
+ initial_table
566
+ . insert_into ( & session_ctx. state ( ) , & multi_partition_table_scan)
567
+ . await ?;
568
+ // Ensure that the table now contains 3 batches of data with 2 partitions.
569
+ assert_eq ! ( initial_table. batches. read( ) . await . get( 0 ) . unwrap( ) . len( ) , 3 ) ;
570
+ assert_eq ! ( initial_table. batches. read( ) . await . get( 1 ) . unwrap( ) . len( ) , 3 ) ;
571
+ Ok ( ( ) )
572
+ }
573
+
574
+ #[ tokio:: test]
575
+ async fn test_insert_into_empty_table ( ) -> Result < ( ) > {
576
+ let ( session_ctx, schema, batch) = create_initial_ctx ( ) ?;
577
+ // create empty memory table
578
+ let initial_table = Arc :: new ( MemTable :: try_new ( schema. clone ( ) , vec ! [ ] ) ?) ;
579
+
580
+ // scan a data source provider from a memory table with a single partition
581
+ let single_partition_table_scan = create_mem_table_scan (
582
+ schema. clone ( ) ,
583
+ vec ! [ vec![ batch. clone( ) , batch. clone( ) ] ] ,
584
+ ) ?;
585
+
586
+ // insert the data from the 1 partition data source provider into the initial table
587
+ initial_table
588
+ . insert_into ( & session_ctx. state ( ) , & single_partition_table_scan)
589
+ . await ?;
590
+
591
+ assert_eq ! ( initial_table. batches. read( ) . await . get( 0 ) . unwrap( ) . len( ) , 2 ) ;
592
+
593
+ // scan a data source provider from a memory table with 2 partition
594
+ let single_partition_table_scan = create_mem_table_scan (
595
+ schema. clone ( ) ,
596
+ vec ! [ vec![ batch. clone( ) ] , vec![ batch] ] ,
597
+ ) ?;
598
+ // We expect coalesce partitions here.
599
+ initial_table
600
+ . insert_into ( & session_ctx. state ( ) , & single_partition_table_scan)
601
+ . await ?;
602
+ // Ensure that the table now contains 3 batches of data with 2 partitions.
603
+ assert_eq ! ( initial_table. batches. read( ) . await . get( 0 ) . unwrap( ) . len( ) , 4 ) ;
604
+ Ok ( ( ) )
605
+ }
391
606
}
0 commit comments