@@ -61,10 +61,20 @@ fn reachable_funcs<'a, H: HugrView>(
61
61
} ) )
62
62
}
63
63
64
- #[ derive( Debug , Clone , Default ) ]
64
+ #[ derive( Debug , Clone ) ]
65
65
/// A configuration for the Dead Function Removal pass.
66
66
pub struct RemoveDeadFuncsPass {
67
67
entry_points : Vec < Node > ,
68
+ include_exports : bool ,
69
+ }
70
+
71
+ impl Default for RemoveDeadFuncsPass {
72
+ fn default ( ) -> Self {
73
+ Self {
74
+ entry_points : Default :: default ( ) ,
75
+ include_exports : true ,
76
+ }
77
+ }
68
78
}
69
79
70
80
impl RemoveDeadFuncsPass {
@@ -80,16 +90,34 @@ impl RemoveDeadFuncsPass {
80
90
self . entry_points . extend ( entry_points) ;
81
91
self
82
92
}
93
+
94
+ /// Sets whether the exported [FuncDefn](hugr_core::ops::FuncDefn) children of a
95
+ /// [Module](hugr_core::ops::Module) are included as entry points (yes by default)
96
+ pub fn include_module_exports ( mut self , include : bool ) -> Self {
97
+ self . include_exports = include;
98
+ self
99
+ }
83
100
}
84
101
85
102
impl < H : HugrMut < Node = Node > > ComposablePass < H > for RemoveDeadFuncsPass {
86
103
type Error = RemoveDeadFuncsError ;
87
104
type Result = ( ) ;
88
105
fn run ( & self , hugr : & mut H ) -> Result < ( ) , RemoveDeadFuncsError > {
106
+ let exports = if hugr. entrypoint ( ) == hugr. module_root ( ) && self . include_exports {
107
+ hugr. children ( hugr. module_root ( ) )
108
+ . filter ( |ch| {
109
+ hugr. get_optype ( * ch)
110
+ . as_func_defn ( )
111
+ . is_some_and ( |fd| fd. link_name . is_some ( ) )
112
+ } )
113
+ . collect ( )
114
+ } else {
115
+ vec ! [ ]
116
+ } ;
89
117
let reachable = reachable_funcs (
90
118
& CallGraph :: new ( hugr) ,
91
119
hugr,
92
- self . entry_points . iter ( ) . cloned ( ) ,
120
+ self . entry_points . iter ( ) . cloned ( ) . chain ( exports ) ,
93
121
) ?
94
122
. collect :: < HashSet < _ > > ( ) ;
95
123
let unreachable = hugr
@@ -108,30 +136,13 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
108
136
/// Deletes from the Hugr any functions that are not used by either [Call] or
109
137
/// [LoadFunction] nodes in reachable parts.
110
138
///
111
- /// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points,
112
- /// which must be children of the root. Note that if `entry_points` is empty, this will
113
- /// result in all functions in the module being removed.
114
- ///
115
- /// For non-[Module]-rooted Hugrs, `entry_points` must be empty; the root node is used.
116
- ///
117
- /// # Errors
118
- /// * If there are any `entry_points` but the root of the hugr is not a [Module]
119
- /// * If any node in `entry_points` is
120
- /// * not a [FuncDefn], or
121
- /// * not a child of the root
139
+ /// For [Module]-rooted Hugrs, all top-level functions with [FuncDefn::link_name] set,
140
+ /// will be used as entry points.
122
141
///
123
- /// [Call]: hugr_core::ops::OpType::Call
124
- /// [FuncDefn]: hugr_core::ops::OpType::FuncDefn
125
- /// [LoadFunction]: hugr_core::ops::OpType::LoadFunction
126
- /// [Module]: hugr_core::ops::OpType::Module
127
142
pub fn remove_dead_funcs (
128
143
h : & mut impl HugrMut < Node = Node > ,
129
- entry_points : impl IntoIterator < Item = Node > ,
130
144
) -> Result < ( ) , ValidatePassError < Node , RemoveDeadFuncsError > > {
131
- validate_if_test (
132
- RemoveDeadFuncsPass :: default ( ) . with_module_entry_points ( entry_points) ,
133
- h,
134
- )
145
+ validate_if_test ( RemoveDeadFuncsPass :: default ( ) , h)
135
146
}
136
147
137
148
#[ cfg( test) ]
@@ -146,29 +157,34 @@ mod test {
146
157
} ;
147
158
use hugr_core:: { extension:: prelude:: usize_t, types:: Signature , HugrView } ;
148
159
149
- use super :: remove_dead_funcs;
160
+ use super :: RemoveDeadFuncsPass ;
161
+ use crate :: ComposablePass ;
150
162
151
163
#[ rstest]
152
- #[ case( [ ] , vec![ ] ) ] // No entry_points removes everything!
153
- #[ case( [ "main" ] , vec![ "from_main" , "main" ] ) ]
154
- #[ case( [ "from_main" ] , vec![ "from_main" ] ) ]
155
- #[ case( [ "other1" ] , vec![ "other1" , "other2" ] ) ]
156
- #[ case( [ "other2" ] , vec![ "other2" ] ) ]
157
- #[ case( [ "other1" , "other2" ] , vec![ "other1" , "other2" ] ) ]
164
+ #[ case( false , [ ] , vec![ ] ) ] // No entry_points removes everything!
165
+ #[ case( false , [ "main" ] , vec![ "from_main" , "main" ] ) ]
166
+ #[ case( false , [ "from_main" ] , vec![ "from_main" ] ) ]
167
+ #[ case( false , [ "other1" ] , vec![ "other1" , "other2" ] ) ]
168
+ #[ case( false , [ "other2" ] , vec![ "other2" ] ) ]
169
+ #[ case( false , [ "other1" , "other2" ] , vec![ "other1" , "other2" ] ) ]
170
+ #[ case( true , [ ] , vec![ "from_main" , "main" , "other2" ] ) ]
171
+ #[ case( true , [ "other1" ] , vec![ "from_main" , "main" , "other1" , "other2" ] ) ]
158
172
fn remove_dead_funcs_entry_points (
173
+ #[ case] include_exports : bool ,
159
174
#[ case] entry_points : impl IntoIterator < Item = & ' static str > ,
160
175
#[ case] retained_funcs : Vec < & ' static str > ,
161
176
) -> Result < ( ) , Box < dyn std:: error:: Error > > {
162
177
let mut hb = ModuleBuilder :: new ( ) ;
163
178
let o2 = hb. define_function ( "other2" , Signature :: new_endo ( usize_t ( ) ) ) ?;
164
179
let o2inp = o2. input_wires ( ) ;
165
180
let o2 = o2. finish_with_outputs ( o2inp) ?;
166
- let mut o1 = hb. define_function ( "other1" , Signature :: new_endo ( usize_t ( ) ) ) ?;
181
+ let mut o1 =
182
+ hb. define_function_link_name ( "other1" , Signature :: new_endo ( usize_t ( ) ) , None ) ?;
167
183
168
184
let o1c = o1. call ( o2. handle ( ) , & [ ] , o1. input_wires ( ) ) ?;
169
185
o1. finish_with_outputs ( o1c. outputs ( ) ) ?;
170
186
171
- let fm = hb. define_function ( "from_main" , Signature :: new_endo ( usize_t ( ) ) ) ?;
187
+ let fm = hb. define_function_link_name ( "from_main" , Signature :: new_endo ( usize_t ( ) ) , None ) ?;
172
188
let f_inp = fm. input_wires ( ) ;
173
189
let fm = fm. finish_with_outputs ( f_inp) ?;
174
190
let mut m = hb. define_function ( "main" , Signature :: new_endo ( usize_t ( ) ) ) ?;
@@ -186,14 +202,16 @@ mod test {
186
202
} )
187
203
. collect :: < HashMap < _ , _ > > ( ) ;
188
204
189
- remove_dead_funcs (
190
- & mut hugr,
191
- entry_points
192
- . into_iter ( )
193
- . map ( |name| * avail_funcs. get ( name) . unwrap ( ) )
194
- . collect :: < Vec < _ > > ( ) ,
195
- )
196
- . unwrap ( ) ;
205
+ RemoveDeadFuncsPass :: default ( )
206
+ . include_module_exports ( include_exports)
207
+ . with_module_entry_points (
208
+ entry_points
209
+ . into_iter ( )
210
+ . map ( |name| * avail_funcs. get ( name) . unwrap ( ) )
211
+ . collect :: < Vec < _ > > ( ) ,
212
+ )
213
+ . run ( & mut hugr)
214
+ . unwrap ( ) ;
197
215
198
216
let remaining_funcs = hugr
199
217
. nodes ( )
0 commit comments