Skip to content

Commit ca662f3

Browse files
committed
change dispatch key to a flag decomposed
1 parent 9860c56 commit ca662f3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchao/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def find_multiple(n: int, *args: int) -> int:
179179
return n + k - (n % k)
180180

181181

182-
def _register_custom_op(lib, implicit=True):
182+
def _register_custom_op(lib, decomposed=True):
183183
"""This decorator is used to preserve some high level operators for torch.export.export
184184
while still allow them to be decomposed for inductor path
185185
@@ -207,7 +207,7 @@ def _the_op_that_needs_to_be_preserved(...)
207207
from torch._inductor.decomposition import register_decomposition
208208

209209
dispatch_key = (
210-
"CompositeImplicitAutograd" if implicit else "CompositeExplicitAutograd"
210+
"CompositeImplicitAutograd" if decomposed else "CompositeExplicitAutograd"
211211
)
212212

213213
def decorator(fn):
@@ -229,7 +229,7 @@ def decorator(fn):
229229

230230
lib_namespace = lib.ns
231231
op = getattr(getattr(torch.ops, lib_namespace), op_name)
232-
if implicit:
232+
if decomposed:
233233
register_decomposition([op])(fn)
234234
return op
235235
else:

0 commit comments

Comments
 (0)