Skip to content

Commit e54c742

Browse files
committed
TailLoopTermination just examine whatever PartialValue's we have, remove most
1 parent 41820f2 commit e54c742

File tree

3 files changed

+11
-129
lines changed

3 files changed

+11
-129
lines changed

hugr-passes/src/dataflow/datalog.rs

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use ascent::lattice::BoundedLattice;
21
use hugr_core::extension::prelude::{MakeTuple, UnpackTuple};
32
use std::collections::HashMap;
43
use std::hash::Hash;
@@ -108,15 +107,6 @@ ascent::ascent! {
108107
if let Some(fields) = out_in_row.unpack_first(1, tailloop.just_outputs.len()), // if it is possible for the tag to be 1
109108
for (out_p, v) in (0..).map(OutgoingPort::from).zip(fields);
110109

111-
lattice tail_loop_termination(C,Node,TailLoopTermination);
112-
tail_loop_termination(c,tl_n,TailLoopTermination::bottom()) <--
113-
tail_loop_node(c,tl_n);
114-
tail_loop_termination(c,tl_n,TailLoopTermination::from_control_value(v)) <--
115-
tail_loop_node(c,tl_n),
116-
io_node(c,tl,out_n, IO::Output),
117-
in_wire_value(c, out_n, IncomingPort::from(0), v);
118-
119-
120110
// Conditional
121111
relation conditional_node(C, Node);
122112
relation case_node(C,Node,usize, Node);
@@ -221,11 +211,14 @@ impl<V: AbstractValue, C: DFContext<V>> Machine<V, C> {
221211

222212
pub fn tail_loop_terminates(&self, hugr: impl HugrView, node: Node) -> TailLoopTermination {
223213
assert!(hugr.get_optype(node).is_tail_loop());
224-
self.0
225-
.tail_loop_termination
226-
.iter()
227-
.find_map(|(_, n, v)| (n == &node).then_some(*v))
228-
.unwrap()
214+
let [_, out] = hugr.get_io(node).unwrap();
215+
TailLoopTermination::from_control_value(
216+
self.0
217+
.in_wire_value
218+
.iter()
219+
.find_map(|(_, n, p, v)| (*n == out && p.index() == 0).then_some(v))
220+
.unwrap(),
221+
)
229222
}
230223

231224
pub fn case_reachable(&self, hugr: impl HugrView, case: Node) -> bool {

hugr-passes/src/dataflow/datalog/test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ fn test_tail_loop_always_iterates() {
125125
let o_r2 = machine.read_out_wire_partial_value(tl_o2).unwrap();
126126
assert_eq!(o_r2, PartialValue::bottom().into());
127127
assert_eq!(
128-
TailLoopTermination::bottom(),
128+
TailLoopTermination::Bottom,
129129
machine.tail_loop_terminates(&hugr, tail_loop.node())
130130
)
131131
}

hugr-passes/src/dataflow/datalog/utils.rs

Lines changed: 2 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@
33
// https://github.com/proptest-rs/proptest/issues/447
44
#![cfg_attr(test, allow(non_local_definitions))]
55

6-
use std::cmp::Ordering;
7-
86
use ascent::lattice::{BoundedLattice, Lattice};
97

108
use super::super::partial_value::{AbstractValue, PartialValue};
119
use hugr_core::{types::Signature, HugrView, IncomingPort, Node, OutgoingPort};
1210

13-
#[cfg(test)]
14-
use proptest_derive::Arbitrary;
15-
1611
impl<V: AbstractValue> Lattice for PartialValue<V> {
1712
fn meet(self, other: Self) -> Self {
1813
self.meet(other)
@@ -57,7 +52,6 @@ pub(super) fn value_outputs(h: &impl HugrView, n: Node) -> impl Iterator<Item =
5752
}
5853

5954
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
60-
#[cfg_attr(test, derive(Arbitrary))]
6155
pub enum TailLoopTermination {
6256
Bottom,
6357
ExactlyZeroContinues,
@@ -70,114 +64,9 @@ impl TailLoopTermination {
7064
if may_break && !may_continue {
7165
Self::ExactlyZeroContinues
7266
} else if may_break && may_continue {
73-
Self::top()
67+
Self::Top
7468
} else {
75-
Self::bottom()
76-
}
77-
}
78-
}
79-
80-
impl PartialOrd for TailLoopTermination {
81-
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
82-
if self == other {
83-
return Some(std::cmp::Ordering::Equal);
84-
};
85-
match (self, other) {
86-
(Self::Bottom, _) => Some(Ordering::Less),
87-
(_, Self::Bottom) => Some(Ordering::Greater),
88-
(Self::Top, _) => Some(Ordering::Greater),
89-
(_, Self::Top) => Some(Ordering::Less),
90-
_ => None,
91-
}
92-
}
93-
}
94-
95-
impl Lattice for TailLoopTermination {
96-
fn meet(mut self, other: Self) -> Self {
97-
self.meet_mut(other);
98-
self
99-
}
100-
101-
fn join(mut self, other: Self) -> Self {
102-
self.join_mut(other);
103-
self
104-
}
105-
106-
fn meet_mut(&mut self, other: Self) -> bool {
107-
// let new_self = &mut self;
108-
match (*self).partial_cmp(&other) {
109-
Some(Ordering::Greater) => {
110-
*self = other;
111-
true
112-
}
113-
Some(_) => false,
114-
_ => {
115-
*self = Self::Bottom;
116-
true
117-
}
118-
}
119-
}
120-
121-
fn join_mut(&mut self, other: Self) -> bool {
122-
match (*self).partial_cmp(&other) {
123-
Some(Ordering::Less) => {
124-
*self = other;
125-
true
126-
}
127-
Some(_) => false,
128-
_ => {
129-
*self = Self::Top;
130-
true
131-
}
132-
}
133-
}
134-
}
135-
136-
impl BoundedLattice for TailLoopTermination {
137-
fn bottom() -> Self {
138-
Self::Bottom
139-
}
140-
141-
fn top() -> Self {
142-
Self::Top
143-
}
144-
}
145-
146-
#[cfg(test)]
147-
#[cfg_attr(test, allow(non_local_definitions))]
148-
mod test {
149-
use super::*;
150-
use proptest::prelude::*;
151-
152-
proptest! {
153-
#[test]
154-
fn bounded_lattice(v: TailLoopTermination) {
155-
prop_assert!(v <= TailLoopTermination::top());
156-
prop_assert!(v >= TailLoopTermination::bottom());
157-
}
158-
159-
#[test]
160-
fn meet_join_self_noop(v1: TailLoopTermination) {
161-
let mut subject = v1.clone();
162-
163-
assert_eq!(v1.clone(), v1.clone().join(v1.clone()));
164-
assert!(!subject.join_mut(v1.clone()));
165-
assert_eq!(subject, v1);
166-
167-
assert_eq!(v1.clone(), v1.clone().meet(v1.clone()));
168-
assert!(!subject.meet_mut(v1.clone()));
169-
assert_eq!(subject, v1);
170-
}
171-
172-
#[test]
173-
fn lattice(v1: TailLoopTermination, v2: TailLoopTermination) {
174-
let meet = v1.clone().meet(v2.clone());
175-
prop_assert!(meet <= v1, "meet not less <=: {:#?}", &meet);
176-
prop_assert!(meet <= v2, "meet not less <=: {:#?}", &meet);
177-
178-
let join = v1.clone().join(v2.clone());
179-
prop_assert!(join >= v1, "join not >=: {:#?}", &join);
180-
prop_assert!(join >= v2, "join not >=: {:#?}", &join);
69+
Self::Bottom
18170
}
18271
}
18372
}

0 commit comments

Comments
 (0)