@@ -20,14 +20,13 @@ use crate::protobuf::{RayShuffleReaderExecNode, RayShuffleWriterExecNode, RaySql
20
20
use crate :: shuffle:: { RayShuffleReaderExec , RayShuffleWriterExec } ;
21
21
use datafusion:: arrow:: datatypes:: SchemaRef ;
22
22
use datafusion:: common:: { DataFusionError , Result } ;
23
- use datafusion:: execution:: runtime_env:: RuntimeEnv ;
24
23
use datafusion:: execution:: FunctionRegistry ;
25
24
use datafusion:: physical_plan:: { ExecutionPlan , Partitioning } ;
26
25
use datafusion_proto:: physical_plan:: from_proto:: parse_protobuf_hash_partitioning;
27
26
use datafusion_proto:: physical_plan:: to_proto:: serialize_physical_expr;
27
+ use datafusion_proto:: physical_plan:: DefaultPhysicalExtensionCodec ;
28
28
use datafusion_proto:: physical_plan:: PhysicalExtensionCodec ;
29
- use datafusion_proto:: physical_plan:: { AsExecutionPlan , DefaultPhysicalExtensionCodec } ;
30
- use datafusion_proto:: protobuf:: { self , PhysicalHashRepartition , PhysicalPlanNode } ;
29
+ use datafusion_proto:: protobuf:: { self , PhysicalHashRepartition } ;
31
30
use prost:: Message ;
32
31
use std:: sync:: Arc ;
33
32
@@ -38,48 +37,62 @@ impl PhysicalExtensionCodec for ShuffleCodec {
38
37
fn try_decode (
39
38
& self ,
40
39
buf : & [ u8 ] ,
41
- _inputs : & [ Arc < dyn ExecutionPlan > ] ,
40
+ inputs : & [ Arc < dyn ExecutionPlan > ] ,
42
41
registry : & dyn FunctionRegistry ,
43
42
) -> Result < Arc < dyn ExecutionPlan > , DataFusionError > {
44
43
// decode bytes to protobuf struct
45
44
let node = RaySqlExecNode :: decode ( buf)
46
45
. map_err ( |e| DataFusionError :: Internal ( format ! ( "failed to decode plan: {e:?}" ) ) ) ?;
47
46
let extension_codec = DefaultPhysicalExtensionCodec { } ;
48
- match node. plan_type {
49
- Some ( PlanType :: RayShuffleReader ( reader) ) => {
50
- let schema = reader. schema . as_ref ( ) . unwrap ( ) ;
51
- let schema: SchemaRef = Arc :: new ( schema. try_into ( ) . unwrap ( ) ) ;
52
- let hash_part = parse_protobuf_hash_partitioning (
53
- reader. partitioning . as_ref ( ) ,
54
- registry,
55
- & schema,
56
- & extension_codec,
57
- ) ?;
58
- Ok ( Arc :: new ( RayShuffleReaderExec :: new (
59
- reader. stage_id as usize ,
60
- schema,
61
- hash_part. unwrap ( ) ,
62
- ) ) )
47
+ if let Some ( plan_type) = node. plan_type {
48
+ match plan_type {
49
+ PlanType :: RayShuffleReader ( reader) => {
50
+ let schema = reader. schema . as_ref ( ) . ok_or_else ( || {
51
+ DataFusionError :: Execution ( "invalid encoded schema" . into ( ) )
52
+ } ) ?;
53
+ let schema: SchemaRef = Arc :: new ( schema. try_into ( ) ?) ;
54
+ let hash_part = parse_protobuf_hash_partitioning (
55
+ reader. partitioning . as_ref ( ) ,
56
+ registry,
57
+ & schema,
58
+ & extension_codec,
59
+ ) ?
60
+ . ok_or_else ( || {
61
+ DataFusionError :: Execution ( "missing partitioning info" . into ( ) )
62
+ } ) ?;
63
+ Ok ( Arc :: new ( RayShuffleReaderExec :: new (
64
+ reader. stage_id as usize ,
65
+ schema,
66
+ hash_part,
67
+ ) ) )
68
+ }
69
+ PlanType :: RayShuffleWriter ( writer) => {
70
+ let plan = inputs
71
+ . first ( )
72
+ . ok_or_else ( || {
73
+ DataFusionError :: Execution ( "No inputs for shuffle writer" . into ( ) )
74
+ } ) ?
75
+ . to_owned ( ) ;
76
+ let hash_part = parse_protobuf_hash_partitioning (
77
+ writer. partitioning . as_ref ( ) ,
78
+ registry,
79
+ plan. schema ( ) . as_ref ( ) ,
80
+ & extension_codec,
81
+ ) ?
82
+ . ok_or_else ( || {
83
+ DataFusionError :: Execution ( "missing partitioning info" . into ( ) )
84
+ } ) ?;
85
+ Ok ( Arc :: new ( RayShuffleWriterExec :: new (
86
+ writer. stage_id as usize ,
87
+ plan,
88
+ hash_part,
89
+ ) ) )
90
+ }
63
91
}
64
- Some ( PlanType :: RayShuffleWriter ( writer) ) => {
65
- let plan = writer. plan . unwrap ( ) . try_into_physical_plan (
66
- registry,
67
- & RuntimeEnv :: default ( ) ,
68
- self ,
69
- ) ?;
70
- let hash_part = parse_protobuf_hash_partitioning (
71
- writer. partitioning . as_ref ( ) ,
72
- registry,
73
- plan. schema ( ) . as_ref ( ) ,
74
- & extension_codec,
75
- ) ?;
76
- Ok ( Arc :: new ( RayShuffleWriterExec :: new (
77
- writer. stage_id as usize ,
78
- plan,
79
- hash_part. unwrap ( ) ,
80
- ) ) )
81
- }
82
- _ => unreachable ! ( ) ,
92
+ } else {
93
+ Err ( DataFusionError :: Execution (
94
+ "RaySqlExecNode with no plan_type" . into ( ) ,
95
+ ) )
83
96
}
84
97
}
85
98
@@ -88,7 +101,7 @@ impl PhysicalExtensionCodec for ShuffleCodec {
88
101
node : Arc < dyn ExecutionPlan > ,
89
102
buf : & mut Vec < u8 > ,
90
103
) -> Result < ( ) , DataFusionError > {
91
- let plan = if let Some ( reader) = node. as_any ( ) . downcast_ref :: < RayShuffleReaderExec > ( ) {
104
+ if let Some ( reader) = node. as_any ( ) . downcast_ref :: < RayShuffleReaderExec > ( ) {
92
105
let schema: protobuf:: Schema = reader. schema ( ) . try_into ( ) . unwrap ( ) ;
93
106
let partitioning =
94
107
encode_partitioning_scheme ( reader. properties ( ) . output_partitioning ( ) ) ?;
@@ -97,22 +110,27 @@ impl PhysicalExtensionCodec for ShuffleCodec {
97
110
schema : Some ( schema) ,
98
111
partitioning : Some ( partitioning) ,
99
112
} ;
100
- PlanType :: RayShuffleReader ( reader)
113
+ PlanType :: RayShuffleReader ( reader) . encode ( buf) ;
114
+ Ok ( ( ) )
101
115
} else if let Some ( writer) = node. as_any ( ) . downcast_ref :: < RayShuffleWriterExec > ( ) {
102
- let plan = PhysicalPlanNode :: try_from_physical_plan ( writer. plan . clone ( ) , self ) ?;
103
116
let partitioning =
104
117
encode_partitioning_scheme ( writer. properties ( ) . output_partitioning ( ) ) ?;
105
118
let writer = RayShuffleWriterExecNode {
106
119
stage_id : writer. stage_id as u32 ,
107
- plan : Some ( plan) ,
120
+ // No need to redundantly serialize the child plan, as input plan(s) are recursively
121
+ // serialized by PhysicalPlanNode and will be available as `inputs` in `try_decode`.
122
+ // TODO: remove this field from the proto definition?
123
+ plan : None ,
108
124
partitioning : Some ( partitioning) ,
109
125
} ;
110
- PlanType :: RayShuffleWriter ( writer)
126
+ PlanType :: RayShuffleWriter ( writer) . encode ( buf) ;
127
+ Ok ( ( ) )
111
128
} else {
112
- unreachable ! ( )
113
- } ;
114
- plan. encode ( buf) ;
115
- Ok ( ( ) )
129
+ Err ( DataFusionError :: Execution ( format ! (
130
+ "Unsupported plan node: {}" ,
131
+ node. name( )
132
+ ) ) )
133
+ }
116
134
}
117
135
}
118
136
0 commit comments