Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] [refactoring] Replace MidCircuitMeasurementAnalyzer with existing quake-add-metadata pass #2610

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 0 additions & 79 deletions python/cudaq/kernel/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,85 +13,6 @@
from ..mlir._mlir_libs._quakeDialects import cudaq_runtime


class MidCircuitMeasurementAnalyzer(ast.NodeVisitor):
"""
The `MidCircuitMeasurementAnalyzer` is a utility class searches for
common measurement - conditional patterns to indicate to the runtime
that we have a circuit with mid-circuit measurement and subsequent conditional
quantum operation application.
"""

def __init__(self):
self.measureResultsVars = []
self.hasMidCircuitMeasures = False

def isMeasureCallOp(self, node):
return isinstance(
node, ast.Call) and node.__dict__['func'].id in ['mx', 'my', 'mz']

def visit_Assign(self, node):
target = node.targets[0]
# Check if a variable is assigned from result(s) of measurement
if hasattr(node, 'value') and hasattr(
node.value, 'id') and node.value.id in self.measureResultsVars:
self.measureResultsVars.append(target.id)
return
if not 'func' in node.value.__dict__:
return
creatorFunc = node.value.func
if 'id' in creatorFunc.__dict__ and creatorFunc.id in [
'mx', 'my', 'mz'
]:
self.measureResultsVars.append(target.id)

# Get the variable name from a variable node.
# Returns an empty string if not something we know how to get a variable name from.
def getVariableName(self, node):
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Subscript):
return self.getVariableName(node.value)
return ''

def checkForMeasureResult(self, value):
return self.isMeasureCallOp(value) or self.getVariableName(
value) in self.measureResultsVars

def visit_If(self, node):
condition = node.test

# Catch `if mz(q)`, `if val`, where `val = mz(q)` or `if var[k]`, where `var = mz(qvec)`
if self.checkForMeasureResult(condition):
self.hasMidCircuitMeasures = True
return

# Compare: look at left expression.
# Catch `if var == True/False` and `if var[k] == True/False:` or `if mz(q) == True/False`
if isinstance(condition, ast.Compare) and self.checkForMeasureResult(
condition.left):
self.hasMidCircuitMeasures = True
return

# Catch `if UnaryOp mz(q)` or `if UnaryOp var` (`var = mz(q)`)
if isinstance(condition, ast.UnaryOp) and self.checkForMeasureResult(
condition.operand):
self.hasMidCircuitMeasures = True
return

# Catch `if something BoolOp mz(q)` or `something BoolOp var` (`var = mz(q)`)
if isinstance(condition, ast.BoolOp) and 'values' in condition.__dict__:

for value in condition.__dict__['values']:
if self.checkForMeasureResult(value):
self.hasMidCircuitMeasures = True
return
if isinstance(value,
ast.Compare) and self.checkForMeasureResult(
value.left):
self.hasMidCircuitMeasures = True
return


class FindDepKernelsVisitor(ast.NodeVisitor):

def __init__(self, ctx):
Expand Down
17 changes: 6 additions & 11 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4232,8 +4232,8 @@ def visit_Name(self, node):
node)


def compile_to_mlir(astModule, metadata,
capturedDataStorage: CapturedDataStorage, **kwargs):
def compile_to_mlir(astModule, capturedDataStorage: CapturedDataStorage,
**kwargs):
"""
Compile the given Python AST Module for the CUDA-Q
kernel FunctionDef to an MLIR `ModuleOp`.
Expand Down Expand Up @@ -4316,22 +4316,17 @@ def compile_to_mlir(astModule, metadata,
if verbose:
print(bridge.module)

# Canonicalize the code
pm = PassManager.parse("builtin.module(canonicalize,cse)",
context=bridge.ctx)
# Canonicalize the code, check for measurement(s) readout
pm = PassManager.parse(
"builtin.module(canonicalize,cse,func.func(quake-add-metadata))",
context=bridge.ctx)

try:
pm.run(bridge.module)
except:
raise RuntimeError("could not compile code for '{}'.".format(
bridge.name))

if metadata['conditionalOnMeasure']:
SymbolTable(
bridge.module.operation)[nvqppPrefix +
bridge.name].attributes.__setitem__(
'qubitMeasurementFeedback',
BoolAttr.get(True, context=bridge.ctx))
extraMetaData = {}
if len(bridge.dependentCaptureVars):
extraMetaData['dependent_captures'] = bridge.dependentCaptureVars
Expand Down
4 changes: 2 additions & 2 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __init__(self, argTypeList):
cc.register_dialect(self.ctx)
cudaq_runtime.registerLLVMDialectTranslation(self.ctx)

self.metadata = {'conditionalOnMeasure': False}
self.conditionalOnMeasure = False
self.regCounter = 0
self.loc = Location.unknown(context=self.ctx)
self.module = Module.create(loc=self.loc)
Expand Down Expand Up @@ -1341,7 +1341,7 @@ def then_function():
function()
self.insertPoint = tmpIp
cc.ContinueOp([])
self.metadata['conditionalOnMeasure'] = True
self.conditionalOnMeasure = True

def for_loop(self, start, stop, function):
"""Add a for loop that starts from the given `start` index,
Expand Down
38 changes: 23 additions & 15 deletions python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..mlir.dialects import quake, cc, func
from .ast_bridge import compile_to_mlir, PyASTBridge
from .utils import mlirTypeFromPyType, nvqppPrefix, mlirTypeToPyType, globalAstRegistry, emitFatalError, emitErrorIfInvalidPauli, globalRegisteredTypes
from .analysis import MidCircuitMeasurementAnalyzer, HasReturnNodeVisitor
from .analysis import HasReturnNodeVisitor
from ..mlir._mlir_libs._quakeDialects import cudaq_runtime
from .captured_data import CapturedDataStorage
from ..handlers import PhotonicsHandler
Expand Down Expand Up @@ -159,11 +159,6 @@ def __init__(self,
'CUDA-Q kernel has return statement but no return type annotation.'
)

# Run analyzers and attach metadata (only have 1 right now)
analyzer = MidCircuitMeasurementAnalyzer()
analyzer.visit(self.astModule)
self.metadata = {'conditionalOnMeasure': analyzer.hasMidCircuitMeasures}

# Store the AST for this kernel, it is needed for
# building up call graphs. We also must retain
# the source code location for error diagnostics
Expand All @@ -175,6 +170,15 @@ def compile(self):
if the kernel is already compiled.
"""

# TODO: Refactor this check
try:
target_name = cudaq_runtime.get_target().name
except RuntimeError:
target_name = None

if 'orca-photonics' == target_name:
return

# Before we can execute, we need to make sure
# variables from the parent frame that we captured
# have not changed. If they have changed, we need to
Expand Down Expand Up @@ -204,12 +208,17 @@ def compile(self):
break
s = s.f_back

if self.module != None:
if self.module is not None:
return

# Cleanup up the captured data if the module needs recompilation.
if self.capturedDataStorage is not None:
self.capturedDataStorage.__del__()
self.capturedDataStorage = self.createStorage()

# Caches the module and stores captured data into `self.capturedDataStorage`.
self.module, self.argTypes, extraMetadata = compile_to_mlir(
self.astModule,
self.metadata,
self.capturedDataStorage,
verbose=self.verbose,
returnType=self.returnType,
Expand Down Expand Up @@ -416,9 +425,6 @@ def __call__(self, *args):
PhotonicsHandler(self.kernelFunction)(*callable_args)
return

# Prepare captured state storage for the run
self.capturedDataStorage = self.createStorage()

# Compile, no-op if the module is not None
self.compile()

Expand Down Expand Up @@ -498,8 +504,7 @@ def __call__(self, *args):
self.module,
*processedArgs,
callable_names=callableNames)
self.capturedDataStorage.__del__()
self.capturedDataStorage = None

else:
result = cudaq_runtime.pyAltLaunchKernelR(
self.name,
Expand All @@ -508,10 +513,13 @@ def __call__(self, *args):
*processedArgs,
callable_names=callableNames)

self.capturedDataStorage.__del__()
self.capturedDataStorage = None
return result

def __del__(self):
if self.capturedDataStorage is not None:
self.capturedDataStorage.__del__()
self.capturedDataStorage = None


def kernel(function=None, **kwargs):
"""
Expand Down
17 changes: 15 additions & 2 deletions python/cudaq/runtime/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

from ..mlir._mlir_libs._quakeDialects import cudaq_runtime
from .utils import __isBroadcast, __createArgumentSet
from ..kernel.kernel_builder import PyKernel
from ..kernel.kernel_decorator import PyKernelDecorator
from ..kernel.utils import nvqppPrefix


def __broadcastSample(kernel,
Expand Down Expand Up @@ -66,8 +70,17 @@ def sample(kernel,
or a list of such results in the case of `sample` function broadcasting."""

has_conditionals_on_measure_result = False
if hasattr(kernel, 'metadata') and kernel.metadata.get(
'conditionalOnMeasure', False):

if isinstance(kernel, PyKernelDecorator):
kernel.compile()
if kernel.module is not None:
for operation in kernel.module.body.operations:
if not hasattr(operation, 'name'):
continue
if nvqppPrefix + kernel.name == operation.name.value and 'qubitMeasurementFeedback' in operation.attributes:
has_conditionals_on_measure_result = True
break
elif isinstance(kernel, PyKernel) and kernel.conditionalOnMeasure:
has_conditionals_on_measure_result = True

if explicit_measurements:
Expand Down
4 changes: 2 additions & 2 deletions python/tests/builder/test_qalloc_init_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ def test_kernel_complex64_capture_f32():
state = cudaq.State.from_data(c)

@cudaq.kernel
def kernel():
def kernel_foo():
q = cudaq.qvector(state)

counts = cudaq.sample(kernel)
counts = cudaq.sample(kernel_foo)
print(counts)
assert '11' in counts
assert '00' in counts
Expand Down
Loading
Loading