@@ -27,18 +27,27 @@ use arrow::util::display::{ArrayFormatter, FormatOptions};
2727use datafusion:: arrow:: datatypes:: Schema ;
2828use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
2929use datafusion:: arrow:: util:: pretty;
30- use datafusion:: common:: UnnestOptions ;
31- use datafusion:: config:: { CsvOptions , TableParquetOptions } ;
30+ use datafusion:: common:: stats:: Precision ;
31+ use datafusion:: common:: { DFSchema , DataFusionError , UnnestOptions } ;
32+ use datafusion:: config:: { ConfigOptions , CsvOptions , TableParquetOptions } ;
3233use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
34+ use datafusion:: execution:: runtime_env:: RuntimeEnvBuilder ;
3335use datafusion:: execution:: SendableRecordBatchStream ;
3436use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
37+ use datafusion:: physical_plan:: ExecutionPlan ;
3538use datafusion:: prelude:: * ;
39+
40+ use datafusion_proto:: physical_plan:: { AsExecutionPlan , PhysicalExtensionCodec } ;
41+ use datafusion_proto:: protobuf:: PhysicalPlanNode ;
42+ use deltalake:: delta_datafusion:: DeltaPhysicalCodec ;
43+ use prost:: Message ;
3644use pyo3:: exceptions:: PyValueError ;
3745use pyo3:: prelude:: * ;
3846use pyo3:: pybacked:: PyBackedStr ;
3947use pyo3:: types:: { PyCapsule , PyTuple , PyTupleMethods } ;
4048use tokio:: task:: JoinHandle ;
4149
50+ use crate :: common:: df_schema:: PyDFSchema ;
4251use crate :: errors:: { py_datafusion_err, PyDataFusionError } ;
4352use crate :: expr:: sort_expr:: to_sort_expressions;
4453use crate :: physical_plan:: PyExecutionPlan ;
@@ -659,6 +668,75 @@ impl PyDataFrame {
659668 fn count ( & self , py : Python ) -> PyDataFusionResult < usize > {
660669 Ok ( wait_for_future ( py, self . df . as_ref ( ) . clone ( ) . count ( ) ) ?)
661670 }
671+
672+ fn distributed_plan ( & self , parallelism : usize , py : Python < ' _ > ) -> PyResult < DistributedPlan > {
673+ let future_plan = self . df . as_ref ( ) . clone ( ) . create_physical_plan ( ) ;
674+ let physical_plan = wait_for_future ( py, future_plan) . map_err ( py_datafusion_err) ?;
675+ DistributedPlan :: try_new ( physical_plan, parallelism) . map_err ( py_datafusion_err)
676+ }
677+
678+ }
679+
680+ #[ pyclass( get_all) ]
681+ #[ derive( Debug , Clone ) ]
682+ pub struct DistributedPlan {
683+ physical_plan : Vec < u8 > ,
684+ schema : PyDFSchema ,
685+ partitions : usize ,
686+ num_bytes : Option < usize > ,
687+ num_rows : Option < usize > ,
688+ }
689+
690+ fn codec ( ) -> & ' static dyn PhysicalExtensionCodec {
691+ static CODEC : DeltaPhysicalCodec = DeltaPhysicalCodec { } ;
692+ & CODEC
693+ }
694+
695+ impl DistributedPlan {
696+ fn try_new ( plan : Arc < dyn ExecutionPlan > , parallelism : usize ) -> Result < Self , DataFusionError > {
697+ fn extract ( prec : Precision < usize > ) -> Option < usize > {
698+ match prec {
699+ Precision :: Exact ( n) => Some ( n) ,
700+ _ => None ,
701+ }
702+ }
703+ let ( num_bytes, num_rows) = if let Ok ( stats) = plan. statistics ( ) {
704+ let num_bytes = extract ( stats. total_byte_size ) ;
705+ let num_rows = extract ( stats. num_rows ) ;
706+ ( num_bytes, num_rows)
707+ } else {
708+ ( None , None )
709+ } ;
710+
711+ let schema = DFSchema :: try_from ( plan. schema ( ) )
712+ . map ( PyDFSchema :: from) ?;
713+ let plan = plan. repartitioned ( parallelism, & ConfigOptions :: default ( ) ) ?
714+ . unwrap_or ( plan) ;
715+ let partitions = plan. properties ( ) . partitioning . partition_count ( ) ;
716+ let physical_plan = PhysicalPlanNode :: try_from_physical_plan ( plan, codec ( ) ) ?
717+ . encode_to_vec ( ) ;
718+ Ok ( Self { physical_plan, schema, partitions, num_bytes, num_rows } )
719+ }
720+
721+ }
722+
723+ #[ pyfunction]
724+ pub fn partition_stream ( serialized_plan : & [ u8 ] , partition : usize , py : Python ) -> PyResult < PyRecordBatchStream > {
725+ deltalake:: ensure_initialized ( ) ;
726+ let ctx = SessionContext :: new ( ) ;
727+ let runtime = RuntimeEnvBuilder :: new ( ) . build ( ) . map_err ( py_datafusion_err) ?;
728+ let node = PhysicalPlanNode :: decode ( serialized_plan)
729+ . map_err ( |e| DataFusionError :: External ( Box :: new ( e) ) )
730+ . map_err ( py_datafusion_err) ?;
731+ let plan = node. try_into_physical_plan ( & ctx, & runtime, codec ( ) )
732+ . map_err ( py_datafusion_err) ?;
733+ let stream_with_runtime = get_tokio_runtime ( ) . 0 . spawn ( async move {
734+ plan. execute ( partition, ctx. task_ctx ( ) )
735+ } ) ;
736+ wait_for_future ( py, stream_with_runtime)
737+ . map_err ( py_datafusion_err) ?
738+ . map ( PyRecordBatchStream :: new)
739+ . map_err ( py_datafusion_err)
662740}
663741
664742/// Print DataFrame
0 commit comments