Skip to content

Commit e475464

Browse files
committed
Improve readability
1 parent 4e1b81e commit e475464

File tree

1 file changed

+70
-68
lines changed

1 file changed

+70
-68
lines changed

datafusion/common/src/tree_node.rs

Lines changed: 70 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ impl<T> Transformed<T> {
779779
}
780780
}
781781

782+
/// Replaces recursion state with the given one
782783
pub fn with_tnr(mut self, tnr: TreeNodeRecursion) -> Self {
783784
self.tnr = tnr;
784785
self
@@ -956,7 +957,8 @@ pub trait DynTreeNode {
956957
) -> Result<Arc<Self>>;
957958
}
958959

959-
pub struct LegacyRewriter<
960+
/// Adapter from the old function-based rewriter to the new Transformer one
961+
struct FuncRewriter<
960962
FD: FnMut(Node) -> Result<Transformed<Node>>,
961963
FU: FnMut(Node) -> Result<Transformed<Node>>,
962964
Node: TreeNode,
@@ -970,7 +972,7 @@ impl<
970972
FD: FnMut(Node) -> Result<Transformed<Node>>,
971973
FU: FnMut(Node) -> Result<Transformed<Node>>,
972974
Node: TreeNode,
973-
> LegacyRewriter<FD, FU, Node>
975+
> FuncRewriter<FD, FU, Node>
974976
{
975977
pub fn new(f_down_func: FD, f_up_func: FU) -> Self {
976978
Self {
@@ -984,7 +986,7 @@ impl<
984986
FD: FnMut(Node) -> Result<Transformed<Node>>,
985987
FU: FnMut(Node) -> Result<Transformed<Node>>,
986988
Node: TreeNode,
987-
> TreeNodeRewriter for LegacyRewriter<FD, FU, Node>
989+
> TreeNodeRewriter for FuncRewriter<FD, FU, Node>
988990
{
989991
type Node = Node;
990992

@@ -997,7 +999,8 @@ impl<
997999
}
9981000
}
9991001

1000-
macro_rules! update_rec_node {
1002+
/// Replaces node's children and recomputes the state
1003+
macro_rules! update_node_after_recursion {
10011004
($NAME:ident, $CHILDREN:ident) => {{
10021005
$NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed);
10031006

@@ -1011,6 +1014,7 @@ macro_rules! update_rec_node {
10111014

10121015
/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
10131016
/// (such as [`Arc<dyn PhysicalExpr>`]).
1017+
/// Unlike [`TreeNode`], performs node traversal iteratively rather than recursively to avoid stack overflow
10141018
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
10151019
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
10161020
&'n self,
@@ -1051,41 +1055,36 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
10511055
f_down: FD,
10521056
f_up: FU,
10531057
) -> Result<Transformed<Self>> {
1054-
self.rewrite(&mut LegacyRewriter::new(f_down, f_up))
1058+
self.rewrite(&mut FuncRewriter::new(f_down, f_up))
10551059
}
10561060

10571061
fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
10581062
self,
10591063
f: F,
10601064
) -> Result<Transformed<Self>> {
1061-
self.rewrite(&mut LegacyRewriter::new(f, |node| {
1062-
Ok(Transformed::no(node))
1063-
}))
1065+
self.rewrite(&mut FuncRewriter::new(f, |node| Ok(Transformed::no(node))))
10641066
}
10651067

10661068
fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
10671069
self,
10681070
f: F,
10691071
) -> Result<Transformed<Self>> {
1070-
self.rewrite(&mut LegacyRewriter::new(
1071-
|node| Ok(Transformed::no(node)),
1072-
f,
1073-
))
1072+
self.rewrite(&mut FuncRewriter::new(|node| Ok(Transformed::no(node)), f))
10741073
}
10751074
fn rewrite<R: TreeNodeRewriter<Node = Self>>(
10761075
self,
10771076
rewriter: &mut R,
10781077
) -> Result<Transformed<Self>> {
1079-
let mut stack = vec![ProcessingState::NotStarted(self)];
1078+
let mut stack = vec![TransformingState::NotStarted(self)];
10801079

10811080
while let Some(item) = stack.pop() {
10821081
match item {
1083-
ProcessingState::NotStarted(node) => {
1082+
TransformingState::NotStarted(node) => {
10841083
let node = rewriter.f_down(node)?;
10851084

10861085
stack.push(match node.tnr {
10871086
TreeNodeRecursion::Continue => {
1088-
ProcessingState::ProcessingChildren {
1087+
TransformingState::ProcessingChildren {
10891088
non_processed_children: node
10901089
.data
10911090
.arc_children()
@@ -1097,59 +1096,68 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
10971096
processed_children: vec![],
10981097
}
10991098
}
1100-
TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren(
1101-
node.with_tnr(TreeNodeRecursion::Continue),
1102-
),
1099+
TreeNodeRecursion::Jump => {
1100+
TransformingState::ProcessedAllChildren(
1101+
// No need to process children, we can just this stage
1102+
node.with_tnr(TreeNodeRecursion::Continue),
1103+
)
1104+
}
11031105
TreeNodeRecursion::Stop => {
1104-
ProcessingState::ProcessedAllChildren(node)
1106+
TransformingState::ProcessedAllChildren(node)
11051107
}
11061108
})
11071109
}
1108-
ProcessingState::ProcessingChildren {
1110+
TransformingState::ProcessingChildren {
11091111
mut item,
11101112
mut non_processed_children,
11111113
mut processed_children,
11121114
} => match item.tnr {
11131115
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
11141116
if let Some(non_processed_item) = non_processed_children.pop() {
1115-
stack.push(ProcessingState::ProcessingChildren {
1116-
item,
1117-
non_processed_children,
1118-
processed_children,
1119-
});
1120-
stack.push(ProcessingState::NotStarted(non_processed_item));
1117+
stack.extend([
1118+
// This node still has children, so put it back in the stack
1119+
TransformingState::ProcessingChildren {
1120+
item,
1121+
non_processed_children,
1122+
processed_children,
1123+
},
1124+
// Also put the child which will be processed first
1125+
TransformingState::NotStarted(non_processed_item),
1126+
]);
11211127
} else {
1122-
stack.push(ProcessingState::ProcessedAllChildren(
1123-
update_rec_node!(item, processed_children),
1128+
stack.push(TransformingState::ProcessedAllChildren(
1129+
update_node_after_recursion!(item, processed_children),
11241130
))
11251131
}
11261132
}
11271133
TreeNodeRecursion::Stop => {
1134+
// At this point, we might have some children we haven't yet processed
11281135
processed_children.extend(
11291136
non_processed_children
11301137
.into_iter()
11311138
.rev()
11321139
.map(Transformed::no),
11331140
);
1134-
stack.push(ProcessingState::ProcessedAllChildren(
1135-
update_rec_node!(item, processed_children),
1141+
stack.push(TransformingState::ProcessedAllChildren(
1142+
update_node_after_recursion!(item, processed_children),
11361143
));
11371144
}
11381145
},
1139-
ProcessingState::ProcessedAllChildren(node) => {
1146+
TransformingState::ProcessedAllChildren(node) => {
11401147
let node = node.transform_parent(|n| rewriter.f_up(n))?;
11411148

1142-
if let Some(ProcessingState::ProcessingChildren {
1149+
if let Some(TransformingState::ProcessingChildren {
11431150
item: mut parent_node,
11441151
non_processed_children,
11451152
mut processed_children,
11461153
..
11471154
}) = stack.pop()
11481155
{
1156+
// We need use returned recursion state when processing the remaining children
11491157
parent_node.tnr = node.tnr;
11501158
processed_children.push(node);
11511159

1152-
stack.push(ProcessingState::ProcessingChildren {
1160+
stack.push(TransformingState::ProcessingChildren {
11531161
item: parent_node,
11541162
non_processed_children,
11551163
processed_children,
@@ -1189,10 +1197,7 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
11891197
item,
11901198
tnr: TreeNodeRecursion::Continue,
11911199
},
1192-
TreeNodeRecursion::Stop => VisitingState::VisitedAllChildren {
1193-
item,
1194-
tnr: TreeNodeRecursion::Stop,
1195-
},
1200+
TreeNodeRecursion::Stop => return Ok(tnr),
11961201
});
11971202
}
11981203
VisitingState::VisitingChildren {
@@ -1202,12 +1207,15 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
12021207
} => match tnr {
12031208
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
12041209
if let Some(non_processed_item) = non_processed_children.pop() {
1205-
stack.push(VisitingState::VisitingChildren {
1206-
item,
1207-
non_processed_children,
1208-
tnr,
1209-
});
1210-
stack.push(VisitingState::NotStarted(non_processed_item));
1210+
stack.extend([
1211+
// Returning the node on the stack because there are more children to process
1212+
VisitingState::VisitingChildren {
1213+
item,
1214+
non_processed_children,
1215+
tnr,
1216+
},
1217+
VisitingState::NotStarted(non_processed_item),
1218+
]);
12111219
} else {
12121220
stack.push(VisitingState::VisitedAllChildren { item, tnr });
12131221
}
@@ -1220,10 +1228,10 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
12201228
let tnr = tnr.visit_parent(|| visitor.f_up(item))?;
12211229

12221230
if let Some(VisitingState::VisitingChildren {
1223-
item,
1224-
non_processed_children,
1225-
..
1226-
}) = stack.pop()
1231+
item,
1232+
non_processed_children,
1233+
.. // we don't care about the parent recursion state, because it will be replaced with the current state anyway
1234+
}) = stack.pop()
12271235
{
12281236
stack.push(VisitingState::VisitingChildren {
12291237
item,
@@ -1242,35 +1250,32 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
12421250
}
12431251
}
12441252

1245-
#[derive(Debug)]
1246-
enum ProcessingState<T> {
1253+
/// Node iterative transformation state. Each node on the stack is visited several times, state determines the operation that is going to run next
1254+
enum TransformingState<T> {
1255+
/// Node was just created. When executed, f_down will be called.
12471256
NotStarted(T),
1248-
// ← at this point, f_down is called
1257+
/// DFS over node's children
12491258
ProcessingChildren {
12501259
item: Transformed<T>,
12511260
non_processed_children: Vec<T>,
12521261
processed_children: Vec<Transformed<T>>,
12531262
},
1254-
// ← at this point, all children are processed
1263+
/// All children are processed (or jumped through). When executed, f_up may be called
12551264
ProcessedAllChildren(Transformed<T>),
1256-
// ← at this point, f_up is called
12571265
}
12581266

1259-
#[derive(Debug)]
1267+
/// Node iterative visit state. Each node on the stack is visited several times, state determines the operation that is going to run next
12601268
enum VisitingState<'a, T> {
1269+
/// Node was just created. When executed, f_down will be called.
12611270
NotStarted(&'a T),
1262-
// ← at this point, f_down is called
1271+
/// DFS over node's children. During processing, reference to children are removed from the inner stack
12631272
VisitingChildren {
12641273
item: &'a T,
12651274
non_processed_children: Vec<&'a T>,
12661275
tnr: TreeNodeRecursion,
12671276
},
1268-
// ← at this point, all children are visited
1269-
VisitedAllChildren {
1270-
item: &'a T,
1271-
tnr: TreeNodeRecursion,
1272-
},
1273-
// ← at this point, f_up is called
1277+
/// All children are processed (or jumped through). When executed, f_up may be called
1278+
VisitedAllChildren { item: &'a T, tnr: TreeNodeRecursion },
12741279
}
12751280

12761281
/// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for
@@ -1409,11 +1414,8 @@ pub(crate) mod tests {
14091414
.collect()
14101415
}
14111416

1412-
/// Implement this train in order for your struct to be supported in parametrisation
1413-
pub(crate) trait TestTree<T: PartialEq>
1414-
where
1415-
T: Sized,
1416-
{
1417+
/// [`node_tests`] uses methods when generating tests
1418+
pub(crate) trait TestTree<T: PartialEq> {
14171419
fn new_with_children(children: Vec<Self>, data: T) -> Self
14181420
where
14191421
Self: Sized;
@@ -1427,7 +1429,7 @@ pub(crate) mod tests {
14271429
fn with_children_from(data: T, other: Self) -> Self;
14281430
}
14291431

1430-
macro_rules! gen_tests {
1432+
macro_rules! node_tests {
14311433
($TYPE:ident) => {
14321434
fn visit_continue<T: PartialEq>(_: &$TYPE<T>) -> Result<TreeNodeRecursion> {
14331435
Ok(TreeNodeRecursion::Continue)
@@ -2546,7 +2548,7 @@ pub(crate) mod tests {
25462548
}
25472549
}
25482550

2549-
gen_tests!(TestTreeNode);
2551+
node_tests!(TestTreeNode);
25502552
}
25512553

25522554
pub mod test_dyn_tree_node {
@@ -2607,7 +2609,7 @@ pub(crate) mod tests {
26072609

26082610
type ArcTestNode<T> = Arc<DynTestNode<T>>;
26092611

2610-
gen_tests!(ArcTestNode);
2612+
node_tests!(ArcTestNode);
26112613

26122614
#[test]
26132615
fn test_large_tree() {

0 commit comments

Comments
 (0)