Skip to content

Commit 3be6774

Browse files
committed
[HSTACK] Building blocks for Ray DataFusionDatasource
1 parent 211f16d commit 3be6774

File tree

3 files changed

+85
-2
lines changed

3 files changed

+85
-2
lines changed

python/datafusion/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,9 @@ def count(self) -> int:
797797
"""
798798
return self.df.count()
799799

800+
def distributed_plan(self, num_shards: int):
801+
return self.df.distributed_plan(num_shards)
802+
800803
@deprecated("Use :py:func:`unnest_columns` instead.")
801804
def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame:
802805
"""See :py:func:`unnest_columns`."""

src/dataframe.rs

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,27 @@ use arrow::util::display::{ArrayFormatter, FormatOptions};
2727
use datafusion::arrow::datatypes::Schema;
2828
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2929
use datafusion::arrow::util::pretty;
30-
use datafusion::common::UnnestOptions;
31-
use datafusion::config::{CsvOptions, TableParquetOptions};
30+
use datafusion::common::stats::Precision;
31+
use datafusion::common::{DFSchema, DataFusionError, UnnestOptions};
32+
use datafusion::config::{ConfigOptions, CsvOptions, TableParquetOptions};
3233
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
34+
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
3335
use datafusion::execution::SendableRecordBatchStream;
3436
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
37+
use datafusion::physical_plan::ExecutionPlan;
3538
use datafusion::prelude::*;
39+
40+
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
41+
use datafusion_proto::protobuf::PhysicalPlanNode;
42+
use deltalake::delta_datafusion::DeltaPhysicalCodec;
43+
use prost::Message;
3644
use pyo3::exceptions::PyValueError;
3745
use pyo3::prelude::*;
3846
use pyo3::pybacked::PyBackedStr;
3947
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
4048
use tokio::task::JoinHandle;
4149

50+
use crate::common::df_schema::PyDFSchema;
4251
use crate::errors::{py_datafusion_err, PyDataFusionError};
4352
use crate::expr::sort_expr::to_sort_expressions;
4453
use crate::physical_plan::PyExecutionPlan;
@@ -659,6 +668,75 @@ impl PyDataFrame {
659668
fn count(&self, py: Python) -> PyDataFusionResult<usize> {
660669
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
661670
}
671+
672+
fn distributed_plan(&self, parallelism: usize, py: Python<'_>) -> PyResult<DistributedPlan> {
673+
let future_plan = self.df.as_ref().clone().create_physical_plan();
674+
let physical_plan = wait_for_future(py, future_plan).map_err(py_datafusion_err)?;
675+
DistributedPlan::try_new(physical_plan, parallelism).map_err(py_datafusion_err)
676+
}
677+
678+
}
679+
680+
#[pyclass(get_all)]
681+
#[derive(Debug, Clone)]
682+
pub struct DistributedPlan {
683+
physical_plan: Vec<u8>,
684+
schema: PyDFSchema,
685+
partitions: usize,
686+
num_bytes: Option<usize>,
687+
num_rows: Option<usize>,
688+
}
689+
690+
fn codec() -> &'static dyn PhysicalExtensionCodec {
691+
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
692+
&CODEC
693+
}
694+
695+
impl DistributedPlan {
696+
fn try_new(plan: Arc<dyn ExecutionPlan>, parallelism: usize) -> Result<Self, DataFusionError> {
697+
fn extract(prec: Precision<usize>) -> Option<usize> {
698+
match prec {
699+
Precision::Exact(n) => Some(n),
700+
_ => None,
701+
}
702+
}
703+
let (num_bytes, num_rows) = if let Ok(stats) = plan.statistics() {
704+
let num_bytes = extract(stats.total_byte_size);
705+
let num_rows = extract(stats.num_rows);
706+
(num_bytes, num_rows)
707+
} else {
708+
(None, None)
709+
};
710+
711+
let schema = DFSchema::try_from(plan.schema())
712+
.map(PyDFSchema::from)?;
713+
let plan = plan.repartitioned(parallelism, &ConfigOptions::default())?
714+
.unwrap_or(plan);
715+
let partitions = plan.properties().partitioning.partition_count();
716+
let physical_plan = PhysicalPlanNode::try_from_physical_plan(plan, codec())?
717+
.encode_to_vec();
718+
Ok(Self { physical_plan, schema, partitions, num_bytes, num_rows })
719+
}
720+
721+
}
722+
723+
#[pyfunction]
724+
pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) -> PyResult<PyRecordBatchStream> {
725+
deltalake::ensure_initialized();
726+
let ctx = SessionContext::new();
727+
let runtime = RuntimeEnvBuilder::new().build().map_err(py_datafusion_err)?;
728+
let node = PhysicalPlanNode::decode(serialized_plan)
729+
.map_err(|e| DataFusionError::External(Box::new(e)))
730+
.map_err(py_datafusion_err)?;
731+
let plan = node.try_into_physical_plan(&ctx, &runtime, codec())
732+
.map_err(py_datafusion_err)?;
733+
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
734+
plan.execute(partition, ctx.task_ctx())
735+
});
736+
wait_for_future(py, stream_with_runtime)
737+
.map_err(py_datafusion_err)?
738+
.map(PyRecordBatchStream::new)
739+
.map_err(py_datafusion_err)
662740
}
663741

664742
/// Print DataFrame

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
116116
#[cfg(feature = "substrait")]
117117
setup_substrait_module(py, &m)?;
118118

119+
m.add_class::<dataframe::DistributedPlan>()?;
120+
m.add_wrapped(wrap_pyfunction!(dataframe::partition_stream))?;
119121
Ok(())
120122
}
121123

0 commit comments

Comments
 (0)