Skip to content

Commit 23b7ee2

Browse files
authored
Use grad context for hashing the generated stablehlo program (#1604)
**Context:** After PR #1562, a single function could have multiple JAXPR representations based on whether it was under a grad context or not. This made the previous hash based on the function id create possible conflicts. To address this, we hashed on the jaxpr string representation. (We cannot hash on the jax object itself since they are unique). The JAXPR string representation can be very long and hashing over long strings can take a long time. **Description of the Change:** Instead of hashing the string representation, add a simple key to denote whether it is inside a grad context or not. **Benefits:** Reduced compilation time. **Possible Drawbacks:** The cache key is getting more complicated. Maybe the drawbacks outweight the benefits now? **Related GitHub Issues:** [sc-88454]
1 parent efe2fa3 commit 23b7ee2

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

doc/releases/changelog-0.11.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@
193193
[(#1562)](https://github.com/PennyLaneAI/catalyst/pull/1562)
194194
[(#1568)](https://github.com/PennyLaneAI/catalyst/pull/1568)
195195
[(#1569)](https://github.com/PennyLaneAI/catalyst/pull/1569)
196+
[(#1604)](https://github.com/PennyLaneAI/catalyst/pull/1604)
196197

197198
Gates that are constant, such as when all parameters are Python or NumPy data types, are not
198199
decomposed when this is allowable. For the adjoint differentiation method, this is allowable

frontend/catalyst/jax_primitives.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
510510
new_argnums = [num + offset for num in argnums]
511511
argnum_numpy = np.array(new_argnums)
512512
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
513-
func_op = lower_jaxpr(ctx, jaxpr)
513+
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums))
514514

515515
symbol_ref = get_symbolref(ctx, func_op)
516516
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
@@ -585,7 +585,7 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params):
585585
val_result_types = flat_output_types[: len(flat_output_types) - len(argnums)]
586586
gradient_result_types = flat_output_types[len(flat_output_types) - len(argnums) :]
587587

588-
func_op = lower_jaxpr(ctx, jaxpr)
588+
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums))
589589

590590
symbol_ref = get_symbolref(ctx, func_op)
591591
return ValueAndGradOp(
@@ -635,7 +635,7 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params):
635635
func_args = consts_and_args[: len(func_call_jaxpr.invars)]
636636
tang_args = consts_and_args[len(func_call_jaxpr.invars) :]
637637

638-
func_op = lower_jaxpr(ctx, jaxpr)
638+
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums))
639639

640640
assert (
641641
len(flat_output_types) % 2 == 0
@@ -688,7 +688,7 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params):
688688
func_result_types = flat_output_types[: len(flat_output_types) - len(argnums)]
689689
vjp_result_types = flat_output_types[len(flat_output_types) - len(argnums) :]
690690

691-
func_op = lower_jaxpr(ctx, jaxpr)
691+
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums))
692692

693693
symbol_ref = get_symbolref(ctx, func_op)
694694
return VJPOp(

frontend/catalyst/jax_primitives_utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,18 @@ def get_call_equation(jaxpr):
4747
raise AssertionError("No call_jaxpr found in the JAXPR.")
4848

4949

50-
def lower_jaxpr(ctx, jaxpr):
50+
def lower_jaxpr(ctx, jaxpr, context=None):
5151
"""Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p"""
5252
equation = get_call_equation(jaxpr)
5353
call_jaxpr = equation.params["call_jaxpr"]
5454
callable_ = equation.params.get("fn")
5555
if callable_ is None:
5656
callable_ = equation.params.get("qnode")
5757
pipeline = equation.params.get("pipeline")
58-
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline)
58+
return lower_callable(ctx, callable_, call_jaxpr, pipeline=pipeline, context=context)
5959

6060

61-
def lower_callable(ctx, callable_, call_jaxpr, pipeline=None):
61+
def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None):
6262
"""Lowers _callable to MLIR.
6363
6464
If callable_ is a qnode, then we will first create a module, then
@@ -77,14 +77,16 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None):
7777
pipeline = tuple()
7878

7979
if not isinstance(callable_, qml.QNode):
80-
return get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline)
80+
return get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=context)
8181

82-
return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline)
82+
return get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context=context)
8383

8484

85-
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline):
85+
def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, context=None):
8686
"""Get funcOp from cache, or create it from scratch"""
87-
key = (str(call_jaxpr), *pipeline)
87+
if context is None:
88+
context = tuple()
89+
key = (callable_, *context, *pipeline)
8890
if func_op := get_cached(ctx, key):
8991
return func_op
9092
func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr)
@@ -123,7 +125,7 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr):
123125
return func_op
124126

125127

126-
def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline):
128+
def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
127129
"""A wrapper around lower_qnode_to_funcop that will cache the FuncOp.
128130
129131
Args:
@@ -133,7 +135,9 @@ def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline):
133135
Returns:
134136
FuncOp
135137
"""
136-
key = (str(call_jaxpr), *pipeline)
138+
if context is None:
139+
context = tuple()
140+
key = (callable_, *context, *pipeline)
137141
if func_op := get_cached(ctx, key):
138142
return func_op
139143
func_op = lower_qnode_to_funcop(ctx, callable_, call_jaxpr, pipeline)

0 commit comments

Comments
 (0)