Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve readability and maintainability of Unitary Checker #2585

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 83 additions & 86 deletions utils/CircuitCheck/CircuitCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,98 +21,95 @@
using namespace llvm;
using namespace mlir;

static cl::opt<std::string> checkFilename(cl::Positional,
cl::desc("<check file>"));
static cl::opt<std::string>
inputFilename("input", cl::desc("File to check (defaults to stdin)"),
cl::init("-"));

static cl::opt<bool>
upToGlobalPhase("up-to-global-phase",
cl::desc("Check unitaries are equal up to global phase."),
cl::init(false));

static cl::opt<bool>
upToMapping("up-to-mapping",
cl::desc("Check unitaries are equal up a known permutation."),
cl::init(false));

static cl::opt<bool>
dontCanonicalize("no-canonicalizer",
cl::desc("Disable running the canonicalizer pass."),
cl::init(false));

static cl::opt<bool> printUnitary("print-unitary",
cl::desc("Print the unitary of each circuit"),
cl::init(false));

static LogicalResult computeUnitary(func::FuncOp func,
cudaq::UnitaryBuilder::UMatrix &unitary,
bool upToMapping = false) {
cudaq::UnitaryBuilder builder(unitary, upToMapping);
return builder.build(func);
// Command-line options
static cl::opt<std::string> checkFilename(cl::Positional, cl::desc("<check file>"));
static cl::opt<std::string> inputFilename("input", cl::desc("File to check (defaults to stdin)"), cl::init("-"));
static cl::opt<bool> upToGlobalPhase("up-to-global-phase", cl::desc("Check unitaries up to global phase."), cl::init(false));
static cl::opt<bool> upToMapping("up-to-mapping", cl::desc("Check unitaries with known permutation."), cl::init(false));
static cl::opt<bool> dontCanonicalize("no-canonicalizer", cl::desc("Disable canonicalization pass."), cl::init(false));
static cl::opt<bool> printUnitary("print-unitary", cl::desc("Print unitary matrices."), cl::init(false));

/**
* Computes the unitary matrix for a given function.
* @param func The function operation to analyze.
* @param unitary Output matrix to store the computed unitary.
* @param upToMapping Whether to allow permutations in computation.
* @return LogicalResult indicating success or failure.
*/
static LogicalResult computeUnitary(func::FuncOp func, cudaq::UnitaryBuilder::UMatrix &unitary, bool upToMapping = false) {
cudaq::UnitaryBuilder builder(unitary, upToMapping);
return builder.build(func);
}

int main(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv);

MLIRContext context;
context.loadDialect<cudaq::cc::CCDialect, quake::QuakeDialect,
func::FuncDialect>();

ParserConfig config(&context);
auto checkMod = parseSourceFile<mlir::ModuleOp>(checkFilename, config);
auto inputMod = parseSourceFile<mlir::ModuleOp>(inputFilename, config);

// Run canonicalizer to make sure angles in parametrized quantum operations
// are taking constants as inputs.
if (!dontCanonicalize) {
PassManager pm(&context);
OpPassManager &nestedFuncPM = pm.nest<func::FuncOp>();
nestedFuncPM.addPass(createCanonicalizerPass());
if (failed(pm.run(*checkMod)) || failed(pm.run(*inputMod)))
return EXIT_FAILURE;
}

auto applyTolerance = [](cudaq::UnitaryBuilder::UMatrix &m) {
m = (1e-12 < m.array().abs()).select(m, 0.0f);
};
cudaq::UnitaryBuilder::UMatrix checkUnitary;
cudaq::UnitaryBuilder::UMatrix inputUnitary;
auto exitStatus = EXIT_SUCCESS;
for (auto checkFunc : checkMod->getOps<func::FuncOp>()) {
StringAttr opName = checkFunc.getSymNameAttr();
checkUnitary.resize(0, 0);
inputUnitary.resize(0, 0);
// We need to check if input also has the same function
auto *inputOp = inputMod->lookupSymbol(opName);
assert(inputOp && "Function not present in input");

auto inputFunc = dyn_cast<func::FuncOp>(inputOp);
if (failed(computeUnitary(checkFunc, checkUnitary)) ||
failed(computeUnitary(inputFunc, inputUnitary, upToMapping))) {
llvm::errs() << "Cannot compute unitary for " << opName.str() << ".\n";
exitStatus = EXIT_FAILURE;
continue;
cl::ParseCommandLineOptions(argc, argv);

// Initialize MLIR context and load required dialects
MLIRContext context;
context.loadDialect<cudaq::cc::CCDialect, quake::QuakeDialect, func::FuncDialect>();

// Parse input and check files
ParserConfig config(&context);
auto checkMod = parseSourceFile<mlir::ModuleOp>(checkFilename, config);
auto inputMod = parseSourceFile<mlir::ModuleOp>(inputFilename, config);
if (!checkMod || !inputMod) {
llvm::errs() << "Error parsing input files.\n";
return EXIT_FAILURE;
Comment on lines +55 to +57
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good check to have.

}

// Here we use std streams because Eigen printers don't work with LLVM ones.
if (!cudaq::isApproxEqual(checkUnitary, inputUnitary, upToGlobalPhase)) {
applyTolerance(checkUnitary);
applyTolerance(inputUnitary);
std::cerr << "Circuit: " << opName.str() << '\n';
std::cerr << "Expected:\n";
std::cerr << checkUnitary << '\n';
std::cerr << "Got:\n";
std::cerr << inputUnitary << '\n';
exitStatus = EXIT_FAILURE;
// Apply canonicalization if enabled
if (!dontCanonicalize) {
PassManager pm(&context);
OpPassManager &nestedFuncPM = pm.nest<func::FuncOp>();
nestedFuncPM.addPass(createCanonicalizerPass());
if (failed(pm.run(*checkMod)) || failed(pm.run(*inputMod))) {
llvm::errs() << "Canonicalization failed.\n";
return EXIT_FAILURE;
}
}

if (printUnitary) {
applyTolerance(checkUnitary);
std::cout << "Circuit: " << opName.str() << '\n'
<< checkUnitary << "\n\n";
// Apply tolerance to small numerical values in matrices
auto applyTolerance = [](cudaq::UnitaryBuilder::UMatrix &m) {
m = (1e-12 < m.array().abs()).select(m, 0.0f);
};

// Iterate through functions in check module and compare against input module
int exitStatus = EXIT_SUCCESS;
for (auto checkFunc : checkMod->getOps<func::FuncOp>()) {
StringAttr opName = checkFunc.getSymNameAttr();
cudaq::UnitaryBuilder::UMatrix checkUnitary, inputUnitary;

// Look up corresponding function in input module
auto *inputOp = inputMod->lookupSymbol(opName);
if (!inputOp) {
llvm::errs() << "Function " << opName.str() << " not found in input.\n";
exitStatus = EXIT_FAILURE;
continue;
}
Comment on lines +84 to +88
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also good, since one might compile this without assertions.


auto inputFunc = dyn_cast<func::FuncOp>(inputOp);
if (!inputFunc || failed(computeUnitary(checkFunc, checkUnitary)) ||
failed(computeUnitary(inputFunc, inputUnitary, upToMapping))) {
llvm::errs() << "Cannot compute unitary for " << opName.str() << ".\n";
exitStatus = EXIT_FAILURE;
continue;
}

// Compare unitaries and report differences
if (!cudaq::isApproxEqual(checkUnitary, inputUnitary, upToGlobalPhase)) {
applyTolerance(checkUnitary);
applyTolerance(inputUnitary);
std::cerr << "Circuit: " << opName.str() << '\n';
std::cerr << "Expected:\n" << checkUnitary << '\n';
std::cerr << "Got:\n" << inputUnitary << '\n';
exitStatus = EXIT_FAILURE;
}

// Print unitaries if requested
if (printUnitary) {
applyTolerance(checkUnitary);
std::cout << "Circuit: " << opName.str() << '\n' << checkUnitary << "\n\n";
}
}
}
return exitStatus;
return exitStatus;
}
Loading