Skip to content

Commit aaeffa6

Browse files
committed
Include in DeadCode/DeadFunc elimination
1 parent 5b90887 commit aaeffa6

File tree

4 files changed

+90
-50
lines changed

4 files changed

+90
-50
lines changed

hugr-passes/src/composable.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,16 @@ mod test {
246246
.define_function("id1", Signature::new_endo(usize_t()))
247247
.unwrap();
248248
let inps = id1.input_wires();
249-
let id1 = id1.finish_with_outputs(inps).unwrap();
249+
id1.finish_with_outputs(inps).unwrap();
250+
250251
let id2 = mb
251-
.define_function("id2", Signature::new_endo(usize_t()))
252+
.define_function_link_name("id2", Signature::new_endo(usize_t()), None)
252253
.unwrap();
253254
let inps = id2.input_wires();
254255
let id2 = id2.finish_with_outputs(inps).unwrap();
255256
let hugr = mb.finish_hugr().unwrap();
256257

257-
let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]);
258+
let dce = DeadCodeElimPass::default();
258259
let cfold =
259260
ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]);
260261

hugr-passes/src/dead_code.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ pub struct DeadCodeElimPass<H: HugrView> {
2020
/// Callback identifying nodes that must be preserved even if their
2121
/// results are not used. Defaults to [PreserveNode::default_for].
2222
preserve_callback: Arc<PreserveCallback<H>>,
23+
include_exports: bool,
2324
}
2425

2526
impl<H: HugrView + 'static> Default for DeadCodeElimPass<H> {
2627
fn default() -> Self {
2728
Self {
2829
entry_points: Default::default(),
2930
preserve_callback: Arc::new(PreserveNode::default_for),
31+
include_exports: true,
3032
}
3133
}
3234
}
@@ -94,18 +96,33 @@ impl<H: HugrView> DeadCodeElimPass<H> {
9496
/// Mark some nodes as entry points to the Hugr, i.e. so we cannot eliminate any code
9597
/// used to evaluate these nodes.
9698
/// [`HugrView::entrypoint`] is assumed to be an entry point;
97-
/// for Module roots the client will want to mark some of the FuncDefn children
98-
/// as entry points too.
99+
/// if the entrypoint is a Module, then any public
100+
/// [FuncDefn](OpType::FuncDefn)s and [Const](OpType::Const)s are also considered entry points
101+
/// by default, but these can be removed by [Self::include_module_exports]`(false)`.
99102
pub fn with_entry_points(mut self, entry_points: impl IntoIterator<Item = H::Node>) -> Self {
100103
self.entry_points.extend(entry_points);
101104
self
102105
}
103106

107+
/// Sets whether, for Module-rooted Hugrs, the exported [FuncDefn](OpType::FuncDefn)s
108+
/// and [Const](OpType::Const)s are included as entry points (they are by default)
109+
pub fn include_module_exports(mut self, include: bool) -> Self {
110+
self.include_exports = include;
111+
self
112+
}
113+
104114
fn find_needed_nodes(&self, h: &H) -> HashSet<H::Node> {
105115
let mut must_preserve = HashMap::new();
106116
let mut needed = HashSet::new();
107117
let mut q = VecDeque::from_iter(self.entry_points.iter().cloned());
108118
q.push_front(h.entrypoint());
119+
if self.include_exports && h.entrypoint() == h.module_root() {
120+
q.extend(h.children(h.module_root()).filter(|ch| {
121+
h.get_optype(*ch)
122+
.as_func_defn()
123+
.is_some_and(|fd| fd.link_name.is_some())
124+
}))
125+
}
109126
while let Some(n) = q.pop_front() {
110127
if !needed.insert(n) {
111128
continue;
@@ -120,8 +137,8 @@ impl<H: HugrView> DeadCodeElimPass<H> {
120137
| OpType::AliasDecl(_) // and all Aliases (we do not track their uses in types)
121138
| OpType::AliasDefn(_)
122139
| OpType::Input(_) // Also Dataflow input/output, these are necessary for legality
123-
| OpType::Output(_) // Do not include FuncDecl / FuncDefn / Const unless reachable by static edges
124-
// (from Call/LoadConst/LoadFunction):
140+
| OpType::Output(_) // Do not include FuncDecl / Const unless reachable by static edges
141+
// (from Call/LoadConst/LoadFunction)
125142
)
126143
{
127144
q.push_back(ch);

hugr-passes/src/dead_funcs.rs

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,20 @@ fn reachable_funcs<'a, H: HugrView>(
6161
}))
6262
}
6363

64-
#[derive(Debug, Clone, Default)]
64+
#[derive(Debug, Clone)]
6565
/// A configuration for the Dead Function Removal pass.
6666
pub struct RemoveDeadFuncsPass {
6767
entry_points: Vec<Node>,
68+
include_exports: bool,
69+
}
70+
71+
impl Default for RemoveDeadFuncsPass {
72+
fn default() -> Self {
73+
Self {
74+
entry_points: Default::default(),
75+
include_exports: true,
76+
}
77+
}
6878
}
6979

7080
impl RemoveDeadFuncsPass {
@@ -80,16 +90,34 @@ impl RemoveDeadFuncsPass {
8090
self.entry_points.extend(entry_points);
8191
self
8292
}
93+
94+
/// Sets whether the exported [FuncDefn](hugr_core::ops::FuncDefn) children of a
95+
/// [Module](hugr_core::ops::Module) are included as entry points (yes by default)
96+
pub fn include_module_exports(mut self, include: bool) -> Self {
97+
self.include_exports = include;
98+
self
99+
}
83100
}
84101

85102
impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
86103
type Error = RemoveDeadFuncsError;
87104
type Result = ();
88105
fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
106+
let exports = if hugr.entrypoint() == hugr.module_root() && self.include_exports {
107+
hugr.children(hugr.module_root())
108+
.filter(|ch| {
109+
hugr.get_optype(*ch)
110+
.as_func_defn()
111+
.is_some_and(|fd| fd.link_name.is_some())
112+
})
113+
.collect()
114+
} else {
115+
vec![]
116+
};
89117
let reachable = reachable_funcs(
90118
&CallGraph::new(hugr),
91119
hugr,
92-
self.entry_points.iter().cloned(),
120+
self.entry_points.iter().cloned().chain(exports),
93121
)?
94122
.collect::<HashSet<_>>();
95123
let unreachable = hugr
@@ -108,30 +136,13 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
108136
/// Deletes from the Hugr any functions that are not used by either [Call] or
109137
/// [LoadFunction] nodes in reachable parts.
110138
///
111-
/// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points,
112-
/// which must be children of the root. Note that if `entry_points` is empty, this will
113-
/// result in all functions in the module being removed.
114-
///
115-
/// For non-[Module]-rooted Hugrs, `entry_points` must be empty; the root node is used.
116-
///
117-
/// # Errors
118-
/// * If there are any `entry_points` but the root of the hugr is not a [Module]
119-
/// * If any node in `entry_points` is
120-
/// * not a [FuncDefn], or
121-
/// * not a child of the root
139+
/// For [Module]-rooted Hugrs, all top-level functions with [FuncDefn::link_name] set,
140+
/// will be used as entry points.
122141
///
123-
/// [Call]: hugr_core::ops::OpType::Call
124-
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn
125-
/// [LoadFunction]: hugr_core::ops::OpType::LoadFunction
126-
/// [Module]: hugr_core::ops::OpType::Module
127142
pub fn remove_dead_funcs(
128143
h: &mut impl HugrMut<Node = Node>,
129-
entry_points: impl IntoIterator<Item = Node>,
130144
) -> Result<(), ValidatePassError<Node, RemoveDeadFuncsError>> {
131-
validate_if_test(
132-
RemoveDeadFuncsPass::default().with_module_entry_points(entry_points),
133-
h,
134-
)
145+
validate_if_test(RemoveDeadFuncsPass::default(), h)
135146
}
136147

137148
#[cfg(test)]
@@ -146,29 +157,34 @@ mod test {
146157
};
147158
use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView};
148159

149-
use super::remove_dead_funcs;
160+
use super::RemoveDeadFuncsPass;
161+
use crate::ComposablePass;
150162

151163
#[rstest]
152-
#[case([], vec![])] // No entry_points removes everything!
153-
#[case(["main"], vec!["from_main", "main"])]
154-
#[case(["from_main"], vec!["from_main"])]
155-
#[case(["other1"], vec!["other1", "other2"])]
156-
#[case(["other2"], vec!["other2"])]
157-
#[case(["other1", "other2"], vec!["other1", "other2"])]
164+
#[case(false, [], vec![])] // No entry_points removes everything!
165+
#[case(false, ["main"], vec!["from_main", "main"])]
166+
#[case(false, ["from_main"], vec!["from_main"])]
167+
#[case(false, ["other1"], vec!["other1", "other2"])]
168+
#[case(false, ["other2"], vec!["other2"])]
169+
#[case(false, ["other1", "other2"], vec!["other1", "other2"])]
170+
#[case(true, [], vec!["from_main", "main", "other2"])]
171+
#[case(true, ["other1"], vec!["from_main", "main", "other1", "other2"])]
158172
fn remove_dead_funcs_entry_points(
173+
#[case] include_exports: bool,
159174
#[case] entry_points: impl IntoIterator<Item = &'static str>,
160175
#[case] retained_funcs: Vec<&'static str>,
161176
) -> Result<(), Box<dyn std::error::Error>> {
162177
let mut hb = ModuleBuilder::new();
163178
let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?;
164179
let o2inp = o2.input_wires();
165180
let o2 = o2.finish_with_outputs(o2inp)?;
166-
let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?;
181+
let mut o1 =
182+
hb.define_function_link_name("other1", Signature::new_endo(usize_t()), None)?;
167183

168184
let o1c = o1.call(o2.handle(), &[], o1.input_wires())?;
169185
o1.finish_with_outputs(o1c.outputs())?;
170186

171-
let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?;
187+
let fm = hb.define_function_link_name("from_main", Signature::new_endo(usize_t()), None)?;
172188
let f_inp = fm.input_wires();
173189
let fm = fm.finish_with_outputs(f_inp)?;
174190
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
@@ -186,14 +202,16 @@ mod test {
186202
})
187203
.collect::<HashMap<_, _>>();
188204

189-
remove_dead_funcs(
190-
&mut hugr,
191-
entry_points
192-
.into_iter()
193-
.map(|name| *avail_funcs.get(name).unwrap())
194-
.collect::<Vec<_>>(),
195-
)
196-
.unwrap();
205+
RemoveDeadFuncsPass::default()
206+
.include_module_exports(include_exports)
207+
.with_module_entry_points(
208+
entry_points
209+
.into_iter()
210+
.map(|name| *avail_funcs.get(name).unwrap())
211+
.collect::<Vec<_>>(),
212+
)
213+
.run(&mut hugr)
214+
.unwrap();
197215

198216
let remaining_funcs = hugr
199217
.nodes()

hugr-passes/src/monomorphize.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ mod test {
307307
use hugr_core::{Hugr, HugrView, Node};
308308
use rstest::rstest;
309309

310-
use crate::{monomorphize, remove_dead_funcs};
310+
use crate::{monomorphize, remove_dead_funcs, ComposablePass, RemoveDeadFuncsPass};
311311

312312
use super::{is_polymorphic, mangle_inner_func, mangle_name};
313313

@@ -414,7 +414,10 @@ mod test {
414414
assert_eq!(mono2, mono); // Idempotent
415415

416416
let mut nopoly = mono;
417-
remove_dead_funcs(&mut nopoly, [mn.node()])?;
417+
RemoveDeadFuncsPass::default()
418+
.include_module_exports(false)
419+
.with_module_entry_points([mn.node()])
420+
.run(&mut nopoly)?;
418421
let mut funcs = list_funcs_link_name(&nopoly);
419422

420423
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
@@ -617,12 +620,13 @@ mod test {
617620
let mut module_builder = ModuleBuilder::new();
618621
let foo = {
619622
let builder = module_builder
620-
.define_function(
623+
.define_function_link_name(
621624
"foo",
622625
PolyFuncType::new(
623626
[TypeBound::Any.into()],
624627
Signature::new_endo(Type::new_var_use(0, TypeBound::Any)),
625628
),
629+
None,
626630
)
627631
.unwrap();
628632
let inputs = builder.input_wires();
@@ -653,7 +657,7 @@ mod test {
653657
};
654658

655659
monomorphize(&mut hugr).unwrap();
656-
remove_dead_funcs(&mut hugr, []).unwrap();
660+
remove_dead_funcs(&mut hugr).unwrap();
657661

658662
let funcs = list_funcs(&hugr);
659663
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));

0 commit comments

Comments
 (0)