Skip to content

[Python] [refactoring] Replace MidCircuitMeasurementAnalyzer with existing quake-add-metadata pass #2610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/cudaq/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,17 @@
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

from .photonics_kernel import PhotonicsHandler
from cudaq.mlir._mlir_libs._quakeDialects import cudaq_runtime
from .target_handler import DefaultTargetHandler, PhotonicsTargetHandler

# Registry of target handlers
TARGET_HANDLERS = {'orca-photonics': PhotonicsTargetHandler()}


def get_target_handler():
"""Get the appropriate target handler based on current target"""
try:
target_name = cudaq_runtime.get_target().name
return TARGET_HANDLERS.get(target_name, DefaultTargetHandler())
except RuntimeError:
return DefaultTargetHandler()
55 changes: 55 additions & 0 deletions python/cudaq/handlers/target_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# ============================================================================ #
# Copyright (c) 2025 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

from cudaq.mlir._mlir_libs._quakeDialects import cudaq_runtime
from .photonics_kernel import PhotonicsHandler


class TargetHandler:
"""Base class for target-specific behavior"""

def skip_compilation(self):
# By default, perform compilation on the kernel
return False

def call_processed(self, decorator, args):
# `None` indicates standard call should be used
return None


class DefaultTargetHandler(TargetHandler):
"""Standard target handler"""
pass


class PhotonicsTargetHandler(TargetHandler):
"""Handler for `orca-photonics` target"""

def skip_compilation(self):
return True

def call_processed(self, kernel, args):
if kernel is None:
raise RuntimeError(
"The 'orca-photonics' target must be used with a valid function."
)
# NOTE: Since this handler does not support MLIR mode (yet), just
# invoke the kernel. If calling from a bound function, need to
# unpack the arguments, for example, see `pyGetStateLibraryMode`
try:
context_name = cudaq_runtime.getExecutionContextName()
except RuntimeError:
context_name = None

callable_args = args
if "extract-state" == context_name and len(args) == 1:
callable_args = args[0]

PhotonicsHandler(kernel)(*callable_args)
# `True` indicates call was handled
return True
98 changes: 0 additions & 98 deletions python/cudaq/kernel/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,104 +16,6 @@
from .utils import globalAstRegistry, globalKernelRegistry, mlirTypeFromAnnotation


class MidCircuitMeasurementAnalyzer(ast.NodeVisitor):
"""
The `MidCircuitMeasurementAnalyzer` is a utility class searches for
common measurement - conditional patterns to indicate to the runtime
that we have a circuit with mid-circuit measurement and subsequent conditional
quantum operation application.
"""

def __init__(self):
self.measureResultsVars = []
self.hasMidCircuitMeasures = False

def isMeasureCallOp(self, node):
return isinstance(
node, ast.Call) and node.__dict__['func'].id in ['mx', 'my', 'mz']

def visit_Assign(self, node):
target = node.targets[0]
# Check if a variable is assigned from result(s) of measurement
if hasattr(node, 'value') and hasattr(
node.value, 'id') and node.value.id in self.measureResultsVars:
self.measureResultsVars.append(target.id)
return
# Check if the new variable is assigned from a measurement result
if hasattr(node, 'value') and isinstance(
node.value,
ast.Name) and node.value.id in self.measureResultsVars:
self.measureResultsVars.append(target.id)
return
# Check if the new variable uses measurement results
if hasattr(node, 'value') and isinstance(
node.value, ast.BoolOp) and 'values' in node.value.__dict__:
for value in node.value.__dict__['values']:
if hasattr(value, 'id') and value.id in self.measureResultsVars:
self.measureResultsVars.append(target.id)
return
if not 'func' in node.value.__dict__:
return
creatorFunc = node.value.func
if 'id' in creatorFunc.__dict__ and creatorFunc.id in [
'mx', 'my', 'mz'
]:
self.measureResultsVars.append(target.id)

# Get the variable name from a variable node.
# Returns an empty string if not something we know how to get a variable name from.
def getVariableName(self, node):
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Subscript):
return self.getVariableName(node.value)
return ''

def checkForMeasureResult(self, value):
if self.isMeasureCallOp(value):
return True
if self.getVariableName(value) in self.measureResultsVars:
return True
if isinstance(value, ast.BoolOp) and 'values' in value.__dict__:
for val in value.__dict__['values']:
if self.getVariableName(val) in self.measureResultsVars:
return True

def visit_If(self, node):
condition = node.test

# Catch `if mz(q)`, `if val`, where `val = mz(q)` or `if var[k]`, where `var = mz(qvec)`
if self.checkForMeasureResult(condition):
self.hasMidCircuitMeasures = True
return

# Compare: look at left expression.
# Catch `if var == True/False` and `if var[k] == True/False:` or `if mz(q) == True/False`
if isinstance(condition, ast.Compare) and self.checkForMeasureResult(
condition.left):
self.hasMidCircuitMeasures = True
return

# Catch `if UnaryOp mz(q)` or `if UnaryOp var` (`var = mz(q)`)
if isinstance(condition, ast.UnaryOp) and self.checkForMeasureResult(
condition.operand):
self.hasMidCircuitMeasures = True
return

# Catch `if something BoolOp mz(q)` or `something BoolOp var` (`var = mz(q)`)
if isinstance(condition, ast.BoolOp) and 'values' in condition.__dict__:

for value in condition.__dict__['values']:
if self.checkForMeasureResult(value):
self.hasMidCircuitMeasures = True
return
if isinstance(value,
ast.Compare) and self.checkForMeasureResult(
value.left):
self.hasMidCircuitMeasures = True
return


class FindDepKernelsVisitor(ast.NodeVisitor):

def __init__(self, ctx):
Expand Down
17 changes: 6 additions & 11 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4248,8 +4248,8 @@ def visit_Name(self, node):
node)


def compile_to_mlir(astModule, metadata,
capturedDataStorage: CapturedDataStorage, **kwargs):
def compile_to_mlir(astModule, capturedDataStorage: CapturedDataStorage,
**kwargs):
"""
Compile the given Python AST Module for the CUDA-Q
kernel FunctionDef to an MLIR `ModuleOp`.
Expand Down Expand Up @@ -4332,22 +4332,17 @@ def compile_to_mlir(astModule, metadata,
if verbose:
print(bridge.module)

# Canonicalize the code
pm = PassManager.parse("builtin.module(canonicalize,cse)",
context=bridge.ctx)
# Canonicalize the code, check for measurement(s) readout
pm = PassManager.parse(
"builtin.module(canonicalize,cse,func.func(quake-add-metadata))",
context=bridge.ctx)

try:
pm.run(bridge.module)
except:
raise RuntimeError("could not compile code for '{}'.".format(
bridge.name))

if metadata['conditionalOnMeasure']:
SymbolTable(
bridge.module.operation)[nvqppPrefix +
bridge.name].attributes.__setitem__(
'qubitMeasurementFeedback',
BoolAttr.get(True, context=bridge.ctx))
extraMetaData = {}
if len(bridge.dependentCaptureVars):
extraMetaData['dependent_captures'] = bridge.dependentCaptureVars
Expand Down
4 changes: 2 additions & 2 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, argTypeList):
cc.register_dialect(self.ctx)
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)

self.metadata = {'conditionalOnMeasure': False}
self.conditionalOnMeasure = False
self.regCounter = 0
self.loc = Location.unknown(context=self.ctx)
self.module = Module.create(loc=self.loc)
Expand Down Expand Up @@ -1366,7 +1366,7 @@ def then_function():
function()
self.insertPoint = tmpIp
cc.ContinueOp([])
self.metadata['conditionalOnMeasure'] = True
self.conditionalOnMeasure = True

def for_loop(self, start, stop, function):
"""Add a for loop that starts from the given `start` index,
Expand Down
40 changes: 9 additions & 31 deletions python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import json
import numpy as np

from cudaq.handlers import PhotonicsHandler
from cudaq.handlers import get_target_handler
from cudaq.mlir._mlir_libs._quakeDialects import cudaq_runtime
from cudaq.mlir.dialects import cc, func
from cudaq.mlir.ir import (ComplexType, F32Type, F64Type, IntegerType,
SymbolTable)
from .analysis import MidCircuitMeasurementAnalyzer, HasReturnNodeVisitor
from .analysis import HasReturnNodeVisitor
from .ast_bridge import compile_to_mlir, PyASTBridge
from .captured_data import CapturedDataStorage
from .utils import (emitFatalError, emitErrorIfInvalidPauli, globalAstRegistry,
Expand Down Expand Up @@ -160,11 +160,6 @@ def __init__(self,
'CUDA-Q kernel has return statement but no return type annotation.'
)

# Run analyzers and attach metadata (only have 1 right now)
analyzer = MidCircuitMeasurementAnalyzer()
analyzer.visit(self.astModule)
self.metadata = {'conditionalOnMeasure': analyzer.hasMidCircuitMeasures}

# Store the AST for this kernel, it is needed for
# building up call graphs. We also must retain
# the source code location for error diagnostics
Expand All @@ -176,6 +171,10 @@ def compile(self):
if the kernel is already compiled.
"""

handler = get_target_handler()
if handler.skip_compilation() is True:
return

# Before we can execute, we need to make sure
# variables from the parent frame that we captured
# have not changed. If they have changed, we need to
Expand Down Expand Up @@ -205,7 +204,7 @@ def compile(self):
break
s = s.f_back

if self.module != None:
if self.module is not None:
return

# Cleanup up the captured data if the module needs recompilation.
Expand All @@ -214,7 +213,6 @@ def compile(self):
# Caches the module and stores captured data into `self.capturedDataStorage`.
self.module, self.argTypes, extraMetadata = compile_to_mlir(
self.astModule,
self.metadata,
self.capturedDataStorage,
verbose=self.verbose,
returnType=self.returnType,
Expand Down Expand Up @@ -397,28 +395,8 @@ def __call__(self, *args):
requires custom handling.
"""

# Check if target is set
try:
target_name = cudaq_runtime.get_target().name
except RuntimeError:
target_name = None

if 'orca-photonics' == target_name:
if self.kernelFunction is None:
raise RuntimeError(
"The 'orca-photonics' target must be used with a valid function."
)
# NOTE: Since this handler does not support MLIR mode (yet), just
# invoke the kernel. If calling from a bound function, need to
# unpack the arguments, for example, see `pyGetStateLibraryMode`
try:
context_name = cudaq_runtime.getExecutionContextName()
except RuntimeError:
context_name = None
callable_args = args
if "extract-state" == context_name and len(args) == 1:
callable_args = args[0]
PhotonicsHandler(self.kernelFunction)(*callable_args)
handler = get_target_handler()
if handler.call_processed(self.kernelFunction, args) is True:
return

# Compile, no-op if the module is not None
Expand Down
16 changes: 14 additions & 2 deletions python/cudaq/runtime/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #
from cudaq.mlir._mlir_libs._quakeDialects import cudaq_runtime
from cudaq.kernel.kernel_builder import PyKernel
from cudaq.kernel.kernel_decorator import PyKernelDecorator
from cudaq.kernel.utils import nvqppPrefix
from .utils import __isBroadcast, __createArgumentSet


Expand Down Expand Up @@ -66,8 +69,17 @@ def sample(kernel,
or a list of such results in the case of `sample` function broadcasting."""

has_conditionals_on_measure_result = False
if hasattr(kernel, 'metadata') and kernel.metadata.get(
'conditionalOnMeasure', False):

if isinstance(kernel, PyKernelDecorator):
kernel.compile()
if kernel.module is not None:
for operation in kernel.module.body.operations:
if not hasattr(operation, 'name'):
continue
if nvqppPrefix + kernel.name == operation.name.value:
has_conditionals_on_measure_result = 'qubitMeasurementFeedback' in operation.attributes
break
elif isinstance(kernel, PyKernel) and kernel.conditionalOnMeasure:
has_conditionals_on_measure_result = True

if explicit_measurements:
Expand Down
2 changes: 1 addition & 1 deletion python/tests/kernel/test_explicit_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def kernel_with_conditional_on_measure():
e)

## NOTE: The following does not fail.
## See: https://github.com/NVIDIA/cuda-quantum/issues/2000
## Needs inlining of the function calls.
# @cudaq.kernel
# def measure(q: cudaq.qubit) -> bool:
# return mz(q)
Expand Down
Loading