Skip to content

Commit 0f474a9

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Remove dead code after introducing UserDefinedDictVariable (pytorch#143699)
Pull Request resolved: pytorch#143699 Approved by: https://github.com/williamwen42, https://github.com/yanboliang, https://github.com/jansel ghstack dependencies: pytorch#143722
1 parent e296bab commit 0f474a9

File tree

5 files changed

+3
-145
lines changed

5 files changed

+3
-145
lines changed

torch/_dynamo/guards.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@
9898
LocalSource,
9999
NNModuleSource,
100100
NumpyTensorSource,
101-
ODictGetItemSource,
102101
OptimizerSource,
103102
ScriptObjectQualifiedNameSource,
104103
ShapeEnvSource,
@@ -1055,8 +1054,7 @@ def get_guard_manager_from_source(self, source):
10551054
assert base_guard_manager # to make mypy happy
10561055
if isinstance(base_example_value, (dict, collections.OrderedDict)):
10571056
# TODO(anijain2305) - Consider isolating GetItemSource and
1058-
# DictGetItemSource (or maybe use ODictGetItemSource for
1059-
# dicts) so that GetItemSource is only for non dict objects.
1057+
# DictGetItemSource for dicts) so that GetItemSource is only for non dict objects.
10601058
if isinstance(base_guard_manager, DictGuardManager):
10611059
assert self.manager_guards_on_keys(base_guard_manager_enum)
10621060
out = getitem_on_dict_manager(
@@ -1102,24 +1100,6 @@ def get_guard_manager_from_source(self, source):
11021100
example_value=example_value,
11031101
guard_manager_enum=guard_manager_enum,
11041102
)
1105-
elif istype(source, ODictGetItemSource):
1106-
if isinstance(base_guard_manager, DictGuardManager):
1107-
assert self.manager_guards_on_keys(base_guard_manager_enum)
1108-
out = getitem_on_dict_manager(
1109-
source,
1110-
base_guard_manager,
1111-
base_example_value,
1112-
example_value,
1113-
guard_manager_enum,
1114-
)
1115-
else:
1116-
assert base_guard_manager # to make mypy happy
1117-
out = base_guard_manager.dict_getitem_manager(
1118-
key=source.index,
1119-
source=source_name,
1120-
example_value=example_value,
1121-
guard_manager_enum=guard_manager_enum,
1122-
)
11231103
elif istype(source, DefaultsSource):
11241104
assert base_guard_manager # to make mypy happy
11251105
assert callable(base_example_value)

torch/_dynamo/side_effects.py

-5
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@
3333
from .variables.user_defined import FrozenDataClassVariable
3434

3535

36-
def _manual_update_dict(dict_from, dict_to):
37-
for k, v in dict_from.items():
38-
dict_to[k] = v
39-
40-
4136
def _manual_dict_setitem(dict_from, dict_to, mro_index):
4237
# Carefully calls the dict or OrderedDict `clear` or `__setitem__`. We have
4338
# to be careful because we don't want to trigger the user defined object

torch/_dynamo/source.py

+1-33
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# mypy: allow-untyped-defs
2-
import collections
32
import dataclasses
43
import enum
54
from typing import Any, Optional, Union
@@ -86,9 +85,7 @@ def is_constant_source(source):
8685
return False
8786

8887

89-
def reconstruct_getitem(
90-
source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
91-
):
88+
def reconstruct_getitem(source: "GetItemSource", codegen, index_is_slice):
9289
source.base.reconstruct(codegen)
9390
if isinstance(source.index, Source):
9491
source.index.reconstruct(codegen)
@@ -567,35 +564,6 @@ def name(self):
567564
return f"type({self.base.name()})"
568565

569566

570-
@dataclasses.dataclass(frozen=True)
571-
class ODictGetItemSource(ChainedSource):
572-
index: Any
573-
574-
def __post_init__(self):
575-
assert self.base is not None
576-
577-
def reconstruct(self, codegen):
578-
codegen.add_push_null(
579-
lambda: codegen.append_output(
580-
codegen.create_load_const_unchecked(collections.OrderedDict.__getitem__)
581-
)
582-
)
583-
reconstruct_getitem(self, codegen, index_is_slice=False)
584-
codegen.extend_output(create_call_function(2, False))
585-
586-
def guard_source(self):
587-
return self.base.guard_source()
588-
589-
def name(self):
590-
if isinstance(self.index, type):
591-
rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
592-
return f"___odict_getitem({self.base.name()}, {rep})"
593-
elif isinstance(self.index, Source):
594-
return f"___odict_getitem({self.base.name()}, {self.index.name()})"
595-
else:
596-
return f"___odict_getitem({self.base.name()}, {self.index!r})"
597-
598-
599567
@dataclasses.dataclass(frozen=True)
600568
class OptimizerSource(ChainedSource):
601569
def reconstruct(self, codegen):

torch/_dynamo/variables/misc.py

-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# mypy: ignore-errors
2-
import collections
32
import dataclasses
43
import functools
54
import inspect
@@ -25,7 +24,6 @@
2524
AttrSource,
2625
DefaultsSource,
2726
GetItemSource,
28-
ODictGetItemSource,
2927
TypeSource,
3028
WeakRefCallSource,
3129
)
@@ -193,18 +191,6 @@ def call_method(
193191
return variables.UserMethodVariable(
194192
inner_fn.__func__, self.objvar, source=source
195193
).call_function(tx, args, kwargs)
196-
elif (
197-
inner_fn is collections.OrderedDict.__getitem__
198-
and isinstance(self.objvar, variables.UserDefinedObjectVariable)
199-
and self.objvar.source
200-
and len(args) == 1
201-
and len(kwargs) == 0
202-
and args[0].is_python_constant()
203-
):
204-
key = args[0].as_python_constant()
205-
value = collections.OrderedDict.__getitem__(self.objvar.value, key)
206-
source = ODictGetItemSource(self.objvar.source, key)
207-
return VariableTracker.build(tx, value, source)
208194
elif is_standard_setattr(inner_fn) and isinstance(
209195
self.objvar, UserDefinedObjectVariable
210196
):

torch/_dynamo/variables/user_defined.py

+1-72
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from ..source import (
3535
AttrSource,
3636
GetItemSource,
37-
ODictGetItemSource,
3837
RandomValueSource,
3938
UnspecializedParamBufferSource,
4039
)
@@ -757,12 +756,7 @@ def call_method(
757756
args: "List[VariableTracker]",
758757
kwargs: "Dict[str, VariableTracker]",
759758
) -> "VariableTracker":
760-
from . import (
761-
BuiltinVariable,
762-
ConstantVariable,
763-
TupleVariable,
764-
UserMethodVariable,
765-
)
759+
from . import ConstantVariable, UserMethodVariable
766760

767761
method = self._maybe_get_baseclass_method(name)
768762
if method is not None:
@@ -772,54 +766,6 @@ def call_method(
772766
if is_standard_setattr(method) or isinstance(self.value, threading.local):
773767
return self.method_setattr_standard(tx, *args, **kwargs)
774768

775-
# [NOTE] OrderedDict, dict subtypes must always have source
776-
# We cannot instantiate such subtypes in-graph due to builtin __new__
777-
if method is collections.OrderedDict.keys:
778-
# subclass of OrderedDict
779-
assert not (args or kwargs)
780-
assert self.source # OrderedDict, dict subtypes must always have source
781-
keys = list(self.value.keys())
782-
assert all(map(ConstantVariable.is_literal, keys))
783-
install_guard(self.source.make_guard(GuardBuilder.DICT_CONST_KEYS))
784-
tx.output.guard_on_key_order.add(self.source.name())
785-
return TupleVariable([ConstantVariable.create(k) for k in keys])
786-
787-
if (
788-
method in (collections.OrderedDict.__contains__, dict.__contains__)
789-
and len(args) == 1
790-
and isinstance(args[0], (ConstantVariable, BuiltinVariable))
791-
and inspect.getattr_static(type(self.value), "keys")
792-
in (collections.OrderedDict.keys, dict.keys)
793-
):
794-
assert not kwargs
795-
assert self.source # OrderedDict, dict subtypes must always have source
796-
797-
# TODO(anijain2305) - Why do we need to guard on all keys?
798-
install_guard(self.source.make_guard(GuardBuilder.DICT_CONST_KEYS))
799-
return ConstantVariable.create(
800-
args[0].as_python_constant() in self.value
801-
)
802-
803-
if method is collections.OrderedDict.items and isinstance(
804-
self.value, collections.OrderedDict
805-
):
806-
assert self.source # OrderedDict, dict subtypes must always have source
807-
assert not (args or kwargs)
808-
keys = self.call_method(tx, "keys", [], {})
809-
items = [
810-
TupleVariable(
811-
[key, self.odict_getitem(tx, key)],
812-
)
813-
for key in keys.force_unpack_var_sequence(tx)
814-
]
815-
tx.output.guard_on_key_order.add(self.source.name())
816-
return TupleVariable(items)
817-
818-
if method is collections.OrderedDict.__getitem__ and len(args) == 1:
819-
assert not kwargs
820-
assert self.source # OrderedDict, dict subtypes must always have source
821-
return self.odict_getitem(tx, args[0])
822-
823769
if len(args) == 1 and not kwargs:
824770
if method is object.__eq__:
825771
func_var = VariableTracker.build(tx, polyfills.object_eq)
@@ -1279,23 +1225,6 @@ def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTrack
12791225
handle_observed_exception(tx)
12801226
return variables.ConstantVariable.create(False)
12811227

1282-
def odict_getitem(self, tx: "InstructionTranslator", key):
1283-
from .dicts import is_hashable
1284-
1285-
# TODO this should probably be merged with the dict handling
1286-
1287-
index = (
1288-
key.source
1289-
if is_hashable(key) and key.source is not None
1290-
else key.as_python_constant()
1291-
)
1292-
1293-
return VariableTracker.build(
1294-
tx,
1295-
collections.OrderedDict.__getitem__(self.value, key.as_python_constant()),
1296-
self.source and ODictGetItemSource(self.source, index),
1297-
)
1298-
12991228

13001229
class FrozenDataClassVariable(UserDefinedObjectVariable):
13011230
@staticmethod

0 commit comments

Comments
 (0)