Skip to content

INSERT INTO support for MemTable #5520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::any::Any;
use std::sync::Arc;

use async_trait::async_trait;
use datafusion_common::Statistics;
use datafusion_common::{DataFusionError, Statistics};
use datafusion_expr::{CreateExternalTable, LogicalPlan};
pub use datafusion_expr::{TableProviderFilterPushDown, TableType};

Expand Down Expand Up @@ -97,6 +97,16 @@ pub trait TableProvider: Sync + Send {
fn statistics(&self) -> Option<Statistics> {
None
}

/// Insert into this table
async fn insert_into(
&self,
_state: &SessionState,
_input: &LogicalPlan,
) -> Result<()> {
let msg = "Insertion not implemented for this table".to_owned();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the table has name, its likely better to return it to the user in error message instead of this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think TableProvider has a way to fetch its own name yet 🤔

Copy link
Contributor

@comphead comphead Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imho TableProvider likely should have its own name.... I can investigate the purpose of having name

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name for a particular table provider is currently managed by its containing schema - this allows, for example, the same provider to be registered as different table names.

Err(DataFusionError::NotImplemented(msg))
}
}

/// A factory which creates [`TableProvider`]s at runtime given a URL.
Expand Down
223 changes: 219 additions & 4 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,22 @@
//! queried by DataFusion. This allows data to be pre-loaded into memory and then
//! repeatedly queried without incurring additional file I/O overhead.

use futures::StreamExt;
use futures::{StreamExt, TryStreamExt};
use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_expr::LogicalPlan;
use tokio::sync::RwLock;
use tokio::task;

use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::common;
use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::memory::MemoryExec;
Expand All @@ -41,7 +45,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
#[derive(Debug)]
pub struct MemTable {
schema: SchemaRef,
batches: Vec<Vec<RecordBatch>>,
batches: Arc<RwLock<Vec<Vec<RecordBatch>>>>,
}

impl MemTable {
Expand All @@ -54,7 +58,7 @@ impl MemTable {
{
Ok(Self {
schema,
batches: partitions,
batches: Arc::new(RwLock::new(partitions)),
})
} else {
Err(DataFusionError::Plan(
Expand Down Expand Up @@ -143,22 +147,102 @@ impl TableProvider for MemTable {
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batches = &self.batches.read().await;
Ok(Arc::new(MemoryExec::try_new(
&self.batches.clone(),
batches,
self.schema(),
projection.cloned(),
)?))
}

/// Inserts the execution results of a given [LogicalPlan] into this [MemTable].
/// The `LogicalPlan` must have the same schema as this `MemTable`.
///
/// # Arguments
///
/// * `state` - The [SessionState] containing the context for executing the plan.
/// * `input` - The [LogicalPlan] to execute and insert.
///
/// # Returns
///
/// * A `Result` indicating success or failure.
async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> {
// Create a physical plan from the logical plan.
let plan = state.create_physical_plan(input).await?;

// Check that the schema of the plan matches the schema of this table.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please elaborate why exactly this check is needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would the alternate behavior be? Pad missing columns with nulls?

Copy link
Contributor Author

@metesynnada metesynnada Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check was done with a defensive approach in mind. I am not sure how the two schemas would be different.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think whatever is creating the plan should ensure that the incoming rows match the record batch (as in I prefer this error approach).

if !plan.schema().eq(&self.schema) {
return Err(DataFusionError::Plan(
"Inserting query must have the same schema with the table.".to_string(),
));
}

// Get the number of partitions in the plan and the table.
let plan_partition_count = plan.output_partitioning().partition_count();
let table_partition_count = self.batches.read().await.len();

// Adjust the plan as necessary to match the number of partitions in the table.
let plan: Arc<dyn ExecutionPlan> = if plan_partition_count
== table_partition_count
|| table_partition_count == 0
{
plan
} else if table_partition_count == 1 {
// If the table has only one partition, coalesce the partitions in the plan.
Arc::new(CoalescePartitionsExec::new(plan))
} else {
// Otherwise, repartition the plan using a round-robin partitioning scheme.
Arc::new(RepartitionExec::try_new(
plan,
Partitioning::RoundRobinBatch(table_partition_count),
)?)
};

// Get the task context from the session state.
let task_ctx = state.task_ctx();

// Execute the plan and collect the results into batches.
let mut tasks = vec![];
for idx in 0..plan.output_partitioning().partition_count() {
let stream = plan.execute(idx, task_ctx.clone())?;
let handle = task::spawn(async move {
stream.try_collect().await.map_err(DataFusionError::from)
});
tasks.push(AbortOnDropSingle::new(handle));
}
let results = futures::future::join_all(tasks)
.await
.into_iter()
.map(|result| {
result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
})
.collect::<Result<Vec<Vec<RecordBatch>>>>()?;

// Write the results into the table.
let mut all_batches = self.batches.write().await;

if all_batches.is_empty() {
*all_batches = results
} else {
for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) {
batches.extend(result);
}
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::datasource::provider_as_source;
use crate::from_slice::FromSlice;
use crate::prelude::SessionContext;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use datafusion_expr::LogicalPlanBuilder;
use futures::StreamExt;
use std::collections::HashMap;

Expand Down Expand Up @@ -388,4 +472,135 @@ mod tests {

Ok(())
}

fn create_mem_table_scan(
schema: SchemaRef,
data: Vec<Vec<RecordBatch>>,
) -> Result<Arc<LogicalPlan>> {
// Convert the table into a provider so that it can be used in a query
let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?));
// Create a table scan logical plan to read from the table
Ok(Arc::new(
LogicalPlanBuilder::scan("source", provider, None)?.build()?,
))
}

fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> {
// Create a new session context
let session_ctx = SessionContext::new();
// Create a new schema with one field called "a" of type Int32
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));

// Create a new batch of data to insert into the table
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_slice([1, 2, 3]))],
)?;
Ok((session_ctx, schema, batch))
}

#[tokio::test]
async fn test_insert_into_single_partition() -> Result<()> {
let (session_ctx, schema, batch) = create_initial_ctx()?;
let initial_table = Arc::new(MemTable::try_new(
schema.clone(),
vec![vec![batch.clone()]],
)?);
// Create a table scan logical plan to read from the table
let single_partition_table_scan =
create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?;
// Insert the data from the provider into the table
initial_table
.insert_into(&session_ctx.state(), &single_partition_table_scan)
.await?;
// Ensure that the table now contains two batches of data in the same partition
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);

// Create a new provider with 2 partitions
let multi_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone()], vec![batch]],
)?;

// Insert the data from the provider into the table. We expect coalescing partitions.
initial_table
.insert_into(&session_ctx.state(), &multi_partition_table_scan)
.await?;
// Ensure that the table now contains 4 batches of data with only 1 partition
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
assert_eq!(initial_table.batches.read().await.len(), 1);
Ok(())
}

#[tokio::test]
async fn test_insert_into_multiple_partition() -> Result<()> {
let (session_ctx, schema, batch) = create_initial_ctx()?;
// create a memory table with two partitions, each having one batch with the same data
let initial_table = Arc::new(MemTable::try_new(
schema.clone(),
vec![vec![batch.clone()], vec![batch.clone()]],
)?);

// scan a data source provider from a memory table with a single partition
let single_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone(), batch.clone()]],
)?;

// insert the data from the 1 partition data source provider into the initial table
initial_table
.insert_into(&session_ctx.state(), &single_partition_table_scan)
.await?;

// We expect round robin repartition here, each partition gets 1 batch.
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);
assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2);

// scan a data source provider from a memory table with 2 partition
let multi_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone()], vec![batch]],
)?;
// We expect one-to-one partition mapping.
initial_table
.insert_into(&session_ctx.state(), &multi_partition_table_scan)
.await?;
// Ensure that the table now contains 3 batches of data with 2 partitions.
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3);
assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3);
Ok(())
}

#[tokio::test]
async fn test_insert_into_empty_table() -> Result<()> {
let (session_ctx, schema, batch) = create_initial_ctx()?;
// create empty memory table
let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);

// scan a data source provider from a memory table with a single partition
let single_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone(), batch.clone()]],
)?;

// insert the data from the 1 partition data source provider into the initial table
initial_table
.insert_into(&session_ctx.state(), &single_partition_table_scan)
.await?;

assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2);

// scan a data source provider from a memory table with 2 partition
let single_partition_table_scan = create_mem_table_scan(
schema.clone(),
vec![vec![batch.clone()], vec![batch]],
)?;
// We expect coalesce partitions here.
initial_table
.insert_into(&session_ctx.state(), &single_partition_table_scan)
.await?;
// Ensure that the table now contains 3 batches of data with 2 partitions.
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4);
Ok(())
}
}
23 changes: 21 additions & 2 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{
optimizer::PhysicalOptimizerRule,
},
};
use datafusion_expr::{DescribeTable, StringifiedPlan};
use datafusion_expr::{DescribeTable, DmlStatement, StringifiedPlan, WriteOp};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
use parking_lot::RwLock;
Expand Down Expand Up @@ -308,7 +308,8 @@ impl SessionContext {

/// Creates a [`DataFrame`] that will execute a SQL query.
///
/// Note: This API implements DDL such as `CREATE TABLE` and `CREATE VIEW` with in-memory
/// Note: This API implements DDL statements such as `CREATE TABLE` and
/// `CREATE VIEW` and DML statements such as `INSERT INTO` with in-memory
/// default implementations.
///
/// If this is not desirable, consider using [`SessionState::create_logical_plan()`] which
Expand All @@ -318,6 +319,24 @@ impl SessionContext {
let plan = self.state().create_logical_plan(sql).await?;

match plan {
LogicalPlan::Dml(DmlStatement {
table_name,
op: WriteOp::Insert,
input,
..
}) => {
if self.table_exist(&table_name)? {
let name = table_name.table();
let provider = self.table_provider(name).await?;
provider.insert_into(&self.state(), &input).await?;
} else {
return Err(DataFusionError::Execution(format!(
"Table '{}' does not exist",
table_name
)));
}
self.return_empty_dataframe()
}
LogicalPlan::CreateExternalTable(cmd) => {
self.create_external_table(&cmd).await
}
Expand Down
Loading