Skip to content

Commit 02fd450

Browse files
fix: parallel parquet can underflow when max_record_batch_rows < execution.batch_size (#9737)
* loop split rb * add test * add new test * fmt * lower batch size in test * make test faster * use path not into_path
1 parent 1dbec3e commit 02fd450

File tree

2 files changed

+93
-36
lines changed

2 files changed

+93
-36
lines changed

datafusion/core/src/dataframe/parquet.rs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,18 @@ impl DataFrame {
7474

7575
#[cfg(test)]
7676
mod tests {
77+
use std::collections::HashMap;
7778
use std::sync::Arc;
7879

7980
use super::super::Result;
8081
use super::*;
8182
use crate::arrow::util::pretty;
8283
use crate::execution::context::SessionContext;
8384
use crate::execution::options::ParquetReadOptions;
84-
use crate::test_util;
85+
use crate::test_util::{self, register_aggregate_csv};
8586

8687
use datafusion_common::file_options::parquet_writer::parse_compression_string;
88+
use datafusion_execution::config::SessionConfig;
8789
use datafusion_expr::{col, lit};
8890

8991
use object_store::local::LocalFileSystem;
@@ -150,7 +152,7 @@ mod tests {
150152
.await?;
151153

152154
// Check that file actually used the specified compression
153-
let file = std::fs::File::open(tmp_dir.into_path().join("test.parquet"))?;
155+
let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?;
154156

155157
let reader =
156158
parquet::file::serialized_reader::SerializedFileReader::new(file)
@@ -166,4 +168,54 @@ mod tests {
166168

167169
Ok(())
168170
}
171+
172+
#[tokio::test]
173+
async fn write_parquet_with_small_rg_size() -> Result<()> {
174+
// This test verifies writing a parquet file with small rg size
175+
// relative to datafusion.execution.batch_size does not panic
176+
let mut ctx = SessionContext::new_with_config(
177+
SessionConfig::from_string_hash_map(HashMap::from_iter(
178+
[("datafusion.execution.batch_size", "10")]
179+
.iter()
180+
.map(|(s1, s2)| (s1.to_string(), s2.to_string())),
181+
))?,
182+
);
183+
register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
184+
let test_df = ctx.table("aggregate_test_100").await?;
185+
186+
let output_path = "file://local/test.parquet";
187+
188+
for rg_size in 1..10 {
189+
let df = test_df.clone();
190+
let tmp_dir = TempDir::new()?;
191+
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
192+
let local_url = Url::parse("file://local").unwrap();
193+
let ctx = &test_df.session_state;
194+
ctx.runtime_env().register_object_store(&local_url, local);
195+
let mut options = TableParquetOptions::default();
196+
options.global.max_row_group_size = rg_size;
197+
options.global.allow_single_file_parallelism = true;
198+
df.write_parquet(
199+
output_path,
200+
DataFrameWriteOptions::new().with_single_file_output(true),
201+
Some(options),
202+
)
203+
.await?;
204+
205+
// Check that file actually used the correct rg size
206+
let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?;
207+
208+
let reader =
209+
parquet::file::serialized_reader::SerializedFileReader::new(file)
210+
.unwrap();
211+
212+
let parquet_metadata = reader.metadata();
213+
214+
let written_rows = parquet_metadata.row_group(0).num_rows();
215+
216+
assert_eq!(written_rows as usize, rg_size);
217+
}
218+
219+
Ok(())
220+
}
169221
}

datafusion/core/src/datasource/file_format/parquet.rs

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -876,42 +876,47 @@ fn spawn_parquet_parallel_serialization_task(
876876
)?;
877877
let mut current_rg_rows = 0;
878878

879-
while let Some(rb) = data.recv().await {
880-
if current_rg_rows + rb.num_rows() < max_row_group_rows {
881-
send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone())
882-
.await?;
883-
current_rg_rows += rb.num_rows();
884-
} else {
885-
let rows_left = max_row_group_rows - current_rg_rows;
886-
let a = rb.slice(0, rows_left);
887-
send_arrays_to_col_writers(&col_array_channels, &a, schema.clone())
888-
.await?;
879+
while let Some(mut rb) = data.recv().await {
880+
// This loop allows the "else" block to repeatedly split the RecordBatch to handle the case
881+
// when max_row_group_rows < execution.batch_size as an alternative to a recursive async
882+
// function.
883+
loop {
884+
if current_rg_rows + rb.num_rows() < max_row_group_rows {
885+
send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone())
886+
.await?;
887+
current_rg_rows += rb.num_rows();
888+
break;
889+
} else {
890+
let rows_left = max_row_group_rows - current_rg_rows;
891+
let a = rb.slice(0, rows_left);
892+
send_arrays_to_col_writers(&col_array_channels, &a, schema.clone())
893+
.await?;
894+
895+
// Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup
896+
// on a separate task, so that we can immediately start on the next RG before waiting
897+
// for the current one to finish.
898+
drop(col_array_channels);
899+
let finalize_rg_task = spawn_rg_join_and_finalize_task(
900+
column_writer_handles,
901+
max_row_group_rows,
902+
);
903+
904+
serialize_tx.send(finalize_rg_task).await.map_err(|_| {
905+
DataFusionError::Internal(
906+
"Unable to send closed RG to concat task!".into(),
907+
)
908+
})?;
889909

890-
// Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup
891-
// on a separate task, so that we can immediately start on the next RG before waiting
892-
// for the current one to finish.
893-
drop(col_array_channels);
894-
let finalize_rg_task = spawn_rg_join_and_finalize_task(
895-
column_writer_handles,
896-
max_row_group_rows,
897-
);
898-
899-
serialize_tx.send(finalize_rg_task).await.map_err(|_| {
900-
DataFusionError::Internal(
901-
"Unable to send closed RG to concat task!".into(),
902-
)
903-
})?;
910+
current_rg_rows = 0;
911+
rb = rb.slice(rows_left, rb.num_rows() - rows_left);
904912

905-
let b = rb.slice(rows_left, rb.num_rows() - rows_left);
906-
(column_writer_handles, col_array_channels) =
907-
spawn_column_parallel_row_group_writer(
908-
schema.clone(),
909-
writer_props.clone(),
910-
max_buffer_rb,
911-
)?;
912-
send_arrays_to_col_writers(&col_array_channels, &b, schema.clone())
913-
.await?;
914-
current_rg_rows = b.num_rows();
913+
(column_writer_handles, col_array_channels) =
914+
spawn_column_parallel_row_group_writer(
915+
schema.clone(),
916+
writer_props.clone(),
917+
max_buffer_rb,
918+
)?;
919+
}
915920
}
916921
}
917922

0 commit comments

Comments
 (0)