|
18 | 18 | //! [`TreeNode`] for visiting and rewriting expression and plan trees
|
19 | 19 |
|
20 | 20 | use crate::Result;
|
| 21 | +use std::marker::PhantomData; |
21 | 22 | use std::sync::Arc;
|
22 | 23 |
|
23 | 24 | /// These macros are used to determine continuation during transforming traversals.
|
@@ -912,104 +913,6 @@ macro_rules! map_until_stop_and_collect {
|
912 | 913 | }}
|
913 | 914 | }
|
914 | 915 |
|
915 |
| -macro_rules! rewrite_recursive { |
916 |
| - ($START:ident, $NAME:ident, $TRANSFORM_UP:expr, $TRANSFORM_DOWN:expr) => { |
917 |
| - let mut queue = vec![ProcessingState::NotStarted($START)]; |
918 |
| - |
919 |
| - while let Some(item) = queue.pop() { |
920 |
| - match item { |
921 |
| - ProcessingState::NotStarted($NAME) => { |
922 |
| - let node = $TRANSFORM_DOWN?; |
923 |
| - |
924 |
| - queue.push(match node.tnr { |
925 |
| - TreeNodeRecursion::Continue => { |
926 |
| - ProcessingState::ProcessingChildren { |
927 |
| - non_processed_children: node |
928 |
| - .data |
929 |
| - .arc_children() |
930 |
| - .into_iter() |
931 |
| - .cloned() |
932 |
| - .rev() |
933 |
| - .collect(), |
934 |
| - item: node, |
935 |
| - processed_children: vec![], |
936 |
| - } |
937 |
| - } |
938 |
| - TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( |
939 |
| - node.with_tnr(TreeNodeRecursion::Continue), |
940 |
| - ), |
941 |
| - TreeNodeRecursion::Stop => { |
942 |
| - ProcessingState::ProcessedAllChildren(node) |
943 |
| - } |
944 |
| - }) |
945 |
| - } |
946 |
| - ProcessingState::ProcessingChildren { |
947 |
| - mut item, |
948 |
| - mut non_processed_children, |
949 |
| - mut processed_children, |
950 |
| - } => match item.tnr { |
951 |
| - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { |
952 |
| - if let Some(non_processed_item) = non_processed_children.pop() { |
953 |
| - queue.push(ProcessingState::ProcessingChildren { |
954 |
| - item, |
955 |
| - non_processed_children, |
956 |
| - processed_children, |
957 |
| - }); |
958 |
| - queue.push(ProcessingState::NotStarted(non_processed_item)); |
959 |
| - } else { |
960 |
| - item.transformed |= |
961 |
| - processed_children.iter().any(|item| item.transformed); |
962 |
| - item.data = item.data.with_new_arc_children( |
963 |
| - processed_children.into_iter().map(|c| c.data).collect(), |
964 |
| - )?; |
965 |
| - queue.push(ProcessingState::ProcessedAllChildren(item)) |
966 |
| - } |
967 |
| - } |
968 |
| - TreeNodeRecursion::Stop => { |
969 |
| - processed_children.extend( |
970 |
| - non_processed_children |
971 |
| - .into_iter() |
972 |
| - .rev() |
973 |
| - .map(Transformed::no), |
974 |
| - ); |
975 |
| - item.transformed |= |
976 |
| - processed_children.iter().any(|item| item.transformed); |
977 |
| - item.data = item.data.with_new_arc_children( |
978 |
| - processed_children.into_iter().map(|c| c.data).collect(), |
979 |
| - )?; |
980 |
| - queue.push(ProcessingState::ProcessedAllChildren(item)); |
981 |
| - } |
982 |
| - }, |
983 |
| - ProcessingState::ProcessedAllChildren(node) => { |
984 |
| - let node = node.transform_parent(|$NAME| $TRANSFORM_UP)?; |
985 |
| - |
986 |
| - if let Some(ProcessingState::ProcessingChildren { |
987 |
| - item: mut parent_node, |
988 |
| - non_processed_children, |
989 |
| - mut processed_children, |
990 |
| - .. |
991 |
| - }) = queue.pop() |
992 |
| - { |
993 |
| - parent_node.tnr = node.tnr; |
994 |
| - processed_children.push(node); |
995 |
| - |
996 |
| - queue.push(ProcessingState::ProcessingChildren { |
997 |
| - item: parent_node, |
998 |
| - non_processed_children, |
999 |
| - processed_children, |
1000 |
| - }) |
1001 |
| - } else { |
1002 |
| - debug_assert_eq!(queue.len(), 0); |
1003 |
| - return Ok(node); |
1004 |
| - } |
1005 |
| - } |
1006 |
| - } |
1007 |
| - } |
1008 |
| - |
1009 |
| - unreachable!(); |
1010 |
| - }; |
1011 |
| -} |
1012 |
| - |
1013 | 916 | /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
|
1014 | 917 | ///
|
1015 | 918 | /// # Example
|
@@ -1063,6 +966,59 @@ pub trait DynTreeNode {
|
1063 | 966 | ) -> Result<Arc<Self>>;
|
1064 | 967 | }
|
1065 | 968 |
|
| 969 | +pub struct LegacyRewriter< |
| 970 | + FD: FnMut(Node) -> Result<Transformed<Node>>, |
| 971 | + FU: FnMut(Node) -> Result<Transformed<Node>>, |
| 972 | + Node: TreeNode, |
| 973 | +> { |
| 974 | + f_down_func: FD, |
| 975 | + f_up_func: FU, |
| 976 | + _node: PhantomData<Node>, |
| 977 | +} |
| 978 | + |
| 979 | +impl< |
| 980 | + FD: FnMut(Node) -> Result<Transformed<Node>>, |
| 981 | + FU: FnMut(Node) -> Result<Transformed<Node>>, |
| 982 | + Node: TreeNode, |
| 983 | + > LegacyRewriter<FD, FU, Node> |
| 984 | +{ |
| 985 | + pub fn new(f_down_func: FD, f_up_func: FU) -> Self { |
| 986 | + Self { |
| 987 | + f_down_func, |
| 988 | + f_up_func, |
| 989 | + _node: PhantomData, |
| 990 | + } |
| 991 | + } |
| 992 | +} |
| 993 | +impl< |
| 994 | + FD: FnMut(Node) -> Result<Transformed<Node>>, |
| 995 | + FU: FnMut(Node) -> Result<Transformed<Node>>, |
| 996 | + Node: TreeNode, |
| 997 | + > TreeNodeRewriter for LegacyRewriter<FD, FU, Node> |
| 998 | +{ |
| 999 | + type Node = Node; |
| 1000 | + |
| 1001 | + fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> { |
| 1002 | + (self.f_down_func)(node) |
| 1003 | + } |
| 1004 | + |
| 1005 | + fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> { |
| 1006 | + (self.f_up_func)(node) |
| 1007 | + } |
| 1008 | +} |
| 1009 | + |
| 1010 | +macro_rules! update_rec_node { |
| 1011 | + ($NAME:ident, $CHILDREN:ident) => {{ |
| 1012 | + $NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed); |
| 1013 | + |
| 1014 | + $NAME.data = $NAME |
| 1015 | + .data |
| 1016 | + .with_new_arc_children($CHILDREN.into_iter().map(|c| c.data).collect())?; |
| 1017 | + |
| 1018 | + $NAME |
| 1019 | + }}; |
| 1020 | +} |
| 1021 | + |
1066 | 1022 | /// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
|
1067 | 1023 | /// (such as [`Arc<dyn PhysicalExpr>`]).
|
1068 | 1024 | impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
|
@@ -1102,30 +1058,121 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
|
1102 | 1058 | FU: FnMut(Self) -> Result<Transformed<Self>>,
|
1103 | 1059 | >(
|
1104 | 1060 | self,
|
1105 |
| - mut f_down: FD, |
1106 |
| - mut f_up: FU, |
| 1061 | + f_down: FD, |
| 1062 | + f_up: FU, |
1107 | 1063 | ) -> Result<Transformed<Self>> {
|
1108 |
| - rewrite_recursive!(self, node, f_up(node), f_down(node)); |
| 1064 | + self.rewrite(&mut LegacyRewriter::new(f_down, f_up)) |
1109 | 1065 | }
|
1110 | 1066 |
|
1111 | 1067 | fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
|
1112 | 1068 | self,
|
1113 | 1069 | f: F,
|
1114 | 1070 | ) -> Result<Transformed<Self>> {
|
1115 |
| - self.transform_down_up(f, |node| Ok(Transformed::no(node))) |
| 1071 | + self.rewrite(&mut LegacyRewriter::new(f, |node| { |
| 1072 | + Ok(Transformed::no(node)) |
| 1073 | + })) |
1116 | 1074 | }
|
1117 | 1075 |
|
1118 | 1076 | fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
|
1119 | 1077 | self,
|
1120 | 1078 | f: F,
|
1121 | 1079 | ) -> Result<Transformed<Self>> {
|
1122 |
| - self.transform_down_up(|node| Ok(Transformed::no(node)), f) |
| 1080 | + self.rewrite(&mut LegacyRewriter::new( |
| 1081 | + |node| Ok(Transformed::no(node)), |
| 1082 | + f, |
| 1083 | + )) |
1123 | 1084 | }
|
1124 | 1085 | fn rewrite<R: TreeNodeRewriter<Node = Self>>(
|
1125 | 1086 | self,
|
1126 | 1087 | rewriter: &mut R,
|
1127 | 1088 | ) -> Result<Transformed<Self>> {
|
1128 |
| - rewrite_recursive!(self, node, rewriter.f_up(node), rewriter.f_down(node)); |
| 1089 | + let mut queue = vec![ProcessingState::NotStarted(self)]; |
| 1090 | + |
| 1091 | + while let Some(item) = queue.pop() { |
| 1092 | + match item { |
| 1093 | + ProcessingState::NotStarted(node) => { |
| 1094 | + let node = rewriter.f_down(node)?; |
| 1095 | + |
| 1096 | + queue.push(match node.tnr { |
| 1097 | + TreeNodeRecursion::Continue => { |
| 1098 | + ProcessingState::ProcessingChildren { |
| 1099 | + non_processed_children: node |
| 1100 | + .data |
| 1101 | + .arc_children() |
| 1102 | + .into_iter() |
| 1103 | + .cloned() |
| 1104 | + .rev() |
| 1105 | + .collect(), |
| 1106 | + item: node, |
| 1107 | + processed_children: vec![], |
| 1108 | + } |
| 1109 | + } |
| 1110 | + TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( |
| 1111 | + node.with_tnr(TreeNodeRecursion::Continue), |
| 1112 | + ), |
| 1113 | + TreeNodeRecursion::Stop => { |
| 1114 | + ProcessingState::ProcessedAllChildren(node) |
| 1115 | + } |
| 1116 | + }) |
| 1117 | + } |
| 1118 | + ProcessingState::ProcessingChildren { |
| 1119 | + mut item, |
| 1120 | + mut non_processed_children, |
| 1121 | + mut processed_children, |
| 1122 | + } => match item.tnr { |
| 1123 | + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { |
| 1124 | + if let Some(non_processed_item) = non_processed_children.pop() { |
| 1125 | + queue.push(ProcessingState::ProcessingChildren { |
| 1126 | + item, |
| 1127 | + non_processed_children, |
| 1128 | + processed_children, |
| 1129 | + }); |
| 1130 | + queue.push(ProcessingState::NotStarted(non_processed_item)); |
| 1131 | + } else { |
| 1132 | + queue.push(ProcessingState::ProcessedAllChildren( |
| 1133 | + update_rec_node!(item, processed_children), |
| 1134 | + )) |
| 1135 | + } |
| 1136 | + } |
| 1137 | + TreeNodeRecursion::Stop => { |
| 1138 | + processed_children.extend( |
| 1139 | + non_processed_children |
| 1140 | + .into_iter() |
| 1141 | + .rev() |
| 1142 | + .map(Transformed::no), |
| 1143 | + ); |
| 1144 | + queue.push(ProcessingState::ProcessedAllChildren( |
| 1145 | + update_rec_node!(item, processed_children), |
| 1146 | + )); |
| 1147 | + } |
| 1148 | + }, |
| 1149 | + ProcessingState::ProcessedAllChildren(node) => { |
| 1150 | + let node = node.transform_parent(|n| rewriter.f_up(n))?; |
| 1151 | + |
| 1152 | + if let Some(ProcessingState::ProcessingChildren { |
| 1153 | + item: mut parent_node, |
| 1154 | + non_processed_children, |
| 1155 | + mut processed_children, |
| 1156 | + .. |
| 1157 | + }) = queue.pop() |
| 1158 | + { |
| 1159 | + parent_node.tnr = node.tnr; |
| 1160 | + processed_children.push(node); |
| 1161 | + |
| 1162 | + queue.push(ProcessingState::ProcessingChildren { |
| 1163 | + item: parent_node, |
| 1164 | + non_processed_children, |
| 1165 | + processed_children, |
| 1166 | + }) |
| 1167 | + } else { |
| 1168 | + debug_assert_eq!(queue.len(), 0); |
| 1169 | + return Ok(node); |
| 1170 | + } |
| 1171 | + } |
| 1172 | + } |
| 1173 | + } |
| 1174 | + |
| 1175 | + unreachable!(); |
1129 | 1176 | }
|
1130 | 1177 |
|
1131 | 1178 | fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
|
|
0 commit comments