Skip to content

Commit 232d7f6

Browse files
committed
[HSTACK] Ray datasource support
1 parent e0194e2 commit 232d7f6

File tree

5 files changed

+154
-8
lines changed

5 files changed

+154
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ classifiers = [
4646
]
4747
dependencies = ["pyarrow>=11.0.0", "typing-extensions;python_version<'3.13'"]
4848
#dynamic = ["version"]
49-
version = "46.0.0+adobe.1"
49+
version = "46.0.0+adobe.2"
5050

5151
[project.urls]
5252
homepage = "https://datafusion.apache.org/python"

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from . import functions, object_store, substrait
3030

3131
# The following imports are okay to remain as opaque to the user.
32-
from ._internal import Config
32+
from ._internal import Config, partition_stream
3333
from .catalog import Catalog, Database, Table
3434
from .common import (
3535
DFSchema,
@@ -86,6 +86,7 @@
8686
"read_avro",
8787
"read_csv",
8888
"read_json",
89+
"partition_stream",
8990
]
9091

9192

python/datafusion/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,9 @@ def count(self) -> int:
805805
"""
806806
return self.df.count()
807807

808+
def distributed_plan(self):
809+
return self.df.distributed_plan()
810+
808811
@deprecated("Use :py:func:`unnest_columns` instead.")
809812
def unnest_column(self, column: str, preserve_nulls: bool = True) -> DataFrame:
810813
"""See :py:func:`unnest_columns`."""

src/dataframe.rs

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,39 @@ use arrow::util::display::{ArrayFormatter, FormatOptions};
2727
use datafusion::arrow::datatypes::Schema;
2828
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2929
use datafusion::arrow::util::pretty;
30-
use datafusion::common::UnnestOptions;
30+
use datafusion::common::stats::Precision;
31+
use datafusion::common::{DFSchema, DataFusionError, Statistics, UnnestOptions};
32+
use datafusion::common::tree_node::{Transformed, TreeNode};
3133
use datafusion::config::{CsvOptions, TableParquetOptions};
3234
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
35+
use datafusion::datasource::memory::DataSourceExec;
3336
use datafusion::datasource::TableProvider;
34-
use datafusion::execution::SendableRecordBatchStream;
37+
use datafusion::datasource::physical_plan::FileScanConfig;
38+
use datafusion::datasource::source::DataSource;
39+
use datafusion::execution::{SendableRecordBatchStream};
3540
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
41+
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
3642
use datafusion::prelude::*;
43+
44+
use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec};
45+
use datafusion_proto::protobuf::PhysicalPlanNode;
46+
use deltalake::delta_datafusion::DeltaPhysicalCodec;
47+
use prost::Message;
3748
use pyo3::exceptions::PyValueError;
3849
use pyo3::prelude::*;
3950
use pyo3::pybacked::PyBackedStr;
4051
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
4152
use tokio::task::JoinHandle;
4253

4354
use crate::catalog::PyTable;
55+
use crate::common::df_schema::PyDFSchema;
4456
use crate::errors::{py_datafusion_err, PyDataFusionError};
4557
use crate::expr::sort_expr::to_sort_expressions;
4658
use crate::physical_plan::PyExecutionPlan;
4759
use crate::record_batch::PyRecordBatchStream;
4860
use crate::sql::logical::PyLogicalPlan;
4961
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
50-
use crate::{
51-
errors::PyDataFusionResult,
52-
expr::{sort_expr::PySortExpr, PyExpr},
53-
};
62+
use crate::{errors::PyDataFusionResult, expr::{sort_expr::PySortExpr, PyExpr}};
5463

5564
// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
5665
// - we have not decided on the table_provider approach yet
@@ -697,6 +706,137 @@ impl PyDataFrame {
697706
fn count(&self, py: Python) -> PyDataFusionResult<usize> {
698707
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
699708
}
709+
710+
fn distributed_plan(&self, py: Python<'_>) -> PyResult<DistributedPlan> {
711+
let future_plan = DistributedPlan::try_new(self.df.as_ref());
712+
wait_for_future(py, future_plan).map_err(py_datafusion_err)
713+
}
714+
715+
}
716+
717+
#[pyclass(get_all)]
718+
#[derive(Debug, Clone)]
719+
pub struct DistributedPlan {
720+
repartition_file_min_size: usize,
721+
physical_plan: PyExecutionPlan,
722+
}
723+
724+
#[pymethods]
725+
impl DistributedPlan {
726+
727+
fn serialize(&self) -> PyResult<Vec<u8>> {
728+
PhysicalPlanNode::try_from_physical_plan(self.plan().clone(), codec())
729+
.map(|node| node.encode_to_vec())
730+
.map_err(py_datafusion_err)
731+
}
732+
733+
fn partition_count(&self) -> usize {
734+
self.plan().output_partitioning().partition_count()
735+
}
736+
737+
fn num_bytes(&self) -> Option<usize> {
738+
self.stats_field(|stats| stats.total_byte_size)
739+
}
740+
741+
fn num_rows(&self) -> Option<usize> {
742+
self.stats_field(|stats| stats.num_rows)
743+
}
744+
745+
fn schema(&self) -> PyResult<PyDFSchema> {
746+
DFSchema::try_from(self.plan().schema())
747+
.map(PyDFSchema::from)
748+
.map_err(py_datafusion_err)
749+
}
750+
751+
fn set_desired_parallelism(&mut self, desired_parallelism: usize) -> PyResult<()> {
752+
if self.plan().output_partitioning().partition_count() == desired_parallelism {
753+
return Ok(())
754+
}
755+
let updated_plan = self.plan().clone().transform_up(|node| {
756+
if let Some(exec) = node.as_any().downcast_ref::<DataSourceExec>() {
757+
// Remove redundant ranges from partition files because FileScanConfig refuses to repartition
758+
// if any file has a range defined (even when the range actually covers the entire file).
759+
// The EnforceDistribution optimizer rule adds ranges for both full and partial files,
760+
// so this tries to revert that in order to trigger a repartition when no files are actually split.
761+
if let Some(file_scan) = exec.data_source().as_any().downcast_ref::<FileScanConfig>() {
762+
let mut range_free_file_scan = file_scan.clone();
763+
for group in range_free_file_scan.file_groups.iter_mut() {
764+
for file in group.iter_mut() {
765+
if let Some(range) = &file.range {
766+
if range.start == 0 && range.end == file.object_meta.size as i64 {
767+
file.range = None; // remove redundant range
768+
}
769+
}
770+
}
771+
}
772+
let ordering = range_free_file_scan.eq_properties().output_ordering();
773+
if let Some(repartitioned) = range_free_file_scan
774+
.repartitioned(desired_parallelism, self.repartition_file_min_size, ordering)? {
775+
return Ok(Transformed::yes(Arc::new(DataSourceExec::new(repartitioned))))
776+
}
777+
}
778+
}
779+
Ok(Transformed::no(node))
780+
}).map_err(py_datafusion_err)?.data;
781+
self.physical_plan = PyExecutionPlan::new(updated_plan);
782+
Ok(())
783+
}
784+
}
785+
786+
impl DistributedPlan {
787+
788+
async fn try_new(df: &DataFrame) -> Result<Self, DataFusionError> {
789+
let (mut session_state, logical_plan) = df.clone().into_parts();
790+
let repartition_file_min_size = session_state.config_options().optimizer.repartition_file_min_size;
791+
// Create the physical plan with a single partition, to ensure that no files are split into ranges.
792+
// Otherwise, any subsequent repartition attempt would fail (see the comment in `set_desired_parallelism`)
793+
session_state.config_mut().options_mut().execution.target_partitions = 1;
794+
let physical_plan = session_state.create_physical_plan(&logical_plan).await?;
795+
let physical_plan = PyExecutionPlan::new(physical_plan);
796+
Ok(Self {
797+
repartition_file_min_size,
798+
physical_plan,
799+
})
800+
}
801+
802+
fn plan(&self) -> &Arc<dyn ExecutionPlan> {
803+
&self.physical_plan.plan
804+
}
805+
806+
fn stats_field(&self, field: fn(Statistics) -> Precision<usize>) -> Option<usize> {
807+
if let Ok(stats) = self.physical_plan.plan.statistics() {
808+
match field(stats) {
809+
Precision::Exact(n) => Some(n),
810+
_ => None,
811+
}
812+
} else {
813+
None
814+
}
815+
}
816+
817+
}
818+
819+
#[pyfunction]
820+
pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) -> PyResult<PyRecordBatchStream> {
821+
deltalake::ensure_initialized();
822+
let node = PhysicalPlanNode::decode(serialized_plan)
823+
.map_err(|e| DataFusionError::External(Box::new(e)))
824+
.map_err(py_datafusion_err)?;
825+
let ctx = SessionContext::new();
826+
let plan = node.try_into_physical_plan(&ctx, ctx.runtime_env().as_ref(), codec())
827+
.map_err(py_datafusion_err)?;
828+
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
829+
plan.execute(partition, ctx.task_ctx())
830+
});
831+
wait_for_future(py, stream_with_runtime)
832+
.map_err(py_datafusion_err)?
833+
.map(PyRecordBatchStream::new)
834+
.map_err(py_datafusion_err)
835+
}
836+
837+
fn codec() -> &'static dyn PhysicalExtensionCodec {
838+
static CODEC: DeltaPhysicalCodec = DeltaPhysicalCodec {};
839+
&CODEC
700840
}
701841

702842
/// Print DataFrame

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
116116
#[cfg(feature = "substrait")]
117117
setup_substrait_module(py, &m)?;
118118

119+
m.add_class::<dataframe::DistributedPlan>()?;
120+
m.add_wrapped(wrap_pyfunction!(dataframe::partition_stream))?;
119121
Ok(())
120122
}
121123

0 commit comments

Comments
 (0)