@@ -56,14 +56,21 @@ static std::string getGateFunctionPrefix(Operation *op) {
56
56
57
57
constexpr std::array<std::string_view, 2 > filterAdjointNames = {" s" , " t" };
58
58
59
- template <typename OP>
59
+ template <typename M, typename OP>
60
60
std::pair<std::string, bool > generateGateFunctionName (OP op) {
61
61
auto prefix = getGateFunctionPrefix (op.getOperation ());
62
62
auto gateName = getGateName (op.getOperation ());
63
63
if (op.isAdj ()) {
64
64
if (std::find (filterAdjointNames.begin (), filterAdjointNames.end (),
65
- gateName) != filterAdjointNames.end ())
66
- prefix += " dg" ;
65
+ gateName) != filterAdjointNames.end ()) {
66
+ if constexpr (M::dgSuffix) {
67
+ prefix += " dg" ;
68
+ } else {
69
+ if (!op.getControls ().empty ())
70
+ return {prefix + " dg__ctl" , false };
71
+ return {prefix + " __adj" , false };
72
+ }
73
+ }
67
74
}
68
75
if (!op.getControls ().empty ())
69
76
return {prefix + " __ctl" , false };
@@ -429,8 +436,14 @@ struct DiscriminateOpToCallRewrite
429
436
cudaq::opt::QIRReadResultBody,
430
437
adaptor.getOperands ());
431
438
} else {
432
- rewriter.replaceOpWithNewOp <cudaq::cc::PoisonOp>(disc,
433
- rewriter.getI1Type ());
439
+ auto loc = disc.getLoc ();
440
+ // NB: the double cast here is to avoid folding the pointer casts.
441
+ auto i64Ty = rewriter.getI64Type ();
442
+ auto unu =
443
+ rewriter.create <cudaq::cc::CastOp>(loc, i64Ty, adaptor.getOperands ());
444
+ auto ptrI1Ty = cudaq::cc::PointerType::get (rewriter.getI1Type ());
445
+ auto du = rewriter.create <cudaq::cc::CastOp>(loc, ptrI1Ty, unu);
446
+ rewriter.replaceOpWithNewOp <cudaq::cc::LoadOp>(disc, du);
434
447
}
435
448
return success ();
436
449
}
@@ -1465,7 +1478,7 @@ struct FullQIR {
1465
1478
1466
1479
template <typename QuakeOp>
1467
1480
static std::string quakeToFuncName (QuakeOp op) {
1468
- auto [prefix, _] = generateGateFunctionName (op);
1481
+ auto [prefix, _] = generateGateFunctionName<Self> (op);
1469
1482
return prefix;
1470
1483
}
1471
1484
@@ -1523,19 +1536,15 @@ struct FullQIR {
1523
1536
static Type getLLVMPointerType (MLIRContext *ctx) {
1524
1537
return GetLLVMPointerType<opaquePtr>(ctx);
1525
1538
}
1539
+
1540
+ static constexpr bool dgSuffix = true ;
1526
1541
};
1527
1542
1528
1543
// / The base modifier class for the "profile QIR" APIs.
1529
1544
template <bool opaquePtr>
1530
1545
struct AnyProfileQIR {
1531
1546
using Self = AnyProfileQIR;
1532
1547
1533
- template <typename QuakeOp>
1534
- static std::string quakeToFuncName (QuakeOp op) {
1535
- auto [prefix, isBarePrefix] = generateGateFunctionName (op);
1536
- return isBarePrefix ? prefix + " __body" : prefix;
1537
- }
1538
-
1539
1548
static void populateRewritePatterns (RewritePatternSet &patterns,
1540
1549
TypeConverter &typeConverter) {
1541
1550
auto *ctx = patterns.getContext ();
@@ -1546,23 +1555,8 @@ struct AnyProfileQIR {
1546
1555
SubveqOpRewrite<Self>,
1547
1556
1548
1557
/* Irregular quantum operators. */
1549
- CustomUnitaryOpPattern<Self>, ExpPauliOpPattern, ResetOpPattern<Self>,
1550
-
1551
- /* Regular quantum operators. */
1552
- QuantumGatePattern<Self, quake::HOp>,
1553
- QuantumGatePattern<Self, quake::PhasedRxOp>,
1554
- QuantumGatePattern<Self, quake::R1Op>,
1555
- QuantumGatePattern<Self, quake::RxOp>,
1556
- QuantumGatePattern<Self, quake::RyOp>,
1557
- QuantumGatePattern<Self, quake::RzOp>,
1558
- QuantumGatePattern<Self, quake::SOp>,
1559
- QuantumGatePattern<Self, quake::SwapOp>,
1560
- QuantumGatePattern<Self, quake::TOp>,
1561
- QuantumGatePattern<Self, quake::U2Op>,
1562
- QuantumGatePattern<Self, quake::U3Op>,
1563
- QuantumGatePattern<Self, quake::XOp>,
1564
- QuantumGatePattern<Self, quake::YOp>,
1565
- QuantumGatePattern<Self, quake::ZOp>>(typeConverter, ctx);
1558
+ CustomUnitaryOpPattern<Self>, ExpPauliOpPattern, ResetOpPattern<Self>>(
1559
+ typeConverter, ctx);
1566
1560
commonQuakeHandlingPatterns (patterns, typeConverter, ctx);
1567
1561
commonClassicalHandlingPatterns (patterns, typeConverter, ctx);
1568
1562
}
@@ -1597,15 +1591,38 @@ struct BaseProfileQIR : public AnyProfileQIR<opaquePtr> {
1597
1591
using Self = BaseProfileQIR;
1598
1592
using Base = AnyProfileQIR<opaquePtr>;
1599
1593
1594
+ template <typename QuakeOp>
1595
+ static std::string quakeToFuncName (QuakeOp op) {
1596
+ auto [prefix, isBarePrefix] = generateGateFunctionName<Self>(op);
1597
+ return isBarePrefix ? prefix + " __body" : prefix;
1598
+ }
1599
+
1600
1600
static void populateRewritePatterns (RewritePatternSet &patterns,
1601
1601
TypeConverter &typeConverter) {
1602
1602
Base::populateRewritePatterns (patterns, typeConverter);
1603
1603
patterns
1604
- .insert <DiscriminateOpToCallRewrite<Self>, MeasurementOpPattern<Self>>(
1605
- typeConverter, patterns.getContext ());
1604
+ .insert <DiscriminateOpToCallRewrite<Self>, MeasurementOpPattern<Self>,
1605
+
1606
+ /* Regular quantum operators. */
1607
+ QuantumGatePattern<Self, quake::HOp>,
1608
+ QuantumGatePattern<Self, quake::PhasedRxOp>,
1609
+ QuantumGatePattern<Self, quake::R1Op>,
1610
+ QuantumGatePattern<Self, quake::RxOp>,
1611
+ QuantumGatePattern<Self, quake::RyOp>,
1612
+ QuantumGatePattern<Self, quake::RzOp>,
1613
+ QuantumGatePattern<Self, quake::SOp>,
1614
+ QuantumGatePattern<Self, quake::SwapOp>,
1615
+ QuantumGatePattern<Self, quake::TOp>,
1616
+ QuantumGatePattern<Self, quake::U2Op>,
1617
+ QuantumGatePattern<Self, quake::U3Op>,
1618
+ QuantumGatePattern<Self, quake::XOp>,
1619
+ QuantumGatePattern<Self, quake::YOp>,
1620
+ QuantumGatePattern<Self, quake::ZOp>>(typeConverter,
1621
+ patterns.getContext ());
1606
1622
}
1607
1623
1608
1624
static constexpr bool discriminateToClassical = false ;
1625
+ static constexpr bool dgSuffix = false ;
1609
1626
};
1610
1627
1611
1628
// / The QIR adaptive profile modifier class.
@@ -1614,15 +1631,38 @@ struct AdaptiveProfileQIR : public AnyProfileQIR<opaquePtr> {
1614
1631
using Self = AdaptiveProfileQIR;
1615
1632
using Base = AnyProfileQIR<opaquePtr>;
1616
1633
1634
+ template <typename QuakeOp>
1635
+ static std::string quakeToFuncName (QuakeOp op) {
1636
+ auto [prefix, isBarePrefix] = generateGateFunctionName<Self>(op);
1637
+ return isBarePrefix ? prefix + " __body" : prefix;
1638
+ }
1639
+
1617
1640
static void populateRewritePatterns (RewritePatternSet &patterns,
1618
1641
TypeConverter &typeConverter) {
1619
1642
Base::populateRewritePatterns (patterns, typeConverter);
1620
1643
patterns
1621
- .insert <DiscriminateOpToCallRewrite<Self>, MeasurementOpPattern<Self>>(
1622
- typeConverter, patterns.getContext ());
1644
+ .insert <DiscriminateOpToCallRewrite<Self>, MeasurementOpPattern<Self>,
1645
+
1646
+ /* Regular quantum operators. */
1647
+ QuantumGatePattern<Self, quake::HOp>,
1648
+ QuantumGatePattern<Self, quake::PhasedRxOp>,
1649
+ QuantumGatePattern<Self, quake::R1Op>,
1650
+ QuantumGatePattern<Self, quake::RxOp>,
1651
+ QuantumGatePattern<Self, quake::RyOp>,
1652
+ QuantumGatePattern<Self, quake::RzOp>,
1653
+ QuantumGatePattern<Self, quake::SOp>,
1654
+ QuantumGatePattern<Self, quake::SwapOp>,
1655
+ QuantumGatePattern<Self, quake::TOp>,
1656
+ QuantumGatePattern<Self, quake::U2Op>,
1657
+ QuantumGatePattern<Self, quake::U3Op>,
1658
+ QuantumGatePattern<Self, quake::XOp>,
1659
+ QuantumGatePattern<Self, quake::YOp>,
1660
+ QuantumGatePattern<Self, quake::ZOp>>(typeConverter,
1661
+ patterns.getContext ());
1623
1662
}
1624
1663
1625
1664
static constexpr bool discriminateToClassical = true ;
1665
+ static constexpr bool dgSuffix = true ;
1626
1666
};
1627
1667
1628
1668
// ===----------------------------------------------------------------------===//
0 commit comments