Skip to content

Commit da58c44

Browse files
ccciudatuadragomir
authored andcommitted
fixup! [HSTACK] Building blocks for Ray DataFusionDatasource
1 parent fa67997 commit da58c44

File tree

1 file changed

+94
-38
lines changed

1 file changed

+94
-38
lines changed

src/dataframe.rs

Lines changed: 94 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@ use datafusion::arrow::datatypes::Schema;
2828
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2929
use datafusion::arrow::util::pretty;
3030
use datafusion::common::stats::Precision;
31-
use datafusion::common::{DFSchema, DataFusionError, UnnestOptions};
32-
use datafusion::config::{CsvOptions, TableParquetOptions};
31+
use datafusion::common::{DFSchema, DataFusionError, Statistics, UnnestOptions};
32+
use datafusion::common::tree_node::{Transformed, TreeNode};
33+
use datafusion::config::{ConfigOptions, CsvOptions, TableParquetOptions};
3334
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
34-
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
35+
use datafusion::datasource::physical_plan::ParquetExec;
3536
use datafusion::execution::SendableRecordBatchStream;
3637
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
37-
use datafusion::physical_plan::ExecutionPlan;
38+
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
3839
use datafusion::prelude::*;
3940

4041
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
@@ -54,10 +55,7 @@ use crate::physical_plan::PyExecutionPlan;
5455
use crate::record_batch::PyRecordBatchStream;
5556
use crate::sql::logical::PyLogicalPlan;
5657
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
57-
use crate::{
58-
errors::PyDataFusionResult,
59-
expr::{sort_expr::PySortExpr, PyExpr},
60-
};
58+
use crate::{errors::PyDataFusionResult, expr::{sort_expr::PySortExpr, PyExpr}};
6159

6260
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
6361
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -671,62 +669,115 @@ impl PyDataFrame {
671669

672670
fn distributed_plan(&self, py: Python<'_>) -> PyResult<DistributedPlan> {
673671
let future_plan = self.df.as_ref().clone().create_physical_plan();
674-
let physical_plan = wait_for_future(py, future_plan).map_err(py_datafusion_err)?;
675-
DistributedPlan::try_new(physical_plan).map_err(py_datafusion_err)
672+
wait_for_future(py, future_plan)
673+
.map(DistributedPlan::new)
674+
.map_err(py_datafusion_err)
676675
}
677676

678677
}
679678

680679
#[pyclass(get_all)]
681680
#[derive(Debug, Clone)]
682681
pub struct DistributedPlan {
683-
physical_plan: Vec<u8>,
684-
schema: PyDFSchema,
685-
num_partitions: usize,
686-
num_bytes: Option<usize>,
687-
num_rows: Option<usize>,
682+
physical_plan: PyExecutionPlan,
688683
}
689684

690-
fn codec() -> &'static dyn PhysicalExtensionCodec {
691-
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
692-
&CODEC
685+
#[pymethods]
686+
impl DistributedPlan {
687+
688+
fn serialize(&self) -> PyResult<Vec<u8>> {
689+
PhysicalPlanNode::try_from_physical_plan(self.plan().clone(), codec())
690+
.map(|node| node.encode_to_vec())
691+
.map_err(py_datafusion_err)
692+
}
693+
694+
fn partition_count(&self) -> usize {
695+
self.plan().output_partitioning().partition_count()
696+
}
697+
698+
fn num_bytes(&self) -> Option<usize> {
699+
self.stats_field(|stats| stats.total_byte_size)
700+
}
701+
702+
fn num_rows(&self) -> Option<usize> {
703+
self.stats_field(|stats| stats.num_rows)
704+
}
705+
706+
fn schema(&self) -> PyResult<PyDFSchema> {
707+
DFSchema::try_from(self.plan().schema())
708+
.map(PyDFSchema::from)
709+
.map_err(py_datafusion_err)
710+
}
711+
712+
fn set_desired_parallelism(&mut self, desired_parallelism: usize) -> PyResult<()> {
713+
if self.plan().output_partitioning().partition_count() == desired_parallelism {
714+
return Ok(())
715+
}
716+
let updated_plan = self.plan().clone().transform_up(|node| {
717+
if let Some(parquet) = node.as_any().downcast_ref::<ParquetExec>() {
718+
// Remove redundant ranges from partition files because ParquetExec refuses to repartition
719+
// if any file has a range defined (even when the range actually covers the entire file).
720+
// The EnforceDistribution optimizer rule adds ranges for both full and partial files,
721+
// so this tries to rever that to trigger a repartition when no files are actually split.
722+
let mut file_groups = parquet.base_config().file_groups.clone();
723+
for group in file_groups.iter_mut() {
724+
for file in group.iter_mut() {
725+
if let Some(range) = &file.range {
726+
if range.start == 0 && range.end == file.object_meta.size as i64 {
727+
file.range = None; // remove redundant range
728+
}
729+
}
730+
}
731+
}
732+
if let Some(repartitioned) = parquet.clone().into_builder().with_file_groups(file_groups)
733+
.build_arc()
734+
.repartitioned(desired_parallelism, &ConfigOptions::default())? {
735+
Ok(Transformed::yes(repartitioned))
736+
} else {
737+
Ok(Transformed::no(node))
738+
}
739+
} else {
740+
Ok(Transformed::no(node))
741+
}
742+
}).map_err(py_datafusion_err)?.data;
743+
self.physical_plan = PyExecutionPlan::new(updated_plan);
744+
Ok(())
745+
}
693746
}
694747

695748
impl DistributedPlan {
696-
fn try_new(plan: Arc<dyn ExecutionPlan>) -> Result<Self, DataFusionError> {
697-
fn extract(prec: Precision<usize>) -> Option<usize> {
698-
match prec {
749+
750+
fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
751+
Self {
752+
physical_plan: PyExecutionPlan::new(plan)
753+
}
754+
}
755+
756+
fn plan(&self) -> &Arc<dyn ExecutionPlan> {
757+
&self.physical_plan.plan
758+
}
759+
760+
fn stats_field(&self, field: fn(Statistics) -> Precision<usize>) -> Option<usize> {
761+
if let Ok(stats) = self.physical_plan.plan.statistics() {
762+
match field(stats) {
699763
Precision::Exact(n) => Some(n),
700764
_ => None,
701765
}
702-
}
703-
let (num_bytes, num_rows) = if let Ok(stats) = plan.statistics() {
704-
let bytes = extract(stats.total_byte_size);
705-
let rows = extract(stats.num_rows);
706-
(bytes, rows)
707766
} else {
708-
(None, None)
709-
};
710-
711-
let schema = DFSchema::try_from(plan.schema())
712-
.map(PyDFSchema::from)?;
713-
let num_partitions = plan.properties().partitioning.partition_count();
714-
let physical_plan = PhysicalPlanNode::try_from_physical_plan(plan, codec())?
715-
.encode_to_vec();
716-
Ok(Self { physical_plan, schema, num_partitions, num_bytes, num_rows })
767+
None
768+
}
717769
}
718770

719771
}
720772

721773
#[pyfunction]
722774
pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) -> PyResult<PyRecordBatchStream> {
723775
deltalake::ensure_initialized();
724-
let ctx = SessionContext::new();
725-
let runtime = RuntimeEnvBuilder::new().build().map_err(py_datafusion_err)?;
726776
let node = PhysicalPlanNode::decode(serialized_plan)
727777
.map_err(|e| DataFusionError::External(Box::new(e)))
728778
.map_err(py_datafusion_err)?;
729-
let plan = node.try_into_physical_plan(&ctx, &runtime, codec())
779+
let ctx = SessionContext::new();
780+
let plan = node.try_into_physical_plan(&ctx, ctx.runtime_env().as_ref(), codec())
730781
.map_err(py_datafusion_err)?;
731782
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
732783
plan.execute(partition, ctx.task_ctx())
@@ -737,6 +788,11 @@ pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) ->
737788
.map_err(py_datafusion_err)
738789
}
739790

791+
fn codec() -> &'static dyn PhysicalExtensionCodec {
792+
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
793+
&CODEC
794+
}
795+
740796
/// Print DataFrame
741797
fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
742798
// Get string representation of record batches

0 commit comments

Comments
 (0)