@@ -6,6 +6,9 @@ use std::borrow::Cow;
6
6
use std:: collections:: HashMap ;
7
7
use std:: sync:: Arc ;
8
8
9
+ use handlers:: list_const;
10
+ use hugr_core:: std_extensions:: collections:: array:: array_type_def;
11
+ use hugr_core:: std_extensions:: collections:: list:: list_type_def;
9
12
use thiserror:: Error ;
10
13
11
14
use hugr_core:: builder:: { BuildError , BuildHandle , Dataflow } ;
@@ -19,7 +22,7 @@ use hugr_core::ops::{
19
22
Value , CFG , DFG ,
20
23
} ;
21
24
use hugr_core:: types:: {
22
- CustomType , Signature , Transformable , Type , TypeArg , TypeEnum , TypeTransformer ,
25
+ ConstTypeError , CustomType , Signature , Transformable , Type , TypeArg , TypeEnum , TypeTransformer ,
23
26
} ;
24
27
use hugr_core:: { Hugr , HugrView , Node , Wire } ;
25
28
@@ -125,7 +128,7 @@ impl NodeTemplate {
125
128
/// * See also limitations noted for [Linearizer].
126
129
///
127
130
/// [monomorphization]: super::monomorphize()
128
- #[ derive( Clone , Default ) ]
131
+ #[ derive( Clone ) ]
129
132
pub struct ReplaceTypes {
130
133
type_map : HashMap < CustomType , Type > ,
131
134
param_types : HashMap < ParametricType , Arc < dyn Fn ( & [ TypeArg ] ) -> Option < Type > > > ,
@@ -143,6 +146,16 @@ pub struct ReplaceTypes {
143
146
validation : ValidationLevel ,
144
147
}
145
148
149
+ impl Default for ReplaceTypes {
150
+ fn default ( ) -> Self {
151
+ let mut res = Self :: new_empty ( ) ;
152
+ res. linearize = DelegatingLinearizer :: default ( ) ;
153
+ res. replace_consts_parametrized ( array_type_def ( ) , handlers:: array_const) ;
154
+ res. replace_consts_parametrized ( list_type_def ( ) , list_const) ;
155
+ res
156
+ }
157
+ }
158
+
146
159
impl TypeTransformer for ReplaceTypes {
147
160
type Err = ReplaceTypesError ;
148
161
@@ -173,10 +186,27 @@ pub enum ReplaceTypesError {
173
186
#[ error( transparent) ]
174
187
ValidationError ( #[ from] ValidatePassError ) ,
175
188
#[ error( transparent) ]
189
+ ConstError ( #[ from] ConstTypeError ) ,
190
+ #[ error( transparent) ]
176
191
LinearizeError ( #[ from] LinearizeError ) ,
177
192
}
178
193
179
194
impl ReplaceTypes {
195
+ /// Makes a new instance. Unlike [Self::default], this does not understand
196
+ /// any extension types, even those in the prelude.
197
+ pub fn new_empty ( ) -> Self {
198
+ Self {
199
+ type_map : Default :: default ( ) ,
200
+ param_types : Default :: default ( ) ,
201
+ linearize : DelegatingLinearizer :: new_empty ( ) ,
202
+ op_map : Default :: default ( ) ,
203
+ param_ops : Default :: default ( ) ,
204
+ consts : Default :: default ( ) ,
205
+ param_consts : Default :: default ( ) ,
206
+ validation : Default :: default ( ) ,
207
+ }
208
+ }
209
+
180
210
/// Sets the validation level used before and after the pass is run.
181
211
pub fn validation_level ( mut self , level : ValidationLevel ) -> Self {
182
212
self . validation = level;
@@ -447,38 +477,7 @@ impl ReplaceTypes {
447
477
}
448
478
}
449
479
450
- pub mod handlers {
451
- //! Callbacks for use with [ReplaceTypes::replace_consts_parametrized]
452
- use hugr_core:: ops:: { constant:: OpaqueValue , Value } ;
453
- use hugr_core:: std_extensions:: collections:: list:: ListValue ;
454
- use hugr_core:: types:: Transformable ;
455
-
456
- use super :: { ReplaceTypes , ReplaceTypesError } ;
457
-
458
- /// Handler for [ListValue] constants that recursively [ReplaceTypes::change_value]s
459
- /// the elements of the list
460
- pub fn list_const (
461
- val : & OpaqueValue ,
462
- repl : & ReplaceTypes ,
463
- ) -> Result < Option < Value > , ReplaceTypesError > {
464
- let Some ( lv) = val. value ( ) . downcast_ref :: < ListValue > ( ) else {
465
- return Ok ( None ) ;
466
- } ;
467
- let mut vals: Vec < Value > = lv. get_contents ( ) . to_vec ( ) ;
468
- let mut ch = false ;
469
- for v in vals. iter_mut ( ) {
470
- ch |= repl. change_value ( v) ?;
471
- }
472
- // If none of the values has changed, assume the Type hasn't (Values have a single known type)
473
- if !ch {
474
- return Ok ( None ) ;
475
- } ;
476
-
477
- let mut elem_t = lv. get_element_type ( ) . clone ( ) ;
478
- elem_t. transform ( repl) ?;
479
- Ok ( Some ( ListValue :: new ( elem_t, vals) . into ( ) ) )
480
- }
481
- }
480
+ pub mod handlers;
482
481
483
482
#[ derive( Clone , Hash , PartialEq , Eq ) ]
484
483
struct OpHashWrapper {
@@ -536,20 +535,26 @@ mod test {
536
535
use hugr_core:: extension:: simple_op:: MakeExtensionOp ;
537
536
use hugr_core:: extension:: { TypeDefBound , Version } ;
538
537
538
+ use hugr_core:: ops:: constant:: OpaqueValue ;
539
539
use hugr_core:: ops:: { ExtensionOp , NamedOp , OpTrait , OpType , Tag , Value } ;
540
540
use hugr_core:: std_extensions:: arithmetic:: int_types:: ConstInt ;
541
541
use hugr_core:: std_extensions:: arithmetic:: { conversions:: ConvertOpDef , int_types:: INT_TYPES } ;
542
542
use hugr_core:: std_extensions:: collections:: array:: {
543
- array_type, ArrayOp , ArrayOpDef , ArrayValue ,
543
+ array_type, array_type_def , ArrayOp , ArrayOpDef , ArrayValue ,
544
544
} ;
545
545
use hugr_core:: std_extensions:: collections:: list:: {
546
546
list_type, list_type_def, ListOp , ListValue ,
547
547
} ;
548
548
549
+ use hugr_core:: hugr:: ValidationError ;
549
550
use hugr_core:: types:: { PolyFuncType , Signature , SumType , Type , TypeArg , TypeBound , TypeRow } ;
550
551
use hugr_core:: { hugr:: IdentList , type_row, Extension , HugrView } ;
551
552
use itertools:: Itertools ;
553
+ use rstest:: rstest;
554
+
555
+ use crate :: validation:: ValidatePassError ;
552
556
557
+ use super :: ReplaceTypesError ;
553
558
use super :: { handlers:: list_const, NodeTemplate , ReplaceTypes } ;
554
559
555
560
const PACKED_VEC : & str = "PackedVec" ;
@@ -792,8 +797,6 @@ mod test {
792
797
let backup = tl. finish_hugr ( ) . unwrap ( ) ;
793
798
794
799
let mut lowerer = ReplaceTypes :: default ( ) ;
795
- // Recursively descend into lists
796
- lowerer. replace_consts_parametrized ( list_type_def ( ) , list_const) ;
797
800
798
801
// 1. Lower List<T> to Array<10, T> UNLESS T is usize_t() or i64_t
799
802
lowerer. replace_parametrized_type ( list_type_def ( ) , |args| {
@@ -951,4 +954,38 @@ mod test {
951
954
[ "NoBoundsCheck.read" , "collections.list.get" ]
952
955
) ;
953
956
}
957
+
958
+ #[ rstest]
959
+ #[ case( & [ ] ) ]
960
+ #[ case( & [ 3 ] ) ]
961
+ #[ case( & [ 5 , 7 , 11 , 13 , 17 , 19 ] ) ]
962
+ fn array_const ( #[ case] vals : & [ u64 ] ) {
963
+ use super :: handlers:: array_const;
964
+ let mut dfb = DFGBuilder :: new ( inout_sig (
965
+ type_row ! [ ] ,
966
+ array_type ( vals. len ( ) as _ , usize_t ( ) ) ,
967
+ ) )
968
+ . unwrap ( ) ;
969
+ let c = dfb. add_load_value ( ArrayValue :: new (
970
+ usize_t ( ) ,
971
+ vals. iter ( ) . map ( |u| ConstUsize :: new ( * u) . into ( ) ) ,
972
+ ) ) ;
973
+ let backup = dfb. finish_hugr_with_outputs ( [ c] ) . unwrap ( ) ;
974
+
975
+ let mut repl = ReplaceTypes :: new_empty ( ) ;
976
+ let usize_custom_t = usize_t ( ) . as_extension ( ) . unwrap ( ) . clone ( ) ;
977
+ repl. replace_type ( usize_custom_t. clone ( ) , INT_TYPES [ 6 ] . clone ( ) ) ;
978
+ repl. replace_consts ( usize_custom_t, |cst : & OpaqueValue , _| {
979
+ let cu = cst. value ( ) . downcast_ref :: < ConstUsize > ( ) . unwrap ( ) ;
980
+ Ok ( ConstInt :: new_u ( 6 , cu. value ( ) ) ?. into ( ) )
981
+ } ) ;
982
+ assert ! (
983
+ matches!( repl. run( & mut backup. clone( ) ) , Err ( ReplaceTypesError :: ValidationError ( ValidatePassError :: OutputError {
984
+ err: ValidationError :: IncompatiblePorts { from, to, ..} , ..
985
+ } ) ) if backup. get_optype( from) . is_const( ) && to == c. node( ) )
986
+ ) ;
987
+ repl. replace_consts_parametrized ( array_type_def ( ) , array_const) ;
988
+ let mut h = backup;
989
+ repl. run ( & mut h) . unwrap ( ) ; // Includes validation
990
+ }
954
991
}
0 commit comments