1616// under the License.
1717
1818use std:: collections:: HashMap ;
19+ use std:: cmp:: { max, min} ;
1920use std:: ffi:: CString ;
21+ use std:: ops:: IndexMut ;
2022use std:: sync:: Arc ;
2123
2224use arrow:: array:: { new_null_array, RecordBatch , RecordBatchIterator , RecordBatchReader } ;
@@ -28,26 +30,35 @@ use arrow::pyarrow::FromPyArrow;
2830use datafusion:: arrow:: datatypes:: Schema ;
2931use datafusion:: arrow:: pyarrow:: { PyArrowType , ToPyArrow } ;
3032use datafusion:: arrow:: util:: pretty;
31- use datafusion:: common:: UnnestOptions ;
33+ use datafusion:: common:: { DFSchema , Statistics , UnnestOptions } ;
34+ use datafusion:: common:: stats:: Precision ;
35+ use datafusion:: common:: tree_node:: { Transformed , TreeNode } ;
3236use datafusion:: config:: { CsvOptions , ParquetColumnOptions , ParquetOptions , TableParquetOptions } ;
3337use datafusion:: dataframe:: { DataFrame , DataFrameWriteOptions } ;
38+ use datafusion:: datasource:: physical_plan:: FileScanConfig ;
39+ use datafusion:: datasource:: source:: { DataSource , DataSourceExec } ;
3440use datafusion:: datasource:: TableProvider ;
3541use datafusion:: error:: DataFusionError ;
3642use datafusion:: execution:: SendableRecordBatchStream ;
3743use datafusion:: parquet:: basic:: { BrotliLevel , Compression , GzipLevel , ZstdLevel } ;
44+ use datafusion:: physical_plan:: ExecutionPlan ;
3845use datafusion:: prelude:: * ;
3946use datafusion_ffi:: table_provider:: FFI_TableProvider ;
47+ use datafusion:: sql:: unparser:: plan_to_sql;
48+ use datafusion_proto:: physical_plan:: AsExecutionPlan ;
49+ use datafusion_proto:: protobuf:: PhysicalPlanNode ;
4050use futures:: { StreamExt , TryStreamExt } ;
51+ use prost:: Message ;
4152use pyo3:: exceptions:: PyValueError ;
4253use pyo3:: prelude:: * ;
4354use pyo3:: pybacked:: PyBackedStr ;
44- use pyo3:: types:: { PyCapsule , PyList , PyTuple , PyTupleMethods } ;
55+ use pyo3:: types:: { PyBytes , PyCapsule , PyDict , PyList , PyString , PyTuple , PyTupleMethods } ;
4556use tokio:: task:: JoinHandle ;
4657
4758use crate :: catalog:: PyTable ;
4859use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionError } ;
4960use crate :: expr:: sort_expr:: to_sort_expressions;
50- use crate :: physical_plan:: PyExecutionPlan ;
61+ use crate :: physical_plan:: { codec , PyExecutionPlan } ;
5162use crate :: record_batch:: PyRecordBatchStream ;
5263use crate :: sql:: logical:: PyLogicalPlan ;
5364use crate :: utils:: {
@@ -57,6 +68,7 @@ use crate::{
5768 errors:: PyDataFusionResult ,
5869 expr:: { sort_expr:: PySortExpr , PyExpr } ,
5970} ;
71+ use crate :: common:: df_schema:: PyDFSchema ;
6072
6173// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
6274// - we have not decided on the table_provider approach yet
@@ -992,6 +1004,168 @@ impl PyDataFrame {
9921004 let df = self . df . as_ref ( ) . clone ( ) . fill_null ( scalar_value, cols) ?;
9931005 Ok ( Self :: new ( df) )
9941006 }
1007+
1008+ fn distributed_plan ( & self , py : Python < ' _ > ) -> PyResult < DistributedPlan > {
1009+ let future_plan = DistributedPlan :: try_new ( self . df . as_ref ( ) ) ;
1010+ wait_for_future ( py, future_plan) ?. map_err ( py_datafusion_err)
1011+ }
1012+
1013+ fn plan_sql ( & self , py : Python < ' _ > ) -> PyResult < PyObject > {
1014+ let logical_plan = self . df . logical_plan ( ) ;
1015+
1016+ let sql = plan_to_sql ( logical_plan) . map_err ( py_datafusion_err) ?;
1017+ Ok ( PyString :: new ( py, sql. to_string ( ) . as_ref ( ) ) . into ( ) )
1018+ }
1019+ }
1020+
1021+ #[ pyclass( get_all) ]
1022+ #[ derive( Debug , Clone ) ]
1023+ pub struct DistributedPlan {
1024+ min_size : usize ,
1025+ physical_plan : PyExecutionPlan ,
1026+ }
1027+
1028+ #[ pymethods]
1029+ impl DistributedPlan {
1030+ #[ new]
1031+ fn unmarshal ( state : Bound < PyDict > ) -> PyResult < Self > {
1032+ let ctx = SessionContext :: new ( ) ;
1033+ let serialized_plan = state
1034+ . get_item ( "plan" ) ?
1035+ . expect ( "missing key `plan` from state" ) ;
1036+ let serialized_plan = serialized_plan. downcast :: < PyBytes > ( ) ?. as_bytes ( ) ;
1037+ let min_size = state
1038+ . get_item ( "min_size" ) ?
1039+ . expect ( "missing key `min_size` from state" )
1040+ . extract :: < usize > ( ) ?;
1041+ let plan = deserialize_plan ( serialized_plan, & ctx) ?;
1042+ Ok ( Self {
1043+ min_size,
1044+ physical_plan : PyExecutionPlan :: new ( plan) ,
1045+ } )
1046+ }
1047+
1048+ fn partition_count ( & self ) -> usize {
1049+ self . physical_plan . partition_count ( )
1050+ }
1051+
1052+ fn num_bytes ( & self ) -> Option < usize > {
1053+ self . stats_field ( |stats| stats. total_byte_size )
1054+ }
1055+
1056+ fn num_rows ( & self ) -> Option < usize > {
1057+ self . stats_field ( |stats| stats. num_rows )
1058+ }
1059+
1060+ fn schema ( & self ) -> PyResult < PyDFSchema > {
1061+ DFSchema :: try_from ( self . plan ( ) . schema ( ) )
1062+ . map ( PyDFSchema :: from)
1063+ . map_err ( py_datafusion_err)
1064+ }
1065+
1066+ fn set_desired_parallelism ( & mut self , desired_parallelism : usize ) -> PyResult < ( ) > {
1067+ let updated_plan = self
1068+ . plan ( )
1069+ . clone ( )
1070+ . transform_up ( |node| {
1071+ if let Some ( exec) = node. as_any ( ) . downcast_ref :: < DataSourceExec > ( ) {
1072+ // Remove redundant ranges from partition files because FileScanConfig refuses to repartition
1073+ // if any file has a range defined (even when the range actually covers the entire file).
1074+ // The EnforceDistribution optimizer rule adds ranges for both full and partial files,
1075+ // so this tries to revert that in order to trigger a repartition when no files are actually split.
1076+ // TODO: check whether EnforceDistribution is still adding redundant ranges and remove this
1077+ // workaround if no longer needed.
1078+ if let Some ( file_scan) =
1079+ exec. data_source ( ) . as_any ( ) . downcast_ref :: < FileScanConfig > ( )
1080+ {
1081+ let mut range_free_file_scan = file_scan. clone ( ) ;
1082+ let mut total_size: usize = 0 ;
1083+ for group in range_free_file_scan. file_groups . iter_mut ( ) {
1084+ for group_idx in 0 ..group. len ( ) {
1085+ let file = group. index_mut ( group_idx) ;
1086+ if let Some ( range) = & file. range {
1087+ total_size += ( range. end - range. start ) as usize ;
1088+ if range. start == 0 && range. end == file. object_meta . size as i64
1089+ {
1090+ file. range = None ; // remove redundant range
1091+ }
1092+ } else {
1093+ total_size += file. object_meta . size as usize ;
1094+ }
1095+
1096+ }
1097+ }
1098+ let min_size_buckets = max ( 1 , total_size. div_ceil ( self . min_size ) ) ;
1099+ let partitions = min ( min_size_buckets, desired_parallelism) ;
1100+ let ordering = range_free_file_scan. eq_properties ( ) . output_ordering ( ) ;
1101+ if let Some ( repartitioned) =
1102+ range_free_file_scan. repartitioned ( partitions, 1 , ordering) ?
1103+ {
1104+ return Ok ( Transformed :: yes ( Arc :: new ( DataSourceExec :: new (
1105+ repartitioned,
1106+ ) ) ) ) ;
1107+ }
1108+ }
1109+ }
1110+ Ok ( Transformed :: no ( node) )
1111+ } )
1112+ . map_err ( py_datafusion_err) ?
1113+ . data ;
1114+ self . physical_plan = PyExecutionPlan :: new ( updated_plan) ;
1115+ Ok ( ( ) )
1116+ }
1117+ }
1118+
1119+ impl DistributedPlan {
1120+ async fn try_new ( df : & DataFrame ) -> Result < Self , DataFusionError > {
1121+ let ( mut session_state, logical_plan) = df. clone ( ) . into_parts ( ) ;
1122+ let min_size = session_state
1123+ . config_options ( )
1124+ . optimizer
1125+ . repartition_file_min_size ;
1126+ // Create the physical plan with a single partition, to ensure that no files are split into ranges.
1127+ // Otherwise, any subsequent repartition attempt would fail (see the comment in `set_desired_parallelism`)
1128+ session_state
1129+ . config_mut ( )
1130+ . options_mut ( )
1131+ . execution
1132+ . target_partitions = 1 ;
1133+ let physical_plan = session_state. create_physical_plan ( & logical_plan) . await ?;
1134+ let physical_plan = PyExecutionPlan :: new ( physical_plan) ;
1135+ Ok ( Self {
1136+ min_size,
1137+ physical_plan,
1138+ } )
1139+ }
1140+
1141+ fn plan ( & self ) -> & Arc < dyn ExecutionPlan > {
1142+ & self . physical_plan . plan
1143+ }
1144+
1145+ fn stats_field ( & self , field : fn ( Statistics ) -> Precision < usize > ) -> Option < usize > {
1146+ if let Ok ( stats) = self . plan ( ) . partition_statistics ( None ) {
1147+ match field ( stats) {
1148+ Precision :: Exact ( n) => Some ( n) ,
1149+ _ => None ,
1150+ }
1151+ } else {
1152+ None
1153+ }
1154+ }
1155+ }
1156+
1157+ fn deserialize_plan (
1158+ serialized_plan : & [ u8 ] ,
1159+ ctx : & SessionContext ,
1160+ ) -> PyResult < Arc < dyn ExecutionPlan > > {
1161+ deltalake:: ensure_initialized ( ) ;
1162+ let node = PhysicalPlanNode :: decode ( serialized_plan)
1163+ . map_err ( |e| DataFusionError :: External ( Box :: new ( e) ) )
1164+ . map_err ( py_datafusion_err) ?;
1165+ let plan = node
1166+ . try_into_physical_plan ( ctx, ctx. runtime_env ( ) . as_ref ( ) , codec ( ) )
1167+ . map_err ( py_datafusion_err) ?;
1168+ Ok ( plan)
9951169}
9961170
9971171/// Print DataFrame
0 commit comments