Skip to content

Commit 83d194e

Browse files
committed
Remove macros in favour of LegacyRewriter
1 parent 6a4dd2c commit 83d194e

File tree

1 file changed

+166
-117
lines changed

1 file changed

+166
-117
lines changed

datafusion/common/src/tree_node.rs

Lines changed: 166 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! [`TreeNode`] for visiting and rewriting expression and plan trees
1919
2020
use crate::Result;
21+
use std::marker::PhantomData;
2122
use std::sync::Arc;
2223

2324
/// These macros are used to determine continuation during transforming traversals.
@@ -912,104 +913,6 @@ macro_rules! map_until_stop_and_collect {
912913
}}
913914
}
914915

915-
macro_rules! rewrite_recursive {
916-
($START:ident, $NAME:ident, $TRANSFORM_UP:expr, $TRANSFORM_DOWN:expr) => {
917-
let mut queue = vec![ProcessingState::NotStarted($START)];
918-
919-
while let Some(item) = queue.pop() {
920-
match item {
921-
ProcessingState::NotStarted($NAME) => {
922-
let node = $TRANSFORM_DOWN?;
923-
924-
queue.push(match node.tnr {
925-
TreeNodeRecursion::Continue => {
926-
ProcessingState::ProcessingChildren {
927-
non_processed_children: node
928-
.data
929-
.arc_children()
930-
.into_iter()
931-
.cloned()
932-
.rev()
933-
.collect(),
934-
item: node,
935-
processed_children: vec![],
936-
}
937-
}
938-
TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren(
939-
node.with_tnr(TreeNodeRecursion::Continue),
940-
),
941-
TreeNodeRecursion::Stop => {
942-
ProcessingState::ProcessedAllChildren(node)
943-
}
944-
})
945-
}
946-
ProcessingState::ProcessingChildren {
947-
mut item,
948-
mut non_processed_children,
949-
mut processed_children,
950-
} => match item.tnr {
951-
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
952-
if let Some(non_processed_item) = non_processed_children.pop() {
953-
queue.push(ProcessingState::ProcessingChildren {
954-
item,
955-
non_processed_children,
956-
processed_children,
957-
});
958-
queue.push(ProcessingState::NotStarted(non_processed_item));
959-
} else {
960-
item.transformed |=
961-
processed_children.iter().any(|item| item.transformed);
962-
item.data = item.data.with_new_arc_children(
963-
processed_children.into_iter().map(|c| c.data).collect(),
964-
)?;
965-
queue.push(ProcessingState::ProcessedAllChildren(item))
966-
}
967-
}
968-
TreeNodeRecursion::Stop => {
969-
processed_children.extend(
970-
non_processed_children
971-
.into_iter()
972-
.rev()
973-
.map(Transformed::no),
974-
);
975-
item.transformed |=
976-
processed_children.iter().any(|item| item.transformed);
977-
item.data = item.data.with_new_arc_children(
978-
processed_children.into_iter().map(|c| c.data).collect(),
979-
)?;
980-
queue.push(ProcessingState::ProcessedAllChildren(item));
981-
}
982-
},
983-
ProcessingState::ProcessedAllChildren(node) => {
984-
let node = node.transform_parent(|$NAME| $TRANSFORM_UP)?;
985-
986-
if let Some(ProcessingState::ProcessingChildren {
987-
item: mut parent_node,
988-
non_processed_children,
989-
mut processed_children,
990-
..
991-
}) = queue.pop()
992-
{
993-
parent_node.tnr = node.tnr;
994-
processed_children.push(node);
995-
996-
queue.push(ProcessingState::ProcessingChildren {
997-
item: parent_node,
998-
non_processed_children,
999-
processed_children,
1000-
})
1001-
} else {
1002-
debug_assert_eq!(queue.len(), 0);
1003-
return Ok(node);
1004-
}
1005-
}
1006-
}
1007-
}
1008-
1009-
unreachable!();
1010-
};
1011-
}
1012-
1013916
/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
1014917
///
1015918
/// # Example
@@ -1063,6 +966,59 @@ pub trait DynTreeNode {
1063966
) -> Result<Arc<Self>>;
1064967
}
1065968

969+
pub struct LegacyRewriter<
970+
FD: FnMut(Node) -> Result<Transformed<Node>>,
971+
FU: FnMut(Node) -> Result<Transformed<Node>>,
972+
Node: TreeNode,
973+
> {
974+
f_down_func: FD,
975+
f_up_func: FU,
976+
_node: PhantomData<Node>,
977+
}
978+
979+
impl<
980+
FD: FnMut(Node) -> Result<Transformed<Node>>,
981+
FU: FnMut(Node) -> Result<Transformed<Node>>,
982+
Node: TreeNode,
983+
> LegacyRewriter<FD, FU, Node>
984+
{
985+
pub fn new(f_down_func: FD, f_up_func: FU) -> Self {
986+
Self {
987+
f_down_func,
988+
f_up_func,
989+
_node: PhantomData,
990+
}
991+
}
992+
}
993+
impl<
994+
FD: FnMut(Node) -> Result<Transformed<Node>>,
995+
FU: FnMut(Node) -> Result<Transformed<Node>>,
996+
Node: TreeNode,
997+
> TreeNodeRewriter for LegacyRewriter<FD, FU, Node>
998+
{
999+
type Node = Node;
1000+
1001+
fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
1002+
(self.f_down_func)(node)
1003+
}
1004+
1005+
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
1006+
(self.f_up_func)(node)
1007+
}
1008+
}
1009+
1010+
macro_rules! update_rec_node {
1011+
($NAME:ident, $CHILDREN:ident) => {{
1012+
$NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed);
1013+
1014+
$NAME.data = $NAME
1015+
.data
1016+
.with_new_arc_children($CHILDREN.into_iter().map(|c| c.data).collect())?;
1017+
1018+
$NAME
1019+
}};
1020+
}
1021+
10661022
/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
10671023
/// (such as [`Arc<dyn PhysicalExpr>`]).
10681024
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
@@ -1102,43 +1058,134 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
11021058
FU: FnMut(Self) -> Result<Transformed<Self>>,
11031059
>(
11041060
self,
1105-
mut f_down: FD,
1106-
mut f_up: FU,
1061+
f_down: FD,
1062+
f_up: FU,
11071063
) -> Result<Transformed<Self>> {
1108-
rewrite_recursive!(self, node, f_up(node), f_down(node));
1064+
self.rewrite(&mut LegacyRewriter::new(f_down, f_up))
11091065
}
11101066

11111067
fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
11121068
self,
11131069
f: F,
11141070
) -> Result<Transformed<Self>> {
1115-
self.transform_down_up(f, |node| Ok(Transformed::no(node)))
1071+
self.rewrite(&mut LegacyRewriter::new(f, |node| {
1072+
Ok(Transformed::no(node))
1073+
}))
11161074
}
11171075

11181076
fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
11191077
self,
11201078
f: F,
11211079
) -> Result<Transformed<Self>> {
1122-
self.transform_down_up(|node| Ok(Transformed::no(node)), f)
1080+
self.rewrite(&mut LegacyRewriter::new(
1081+
|node| Ok(Transformed::no(node)),
1082+
f,
1083+
))
11231084
}
11241085
fn rewrite<R: TreeNodeRewriter<Node = Self>>(
11251086
self,
11261087
rewriter: &mut R,
11271088
) -> Result<Transformed<Self>> {
1128-
rewrite_recursive!(self, node, rewriter.f_up(node), rewriter.f_down(node));
1089+
let mut stack = vec![ProcessingState::NotStarted(self)];
1090+
1091+
while let Some(item) = stack.pop() {
1092+
match item {
1093+
ProcessingState::NotStarted(node) => {
1094+
let node = rewriter.f_down(node)?;
1095+
1096+
stack.push(match node.tnr {
1097+
TreeNodeRecursion::Continue => {
1098+
ProcessingState::ProcessingChildren {
1099+
non_processed_children: node
1100+
.data
1101+
.arc_children()
1102+
.into_iter()
1103+
.cloned()
1104+
.rev()
1105+
.collect(),
1106+
item: node,
1107+
processed_children: vec![],
1108+
}
1109+
}
1110+
TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren(
1111+
node.with_tnr(TreeNodeRecursion::Continue),
1112+
),
1113+
TreeNodeRecursion::Stop => {
1114+
ProcessingState::ProcessedAllChildren(node)
1115+
}
1116+
})
1117+
}
1118+
ProcessingState::ProcessingChildren {
1119+
mut item,
1120+
mut non_processed_children,
1121+
mut processed_children,
1122+
} => match item.tnr {
1123+
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
1124+
if let Some(non_processed_item) = non_processed_children.pop() {
1125+
stack.push(ProcessingState::ProcessingChildren {
1126+
item,
1127+
non_processed_children,
1128+
processed_children,
1129+
});
1130+
stack.push(ProcessingState::NotStarted(non_processed_item));
1131+
} else {
1132+
stack.push(ProcessingState::ProcessedAllChildren(
1133+
update_rec_node!(item, processed_children),
1134+
))
1135+
}
1136+
}
1137+
TreeNodeRecursion::Stop => {
1138+
processed_children.extend(
1139+
non_processed_children
1140+
.into_iter()
1141+
.rev()
1142+
.map(Transformed::no),
1143+
);
1144+
stack.push(ProcessingState::ProcessedAllChildren(
1145+
update_rec_node!(item, processed_children),
1146+
));
1147+
}
1148+
},
1149+
ProcessingState::ProcessedAllChildren(node) => {
1150+
let node = node.transform_parent(|n| rewriter.f_up(n))?;
1151+
1152+
if let Some(ProcessingState::ProcessingChildren {
1153+
item: mut parent_node,
1154+
non_processed_children,
1155+
mut processed_children,
1156+
..
1157+
}) = stack.pop()
1158+
{
1159+
parent_node.tnr = node.tnr;
1160+
processed_children.push(node);
1161+
1162+
stack.push(ProcessingState::ProcessingChildren {
1163+
item: parent_node,
1164+
non_processed_children,
1165+
processed_children,
1166+
})
1167+
} else {
1168+
debug_assert_eq!(stack.len(), 0);
1169+
return Ok(node);
1170+
}
1171+
}
1172+
}
1173+
}
1174+
1175+
unreachable!();
11291176
}
11301177

11311178
fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
11321179
&'n self,
11331180
visitor: &mut V,
11341181
) -> Result<TreeNodeRecursion> {
1135-
let mut queue = vec![VisitingState::NotStarted(self)];
1182+
let mut stack = vec![VisitingState::NotStarted(self)];
11361183

1137-
while let Some(item) = queue.pop() {
1184+
while let Some(item) = stack.pop() {
11381185
match item {
11391186
VisitingState::NotStarted(item) => {
11401187
let tnr = visitor.f_down(item)?;
1141-
queue.push(match tnr {
1188+
stack.push(match tnr {
11421189
TreeNodeRecursion::Continue => VisitingState::VisitingChildren {
11431190
non_processed_children: item
11441191
.arc_children()
@@ -1165,14 +1212,14 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
11651212
} => match tnr {
11661213
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
11671214
if let Some(non_processed_item) = non_processed_children.pop() {
1168-
queue.push(VisitingState::VisitingChildren {
1215+
stack.push(VisitingState::VisitingChildren {
11691216
item,
11701217
non_processed_children,
11711218
tnr,
11721219
});
1173-
queue.push(VisitingState::NotStarted(non_processed_item));
1220+
stack.push(VisitingState::NotStarted(non_processed_item));
11741221
} else {
1175-
queue.push(VisitingState::VisitedAllChildren { item, tnr });
1222+
stack.push(VisitingState::VisitedAllChildren { item, tnr });
11761223
}
11771224
}
11781225
TreeNodeRecursion::Stop => {
@@ -1186,15 +1233,15 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
11861233
item,
11871234
non_processed_children,
11881235
..
1189-
}) = queue.pop()
1236+
}) = stack.pop()
11901237
{
1191-
queue.push(VisitingState::VisitingChildren {
1238+
stack.push(VisitingState::VisitingChildren {
11921239
item,
11931240
non_processed_children,
11941241
tnr,
11951242
});
11961243
} else {
1197-
debug_assert_eq!(queue.len(), 0);
1244+
debug_assert_eq!(stack.len(), 0);
11981245
return Ok(tnr);
11991246
}
12001247
}
@@ -1208,30 +1255,32 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
12081255
#[derive(Debug)]
12091256
enum ProcessingState<T> {
12101257
NotStarted(T),
1211-
// f_down is called
1258+
// ← at this point, f_down is called
12121259
ProcessingChildren {
12131260
item: Transformed<T>,
12141261
non_processed_children: Vec<T>,
12151262
processed_children: Vec<Transformed<T>>,
12161263
},
1264+
// ← at this point, all children are processed
12171265
ProcessedAllChildren(Transformed<T>),
1218-
// f_up is called
1266+
// ← at this point, f_up is called
12191267
}
12201268

12211269
#[derive(Debug)]
12221270
enum VisitingState<'a, T> {
12231271
NotStarted(&'a T),
1224-
// f_down is called
1272+
// ← at this point, f_down is called
12251273
VisitingChildren {
12261274
item: &'a T,
12271275
non_processed_children: Vec<&'a T>,
12281276
tnr: TreeNodeRecursion,
12291277
},
1278+
// ← at this point, all children are visited
12301279
VisitedAllChildren {
12311280
item: &'a T,
12321281
tnr: TreeNodeRecursion,
12331282
},
1234-
// f_up is called
1283+
// ← at this point, f_up is called
12351284
}
12361285

12371286
/// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for

0 commit comments

Comments
 (0)