Skip to content

Commit 3914566

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Refactor OrderedDict to dict (pytorch#113234)
In Python3 all dicts are ordered. Pull Request resolved: pytorch#113234 Approved by: https://github.com/oulgen, https://github.com/lezcano
1 parent 728ed37 commit 3914566

File tree

12 files changed

+90
-135
lines changed

12 files changed

+90
-135
lines changed

Diff for: torch/_dynamo/codegen.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import sys
55
import types
6-
from typing import Counter, List, Optional, OrderedDict
6+
from typing import Counter, Dict, List, Optional
77

88
import torch.nn
99
from . import utils
@@ -51,9 +51,7 @@ def __init__(
5151
self.root = root
5252
self.top_of_stack: Optional[VariableTracker] = None
5353
self.uses: Counter[VariableTracker] = collections.Counter()
54-
self.graph_outputs: OrderedDict[
55-
int, GraphOutputEntry
56-
] = collections.OrderedDict()
54+
self.graph_outputs: Dict[int, GraphOutputEntry] = {}
5755
self._output: List[Instruction] = []
5856
self.tempvars = tempvars or {}
5957
self.tx = tx

Diff for: torch/_dynamo/convert_frame.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,9 @@ def maybe_cprofile(func):
405405
return func
406406

407407

408-
from collections import OrderedDict
409-
410408
from torch.utils.hooks import RemovableHandle
411409

412-
_bytecode_hooks: Dict[int, BytecodeHook] = OrderedDict()
410+
_bytecode_hooks: Dict[int, BytecodeHook] = {}
413411

414412

415413
def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:

Diff for: torch/_dynamo/guards.py

+40-51
Original file line numberDiff line numberDiff line change
@@ -91,46 +91,37 @@ def uninteresting_files():
9191
return {inspect.getfile(m) for m in mods}
9292

9393

94-
CLOSURE_VARS = collections.OrderedDict(
95-
[
96-
("___check_type_id", check_type_id),
97-
("___check_obj_id", check_obj_id),
98-
(
99-
"___current_backend",
100-
lambda: torch._dynamo.eval_frame.guarded_backend_cache.current_backend,
101-
),
102-
(
103-
"___lookup_backend",
104-
lambda backend_obj_id: torch._dynamo.eval_frame.guarded_backend_cache.cached_backends[
105-
backend_obj_id
106-
],
107-
),
108-
(
109-
"___skip_backend_check",
110-
lambda: torch._dynamo.eval_frame.guarded_backend_cache.skip_backend_check_for_run_only_mode,
111-
),
112-
("___odict_getitem", collections.OrderedDict.__getitem__),
113-
("___dict_param_key_ids", dict_param_key_ids),
114-
("___dict_const_keys", dict_const_keys),
115-
("___dict_version", dict_version),
116-
("___dict_contains", lambda a, b: a in b),
117-
("___tuple_iterator_len", tuple_iterator_len),
118-
("___tuple_iterator_getitem", tuple_iterator_getitem),
119-
("__math_isnan", math.isnan),
120-
("inf", float("inf")),
121-
("__load_module", lambda name: importlib.import_module(name)),
122-
("utils_device", torch.utils._device),
123-
("device", torch.device),
124-
(
125-
"___from_numpy",
126-
# If not numpy array, piggy back on e.g. tensor guards to check type
127-
lambda a: torch.as_tensor(a)
128-
if isinstance(a, (np.generic, np.ndarray))
129-
else a,
130-
),
131-
("torch", torch),
132-
]
133-
)
94+
CLOSURE_VARS = {
95+
"___check_type_id": check_type_id,
96+
"___check_obj_id": check_obj_id,
97+
"___current_backend": (
98+
lambda: torch._dynamo.eval_frame.guarded_backend_cache.current_backend
99+
),
100+
"___lookup_backend": (
101+
lambda backend_obj_id: torch._dynamo.eval_frame.guarded_backend_cache.cached_backends[
102+
backend_obj_id
103+
]
104+
),
105+
"___skip_backend_check": (
106+
lambda: torch._dynamo.eval_frame.guarded_backend_cache.skip_backend_check_for_run_only_mode
107+
),
108+
"___odict_getitem": collections.OrderedDict.__getitem__,
109+
"___dict_param_key_ids": dict_param_key_ids,
110+
"___dict_const_keys": dict_const_keys,
111+
"___dict_version": dict_version,
112+
"___dict_contains": lambda a, b: a in b,
113+
"___tuple_iterator_len": tuple_iterator_len,
114+
"___tuple_iterator_getitem": tuple_iterator_getitem,
115+
"__math_isnan": math.isnan,
116+
"inf": float("inf"),
117+
"__load_module": lambda name: importlib.import_module(name),
118+
"utils_device": torch.utils._device,
119+
"device": torch.device,
120+
"___from_numpy":
121+
# If not numpy array, piggy back on e.g. tensor guards to check type
122+
(lambda a: torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a),
123+
"torch": torch,
124+
}
134125

135126
if sys.version_info[:2] <= (3, 8):
136127
# [Note: Python Version <= 3.8]
@@ -1138,17 +1129,15 @@ def convert(size_or_stride):
11381129
if global_state is None:
11391130
# we should only hit this case in NopTests()
11401131
global_state = convert_frame.GlobalStateGuard()
1141-
closure_vars = collections.OrderedDict(
1142-
[
1143-
("___guarded_code", self),
1144-
("___check_tensors", check_tensors_fn),
1145-
("___check_tensors_verbose", check_tensors_verbose_fn),
1146-
("___check_global_state", global_state.check),
1147-
("tensor_check_names", tensor_check_names),
1148-
]
1149-
+ list(SYMPY_INTERP.items())
1150-
)
1151-
closure_vars.update(CLOSURE_VARS)
1132+
closure_vars = {
1133+
"___guarded_code": self,
1134+
"___check_tensors": check_tensors_fn,
1135+
"___check_tensors_verbose": check_tensors_verbose_fn,
1136+
"___check_global_state": global_state.check,
1137+
"tensor_check_names": tensor_check_names,
1138+
**SYMPY_INTERP,
1139+
**CLOSURE_VARS,
1140+
}
11521141

11531142
unique_code_parts = list(unique(code_parts))
11541143
make_guard_fn_args = ", ".join(closure_vars.keys())

Diff for: torch/_dynamo/output_graph.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,7 @@
1010
import traceback
1111
import weakref
1212
from dataclasses import dataclass
13-
from typing import (
14-
Any,
15-
Callable,
16-
Dict,
17-
List,
18-
NamedTuple,
19-
Optional,
20-
OrderedDict,
21-
Set,
22-
Tuple,
23-
Union,
24-
)
13+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union
2514

2615
import sympy
2716

@@ -827,9 +816,7 @@ def append_prefix_insts():
827816
root = FakeRootModule(self.nn_modules)
828817
# Add all the local vars to the "stack" so restore at the end
829818
restore_vars = []
830-
val_to_names: OrderedDict[
831-
VariableTracker, List[str]
832-
] = collections.OrderedDict()
819+
val_to_names: Dict[VariableTracker, List[str]] = {}
833820
if stack_values:
834821
val_to_names[stack_values[-1]] = list()
835822
# NB: Typically (i.e., for graph compile from RETURN_VALUE),
@@ -1307,7 +1294,7 @@ def __init__(
13071294
# Map from graph input name to its placeholder proxy object, where the
13081295
# map's keys give all current placeholder node names and can be used to
13091296
# create unique node names
1310-
self.input_name_to_proxy: OrderedDict[str, fx.Proxy] = collections.OrderedDict()
1297+
self.input_name_to_proxy: Dict[str, fx.Proxy] = {}
13111298
# Node => computed real value (see utils.get_real_value)
13121299
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
13131300

@@ -1324,9 +1311,8 @@ def __init__(
13241311
# - If we are tracing a HigherOrderOperator's body_fn, then we
13251312
# need to keep track of what free variables were lifted so we can
13261313
# rewrite the HigherOrderOperator call using the traced body_fn.
1327-
# This is a OrderedDict so that we can
1328-
# maintain the order of args for the HigherOrderOperator call.
1329-
self.lifted_freevars = collections.OrderedDict()
1314+
# Dicts maintain the order of args for the HigherOrderOperator call.
1315+
self.lifted_freevars = {}
13301316
self.prev_inst = None
13311317

13321318
self._cur_code = None

Diff for: torch/_dynamo/side_effects.py

+19-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import collections
21
import inspect
32
from typing import Any, Dict, List, Optional
43

@@ -76,8 +75,8 @@ def __init__(
7675
tensor_hooks=None,
7776
):
7877
super().__init__()
79-
self.id_to_variable = id_to_variable or collections.OrderedDict()
80-
self.store_attr_mutations = store_attr_mutations or collections.OrderedDict()
78+
self.id_to_variable = id_to_variable or {}
79+
self.store_attr_mutations = store_attr_mutations or {}
8180
self.keepalive = keepalive or []
8281
self.save_for_backward = save_for_backward or []
8382
self.tensor_hooks = tensor_hooks or {}
@@ -115,11 +114,10 @@ def diff(self, other: "SideEffects") -> Optional[str]:
115114
def clone(self):
116115
"""Create a shallow copy"""
117116
return self.__class__(
118-
id_to_variable=collections.OrderedDict(self.id_to_variable),
119-
store_attr_mutations=collections.OrderedDict(
120-
(k, collections.OrderedDict(v))
121-
for k, v in self.store_attr_mutations.items()
122-
),
117+
id_to_variable=dict(self.id_to_variable),
118+
store_attr_mutations={
119+
k: dict(v) for k, v in self.store_attr_mutations.items()
120+
},
123121
keepalive=list(self.keepalive),
124122
save_for_backward=self.save_for_backward,
125123
tensor_hooks=self.tensor_hooks,
@@ -129,14 +127,14 @@ def apply(self, fn, cache=None, skip_fn=lambda _: False):
129127
if cache is None:
130128
cache = dict()
131129

132-
self.id_to_variable = collections.OrderedDict(
133-
(k, VariableTracker.apply(fn, v, cache, skip_fn))
130+
self.id_to_variable = {
131+
k: VariableTracker.apply(fn, v, cache, skip_fn)
134132
for k, v in self.id_to_variable.items()
135-
)
136-
self.store_attr_mutations = collections.OrderedDict(
137-
(k, VariableTracker.apply(fn, v, cache, skip_fn))
133+
}
134+
self.store_attr_mutations = {
135+
k: VariableTracker.apply(fn, v, cache, skip_fn)
138136
for k, v in self.store_attr_mutations.items()
139-
)
137+
}
140138
self.save_for_backward = VariableTracker.apply(
141139
fn, self.save_for_backward, cache, skip_fn
142140
)
@@ -164,7 +162,7 @@ def store_attr(self, item: VariableTracker, name: str, value: VariableTracker):
164162
assert self.is_attribute_mutation(item)
165163
self.check_allowed_side_effect(item)
166164
if item.mutable_local not in self.store_attr_mutations:
167-
self.store_attr_mutations[item.mutable_local] = collections.OrderedDict()
165+
self.store_attr_mutations[item.mutable_local] = {}
168166
self.store_attr_mutations[item.mutable_local][name] = value
169167

170168
def load_attr(self, item, name, deleted_ok=False):
@@ -320,12 +318,12 @@ def is_live(var: VariableTracker):
320318
for skip_obj, setattrs in self.store_attr_mutations.items():
321319
VariableTracker.apply(visit, setattrs)
322320

323-
self.id_to_variable = collections.OrderedDict(
324-
(k, v) for k, v in self.id_to_variable.items() if is_live(v)
325-
)
326-
self.store_attr_mutations = collections.OrderedDict(
327-
(k, v) for k, v in self.store_attr_mutations.items() if is_live(k)
328-
)
321+
self.id_to_variable = {
322+
k: v for k, v in self.id_to_variable.items() if is_live(v)
323+
}
324+
self.store_attr_mutations = {
325+
k: v for k, v in self.store_attr_mutations.items() if is_live(k)
326+
}
329327

330328
def mutation(self, oldvar, newvar):
331329
self.check_allowed_side_effect(oldvar)

Diff for: torch/_dynamo/symbolic_convert.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,9 @@ def prune_dead_locals(self):
622622
# reads = reads | {"__class__"}
623623
# output variables?
624624
reads = reads | set(self.cell_and_freevars())
625-
self.symbolic_locals = collections.OrderedDict(
626-
[(k, v) for k, v in self.symbolic_locals.items() if k in reads]
627-
)
625+
self.symbolic_locals = {
626+
k: v for k, v in self.symbolic_locals.items() if k in reads
627+
}
628628
self.output.side_effects.prune_dead_object_new(self)
629629

630630
def call_function(
@@ -1863,7 +1863,7 @@ def copy_graphstate(self) -> InstructionTranslatorGraphState:
18631863
"""Create a checkpoint of the current state by copying everything"""
18641864
return InstructionTranslatorGraphState(
18651865
self.output.copy_graphstate(),
1866-
collections.OrderedDict(self.symbolic_locals),
1866+
dict(self.symbolic_locals),
18671867
list(self.stack),
18681868
list(self.block_stack),
18691869
self.instruction_pointer,
@@ -2071,9 +2071,9 @@ def __init__(
20712071
f_globals=f_globals,
20722072
f_builtins=f_builtins,
20732073
code_options=code_options,
2074-
symbolic_locals=collections.OrderedDict(), # set below
2074+
symbolic_locals={}, # set below
20752075
# A global var is inserted only after a STORE_GLOBAL happens to it
2076-
symbolic_globals=collections.OrderedDict(),
2076+
symbolic_globals={},
20772077
f_code=f_code,
20782078
export=export,
20792079
inline_depth=0,

Diff for: torch/_dynamo/types.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,7 @@
11
import dataclasses
22
import sys
33
import types
4-
from typing import (
5-
Any,
6-
Callable,
7-
Dict,
8-
List,
9-
NamedTuple,
10-
Optional,
11-
OrderedDict,
12-
Protocol,
13-
Union,
14-
)
4+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
155

166
from typing_extensions import TypeAlias
177

@@ -41,7 +31,7 @@ class GuardFail(NamedTuple):
4131

4232

4333
class GuardFn(Protocol):
44-
closure_vars: OrderedDict[str, object]
34+
closure_vars: Dict[str, object]
4535
args: List[str]
4636
code_parts: List[str]
4737
verbose_code_parts: List[str]

Diff for: torch/_dynamo/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@
8080
log = logging.getLogger(__name__)
8181

8282
# profiling compilation time by function
83-
compilation_time_metrics = collections.OrderedDict()
83+
compilation_time_metrics = {}
8484

8585
# profiling compilation time by frame phase
86-
frame_phase_timing = collections.OrderedDict()
86+
frame_phase_timing = {}
8787

8888
timer_counter = itertools.count()
8989

Diff for: torch/_dynamo/variables/base.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,7 @@ def update_object_dict(v):
212212
result = [cls.apply(fn, v, cache, skip_fn) for v in value]
213213
elif istype(value, tuple):
214214
result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value)
215-
elif istype(value, collections.OrderedDict):
216-
result = collections.OrderedDict(
217-
cls.apply(fn, v, cache, skip_fn) for v in value.items()
218-
)
219-
elif istype(value, dict):
215+
elif istype(value, (dict, collections.OrderedDict)):
220216
assert "__name__" not in value, "_nonvar_fields should have excluded this"
221217
result = {
222218
k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items())

0 commit comments

Comments
 (0)