Skip to content

Commit 7312c62

Browse files
authored
fix: const-folding Module keeps at least "main" (#1901)
Minimal, non-breaking, fix for #1797, this seems consistent with what dataflow analysis does.
1 parent d6b8681 commit 7312c62

File tree

2 files changed

+74
-11
lines changed

2 files changed

+74
-11
lines changed

hugr-passes/src/const_fold.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use hugr_core::{
1313
},
1414
ops::{
1515
constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant,
16-
Value,
16+
OpType, Value,
1717
},
1818
types::{EdgeKind, TypeArg},
1919
HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire,
@@ -88,8 +88,7 @@ impl ConstantFoldPass {
8888
});
8989

9090
let results = Machine::new(&hugr).run(ConstFoldContext(hugr), inputs);
91-
let mut keep_nodes = HashSet::new();
92-
self.find_needed_nodes(&results, &mut keep_nodes);
91+
let keep_nodes = self.find_needed_nodes(&results);
9392
let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i);
9493

9594
let remove_nodes = hugr
@@ -145,17 +144,30 @@ impl ConstantFoldPass {
145144
fn find_needed_nodes<H: HugrView>(
146145
&self,
147146
results: &AnalysisResults<ValueHandle, H>,
148-
needed: &mut HashSet<Node>,
149-
) {
150-
let mut q = VecDeque::new();
147+
) -> HashSet<Node> {
148+
let mut needed = HashSet::new();
151149
let h = results.hugr();
152-
q.push_back(h.root());
150+
let mut q = VecDeque::from_iter([h.root()]);
153151
while let Some(n) = q.pop_front() {
154152
if !needed.insert(n) {
155153
continue;
156154
};
157-
158-
if h.get_optype(n).is_cfg() {
155+
if h.get_optype(n).is_module() {
156+
for ch in h.children(n) {
157+
match h.get_optype(ch) {
158+
OpType::AliasDecl(_) | OpType::AliasDefn(_) => {
159+
// Use of these is done via names, rather than following edges.
160+
// We could track these as well but for now be conservative.
161+
q.push_back(ch);
162+
}
163+
OpType::FuncDefn(f) if f.name == "main" => {
164+
// Dataflow analysis will have applied any inputs the 'main' function, so assume reachable.
165+
q.push_back(ch);
166+
}
167+
_ => (),
168+
}
169+
}
170+
} else if h.get_optype(n).is_cfg() {
159171
for bb in h.children(n) {
160172
//if results.bb_reachable(bb).unwrap() { // no, we'd need to patch up predicates
161173
q.push_back(bb);
@@ -192,6 +204,7 @@ impl ConstantFoldPass {
192204
}
193205
}
194206
}
207+
needed
195208
}
196209
}
197210

hugr-passes/src/const_fold/test.rs

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
use std::collections::hash_map::RandomState;
22
use std::collections::HashSet;
33

4+
use hugr_core::ops::handle::NodeHandle;
5+
use hugr_core::ops::Const;
6+
use hugr_core::std_extensions::arithmetic::{int_ops, int_types};
47
use itertools::Itertools;
58
use lazy_static::lazy_static;
69
use rstest::rstest;
710

811
use hugr_core::builder::{
912
endo_sig, inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
10-
SubContainer,
13+
HugrBuilder, ModuleBuilder, SubContainer,
1114
};
1215
use hugr_core::extension::prelude::{
1316
bool_t, const_ok, error_type, string_type, sum_with_error, ConstError, ConstString, MakeTuple,
@@ -25,7 +28,7 @@ use hugr_core::std_extensions::arithmetic::{
2528
int_types::{ConstInt, INT_TYPES},
2629
};
2730
use hugr_core::std_extensions::logic::LogicOp;
28-
use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV};
31+
use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
2932
use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node};
3033

3134
use crate::dataflow::{partial_from_const, DFContext, PartialValue};
@@ -1580,3 +1583,50 @@ fn test_cfg(
15801583
assert_eq!(output_src, nested);
15811584
}
15821585
}
1586+
1587+
#[test]
1588+
fn test_module() -> Result<(), Box<dyn std::error::Error>> {
1589+
let mut mb = ModuleBuilder::new();
1590+
// Define a top-level constant, (only) the second of which can be removed
1591+
let c7 = mb.add_constant(Value::from(ConstInt::new_u(5, 7)?));
1592+
let c17 = mb.add_constant(Value::from(ConstInt::new_u(5, 17)?));
1593+
let ad1 = mb.add_alias_declare("unused", TypeBound::Any)?;
1594+
let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?;
1595+
let mut main = mb.define_function(
1596+
"main",
1597+
Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2])
1598+
.with_extension_delta(int_types::EXTENSION_ID)
1599+
.with_extension_delta(int_ops::EXTENSION_ID),
1600+
)?;
1601+
let lc7 = main.load_const(&c7);
1602+
let lc17 = main.load_const(&c17);
1603+
let [add] = main
1604+
.add_dataflow_op(IntOpDef::iadd.with_log_width(5), [lc7, lc17])?
1605+
.outputs_arr();
1606+
let main = main.finish_with_outputs([lc7, add])?;
1607+
let mut hugr = mb.finish_hugr()?;
1608+
constant_fold_pass(&mut hugr);
1609+
assert!(hugr.get_optype(hugr.root()).is_module());
1610+
assert_eq!(
1611+
hugr.children(hugr.root()).collect_vec(),
1612+
[c7.node(), ad1.node(), ad2.node(), main.node()]
1613+
);
1614+
let tags = hugr
1615+
.children(main.node())
1616+
.map(|n| hugr.get_optype(n).tag())
1617+
.collect_vec();
1618+
for (tag, expected_count) in [
1619+
(OpTag::Input, 1),
1620+
(OpTag::Output, 1),
1621+
(OpTag::Const, 1),
1622+
(OpTag::LoadConst, 2),
1623+
] {
1624+
assert_eq!(tags.iter().filter(|t| **t == tag).count(), expected_count);
1625+
}
1626+
assert_eq!(
1627+
hugr.children(main.node())
1628+
.find_map(|n| hugr.get_optype(n).as_const()),
1629+
Some(&Const::new(ConstInt::new_u(5, 24).unwrap().into()))
1630+
);
1631+
Ok(())
1632+
}

0 commit comments

Comments
 (0)