Skip to content

Commit da6d0b9

Browse files
amccaskeyschweitzpgi1tnguyen
authored
Add exp_pauli quantum instruction to the runtime and ASTBridge (#660)
* Add exp_pauli quantum instruction to the runtime and ASTBridge Signed-off-by: Alex McCaskey <[email protected]> * Drop the cc::StringType for the moment. Convert the headers to use ASCIIZ string literals for the final argument to exp_pauli(). (Simplifies the AST presented to the bridge and falls back to using code that handles string literals.) Update LowerToQIR for the new types, fix bugs. Fix up tests. * build fixes, still have to fix the wheel validation Signed-off-by: Alex McCaskey <[email protected]> * cleanup, add qpp applyExpPauli impl Signed-off-by: Alex McCaskey <[email protected]> * fix docs gen Signed-off-by: Alex McCaskey <[email protected]> * clean up Signed-off-by: Alex McCaskey <[email protected]> * clean up, provide docs Signed-off-by: Alex McCaskey <[email protected]> * Update runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cu Co-authored-by: Thien Nguyen <[email protected]> --------- Signed-off-by: Alex McCaskey <[email protected]> Co-authored-by: Eric Schweitz <[email protected]> Co-authored-by: Thien Nguyen <[email protected]>
1 parent 8369b4e commit da6d0b9

35 files changed

+842
-438
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ _version.py
8686

8787
# third party integrations
8888
simulators/
89+
apps/
8990

9091
# macOS
9192
.DS_Store

docs/sphinx/api/languages/python_api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Program Construction
3333
.. automethod:: rz
3434
.. automethod:: r1
3535
.. automethod:: swap
36+
.. automethod:: exp_pauli
3637
.. automethod:: mx
3738
.. automethod:: my
3839
.. automethod:: mz

docs/sphinx/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def setup(app):
169169
('cpp:identifier', 'mlir::ImplicitLocOpBuilder'),
170170
('cpp:identifier', 'BinarySymplecticForm'),
171171
('cpp:identifier', 'CountsDictionary'),
172+
('cpp:identifier', 'QuakeValueOrNumericType'),
172173
('py:class', 'function'),
173174
('py:class', 'type'),
174175
('py:class', 'cudaq::spin_op'),

include/cudaq/Frontend/nvqpp/ASTBridge.h

+12
Original file line numberDiff line numberDiff line change
@@ -704,4 +704,16 @@ inline bool isCallOperator(clang::OverloadedOperatorKind kindValue) {
704704
return kindValue == clang::OverloadedOperatorKind::OO_Call;
705705
}
706706

707+
// Is \p t of type `char *`?
708+
inline bool isCharPointerType(mlir::Type t) {
709+
if (auto ptrTy = dyn_cast<cc::PointerType>(t)) {
710+
mlir::Type eleTy = ptrTy.getElementType();
711+
if (auto arrTy = dyn_cast<cc::ArrayType>(eleTy))
712+
eleTy = arrTy.getElementType();
713+
if (auto intTy = dyn_cast<mlir::IntegerType>(eleTy))
714+
return intTy.getWidth() == 8;
715+
}
716+
return false;
717+
}
718+
707719
} // namespace cudaq

include/cudaq/Optimizer/Dialect/CC/CCOps.td

+18
Original file line numberDiff line numberDiff line change
@@ -1369,4 +1369,22 @@ def cc_CallableClosureOp : CCOp<"callable_closure", [Pure]> {
13691369
}];
13701370
}
13711371

1372+
def cc_CreateStringLiteralOp : CCOp<"string_literal"> {
1373+
let summary = "Create a constant string literal.";
1374+
let description = [{
1375+
This operation creates a ASCIIZ string literal value. It's argument is a
1376+
constant MLIR String Attribute. The literal will have a null character
1377+
appended automatically.
1378+
1379+
```mlir
1380+
%0 = cc.string_literal "Quantum Computing" : !cc.ptr<!cc.array<i8 x 18>>
1381+
```
1382+
}];
1383+
1384+
let arguments = (ins StrAttr:$stringLiteral);
1385+
let results = (outs cc_PointerType:$result);
1386+
let assemblyFormat = [{
1387+
$stringLiteral `:` qualified(type(results)) attr-dict
1388+
}];
1389+
}
13721390
#endif // CUDAQ_OPTIMIZER_DIALECT_CC_OPS

include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td

+18-1
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,23 @@ class TwoTargetOp<string mnemonic, list<Trait> traits = []> :
786786
// Quantum operators (gates)
787787
//===----------------------------------------------------------------------===//
788788

789+
def ExpPauliOp : QuakeOp<"exp_pauli", []> {
790+
let summary = "General Pauli tensor product rotation";
791+
let description = [{
792+
This operation affects a general Pauli tensor product rotation on
793+
the input qubits. The number of Pauli characters in the input Pauli word
794+
string must equal the number of qubits in the veq. Mathematically, this operation
795+
applies exp(i theta P) where P is a general Pauli tensor product.
796+
}];
797+
798+
let arguments = (ins AnyFloat:$parameter, VeqType:$qubits, cc_PointerType:$pauli);
799+
let results = (outs );
800+
801+
let assemblyFormat = [{
802+
`(` $parameter `)` $qubits `,` $pauli `:` functional-type(operands, results) attr-dict
803+
}];
804+
}
805+
789806
def HOp : OneTargetOp<"h", [Hermitian]> {
790807
let summary = "Hadamard operation";
791808
let description = [{
@@ -815,7 +832,7 @@ def PhasedRxOp : QuakeOperator<"phased_rx",
815832
Matrix representation:
816833
```
817834
PhasedRx(θ,φ) = | cos(θ/2) -iexp(-iφ) * sin(θ/2) |
818-
| -iexp(iφ)) * sin(θ/2) cos(θ/2) |
835+
| -iexp(iφ) * sin(θ/2) cos(θ/2) |
819836
```
820837

821838
Circuit symbol:

lib/Frontend/nvqpp/ConvertDecl.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ void QuakeBridgeVisitor::addArgumentSymbols(
9090
// Transform pass-by-value arguments to stack slots.
9191
auto loc = toLocation(argVal);
9292
auto parmTy = entryBlock->getArgument(index).getType();
93-
if (isa<cc::CallableType, cc::StdvecType, cc::ArrayType, cc::StructType,
94-
LLVM::LLVMStructType, FunctionType, quake::RefType,
95-
quake::VeqType>(parmTy)) {
93+
if (isa<FunctionType, cc::ArrayType, cc::CallableType, cc::PointerType,
94+
cc::StdvecType, cc::StructType, LLVM::LLVMStructType,
95+
quake::ControlType, quake::RefType, quake::VeqType,
96+
quake::WireType>(parmTy)) {
9697
symbolTable.insert(name, entryBlock->getArgument(index));
9798
} else {
9899
auto stackSlot = builder.create<cc::AllocaOp>(loc, parmTy);

lib/Frontend/nvqpp/ConvertExpr.cpp

+46-3
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,47 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) {
12541254
isAdjoint = structTypeAsRecord->getName() == "adj";
12551255
}
12561256

1257+
if (funcName.equals("exp_pauli")) {
1258+
assert(args.size() > 2);
1259+
SmallVector<Value> processedArgs;
1260+
auto addTheString = [&](Value v) {
1261+
// The C-string argument (char*) may be loaded by an lvalue to rvalue
1262+
// cast. Here, we must pass the pointer and not the first character's
1263+
// value.
1264+
if (isCharPointerType(v.getType())) {
1265+
processedArgs.push_back(v);
1266+
} else if (auto load = v.getDefiningOp<cudaq::cc::LoadOp>()) {
1267+
processedArgs.push_back(load.getPtrvalue());
1268+
} else {
1269+
reportClangError(x, mangler, "could not determine string argument");
1270+
}
1271+
};
1272+
if (args.size() == 3 && isa<quake::VeqType>(args[1].getType())) {
1273+
// Have f64, veq, string
1274+
processedArgs.push_back(args[0]);
1275+
processedArgs.push_back(args[1]);
1276+
addTheString(args[2]);
1277+
} else {
1278+
// should have f64, string, qubits...
1279+
// need f64, veq, string, so process here
1280+
1281+
// add f64 value
1282+
processedArgs.push_back(args[0]);
1283+
1284+
// concat the qubits to a veq
1285+
SmallVector<Value> quantumArgs;
1286+
for (std::size_t i = 2; i < args.size(); i++)
1287+
quantumArgs.push_back(args[i]);
1288+
processedArgs.push_back(builder.create<quake::ConcatOp>(
1289+
loc, quake::VeqType::get(builder.getContext(), quantumArgs.size()),
1290+
quantumArgs));
1291+
addTheString(args[1]);
1292+
}
1293+
1294+
builder.create<quake::ExpPauliOp>(loc, TypeRange{}, processedArgs);
1295+
return true;
1296+
}
1297+
12571298
if (funcName.equals("mx") || funcName.equals("my") ||
12581299
funcName.equals("mz")) {
12591300
// Measurements always return a bool or a std::vector<bool>.
@@ -2140,7 +2181,7 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
21402181

21412182
// TODO: remove this when we can handle ctors more generally.
21422183
if (!ctor->isDefaultConstructor()) {
2143-
LLVM_DEBUG(llvm::dbgs() << "unhandled ctor:\n"; x->dump());
2184+
LLVM_DEBUG(llvm::dbgs() << ctorName << " - unhandled ctor:\n"; x->dump());
21442185
TODO_loc(loc, "C++ ctor (not-default)");
21452186
}
21462187

@@ -2206,8 +2247,10 @@ bool QuakeBridgeVisitor::VisitDeclRefExpr(clang::DeclRefExpr *x) {
22062247
}
22072248

22082249
bool QuakeBridgeVisitor::VisitStringLiteral(clang::StringLiteral *x) {
2209-
TODO_x(toLocation(x->getSourceRange()), x, mangler, "string literal");
2210-
return false;
2250+
auto strLitTy = cc::PointerType::get(cc::ArrayType::get(
2251+
builder.getContext(), builder.getI8Type(), x->getString().size() + 1));
2252+
return pushValue(builder.create<cc::CreateStringLiteralOp>(
2253+
toLocation(x), strLitTy, builder.getStringAttr(x->getString())));
22112254
}
22122255

22132256
} // namespace cudaq::details

lib/Frontend/nvqpp/ConvertType.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ QuakeBridgeVisitor::findCallOperator(const clang::CXXRecordDecl *decl) {
8181

8282
bool QuakeBridgeVisitor::TraverseRecordType(clang::RecordType *t) {
8383
auto *recDecl = t->getDecl();
84+
8485
if (ignoredClass(recDecl))
8586
return true;
8687
auto reci = records.find(t);
@@ -311,7 +312,7 @@ bool QuakeBridgeVisitor::doSyntaxChecks(const clang::FunctionDecl *x) {
311312
// device kernels may take veq and/or ref arguments.
312313
if (isArithmeticType(t) || isArithmeticSequenceType(t) ||
313314
isQuantumType(t) || isKernelCallable(t) || isFunctionCallable(t) ||
314-
isReferenceToCallableRecord(t, p))
315+
isCharPointerType(t) || isReferenceToCallableRecord(t, p))
315316
continue;
316317
reportClangError(p, mangler, "kernel argument type not supported");
317318
return false;

lib/Optimizer/CodeGen/LowerToQIR.cpp

+91-25
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,14 @@ class SubveqOpRewrite : public ConvertOpToLLVMPattern<quake::SubVeqOp> {
316316
};
317317

318318
/// Lower the quake.reset op to QIR
319-
template <typename ResetOpType>
320-
class ResetRewrite : public ConvertOpToLLVMPattern<ResetOpType> {
319+
class ResetRewrite : public ConvertOpToLLVMPattern<quake::ResetOp> {
321320
public:
322-
using Base = ConvertOpToLLVMPattern<ResetOpType>;
323-
using Base::Base;
321+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
324322

325323
LogicalResult
326-
matchAndRewrite(ResetOpType instOp, typename Base::OpAdaptor adaptor,
324+
matchAndRewrite(quake::ResetOp instOp, OpAdaptor adaptor,
327325
ConversionPatternRewriter &rewriter) const override {
328-
auto parentModule = instOp->template getParentOfType<ModuleOp>();
326+
auto parentModule = instOp->getParentOfType<ModuleOp>();
329327
auto context = parentModule->getContext();
330328
std::string qirQisPrefix(cudaq::opt::QIRQISPrefix);
331329
std::string instName = instOp->getName().stripDialect().str();
@@ -348,6 +346,37 @@ class ResetRewrite : public ConvertOpToLLVMPattern<ResetOpType> {
348346
}
349347
};
350348

349+
/// Lower exp_pauli(f64, veq, cc.string) to __quantum__qis__exp_pauli
350+
class ExpPauliRewrite : public ConvertOpToLLVMPattern<quake::ExpPauliOp> {
351+
public:
352+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
353+
354+
LogicalResult
355+
matchAndRewrite(quake::ExpPauliOp instOp, OpAdaptor adaptor,
356+
ConversionPatternRewriter &rewriter) const override {
357+
auto loc = instOp->getLoc();
358+
auto parentModule = instOp->getParentOfType<ModuleOp>();
359+
auto *context = rewriter.getContext();
360+
std::string qirQisPrefix(cudaq::opt::QIRQISPrefix);
361+
auto qirFunctionName = qirQisPrefix + "exp_pauli";
362+
FlatSymbolRefAttr symbolRef = cudaq::opt::factory::createLLVMFunctionSymbol(
363+
qirFunctionName, /*return type=*/LLVM::LLVMVoidType::get(context),
364+
{rewriter.getF64Type(), cudaq::opt::getArrayType(context),
365+
cudaq::opt::factory::getPointerType(context)},
366+
parentModule);
367+
SmallVector<Value> operands = adaptor.getOperands();
368+
// Make sure to drop any length information from the type of the Pauli word.
369+
auto pauliWord = operands.back();
370+
operands.pop_back();
371+
auto castedPauli = rewriter.create<LLVM::BitcastOp>(
372+
loc, cudaq::opt::factory::getPointerType(context), pauliWord);
373+
operands.push_back(castedPauli);
374+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(instOp, TypeRange{}, symbolRef,
375+
operands);
376+
return success();
377+
}
378+
};
379+
351380
/// Lower single target Quantum ops with no parameter to QIR:
352381
/// h, x, y, z, s, t
353382
template <typename OP>
@@ -1310,6 +1339,42 @@ class StdvecSizeOpPattern
13101339
}
13111340
};
13121341

1342+
class CreateStringLiteralOpPattern
1343+
: public ConvertOpToLLVMPattern<cudaq::cc::CreateStringLiteralOp> {
1344+
public:
1345+
using Base = ConvertOpToLLVMPattern<cudaq::cc::CreateStringLiteralOp>;
1346+
using Base::Base;
1347+
1348+
LogicalResult
1349+
matchAndRewrite(cudaq::cc::CreateStringLiteralOp stringLiteralOp,
1350+
OpAdaptor adaptor,
1351+
ConversionPatternRewriter &rewriter) const override {
1352+
auto loc = stringLiteralOp.getLoc();
1353+
auto parentModule = stringLiteralOp->getParentOfType<ModuleOp>();
1354+
StringRef stringLiteral = stringLiteralOp.getStringLiteral();
1355+
1356+
// Write to the module body
1357+
auto insertPoint = rewriter.saveInsertionPoint();
1358+
rewriter.setInsertionPointToStart(parentModule.getBody());
1359+
1360+
// Create the register name global
1361+
auto builder = cudaq::IRBuilder::atBlockEnd(parentModule.getBody());
1362+
auto slGlobal =
1363+
builder.genCStringLiteralAppendNul(loc, parentModule, stringLiteral);
1364+
1365+
// Shift back to the function
1366+
rewriter.restoreInsertionPoint(insertPoint);
1367+
1368+
// Get the string address
1369+
rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
1370+
stringLiteralOp,
1371+
cudaq::opt::factory::getPointerType(slGlobal.getType()),
1372+
slGlobal.getSymName());
1373+
1374+
return success();
1375+
}
1376+
};
1377+
13131378
class StoreOpPattern : public ConvertOpToLLVMPattern<cudaq::cc::StoreOp> {
13141379
public:
13151380
using Base = ConvertOpToLLVMPattern<cudaq::cc::StoreOp>;
@@ -1420,25 +1485,26 @@ class QuakeToQIRRewrite : public cudaq::opt::QuakeToQIRBase<QuakeToQIRRewrite> {
14201485

14211486
patterns.insert<GetVeqSizeOpRewrite, MxToMz, MyToMz, ReturnBitRewrite>(
14221487
context);
1423-
patterns.insert<
1424-
AllocaOpRewrite, AllocaOpPattern, CallableClosureOpPattern,
1425-
CallableFuncOpPattern, CallCallableOpPattern, CastOpPattern,
1426-
ComputePtrOpPattern, ConcatOpRewrite, DeallocOpRewrite,
1427-
ExtractQubitOpRewrite, ExtractValueOpPattern, FuncToPtrOpPattern,
1428-
InsertValueOpPattern, InstantiateCallableOpPattern, LoadOpPattern,
1429-
OneTargetRewrite<quake::HOp>, OneTargetRewrite<quake::XOp>,
1430-
OneTargetRewrite<quake::YOp>, OneTargetRewrite<quake::ZOp>,
1431-
OneTargetRewrite<quake::SOp>, OneTargetRewrite<quake::TOp>,
1432-
OneTargetOneParamRewrite<quake::R1Op>,
1433-
OneTargetTwoParamRewrite<quake::PhasedRxOp>,
1434-
OneTargetOneParamRewrite<quake::RxOp>,
1435-
OneTargetOneParamRewrite<quake::RyOp>,
1436-
OneTargetOneParamRewrite<quake::RzOp>,
1437-
OneTargetTwoParamRewrite<quake::U2Op>,
1438-
OneTargetTwoParamRewrite<quake::U3Op>, ResetRewrite<quake::ResetOp>,
1439-
StdvecDataOpPattern, StdvecInitOpPattern, StdvecSizeOpPattern,
1440-
StoreOpPattern, SubveqOpRewrite, TwoTargetRewrite<quake::SwapOp>,
1441-
UndefOpPattern>(typeConverter);
1488+
patterns
1489+
.insert<AllocaOpRewrite, AllocaOpPattern, CallableClosureOpPattern,
1490+
CallableFuncOpPattern, CallCallableOpPattern, CastOpPattern,
1491+
ComputePtrOpPattern, ConcatOpRewrite, DeallocOpRewrite,
1492+
CreateStringLiteralOpPattern, ExtractQubitOpRewrite,
1493+
ExtractValueOpPattern, FuncToPtrOpPattern, InsertValueOpPattern,
1494+
InstantiateCallableOpPattern, LoadOpPattern, ExpPauliRewrite,
1495+
OneTargetRewrite<quake::HOp>, OneTargetRewrite<quake::XOp>,
1496+
OneTargetRewrite<quake::YOp>, OneTargetRewrite<quake::ZOp>,
1497+
OneTargetRewrite<quake::SOp>, OneTargetRewrite<quake::TOp>,
1498+
OneTargetOneParamRewrite<quake::R1Op>,
1499+
OneTargetTwoParamRewrite<quake::PhasedRxOp>,
1500+
OneTargetOneParamRewrite<quake::RxOp>,
1501+
OneTargetOneParamRewrite<quake::RyOp>,
1502+
OneTargetOneParamRewrite<quake::RzOp>,
1503+
OneTargetTwoParamRewrite<quake::U2Op>,
1504+
OneTargetTwoParamRewrite<quake::U3Op>, ResetRewrite,
1505+
StdvecDataOpPattern, StdvecInitOpPattern, StdvecSizeOpPattern,
1506+
StoreOpPattern, SubveqOpRewrite,
1507+
TwoTargetRewrite<quake::SwapOp>, UndefOpPattern>(typeConverter);
14421508
patterns.insert<MeasureRewrite<quake::MzOp>>(typeConverter, measureCounter);
14431509

14441510
target.addLegalDialect<LLVM::LLVMDialect>();

lib/Optimizer/Dialect/CC/CCTypes.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,14 @@ void cc::ArrayType::print(AsmPrinter &printer) const {
133133
#define GET_TYPEDEF_CLASSES
134134
#include "cudaq/Optimizer/Dialect/CC/CCTypes.cpp.inc"
135135

136+
//===----------------------------------------------------------------------===//
137+
136138
namespace cudaq {
137139

138140
cc::CallableType cc::CallableType::getNoSignature(MLIRContext *ctx) {
139141
return CallableType::get(ctx, FunctionType::get(ctx, {}, {}));
140142
}
141143

142-
//===----------------------------------------------------------------------===//
143-
144144
void cc::CCDialect::registerTypes() {
145145
addTypes<ArrayType, CallableType, PointerType, StdvecType, StructType>();
146146
}

python/runtime/cudaq/builder/py_kernel_builder.cpp

+18-1
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,24 @@ provided `function` will be applied within `self` at each iteration.
893893
.def("to_quake", &kernel_builder<>::to_quake, "See :func:`__str__`.")
894894
.def("__str__", &kernel_builder<>::to_quake,
895895
"Return the :class:`Kernel` as a string in its MLIR representation "
896-
"using the Quake dialect.\n");
896+
"using the Quake dialect.\n")
897+
.def(
898+
"exp_pauli",
899+
[](kernel_builder<> &self, py::object theta, const QuakeValue &qubits,
900+
const std::string &pauliWord) {
901+
if (py::isinstance<py::float_>(theta))
902+
self.exp_pauli(theta.cast<double>(), qubits, pauliWord);
903+
else if (py::isinstance<QuakeValue>(theta))
904+
self.exp_pauli(theta.cast<QuakeValue &>(), qubits, pauliWord);
905+
else
906+
throw std::runtime_error(
907+
"Invalid `theta` argument type. Must be a "
908+
"`float` or a `QuakeValue`.");
909+
},
910+
"Apply a general Pauli tensor product rotation, `exp(i theta P)`, on "
911+
"the specified qubit register. The Pauli tensor product is provided "
912+
"as a string, e.g. `XXYX` for a 4-qubit term. The angle parameter "
913+
"can be provided as a concrete float or a `QuakeValue`.");
897914
}
898915

899916
void bindBuilder(py::module &mod) {

0 commit comments

Comments
 (0)