1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use std:: cmp:: { max, min} ;
1819use std:: ffi:: CString ;
1920use std:: sync:: Arc ;
2021
@@ -48,7 +49,7 @@ use prost::Message;
4849use pyo3:: exceptions:: PyValueError ;
4950use pyo3:: prelude:: * ;
5051use pyo3:: pybacked:: PyBackedStr ;
51- use pyo3:: types:: { PyCapsule , PyTuple , PyTupleMethods } ;
52+ use pyo3:: types:: { PyBytes , PyCapsule , PyDict , PyTuple , PyTupleMethods } ;
5253use tokio:: task:: JoinHandle ;
5354
5455use crate :: catalog:: PyTable ;
@@ -717,17 +718,39 @@ impl PyDataFrame {
717718#[ pyclass( get_all) ]
718719#[ derive( Debug , Clone ) ]
719720pub struct DistributedPlan {
720- repartition_file_min_size : usize ,
721+ min_size : usize ,
721722 physical_plan : PyExecutionPlan ,
722723}
723724
724725#[ pymethods]
725726impl DistributedPlan {
726727
727- fn serialize ( & self ) -> PyResult < Vec < u8 > > {
728- PhysicalPlanNode :: try_from_physical_plan ( self . plan ( ) . clone ( ) , codec ( ) )
728+ fn marshal ( & self , py : Python ) -> PyResult < PyObject > {
729+ let bytes = PhysicalPlanNode :: try_from_physical_plan ( self . plan ( ) . clone ( ) , codec ( ) )
729730 . map ( |node| node. encode_to_vec ( ) )
730- . map_err ( py_datafusion_err)
731+ . map_err ( py_datafusion_err) ?;
732+ let state = PyDict :: new ( py) ;
733+ state. set_item ( "plan" , PyBytes :: new ( py, bytes. as_slice ( ) ) ) ?;
734+ state. set_item ( "min_size" , self . min_size ) ?;
735+ Ok ( state. into ( ) )
736+ }
737+
738+ #[ new]
739+ fn unmarshal ( state : Bound < PyDict > ) -> PyResult < Self > {
740+ let ctx = SessionContext :: new ( ) ;
741+ let serialized_plan = state. get_item ( "plan" ) ?
742+ . expect ( "missing key `plan` from state" ) ;
743+ let serialized_plan = serialized_plan
744+ . downcast :: < PyBytes > ( ) ?
745+ . as_bytes ( ) ;
746+ let min_size = state. get_item ( "min_size" ) ?
747+ . expect ( "missing key `min_size` from state" )
748+ . extract :: < usize > ( ) ?;
749+ let plan = deserialize_plan ( serialized_plan, & ctx) ?;
750+ Ok ( Self {
751+ min_size,
752+ physical_plan : PyExecutionPlan :: new ( plan) ,
753+ } )
731754 }
732755
733756 fn partition_count ( & self ) -> usize {
@@ -749,9 +772,6 @@ impl DistributedPlan {
749772 }
750773
751774 fn set_desired_parallelism ( & mut self , desired_parallelism : usize ) -> PyResult < ( ) > {
752- if self . plan ( ) . output_partitioning ( ) . partition_count ( ) == desired_parallelism {
753- return Ok ( ( ) )
754- }
755775 let updated_plan = self . plan ( ) . clone ( ) . transform_up ( |node| {
756776 if let Some ( exec) = node. as_any ( ) . downcast_ref :: < DataSourceExec > ( ) {
757777 // Remove redundant ranges from partition files because FileScanConfig refuses to repartition
@@ -760,18 +780,24 @@ impl DistributedPlan {
760780 // so this tries to revert that in order to trigger a repartition when no files are actually split.
761781 if let Some ( file_scan) = exec. data_source ( ) . as_any ( ) . downcast_ref :: < FileScanConfig > ( ) {
762782 let mut range_free_file_scan = file_scan. clone ( ) ;
783+ let mut total_size: usize = 0 ;
763784 for group in range_free_file_scan. file_groups . iter_mut ( ) {
764785 for file in group. iter_mut ( ) {
765786 if let Some ( range) = & file. range {
787+ total_size += ( range. end - range. start ) as usize ;
766788 if range. start == 0 && range. end == file. object_meta . size as i64 {
767789 file. range = None ; // remove redundant range
768790 }
791+ } else {
792+ total_size += file. object_meta . size ;
769793 }
770794 }
771795 }
796+ let min_size_buckets = max ( 1 , total_size. div_ceil ( self . min_size ) ) ;
797+ let partitions = min ( min_size_buckets, desired_parallelism) ;
772798 let ordering = range_free_file_scan. eq_properties ( ) . output_ordering ( ) ;
773799 if let Some ( repartitioned) = range_free_file_scan
774- . repartitioned ( desired_parallelism , self . repartition_file_min_size , ordering) ? {
800+ . repartitioned ( partitions , 1 , ordering) ? {
775801 return Ok ( Transformed :: yes ( Arc :: new ( DataSourceExec :: new ( repartitioned) ) ) )
776802 }
777803 }
@@ -787,14 +813,14 @@ impl DistributedPlan {
787813
788814 async fn try_new ( df : & DataFrame ) -> Result < Self , DataFusionError > {
789815 let ( mut session_state, logical_plan) = df. clone ( ) . into_parts ( ) ;
790- let repartition_file_min_size = session_state. config_options ( ) . optimizer . repartition_file_min_size ;
816+ let min_size = session_state. config_options ( ) . optimizer . repartition_file_min_size ;
791817 // Create the physical plan with a single partition, to ensure that no files are split into ranges.
792818 // Otherwise, any subsequent repartition attempt would fail (see the comment in `set_desired_parallelism`)
793819 session_state. config_mut ( ) . options_mut ( ) . execution . target_partitions = 1 ;
794820 let physical_plan = session_state. create_physical_plan ( & logical_plan) . await ?;
795821 let physical_plan = PyExecutionPlan :: new ( physical_plan) ;
796822 Ok ( Self {
797- repartition_file_min_size ,
823+ min_size ,
798824 physical_plan,
799825 } )
800826 }
@@ -816,15 +842,20 @@ impl DistributedPlan {
816842
817843}
818844
819- #[ pyfunction]
820- pub fn partition_stream ( serialized_plan : & [ u8 ] , partition : usize , py : Python ) -> PyResult < PyRecordBatchStream > {
845+ fn deserialize_plan ( serialized_plan : & [ u8 ] , ctx : & SessionContext ) -> PyResult < Arc < dyn ExecutionPlan > > {
821846 deltalake:: ensure_initialized ( ) ;
822847 let node = PhysicalPlanNode :: decode ( serialized_plan)
823848 . map_err ( |e| DataFusionError :: External ( Box :: new ( e) ) )
824849 . map_err ( py_datafusion_err) ?;
825- let ctx = SessionContext :: new ( ) ;
826- let plan = node. try_into_physical_plan ( & ctx, ctx. runtime_env ( ) . as_ref ( ) , codec ( ) )
850+ let plan = node. try_into_physical_plan ( ctx, ctx. runtime_env ( ) . as_ref ( ) , codec ( ) )
827851 . map_err ( py_datafusion_err) ?;
852+ Ok ( plan)
853+ }
854+
855+ #[ pyfunction]
856+ pub fn partition_stream ( serialized_plan : & [ u8 ] , partition : usize , py : Python ) -> PyResult < PyRecordBatchStream > {
857+ let ctx = SessionContext :: new ( ) ;
858+ let plan = deserialize_plan ( serialized_plan, & ctx) ?;
828859 let stream_with_runtime = get_tokio_runtime ( ) . 0 . spawn ( async move {
829860 plan. execute ( partition, ctx. task_ctx ( ) )
830861 } ) ;
0 commit comments