Skip to content

Commit 11dd02d

Browse files
1tnguyensacpis
andauthored
Fix a bug in Python run, run_async, sample_async, and observe_async: callable args are dropped (#3545)
* Fix a bug in Python cudaq.run: callable args are dropped Signed-off-by: Thien Nguyen <[email protected]> * Fix sample_async and observe_async as well Signed-off-by: Thien Nguyen <[email protected]> * Code format Signed-off-by: Thien Nguyen <[email protected]> * code review: remove test print statements Signed-off-by: Thien Nguyen <[email protected]> --------- Signed-off-by: Thien Nguyen <[email protected]> Co-authored-by: Sachin Pisal <[email protected]>
1 parent 156f80c commit 11dd02d

File tree

10 files changed

+205
-54
lines changed

10 files changed

+205
-54
lines changed

python/cudaq/kernel/kernel_decorator.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,25 @@ def __convertStringsToPauli__(self, arg):
453453

454454
return arg
455455

456+
def processCallableArg(self, arg):
457+
"""
458+
Process a callable argument
459+
"""
460+
if not isinstance(arg, PyKernelDecorator):
461+
emitFatalError(
462+
"Callable argument provided is not a cudaq.kernel decorated function."
463+
)
464+
# It may be that the provided input callable kernel
465+
# is not currently in the ModuleOp. Need to add it
466+
# if that is the case, we have to use the AST
467+
# so that it shares self.module's MLIR Context
468+
symbols = SymbolTable(self.module.operation)
469+
if nvqppPrefix + arg.name not in symbols:
470+
tmpBridge = PyASTBridge(self.capturedDataStorage,
471+
existingModule=self.module,
472+
disableEntryPointTag=True)
473+
tmpBridge.visit(globalAstRegistry[arg.name][0])
474+
456475
def __call__(self, *args):
457476
"""
458477
Invoke the CUDA-Q kernel. JIT compilation of the kernel AST to MLIR
@@ -500,16 +519,7 @@ def __call__(self, *args):
500519
if cc.CallableType.isinstance(mlirType):
501520
# Assume this is a PyKernelDecorator
502521
callableNames.append(arg.name)
503-
# It may be that the provided input callable kernel
504-
# is not currently in the ModuleOp. Need to add it
505-
# if that is the case, we have to use the AST
506-
# so that it shares self.module's MLIR Context
507-
symbols = SymbolTable(self.module.operation)
508-
if nvqppPrefix + arg.name not in symbols:
509-
tmpBridge = PyASTBridge(self.capturedDataStorage,
510-
existingModule=self.module,
511-
disableEntryPointTag=True)
512-
tmpBridge.visit(globalAstRegistry[arg.name][0])
522+
self.processCallableArg(arg)
513523

514524
# Convert `numpy` arrays to lists
515525
if cc.StdvecType.isinstance(mlirType) and hasattr(arg, "tolist"):

python/runtime/cudaq/algorithms/py_observe_async.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,20 @@ async_observe_result pyObserveAsync(py::object &kernel,
8080
if (py::len(kernelBlockArgs) != args.size())
8181
throw std::runtime_error(
8282
"Invalid number of arguments passed to observe_async.");
83-
83+
// Process any callable args
84+
const auto callableNames = getCallableNames(kernel, args);
8485
auto &platform = cudaq::get_platform();
8586
auto kernelName = kernel.attr("name").cast<std::string>();
8687
auto kernelMod = kernel.attr("module").cast<MlirModule>();
8788
args = simplifiedValidateInputArguments(args);
88-
auto *argData = toOpaqueArgs(args, kernelMod, kernelName);
89+
auto *argData =
90+
toOpaqueArgs(args, kernelMod, kernelName, getCallableArgHandler());
8991

9092
// Launch the asynchronous execution.
9193
py::gil_scoped_release release;
9294
return details::runObservationAsync(
93-
[argData, kernelName, kernelMod]() mutable {
94-
pyAltLaunchKernel(kernelName, kernelMod, *argData, {});
95+
[argData, kernelName, kernelMod, callableNames]() mutable {
96+
pyAltLaunchKernel(kernelName, kernelMod, *argData, callableNames);
9597
delete argData;
9698
},
9799
spin_operator, platform, shots, kernelName, qpu_id);

python/runtime/cudaq/algorithms/py_run.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ static std::vector<py::object> readRunResults(mlir::ModuleOp module,
3939
}
4040

4141
static std::tuple<std::string, MlirModule, OpaqueArguments *,
42-
mlir::func::FuncOp, std::string, mlir::func::FuncOp>
42+
mlir::func::FuncOp, std::string, mlir::func::FuncOp,
43+
std::vector<std::string>>
4344
getKernelLaunchParameters(py::object &kernel, py::args args) {
4445
if (!py::hasattr(kernel, "arguments"))
4546
throw std::runtime_error(
@@ -52,6 +53,9 @@ getKernelLaunchParameters(py::object &kernel, py::args args) {
5253
if (py::hasattr(kernel, "compile"))
5354
kernel.attr("compile")();
5455

56+
// Process any callable args
57+
const auto callableNames = getCallableNames(kernel, args);
58+
5559
auto origKernName = kernel.attr("name").cast<std::string>();
5660
auto kernelName = origKernName + ".run";
5761
if (!py::hasattr(kernel, "module") || kernel.attr("module").is_none())
@@ -76,16 +80,19 @@ getKernelLaunchParameters(py::object &kernel, py::args args) {
7680
throw std::runtime_error(
7781
"failed to autogenerate the runnable variant of the kernel.");
7882
}
79-
auto *argData = toOpaqueArgs(args, kernelMod, kernelName);
83+
auto *argData =
84+
toOpaqueArgs(args, kernelMod, kernelName, getCallableArgHandler());
8085
auto funcOp = getKernelFuncOp(kernelMod, kernelName);
81-
return {kernelName, kernelMod, argData, funcOp, origKernName, origKern};
86+
return {kernelName, kernelMod, argData, funcOp,
87+
origKernName, origKern, callableNames};
8288
}
8389

8490
static details::RunResultSpan
8591
pyRunTheKernel(const std::string &name, const std::string &origName,
8692
MlirModule module, mlir::func::FuncOp funcOp,
8793
mlir::func::FuncOp origKernel, OpaqueArguments &runtimeArgs,
8894
quantum_platform &platform, std::size_t shots_count,
95+
const std::vector<std::string> &callableNames,
8996
std::size_t qpu_id = 0) {
9097
auto returnTypes = origKernel.getResultTypes();
9198
if (returnTypes.empty() || returnTypes.size() > 1)
@@ -101,13 +108,13 @@ pyRunTheKernel(const std::string &name, const std::string &origName,
101108

102109
auto mod = unwrap(module);
103110

104-
auto [rawArgs, size, returnOffset, thunk] =
105-
pyAltLaunchKernelBase(name, module, returnTy, runtimeArgs, {}, 0, false);
111+
auto [rawArgs, size, returnOffset, thunk] = pyAltLaunchKernelBase(
112+
name, module, returnTy, runtimeArgs, callableNames, 0, false);
106113

107114
auto results = details::runTheKernel(
108115
[&]() mutable {
109116
pyLaunchKernel(name, thunk, mod, runtimeArgs, rawArgs, size,
110-
returnOffset, {});
117+
returnOffset, callableNames);
111118
},
112119
platform, name, origName, shots_count, qpu_id);
113120

@@ -133,7 +140,7 @@ std::vector<py::object> pyRun(py::object &kernel, py::args args,
133140
if (shots_count == 0)
134141
return {};
135142

136-
auto [name, module, argData, func, origName, origKern] =
143+
auto [name, module, argData, func, origName, origKern, callableNames] =
137144
getKernelLaunchParameters(kernel, args);
138145

139146
auto mod = unwrap(module);
@@ -149,7 +156,7 @@ std::vector<py::object> pyRun(py::object &kernel, py::args args,
149156
}
150157

151158
auto span = pyRunTheKernel(name, origName, module, func, origKern, *argData,
152-
platform, shots_count);
159+
platform, shots_count, callableNames);
153160
delete argData;
154161
auto results = pyReadResults(span, module, func, origKern, shots_count);
155162

@@ -184,7 +191,7 @@ async_run_result pyRunAsync(py::object &kernel, py::args args,
184191
") exceeds the number of available QPUs (" +
185192
std::to_string(numQPUs) + ")");
186193

187-
auto [name, module, argData, func, origName, origKern] =
194+
auto [name, module, argData, func, origName, origKern, callableNames] =
188195
getKernelLaunchParameters(kernel, args);
189196

190197
auto mod = unwrap(module);
@@ -219,16 +226,17 @@ async_run_result pyRunAsync(py::object &kernel, py::args args,
219226
QuantumTask wrapped = detail::make_copyable_function(
220227
[sp = std::move(spanPromise), ep = std::move(errorPromise), shots_count,
221228
qpu_id, argData, name, module, func, origKern, origName,
222-
noise_model = std::move(noise_model)]() mutable {
229+
noise_model = std::move(noise_model), callableNames]() mutable {
223230
auto &platform = get_platform();
224231

225232
// Launch the kernel in the appropriate context.
226233
if (noise_model.has_value())
227234
platform.set_noise(&noise_model.value());
228235

229236
try {
230-
auto span = pyRunTheKernel(name, origName, module, func, origKern,
231-
*argData, platform, shots_count, qpu_id);
237+
auto span =
238+
pyRunTheKernel(name, origName, module, func, origKern, *argData,
239+
platform, shots_count, callableNames, qpu_id);
232240
delete argData;
233241
sp.set_value(span);
234242
ep.set_value("");

python/runtime/cudaq/algorithms/py_sample_async.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ for more information on this programming pattern.)#")
8888
auto &platform = cudaq::get_platform();
8989
if (py::hasattr(kernel, "compile"))
9090
kernel.attr("compile")();
91+
// Process any callable args
92+
const auto callableNames = getCallableNames(kernel, args);
9193
auto kernelName = kernel.attr("name").cast<std::string>();
9294
// Clone the kernel module
9395
auto kernelMod = mlirModuleFromOperation(
@@ -118,7 +120,7 @@ for more information on this programming pattern.)#")
118120
// Hence, pass it as a unique_ptr for the functor to manage its
119121
// lifetime.
120122
std::unique_ptr<OpaqueArguments> argData(
121-
toOpaqueArgs(args, kernelMod, kernelName));
123+
toOpaqueArgs(args, kernelMod, kernelName, getCallableArgHandler()));
122124

123125
// Should only have C++ going on here, safe to release the GIL
124126
py::gil_scoped_release release;
@@ -129,9 +131,10 @@ for more information on this programming pattern.)#")
129131
// (2) This lambda might be executed multiple times, e.g, when
130132
// the kernel contains measurement feedback.
131133
cudaq::detail::make_copyable_function(
132-
[argData = std::move(argData), kernelName,
133-
kernelMod]() mutable {
134-
pyAltLaunchKernel(kernelName, kernelMod, *argData, {});
134+
[argData = std::move(argData), kernelName, kernelMod,
135+
callableNames]() mutable {
136+
pyAltLaunchKernel(kernelName, kernelMod, *argData,
137+
callableNames);
135138
}),
136139
platform, kernelName, shots, explicitMeasurements, qpu_id),
137140
std::move(mlirCtx));

python/runtime/cudaq/platform/py_alt_launch_kernel.cpp

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -109,16 +109,37 @@ void setDataLayout(MlirModule module) {
109109
}
110110
}
111111

112+
std::function<bool(OpaqueArguments &argData, py::object &arg)>
113+
getCallableArgHandler() {
114+
return [](cudaq::OpaqueArguments &argData, py::object &arg) {
115+
if (py::hasattr(arg, "module")) {
116+
// Just give it some dummy data that will not be used.
117+
// We synthesize away all callables, the block argument
118+
// remains but it is not used, so just give argsCreator
119+
// something, and we'll make sure its cleaned up.
120+
long *ourAllocatedArg = new long();
121+
argData.emplace_back(ourAllocatedArg,
122+
[](void *ptr) { delete static_cast<long *>(ptr); });
123+
return true;
124+
}
125+
return false;
126+
};
127+
}
128+
112129
/// @brief Create a new OpaqueArguments pointer and pack the python arguments
113130
/// in it. Clients must delete the memory.
114-
OpaqueArguments *toOpaqueArgs(py::args &args, MlirModule mod,
115-
const std::string &name) {
131+
OpaqueArguments *
132+
toOpaqueArgs(py::args &args, MlirModule mod, const std::string &name,
133+
const std::optional<
134+
std::function<bool(OpaqueArguments &argData, py::object &arg)>>
135+
&optionalBackupHandler) {
116136
auto kernelFunc = getKernelFuncOp(mod, name);
117137
auto *argData = new cudaq::OpaqueArguments();
118138
args = simplifiedValidateInputArguments(args);
119139
setDataLayout(mod);
120-
cudaq::packArgs(*argData, args, kernelFunc,
121-
[](OpaqueArguments &, py::object &) { return false; });
140+
auto backupHandler = optionalBackupHandler.value_or(
141+
[](OpaqueArguments &, py::object &) { return false; });
142+
cudaq::packArgs(*argData, args, kernelFunc, backupHandler);
122143
return argData;
123144
}
124145

@@ -998,26 +1019,32 @@ std::string getASM(const std::string &name, MlirModule module,
9981019
return str;
9991020
}
10001021

1022+
std::vector<std::string> getCallableNames(py::object &kernel, py::args &args) {
1023+
// Handle callable arguments, if any, similar to `PyKernelDecorator.__call__`,
1024+
// so that the callable arguments are properly packed for `pyAltLaunchKernel`
1025+
// as if it's launched from Python.
1026+
std::vector<std::string> callableNames;
1027+
for (std::size_t i = 0; i < args.size(); ++i) {
1028+
auto arg = args[i];
1029+
// If this is a `PyKernelDecorator` callable:
1030+
if (py::hasattr(arg, "__call__") && py::hasattr(arg, "module") &&
1031+
py::hasattr(arg, "name")) {
1032+
if (py::hasattr(arg, "compile"))
1033+
arg.attr("compile")();
1034+
1035+
if (py::hasattr(kernel, "processCallableArg"))
1036+
kernel.attr("processCallableArg")(arg);
1037+
callableNames.push_back(arg.attr("name").cast<std::string>());
1038+
}
1039+
}
1040+
return callableNames;
1041+
}
1042+
10011043
void bindAltLaunchKernel(py::module &mod,
10021044
std::function<std::string()> &&getTL) {
10031045
jitCache = std::make_unique<JITExecutionCache>();
10041046
getTransportLayer = std::move(getTL);
10051047

1006-
auto callableArgHandler = [](cudaq::OpaqueArguments &argData,
1007-
py::object &arg) {
1008-
if (py::hasattr(arg, "module")) {
1009-
// Just give it some dummy data that will not be used.
1010-
// We synthesize away all callables, the block argument
1011-
// remains but it is not used, so just give argsCreator
1012-
// something, and we'll make sure its cleaned up.
1013-
long *ourAllocatedArg = new long();
1014-
argData.emplace_back(ourAllocatedArg,
1015-
[](void *ptr) { delete static_cast<long *>(ptr); });
1016-
return true;
1017-
}
1018-
return false;
1019-
};
1020-
10211048
mod.def(
10221049
"pyAltLaunchKernel",
10231050
[&](const std::string &kernelName, MlirModule module,
@@ -1026,7 +1053,7 @@ void bindAltLaunchKernel(py::module &mod,
10261053

10271054
cudaq::OpaqueArguments args;
10281055
setDataLayout(module);
1029-
cudaq::packArgs(args, runtimeArgs, kernelFunc, callableArgHandler);
1056+
cudaq::packArgs(args, runtimeArgs, kernelFunc, getCallableArgHandler());
10301057
pyAltLaunchKernel(kernelName, module, args, callable_names);
10311058
},
10321059
py::arg("kernelName"), py::arg("module"), py::kw_only(),
@@ -1040,7 +1067,7 @@ void bindAltLaunchKernel(py::module &mod,
10401067

10411068
cudaq::OpaqueArguments args;
10421069
setDataLayout(module);
1043-
cudaq::packArgs(args, runtimeArgs, kernelFunc, callableArgHandler);
1070+
cudaq::packArgs(args, runtimeArgs, kernelFunc, getCallableArgHandler());
10441071
return pyAltLaunchKernelR(kernelName, module, returnType, args,
10451072
callable_names);
10461073
},

python/runtime/cudaq/platform/py_alt_launch_kernel.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,23 @@ namespace cudaq {
2626
/// @brief Set current architecture's data layout attribute on a module.
2727
void setDataLayout(MlirModule module);
2828

29+
/// @brief Get the default callable argument handler for packing arguments.
30+
std::function<bool(OpaqueArguments &argData, py::object &arg)>
31+
getCallableArgHandler();
32+
33+
/// @brief Get the names of callable arguments from the given kernel and
34+
/// arguments.
35+
// As we process the arguments, we also perform any extra processing required
36+
// for callable arguments.
37+
std::vector<std::string> getCallableNames(py::object &kernel, py::args &args);
38+
2939
/// @brief Create a new OpaqueArguments pointer and pack the
3040
/// python arguments in it. Clients must delete the memory.
31-
OpaqueArguments *toOpaqueArgs(py::args &args, MlirModule mod,
32-
const std::string &name);
41+
OpaqueArguments *
42+
toOpaqueArgs(py::args &args, MlirModule mod, const std::string &name,
43+
const std::optional<
44+
std::function<bool(OpaqueArguments &argData, py::object &arg)>>
45+
&optionalBackupHandler = std::nullopt);
3346

3447
inline std::size_t byteSize(mlir::Type ty) {
3548
if (isa<mlir::ComplexType>(ty)) {

python/tests/kernel/test_observe_kernel.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import pytest
1212
import numpy as np
13-
from typing import List
13+
from typing import List, Callable
1414

1515
import cudaq
1616
from cudaq import spin
@@ -343,3 +343,31 @@ def gqeCirc2(N: int, thetas: list[float], paulis: list[cudaq.pauli_word]):
343343
exp_val2 = cudaq.observe_async(gqeCirc2, obs, numQubits, list(ts),
344344
pauliStings).get().expectation()
345345
print('observe_async exp_val2', exp_val2)
346+
347+
348+
def test_observe_callable():
349+
"""Test that we can observe kernels with callable arguments."""
350+
351+
@cudaq.kernel
352+
def ansatz_callable(angle: float, rotate: Callable[[cudaq.qubit, float],
353+
None]):
354+
q = cudaq.qvector(2)
355+
x(q[0])
356+
rotate(q[1], angle)
357+
x.ctrl(q[1], q[0])
358+
359+
@cudaq.kernel
360+
def ry_rotate(qubit: cudaq.qubit, angle: float):
361+
ry(angle, qubit)
362+
363+
hamiltonian = 5.907 - 2.1433 * spin.x(0) * spin.x(1) - 2.1433 * spin.y(
364+
0) * spin.y(1) + .21829 * spin.z(0) - 6.125 * spin.z(1)
365+
366+
result = cudaq.observe(ansatz_callable, hamiltonian, .59, ry_rotate)
367+
print(result.expectation())
368+
assert np.isclose(result.expectation(), -1.74, atol=1e-2)
369+
370+
result_async = cudaq.observe_async(ansatz_callable, hamiltonian, .59,
371+
ry_rotate).get()
372+
print(result_async.expectation())
373+
assert np.isclose(result_async.expectation(), -1.74, atol=1e-2)

0 commit comments

Comments
 (0)