Skip to content

Commit b16a403

Browse files
authored
INSERT returns number of rows written, add InsertExec to handle common case. (#6354)
* Add InsertExec, port in memory insert to use DataSink * fix: clippy * Add Display to Sink and update plans * Add additional verification that insert made it to the table * Add test to ensure the sort order is maintained for insert query * Ensure the sort order is maintained for insert query, test for same
1 parent 98839db commit b16a403

File tree

6 files changed

+568
-575
lines changed

6 files changed

+568
-575
lines changed

datafusion/core/src/datasource/datasource.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,26 @@ pub trait TableProvider: Sync + Send {
9898
None
9999
}
100100

101-
/// Insert into this table
101+
/// Return an [`ExecutionPlan`] to insert data into this table, if
102+
/// supported.
103+
///
104+
/// The returned plan should return a single row in a UInt64
105+
/// column called "count" such as the following
106+
///
107+
/// ```text
108+
/// +-------+,
109+
/// | count |,
110+
/// +-------+,
111+
/// | 6 |,
112+
/// +-------+,
113+
/// ```
114+
///
115+
/// # See Also
116+
///
117+
/// See [`InsertExec`] for the common pattern of inserting a
118+
/// single stream of `RecordBatch`es.
119+
///
120+
/// [`InsertExec`]: crate::physical_plan::insert::InsertExec
102121
async fn insert_into(
103122
&self,
104123
_state: &SessionState,

datafusion/core/src/datasource/memory.rs

Lines changed: 94 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
2020
use futures::StreamExt;
2121
use std::any::Any;
22+
use std::fmt::{self, Debug, Display};
2223
use std::sync::Arc;
2324

2425
use arrow::datatypes::SchemaRef;
@@ -30,11 +31,11 @@ use crate::datasource::{TableProvider, TableType};
3031
use crate::error::{DataFusionError, Result};
3132
use crate::execution::context::SessionState;
3233
use crate::logical_expr::Expr;
33-
use crate::physical_plan::common;
3434
use crate::physical_plan::common::AbortOnDropSingle;
35+
use crate::physical_plan::insert::{DataSink, InsertExec};
3536
use crate::physical_plan::memory::MemoryExec;
36-
use crate::physical_plan::memory::MemoryWriteExec;
3737
use crate::physical_plan::ExecutionPlan;
38+
use crate::physical_plan::{common, SendableRecordBatchStream};
3839
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
3940

4041
/// Type alias for partition data
@@ -164,7 +165,8 @@ impl TableProvider for MemTable {
164165
)?))
165166
}
166167

167-
/// Inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
168+
/// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
169+
///
168170
/// The [`ExecutionPlan`] must have the same schema as this [`MemTable`].
169171
///
170172
/// # Arguments
@@ -174,7 +176,7 @@ impl TableProvider for MemTable {
174176
///
175177
/// # Returns
176178
///
177-
/// * A `Result` indicating success or failure.
179+
/// * A plan that returns the number of rows written.
178180
async fn insert_into(
179181
&self,
180182
_state: &SessionState,
@@ -187,27 +189,61 @@ impl TableProvider for MemTable {
187189
"Inserting query must have the same schema with the table.".to_string(),
188190
));
189191
}
192+
let sink = Arc::new(MemSink::new(self.batches.clone()));
193+
Ok(Arc::new(InsertExec::new(input, sink)))
194+
}
195+
}
190196

191-
if self.batches.is_empty() {
192-
return Err(DataFusionError::Plan(
193-
"The table must have partitions.".to_string(),
194-
));
197+
/// Implements for writing to a [`MemTable`]
198+
struct MemSink {
199+
/// Target locations for writing data
200+
batches: Vec<PartitionData>,
201+
}
202+
203+
impl Debug for MemSink {
204+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205+
f.debug_struct("MemSink")
206+
.field("num_partitions", &self.batches.len())
207+
.finish()
208+
}
209+
}
210+
211+
impl Display for MemSink {
212+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213+
let partition_count = self.batches.len();
214+
write!(f, "MemoryTable (partitions={partition_count})")
215+
}
216+
}
217+
218+
impl MemSink {
219+
fn new(batches: Vec<PartitionData>) -> Self {
220+
Self { batches }
221+
}
222+
}
223+
224+
#[async_trait]
225+
impl DataSink for MemSink {
226+
async fn write_all(&self, mut data: SendableRecordBatchStream) -> Result<u64> {
227+
let num_partitions = self.batches.len();
228+
229+
// buffer up the data round robin style into num_partitions
230+
231+
let mut new_batches = vec![vec![]; num_partitions];
232+
let mut i = 0;
233+
let mut row_count = 0;
234+
while let Some(batch) = data.next().await.transpose()? {
235+
row_count += batch.num_rows();
236+
new_batches[i].push(batch);
237+
i = (i + 1) % num_partitions;
195238
}
196239

197-
let input = if self.batches.len() > 1 {
198-
Arc::new(RepartitionExec::try_new(
199-
input,
200-
Partitioning::RoundRobinBatch(self.batches.len()),
201-
)?)
202-
} else {
203-
input
204-
};
240+
// write the outputs into the batches
241+
for (target, mut batches) in self.batches.iter().zip(new_batches.into_iter()) {
242+
// Append all the new batches in one go to minimize locking overhead
243+
target.write().await.append(&mut batches);
244+
}
205245

206-
Ok(Arc::new(MemoryWriteExec::try_new(
207-
input,
208-
self.batches.clone(),
209-
self.schema.clone(),
210-
)?))
246+
Ok(row_count as u64)
211247
}
212248
}
213249

@@ -218,8 +254,8 @@ mod tests {
218254
use crate::from_slice::FromSlice;
219255
use crate::physical_plan::collect;
220256
use crate::prelude::SessionContext;
221-
use arrow::array::Int32Array;
222-
use arrow::datatypes::{DataType, Field, Schema};
257+
use arrow::array::{AsArray, Int32Array};
258+
use arrow::datatypes::{DataType, Field, Schema, UInt64Type};
223259
use arrow::error::ArrowError;
224260
use datafusion_expr::LogicalPlanBuilder;
225261
use futures::StreamExt;
@@ -457,6 +493,11 @@ mod tests {
457493
initial_data: Vec<Vec<RecordBatch>>,
458494
inserted_data: Vec<Vec<RecordBatch>>,
459495
) -> Result<Vec<Vec<RecordBatch>>> {
496+
let expected_count: u64 = inserted_data
497+
.iter()
498+
.flat_map(|batches| batches.iter().map(|batch| batch.num_rows() as u64))
499+
.sum();
500+
460501
// Create a new session context
461502
let session_ctx = SessionContext::new();
462503
// Create and register the initial table with the provided schema and data
@@ -480,8 +521,8 @@ mod tests {
480521

481522
// Execute the physical plan and collect the results
482523
let res = collect(plan, session_ctx.task_ctx()).await?;
483-
// Ensure the result is empty after the insert operation
484-
assert!(res.is_empty());
524+
assert_eq!(extract_count(res), expected_count);
525+
485526
// Read the data from the initial table and store it in a vector of partitions
486527
let mut partitions = vec![];
487528
for partition in initial_table.batches.iter() {
@@ -491,6 +532,34 @@ mod tests {
491532
Ok(partitions)
492533
}
493534

535+
/// Returns the value of results. For example, returns 6 given the follwing
536+
///
537+
/// ```text
538+
/// +-------+,
539+
/// | count |,
540+
/// +-------+,
541+
/// | 6 |,
542+
/// +-------+,
543+
/// ```
544+
fn extract_count(res: Vec<RecordBatch>) -> u64 {
545+
assert_eq!(res.len(), 1, "expected one batch, got {}", res.len());
546+
let batch = &res[0];
547+
assert_eq!(
548+
batch.num_columns(),
549+
1,
550+
"expected 1 column, got {}",
551+
batch.num_columns()
552+
);
553+
let col = batch.column(0).as_primitive::<UInt64Type>();
554+
assert_eq!(col.len(), 1, "expected 1 row, got {}", col.len());
555+
let val = col
556+
.iter()
557+
.next()
558+
.expect("had value")
559+
.expect("expected non null");
560+
val
561+
}
562+
494563
// Test inserting a single batch of data into a single partition
495564
#[tokio::test]
496565
async fn test_insert_into_single_partition() -> Result<()> {

0 commit comments

Comments
 (0)