Skip to content

Commit c12d2d3

Browse files
* Patch from Anna to handle the data capture and storage issue.
* Removed the skipping of Python tests ----------------------- Signed-off-by: Pradnya Khalate <[email protected]> Co-authored-by: Anna Gringauze <[email protected]>
1 parent 706f12f commit c12d2d3

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

python/cudaq/kernel/kernel_decorator.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ def compile(self):
211211
if self.module != None:
212212
return
213213

214+
# FIX: Cleanup up the captured data if the module needs recompilation.
215+
if self.capturedDataStorage != None:
216+
self.capturedDataStorage.__del__()
217+
self.capturedDataStorage = self.createStorage()
218+
219+
# Caches the module and stores captured data into self.capturedDataStorage.
214220
self.module, self.argTypes, extraMetadata = compile_to_mlir(
215221
self.astModule,
216222
self.capturedDataStorage,
@@ -420,7 +426,8 @@ def __call__(self, *args):
420426
return
421427

422428
# Prepare captured state storage for the run
423-
self.capturedDataStorage = self.createStorage()
429+
# FIX: moved to compile()
430+
# self.capturedDataStorage = self.createStorage()
424431

425432
# Compile, no-op if the module is not None
426433
self.compile()
@@ -501,8 +508,9 @@ def __call__(self, *args):
501508
self.module,
502509
*processedArgs,
503510
callable_names=callableNames)
504-
self.capturedDataStorage.__del__()
505-
self.capturedDataStorage = None
511+
# FIX: do not delete the captured storage
512+
# self.capturedDataStorage.__del__()
513+
# self.capturedDataStorage = None
506514
else:
507515
result = cudaq_runtime.pyAltLaunchKernelR(
508516
self.name,
@@ -511,8 +519,9 @@ def __call__(self, *args):
511519
*processedArgs,
512520
callable_names=callableNames)
513521

514-
self.capturedDataStorage.__del__()
515-
self.capturedDataStorage = None
522+
# FIX: do not delete the captured storage
523+
# self.capturedDataStorage.__del__()
524+
# self.capturedDataStorage = None
516525
return result
517526

518527

python/tests/builder/test_qalloc_init_state.py

-2
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def test_kernel_complex_params_f32():
154154

155155

156156
@skipIfNvidiaFP64NotInstalled
157-
@pytest.mark.skip(reason="kernel.compile() causes a crash")
158157
def test_kernel_complex_capture_f64():
159158
cudaq.reset_target()
160159
cudaq.set_target('nvidia-fp64')
@@ -207,7 +206,6 @@ def test_kernel_complex128_capture_f64():
207206

208207

209208
@skipIfNvidiaNotInstalled
210-
@pytest.mark.skip(reason="kernel.compile() causes a crash")
211209
def test_kernel_complex64_capture_f32():
212210
cudaq.reset_target()
213211
cudaq.set_target('nvidia')

python/tests/kernel/test_kernel_qvector_state_init.py

-5
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def test_kernel_complex_params_f32():
175175

176176

177177
@skipIfNvidiaFP64NotInstalled
178-
@pytest.mark.skip(reason="kernel.compile() causes a crash")
179178
def test_kernel_complex_capture_f64():
180179
cudaq.reset_target()
181180
cudaq.set_target('nvidia', option='fp64')
@@ -228,7 +227,6 @@ def test_kernel_complex128_capture_f64():
228227

229228

230229
@skipIfNvidiaNotInstalled
231-
@pytest.mark.skip(reason="kernel.compile() causes a crash")
232230
def test_kernel_complex64_capture_f32():
233231
cudaq.reset_target()
234232
cudaq.set_target('nvidia')
@@ -316,7 +314,6 @@ def kernel(vec: cudaq.State):
316314

317315

318316
@skipIfNvidiaFP64NotInstalled
319-
@pytest.mark.skip(reason="kernel.compile() causes a crash")
320317
def test_kernel_simulation_dtype_capture_f64():
321318
cudaq.reset_target()
322319
cudaq.set_target('nvidia', option='fp64')
@@ -337,7 +334,6 @@ def kernel():
337334

338335

339336
@skipIfNvidiaNotInstalled
340-
@pytest.mark.skip(reason="kernel.compile() causes a crash")
341337
def test_kernel_simulation_dtype_capture_f32():
342338
cudaq.reset_target()
343339
cudaq.set_target('nvidia')
@@ -418,7 +414,6 @@ def kernel(initialState: cudaq.State):
418414
assert not '01' in counts
419415

420416

421-
@pytest.mark.skip(reason="kernel0.compile() causes a crash")
422417
def test_inner_kernels_state():
423418
c = np.array([1. / np.sqrt(2.) + 0j, 0., 0., 1. / np.sqrt(2.)],
424419
dtype=cudaq.complex())

0 commit comments

Comments
 (0)