Skip to content

Commit 2576dbb

Browse files
williamwen42pytorchmergebot
authored andcommitted
[dynamo] implement IteratorVariable and polyfill fallbacks for enumerate (pytorch#131725)
Fixes pytorch#112794. Pull Request resolved: pytorch#131725 Approved by: https://github.com/anijain2305 ghstack dependencies: pytorch#131413, pytorch#131716
1 parent 35b4de3 commit 2576dbb

File tree

5 files changed

+89
-15
lines changed

5 files changed

+89
-15
lines changed

test/dynamo/test_functions.py

+39
Original file line numberDiff line numberDiff line change
@@ -2987,6 +2987,45 @@ def test_map_unpack_twice(a, b):
29872987
l2 = list(m)
29882988
return l1, l2
29892989

2990+
@make_test
2991+
def test_enumerate(a, b):
2992+
return list(enumerate([a, b], start=1)), a + 1
2993+
2994+
@make_test
2995+
def test_map_enumerate(a, b):
2996+
return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1
2997+
2998+
def test_enumerate_custom(self):
2999+
class MyClass:
3000+
def __iter__(self):
3001+
self.a = 1
3002+
return self
3003+
3004+
def __next__(self):
3005+
if self.a > 3:
3006+
raise StopIteration
3007+
self.a += 1
3008+
return self.a
3009+
3010+
def fn(x):
3011+
for i, it in enumerate(MyClass()):
3012+
x += i + it
3013+
return x
3014+
3015+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3016+
self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3)))
3017+
3018+
def test_enumerate_reconstruct(self):
3019+
def fn(a, b):
3020+
return enumerate([a, b], start=1)
3021+
3022+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
3023+
inps = (torch.randn(3, 3), torch.randn(3, 3))
3024+
it1 = fn(*inps)
3025+
it2 = opt_fn(*inps)
3026+
self.assertIsInstance(it2, enumerate)
3027+
self.assertEqual(list(it1), list(it2))
3028+
29903029

29913030
def udf_mul(x, y):
29923031
return x * y

torch/_dynamo/polyfill.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def any(iterator):
2424

2525

2626
def index(iterator, item, start=0, end=None):
27-
for i, elem in enumerate(list(iterator))[start:end]:
27+
for i, elem in enumerate(list(iterator)[start:end], start):
2828
if item == elem:
2929
return i
3030
# This will not run in dynamo
@@ -124,3 +124,10 @@ def getattr_and_trace(*args, **kwargs):
124124
attr_name = args[1]
125125
fn = getattr(wrapper_obj, attr_name)
126126
return fn(*args[2:], **kwargs)
127+
128+
129+
def enumerate(iterable, start=0):
130+
n = start
131+
for elem in iterable:
132+
yield n, elem
133+
n += 1

torch/_dynamo/variables/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .iter import (
4343
CountIteratorVariable,
4444
CycleIteratorVariable,
45+
EnumerateVariable,
4546
IteratorVariable,
4647
ItertoolsVariable,
4748
MapVariable,

torch/_dynamo/variables/builtin.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -1419,22 +1419,29 @@ def call_zip(self, tx: "InstructionTranslator", *args, **kwargs):
14191419
]
14201420
return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal())
14211421

1422-
def call_enumerate(self, tx: "InstructionTranslator", *args):
1423-
if len(args) == 1:
1422+
_call_enumerate_polyfill = _polyfill_call_impl("enumerate")
1423+
1424+
def call_enumerate(self, tx: "InstructionTranslator", iterable, start=_SENTINEL):
1425+
if start is self._SENTINEL:
14241426
start = 0
14251427
else:
1426-
assert len(args) == 2
1427-
assert isinstance(args[1], variables.ConstantVariable)
1428-
start = args[1].as_python_constant()
1429-
if args[0].has_unpack_var_sequence(tx):
1430-
items = [
1431-
variables.TupleVariable(
1432-
[variables.ConstantVariable.create(idx), var],
1433-
)
1434-
for idx, var in enumerate(args[0].unpack_var_sequence(tx), start)
1435-
]
1436-
return variables.TupleVariable(items)
1437-
# could have an iterable version
1428+
assert isinstance(start, variables.ConstantVariable)
1429+
start = start.as_python_constant()
1430+
1431+
if iterable.has_unpack_var_sequence(tx):
1432+
return variables.EnumerateVariable(
1433+
iterable.unpack_var_sequence(tx),
1434+
start,
1435+
mutable_local=MutableLocal(),
1436+
)
1437+
elif isinstance(iterable, variables.IteratorVariable):
1438+
return variables.EnumerateVariable(
1439+
iterable, start, mutable_local=MutableLocal()
1440+
)
1441+
1442+
return self._call_enumerate_polyfill(
1443+
tx, iterable, variables.ConstantVariable.create(start)
1444+
)
14381445

14391446
def call_len(self, tx: "InstructionTranslator", *args, **kwargs):
14401447
return args[0].call_method(tx, "__len__", args[1:], kwargs)

torch/_dynamo/variables/iter.py

+20
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,23 @@ def reconstruct(self, codegen):
480480
create_instruction("CALL_FUNCTION_EX", arg=0),
481481
]
482482
)
483+
484+
485+
class EnumerateVariable(ZipVariable):
486+
def __init__(
487+
self,
488+
iterable: Union[List[VariableTracker], VariableTracker],
489+
start: int = 0,
490+
**kwargs,
491+
):
492+
super().__init__(
493+
[CountIteratorVariable(start, mutable_local=MutableLocal()), iterable],
494+
**kwargs,
495+
)
496+
497+
def reconstruct(self, codegen):
498+
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "enumerate"))
499+
codegen(self.iterables[1])
500+
assert isinstance(self.iterables[0], CountIteratorVariable)
501+
codegen(self.iterables[0].item)
502+
codegen.extend_output(codegen.create_call_function_kw(2, ("start",), False))

0 commit comments

Comments
 (0)