Skip to content

Commit f297239

Browse files
committed
copy_array and test
1 parent b942088 commit f297239

File tree

1 file changed

+205
-22
lines changed

1 file changed

+205
-22
lines changed

Diff for: hugr-passes/src/lower_types/linearize.rs

+205-22
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
use std::{collections::HashMap, sync::Arc};
22

33
use hugr_core::{
4-
builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr},
5-
extension::TypeDef,
4+
builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr},
5+
extension::{
6+
prelude::{option_type, UnwrapBuilder},
7+
ExtensionSet, TypeDef,
8+
},
69
hugr::hugrmut::HugrMut,
7-
ops::Value,
8-
std_extensions::collections::array::{array_type, ArrayScan},
10+
ops::{OpTrait, OpType, Tag, Value},
11+
std_extensions::{
12+
arithmetic::{
13+
conversions::ConvertOpDef,
14+
int_ops::IntOpDef,
15+
int_types::{ConstInt, INT_TYPES},
16+
},
17+
collections::array::{array_type, ArrayOpDef, ArrayRepeat, ArrayScan},
18+
},
919
type_row,
10-
types::{Type, TypeArg, TypeEnum},
11-
HugrView, IncomingPort, Node, OutgoingPort,
20+
types::{SumType, Type, TypeArg, TypeEnum},
21+
Hugr, HugrView, IncomingPort, Node, OutgoingPort,
1222
};
1323

1424
use super::{OpReplacement, ParametricType};
@@ -180,17 +190,139 @@ pub fn discard_array(args: &[TypeArg], lin: &Linearizer) -> Result<OpReplacement
180190
})))
181191
}
182192

193+
pub fn copy_array(args: &[TypeArg], lin: &Linearizer) -> Result<OpReplacement, LinearizeError> {
194+
// Require known length i.e. usable only after monomorphization, due to no-variables limitation
195+
// restriction on OpReplacement::CompoundOp
196+
let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = args else {
197+
panic!("Illegal TypeArgs to array: {:?}", args)
198+
};
199+
let i64_t = INT_TYPES[6].to_owned();
200+
let option_sty = option_type(ty.clone());
201+
let option_ty = Type::from(option_sty.clone());
202+
let option_array = array_type(*n, option_ty.clone());
203+
let copy_elem = {
204+
let mut dfb = DFGBuilder::new(endo_sig(vec![
205+
ty.clone(),
206+
i64_t.clone(),
207+
option_array.clone(),
208+
]))
209+
.unwrap();
210+
let [elem, idx, opt_array] = dfb.input_wires_arr();
211+
let [idx_usz] = dfb
212+
.add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx])
213+
.unwrap()
214+
.outputs_arr();
215+
let [copy0, copy1] = lin
216+
.copy_op(ty)?
217+
.add(&mut dfb, [elem])
218+
.unwrap()
219+
.outputs_arr();
220+
221+
//let variants = option_ty.variants().cloned().map(|row| row.try_into().unwrap()).collect();
222+
let [tag] = dfb
223+
.add_dataflow_op(Tag::new(1, vec![type_row![], ty.clone().into()]), [copy1])
224+
.unwrap()
225+
.outputs_arr();
226+
let set_op = OpType::from(ArrayOpDef::set.to_concrete(option_ty.clone(), *n));
227+
let TypeEnum::Sum(either_st) = set_op.dataflow_signature().unwrap().output[0]
228+
.as_type_enum()
229+
.clone()
230+
else {
231+
panic!("ArrayOp::set has one output of sum type")
232+
};
233+
let [set_result] = dfb
234+
.add_dataflow_op(set_op, [opt_array, idx_usz, tag])
235+
.unwrap()
236+
.outputs_arr();
237+
// set should always be successful
238+
let [none, opt_array] = dfb.build_unwrap_sum(1, either_st, set_result).unwrap();
239+
//the removed element is an option, which should always be none (and thus discardable)
240+
let [] = dfb
241+
.build_unwrap_sum(0, SumType::new_option(ty.clone()), none)
242+
.unwrap();
243+
244+
let cst1 = dfb.add_load_value(ConstInt::new_u(6, 1).unwrap());
245+
let [new_idx] = dfb
246+
.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [idx, cst1])
247+
.unwrap()
248+
.outputs_arr();
249+
dfb.finish_hugr_with_outputs([copy0, new_idx, opt_array])
250+
.unwrap()
251+
};
252+
let unwrap_elem = {
253+
let mut dfb =
254+
DFGBuilder::new(inout_sig(Type::from(option_ty.clone()), ty.clone())).unwrap();
255+
let [opt] = dfb.input_wires_arr();
256+
let [val] = dfb.build_unwrap_sum(1, option_sty.clone(), opt).unwrap();
257+
dfb.finish_hugr_with_outputs([val]).unwrap()
258+
};
259+
let array_ty = array_type(*n, ty.clone());
260+
let mut dfb = DFGBuilder::new(inout_sig(array_ty.clone(), vec![array_ty.clone(); 2])).unwrap();
261+
let [in_array] = dfb.input_wires_arr();
262+
// First, Scan `copy_elem` to get the original array back and an array of Some's of the copies
263+
let scan1 = ArrayScan::new(
264+
ty.clone(),
265+
ty.clone(),
266+
vec![i64_t, option_array],
267+
*n,
268+
runtime_reqs(&copy_elem),
269+
);
270+
let copy_elem = dfb.add_load_value(Value::function(copy_elem).unwrap());
271+
let cst0 = dfb.add_load_value(ConstInt::new_u(6, 0).unwrap());
272+
let [array_of_none] = {
273+
let fn_none = {
274+
let mut dfb = DFGBuilder::new(inout_sig(vec![], option_ty.clone())).unwrap();
275+
let none = dfb
276+
.add_dataflow_op(Tag::new(0, vec![type_row![], ty.clone().into()]), [])
277+
.unwrap();
278+
dfb.finish_hugr_with_outputs(none.outputs()).unwrap()
279+
};
280+
let rpt = ArrayRepeat::new(option_ty.clone(), *n, runtime_reqs(&fn_none));
281+
let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap());
282+
dfb.add_dataflow_op(rpt, [fn_none]).unwrap()
283+
}
284+
.outputs_arr();
285+
let [out_array1, _idx_out, opt_array] = dfb
286+
.add_dataflow_op(scan1, [in_array, copy_elem, cst0, array_of_none])
287+
.unwrap()
288+
.outputs_arr();
289+
290+
let scan2 = ArrayScan::new(
291+
option_ty.clone(),
292+
ty.clone(),
293+
vec![],
294+
*n,
295+
runtime_reqs(&unwrap_elem),
296+
);
297+
let unwrap_elem = dfb.add_load_value(Value::function(unwrap_elem).unwrap());
298+
let [out_array2] = dfb
299+
.add_dataflow_op(scan2, [opt_array, unwrap_elem])
300+
.unwrap()
301+
.outputs_arr();
302+
303+
Ok(OpReplacement::CompoundOp(Box::new(
304+
dfb.finish_hugr_with_outputs([out_array1, out_array2])
305+
.unwrap(),
306+
)))
307+
}
308+
183309
#[cfg(test)]
184310
mod test {
185311
use std::collections::HashMap;
312+
use std::iter::successors;
186313
use std::sync::Arc;
187314

188-
use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr};
315+
use hugr_core::builder::{
316+
endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
317+
};
318+
use hugr_core::extension::prelude::option_type;
189319
use hugr_core::extension::simple_op::MakeExtensionOp;
190320
use hugr_core::extension::{prelude::usize_t, TypeDefBound, Version};
191-
use hugr_core::ops::{ExtensionOp, NamedOp, OpName};
321+
use hugr_core::ops::handle::NodeHandle;
322+
use hugr_core::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpName};
323+
use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES;
192324
use hugr_core::std_extensions::collections::array::{
193-
self, array_type, ArrayOpDef, ArrayScanDef, ARRAY_TYPENAME,
325+
self, array_type, ArrayOpDef, ArrayRepeat, ArrayScan, ArrayScanDef, ARRAY_TYPENAME,
194326
};
195327
use hugr_core::types::{Signature, Type, TypeEnum, TypeRow};
196328
use hugr_core::{hugr::IdentList, type_row, Extension, HugrView};
@@ -282,32 +414,83 @@ mod test {
282414
}
283415

284416
#[test]
285-
fn discard_array() {
286-
let (_e, mut lowerer) = ext_lowerer();
417+
fn array() {
418+
let (e, mut lowerer) = ext_lowerer();
287419

288420
lowerer.linearize_parametric(
289421
array::EXTENSION.get_type(ARRAY_TYPENAME.as_str()).unwrap(),
290-
Box::new(|_, _| panic!("No copy yet")),
422+
Box::new(super::copy_array),
291423
Box::new(super::discard_array),
292424
);
293-
294-
let mut h = DFGBuilder::new(Signature::new(array_type(5, usize_t()), type_row![]))
425+
let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap());
426+
let opt_lin_ty = Type::from(option_type(lin_t.clone()));
427+
let mut dfb = DFGBuilder::new(endo_sig(array_type(5, usize_t()))).unwrap();
428+
let [array_in] = dfb.input_wires_arr();
429+
// The outer DFG passes the input array into (1) a DFG that discards it
430+
let discard = dfb
431+
.dfg_builder(
432+
Signature::new(array_type(5, usize_t()), type_row![]),
433+
[array_in],
434+
)
295435
.unwrap()
296-
.finish_hugr_with_outputs([])
436+
.finish_with_outputs([])
297437
.unwrap();
438+
// and (2) its own output
439+
let mut h = dfb.finish_hugr_with_outputs([array_in]).unwrap();
298440

299441
assert!(lowerer.run(&mut h).unwrap());
300442

301-
let ext_ops = h
443+
let (discard_ops, copy_ops): (Vec<_>, Vec<_>) = h
302444
.nodes()
303445
.filter_map(|n| h.get_optype(n).as_extension_op().map(|e| (n, e)))
304-
.collect_vec();
305-
let [(n, ext_op)] = ext_ops.try_into().unwrap();
306-
assert!(ArrayScanDef::from_extension_op(ext_op).is_ok());
446+
.partition(|(n, _)| {
447+
successors(Some(*n), |n| h.get_parent(*n)).contains(&discard.node())
448+
});
449+
{
450+
let [(n, ext_op)] = discard_ops.try_into().unwrap();
451+
assert!(ArrayScanDef::from_extension_op(ext_op).is_ok());
452+
assert_eq!(
453+
ext_op.signature().output,
454+
TypeRow::from(vec![array_type(5, Type::UNIT)])
455+
);
456+
assert_eq!(h.linked_inputs(n, 0).next(), None);
457+
}
458+
assert_eq!(copy_ops.len(), 3);
459+
let copy_ops = copy_ops.into_iter().map(|(_, e)| e).collect_vec();
460+
let rpt = *copy_ops
461+
.iter()
462+
.find(|e| ArrayRepeat::from_extension_op(e).is_ok())
463+
.unwrap();
307464
assert_eq!(
308-
ext_op.clone().signature_mut().output(),
309-
&TypeRow::from(vec![array_type(5, Type::UNIT)])
465+
rpt.signature().output(),
466+
&TypeRow::from(array_type(5, opt_lin_ty.clone()))
467+
);
468+
let scan0 = copy_ops
469+
.iter()
470+
.find_map(|e| {
471+
ArrayScan::from_extension_op(e)
472+
.ok()
473+
.filter(|sc| sc.acc_tys.is_empty())
474+
})
475+
.unwrap();
476+
assert_eq!(scan0.src_ty, opt_lin_ty);
477+
assert_eq!(scan0.tgt_ty, lin_t);
478+
479+
let scan2 = *copy_ops
480+
.iter()
481+
.find(|e| ArrayScan::from_extension_op(e).is_ok_and(|sc| !sc.acc_tys.is_empty()))
482+
.unwrap();
483+
let sig = scan2.signature().into_owned();
484+
assert_eq!(
485+
sig.output,
486+
TypeRow::from(vec![
487+
array_type(5, lin_t.clone()),
488+
INT_TYPES[6].to_owned(),
489+
array_type(5, option_type(lin_t.clone()).into())
490+
])
310491
);
311-
assert_eq!(h.linked_inputs(n, 0).next(), None);
492+
assert_eq!(sig.input[0], sig.output[0]);
493+
assert!(matches!(sig.input[1].as_type_enum(), TypeEnum::Function(_)));
494+
assert_eq!(sig.input[2..], sig.output[1..]);
312495
}
313496
}

0 commit comments

Comments
 (0)