-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve readability and maintainability of Unitary Checker #2585
Open
iqbalbhatti49
wants to merge
2
commits into
NVIDIA:main
Choose a base branch
from
iqbalbhatti49:patch-1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,98 +21,95 @@ | |
using namespace llvm; | ||
using namespace mlir; | ||
|
||
static cl::opt<std::string> checkFilename(cl::Positional, | ||
cl::desc("<check file>")); | ||
static cl::opt<std::string> | ||
inputFilename("input", cl::desc("File to check (defaults to stdin)"), | ||
cl::init("-")); | ||
|
||
static cl::opt<bool> | ||
upToGlobalPhase("up-to-global-phase", | ||
cl::desc("Check unitaries are equal up to global phase."), | ||
cl::init(false)); | ||
|
||
static cl::opt<bool> | ||
upToMapping("up-to-mapping", | ||
cl::desc("Check unitaries are equal up a known permutation."), | ||
cl::init(false)); | ||
|
||
static cl::opt<bool> | ||
dontCanonicalize("no-canonicalizer", | ||
cl::desc("Disable running the canonicalizer pass."), | ||
cl::init(false)); | ||
|
||
static cl::opt<bool> printUnitary("print-unitary", | ||
cl::desc("Print the unitary of each circuit"), | ||
cl::init(false)); | ||
|
||
static LogicalResult computeUnitary(func::FuncOp func, | ||
cudaq::UnitaryBuilder::UMatrix &unitary, | ||
bool upToMapping = false) { | ||
cudaq::UnitaryBuilder builder(unitary, upToMapping); | ||
return builder.build(func); | ||
// Command-line options | ||
static cl::opt<std::string> checkFilename(cl::Positional, cl::desc("<check file>")); | ||
static cl::opt<std::string> inputFilename("input", cl::desc("File to check (defaults to stdin)"), cl::init("-")); | ||
static cl::opt<bool> upToGlobalPhase("up-to-global-phase", cl::desc("Check unitaries up to global phase."), cl::init(false)); | ||
static cl::opt<bool> upToMapping("up-to-mapping", cl::desc("Check unitaries with known permutation."), cl::init(false)); | ||
static cl::opt<bool> dontCanonicalize("no-canonicalizer", cl::desc("Disable canonicalization pass."), cl::init(false)); | ||
static cl::opt<bool> printUnitary("print-unitary", cl::desc("Print unitary matrices."), cl::init(false)); | ||
|
||
/** | ||
* Computes the unitary matrix for a given function. | ||
* @param func The function operation to analyze. | ||
* @param unitary Output matrix to store the computed unitary. | ||
* @param upToMapping Whether to allow permutations in computation. | ||
* @return LogicalResult indicating success or failure. | ||
*/ | ||
static LogicalResult computeUnitary(func::FuncOp func, cudaq::UnitaryBuilder::UMatrix &unitary, bool upToMapping = false) { | ||
cudaq::UnitaryBuilder builder(unitary, upToMapping); | ||
return builder.build(func); | ||
} | ||
|
||
int main(int argc, char **argv) { | ||
cl::ParseCommandLineOptions(argc, argv); | ||
|
||
MLIRContext context; | ||
context.loadDialect<cudaq::cc::CCDialect, quake::QuakeDialect, | ||
func::FuncDialect>(); | ||
|
||
ParserConfig config(&context); | ||
auto checkMod = parseSourceFile<mlir::ModuleOp>(checkFilename, config); | ||
auto inputMod = parseSourceFile<mlir::ModuleOp>(inputFilename, config); | ||
|
||
// Run canonicalizer to make sure angles in parametrized quantum operations | ||
// are taking constants as inputs. | ||
if (!dontCanonicalize) { | ||
PassManager pm(&context); | ||
OpPassManager &nestedFuncPM = pm.nest<func::FuncOp>(); | ||
nestedFuncPM.addPass(createCanonicalizerPass()); | ||
if (failed(pm.run(*checkMod)) || failed(pm.run(*inputMod))) | ||
return EXIT_FAILURE; | ||
} | ||
|
||
auto applyTolerance = [](cudaq::UnitaryBuilder::UMatrix &m) { | ||
m = (1e-12 < m.array().abs()).select(m, 0.0f); | ||
}; | ||
cudaq::UnitaryBuilder::UMatrix checkUnitary; | ||
cudaq::UnitaryBuilder::UMatrix inputUnitary; | ||
auto exitStatus = EXIT_SUCCESS; | ||
for (auto checkFunc : checkMod->getOps<func::FuncOp>()) { | ||
StringAttr opName = checkFunc.getSymNameAttr(); | ||
checkUnitary.resize(0, 0); | ||
inputUnitary.resize(0, 0); | ||
// We need to check if input also has the same function | ||
auto *inputOp = inputMod->lookupSymbol(opName); | ||
assert(inputOp && "Function not present in input"); | ||
|
||
auto inputFunc = dyn_cast<func::FuncOp>(inputOp); | ||
if (failed(computeUnitary(checkFunc, checkUnitary)) || | ||
failed(computeUnitary(inputFunc, inputUnitary, upToMapping))) { | ||
llvm::errs() << "Cannot compute unitary for " << opName.str() << ".\n"; | ||
exitStatus = EXIT_FAILURE; | ||
continue; | ||
cl::ParseCommandLineOptions(argc, argv); | ||
|
||
// Initialize MLIR context and load required dialects | ||
MLIRContext context; | ||
context.loadDialect<cudaq::cc::CCDialect, quake::QuakeDialect, func::FuncDialect>(); | ||
|
||
// Parse input and check files | ||
ParserConfig config(&context); | ||
auto checkMod = parseSourceFile<mlir::ModuleOp>(checkFilename, config); | ||
auto inputMod = parseSourceFile<mlir::ModuleOp>(inputFilename, config); | ||
if (!checkMod || !inputMod) { | ||
llvm::errs() << "Error parsing input files.\n"; | ||
return EXIT_FAILURE; | ||
} | ||
|
||
// Here we use std streams because Eigen printers don't work with LLVM ones. | ||
if (!cudaq::isApproxEqual(checkUnitary, inputUnitary, upToGlobalPhase)) { | ||
applyTolerance(checkUnitary); | ||
applyTolerance(inputUnitary); | ||
std::cerr << "Circuit: " << opName.str() << '\n'; | ||
std::cerr << "Expected:\n"; | ||
std::cerr << checkUnitary << '\n'; | ||
std::cerr << "Got:\n"; | ||
std::cerr << inputUnitary << '\n'; | ||
exitStatus = EXIT_FAILURE; | ||
// Apply canonicalization if enabled | ||
if (!dontCanonicalize) { | ||
PassManager pm(&context); | ||
OpPassManager &nestedFuncPM = pm.nest<func::FuncOp>(); | ||
nestedFuncPM.addPass(createCanonicalizerPass()); | ||
if (failed(pm.run(*checkMod)) || failed(pm.run(*inputMod))) { | ||
llvm::errs() << "Canonicalization failed.\n"; | ||
return EXIT_FAILURE; | ||
} | ||
} | ||
|
||
if (printUnitary) { | ||
applyTolerance(checkUnitary); | ||
std::cout << "Circuit: " << opName.str() << '\n' | ||
<< checkUnitary << "\n\n"; | ||
// Apply tolerance to small numerical values in matrices | ||
auto applyTolerance = [](cudaq::UnitaryBuilder::UMatrix &m) { | ||
m = (1e-12 < m.array().abs()).select(m, 0.0f); | ||
}; | ||
|
||
// Iterate through functions in check module and compare against input module | ||
int exitStatus = EXIT_SUCCESS; | ||
for (auto checkFunc : checkMod->getOps<func::FuncOp>()) { | ||
StringAttr opName = checkFunc.getSymNameAttr(); | ||
cudaq::UnitaryBuilder::UMatrix checkUnitary, inputUnitary; | ||
|
||
// Look up corresponding function in input module | ||
auto *inputOp = inputMod->lookupSymbol(opName); | ||
if (!inputOp) { | ||
llvm::errs() << "Function " << opName.str() << " not found in input.\n"; | ||
exitStatus = EXIT_FAILURE; | ||
continue; | ||
} | ||
Comment on lines
+84
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also good, since one might compile this without assertions. |
||
|
||
auto inputFunc = dyn_cast<func::FuncOp>(inputOp); | ||
if (!inputFunc || failed(computeUnitary(checkFunc, checkUnitary)) || | ||
failed(computeUnitary(inputFunc, inputUnitary, upToMapping))) { | ||
llvm::errs() << "Cannot compute unitary for " << opName.str() << ".\n"; | ||
exitStatus = EXIT_FAILURE; | ||
continue; | ||
} | ||
iqbalbhatti49 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Compare unitaries and report differences | ||
if (!cudaq::isApproxEqual(checkUnitary, inputUnitary, upToGlobalPhase)) { | ||
applyTolerance(checkUnitary); | ||
applyTolerance(inputUnitary); | ||
std::cerr << "Circuit: " << opName.str() << '\n'; | ||
std::cerr << "Expected:\n" << checkUnitary << '\n'; | ||
std::cerr << "Got:\n" << inputUnitary << '\n'; | ||
exitStatus = EXIT_FAILURE; | ||
} | ||
|
||
// Print unitaries if requested | ||
if (printUnitary) { | ||
applyTolerance(checkUnitary); | ||
std::cout << "Circuit: " << opName.str() << '\n' << checkUnitary << "\n\n"; | ||
} | ||
} | ||
} | ||
return exitStatus; | ||
return exitStatus; | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good check to have.