|
1 | 1 | use std::{collections::HashMap, sync::Arc};
|
2 | 2 |
|
3 | 3 | use hugr_core::{
|
4 |
| - builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, |
5 |
| - extension::TypeDef, |
| 4 | + builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr}, |
| 5 | + extension::{ |
| 6 | + prelude::{option_type, UnwrapBuilder}, |
| 7 | + ExtensionSet, TypeDef, |
| 8 | + }, |
6 | 9 | hugr::hugrmut::HugrMut,
|
7 |
| - ops::Value, |
8 |
| - std_extensions::collections::array::{array_type, ArrayScan}, |
| 10 | + ops::{OpTrait, OpType, Tag, Value}, |
| 11 | + std_extensions::{ |
| 12 | + arithmetic::{ |
| 13 | + conversions::ConvertOpDef, |
| 14 | + int_ops::IntOpDef, |
| 15 | + int_types::{ConstInt, INT_TYPES}, |
| 16 | + }, |
| 17 | + collections::array::{array_type, ArrayOpDef, ArrayRepeat, ArrayScan}, |
| 18 | + }, |
9 | 19 | type_row,
|
10 |
| - types::{Type, TypeArg, TypeEnum}, |
11 |
| - HugrView, IncomingPort, Node, OutgoingPort, |
| 20 | + types::{SumType, Type, TypeArg, TypeEnum}, |
| 21 | + Hugr, HugrView, IncomingPort, Node, OutgoingPort, |
12 | 22 | };
|
13 | 23 |
|
14 | 24 | use super::{OpReplacement, ParametricType};
|
@@ -180,17 +190,139 @@ pub fn discard_array(args: &[TypeArg], lin: &Linearizer) -> Result<OpReplacement
|
180 | 190 | })))
|
181 | 191 | }
|
182 | 192 |
|
| 193 | +pub fn copy_array(args: &[TypeArg], lin: &Linearizer) -> Result<OpReplacement, LinearizeError> { |
| 194 | + // Require known length i.e. usable only after monomorphization, due to no-variables limitation |
| 195 | + // restriction on OpReplacement::CompoundOp |
| 196 | + let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = args else { |
| 197 | + panic!("Illegal TypeArgs to array: {:?}", args) |
| 198 | + }; |
| 199 | + let i64_t = INT_TYPES[6].to_owned(); |
| 200 | + let option_sty = option_type(ty.clone()); |
| 201 | + let option_ty = Type::from(option_sty.clone()); |
| 202 | + let option_array = array_type(*n, option_ty.clone()); |
| 203 | + let copy_elem = { |
| 204 | + let mut dfb = DFGBuilder::new(endo_sig(vec![ |
| 205 | + ty.clone(), |
| 206 | + i64_t.clone(), |
| 207 | + option_array.clone(), |
| 208 | + ])) |
| 209 | + .unwrap(); |
| 210 | + let [elem, idx, opt_array] = dfb.input_wires_arr(); |
| 211 | + let [idx_usz] = dfb |
| 212 | + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) |
| 213 | + .unwrap() |
| 214 | + .outputs_arr(); |
| 215 | + let [copy0, copy1] = lin |
| 216 | + .copy_op(ty)? |
| 217 | + .add(&mut dfb, [elem]) |
| 218 | + .unwrap() |
| 219 | + .outputs_arr(); |
| 220 | + |
| 221 | + //let variants = option_ty.variants().cloned().map(|row| row.try_into().unwrap()).collect(); |
| 222 | + let [tag] = dfb |
| 223 | + .add_dataflow_op(Tag::new(1, vec![type_row![], ty.clone().into()]), [copy1]) |
| 224 | + .unwrap() |
| 225 | + .outputs_arr(); |
| 226 | + let set_op = OpType::from(ArrayOpDef::set.to_concrete(option_ty.clone(), *n)); |
| 227 | + let TypeEnum::Sum(either_st) = set_op.dataflow_signature().unwrap().output[0] |
| 228 | + .as_type_enum() |
| 229 | + .clone() |
| 230 | + else { |
| 231 | + panic!("ArrayOp::set has one output of sum type") |
| 232 | + }; |
| 233 | + let [set_result] = dfb |
| 234 | + .add_dataflow_op(set_op, [opt_array, idx_usz, tag]) |
| 235 | + .unwrap() |
| 236 | + .outputs_arr(); |
| 237 | + // set should always be successful |
| 238 | + let [none, opt_array] = dfb.build_unwrap_sum(1, either_st, set_result).unwrap(); |
| 239 | + //the removed element is an option, which should always be none (and thus discardable) |
| 240 | + let [] = dfb |
| 241 | + .build_unwrap_sum(0, SumType::new_option(ty.clone()), none) |
| 242 | + .unwrap(); |
| 243 | + |
| 244 | + let cst1 = dfb.add_load_value(ConstInt::new_u(6, 1).unwrap()); |
| 245 | + let [new_idx] = dfb |
| 246 | + .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [idx, cst1]) |
| 247 | + .unwrap() |
| 248 | + .outputs_arr(); |
| 249 | + dfb.finish_hugr_with_outputs([copy0, new_idx, opt_array]) |
| 250 | + .unwrap() |
| 251 | + }; |
| 252 | + let unwrap_elem = { |
| 253 | + let mut dfb = |
| 254 | + DFGBuilder::new(inout_sig(Type::from(option_ty.clone()), ty.clone())).unwrap(); |
| 255 | + let [opt] = dfb.input_wires_arr(); |
| 256 | + let [val] = dfb.build_unwrap_sum(1, option_sty.clone(), opt).unwrap(); |
| 257 | + dfb.finish_hugr_with_outputs([val]).unwrap() |
| 258 | + }; |
| 259 | + let array_ty = array_type(*n, ty.clone()); |
| 260 | + let mut dfb = DFGBuilder::new(inout_sig(array_ty.clone(), vec![array_ty.clone(); 2])).unwrap(); |
| 261 | + let [in_array] = dfb.input_wires_arr(); |
| 262 | + // First, Scan `copy_elem` to get the original array back and an array of Some's of the copies |
| 263 | + let scan1 = ArrayScan::new( |
| 264 | + ty.clone(), |
| 265 | + ty.clone(), |
| 266 | + vec![i64_t, option_array], |
| 267 | + *n, |
| 268 | + runtime_reqs(©_elem), |
| 269 | + ); |
| 270 | + let copy_elem = dfb.add_load_value(Value::function(copy_elem).unwrap()); |
| 271 | + let cst0 = dfb.add_load_value(ConstInt::new_u(6, 0).unwrap()); |
| 272 | + let [array_of_none] = { |
| 273 | + let fn_none = { |
| 274 | + let mut dfb = DFGBuilder::new(inout_sig(vec![], option_ty.clone())).unwrap(); |
| 275 | + let none = dfb |
| 276 | + .add_dataflow_op(Tag::new(0, vec![type_row![], ty.clone().into()]), []) |
| 277 | + .unwrap(); |
| 278 | + dfb.finish_hugr_with_outputs(none.outputs()).unwrap() |
| 279 | + }; |
| 280 | + let rpt = ArrayRepeat::new(option_ty.clone(), *n, runtime_reqs(&fn_none)); |
| 281 | + let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap()); |
| 282 | + dfb.add_dataflow_op(rpt, [fn_none]).unwrap() |
| 283 | + } |
| 284 | + .outputs_arr(); |
| 285 | + let [out_array1, _idx_out, opt_array] = dfb |
| 286 | + .add_dataflow_op(scan1, [in_array, copy_elem, cst0, array_of_none]) |
| 287 | + .unwrap() |
| 288 | + .outputs_arr(); |
| 289 | + |
| 290 | + let scan2 = ArrayScan::new( |
| 291 | + option_ty.clone(), |
| 292 | + ty.clone(), |
| 293 | + vec![], |
| 294 | + *n, |
| 295 | + runtime_reqs(&unwrap_elem), |
| 296 | + ); |
| 297 | + let unwrap_elem = dfb.add_load_value(Value::function(unwrap_elem).unwrap()); |
| 298 | + let [out_array2] = dfb |
| 299 | + .add_dataflow_op(scan2, [opt_array, unwrap_elem]) |
| 300 | + .unwrap() |
| 301 | + .outputs_arr(); |
| 302 | + |
| 303 | + Ok(OpReplacement::CompoundOp(Box::new( |
| 304 | + dfb.finish_hugr_with_outputs([out_array1, out_array2]) |
| 305 | + .unwrap(), |
| 306 | + ))) |
| 307 | +} |
| 308 | + |
183 | 309 | #[cfg(test)]
|
184 | 310 | mod test {
|
185 | 311 | use std::collections::HashMap;
|
| 312 | + use std::iter::successors; |
186 | 313 | use std::sync::Arc;
|
187 | 314 |
|
188 |
| - use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}; |
| 315 | + use hugr_core::builder::{ |
| 316 | + endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, |
| 317 | + }; |
| 318 | + use hugr_core::extension::prelude::option_type; |
189 | 319 | use hugr_core::extension::simple_op::MakeExtensionOp;
|
190 | 320 | use hugr_core::extension::{prelude::usize_t, TypeDefBound, Version};
|
191 |
| - use hugr_core::ops::{ExtensionOp, NamedOp, OpName}; |
| 321 | + use hugr_core::ops::handle::NodeHandle; |
| 322 | + use hugr_core::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpName}; |
| 323 | + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; |
192 | 324 | use hugr_core::std_extensions::collections::array::{
|
193 |
| - self, array_type, ArrayOpDef, ArrayScanDef, ARRAY_TYPENAME, |
| 325 | + self, array_type, ArrayOpDef, ArrayRepeat, ArrayScan, ArrayScanDef, ARRAY_TYPENAME, |
194 | 326 | };
|
195 | 327 | use hugr_core::types::{Signature, Type, TypeEnum, TypeRow};
|
196 | 328 | use hugr_core::{hugr::IdentList, type_row, Extension, HugrView};
|
@@ -282,32 +414,83 @@ mod test {
|
282 | 414 | }
|
283 | 415 |
|
284 | 416 | #[test]
|
285 |
| - fn discard_array() { |
286 |
| - let (_e, mut lowerer) = ext_lowerer(); |
| 417 | + fn array() { |
| 418 | + let (e, mut lowerer) = ext_lowerer(); |
287 | 419 |
|
288 | 420 | lowerer.linearize_parametric(
|
289 | 421 | array::EXTENSION.get_type(ARRAY_TYPENAME.as_str()).unwrap(),
|
290 |
| - Box::new(|_, _| panic!("No copy yet")), |
| 422 | + Box::new(super::copy_array), |
291 | 423 | Box::new(super::discard_array),
|
292 | 424 | );
|
293 |
| - |
294 |
| - let mut h = DFGBuilder::new(Signature::new(array_type(5, usize_t()), type_row![])) |
| 425 | + let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); |
| 426 | + let opt_lin_ty = Type::from(option_type(lin_t.clone())); |
| 427 | + let mut dfb = DFGBuilder::new(endo_sig(array_type(5, usize_t()))).unwrap(); |
| 428 | + let [array_in] = dfb.input_wires_arr(); |
| 429 | + // The outer DFG passes the input array into (1) a DFG that discards it |
| 430 | + let discard = dfb |
| 431 | + .dfg_builder( |
| 432 | + Signature::new(array_type(5, usize_t()), type_row![]), |
| 433 | + [array_in], |
| 434 | + ) |
295 | 435 | .unwrap()
|
296 |
| - .finish_hugr_with_outputs([]) |
| 436 | + .finish_with_outputs([]) |
297 | 437 | .unwrap();
|
| 438 | + // and (2) its own output |
| 439 | + let mut h = dfb.finish_hugr_with_outputs([array_in]).unwrap(); |
298 | 440 |
|
299 | 441 | assert!(lowerer.run(&mut h).unwrap());
|
300 | 442 |
|
301 |
| - let ext_ops = h |
| 443 | + let (discard_ops, copy_ops): (Vec<_>, Vec<_>) = h |
302 | 444 | .nodes()
|
303 | 445 | .filter_map(|n| h.get_optype(n).as_extension_op().map(|e| (n, e)))
|
304 |
| - .collect_vec(); |
305 |
| - let [(n, ext_op)] = ext_ops.try_into().unwrap(); |
306 |
| - assert!(ArrayScanDef::from_extension_op(ext_op).is_ok()); |
| 446 | + .partition(|(n, _)| { |
| 447 | + successors(Some(*n), |n| h.get_parent(*n)).contains(&discard.node()) |
| 448 | + }); |
| 449 | + { |
| 450 | + let [(n, ext_op)] = discard_ops.try_into().unwrap(); |
| 451 | + assert!(ArrayScanDef::from_extension_op(ext_op).is_ok()); |
| 452 | + assert_eq!( |
| 453 | + ext_op.signature().output, |
| 454 | + TypeRow::from(vec![array_type(5, Type::UNIT)]) |
| 455 | + ); |
| 456 | + assert_eq!(h.linked_inputs(n, 0).next(), None); |
| 457 | + } |
| 458 | + assert_eq!(copy_ops.len(), 3); |
| 459 | + let copy_ops = copy_ops.into_iter().map(|(_, e)| e).collect_vec(); |
| 460 | + let rpt = *copy_ops |
| 461 | + .iter() |
| 462 | + .find(|e| ArrayRepeat::from_extension_op(e).is_ok()) |
| 463 | + .unwrap(); |
307 | 464 | assert_eq!(
|
308 |
| - ext_op.clone().signature_mut().output(), |
309 |
| - &TypeRow::from(vec![array_type(5, Type::UNIT)]) |
| 465 | + rpt.signature().output(), |
| 466 | + &TypeRow::from(array_type(5, opt_lin_ty.clone())) |
| 467 | + ); |
| 468 | + let scan0 = copy_ops |
| 469 | + .iter() |
| 470 | + .find_map(|e| { |
| 471 | + ArrayScan::from_extension_op(e) |
| 472 | + .ok() |
| 473 | + .filter(|sc| sc.acc_tys.is_empty()) |
| 474 | + }) |
| 475 | + .unwrap(); |
| 476 | + assert_eq!(scan0.src_ty, opt_lin_ty); |
| 477 | + assert_eq!(scan0.tgt_ty, lin_t); |
| 478 | + |
| 479 | + let scan2 = *copy_ops |
| 480 | + .iter() |
| 481 | + .find(|e| ArrayScan::from_extension_op(e).is_ok_and(|sc| !sc.acc_tys.is_empty())) |
| 482 | + .unwrap(); |
| 483 | + let sig = scan2.signature().into_owned(); |
| 484 | + assert_eq!( |
| 485 | + sig.output, |
| 486 | + TypeRow::from(vec![ |
| 487 | + array_type(5, lin_t.clone()), |
| 488 | + INT_TYPES[6].to_owned(), |
| 489 | + array_type(5, option_type(lin_t.clone()).into()) |
| 490 | + ]) |
310 | 491 | );
|
311 |
| - assert_eq!(h.linked_inputs(n, 0).next(), None); |
| 492 | + assert_eq!(sig.input[0], sig.output[0]); |
| 493 | + assert!(matches!(sig.input[1].as_type_enum(), TypeEnum::Function(_))); |
| 494 | + assert_eq!(sig.input[2..], sig.output[1..]); |
312 | 495 | }
|
313 | 496 | }
|
0 commit comments