Skip to content

Commit c09b3d0

Browse files
authored
Derive _kraus_ from _apply_channel_ (#7434)
Add fallback strategy to the kraus protocol to obtain Kraus values from `_apply_channel_`. Attempted as a very-last strategy, because it is computationally expensive. Resolves #5921
1 parent 6b6a7ff commit c09b3d0

File tree

3 files changed

+113
-0
lines changed

3 files changed

+113
-0
lines changed

cirq-core/cirq/protocols/kraus_protocol.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
from typing_extensions import Protocol
2525

26+
from cirq import protocols, qis
2627
from cirq._doc import doc_private
2728
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
2829
from cirq.protocols.mixture_protocol import has_mixture
@@ -94,6 +95,37 @@ def _has_kraus_(self) -> bool:
9495
"""
9596

9697

98+
def _strat_kraus_from_apply_channel(val: Any) -> tuple[np.ndarray, ...] | None:
99+
"""Attempts to compute a value's Kraus operators via its _apply_channel_ method.
100+
This is very expensive (O(16^N)), so only do this as a last resort."""
101+
method = getattr(val, '_apply_channel_', None)
102+
if method is None:
103+
return None
104+
105+
qid_shape = protocols.qid_shape(val)
106+
107+
eye = qis.eye_tensor(qid_shape * 2, dtype=np.complex128)
108+
buffer = np.empty_like(eye)
109+
buffer.fill(float('nan'))
110+
superop = protocols.apply_channel(
111+
val=val,
112+
args=protocols.ApplyChannelArgs(
113+
target_tensor=eye,
114+
out_buffer=buffer,
115+
auxiliary_buffer0=buffer.copy(),
116+
auxiliary_buffer1=buffer.copy(),
117+
left_axes=list(range(len(qid_shape))),
118+
right_axes=list(range(len(qid_shape), len(qid_shape) * 2)),
119+
),
120+
default=None,
121+
)
122+
if superop is None or superop is NotImplemented:
123+
return None
124+
n = np.prod(qid_shape) ** 2
125+
kraus_ops = qis.superoperator_to_kraus(superop.reshape((n, n)))
126+
return tuple(kraus_ops)
127+
128+
97129
def kraus(
98130
val: Any, default: Any = RaiseTypeErrorIfNotProvided
99131
) -> tuple[np.ndarray, ...] | TDefault:
@@ -159,6 +191,14 @@ def kraus(
159191
if channel_result is not NotImplemented:
160192
return tuple(channel_result) # pragma: no cover
161193

194+
# Last-resort fallback: try to derive Kraus from _apply_channel_.
195+
# Note: _apply_channel can lead to kraus being called again, so if default
196+
# is None, this can trigger an infinite loop.
197+
if default is not None:
198+
result = _strat_kraus_from_apply_channel(val)
199+
if result is not None:
200+
return result
201+
162202
if default is not RaiseTypeErrorIfNotProvided:
163203
return default
164204

cirq-core/cirq/protocols/kraus_protocol_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytest
2323

2424
import cirq
25+
from cirq.protocols.apply_channel_protocol import _apply_kraus
2526

2627
LOCAL_DEFAULT: list[np.ndarray] = [np.array([])]
2728

@@ -171,3 +172,74 @@ def test_has_kraus_when_decomposed(decomposed_cls) -> None:
171172
op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test'))
172173
assert cirq.has_kraus(op)
173174
assert not cirq.has_kraus(op, allow_decompose=False)
175+
176+
177+
def test_strat_kraus_from_apply_channel_returns_none():
178+
# Remove _kraus_ and _apply_channel_ methods
179+
class NoApplyChannelReset(cirq.ResetChannel):
180+
def _kraus_(self):
181+
return NotImplemented
182+
183+
def _apply_channel_(self, args):
184+
return NotImplemented
185+
186+
gate_no_apply = NoApplyChannelReset()
187+
with pytest.raises(
188+
TypeError,
189+
match="does have a _kraus_, _mixture_ or _unitary_ method, but it returned NotImplemented",
190+
):
191+
cirq.kraus(gate_no_apply)
192+
193+
194+
@pytest.mark.parametrize(
195+
'channel_cls,params',
196+
[
197+
(cirq.BitFlipChannel, (0.5,)),
198+
(cirq.PhaseFlipChannel, (0.3,)),
199+
(cirq.DepolarizingChannel, (0.2,)),
200+
(cirq.AmplitudeDampingChannel, (0.4,)),
201+
(cirq.PhaseDampingChannel, (0.25,)),
202+
],
203+
)
204+
def test_kraus_fallback_to_apply_channel(channel_cls, params) -> None:
205+
"""Kraus protocol falls back to _apply_channel_ when no _kraus_, _mixture_, or _unitary_."""
206+
# Create the expected channel and get its Kraus operators
207+
expected_channel = channel_cls(*params)
208+
expected_kraus = cirq.kraus(expected_channel)
209+
210+
class TestChannel:
211+
def __init__(self, channel_cls, params):
212+
self.channel_cls = channel_cls
213+
self.params = params
214+
self.expected_kraus = cirq.kraus(channel_cls(*params))
215+
216+
def _num_qubits_(self):
217+
return 1
218+
219+
def _apply_channel_(self, args: cirq.ApplyChannelArgs):
220+
return _apply_kraus(self.expected_kraus, args)
221+
222+
chan = TestChannel(channel_cls, params)
223+
kraus_ops = cirq.kraus(chan)
224+
225+
# Compare the superoperator matrices for equivalence
226+
expected_super = sum(np.kron(k, k.conj()) for k in expected_kraus)
227+
actual_super = sum(np.kron(k, k.conj()) for k in kraus_ops)
228+
np.testing.assert_allclose(actual_super, expected_super, atol=1e-8)
229+
230+
231+
def test_reset_channel_kraus_apply_channel_consistency():
232+
Reset = cirq.ResetChannel
233+
# Original gate
234+
gate = Reset()
235+
cirq.testing.assert_has_consistent_apply_channel(gate)
236+
cirq.testing.assert_consistent_channel(gate)
237+
238+
# Remove _kraus_ method
239+
class NoKrausReset(Reset):
240+
def _kraus_(self):
241+
return NotImplemented
242+
243+
gate_no_kraus = NoKrausReset()
244+
# Should still match the original superoperator
245+
np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8)

cirq-core/cirq/testing/circuit_compare.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def assert_has_consistent_apply_channel(val: Any, *, atol: float = 1e-8) -> None
336336
atol: Absolute error tolerance.
337337
"""
338338
__tracebackhide__ = True
339+
assert hasattr(val, '_apply_channel_')
339340

340341
kraus = protocols.kraus(val, default=None)
341342
expected = qis.kraus_to_superoperator(kraus) if kraus is not None else None

0 commit comments

Comments
 (0)