Skip to content

Commit cad7484

Browse files
authored
feat: Add CallGraph struct, and dead-function-removal pass (#1796)
Closes #1753. * `remove_polyfuncs` preserved but deprecated, some uses in tests replaced to give coverage here. * Future (breaking) release to remove the automatic-`remove_polyfuncs` that currently follows monomorphization.
1 parent 33dd8fd commit cad7484

File tree

4 files changed

+329
-12
lines changed

4 files changed

+329
-12
lines changed

hugr-passes/src/call_graph.rs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#![warn(missing_docs)]
2+
//! Data structure for call graphs of a Hugr
3+
use std::collections::HashMap;
4+
5+
use hugr_core::{ops::OpType, HugrView, Node};
6+
use petgraph::{graph::NodeIndex, Graph};
7+
8+
/// Weight for an edge in a [CallGraph]
9+
pub enum CallGraphEdge {
10+
/// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr
11+
Call(Node),
12+
/// Edge corresponds to a [LoadFunction](OpType::LoadFunction) node (specified) in the Hugr
13+
LoadFunction(Node),
14+
}
15+
16+
/// Weight for a petgraph-node in a [CallGraph]
17+
pub enum CallGraphNode {
18+
/// petgraph-node corresponds to a [FuncDecl](OpType::FuncDecl) node (specified) in the Hugr
19+
FuncDecl(Node),
20+
/// petgraph-node corresponds to a [FuncDefn](OpType::FuncDefn) node (specified) in the Hugr
21+
FuncDefn(Node),
22+
/// petgraph-node corresponds to the root node of the hugr, that is not
23+
/// a [FuncDefn](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
24+
/// either, as such a node could not have outgoing edges, so is not represented in the petgraph.
25+
NonFuncRoot,
26+
}
27+
28+
/// Details the [Call]s and [LoadFunction]s in a Hugr.
29+
/// Each node in the `CallGraph` corresponds to a [FuncDefn] in the Hugr; each edge corresponds
30+
/// to a [Call]/[LoadFunction] of the edge's target, contained in the edge's source.
31+
///
32+
/// For Hugrs whose root is neither a [Module](OpType::Module) nor a [FuncDefn], the call graph
33+
/// will have an additional [CallGraphNode::NonFuncRoot] corresponding to the Hugr's root, with no incoming edges.
34+
///
35+
/// [Call]: OpType::Call
36+
/// [FuncDefn]: OpType::FuncDefn
37+
/// [LoadFunction]: OpType::LoadFunction
38+
pub struct CallGraph {
39+
g: Graph<CallGraphNode, CallGraphEdge>,
40+
node_to_g: HashMap<Node, NodeIndex<u32>>,
41+
}
42+
43+
impl CallGraph {
44+
/// Makes a new CallGraph for a specified (subview) of a Hugr.
45+
/// Calls to functions outside the view will be dropped.
46+
pub fn new(hugr: &impl HugrView) -> Self {
47+
let mut g = Graph::default();
48+
let non_func_root = (!hugr.get_optype(hugr.root()).is_module()).then_some(hugr.root());
49+
let node_to_g = hugr
50+
.nodes()
51+
.filter_map(|n| {
52+
let weight = match hugr.get_optype(n) {
53+
OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n),
54+
OpType::FuncDefn(_) => CallGraphNode::FuncDefn(n),
55+
_ => (Some(n) == non_func_root).then_some(CallGraphNode::NonFuncRoot)?,
56+
};
57+
Some((n, g.add_node(weight)))
58+
})
59+
.collect::<HashMap<_, _>>();
60+
for (func, cg_node) in node_to_g.iter() {
61+
traverse(hugr, *cg_node, *func, &mut g, &node_to_g)
62+
}
63+
fn traverse(
64+
h: &impl HugrView,
65+
enclosing_func: NodeIndex<u32>,
66+
node: Node, // Nonstrict-descendant of `enclosing_func``
67+
g: &mut Graph<CallGraphNode, CallGraphEdge>,
68+
node_to_g: &HashMap<Node, NodeIndex<u32>>,
69+
) {
70+
for ch in h.children(node) {
71+
if h.get_optype(ch).is_func_defn() {
72+
continue;
73+
};
74+
traverse(h, enclosing_func, ch, g, node_to_g);
75+
let weight = match h.get_optype(ch) {
76+
OpType::Call(_) => CallGraphEdge::Call(ch),
77+
OpType::LoadFunction(_) => CallGraphEdge::LoadFunction(ch),
78+
_ => continue,
79+
};
80+
if let Some(target) = h.static_source(ch) {
81+
g.add_edge(enclosing_func, *node_to_g.get(&target).unwrap(), weight);
82+
}
83+
}
84+
}
85+
CallGraph { g, node_to_g }
86+
}
87+
88+
/// Allows access to the petgraph
89+
pub fn graph(&self) -> &Graph<CallGraphNode, CallGraphEdge> {
90+
&self.g
91+
}
92+
93+
/// Convert a Hugr [Node] into a petgraph node index.
94+
/// Result will be `None` if `n` is not a [FuncDefn](OpType::FuncDefn),
95+
/// [FuncDecl](OpType::FuncDecl) or the hugr root.
96+
pub fn node_index(&self, n: Node) -> Option<NodeIndex<u32>> {
97+
self.node_to_g.get(&n).copied()
98+
}
99+
}

hugr-passes/src/dead_funcs.rs

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#![warn(missing_docs)]
2+
//! Pass for removing statically-unreachable functions from a Hugr
3+
4+
use std::collections::HashSet;
5+
6+
use hugr_core::{
7+
hugr::hugrmut::HugrMut,
8+
ops::{OpTag, OpTrait},
9+
HugrView, Node,
10+
};
11+
use petgraph::visit::{Dfs, Walker};
12+
13+
use crate::validation::{ValidatePassError, ValidationLevel};
14+
15+
use super::call_graph::{CallGraph, CallGraphNode};
16+
17+
#[derive(Debug, thiserror::Error)]
18+
#[non_exhaustive]
19+
/// Errors produced by [ConstantFoldPass].
20+
pub enum RemoveDeadFuncsError {
21+
#[error("Node {0} was not a FuncDefn child of the Module root")]
22+
InvalidEntryPoint(Node),
23+
#[error(transparent)]
24+
#[allow(missing_docs)]
25+
ValidationError(#[from] ValidatePassError),
26+
}
27+
28+
fn reachable_funcs<'a>(
29+
cg: &'a CallGraph,
30+
h: &'a impl HugrView,
31+
entry_points: impl IntoIterator<Item = Node>,
32+
) -> Result<impl Iterator<Item = Node> + 'a, RemoveDeadFuncsError> {
33+
let g = cg.graph();
34+
let mut entry_points = entry_points.into_iter();
35+
let searcher = if h.get_optype(h.root()).is_module() {
36+
let mut d = Dfs::new(g, 0.into());
37+
d.stack.clear();
38+
for n in entry_points {
39+
if !h.get_optype(n).is_func_defn() || h.get_parent(n) != Some(h.root()) {
40+
return Err(RemoveDeadFuncsError::InvalidEntryPoint(n));
41+
}
42+
d.stack.push(cg.node_index(n).unwrap())
43+
}
44+
d
45+
} else {
46+
if let Some(n) = entry_points.next() {
47+
// Can't be a child of the module root as there isn't a module root!
48+
return Err(RemoveDeadFuncsError::InvalidEntryPoint(n));
49+
}
50+
Dfs::new(g, cg.node_index(h.root()).unwrap())
51+
};
52+
Ok(searcher.iter(g).map(|i| match g.node_weight(i).unwrap() {
53+
CallGraphNode::FuncDefn(n) | CallGraphNode::FuncDecl(n) => *n,
54+
CallGraphNode::NonFuncRoot => h.root(),
55+
}))
56+
}
57+
58+
#[derive(Debug, Clone, Default)]
59+
/// A configuration for the Dead Function Removal pass.
60+
pub struct RemoveDeadFuncsPass {
61+
validation: ValidationLevel,
62+
entry_points: Vec<Node>,
63+
}
64+
65+
impl RemoveDeadFuncsPass {
66+
/// Sets the validation level used before and after the pass is run
67+
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
68+
self.validation = level;
69+
self
70+
}
71+
72+
/// Adds new entry points - these must be [FuncDefn] nodes
73+
/// that are children of the [Module] at the root of the Hugr.
74+
///
75+
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn
76+
/// [Module]: hugr_core::ops::OpType::Module
77+
pub fn with_module_entry_points(
78+
mut self,
79+
entry_points: impl IntoIterator<Item = Node>,
80+
) -> Self {
81+
self.entry_points.extend(entry_points);
82+
self
83+
}
84+
85+
/// Runs the pass (see [remove_dead_funcs]) with this configuration
86+
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
87+
self.validation.run_validated_pass(hugr, |hugr: &mut H, _| {
88+
remove_dead_funcs(hugr, self.entry_points.iter().cloned())
89+
})
90+
}
91+
}
92+
93+
/// Delete from the Hugr any functions that are not used by either [Call] or
94+
/// [LoadFunction] nodes in reachable parts.
95+
///
96+
/// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points,
97+
/// which must be children of the root. Note that if `entry_points` is empty, this will
98+
/// result in all functions in the module being removed.
99+
///
100+
/// For non-[Module]-rooted Hugrs, `entry_points` must be empty; the root node is used.
101+
///
102+
/// # Errors
103+
/// * If there are any `entry_points` but the root of the hugr is not a [Module]
104+
/// * If any node in `entry_points` is
105+
/// * not a [FuncDefn], or
106+
/// * not a child of the root
107+
///
108+
/// [Call]: hugr_core::ops::OpType::Call
109+
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn
110+
/// [LoadFunction]: hugr_core::ops::OpType::LoadFunction
111+
/// [Module]: hugr_core::ops::OpType::Module
112+
pub fn remove_dead_funcs(
113+
h: &mut impl HugrMut,
114+
entry_points: impl IntoIterator<Item = Node>,
115+
) -> Result<(), RemoveDeadFuncsError> {
116+
let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::<HashSet<_>>();
117+
let unreachable = h
118+
.nodes()
119+
.filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n))
120+
.collect::<Vec<_>>();
121+
for n in unreachable {
122+
h.remove_subtree(n);
123+
}
124+
Ok(())
125+
}
126+
127+
#[cfg(test)]
128+
mod test {
129+
use std::collections::HashMap;
130+
131+
use itertools::Itertools;
132+
use rstest::rstest;
133+
134+
use hugr_core::builder::{
135+
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder,
136+
};
137+
use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView};
138+
139+
use super::RemoveDeadFuncsPass;
140+
141+
#[rstest]
142+
#[case([], vec![])] // No entry_points removes everything!
143+
#[case(["main"], vec!["from_main", "main"])]
144+
#[case(["from_main"], vec!["from_main"])]
145+
#[case(["other1"], vec!["other1", "other2"])]
146+
#[case(["other2"], vec!["other2"])]
147+
#[case(["other1", "other2"], vec!["other1", "other2"])]
148+
fn remove_dead_funcs_entry_points(
149+
#[case] entry_points: impl IntoIterator<Item = &'static str>,
150+
#[case] retained_funcs: Vec<&'static str>,
151+
) -> Result<(), Box<dyn std::error::Error>> {
152+
let mut hb = ModuleBuilder::new();
153+
let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?;
154+
let o2inp = o2.input_wires();
155+
let o2 = o2.finish_with_outputs(o2inp)?;
156+
let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?;
157+
158+
let o1c = o1.call(o2.handle(), &[], o1.input_wires())?;
159+
o1.finish_with_outputs(o1c.outputs())?;
160+
161+
let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?;
162+
let f_inp = fm.input_wires();
163+
let fm = fm.finish_with_outputs(f_inp)?;
164+
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
165+
let mc = m.call(fm.handle(), &[], m.input_wires())?;
166+
m.finish_with_outputs(mc.outputs())?;
167+
168+
let mut hugr = hb.finish_hugr()?;
169+
170+
let avail_funcs = hugr
171+
.nodes()
172+
.filter_map(|n| {
173+
hugr.get_optype(n)
174+
.as_func_defn()
175+
.map(|fd| (fd.name.clone(), n))
176+
})
177+
.collect::<HashMap<_, _>>();
178+
179+
RemoveDeadFuncsPass::default()
180+
.with_module_entry_points(
181+
entry_points
182+
.into_iter()
183+
.map(|name| *avail_funcs.get(name).unwrap())
184+
.collect::<Vec<_>>(),
185+
)
186+
.run(&mut hugr)
187+
.unwrap();
188+
189+
let remaining_funcs = hugr
190+
.nodes()
191+
.filter_map(|n| hugr.get_optype(n).as_func_defn().map(|fd| fd.name.as_str()))
192+
.sorted()
193+
.collect_vec();
194+
assert_eq!(remaining_funcs, retained_funcs);
195+
Ok(())
196+
}
197+
}

hugr-passes/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
//! Compilation passes acting on the HUGR program representation.
22
3+
pub mod call_graph;
34
pub mod const_fold;
45
pub mod dataflow;
6+
mod dead_funcs;
7+
pub use dead_funcs::{remove_dead_funcs, RemoveDeadFuncsPass};
58
pub mod force_order;
69
mod half_node;
710
pub mod lower;
811
pub mod merge_bbs;
912
mod monomorphize;
1013
// TODO: Deprecated re-export. Remove on a breaking release.
14+
#[deprecated(
15+
since = "0.14.1",
16+
note = "Use `hugr::algorithms::call_graph::RemoveDeadFuncsPass` instead."
17+
)]
18+
#[allow(deprecated)]
19+
pub use monomorphize::remove_polyfuncs;
20+
// TODO: Deprecated re-export. Remove on a breaking release.
1121
#[deprecated(
1222
since = "0.14.1",
1323
note = "Use `hugr::algorithms::MonomorphizePass` instead."
1424
)]
1525
#[allow(deprecated)]
1626
pub use monomorphize::monomorphize;
17-
pub use monomorphize::{remove_polyfuncs, MonomorphizeError, MonomorphizePass};
27+
pub use monomorphize::{MonomorphizeError, MonomorphizePass};
1828
pub mod nest_cfgs;
1929
pub mod non_local;
2030
pub mod validation;

0 commit comments

Comments
 (0)