-
Notifications
You must be signed in to change notification settings - Fork 1.5k
INSERT INTO support for MemTable #5520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7471630
d8d767b
3cccce1
a7f7943
6be6e1b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,18 +19,22 @@ | |
//! queried by DataFusion. This allows data to be pre-loaded into memory and then | ||
//! repeatedly queried without incurring additional file I/O overhead. | ||
|
||
use futures::StreamExt; | ||
use futures::{StreamExt, TryStreamExt}; | ||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
||
use arrow::datatypes::SchemaRef; | ||
use arrow::record_batch::RecordBatch; | ||
use async_trait::async_trait; | ||
use datafusion_expr::LogicalPlan; | ||
use tokio::sync::RwLock; | ||
use tokio::task; | ||
|
||
use crate::datasource::{TableProvider, TableType}; | ||
use crate::error::{DataFusionError, Result}; | ||
use crate::execution::context::SessionState; | ||
use crate::logical_expr::Expr; | ||
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; | ||
use crate::physical_plan::common; | ||
use crate::physical_plan::common::AbortOnDropSingle; | ||
use crate::physical_plan::memory::MemoryExec; | ||
|
@@ -41,7 +45,7 @@ use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; | |
#[derive(Debug)] | ||
pub struct MemTable { | ||
schema: SchemaRef, | ||
batches: Vec<Vec<RecordBatch>>, | ||
batches: Arc<RwLock<Vec<Vec<RecordBatch>>>>, | ||
} | ||
|
||
impl MemTable { | ||
|
@@ -54,7 +58,7 @@ impl MemTable { | |
{ | ||
Ok(Self { | ||
schema, | ||
batches: partitions, | ||
batches: Arc::new(RwLock::new(partitions)), | ||
}) | ||
} else { | ||
Err(DataFusionError::Plan( | ||
|
@@ -143,22 +147,102 @@ impl TableProvider for MemTable { | |
_filters: &[Expr], | ||
_limit: Option<usize>, | ||
) -> Result<Arc<dyn ExecutionPlan>> { | ||
let batches = &self.batches.read().await; | ||
Ok(Arc::new(MemoryExec::try_new( | ||
&self.batches.clone(), | ||
batches, | ||
self.schema(), | ||
projection.cloned(), | ||
)?)) | ||
} | ||
|
||
/// Inserts the execution results of a given [LogicalPlan] into this [MemTable]. | ||
/// The `LogicalPlan` must have the same schema as this `MemTable`. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `state` - The [SessionState] containing the context for executing the plan. | ||
/// * `input` - The [LogicalPlan] to execute and insert. | ||
/// | ||
/// # Returns | ||
/// | ||
/// * A `Result` indicating success or failure. | ||
async fn insert_into(&self, state: &SessionState, input: &LogicalPlan) -> Result<()> { | ||
// Create a physical plan from the logical plan. | ||
let plan = state.create_physical_plan(input).await?; | ||
|
||
// Check that the schema of the plan matches the schema of this table. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please elaborate why exactly this check is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would the alternate behavior be? Pad missing columns with nulls? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check was done with a defensive approach in mind. I am not sure how the two schemas would be different. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think whatever is creating the plan should ensure that the incoming rows match the record batch (as in I prefer this error approach). |
||
if !plan.schema().eq(&self.schema) { | ||
return Err(DataFusionError::Plan( | ||
"Inserting query must have the same schema with the table.".to_string(), | ||
)); | ||
} | ||
|
||
// Get the number of partitions in the plan and the table. | ||
let plan_partition_count = plan.output_partitioning().partition_count(); | ||
let table_partition_count = self.batches.read().await.len(); | ||
|
||
// Adjust the plan as necessary to match the number of partitions in the table. | ||
let plan: Arc<dyn ExecutionPlan> = if plan_partition_count | ||
== table_partition_count | ||
|| table_partition_count == 0 | ||
{ | ||
plan | ||
} else if table_partition_count == 1 { | ||
// If the table has only one partition, coalesce the partitions in the plan. | ||
Arc::new(CoalescePartitionsExec::new(plan)) | ||
} else { | ||
// Otherwise, repartition the plan using a round-robin partitioning scheme. | ||
Arc::new(RepartitionExec::try_new( | ||
plan, | ||
Partitioning::RoundRobinBatch(table_partition_count), | ||
)?) | ||
}; | ||
|
||
// Get the task context from the session state. | ||
let task_ctx = state.task_ctx(); | ||
|
||
// Execute the plan and collect the results into batches. | ||
let mut tasks = vec![]; | ||
for idx in 0..plan.output_partitioning().partition_count() { | ||
let stream = plan.execute(idx, task_ctx.clone())?; | ||
let handle = task::spawn(async move { | ||
stream.try_collect().await.map_err(DataFusionError::from) | ||
}); | ||
tasks.push(AbortOnDropSingle::new(handle)); | ||
} | ||
let results = futures::future::join_all(tasks) | ||
.await | ||
.into_iter() | ||
.map(|result| { | ||
result.map_err(|e| DataFusionError::Execution(format!("{e}")))? | ||
}) | ||
.collect::<Result<Vec<Vec<RecordBatch>>>>()?; | ||
|
||
// Write the results into the table. | ||
let mut all_batches = self.batches.write().await; | ||
|
||
if all_batches.is_empty() { | ||
*all_batches = results | ||
} else { | ||
for (batches, result) in all_batches.iter_mut().zip(results.into_iter()) { | ||
batches.extend(result); | ||
} | ||
} | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use crate::datasource::provider_as_source; | ||
use crate::from_slice::FromSlice; | ||
use crate::prelude::SessionContext; | ||
use arrow::array::Int32Array; | ||
use arrow::datatypes::{DataType, Field, Schema}; | ||
use arrow::error::ArrowError; | ||
use datafusion_expr::LogicalPlanBuilder; | ||
use futures::StreamExt; | ||
use std::collections::HashMap; | ||
|
||
|
@@ -388,4 +472,135 @@ mod tests { | |
|
||
Ok(()) | ||
} | ||
|
||
fn create_mem_table_scan( | ||
schema: SchemaRef, | ||
data: Vec<Vec<RecordBatch>>, | ||
) -> Result<Arc<LogicalPlan>> { | ||
// Convert the table into a provider so that it can be used in a query | ||
let provider = provider_as_source(Arc::new(MemTable::try_new(schema, data)?)); | ||
// Create a table scan logical plan to read from the table | ||
Ok(Arc::new( | ||
LogicalPlanBuilder::scan("source", provider, None)?.build()?, | ||
)) | ||
} | ||
|
||
fn create_initial_ctx() -> Result<(SessionContext, SchemaRef, RecordBatch)> { | ||
// Create a new session context | ||
let session_ctx = SessionContext::new(); | ||
// Create a new schema with one field called "a" of type Int32 | ||
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); | ||
|
||
// Create a new batch of data to insert into the table | ||
let batch = RecordBatch::try_new( | ||
schema.clone(), | ||
vec![Arc::new(Int32Array::from_slice([1, 2, 3]))], | ||
)?; | ||
Ok((session_ctx, schema, batch)) | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_insert_into_single_partition() -> Result<()> { | ||
let (session_ctx, schema, batch) = create_initial_ctx()?; | ||
let initial_table = Arc::new(MemTable::try_new( | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
schema.clone(), | ||
vec![vec![batch.clone()]], | ||
)?); | ||
// Create a table scan logical plan to read from the table | ||
let single_partition_table_scan = | ||
create_mem_table_scan(schema.clone(), vec![vec![batch.clone()]])?; | ||
// Insert the data from the provider into the table | ||
initial_table | ||
.insert_into(&session_ctx.state(), &single_partition_table_scan) | ||
.await?; | ||
// Ensure that the table now contains two batches of data in the same partition | ||
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); | ||
|
||
// Create a new provider with 2 partitions | ||
let multi_partition_table_scan = create_mem_table_scan( | ||
schema.clone(), | ||
vec![vec![batch.clone()], vec![batch]], | ||
)?; | ||
|
||
// Insert the data from the provider into the table. We expect coalescing partitions. | ||
initial_table | ||
.insert_into(&session_ctx.state(), &multi_partition_table_scan) | ||
.await?; | ||
// Ensure that the table now contains 4 batches of data with only 1 partition | ||
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); | ||
assert_eq!(initial_table.batches.read().await.len(), 1); | ||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_insert_into_multiple_partition() -> Result<()> { | ||
let (session_ctx, schema, batch) = create_initial_ctx()?; | ||
// create a memory table with two partitions, each having one batch with the same data | ||
let initial_table = Arc::new(MemTable::try_new( | ||
schema.clone(), | ||
vec![vec![batch.clone()], vec![batch.clone()]], | ||
)?); | ||
|
||
// scan a data source provider from a memory table with a single partition | ||
let single_partition_table_scan = create_mem_table_scan( | ||
schema.clone(), | ||
vec![vec![batch.clone(), batch.clone()]], | ||
)?; | ||
|
||
// insert the data from the 1 partition data source provider into the initial table | ||
initial_table | ||
.insert_into(&session_ctx.state(), &single_partition_table_scan) | ||
.await?; | ||
|
||
// We expect round robin repartition here, each partition gets 1 batch. | ||
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); | ||
assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 2); | ||
|
||
// scan a data source provider from a memory table with 2 partition | ||
let multi_partition_table_scan = create_mem_table_scan( | ||
schema.clone(), | ||
vec![vec![batch.clone()], vec![batch]], | ||
)?; | ||
// We expect one-to-one partition mapping. | ||
initial_table | ||
.insert_into(&session_ctx.state(), &multi_partition_table_scan) | ||
.await?; | ||
// Ensure that the table now contains 3 batches of data with 2 partitions. | ||
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 3); | ||
assert_eq!(initial_table.batches.read().await.get(1).unwrap().len(), 3); | ||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_insert_into_empty_table() -> Result<()> { | ||
let (session_ctx, schema, batch) = create_initial_ctx()?; | ||
// create empty memory table | ||
let initial_table = Arc::new(MemTable::try_new(schema.clone(), vec![])?); | ||
|
||
// scan a data source provider from a memory table with a single partition | ||
let single_partition_table_scan = create_mem_table_scan( | ||
schema.clone(), | ||
vec![vec![batch.clone(), batch.clone()]], | ||
)?; | ||
|
||
// insert the data from the 1 partition data source provider into the initial table | ||
initial_table | ||
.insert_into(&session_ctx.state(), &single_partition_table_scan) | ||
.await?; | ||
|
||
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 2); | ||
|
||
// scan a data source provider from a memory table with 2 partition | ||
let single_partition_table_scan = create_mem_table_scan( | ||
schema.clone(), | ||
vec![vec![batch.clone()], vec![batch]], | ||
)?; | ||
// We expect coalesce partitions here. | ||
initial_table | ||
.insert_into(&session_ctx.state(), &single_partition_table_scan) | ||
.await?; | ||
// Ensure that the table now contains 3 batches of data with 2 partitions. | ||
assert_eq!(initial_table.batches.read().await.get(0).unwrap().len(), 4); | ||
Ok(()) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the table has name, its likely better to return it to the user in error message instead of
this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think
TableProvider
has a way to fetch its own name yet 🤔Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imho
TableProvider
likely should have its own name.... I can investigate the purpose of havingname
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the name for a particular table provider is currently managed by its containing schema - this allows, for example, the same provider to be registered as different table names.