|
22 | 22 | import pytest
|
23 | 23 |
|
24 | 24 | import cirq
|
| 25 | +from cirq.protocols.apply_channel_protocol import _apply_kraus |
25 | 26 |
|
26 | 27 | LOCAL_DEFAULT: list[np.ndarray] = [np.array([])]
|
27 | 28 |
|
@@ -171,3 +172,74 @@ def test_has_kraus_when_decomposed(decomposed_cls) -> None:
|
171 | 172 | op = HasKrausWhenDecomposed(decomposed_cls).on(cirq.NamedQubit('test'))
|
172 | 173 | assert cirq.has_kraus(op)
|
173 | 174 | assert not cirq.has_kraus(op, allow_decompose=False)
|
| 175 | + |
| 176 | + |
| 177 | +def test_strat_kraus_from_apply_channel_returns_none(): |
| 178 | + # Remove _kraus_ and _apply_channel_ methods |
| 179 | + class NoApplyChannelReset(cirq.ResetChannel): |
| 180 | + def _kraus_(self): |
| 181 | + return NotImplemented |
| 182 | + |
| 183 | + def _apply_channel_(self, args): |
| 184 | + return NotImplemented |
| 185 | + |
| 186 | + gate_no_apply = NoApplyChannelReset() |
| 187 | + with pytest.raises( |
| 188 | + TypeError, |
| 189 | + match="does have a _kraus_, _mixture_ or _unitary_ method, but it returned NotImplemented", |
| 190 | + ): |
| 191 | + cirq.kraus(gate_no_apply) |
| 192 | + |
| 193 | + |
| 194 | +@pytest.mark.parametrize( |
| 195 | + 'channel_cls,params', |
| 196 | + [ |
| 197 | + (cirq.BitFlipChannel, (0.5,)), |
| 198 | + (cirq.PhaseFlipChannel, (0.3,)), |
| 199 | + (cirq.DepolarizingChannel, (0.2,)), |
| 200 | + (cirq.AmplitudeDampingChannel, (0.4,)), |
| 201 | + (cirq.PhaseDampingChannel, (0.25,)), |
| 202 | + ], |
| 203 | +) |
| 204 | +def test_kraus_fallback_to_apply_channel(channel_cls, params) -> None: |
| 205 | + """Kraus protocol falls back to _apply_channel_ when no _kraus_, _mixture_, or _unitary_.""" |
| 206 | + # Create the expected channel and get its Kraus operators |
| 207 | + expected_channel = channel_cls(*params) |
| 208 | + expected_kraus = cirq.kraus(expected_channel) |
| 209 | + |
| 210 | + class TestChannel: |
| 211 | + def __init__(self, channel_cls, params): |
| 212 | + self.channel_cls = channel_cls |
| 213 | + self.params = params |
| 214 | + self.expected_kraus = cirq.kraus(channel_cls(*params)) |
| 215 | + |
| 216 | + def _num_qubits_(self): |
| 217 | + return 1 |
| 218 | + |
| 219 | + def _apply_channel_(self, args: cirq.ApplyChannelArgs): |
| 220 | + return _apply_kraus(self.expected_kraus, args) |
| 221 | + |
| 222 | + chan = TestChannel(channel_cls, params) |
| 223 | + kraus_ops = cirq.kraus(chan) |
| 224 | + |
| 225 | + # Compare the superoperator matrices for equivalence |
| 226 | + expected_super = sum(np.kron(k, k.conj()) for k in expected_kraus) |
| 227 | + actual_super = sum(np.kron(k, k.conj()) for k in kraus_ops) |
| 228 | + np.testing.assert_allclose(actual_super, expected_super, atol=1e-8) |
| 229 | + |
| 230 | + |
| 231 | +def test_reset_channel_kraus_apply_channel_consistency(): |
| 232 | + Reset = cirq.ResetChannel |
| 233 | + # Original gate |
| 234 | + gate = Reset() |
| 235 | + cirq.testing.assert_has_consistent_apply_channel(gate) |
| 236 | + cirq.testing.assert_consistent_channel(gate) |
| 237 | + |
| 238 | + # Remove _kraus_ method |
| 239 | + class NoKrausReset(Reset): |
| 240 | + def _kraus_(self): |
| 241 | + return NotImplemented |
| 242 | + |
| 243 | + gate_no_kraus = NoKrausReset() |
| 244 | + # Should still match the original superoperator |
| 245 | + np.testing.assert_allclose(cirq.kraus(gate), cirq.kraus(gate_no_kraus), atol=1e-8) |
0 commit comments