Skip to content

Commit 718f058

Browse files
committed
Use Value::sum, adding FromSum::Err; fmt
1 parent c050321 commit 718f058

File tree

4 files changed

+37
-16
lines changed

4 files changed

+37
-16
lines changed

hugr-passes/src/const_fold2/datalog/partial_value.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@ pub trait AbstractValue: Clone + std::fmt::Debug + PartialEq + Eq + Hash {
1414
fn as_sum(&self) -> Option<(usize, impl Iterator<Item = Self> + '_)>;
1515
}
1616

17-
pub trait FromSum {
18-
fn new_sum(tag: usize, items: impl IntoIterator<Item=Self>, st: &SumType) -> Self;
17+
pub trait FromSum: Sized {
18+
type Err: std::error::Error;
19+
fn try_new_sum(
20+
tag: usize,
21+
items: impl IntoIterator<Item = Self>,
22+
st: &SumType,
23+
) -> Result<Self, Self::Err>;
1924
fn debug_check_is_type(&self, _ty: &Type) {}
2025
}
2126

@@ -120,7 +125,7 @@ impl<V: AbstractValue> PartialSum<V> {
120125
.map(|(v, t)| v.clone().try_into_value(t))
121126
.collect::<Result<Vec<_>, _>>()
122127
{
123-
Ok(vs) => Ok(V2::new_sum(*k, vs, &st)),
128+
Ok(vs) => V2::try_new_sum(*k, vs, &st).map_err(|_| self),
124129
Err(_) => Err(self),
125130
}
126131
}

hugr-passes/src/const_fold2/datalog/utils.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
// https://github.com/proptest-rs/proptest/issues/447
44
#![cfg_attr(test, allow(non_local_definitions))]
55

6-
use std::{cmp::Ordering, ops::{Index, IndexMut}};
6+
use std::{
7+
cmp::Ordering,
8+
ops::{Index, IndexMut},
9+
};
710

811
use ascent::lattice::{BoundedLattice, Lattice};
912
use itertools::zip_eq;
@@ -149,8 +152,9 @@ where
149152

150153
impl<V, Idx> IndexMut<Idx> for ValueRow<V>
151154
where
152-
Vec<PartialValue<V>>: IndexMut<Idx> {
153-
fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
155+
Vec<PartialValue<V>>: IndexMut<Idx>,
156+
{
157+
fn index_mut(&mut self, index: Idx) -> &mut Self::Output {
154158
self.0.index_mut(index)
155159
}
156160
}

hugr-passes/src/const_fold2/value_handle.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33

44
use hugr_core::ops::constant::{CustomConst, Sum};
55
use hugr_core::ops::Value;
6-
use hugr_core::types::Type;
6+
use hugr_core::types::{ConstTypeError, Type};
77
use hugr_core::Node;
88

99
use super::datalog::{AbstractValue, FromSum};
@@ -103,8 +103,13 @@ impl AbstractValue for ValueHandle {
103103
}
104104

105105
impl FromSum for Value {
106-
fn new_sum(tag: usize, items: impl IntoIterator<Item=Self>, st: &hugr_core::types::SumType) -> Self {
107-
Value::Sum(Sum {tag, values: items.into_iter().collect(), sum_type: st.clone()})
106+
type Err = ConstTypeError;
107+
fn try_new_sum(
108+
tag: usize,
109+
items: impl IntoIterator<Item = Self>,
110+
st: &hugr_core::types::SumType,
111+
) -> Result<Self, ConstTypeError> {
112+
Self::sum(tag, items, st.clone())
108113
}
109114
}
110115

hugr-passes/src/const_fold2/value_handle/context.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl<H: HugrView> DFContext<ValueHandle> for HugrValueContext<H> {
6868
&self,
6969
n: Node,
7070
ins: &[(IncomingPort, Value)],
71-
) -> Vec<(OutgoingPort,ValueHandle)> {
71+
) -> Vec<(OutgoingPort, ValueHandle)> {
7272
match self.0.get_optype(n) {
7373
OpType::LoadConstant(load_op) => {
7474
assert!(ins.is_empty()); // static edge, so need to find constant
@@ -78,14 +78,21 @@ impl<H: HugrView> DFContext<ValueHandle> for HugrValueContext<H> {
7878
.unwrap()
7979
.0;
8080
let const_op = self.0.get_optype(const_node).as_const().unwrap();
81-
vec![(OutgoingPort::from(0), ValueHandle::new(
82-
const_node.into(),
83-
Arc::new(const_op.value().clone()),
84-
))]
81+
vec![(
82+
OutgoingPort::from(0),
83+
ValueHandle::new(const_node.into(), Arc::new(const_op.value().clone())),
84+
)]
8585
}
8686
OpType::CustomOp(CustomOp::Extension(op)) => {
87-
let ins = ins.into_iter().map(|(p,v)|(*p,v.clone())).collect::<Vec<_>>();
88-
op.constant_fold(&ins).map_or(Vec::new(), |outs|outs.into_iter().map(|(p,v)|(p, ValueHandle::new(ValueKey::Node(n), Arc::new(v)))).collect())
87+
let ins = ins
88+
.into_iter()
89+
.map(|(p, v)| (*p, v.clone()))
90+
.collect::<Vec<_>>();
91+
op.constant_fold(&ins).map_or(Vec::new(), |outs| {
92+
outs.into_iter()
93+
.map(|(p, v)| (p, ValueHandle::new(ValueKey::Node(n), Arc::new(v))))
94+
.collect()
95+
})
8996
}
9097
_ => vec![],
9198
}

0 commit comments

Comments
 (0)