Skip to content

Commit a4f5196

Browse files
acl-cqcdoug-q
andauthored
feat: ReplaceTypes: handlers for array constants + linearization (#2023)
Final chapter closes #1948. `default()` includes all the handlers, `new_empty()` does not. --------- Co-authored-by: Douglas Wilson <[email protected]>
1 parent ff0252e commit a4f5196

File tree

4 files changed

+473
-47
lines changed

4 files changed

+473
-47
lines changed

hugr-core/src/std_extensions/collections/array.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ lazy_static! {
178178
};
179179
}
180180

181-
fn array_type_def() -> &'static TypeDef {
181+
/// Gets the [TypeDef] for arrays. Note that instantiations are more easily
182+
/// created via [array_type] and [array_type_parametric]
183+
pub fn array_type_def() -> &'static TypeDef {
182184
EXTENSION.get_type(&ARRAY_TYPENAME).unwrap()
183185
}
184186

hugr-passes/src/replace_types.rs

Lines changed: 74 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ use std::borrow::Cow;
66
use std::collections::HashMap;
77
use std::sync::Arc;
88

9+
use handlers::list_const;
10+
use hugr_core::std_extensions::collections::array::array_type_def;
11+
use hugr_core::std_extensions::collections::list::list_type_def;
912
use thiserror::Error;
1013

1114
use hugr_core::builder::{BuildError, BuildHandle, Dataflow};
@@ -19,7 +22,7 @@ use hugr_core::ops::{
1922
Value, CFG, DFG,
2023
};
2124
use hugr_core::types::{
22-
CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer,
25+
ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer,
2326
};
2427
use hugr_core::{Hugr, HugrView, Node, Wire};
2528

@@ -125,7 +128,7 @@ impl NodeTemplate {
125128
/// * See also limitations noted for [Linearizer].
126129
///
127130
/// [monomorphization]: super::monomorphize()
128-
#[derive(Clone, Default)]
131+
#[derive(Clone)]
129132
pub struct ReplaceTypes {
130133
type_map: HashMap<CustomType, Type>,
131134
param_types: HashMap<ParametricType, Arc<dyn Fn(&[TypeArg]) -> Option<Type>>>,
@@ -143,6 +146,16 @@ pub struct ReplaceTypes {
143146
validation: ValidationLevel,
144147
}
145148

149+
impl Default for ReplaceTypes {
150+
fn default() -> Self {
151+
let mut res = Self::new_empty();
152+
res.linearize = DelegatingLinearizer::default();
153+
res.replace_consts_parametrized(array_type_def(), handlers::array_const);
154+
res.replace_consts_parametrized(list_type_def(), list_const);
155+
res
156+
}
157+
}
158+
146159
impl TypeTransformer for ReplaceTypes {
147160
type Err = ReplaceTypesError;
148161

@@ -173,10 +186,27 @@ pub enum ReplaceTypesError {
173186
#[error(transparent)]
174187
ValidationError(#[from] ValidatePassError),
175188
#[error(transparent)]
189+
ConstError(#[from] ConstTypeError),
190+
#[error(transparent)]
176191
LinearizeError(#[from] LinearizeError),
177192
}
178193

179194
impl ReplaceTypes {
195+
/// Makes a new instance. Unlike [Self::default], this does not understand
196+
/// any extension types, even those in the prelude.
197+
pub fn new_empty() -> Self {
198+
Self {
199+
type_map: Default::default(),
200+
param_types: Default::default(),
201+
linearize: DelegatingLinearizer::new_empty(),
202+
op_map: Default::default(),
203+
param_ops: Default::default(),
204+
consts: Default::default(),
205+
param_consts: Default::default(),
206+
validation: Default::default(),
207+
}
208+
}
209+
180210
/// Sets the validation level used before and after the pass is run.
181211
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
182212
self.validation = level;
@@ -447,38 +477,7 @@ impl ReplaceTypes {
447477
}
448478
}
449479

450-
pub mod handlers {
451-
//! Callbacks for use with [ReplaceTypes::replace_consts_parametrized]
452-
use hugr_core::ops::{constant::OpaqueValue, Value};
453-
use hugr_core::std_extensions::collections::list::ListValue;
454-
use hugr_core::types::Transformable;
455-
456-
use super::{ReplaceTypes, ReplaceTypesError};
457-
458-
/// Handler for [ListValue] constants that recursively [ReplaceTypes::change_value]s
459-
/// the elements of the list
460-
pub fn list_const(
461-
val: &OpaqueValue,
462-
repl: &ReplaceTypes,
463-
) -> Result<Option<Value>, ReplaceTypesError> {
464-
let Some(lv) = val.value().downcast_ref::<ListValue>() else {
465-
return Ok(None);
466-
};
467-
let mut vals: Vec<Value> = lv.get_contents().to_vec();
468-
let mut ch = false;
469-
for v in vals.iter_mut() {
470-
ch |= repl.change_value(v)?;
471-
}
472-
// If none of the values has changed, assume the Type hasn't (Values have a single known type)
473-
if !ch {
474-
return Ok(None);
475-
};
476-
477-
let mut elem_t = lv.get_element_type().clone();
478-
elem_t.transform(repl)?;
479-
Ok(Some(ListValue::new(elem_t, vals).into()))
480-
}
481-
}
480+
pub mod handlers;
482481

483482
#[derive(Clone, Hash, PartialEq, Eq)]
484483
struct OpHashWrapper {
@@ -536,20 +535,26 @@ mod test {
536535
use hugr_core::extension::simple_op::MakeExtensionOp;
537536
use hugr_core::extension::{TypeDefBound, Version};
538537

538+
use hugr_core::ops::constant::OpaqueValue;
539539
use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value};
540540
use hugr_core::std_extensions::arithmetic::int_types::ConstInt;
541541
use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES};
542542
use hugr_core::std_extensions::collections::array::{
543-
array_type, ArrayOp, ArrayOpDef, ArrayValue,
543+
array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue,
544544
};
545545
use hugr_core::std_extensions::collections::list::{
546546
list_type, list_type_def, ListOp, ListValue,
547547
};
548548

549+
use hugr_core::hugr::ValidationError;
549550
use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow};
550551
use hugr_core::{hugr::IdentList, type_row, Extension, HugrView};
551552
use itertools::Itertools;
553+
use rstest::rstest;
554+
555+
use crate::validation::ValidatePassError;
552556

557+
use super::ReplaceTypesError;
553558
use super::{handlers::list_const, NodeTemplate, ReplaceTypes};
554559

555560
const PACKED_VEC: &str = "PackedVec";
@@ -792,8 +797,6 @@ mod test {
792797
let backup = tl.finish_hugr().unwrap();
793798

794799
let mut lowerer = ReplaceTypes::default();
795-
// Recursively descend into lists
796-
lowerer.replace_consts_parametrized(list_type_def(), list_const);
797800

798801
// 1. Lower List<T> to Array<10, T> UNLESS T is usize_t() or i64_t
799802
lowerer.replace_parametrized_type(list_type_def(), |args| {
@@ -951,4 +954,38 @@ mod test {
951954
["NoBoundsCheck.read", "collections.list.get"]
952955
);
953956
}
957+
958+
#[rstest]
959+
#[case(&[])]
960+
#[case(&[3])]
961+
#[case(&[5,7,11,13,17,19])]
962+
fn array_const(#[case] vals: &[u64]) {
963+
use super::handlers::array_const;
964+
let mut dfb = DFGBuilder::new(inout_sig(
965+
type_row![],
966+
array_type(vals.len() as _, usize_t()),
967+
))
968+
.unwrap();
969+
let c = dfb.add_load_value(ArrayValue::new(
970+
usize_t(),
971+
vals.iter().map(|u| ConstUsize::new(*u).into()),
972+
));
973+
let backup = dfb.finish_hugr_with_outputs([c]).unwrap();
974+
975+
let mut repl = ReplaceTypes::new_empty();
976+
let usize_custom_t = usize_t().as_extension().unwrap().clone();
977+
repl.replace_type(usize_custom_t.clone(), INT_TYPES[6].clone());
978+
repl.replace_consts(usize_custom_t, |cst: &OpaqueValue, _| {
979+
let cu = cst.value().downcast_ref::<ConstUsize>().unwrap();
980+
Ok(ConstInt::new_u(6, cu.value())?.into())
981+
});
982+
assert!(
983+
matches!(repl.run(&mut backup.clone()), Err(ReplaceTypesError::ValidationError(ValidatePassError::OutputError {
984+
err: ValidationError::IncompatiblePorts {from, to, ..}, ..
985+
})) if backup.get_optype(from).is_const() && to == c.node())
986+
);
987+
repl.replace_consts_parametrized(array_type_def(), array_const);
988+
let mut h = backup;
989+
repl.run(&mut h).unwrap(); // Includes validation
990+
}
954991
}

0 commit comments

Comments
 (0)