Skip to content

Commit ff55eff

Browse files
INSERT INTO support for MemTable (#5520)
* Insert into memory table * Code simplifications * Minor comment refactor * Revamping tests and refactor code --------- Co-authored-by: Mehmet Ozan Kabak <[email protected]>
1 parent 3df1cb3 commit ff55eff

File tree

7 files changed

+292
-111
lines changed

7 files changed

+292
-111
lines changed

datafusion/core/src/datasource/datasource.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::any::Any;
2121
use std::sync::Arc;
2222

2323
use async_trait::async_trait;
24-
use datafusion_common::Statistics;
24+
use datafusion_common::{DataFusionError, Statistics};
2525
use datafusion_expr::{CreateExternalTable, LogicalPlan};
2626
pub use datafusion_expr::{TableProviderFilterPushDown, TableType};
2727

@@ -97,6 +97,16 @@ pub trait TableProvider: Sync + Send {
9797
fn statistics(&self) -> Option<Statistics> {
9898
None
9999
}
100+
101+
/// Insert into this table
102+
async fn insert_into(
103+
&self,
104+
_state: &SessionState,
105+
_input: &LogicalPlan,
106+
) -> Result<()> {
107+
let msg = "Insertion not implemented for this table".to_owned();
108+
Err(DataFusionError::NotImplemented(msg))
109+
}
100110
}
101111

102112
/// A factory which creates [`TableProvider`]s at runtime given a URL.

datafusion/core/src/datasource/memory.rs

Lines changed: 219 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,22 @@
1919
//! queried by DataFusion. This allows data to be pre-loaded into memory and then
2020
//! repeatedly queried without incurring additional file I/O overhead.
2121
22-
use futures::StreamExt;
22+
use futures::{StreamExt, TryStreamExt};
2323
use std::any::Any;
2424
use std::sync::Arc;
2525

2626
use arrow::datatypes::SchemaRef;
2727
use arrow::record_batch::RecordBatch;
2828
use async_trait::async_trait;
29+
use datafusion_expr::LogicalPlan;
30+
use tokio::sync::RwLock;
31+
use tokio::task;
2932

3033
use crate::datasource::{TableProvider, TableType};
3134
use crate::error::{DataFusionError, Result};
3235
use crate::execution::context::SessionState;
3336
use crate::logical_expr::Expr;
37+
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
3438
use crate::physical_plan::common;
3539
use crate::physical_plan::common::AbortOnDropSingle;
3640
use crate::physical_plan::memory::MemoryExec;
@@ -41,7 +45,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
4145
#[derive(Debug)]
4246
pub struct MemTable {
4347
schema: SchemaRef,
44-
batches: Vec<Vec<RecordBatch>>,
48+
batches: Arc<RwLock<Vec<Vec<RecordBatch>>>>,
4549
}
4650

4751
impl MemTable {
@@ -54,7 +58,7 @@ impl MemTable {
5458
{
5559
Ok(Self {
5660
schema,
57-
batches: partitions,
61+
batches: Arc::new(RwLock::new(partitions)),
5862
})
5963
} else {
6064
Err(DataFusionError::Plan(
@@ -143,22 +147,102 @@ impl TableProvider for MemTable {
143147
_filters: &[Expr],
144148
_limit: Option<usize>,
145149
) -> Result<Arc<dyn ExecutionPlan>> {
150+
let batches = &self.batches.read().await;
146151
Ok(Arc::new(MemoryExec::try_new(
147-
&self.batches.clone(),
152+
batches,
148153
self.schema(),
149154
projection.cloned(),
150155
)?))
151156
}
157+
158+
/// Inserts the execution results of a given [LogicalPlan] into this [MemTable].
159+
/// The `LogicalPlan` must have the same schema as this `MemTable`.
160+
///
161+
/// # Arguments
162+
///
163+
/// * `state` - The [SessionState] containing the context for executing the plan.
164+
/// * `input` - The [LogicalPlan] to execute and insert.
165+
///
166+
/// # Returns
167+
///
168+
/// * A `Result` indicating success or failure.
169+
async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> {
170+
// Create a physical plan from the logical plan.
171+
let plan = state.create_physical_plan(input).await?;
172+
173+
// Check that the schema of the plan matches the schema of this table.
174+
if !plan.schema().eq(&self.schema) {
175+
return Err(DataFusionError::Plan(
176+
"Inserting query must have the same schema with the table.".to_string(),
177+
));
178+
}
179+
180+
// Get the number of partitions in the plan and the table.
181+
let plan_partition_count = plan.output_partitioning().partition_count();
182+
let table_partition_count = self.batches.read().await.len();
183+
184+
// Adjust the plan as necessary to match the number of partitions in the table.
185+
let plan: Arc<dyn ExecutionPlan> = if plan_partition_count
186+
== table_partition_count
187+
|| table_partition_count == 0
188+
{
189+
plan
190+
} else if table_partition_count == 1 {
191+
// If the table has only one partition, coalesce the partitions in the plan.
192+
Arc::new(CoalescePartitionsExec::new(plan))
193+
} else {
194+
// Otherwise, repartition the plan using a round-robin partitioning scheme.
195+
Arc::new(RepartitionExec::try_new(
196+
plan,
197+
Partitioning::RoundRobinBatch(table_partition_count),
198+
)?)
199+
};
200+
201+
// Get the task context from the session state.
202+
let task_ctx = state.task_ctx();
203+
204+
// Execute the plan and collect the results into batches.
205+
let mut tasks = vec![];
206+
for idx in 0..plan.output_partitioning().partition_count() {
207+
let stream = plan.execute(idx, task_ctx.clone())?;
208+
let handle = task::spawn(async move {
209+
stream.try_collect().await.map_err(DataFusionError::from)
210+
});
211+
tasks.push(AbortOnDropSingle::new(handle));
212+
}
213+
let results = futures::future::join_all(tasks)
214+
.await
215+
.into_iter()
216+
.map(|result| {
217+
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
218+
})
219+
.collect::<Result<Vec<Vec<RecordBatch>>>>()?;
220+
221+
// Write the results into the table.
222+
let mut all_batches = self.batches.write().await;
223+
224+
if all_batches.is_empty() {
225+
*all_batches = results
226+
} else {
227+
for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) {
228+
batches.extend(result);
229+
}
230+
}
231+
232+
Ok(())
233+
}
152234
}
153235

154236
#[cfg(test)]
155237
mod tests {
156238
use super::*;
239+
use crate::datasource::provider_as_source;
157240
use crate::from_slice::FromSlice;
158241
use crate::prelude::SessionContext;
159242
use arrow::array::Int32Array;
160243
use arrow::datatypes::{DataType, Field, Schema};
161244
use arrow::error::ArrowError;
245+
use datafusion_expr::LogicalPlanBuilder;
162246
use futures::StreamExt;
163247
use std::collections::HashMap;
164248

@@ -388,4 +472,135 @@ mod tests {
388472

389473
Ok(())
390474
}
475+
476+
fn create_mem_table_scan(
477+
schema: SchemaRef,
478+
data: Vec<Vec<RecordBatch>>,
479+
) -> Result<Arc<LogicalPlan>> {
480+
// Convert the table into a provider so that it can be used in a query
481+
let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?));
482+
// Create a table scan logical plan to read from the table
483+
Ok(Arc::new(
484+
LogicalPlanBuilder::scan("source", provider, None)?.build()?,
485+
))
486+
}
487+
488+
fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> {
489+
// Create a new session context
490+
let session_ctx = SessionContext::new();
491+
// Create a new schema with one field called "a" of type Int32
492+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
493+
494+
// Create a new batch of data to insert into the table
495+
let batch = RecordBatch::try_new(
496+
schema.clone(),
497+
vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
498+
)?;
499+
Ok((session_ctx, schema, batch))
500+
}
501+
502+
#[tokio::test]
503+
async fn test_insert_into_single_partition() -> Result<()> {
504+
let (session_ctx, schema, batch) = create_initial_ctx()?;
505+
let initial_table = Arc::new(MemTable::try_new(
506+
schema.clone(),
507+
vec![vec![batch.clone()]],
508+
)?);
509+
// Create a table scan logical plan to read from the table
510+
let single_partition_table_scan =
511+
create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?;
512+
// Insert the data from the provider into the table
513+
initial_table
514+
.insert_into(&session_ctx.state(), &single_partition_table_scan)
515+
.await?;
516+
// Ensure that the table now contains two batches of data in the same partition
517+
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
518+
519+
// Create a new provider with 2 partitions
520+
let multi_partition_table_scan = create_mem_table_scan(
521+
schema.clone(),
522+
vec![vec![batch.clone()], vec![batch]],
523+
)?;
524+
525+
// Insert the data from the provider into the table. We expect coalescing partitions.
526+
initial_table
527+
.insert_into(&session_ctx.state(), &multi_partition_table_scan)
528+
.await?;
529+
// Ensure that the table now contains 4 batches of data with only 1 partition
530+
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
531+
assert_eq!(initial_table.batches.read().await.len(), 1);
532+
Ok(())
533+
}
534+
535+
#[tokio::test]
536+
async fn test_insert_into_multiple_partition() -> Result<()> {
537+
let (session_ctx, schema, batch) = create_initial_ctx()?;
538+
// create a memory table with two partitions, each having one batch with the same data
539+
let initial_table = Arc::new(MemTable::try_new(
540+
schema.clone(),
541+
vec![vec![batch.clone()], vec![batch.clone()]],
542+
)?);
543+
544+
// scan a data source provider from a memory table with a single partition
545+
let single_partition_table_scan = create_mem_table_scan(
546+
schema.clone(),
547+
vec![vec![batch.clone(), batch.clone()]],
548+
)?;
549+
550+
// insert the data from the 1 partition data source provider into the initial table
551+
initial_table
552+
.insert_into(&session_ctx.state(), &single_partition_table_scan)
553+
.await?;
554+
555+
// We expect round robin repartition here, each partition gets 1 batch.
556+
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
557+
assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2);
558+
559+
// scan a data source provider from a memory table with 2 partition
560+
let multi_partition_table_scan = create_mem_table_scan(
561+
schema.clone(),
562+
vec![vec![batch.clone()], vec![batch]],
563+
)?;
564+
// We expect one-to-one partition mapping.
565+
initial_table
566+
.insert_into(&session_ctx.state(), &multi_partition_table_scan)
567+
.await?;
568+
// Ensure that the table now contains 3 batches of data with 2 partitions.
569+
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3);
570+
assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3);
571+
Ok(())
572+
}
573+
574+
#[tokio::test]
575+
async fn test_insert_into_empty_table() -> Result<()> {
576+
let (session_ctx, schema, batch) = create_initial_ctx()?;
577+
// create empty memory table
578+
let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
579+
580+
// scan a data source provider from a memory table with a single partition
581+
let single_partition_table_scan = create_mem_table_scan(
582+
schema.clone(),
583+
vec![vec![batch.clone(), batch.clone()]],
584+
)?;
585+
586+
// insert the data from the 1 partition data source provider into the initial table
587+
initial_table
588+
.insert_into(&session_ctx.state(), &single_partition_table_scan)
589+
.await?;
590+
591+
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
592+
593+
// scan a data source provider from a memory table with 2 partition
594+
let single_partition_table_scan = create_mem_table_scan(
595+
schema.clone(),
596+
vec![vec![batch.clone()], vec![batch]],
597+
)?;
598+
// We expect coalesce partitions here.
599+
initial_table
600+
.insert_into(&session_ctx.state(), &single_partition_table_scan)
601+
.await?;
602+
// Ensure that the table now contains 3 batches of data with 2 partitions.
603+
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
604+
Ok(())
605+
}
391606
}

datafusion/core/src/execution/context.rs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use crate::{
3131
optimizer::PhysicalOptimizerRule,
3232
},
3333
};
34-
use datafusion_expr::{DescribeTable, StringifiedPlan};
34+
use datafusion_expr::{DescribeTable, DmlStatement, StringifiedPlan, WriteOp};
3535
pub use datafusion_physical_expr::execution_props::ExecutionProps;
3636
use datafusion_physical_expr::var_provider::is_system_variables;
3737
use parking_lot::RwLock;
@@ -308,7 +308,8 @@ impl SessionContext {
308308

309309
/// Creates a [`DataFrame`] that will execute a SQL query.
310310
///
311-
/// Note: This API implements DDL such as `CREATE TABLE` and `CREATE VIEW` with in-memory
311+
/// Note: This API implements DDL statements such as `CREATE TABLE` and
312+
/// `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory
312313
/// default implementations.
313314
///
314315
/// If this is not desirable, consider using [`SessionState::create_logical_plan()`] which
@@ -318,6 +319,24 @@ impl SessionContext {
318319
let plan = self.state().create_logical_plan(sql).await?;
319320

320321
match plan {
322+
LogicalPlan::Dml(DmlStatement {
323+
table_name,
324+
op: WriteOp::Insert,
325+
input,
326+
..
327+
}) => {
328+
if self.table_exist(&table_name)? {
329+
let name = table_name.table();
330+
let provider = self.table_provider(name).await?;
331+
provider.insert_into(&self.state(), &input).await?;
332+
} else {
333+
return Err(DataFusionError::Execution(format!(
334+
"Table '{}' does not exist",
335+
table_name
336+
)));
337+
}
338+
self.return_empty_dataframe()
339+
}
321340
LogicalPlan::CreateExternalTable(cmd) => {
322341
self.create_external_table(&cmd).await
323342
}

0 commit comments

Comments
 (0)