Skip to content

Commit 3602887

Browse files
authored
[core] enable the new QIR codegen for python (#2588)
* Remove the python hooks to the old codegen. This exposes all the python problems in the tests. Eliminate the expansion of python enumerate(). Remove use of empty labels for all measurements. This eliminates a loop, a data structure, and the invalid mixing of quantum and classical data values in classical memory. Fix bugs in AST bridge. Fix #2538 - measurement register name cannot be empty. Signed-off-by: Eric Schweitz <[email protected]> * Workaround use of empty labels in kernel builder. Signed-off-by: Eric Schweitz <[email protected]> * Remove empty names. We've made them illegal. Signed-off-by: Eric Schweitz <[email protected]> * Fix tests. Signed-off-by: Eric Schweitz <[email protected]> --------- Signed-off-by: Eric Schweitz <[email protected]>
1 parent 4f24197 commit 3602887

26 files changed

+549
-501
lines changed

include/cudaq/Optimizer/CodeGen/Pipelines.h

-17
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,18 @@ void commonPipelineConvertToQIR(mlir::PassManager &pm,
3030
mlir::StringRef codeGenFor = "qir",
3131
mlir::StringRef passConfigAs = "qir");
3232

33-
/// \deprecated{Only for Python, since it can't use the new QIR codegen.}
34-
void commonPipelineConvertToQIR_PythonWorkaround(
35-
mlir::PassManager &pm, const std::optional<mlir::StringRef> &convertTo);
36-
3733
/// \brief Pipeline builder to convert Quake to QIR.
3834
/// Does not specify a particular QIR profile.
3935
inline void addPipelineConvertToQIR(mlir::PassManager &pm) {
4036
commonPipelineConvertToQIR(pm);
4137
}
4238

43-
/// \deprecated{Only for Python, since it can't use the new QIR codegen.}
44-
inline void addPipelineConvertToQIR_PythonWorkaround(mlir::PassManager &pm) {
45-
commonPipelineConvertToQIR_PythonWorkaround(pm, std::nullopt);
46-
}
47-
4839
/// \brief Pipeline builder to convert Quake to QIR.
4940
/// Specifies a particular QIR profile in \p convertTo.
5041
/// \p pm Pass manager to append passes to
5142
/// \p convertTo name of QIR profile (e.g., `qir-base`, `qir-adaptive`, ...)
5243
void addPipelineConvertToQIR(mlir::PassManager &pm, mlir::StringRef convertTo);
5344

54-
/// \deprecated{Only for Python, since it can't use the new QIR codegen.}
55-
inline void
56-
addPipelineConvertToQIR_PythonWorkaround(mlir::PassManager &pm,
57-
mlir::StringRef convertTo) {
58-
commonPipelineConvertToQIR_PythonWorkaround(pm, convertTo);
59-
addQIRProfilePipeline(pm, convertTo);
60-
}
61-
6245
void addLowerToCCPipeline(mlir::OpPassManager &pm);
6346

6447
void addPipelineTranslateToOpenQASM(mlir::PassManager &pm);

lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ struct QuantumGatePattern : public OpConversionPattern<OP> {
10921092

10931093
// Process the controls, sorting them by type.
10941094
for (auto pr : llvm::zip(op.getControls(), adaptor.getControls())) {
1095-
if (isa<quake::VeqType>(std::get<0>(pr).getType())) {
1095+
if (isaVeqArgument(std::get<0>(pr).getType())) {
10961096
numArrayCtrls++;
10971097
auto sizeCall = rewriter.create<func::CallOp>(
10981098
loc, i64Ty, cudaq::opt::QIRArrayGetSize,
@@ -1155,6 +1155,18 @@ struct QuantumGatePattern : public OpConversionPattern<OP> {
11551155
return forwardOrEraseOp();
11561156
}
11571157

1158+
static bool isaVeqArgument(Type ty) {
1159+
// TODO: Need a way to identify arrays when using the opaque pointer
1160+
// variant. (In Python, the arguments may already be converted.)
1161+
auto alreadyConverted = [](Type ty) {
1162+
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(ty))
1163+
if (auto strTy = dyn_cast<LLVM::LLVMStructType>(ptrTy.getElementType()))
1164+
return strTy.isIdentified() && strTy.getName() == "Array";
1165+
return false;
1166+
};
1167+
return isa<quake::VeqType>(ty) || alreadyConverted(ty);
1168+
}
1169+
11581170
static bool conformsToIntendedCall(std::size_t numControls, Value ctrl, OP op,
11591171
StringRef qirFunctionName) {
11601172
if (numControls != 1)
@@ -1819,9 +1831,7 @@ struct QuakeToQIRAPIPrepPass
18191831
}
18201832

18211833
void guaranteeMzIsLabeled(quake::MzOp mz, int &counter, OpBuilder &builder) {
1822-
if (mz.getRegisterNameAttr() &&
1823-
/* FIXME: issue 2538: the name should never be empty. */
1824-
!mz.getRegisterNameAttr().getValue().empty()) {
1834+
if (mz.getRegisterNameAttr()) {
18251835
mz->setAttr(cudaq::opt::MzAssignedNameAttrName, builder.getUnitAttr());
18261836
return;
18271837
}

lib/Optimizer/CodeGen/Pipelines.cpp

-31
Original file line numberDiff line numberDiff line change
@@ -51,37 +51,6 @@ void cudaq::opt::commonPipelineConvertToQIR(PassManager &pm,
5151
pm.addPass(createCCToLLVM());
5252
}
5353

54-
void cudaq::opt::commonPipelineConvertToQIR_PythonWorkaround(
55-
PassManager &pm, const std::optional<StringRef> &convertTo) {
56-
pm.addNestedPass<func::FuncOp>(createApplyControlNegations());
57-
addAggressiveEarlyInlining(pm);
58-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
59-
pm.addNestedPass<func::FuncOp>(createUnwindLoweringPass());
60-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
61-
pm.addPass(createApplyOpSpecializationPass());
62-
pm.addNestedPass<func::FuncOp>(createExpandMeasurementsPass());
63-
pm.addNestedPass<func::FuncOp>(createClassicalMemToReg());
64-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
65-
pm.addNestedPass<func::FuncOp>(createCSEPass());
66-
pm.addNestedPass<func::FuncOp>(createQuakeAddDeallocs());
67-
pm.addNestedPass<func::FuncOp>(createQuakeAddMetadata());
68-
pm.addNestedPass<func::FuncOp>(createLoopNormalize());
69-
LoopUnrollOptions luo;
70-
luo.allowBreak = convertTo && (*convertTo == "qir-adaptive");
71-
pm.addNestedPass<func::FuncOp>(createLoopUnroll(luo));
72-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
73-
pm.addNestedPass<func::FuncOp>(createCSEPass());
74-
pm.addNestedPass<func::FuncOp>(createLowerToCFGPass());
75-
pm.addNestedPass<func::FuncOp>(createCombineQuantumAllocations());
76-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
77-
pm.addNestedPass<func::FuncOp>(createCSEPass());
78-
if (convertTo && (*convertTo == "qir-base"))
79-
pm.addNestedPass<func::FuncOp>(createDelayMeasurementsPass());
80-
pm.addPass(createConvertMathToFuncs());
81-
pm.addPass(createSymbolDCEPass());
82-
pm.addPass(createConvertToQIR());
83-
}
84-
8554
void cudaq::opt::addPipelineTranslateToOpenQASM(PassManager &pm) {
8655
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
8756
pm.addNestedPass<func::FuncOp>(createCSEPass());

lib/Optimizer/Dialect/Quake/QuakeOps.cpp

+13-10
Original file line numberDiff line numberDiff line change
@@ -517,38 +517,41 @@ void quake::WrapOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
517517
//===----------------------------------------------------------------------===//
518518

519519
// Common verification for measurement operations.
520-
static LogicalResult verifyMeasurements(Operation *const op,
521-
TypeRange targetsType,
522-
const Type bitsType) {
520+
template <typename MEAS>
521+
LogicalResult verifyMeasurements(MEAS op, TypeRange targetsType,
522+
const Type bitsType) {
523523
if (failed(verifyWireResultsAreLinear(op)))
524524
return failure();
525525
bool mustBeStdvec =
526526
targetsType.size() > 1 ||
527527
(targetsType.size() == 1 && isa<quake::VeqType>(targetsType[0]));
528528
if (mustBeStdvec) {
529-
if (!isa<cudaq::cc::StdvecType>(op->getResult(0).getType()))
530-
return op->emitOpError("must return `!cc.stdvec<!quake.measure>`, when "
531-
"measuring a qreg, a series of qubits, or both");
529+
if (!isa<cudaq::cc::StdvecType>(op.getMeasOut().getType()))
530+
return op.emitOpError("must return `!cc.stdvec<!quake.measure>`, when "
531+
"measuring a qreg, a series of qubits, or both");
532532
} else {
533-
if (!isa<quake::MeasureType>(op->getResult(0).getType()))
533+
if (!isa<quake::MeasureType>(op.getMeasOut().getType()))
534534
return op->emitOpError(
535535
"must return `!quake.measure` when measuring exactly one qubit");
536536
}
537+
if (op.getRegisterName())
538+
if (op.getRegisterName()->empty())
539+
return op->emitError("quake measurement name cannot be empty.");
537540
return success();
538541
}
539542

540543
LogicalResult quake::MxOp::verify() {
541-
return verifyMeasurements(getOperation(), getTargets().getType(),
544+
return verifyMeasurements(*this, getTargets().getType(),
542545
getMeasOut().getType());
543546
}
544547

545548
LogicalResult quake::MyOp::verify() {
546-
return verifyMeasurements(getOperation(), getTargets().getType(),
549+
return verifyMeasurements(*this, getTargets().getType(),
547550
getMeasOut().getType());
548551
}
549552

550553
LogicalResult quake::MzOp::verify() {
551-
return verifyMeasurements(getOperation(), getTargets().getType(),
554+
return verifyMeasurements(*this, getTargets().getType(),
552555
getMeasOut().getType());
553556
}
554557

lib/Optimizer/Transforms/GlobalizeArrayValues.cpp

+69-16
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,23 @@ convertArrayAttrToGlobalConstant(MLIRContext *ctx, Location loc,
8787
}
8888

8989
namespace {
90+
91+
// This pattern replaces a cc.const_array with a global constant. It can
92+
// recognize a couple of usage patterns and will generate efficient IR in those
93+
// cases.
94+
//
95+
// Pattern 1: The entire constant array is stored to a stack variable(s). Here
96+
// we can eliminate the stack allocation and use the global constant.
97+
//
98+
// Pattern 2: Individual elements at dynamic offsets are extracted from the
99+
// constant array and used. This can be replaced with a compute pointer
100+
// operation using the global constant and a load of the element at the computed
101+
// offset.
102+
//
103+
// Default: If the usage is not recognized, the constant array value is replaced
104+
// with a load of the entire global variable. In this case, LLVM's optimizations
105+
// are counted on to help demote the (large?) sequence value to primitive memory
106+
// address arithmetic.
90107
struct ConstantArrayPattern
91108
: public OpRewritePattern<cudaq::cc::ConstantArrayOp> {
92109
explicit ConstantArrayPattern(MLIRContext *ctx, ModuleOp module,
@@ -95,21 +112,30 @@ struct ConstantArrayPattern
95112

96113
LogicalResult matchAndRewrite(cudaq::cc::ConstantArrayOp conarr,
97114
PatternRewriter &rewriter) const override {
115+
auto func = conarr->getParentOfType<func::FuncOp>();
116+
if (!func)
117+
return failure();
118+
98119
SmallVector<cudaq::cc::AllocaOp> allocas;
99120
SmallVector<cudaq::cc::StoreOp> stores;
121+
SmallVector<cudaq::cc::ExtractValueOp> extracts;
122+
bool loadAsValue = false;
100123
for (auto *usr : conarr->getUsers()) {
101124
auto store = dyn_cast<cudaq::cc::StoreOp>(usr);
102-
if (!store)
103-
return failure();
104-
auto alloca = store.getPtrvalue().getDefiningOp<cudaq::cc::AllocaOp>();
105-
if (!alloca)
106-
return failure();
107-
stores.push_back(store);
108-
allocas.push_back(alloca);
125+
auto extract = dyn_cast<cudaq::cc::ExtractValueOp>(usr);
126+
if (store) {
127+
auto alloca = store.getPtrvalue().getDefiningOp<cudaq::cc::AllocaOp>();
128+
if (alloca) {
129+
stores.push_back(store);
130+
allocas.push_back(alloca);
131+
continue;
132+
}
133+
} else if (extract) {
134+
extracts.push_back(extract);
135+
continue;
136+
}
137+
loadAsValue = true;
109138
}
110-
auto func = conarr->getParentOfType<func::FuncOp>();
111-
if (!func)
112-
return failure();
113139
std::string globalName =
114140
func.getName().str() + ".rodata_" + std::to_string(counter++);
115141
auto *ctx = rewriter.getContext();
@@ -118,12 +144,39 @@ struct ConstantArrayPattern
118144
if (failed(convertArrayAttrToGlobalConstant(ctx, conarr.getLoc(), valueAttr,
119145
module, globalName, eleTy)))
120146
return failure();
121-
for (auto alloca : allocas)
122-
rewriter.replaceOpWithNewOp<cudaq::cc::AddressOfOp>(
123-
alloca, alloca.getType(), globalName);
124-
for (auto store : stores)
125-
rewriter.eraseOp(store);
126-
rewriter.eraseOp(conarr);
147+
auto loc = conarr.getLoc();
148+
if (!extracts.empty()) {
149+
auto base = rewriter.create<cudaq::cc::AddressOfOp>(
150+
loc, cudaq::cc::PointerType::get(conarr.getType()), globalName);
151+
auto elePtrTy = cudaq::cc::PointerType::get(eleTy);
152+
for (auto extract : extracts) {
153+
SmallVector<cudaq::cc::ComputePtrArg> args;
154+
unsigned i = 0;
155+
for (auto arg : extract.getRawConstantIndices()) {
156+
if (arg == cudaq::cc::ExtractValueOp::getDynamicIndexValue())
157+
args.push_back(extract.getDynamicIndices()[i++]);
158+
else
159+
args.push_back(arg);
160+
}
161+
OpBuilder::InsertionGuard guard(rewriter);
162+
rewriter.setInsertionPoint(extract);
163+
auto addrVal =
164+
rewriter.create<cudaq::cc::ComputePtrOp>(loc, elePtrTy, base, args);
165+
rewriter.replaceOpWithNewOp<cudaq::cc::LoadOp>(extract, addrVal);
166+
}
167+
}
168+
if (!stores.empty()) {
169+
for (auto alloca : allocas)
170+
rewriter.replaceOpWithNewOp<cudaq::cc::AddressOfOp>(
171+
alloca, alloca.getType(), globalName);
172+
for (auto store : stores)
173+
rewriter.eraseOp(store);
174+
}
175+
if (loadAsValue) {
176+
auto base = rewriter.create<cudaq::cc::AddressOfOp>(
177+
loc, cudaq::cc::PointerType::get(conarr.getType()), globalName);
178+
rewriter.replaceOpWithNewOp<cudaq::cc::LoadOp>(conarr, base);
179+
}
127180
return success();
128181
}
129182

python/cudaq/kernel/ast_bridge.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -1749,9 +1749,11 @@ def bodyBuilder(iterVal):
17491749
self.ctx) if len(qubits) == 1 and quake.RefType.isinstance(
17501750
qubits[0].type) else cc.StdvecType.get(
17511751
self.ctx, quake.MeasureType.get(self.ctx))
1752-
measureResult = opCtor(measTy, [],
1753-
qubits,
1754-
registerName=registerName).result
1752+
label = registerName
1753+
if not label:
1754+
label = None
1755+
measureResult = opCtor(measTy, [], qubits,
1756+
registerName=label).result
17551757
if pushResultToStack:
17561758
self.pushValue(
17571759
quake.DiscriminateOp(resTy, measureResult).result)
@@ -3152,6 +3154,73 @@ def bodyBuilder(iterVar):
31523154
isDecrementing=isDecrementing)
31533155
return
31543156

3157+
# We can simplify `for i,j in enumerate(L)` MLIR code immensely
3158+
# by just building a for loop over the iterable object L and using
3159+
# the index into that iterable and the element.
3160+
if isinstance(node.iter, ast.Call):
3161+
if node.iter.func.id == 'enumerate':
3162+
[self.visit(arg) for arg in node.iter.args]
3163+
if len(self.valueStack) == 2:
3164+
iterable = self.popValue()
3165+
self.popValue()
3166+
else:
3167+
assert len(self.valueStack) == 1
3168+
iterable = self.popValue()
3169+
iterable = self.ifPointerThenLoad(iterable)
3170+
totalSize = None
3171+
extractFunctor = None
3172+
varNames = []
3173+
for elt in node.target.elts:
3174+
varNames.append(elt.id)
3175+
3176+
beEfficient = False
3177+
if quake.VeqType.isinstance(iterable.type):
3178+
totalSize = quake.VeqSizeOp(self.getIntegerType(),
3179+
iterable).result
3180+
3181+
def functor(seq, idx):
3182+
q = quake.ExtractRefOp(self.getRefType(),
3183+
seq,
3184+
-1,
3185+
index=idx).result
3186+
return [idx, q]
3187+
3188+
extractFunctor = functor
3189+
beEfficient = True
3190+
elif cc.StdvecType.isinstance(iterable.type):
3191+
totalSize = cc.StdvecSizeOp(self.getIntegerType(),
3192+
iterable).result
3193+
3194+
def functor(seq, idx):
3195+
vecTy = cc.StdvecType.getElementType(seq.type)
3196+
dataTy = cc.PointerType.get(self.ctx, vecTy)
3197+
arrTy = vecTy
3198+
if not cc.ArrayType.isinstance(arrTy):
3199+
arrTy = cc.ArrayType.get(self.ctx, vecTy)
3200+
dataArrTy = cc.PointerType.get(self.ctx, arrTy)
3201+
data = cc.StdvecDataOp(dataArrTy, seq).result
3202+
v = cc.ComputePtrOp(
3203+
dataTy, data, [idx],
3204+
DenseI32ArrayAttr.get([kDynamicPtrIndex],
3205+
context=self.ctx)).result
3206+
return [idx, v]
3207+
3208+
extractFunctor = functor
3209+
beEfficient = True
3210+
3211+
if beEfficient:
3212+
3213+
def bodyBuilder(iterVar):
3214+
self.symbolTable.pushScope()
3215+
values = extractFunctor(iterable, iterVar)
3216+
for i, v in enumerate(values):
3217+
self.symbolTable[varNames[i]] = v
3218+
[self.visit(b) for b in node.body]
3219+
self.symbolTable.popScope()
3220+
3221+
self.createInvariantForLoop(totalSize, bodyBuilder)
3222+
return
3223+
31553224
self.visit(node.iter)
31563225
assert len(self.valueStack) > 0 and len(self.valueStack) < 3
31573226

0 commit comments

Comments
 (0)