18
18
//! [`TreeNode`] for visiting and rewriting expression and plan trees
19
19
20
20
use crate :: Result ;
21
+ use std:: marker:: PhantomData ;
21
22
use std:: sync:: Arc ;
22
23
23
24
/// These macros are used to determine continuation during transforming traversals.
@@ -912,104 +913,6 @@ macro_rules! map_until_stop_and_collect {
912
913
} }
913
914
}
914
915
915
- macro_rules! rewrite_recursive {
916
- ( $START: ident, $NAME: ident, $TRANSFORM_UP: expr, $TRANSFORM_DOWN: expr) => {
917
- let mut queue = vec![ ProcessingState :: NotStarted ( $START) ] ;
918
-
919
- while let Some ( item) = queue. pop( ) {
920
- match item {
921
- ProcessingState :: NotStarted ( $NAME) => {
922
- let node = $TRANSFORM_DOWN?;
923
-
924
- queue. push( match node. tnr {
925
- TreeNodeRecursion :: Continue => {
926
- ProcessingState :: ProcessingChildren {
927
- non_processed_children: node
928
- . data
929
- . arc_children( )
930
- . into_iter( )
931
- . cloned( )
932
- . rev( )
933
- . collect( ) ,
934
- item: node,
935
- processed_children: vec![ ] ,
936
- }
937
- }
938
- TreeNodeRecursion :: Jump => ProcessingState :: ProcessedAllChildren (
939
- node. with_tnr( TreeNodeRecursion :: Continue ) ,
940
- ) ,
941
- TreeNodeRecursion :: Stop => {
942
- ProcessingState :: ProcessedAllChildren ( node)
943
- }
944
- } )
945
- }
946
- ProcessingState :: ProcessingChildren {
947
- mut item,
948
- mut non_processed_children,
949
- mut processed_children,
950
- } => match item. tnr {
951
- TreeNodeRecursion :: Continue | TreeNodeRecursion :: Jump => {
952
- if let Some ( non_processed_item) = non_processed_children. pop( ) {
953
- queue. push( ProcessingState :: ProcessingChildren {
954
- item,
955
- non_processed_children,
956
- processed_children,
957
- } ) ;
958
- queue. push( ProcessingState :: NotStarted ( non_processed_item) ) ;
959
- } else {
960
- item. transformed |=
961
- processed_children. iter( ) . any( |item| item. transformed) ;
962
- item. data = item. data. with_new_arc_children(
963
- processed_children. into_iter( ) . map( |c| c. data) . collect( ) ,
964
- ) ?;
965
- queue. push( ProcessingState :: ProcessedAllChildren ( item) )
966
- }
967
- }
968
- TreeNodeRecursion :: Stop => {
969
- processed_children. extend(
970
- non_processed_children
971
- . into_iter( )
972
- . rev( )
973
- . map( Transformed :: no) ,
974
- ) ;
975
- item. transformed |=
976
- processed_children. iter( ) . any( |item| item. transformed) ;
977
- item. data = item. data. with_new_arc_children(
978
- processed_children. into_iter( ) . map( |c| c. data) . collect( ) ,
979
- ) ?;
980
- queue. push( ProcessingState :: ProcessedAllChildren ( item) ) ;
981
- }
982
- } ,
983
- ProcessingState :: ProcessedAllChildren ( node) => {
984
- let node = node. transform_parent( |$NAME| $TRANSFORM_UP) ?;
985
-
986
- if let Some ( ProcessingState :: ProcessingChildren {
987
- item: mut parent_node,
988
- non_processed_children,
989
- mut processed_children,
990
- ..
991
- } ) = queue. pop( )
992
- {
993
- parent_node. tnr = node. tnr;
994
- processed_children. push( node) ;
995
-
996
- queue. push( ProcessingState :: ProcessingChildren {
997
- item: parent_node,
998
- non_processed_children,
999
- processed_children,
1000
- } )
1001
- } else {
1002
- debug_assert_eq!( queue. len( ) , 0 ) ;
1003
- return Ok ( node) ;
1004
- }
1005
- }
1006
- }
1007
- }
1008
-
1009
- unreachable!( ) ;
1010
- } ;
1011
- }
1012
-
1013
916
/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
1014
917
///
1015
918
/// # Example
@@ -1063,6 +966,59 @@ pub trait DynTreeNode {
1063
966
) -> Result < Arc < Self > > ;
1064
967
}
1065
968
969
+ pub struct LegacyRewriter <
970
+ FD : FnMut ( Node ) -> Result < Transformed < Node > > ,
971
+ FU : FnMut ( Node ) -> Result < Transformed < Node > > ,
972
+ Node : TreeNode ,
973
+ > {
974
+ f_down_func : FD ,
975
+ f_up_func : FU ,
976
+ _node : PhantomData < Node > ,
977
+ }
978
+
979
+ impl <
980
+ FD : FnMut ( Node ) -> Result < Transformed < Node > > ,
981
+ FU : FnMut ( Node ) -> Result < Transformed < Node > > ,
982
+ Node : TreeNode ,
983
+ > LegacyRewriter < FD , FU , Node >
984
+ {
985
+ pub fn new ( f_down_func : FD , f_up_func : FU ) -> Self {
986
+ Self {
987
+ f_down_func,
988
+ f_up_func,
989
+ _node : PhantomData ,
990
+ }
991
+ }
992
+ }
993
+ impl <
994
+ FD : FnMut ( Node ) -> Result < Transformed < Node > > ,
995
+ FU : FnMut ( Node ) -> Result < Transformed < Node > > ,
996
+ Node : TreeNode ,
997
+ > TreeNodeRewriter for LegacyRewriter < FD , FU , Node >
998
+ {
999
+ type Node = Node ;
1000
+
1001
+ fn f_down ( & mut self , node : Self :: Node ) -> Result < Transformed < Self :: Node > > {
1002
+ ( self . f_down_func ) ( node)
1003
+ }
1004
+
1005
+ fn f_up ( & mut self , node : Self :: Node ) -> Result < Transformed < Self :: Node > > {
1006
+ ( self . f_up_func ) ( node)
1007
+ }
1008
+ }
1009
+
1010
+ macro_rules! update_rec_node {
1011
+ ( $NAME: ident, $CHILDREN: ident) => { {
1012
+ $NAME. transformed |= $CHILDREN. iter( ) . any( |item| item. transformed) ;
1013
+
1014
+ $NAME. data = $NAME
1015
+ . data
1016
+ . with_new_arc_children( $CHILDREN. into_iter( ) . map( |c| c. data) . collect( ) ) ?;
1017
+
1018
+ $NAME
1019
+ } } ;
1020
+ }
1021
+
1066
1022
/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
1067
1023
/// (such as [`Arc<dyn PhysicalExpr>`]).
1068
1024
impl < T : DynTreeNode + ?Sized > TreeNode for Arc < T > {
@@ -1102,43 +1058,134 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
1102
1058
FU : FnMut ( Self ) -> Result < Transformed < Self > > ,
1103
1059
> (
1104
1060
self ,
1105
- mut f_down : FD ,
1106
- mut f_up : FU ,
1061
+ f_down : FD ,
1062
+ f_up : FU ,
1107
1063
) -> Result < Transformed < Self > > {
1108
- rewrite_recursive ! ( self , node , f_up( node ) , f_down ( node ) ) ;
1064
+ self . rewrite ( & mut LegacyRewriter :: new ( f_down , f_up) )
1109
1065
}
1110
1066
1111
1067
fn transform_down < F : FnMut ( Self ) -> Result < Transformed < Self > > > (
1112
1068
self ,
1113
1069
f : F ,
1114
1070
) -> Result < Transformed < Self > > {
1115
- self . transform_down_up ( f, |node| Ok ( Transformed :: no ( node) ) )
1071
+ self . rewrite ( & mut LegacyRewriter :: new ( f, |node| {
1072
+ Ok ( Transformed :: no ( node) )
1073
+ } ) )
1116
1074
}
1117
1075
1118
1076
fn transform_up < F : FnMut ( Self ) -> Result < Transformed < Self > > > (
1119
1077
self ,
1120
1078
f : F ,
1121
1079
) -> Result < Transformed < Self > > {
1122
- self . transform_down_up ( |node| Ok ( Transformed :: no ( node) ) , f)
1080
+ self . rewrite ( & mut LegacyRewriter :: new (
1081
+ |node| Ok ( Transformed :: no ( node) ) ,
1082
+ f,
1083
+ ) )
1123
1084
}
1124
1085
fn rewrite < R : TreeNodeRewriter < Node = Self > > (
1125
1086
self ,
1126
1087
rewriter : & mut R ,
1127
1088
) -> Result < Transformed < Self > > {
1128
- rewrite_recursive ! ( self , node, rewriter. f_up( node) , rewriter. f_down( node) ) ;
1089
+ let mut stack = vec ! [ ProcessingState :: NotStarted ( self ) ] ;
1090
+
1091
+ while let Some ( item) = stack. pop ( ) {
1092
+ match item {
1093
+ ProcessingState :: NotStarted ( node) => {
1094
+ let node = rewriter. f_down ( node) ?;
1095
+
1096
+ stack. push ( match node. tnr {
1097
+ TreeNodeRecursion :: Continue => {
1098
+ ProcessingState :: ProcessingChildren {
1099
+ non_processed_children : node
1100
+ . data
1101
+ . arc_children ( )
1102
+ . into_iter ( )
1103
+ . cloned ( )
1104
+ . rev ( )
1105
+ . collect ( ) ,
1106
+ item : node,
1107
+ processed_children : vec ! [ ] ,
1108
+ }
1109
+ }
1110
+ TreeNodeRecursion :: Jump => ProcessingState :: ProcessedAllChildren (
1111
+ node. with_tnr ( TreeNodeRecursion :: Continue ) ,
1112
+ ) ,
1113
+ TreeNodeRecursion :: Stop => {
1114
+ ProcessingState :: ProcessedAllChildren ( node)
1115
+ }
1116
+ } )
1117
+ }
1118
+ ProcessingState :: ProcessingChildren {
1119
+ mut item,
1120
+ mut non_processed_children,
1121
+ mut processed_children,
1122
+ } => match item. tnr {
1123
+ TreeNodeRecursion :: Continue | TreeNodeRecursion :: Jump => {
1124
+ if let Some ( non_processed_item) = non_processed_children. pop ( ) {
1125
+ stack. push ( ProcessingState :: ProcessingChildren {
1126
+ item,
1127
+ non_processed_children,
1128
+ processed_children,
1129
+ } ) ;
1130
+ stack. push ( ProcessingState :: NotStarted ( non_processed_item) ) ;
1131
+ } else {
1132
+ stack. push ( ProcessingState :: ProcessedAllChildren (
1133
+ update_rec_node ! ( item, processed_children) ,
1134
+ ) )
1135
+ }
1136
+ }
1137
+ TreeNodeRecursion :: Stop => {
1138
+ processed_children. extend (
1139
+ non_processed_children
1140
+ . into_iter ( )
1141
+ . rev ( )
1142
+ . map ( Transformed :: no) ,
1143
+ ) ;
1144
+ stack. push ( ProcessingState :: ProcessedAllChildren (
1145
+ update_rec_node ! ( item, processed_children) ,
1146
+ ) ) ;
1147
+ }
1148
+ } ,
1149
+ ProcessingState :: ProcessedAllChildren ( node) => {
1150
+ let node = node. transform_parent ( |n| rewriter. f_up ( n) ) ?;
1151
+
1152
+ if let Some ( ProcessingState :: ProcessingChildren {
1153
+ item : mut parent_node,
1154
+ non_processed_children,
1155
+ mut processed_children,
1156
+ ..
1157
+ } ) = stack. pop ( )
1158
+ {
1159
+ parent_node. tnr = node. tnr ;
1160
+ processed_children. push ( node) ;
1161
+
1162
+ stack. push ( ProcessingState :: ProcessingChildren {
1163
+ item : parent_node,
1164
+ non_processed_children,
1165
+ processed_children,
1166
+ } )
1167
+ } else {
1168
+ debug_assert_eq ! ( stack. len( ) , 0 ) ;
1169
+ return Ok ( node) ;
1170
+ }
1171
+ }
1172
+ }
1173
+ }
1174
+
1175
+ unreachable ! ( ) ;
1129
1176
}
1130
1177
1131
1178
fn visit < ' n , V : TreeNodeVisitor < ' n , Node = Self > > (
1132
1179
& ' n self ,
1133
1180
visitor : & mut V ,
1134
1181
) -> Result < TreeNodeRecursion > {
1135
- let mut queue = vec ! [ VisitingState :: NotStarted ( self ) ] ;
1182
+ let mut stack = vec ! [ VisitingState :: NotStarted ( self ) ] ;
1136
1183
1137
- while let Some ( item) = queue . pop ( ) {
1184
+ while let Some ( item) = stack . pop ( ) {
1138
1185
match item {
1139
1186
VisitingState :: NotStarted ( item) => {
1140
1187
let tnr = visitor. f_down ( item) ?;
1141
- queue . push ( match tnr {
1188
+ stack . push ( match tnr {
1142
1189
TreeNodeRecursion :: Continue => VisitingState :: VisitingChildren {
1143
1190
non_processed_children : item
1144
1191
. arc_children ( )
@@ -1165,14 +1212,14 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
1165
1212
} => match tnr {
1166
1213
TreeNodeRecursion :: Continue | TreeNodeRecursion :: Jump => {
1167
1214
if let Some ( non_processed_item) = non_processed_children. pop ( ) {
1168
- queue . push ( VisitingState :: VisitingChildren {
1215
+ stack . push ( VisitingState :: VisitingChildren {
1169
1216
item,
1170
1217
non_processed_children,
1171
1218
tnr,
1172
1219
} ) ;
1173
- queue . push ( VisitingState :: NotStarted ( non_processed_item) ) ;
1220
+ stack . push ( VisitingState :: NotStarted ( non_processed_item) ) ;
1174
1221
} else {
1175
- queue . push ( VisitingState :: VisitedAllChildren { item, tnr } ) ;
1222
+ stack . push ( VisitingState :: VisitedAllChildren { item, tnr } ) ;
1176
1223
}
1177
1224
}
1178
1225
TreeNodeRecursion :: Stop => {
@@ -1186,15 +1233,15 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
1186
1233
item,
1187
1234
non_processed_children,
1188
1235
..
1189
- } ) = queue . pop ( )
1236
+ } ) = stack . pop ( )
1190
1237
{
1191
- queue . push ( VisitingState :: VisitingChildren {
1238
+ stack . push ( VisitingState :: VisitingChildren {
1192
1239
item,
1193
1240
non_processed_children,
1194
1241
tnr,
1195
1242
} ) ;
1196
1243
} else {
1197
- debug_assert_eq ! ( queue . len( ) , 0 ) ;
1244
+ debug_assert_eq ! ( stack . len( ) , 0 ) ;
1198
1245
return Ok ( tnr) ;
1199
1246
}
1200
1247
}
@@ -1208,30 +1255,32 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
1208
1255
#[ derive( Debug ) ]
1209
1256
enum ProcessingState < T > {
1210
1257
NotStarted ( T ) ,
1211
- // f_down is called
1258
+ // ← at this point, f_down is called
1212
1259
ProcessingChildren {
1213
1260
item : Transformed < T > ,
1214
1261
non_processed_children : Vec < T > ,
1215
1262
processed_children : Vec < Transformed < T > > ,
1216
1263
} ,
1264
+ // ← at this point, all children are processed
1217
1265
ProcessedAllChildren ( Transformed < T > ) ,
1218
- // f_up is called
1266
+ // ← at this point, f_up is called
1219
1267
}
1220
1268
1221
1269
#[ derive( Debug ) ]
1222
1270
enum VisitingState < ' a , T > {
1223
1271
NotStarted ( & ' a T ) ,
1224
- // f_down is called
1272
+ // ← at this point, f_down is called
1225
1273
VisitingChildren {
1226
1274
item : & ' a T ,
1227
1275
non_processed_children : Vec < & ' a T > ,
1228
1276
tnr : TreeNodeRecursion ,
1229
1277
} ,
1278
+ // ← at this point, all children are visited
1230
1279
VisitedAllChildren {
1231
1280
item : & ' a T ,
1232
1281
tnr : TreeNodeRecursion ,
1233
1282
} ,
1234
- // f_up is called
1283
+ // ← at this point, f_up is called
1235
1284
}
1236
1285
1237
1286
/// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for
0 commit comments