Skip to content

Commit e296bab

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Remove DICT_SUBCLASS_GUARD_MANAGER and use dict.keys (pytorch#143722)
In hinsight, we never needed a DICT_SUBCLASS_GUARD_MANAGER, because Dynamo would inline through the overridden keys method. In this PR, we ensure that while creating guards and constructing variable trackers, we get the `d.keys()` value by using `dict.keys(d)`. This ensures that we do not call overridden keys method. Therefore, the C++ guard can use `PyDict_Next` directly to check the guards. Pull Request resolved: pytorch#143722 Approved by: https://github.com/jansel
1 parent d60282c commit e296bab

File tree

9 files changed

+235
-366
lines changed

9 files changed

+235
-366
lines changed

test/dynamo/test_dicts.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# flake8: noqa
55

66
import dataclasses
7+
import unittest
78
from collections import OrderedDict
89
from dataclasses import dataclass, fields, is_dataclass
910
from typing import Any, Optional, Tuple
@@ -174,6 +175,149 @@ def fn(x):
174175
self.assertEqual(ref["x"], res["x"])
175176
self.assertEqual(ref["y"], res["y"])
176177

178+
def test_custom_iter_dict(self):
179+
class ReversedDict(dict):
180+
def __iter__(self):
181+
return reversed(list(self.keys()))
182+
183+
d = {
184+
"foo": 1,
185+
"bar": 2,
186+
}
187+
188+
d = ReversedDict(d)
189+
190+
@torch.compile(backend="eager")
191+
def fn(x, d):
192+
return x * d["foo"] * d["bar"]
193+
194+
fn(torch.randn(4), d)
195+
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
196+
fn(torch.randn(4), d)
197+
198+
def test_custom_keys_iter_dict(self):
199+
class ReversedDict(dict):
200+
def keys(self):
201+
return ["bar", "foo"]
202+
203+
d = {
204+
"foo": 1,
205+
"bar": 2,
206+
}
207+
208+
d = ReversedDict(d)
209+
210+
@torch.compile(backend="eager")
211+
def fn(x, d):
212+
return x * d["foo"] * d["bar"]
213+
214+
fn(torch.randn(4), d)
215+
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
216+
fn(torch.randn(4), d)
217+
218+
def test_dict_guard_on_keys_order(self):
219+
d = {
220+
2: 4,
221+
3: 5,
222+
}
223+
224+
cnts = torch._dynamo.testing.CompileCounter()
225+
226+
def fn(x, d):
227+
for key, value in d.items():
228+
x = x * key + value
229+
return x
230+
231+
opt_fn = torch.compile(fn, backend=cnts)
232+
opt_fn(torch.randn(4), d)
233+
opt_fn(torch.randn(4), d)
234+
# No recompilation
235+
self.assertEqual(cnts.frame_count, 1)
236+
237+
# move 2 to the end
238+
d[2] = d.pop(2)
239+
240+
x = torch.randn(4)
241+
res = opt_fn(x, d)
242+
# Check recompilation
243+
self.assertEqual(cnts.frame_count, 2)
244+
self.assertEqual(res, fn(x, d))
245+
246+
def test_dict_guard_on_keys_order2(self):
247+
d = {
248+
2: 4,
249+
3: 5,
250+
}
251+
252+
cnts = torch._dynamo.testing.CompileCounter()
253+
254+
def fn(x, d):
255+
for key in d:
256+
value = d[key]
257+
x = x * key + value
258+
return x
259+
260+
opt_fn = torch.compile(fn, backend=cnts)
261+
opt_fn(torch.randn(4), d)
262+
opt_fn(torch.randn(4), d)
263+
# No recompilation
264+
self.assertEqual(cnts.frame_count, 1)
265+
266+
# move 2 to the end
267+
d[2] = d.pop(2)
268+
269+
x = torch.randn(4)
270+
res = opt_fn(x, d)
271+
# Check recompilation
272+
self.assertEqual(cnts.frame_count, 2)
273+
self.assertEqual(res, fn(x, d))
274+
275+
def test_ordered_dict_reordered_keys(self):
276+
d = OrderedDict()
277+
d[2] = 4
278+
d[3] = 5
279+
d.move_to_end(2)
280+
281+
cnts = torch._dynamo.testing.CompileCounter()
282+
283+
def fn(x, d):
284+
y = 0
285+
for idx, (key, value) in enumerate(d.items()):
286+
if idx == 0:
287+
y += torch.sin(x * value)
288+
else:
289+
y += torch.cos(x * value)
290+
return y
291+
292+
opt_fn = torch.compile(fn, backend=cnts)
293+
x = torch.randn(4)
294+
self.assertEqual(opt_fn(x, d), fn(x, d))
295+
296+
def test_ordered_dict_subclass_reordered_keys(self):
297+
class ODSubclass(OrderedDict):
298+
def keys():
299+
return super().keys()
300+
301+
d = ODSubclass()
302+
d[2] = 4
303+
d[3] = 5
304+
d.move_to_end(2)
305+
306+
cnts = torch._dynamo.testing.CompileCounter()
307+
308+
def fn(x, d):
309+
y = 0
310+
for idx, (key, value) in enumerate(d.items()):
311+
if idx == 0:
312+
y += torch.sin(x * value)
313+
else:
314+
y += torch.cos(x * value)
315+
return y
316+
317+
opt_fn = torch.compile(fn, backend=cnts)
318+
x = torch.randn(4)
319+
self.assertEqual(opt_fn(x, d), fn(x, d))
320+
177321

178322
def is_tensor(x):
179323
import torch

test/dynamo/test_guard_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
RootGuardManager = guards.RootGuardManager
1616
DictGuardManager = guards.DictGuardManager
17-
DictSubclassGuardManager = guards.DictSubclassGuardManager
1817
GetAttrGuardAccessor = guards.GetAttrGuardAccessor
1918
GetItemGuardAccessor = guards.GetItemGuardAccessor
2019
TypeGuardAccessor = guards.TypeGuardAccessor

test/dynamo/test_misc.py

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -11466,103 +11466,6 @@ def f(mask, box):
1146611466

1146711467
f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda"))
1146811468

11469-
def test_custom_iter_dict(self):
11470-
class ReversedDict(dict):
11471-
def __iter__(self):
11472-
return reversed(list(self.keys()))
11473-
11474-
d = {
11475-
"foo": 1,
11476-
"bar": 2,
11477-
}
11478-
11479-
d = ReversedDict(d)
11480-
11481-
@torch.compile(backend="eager")
11482-
def fn(x, d):
11483-
return x * d["foo"] * d["bar"]
11484-
11485-
fn(torch.randn(4), d)
11486-
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
11487-
fn(torch.randn(4), d)
11488-
11489-
def test_custom_keys_iter_dict(self):
11490-
class ReversedDict(dict):
11491-
def keys(self):
11492-
return ["bar", "foo"]
11493-
11494-
d = {
11495-
"foo": 1,
11496-
"bar": 2,
11497-
}
11498-
11499-
d = ReversedDict(d)
11500-
11501-
@torch.compile(backend="eager")
11502-
def fn(x, d):
11503-
return x * d["foo"] * d["bar"]
11504-
11505-
fn(torch.randn(4), d)
11506-
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
11507-
fn(torch.randn(4), d)
11508-
11509-
def test_dict_guard_on_keys_order(self):
11510-
d = {
11511-
2: 4,
11512-
3: 5,
11513-
}
11514-
11515-
cnts = torch._dynamo.testing.CompileCounter()
11516-
11517-
def fn(x, d):
11518-
for key, value in d.items():
11519-
x = x * key + value
11520-
return x
11521-
11522-
opt_fn = torch.compile(fn, backend=cnts)
11523-
opt_fn(torch.randn(4), d)
11524-
opt_fn(torch.randn(4), d)
11525-
# No recompilation
11526-
self.assertEqual(cnts.frame_count, 1)
11527-
11528-
# move 2 to the end
11529-
d[2] = d.pop(2)
11530-
11531-
x = torch.randn(4)
11532-
res = opt_fn(x, d)
11533-
# Check recompilation
11534-
self.assertEqual(cnts.frame_count, 2)
11535-
self.assertEqual(res, fn(x, d))
11536-
11537-
def test_dict_guard_on_keys_order2(self):
11538-
d = {
11539-
2: 4,
11540-
3: 5,
11541-
}
11542-
11543-
cnts = torch._dynamo.testing.CompileCounter()
11544-
11545-
def fn(x, d):
11546-
for key in d:
11547-
value = d[key]
11548-
x = x * key + value
11549-
return x
11550-
11551-
opt_fn = torch.compile(fn, backend=cnts)
11552-
opt_fn(torch.randn(4), d)
11553-
opt_fn(torch.randn(4), d)
11554-
# No recompilation
11555-
self.assertEqual(cnts.frame_count, 1)
11556-
11557-
# move 2 to the end
11558-
d[2] = d.pop(2)
11559-
11560-
x = torch.randn(4)
11561-
res = opt_fn(x, d)
11562-
# Check recompilation
11563-
self.assertEqual(cnts.frame_count, 2)
11564-
self.assertEqual(res, fn(x, d))
11565-
1156611469
def test_contains_dunder_dict(self):
1156711470
class UserDefined:
1156811471
def __init__(self) -> None:

test/dynamo/test_modules.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2428,7 +2428,6 @@ def new_forward_hook(
24282428
m._forward_hooks[handle.id] = new_forward_hook
24292429
self.assertEqual(compiled_func(inp), outer_func(inp))
24302430
self.assertEqual(compiled_func(inp).item(), 16)
2431-
self.assertRegex(failure_reason, r"___check_obj_id\(L\['m'\]._forward_hooks")
24322431

24332432
@patch.object(torch._dynamo.config, "guard_nn_modules", False)
24342433
@patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True)

0 commit comments

Comments
 (0)