Skip to content

Commit edff957

Browse files
committed
[HSTACK] FIXUP distributed plan serde cleanup
1 parent c597a4e commit edff957

File tree

3 files changed

+61
-93
lines changed

3 files changed

+61
-93
lines changed

python/datafusion/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from . import functions, object_store, substrait
3030

3131
# The following imports are okay to remain as opaque to the user.
32-
from ._internal import Config, partition_stream
32+
from ._internal import Config
3333
from .catalog import Catalog, Database, Table
3434
from .common import (
3535
DFSchema,
@@ -86,7 +86,6 @@
8686
"read_avro",
8787
"read_csv",
8888
"read_json",
89-
"partition_stream",
9089
]
9190

9291

src/dataframe.rs

Lines changed: 60 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -29,38 +29,37 @@ use datafusion::arrow::datatypes::Schema;
2929
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
3030
use datafusion::arrow::util::pretty;
3131
use datafusion::common::stats::Precision;
32-
use datafusion::common::{DFSchema, DataFusionError, Statistics, UnnestOptions};
3332
use datafusion::common::tree_node::{Transformed, TreeNode};
33+
use datafusion::common::{DFSchema, DataFusionError, Statistics, UnnestOptions};
3434
use datafusion::config::{CsvOptions, TableParquetOptions};
3535
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
3636
use datafusion::datasource::memory::DataSourceExec;
37-
use datafusion::datasource::TableProvider;
3837
use datafusion::datasource::physical_plan::FileScanConfig;
3938
use datafusion::datasource::source::DataSource;
40-
use datafusion::execution::{SendableRecordBatchStream};
39+
use datafusion::datasource::TableProvider;
40+
use datafusion::execution::SendableRecordBatchStream;
4141
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
4242
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
4343
use datafusion::prelude::*;
4444

45-
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
46-
use datafusion_proto::protobuf::PhysicalPlanNode;
47-
use deltalake::delta_datafusion::DeltaPhysicalCodec;
48-
use prost::Message;
4945
use pyo3::exceptions::PyValueError;
5046
use pyo3::prelude::*;
5147
use pyo3::pybacked::PyBackedStr;
52-
use pyo3::types::{PyBytes, PyCapsule, PyDict, PyTuple, PyTupleMethods};
48+
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
5349
use tokio::task::JoinHandle;
5450

5551
use crate::catalog::PyTable;
5652
use crate::common::df_schema::PyDFSchema;
5753
use crate::errors::{py_datafusion_err, PyDataFusionError};
5854
use crate::expr::sort_expr::to_sort_expressions;
59-
use crate::physical_plan::{ codec, PyExecutionPlan } ;
55+
use crate::physical_plan::PyExecutionPlan;
6056
use crate::record_batch::PyRecordBatchStream;
6157
use crate::sql::logical::PyLogicalPlan;
6258
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
63-
use crate::{errors::PyDataFusionResult, expr::{sort_expr::PySortExpr, PyExpr}};
59+
use crate::{
60+
errors::PyDataFusionResult,
61+
expr::{sort_expr::PySortExpr, PyExpr},
62+
};
6463

6564
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
6665
// - we have not decided on the table_provider approach yet
@@ -712,7 +711,6 @@ impl PyDataFrame {
712711
let future_plan = DistributedPlan::try_new(self.df.as_ref());
713712
wait_for_future(py, future_plan).map_err(py_datafusion_err)
714713
}
715-
716714
}
717715

718716
#[pyclass(get_all)]
@@ -724,37 +722,16 @@ pub struct DistributedPlan {
724722

725723
#[pymethods]
726724
impl DistributedPlan {
727-
728-
fn marshal(&self, py: Python) -> PyResult<PyObject> {
729-
let bytes = PhysicalPlanNode::try_from_physical_plan(self.plan().clone(), codec())
730-
.map(|node| node.encode_to_vec())
731-
.map_err(py_datafusion_err)?;
732-
let state = PyDict::new(py);
733-
state.set_item("plan", PyBytes::new(py, bytes.as_slice()))?;
734-
state.set_item("min_size", self.min_size)?;
735-
Ok(state.into())
736-
}
737-
738725
#[new]
739-
fn unmarshal(state: Bound<PyDict>) -> PyResult<Self>{
740-
let ctx = SessionContext::new();
741-
let serialized_plan = state.get_item("plan")?
742-
.expect("missing key `plan` from state");
743-
let serialized_plan = serialized_plan
744-
.downcast::<PyBytes>()?
745-
.as_bytes();
746-
let min_size = state.get_item("min_size")?
747-
.expect("missing key `min_size` from state")
748-
.extract::<usize>()?;
749-
let plan = deserialize_plan(serialized_plan, &ctx)?;
726+
fn new(physical_plan: PyExecutionPlan, min_size: usize) -> PyResult<Self> {
750727
Ok(Self {
751728
min_size,
752-
physical_plan: PyExecutionPlan::new(plan),
729+
physical_plan,
753730
})
754731
}
755732

756733
fn partition_count(&self) -> usize {
757-
self.plan().output_partitioning().partition_count()
734+
self.physical_plan.partition_count()
758735
}
759736

760737
fn num_bytes(&self) -> Option<usize> {
@@ -772,51 +749,68 @@ impl DistributedPlan {
772749
}
773750

774751
fn set_desired_parallelism(&mut self, desired_parallelism: usize) -> PyResult<()> {
775-
let updated_plan = self.plan().clone().transform_up(|node| {
776-
if let Some(exec) = node.as_any().downcast_ref::<DataSourceExec>() {
777-
// Remove redundant ranges from partition files because FileScanConfig refuses to repartition
778-
// if any file has a range defined (even when the range actually covers the entire file).
779-
// The EnforceDistribution optimizer rule adds ranges for both full and partial files,
780-
// so this tries to revert that in order to trigger a repartition when no files are actually split.
781-
if let Some(file_scan) = exec.data_source().as_any().downcast_ref::<FileScanConfig>() {
782-
let mut range_free_file_scan = file_scan.clone();
783-
let mut total_size: usize = 0;
784-
for group in range_free_file_scan.file_groups.iter_mut() {
785-
for file in group.iter_mut() {
786-
if let Some(range) = &file.range {
787-
total_size += (range.end - range.start) as usize;
788-
if range.start == 0 && range.end == file.object_meta.size as i64 {
789-
file.range = None; // remove redundant range
752+
let updated_plan = self
753+
.plan()
754+
.clone()
755+
.transform_up(|node| {
756+
if let Some(exec) = node.as_any().downcast_ref::<DataSourceExec>() {
757+
// Remove redundant ranges from partition files because FileScanConfig refuses to repartition
758+
// if any file has a range defined (even when the range actually covers the entire file).
759+
// The EnforceDistribution optimizer rule adds ranges for both full and partial files,
760+
// so this tries to revert that in order to trigger a repartition when no files are actually split.
761+
if let Some(file_scan) =
762+
exec.data_source().as_any().downcast_ref::<FileScanConfig>()
763+
{
764+
let mut range_free_file_scan = file_scan.clone();
765+
let mut total_size: usize = 0;
766+
for group in range_free_file_scan.file_groups.iter_mut() {
767+
for file in group.iter_mut() {
768+
if let Some(range) = &file.range {
769+
total_size += (range.end - range.start) as usize;
770+
if range.start == 0 && range.end == file.object_meta.size as i64
771+
{
772+
file.range = None; // remove redundant range
773+
}
774+
} else {
775+
total_size += file.object_meta.size;
790776
}
791-
} else {
792-
total_size += file.object_meta.size;
793777
}
794778
}
795-
}
796-
let min_size_buckets = max(1, total_size.div_ceil(self.min_size));
797-
let partitions = min(min_size_buckets, desired_parallelism);
798-
let ordering = range_free_file_scan.eq_properties().output_ordering();
799-
if let Some(repartitioned) = range_free_file_scan
800-
.repartitioned(partitions, 1, ordering)? {
801-
return Ok(Transformed::yes(Arc::new(DataSourceExec::new(repartitioned))))
779+
let min_size_buckets = max(1, total_size.div_ceil(self.min_size));
780+
let partitions = min(min_size_buckets, desired_parallelism);
781+
let ordering = range_free_file_scan.eq_properties().output_ordering();
782+
if let Some(repartitioned) =
783+
range_free_file_scan.repartitioned(partitions, 1, ordering)?
784+
{
785+
return Ok(Transformed::yes(Arc::new(DataSourceExec::new(
786+
repartitioned,
787+
))));
788+
}
802789
}
803790
}
804-
}
805-
Ok(Transformed::no(node))
806-
}).map_err(py_datafusion_err)?.data;
791+
Ok(Transformed::no(node))
792+
})
793+
.map_err(py_datafusion_err)?
794+
.data;
807795
self.physical_plan = PyExecutionPlan::new(updated_plan);
808796
Ok(())
809797
}
810798
}
811799

812800
impl DistributedPlan {
813-
814801
async fn try_new(df: &DataFrame) -> Result<Self, DataFusionError> {
815802
let (mut session_state, logical_plan) = df.clone().into_parts();
816-
let min_size = session_state.config_options().optimizer.repartition_file_min_size;
803+
let min_size = session_state
804+
.config_options()
805+
.optimizer
806+
.repartition_file_min_size;
817807
// Create the physical plan with a single partition, to ensure that no files are split into ranges.
818808
// Otherwise, any subsequent repartition attempt would fail (see the comment in `set_desired_parallelism`)
819-
session_state.config_mut().options_mut().execution.target_partitions = 1;
809+
session_state
810+
.config_mut()
811+
.options_mut()
812+
.execution
813+
.target_partitions = 1;
820814
let physical_plan = session_state.create_physical_plan(&logical_plan).await?;
821815
let physical_plan = PyExecutionPlan::new(physical_plan);
822816
Ok(Self {
@@ -830,7 +824,7 @@ impl DistributedPlan {
830824
}
831825

832826
fn stats_field(&self, field: fn(Statistics) -> Precision<usize>) -> Option<usize> {
833-
if let Ok(stats) = self.physical_plan.plan.statistics() {
827+
if let Ok(stats) = self.plan().statistics() {
834828
match field(stats) {
835829
Precision::Exact(n) => Some(n),
836830
_ => None,
@@ -839,30 +833,6 @@ impl DistributedPlan {
839833
None
840834
}
841835
}
842-
843-
}
844-
845-
fn deserialize_plan(serialized_plan: &[u8], ctx: &SessionContext) -> PyResult<Arc<dyn ExecutionPlan>> {
846-
deltalake::ensure_initialized();
847-
let node = PhysicalPlanNode::decode(serialized_plan)
848-
.map_err(|e| DataFusionError::External(Box::new(e)))
849-
.map_err(py_datafusion_err)?;
850-
let plan = node.try_into_physical_plan(ctx, ctx.runtime_env().as_ref(), codec())
851-
.map_err(py_datafusion_err)?;
852-
Ok(plan)
853-
}
854-
855-
#[pyfunction]
856-
pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) -> PyResult<PyRecordBatchStream> {
857-
let ctx = SessionContext::new();
858-
let plan = deserialize_plan(serialized_plan, &ctx)?;
859-
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
860-
plan.execute(partition, ctx.task_ctx())
861-
});
862-
wait_for_future(py, stream_with_runtime)
863-
.map_err(py_datafusion_err)?
864-
.map(PyRecordBatchStream::new)
865-
.map_err(py_datafusion_err)
866836
}
867837

868838
/// Print DataFrame

src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
117117
setup_substrait_module(py, &m)?;
118118

119119
m.add_class::<dataframe::DistributedPlan>()?;
120-
m.add_wrapped(wrap_pyfunction!(dataframe::partition_stream))?;
121120
Ok(())
122121
}
123122

0 commit comments

Comments
 (0)