19
19
20
20
use futures:: StreamExt ;
21
21
use std:: any:: Any ;
22
+ use std:: fmt:: { self , Debug , Display } ;
22
23
use std:: sync:: Arc ;
23
24
24
25
use arrow:: datatypes:: SchemaRef ;
@@ -30,11 +31,11 @@ use crate::datasource::{TableProvider, TableType};
30
31
use crate :: error:: { DataFusionError , Result } ;
31
32
use crate :: execution:: context:: SessionState ;
32
33
use crate :: logical_expr:: Expr ;
33
- use crate :: physical_plan:: common;
34
34
use crate :: physical_plan:: common:: AbortOnDropSingle ;
35
+ use crate :: physical_plan:: insert:: { DataSink , InsertExec } ;
35
36
use crate :: physical_plan:: memory:: MemoryExec ;
36
- use crate :: physical_plan:: memory:: MemoryWriteExec ;
37
37
use crate :: physical_plan:: ExecutionPlan ;
38
+ use crate :: physical_plan:: { common, SendableRecordBatchStream } ;
38
39
use crate :: physical_plan:: { repartition:: RepartitionExec , Partitioning } ;
39
40
40
41
/// Type alias for partition data
@@ -164,7 +165,8 @@ impl TableProvider for MemTable {
164
165
) ?) )
165
166
}
166
167
167
- /// Inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
168
+ /// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
169
+ ///
168
170
/// The [`ExecutionPlan`] must have the same schema as this [`MemTable`].
169
171
///
170
172
/// # Arguments
@@ -174,7 +176,7 @@ impl TableProvider for MemTable {
174
176
///
175
177
/// # Returns
176
178
///
177
- /// * A `Result` indicating success or failure .
179
+ /// * A plan that returns the number of rows written .
178
180
async fn insert_into (
179
181
& self ,
180
182
_state : & SessionState ,
@@ -187,27 +189,61 @@ impl TableProvider for MemTable {
187
189
"Inserting query must have the same schema with the table." . to_string ( ) ,
188
190
) ) ;
189
191
}
192
+ let sink = Arc :: new ( MemSink :: new ( self . batches . clone ( ) ) ) ;
193
+ Ok ( Arc :: new ( InsertExec :: new ( input, sink) ) )
194
+ }
195
+ }
190
196
191
- if self . batches . is_empty ( ) {
192
- return Err ( DataFusionError :: Plan (
193
- "The table must have partitions." . to_string ( ) ,
194
- ) ) ;
197
+ /// Implements for writing to a [`MemTable`]
198
+ struct MemSink {
199
+ /// Target locations for writing data
200
+ batches : Vec < PartitionData > ,
201
+ }
202
+
203
+ impl Debug for MemSink {
204
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
205
+ f. debug_struct ( "MemSink" )
206
+ . field ( "num_partitions" , & self . batches . len ( ) )
207
+ . finish ( )
208
+ }
209
+ }
210
+
211
+ impl Display for MemSink {
212
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
213
+ let partition_count = self . batches . len ( ) ;
214
+ write ! ( f, "MemoryTable (partitions={partition_count})" )
215
+ }
216
+ }
217
+
218
+ impl MemSink {
219
+ fn new ( batches : Vec < PartitionData > ) -> Self {
220
+ Self { batches }
221
+ }
222
+ }
223
+
224
+ #[ async_trait]
225
+ impl DataSink for MemSink {
226
+ async fn write_all ( & self , mut data : SendableRecordBatchStream ) -> Result < u64 > {
227
+ let num_partitions = self . batches . len ( ) ;
228
+
229
+ // buffer up the data round robin style into num_partitions
230
+
231
+ let mut new_batches = vec ! [ vec![ ] ; num_partitions] ;
232
+ let mut i = 0 ;
233
+ let mut row_count = 0 ;
234
+ while let Some ( batch) = data. next ( ) . await . transpose ( ) ? {
235
+ row_count += batch. num_rows ( ) ;
236
+ new_batches[ i] . push ( batch) ;
237
+ i = ( i + 1 ) % num_partitions;
195
238
}
196
239
197
- let input = if self . batches . len ( ) > 1 {
198
- Arc :: new ( RepartitionExec :: try_new (
199
- input,
200
- Partitioning :: RoundRobinBatch ( self . batches . len ( ) ) ,
201
- ) ?)
202
- } else {
203
- input
204
- } ;
240
+ // write the outputs into the batches
241
+ for ( target, mut batches) in self . batches . iter ( ) . zip ( new_batches. into_iter ( ) ) {
242
+ // Append all the new batches in one go to minimize locking overhead
243
+ target. write ( ) . await . append ( & mut batches) ;
244
+ }
205
245
206
- Ok ( Arc :: new ( MemoryWriteExec :: try_new (
207
- input,
208
- self . batches . clone ( ) ,
209
- self . schema . clone ( ) ,
210
- ) ?) )
246
+ Ok ( row_count as u64 )
211
247
}
212
248
}
213
249
@@ -218,8 +254,8 @@ mod tests {
218
254
use crate :: from_slice:: FromSlice ;
219
255
use crate :: physical_plan:: collect;
220
256
use crate :: prelude:: SessionContext ;
221
- use arrow:: array:: Int32Array ;
222
- use arrow:: datatypes:: { DataType , Field , Schema } ;
257
+ use arrow:: array:: { AsArray , Int32Array } ;
258
+ use arrow:: datatypes:: { DataType , Field , Schema , UInt64Type } ;
223
259
use arrow:: error:: ArrowError ;
224
260
use datafusion_expr:: LogicalPlanBuilder ;
225
261
use futures:: StreamExt ;
@@ -457,6 +493,11 @@ mod tests {
457
493
initial_data : Vec < Vec < RecordBatch > > ,
458
494
inserted_data : Vec < Vec < RecordBatch > > ,
459
495
) -> Result < Vec < Vec < RecordBatch > > > {
496
+ let expected_count: u64 = inserted_data
497
+ . iter ( )
498
+ . flat_map ( |batches| batches. iter ( ) . map ( |batch| batch. num_rows ( ) as u64 ) )
499
+ . sum ( ) ;
500
+
460
501
// Create a new session context
461
502
let session_ctx = SessionContext :: new ( ) ;
462
503
// Create and register the initial table with the provided schema and data
@@ -480,8 +521,8 @@ mod tests {
480
521
481
522
// Execute the physical plan and collect the results
482
523
let res = collect ( plan, session_ctx. task_ctx ( ) ) . await ?;
483
- // Ensure the result is empty after the insert operation
484
- assert ! ( res . is_empty ( ) ) ;
524
+ assert_eq ! ( extract_count ( res ) , expected_count ) ;
525
+
485
526
// Read the data from the initial table and store it in a vector of partitions
486
527
let mut partitions = vec ! [ ] ;
487
528
for partition in initial_table. batches . iter ( ) {
@@ -491,6 +532,34 @@ mod tests {
491
532
Ok ( partitions)
492
533
}
493
534
535
+ /// Returns the value of results. For example, returns 6 given the follwing
536
+ ///
537
+ /// ```text
538
+ /// +-------+,
539
+ /// | count |,
540
+ /// +-------+,
541
+ /// | 6 |,
542
+ /// +-------+,
543
+ /// ```
544
+ fn extract_count ( res : Vec < RecordBatch > ) -> u64 {
545
+ assert_eq ! ( res. len( ) , 1 , "expected one batch, got {}" , res. len( ) ) ;
546
+ let batch = & res[ 0 ] ;
547
+ assert_eq ! (
548
+ batch. num_columns( ) ,
549
+ 1 ,
550
+ "expected 1 column, got {}" ,
551
+ batch. num_columns( )
552
+ ) ;
553
+ let col = batch. column ( 0 ) . as_primitive :: < UInt64Type > ( ) ;
554
+ assert_eq ! ( col. len( ) , 1 , "expected 1 row, got {}" , col. len( ) ) ;
555
+ let val = col
556
+ . iter ( )
557
+ . next ( )
558
+ . expect ( "had value" )
559
+ . expect ( "expected non null" ) ;
560
+ val
561
+ }
562
+
494
563
// Test inserting a single batch of data into a single partition
495
564
#[ tokio:: test]
496
565
async fn test_insert_into_single_partition ( ) -> Result < ( ) > {
0 commit comments