Skip to content

Commit 18627a0

Browse files
committed
fixup! [HSTACK] Ray datasource support
1 parent 232d7f6 commit 18627a0

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ classifiers = [
4646
]
4747
dependencies = ["pyarrow>=11.0.0", "typing-extensions;python_version<'3.13'"]
4848
#dynamic = ["version"]
49-
version = "46.0.0+adobe.2"
49+
version = "46.0.0+adobe.3"
5050

5151
[project.urls]
5252
homepage = "https://datafusion.apache.org/python"

src/dataframe.rs

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::cmp::{max, min};
1819
use std::ffi::CString;
1920
use std::sync::Arc;
2021

@@ -48,7 +49,7 @@ use prost::Message;
4849
use pyo3::exceptions::PyValueError;
4950
use pyo3::prelude::*;
5051
use pyo3::pybacked::PyBackedStr;
51-
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
52+
use pyo3::types::{PyBytes, PyCapsule, PyDict, PyTuple, PyTupleMethods};
5253
use tokio::task::JoinHandle;
5354

5455
use crate::catalog::PyTable;
@@ -717,17 +718,39 @@ impl PyDataFrame {
717718
#[pyclass(get_all)]
718719
#[derive(Debug, Clone)]
719720
pub struct DistributedPlan {
720-
repartition_file_min_size: usize,
721+
min_size: usize,
721722
physical_plan: PyExecutionPlan,
722723
}
723724

724725
#[pymethods]
725726
impl DistributedPlan {
726727

727-
fn serialize(&self) -> PyResult<Vec<u8>> {
728-
PhysicalPlanNode::try_from_physical_plan(self.plan().clone(), codec())
728+
fn marshal(&self, py: Python) -> PyResult<PyObject> {
729+
let bytes = PhysicalPlanNode::try_from_physical_plan(self.plan().clone(), codec())
729730
.map(|node| node.encode_to_vec())
730-
.map_err(py_datafusion_err)
731+
.map_err(py_datafusion_err)?;
732+
let state = PyDict::new(py);
733+
state.set_item("plan", PyBytes::new(py, bytes.as_slice()))?;
734+
state.set_item("min_size", self.min_size)?;
735+
Ok(state.into())
736+
}
737+
738+
#[new]
739+
fn unmarshal(state: Bound<PyDict>) -> PyResult<Self>{
740+
let ctx = SessionContext::new();
741+
let serialized_plan = state.get_item("plan")?
742+
.expect("missing key `plan` from state");
743+
let serialized_plan = serialized_plan
744+
.downcast::<PyBytes>()?
745+
.as_bytes();
746+
let min_size = state.get_item("min_size")?
747+
.expect("missing key `min_size` from state")
748+
.extract::<usize>()?;
749+
let plan = deserialize_plan(serialized_plan, &ctx)?;
750+
Ok(Self {
751+
min_size,
752+
physical_plan: PyExecutionPlan::new(plan),
753+
})
731754
}
732755

733756
fn partition_count(&self) -> usize {
@@ -749,9 +772,6 @@ impl DistributedPlan {
749772
}
750773

751774
fn set_desired_parallelism(&mut self, desired_parallelism: usize) -> PyResult<()> {
752-
if self.plan().output_partitioning().partition_count() == desired_parallelism {
753-
return Ok(())
754-
}
755775
let updated_plan = self.plan().clone().transform_up(|node| {
756776
if let Some(exec) = node.as_any().downcast_ref::<DataSourceExec>() {
757777
// Remove redundant ranges from partition files because FileScanConfig refuses to repartition
@@ -760,18 +780,24 @@ impl DistributedPlan {
760780
// so this tries to revert that in order to trigger a repartition when no files are actually split.
761781
if let Some(file_scan) = exec.data_source().as_any().downcast_ref::<FileScanConfig>() {
762782
let mut range_free_file_scan = file_scan.clone();
783+
let mut total_size: usize = 0;
763784
for group in range_free_file_scan.file_groups.iter_mut() {
764785
for file in group.iter_mut() {
765786
if let Some(range) = &file.range {
787+
total_size += (range.end - range.start) as usize;
766788
if range.start == 0 && range.end == file.object_meta.size as i64 {
767789
file.range = None; // remove redundant range
768790
}
791+
} else {
792+
total_size += file.object_meta.size;
769793
}
770794
}
771795
}
796+
let min_size_buckets = max(1, total_size.div_ceil(self.min_size));
797+
let partitions = min(min_size_buckets, desired_parallelism);
772798
let ordering = range_free_file_scan.eq_properties().output_ordering();
773799
if let Some(repartitioned) = range_free_file_scan
774-
.repartitioned(desired_parallelism, self.repartition_file_min_size, ordering)? {
800+
.repartitioned(partitions, 1, ordering)? {
775801
return Ok(Transformed::yes(Arc::new(DataSourceExec::new(repartitioned))))
776802
}
777803
}
@@ -787,14 +813,14 @@ impl DistributedPlan {
787813

788814
async fn try_new(df: &DataFrame) -> Result<Self, DataFusionError> {
789815
let (mut session_state, logical_plan) = df.clone().into_parts();
790-
let repartition_file_min_size = session_state.config_options().optimizer.repartition_file_min_size;
816+
let min_size = session_state.config_options().optimizer.repartition_file_min_size;
791817
// Create the physical plan with a single partition, to ensure that no files are split into ranges.
792818
// Otherwise, any subsequent repartition attempt would fail (see the comment in `set_desired_parallelism`)
793819
session_state.config_mut().options_mut().execution.target_partitions = 1;
794820
let physical_plan = session_state.create_physical_plan(&logical_plan).await?;
795821
let physical_plan = PyExecutionPlan::new(physical_plan);
796822
Ok(Self {
797-
repartition_file_min_size,
823+
min_size,
798824
physical_plan,
799825
})
800826
}
@@ -816,15 +842,20 @@ impl DistributedPlan {
816842

817843
}
818844

819-
#[pyfunction]
820-
pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) -> PyResult<PyRecordBatchStream> {
845+
fn deserialize_plan(serialized_plan: &[u8], ctx: &SessionContext) -> PyResult<Arc<dyn ExecutionPlan>> {
821846
deltalake::ensure_initialized();
822847
let node = PhysicalPlanNode::decode(serialized_plan)
823848
.map_err(|e| DataFusionError::External(Box::new(e)))
824849
.map_err(py_datafusion_err)?;
825-
let ctx = SessionContext::new();
826-
let plan = node.try_into_physical_plan(&ctx, ctx.runtime_env().as_ref(), codec())
850+
let plan = node.try_into_physical_plan(ctx, ctx.runtime_env().as_ref(), codec())
827851
.map_err(py_datafusion_err)?;
852+
Ok(plan)
853+
}
854+
855+
#[pyfunction]
856+
pub fn partition_stream(serialized_plan: &[u8], partition: usize, py: Python) -> PyResult<PyRecordBatchStream> {
857+
let ctx = SessionContext::new();
858+
let plan = deserialize_plan(serialized_plan, &ctx)?;
828859
let stream_with_runtime = get_tokio_runtime().0.spawn(async move {
829860
plan.execute(partition, ctx.task_ctx())
830861
});

0 commit comments

Comments
 (0)