Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,34 @@

<h3>Improvements 🛠</h3>

* `catalyst.cond` can accept branch functions with arguments, e.g

```python
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2))
def func():
qml.PauliX(wires=1) # |01>
m0 = catalyst.measure(0) # will measure 0

@catalyst.cond(m0 == 1)
def conditional(wire):
qml.PauliX(wires=wire)

@conditional.otherwise
def false_fn(wire): # will come here
qml.RX(1.23, wires=wire+1)

conditional(0)

return qml.probs()

print(func())
```
```
[0.33288114 0.66711886 0. 0. ]
```
[(#1531)](https://github.com/PennyLaneAI/catalyst/pull/1531)

* Changed pattern rewritting in `quantum-to-ion` lowering pass to use MLIR's dialect conversion
infrastracture.
[(#1442)](https://github.com/PennyLaneAI/catalyst/pull/1442)
Expand Down
91 changes: 47 additions & 44 deletions frontend/catalyst/api_extensions/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def cond(pred: DynamicJaxprTracer):
Returns:
A callable decorator that wraps the first 'if' branch of the conditional.

Raises:
AssertionError: Branch functions cannot have arguments.

**Example**

.. code-block:: python
Expand Down Expand Up @@ -240,14 +237,8 @@ def conditional_fn():
def _decorator(true_fn: Callable):

if len(inspect.signature(true_fn).parameters):
if isinstance(true_fn, type) and issubclass(true_fn, qml.operation.Operation):
# Special treatment if conditional function body is a single pennylane gate
# The qml.operation.Operation base class represents things that
# can reasonably be considered as a gate,
# e.g. qml.Hadamard, qml.RX, etc.
return CondCallableSingleGateHandler(pred, true_fn)
else:
raise TypeError("Conditional 'True' function is not allowed to have any arguments")
# Special treatment if conditional function body has arguments
return CondCallableArgumentsHandler(pred, true_fn)

return CondCallable(pred, true_fn)

Expand Down Expand Up @@ -761,45 +752,67 @@ def __call__(self):
return self._call_during_interpretation()


class CondCallableSingleGateHandler(CondCallable):
class CondCallableArgumentsHandler(CondCallable):
"""
Special CondCallable when the conditional body function is a single pennylane gate.
Special CondCallable when the conditional body function has arguments.

A usual pennylane conditional call for a gate looks like
For example, a usual pennylane conditional call for a gate looks like
`qml.cond(x == 42, qml.RX)(theta, wires=0)`

Since gates are guaranteed to take in arguments (at the very least the wire argument),
the usual CondCallable class, which expects the conditional body function to have no arguments,
cannot be used.
This class inherits from base CondCallable, but wraps the gate in a function with no arguments,
and sends that function to CondCallable.
This allows us to perform the conditional branch gate function with arguments.

This class inherits from base CondCallable, but wraps the branch function in a function with
no arguments, and sends that function to CondCallable.
This allows us to perform the conditional branch function with arguments.
"""

def __init__(self, pred, true_fn): # pylint:disable=super-init-not-called
self.sgh_preds = [pred]
self.sgh_branch_fns = [true_fn]
self.sgh_otherwise_fn = None
self.ccah_preds = [pred]
self.ccah_branch_fns = [true_fn]
self.ccah_otherwise_fn = None

def __call__(self, *args, **kwargs):
def argless_true_fn():
self.sgh_branch_fns[0](*args, **kwargs)
def argless_true_fn(): # pylint:disable=inconsistent-return-statements
# Special treatment if conditional function body is a single pennylane gate
# In such cases, the gate function should only be called, but not returned.
# Note: The qml.operation.Operation base class represents things that
# can reasonably be considered as a gate,
# e.g. qml.Hadamard, qml.RX, etc.
if isinstance(self.ccah_branch_fns[0], type) and issubclass(
self.ccah_branch_fns[0], qml.operation.Operation
):
self.ccah_branch_fns[0](*args, **kwargs)
else:
return self.ccah_branch_fns[0](*args, **kwargs)

super().__init__(self.sgh_preds[0], argless_true_fn)
super().__init__(self.ccah_preds[0], argless_true_fn)

if self.sgh_otherwise_fn is not None:
if self.ccah_otherwise_fn is not None:

def argless_otherwise_fn():
self.sgh_otherwise_fn(*args, **kwargs)
def argless_otherwise_fn(): # pylint:disable=inconsistent-return-statements
if isinstance(self.ccah_otherwise_fn, type) and issubclass(
self.ccah_otherwise_fn, qml.operation.Operation
):
self.ccah_otherwise_fn(*args, **kwargs)
else:
return self.ccah_otherwise_fn(*args, **kwargs)

super().set_otherwise_fn(argless_otherwise_fn)

for i in range(1, len(self.sgh_branch_fns)):
for i in range(1, len(self.ccah_branch_fns)):

def argless_elseif_fn(i=i): # i=i to work around late binding
self.sgh_branch_fns[i](*args, **kwargs)
def argless_elseif_fn(i=i): # pylint:disable=inconsistent-return-statements
# i=i to work around late binding
if isinstance(self.ccah_branch_fns[i], type) and issubclass(
self.ccah_branch_fns[i], qml.operation.Operation
):
self.ccah_branch_fns[i](*args, **kwargs)
else:
return self.ccah_branch_fns[i](*args, **kwargs)

super().add_pred(self.sgh_preds[i])
super().add_pred(self.ccah_preds[i])
super().add_branch_fn(argless_elseif_fn)

return super().__call__()
Expand All @@ -810,27 +823,17 @@ def else_if(self, _pred):
"""

def decorator(branch_fn):
if isinstance(branch_fn, type) and issubclass(branch_fn, qml.operation.Operation):
self.sgh_preds.append(_pred)
self.sgh_branch_fns.append(branch_fn)
return self
else: # pylint:disable=line-too-long
raise TypeError(
"Conditional 'else if' function can have arguments only if it is a PennyLane gate."
)
self.ccah_preds.append(_pred)
self.ccah_branch_fns.append(branch_fn)
return self

return decorator

def otherwise(self, otherwise_fn):
"""
Override the "can't have arguments" check in the original CondCallable's `otherwise`
"""
if isinstance(otherwise_fn, type) and issubclass(otherwise_fn, qml.operation.Operation):
self.sgh_otherwise_fn = otherwise_fn
else:
raise TypeError(
"Conditional 'False' function can have arguments only if it is a PennyLane gate."
)
self.ccah_otherwise_fn = otherwise_fn


class ForLoopCallable:
Expand Down
22 changes: 22 additions & 0 deletions frontend/test/pytest/test_capture_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,28 @@ def ansatz_false():
experimental_capture_result = qml.qjit(circuit, experimental_capture=True)(0.1)
assert default_capture_result == experimental_capture_result

def test_cond_workflow_if_else_args(self, backend):
"""Test the integration for a circuit with a cond primitive with true and false branches
with args."""

@qml.qnode(qml.device(backend, wires=1))
def circuit(x: float):

def ansatz_true(wire):
qml.RX(x, wires=wire)
qml.Hadamard(wires=wire)

def ansatz_false(wire):
qml.RY(x, wires=wire)

qml.cond(x > 1.4, ansatz_true, ansatz_false)(0)

return qml.expval(qml.Z(0))

default_capture_result = qml.qjit(circuit)(0.1)
experimental_capture_result = qml.qjit(circuit, experimental_capture=True)(0.1)
assert default_capture_result == experimental_capture_result

def test_cond_workflow_if(self, backend):
"""Test the integration for a circuit with a cond primitive with a true branch only."""

Expand Down
142 changes: 75 additions & 67 deletions frontend/test/pytest/test_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,54 @@ def else_fn():
assert circuit(5) == 25
assert circuit(6) == 36

def test_simple_cond_with_args(self, backend):
"""Test basic function with conditional, with the branch functions having arguments."""

@qjit
@qml.qnode(qml.device(backend, wires=1))
def circuit(x):
@cond(x > 4.8)
def cond_fn(multiplier):
return x * multiplier

@cond_fn.else_if(x > 2.7)
def cond_elif(multiplier):
return x * multiplier + 1

@cond_fn.else_if(x > 1.4)
def cond_elif2(multiplier):
return x * multiplier + 2

@cond_fn.otherwise
def cond_else(multiplier):
return x * multiplier - 10

return cond_fn(8)

assert circuit(5) == 40
assert circuit(3) == 25
assert circuit(2) == 18
assert circuit(-3) == -34

def test_simple_cond_multiple_args(self, backend):
"""Test function with conditional, with the branch function having multiple arguments."""

@qjit
@qml.qnode(qml.device(backend, wires=1))
def circuit(x):
@cond(x > 4.8)
def cond_fn(multiplier, adder):
return x * multiplier + adder

@cond_fn.otherwise
def cond_else(multiplier, adder):
return (x * multiplier + adder) * 0

return cond_fn(8, 42)

assert circuit(5) == 82
assert circuit(-3) == 0

def test_cond_one_else_if(self, backend):
"""Test a cond with one else_if branch"""

Expand Down Expand Up @@ -416,42 +464,6 @@ def conditional_flip():
assert circuit(False) == 0
assert circuit(True) == 1

def test_argument_error_with_callables(self):
"""Test for the error when arguments are supplied and the target is not a function."""

def f(x: int):

res = qml.cond(x < 5, lambda z: z + 1)(0)

return res

with pytest.raises(TypeError, match="not allowed to have any arguments"):
qjit(f)

def g(x: int):

res = qml.cond(x < 5, qml.Hadamard, lambda z: z + 1)(0)

return res

with pytest.raises(
TypeError,
match="Conditional 'False' function can have arguments only if it is a PennyLane gate.",
):
qjit(g)

def h(x: int):

res = qml.cond(x < 5, qml.Hadamard, qml.Hadamard, ((x < 6, lambda z: z + 1),))(0)

return res

with pytest.raises(
TypeError,
match="Conditional 'else if' function can have arguments only if it is a PennyLane gate.", # pylint:disable=line-too-long
):
qjit(h)


class TestInterpretationConditional:
"""Test that the conditional operation's execution is semantically equivalent
Expand Down Expand Up @@ -574,37 +586,6 @@ def arithi(x, y, op1, op2):

assert arithi(x, y, op1, op2) == arithc(x, y, op1, op2)

def test_no_true_false_parameters(self):
"""Test non-empty parameter detection in conditionals"""

def arithc2():
@cond(True)
def branch(_):
return 1

@branch.otherwise
def branch():
return 0

return branch()

with pytest.raises(TypeError, match="Conditional 'True'"):
qjit(arithc2)

def arithc1():
@cond(True)
def branch():
return 1

@branch.otherwise
def branch(_):
return 0

return branch() # pylint: disable=no-value-for-parameter

with pytest.raises(TypeError, match="Conditional 'False'"):
qjit(arithc1)


class TestCondOperatorAccess:
"""Test suite for accessing the Cond operation in quantum contexts in Catalyst."""
Expand Down Expand Up @@ -742,6 +723,33 @@ def func(x, y):
assert np.allclose(expected_2, observed_2)
assert np.allclose(expected_3, observed_3)

def test_cond_measurement(self, backend):
"""
Test conditionals with measurements being the predicate.
"""

@qjit
@qml.qnode(qml.device(backend, wires=2))
def func():
qml.PauliX(wires=1) # |01>
m0 = measure(0) # will measure 0

@cond(m0 == 1)
def conditional(wire): # should not be triggered
qml.RX(1.23, wires=wire + 1)

@conditional.otherwise
def false_fn(wire): # will come here
qml.PauliX(wires=wire + 1)

conditional(0)

return qml.probs()

observed = func()

assert np.allclose(observed, np.array([1, 0, 0, 0]))


class TestCondPredicateConversion:
"""Test suite for checking predicate conversion to bool."""
Expand Down
Loading