Skip to content

Commit 78010b8

Browse files
committed
Merge branch 'add_original_name_sample' into add_original_names
2 parents 3b6c041 + a59fbba commit 78010b8

File tree

1,824 files changed

+42593
-2362
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,824 files changed

+42593
-2362
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
samples/transformers-auto-model/dbmdz_electra-large-discriminator-finetuned-conll03-english

graph_net/config/empty_cstr_torch_samples_list.txt

Lines changed: 151 additions & 487 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
samples/transformers-auto-model/microsoft_xclip-base-patch32-16-frames

graph_net/constraint_util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import tempfile
1313
import shutil
1414
from pathlib import Path
15-
import json
1615
from dataclasses import asdict
1716

1817

@@ -187,12 +186,14 @@ def _save_model_to_log_file(self, model_path):
187186
shutil.copy(Path(model_path) / "model.py", log_file)
188187

189188
def _save_dim_gen_pass_names(self, dim_gen_pass_names, model_path):
190-
from graph_net.graph_net_json_file_util import kDimensionGeneralizationPasses
189+
from graph_net.graph_net_json_file_util import (
190+
kDimensionGeneralizationPasses,
191+
update_json,
192+
)
191193

192-
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
193-
graph_net_json = json.loads(graph_net_json_file_path.read_text())
194-
graph_net_json[kDimensionGeneralizationPasses] = list(dim_gen_pass_names)
195-
graph_net_json_file_path.write_text(json.dumps(graph_net_json))
194+
update_json(
195+
model_path, kDimensionGeneralizationPasses, list(dim_gen_pass_names)
196+
)
196197

197198
def _save_dyn_dim_cstr(self, dyn_dim_cstr, model_path):
198199
cstr_code = dyn_dim_cstr.serialize_to_py_str()

graph_net/dimension_generalizer.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
import logging
2+
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
3+
from graph_net.imp_util import load_module
4+
from graph_net.tensor_meta import TensorMeta
5+
import functools
6+
import sys
7+
import os
8+
from contextlib import contextmanager
9+
import tempfile
10+
import shutil
11+
from pathlib import Path
12+
from dataclasses import asdict
13+
import graph_net.graph_net_json_file_util as gn_json
14+
from collections import OrderedDict
15+
import copy
16+
from graph_net.hash_util import get_sha256_hash
17+
18+
19+
class ApplyDimGenPasses:
20+
def __init__(self, config=None):
21+
if config is None:
22+
config = {}
23+
self.config = self._make_config(**config)
24+
self.num_handled_models = 0
25+
26+
def _make_config(
27+
self,
28+
output_dir: str,
29+
dimension_generalizer_filepath=None,
30+
dimension_generalizer_class_name="StaticToDynamic",
31+
dimension_generalizer_config=None,
32+
model_path_prefix="",
33+
resume=False,
34+
last_model_log_file=None,
35+
limits_handled_models=None,
36+
):
37+
if dimension_generalizer_config is None:
38+
dimension_generalizer_config = {}
39+
return {
40+
"resume": resume,
41+
"output_dir": output_dir,
42+
"model_path_prefix": model_path_prefix,
43+
"dimension_generalizer_filepath": dimension_generalizer_filepath,
44+
"dimension_generalizer_class_name": dimension_generalizer_class_name,
45+
"dimension_generalizer_config": dimension_generalizer_config,
46+
"last_model_log_file": last_model_log_file,
47+
"limits_handled_models": limits_handled_models,
48+
}
49+
50+
def __call__(self, rel_model_path):
51+
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
52+
output_dir = Path(self.config["output_dir"])
53+
output_dir.mkdir(parents=True, exist_ok=True)
54+
generalized_model_path = output_dir / rel_model_path
55+
if (
56+
self.config["resume"]
57+
and generalized_model_path.exists()
58+
and generalized_model_path.is_dir()
59+
and len(list(generalized_model_path.iterdir())) > 0
60+
):
61+
return
62+
tensor_metas = self._get_tensor_metas(model_path)
63+
tensor_meta_attrs_list = [asdict(tensor_meta) for tensor_meta in tensor_metas]
64+
dim_gen_pass_names = self._get_dim_gen_pass_names(model_path)
65+
dim_generalizer = self._get_dimension_generalizer(dim_gen_pass_names)
66+
inputs = dim_generalizer.create_inputs_by_metas(
67+
module=self._get_model(model_path),
68+
tensor_meta_attrs_list=tensor_meta_attrs_list,
69+
)
70+
dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file(
71+
os.path.join(model_path, "input_tensor_constraints.py")
72+
)
73+
dim_axes_pairs = self._get_dim_axes_pairs(dyn_dim_cstrs)
74+
if len(dim_axes_pairs) == 0:
75+
print("No symbolic dims found. {model_path=}")
76+
return
77+
78+
def get_generalized():
79+
return self._get_generalized_model_py_file_path(
80+
dim_generalizer=dim_generalizer,
81+
dim_axes_pairs=dim_axes_pairs,
82+
model_path=model_path,
83+
inputs=inputs,
84+
)
85+
86+
with get_generalized() as tmp_model_py_path:
87+
from_model_path = Path(self.config["model_path_prefix"]) / rel_model_path
88+
triples = self._get_reified_tensor_metas(from_model_path, dyn_dim_cstrs)
89+
for symbol2example_value, cur_tensor_metas, cur_dyn_dim_cstrs in triples:
90+
to_model_path = self._get_to_model_path(
91+
rel_model_path, symbol2example_value
92+
)
93+
print(f"{str(to_model_path)=}")
94+
self._copy_sample_model_path(from_model_path, to_model_path)
95+
self._save_generalized_model_path(to_model_path, tmp_model_py_path)
96+
self._save_tensor_metas_as_weight_meta(to_model_path, cur_tensor_metas)
97+
self._save_dyn_dim_cstrs(to_model_path, cur_dyn_dim_cstrs)
98+
99+
self._check_num_handled_models()
100+
101+
def _get_reified_tensor_metas(self, from_model_path, dyn_dim_cstrs):
102+
tensor_metas = self._get_tensor_metas(str(from_model_path))
103+
symbols, reified_dims = self._get_symbols_and_reified_dims(
104+
from_model_path, dyn_dim_cstrs
105+
)
106+
for dims in reified_dims:
107+
symbol2example_value = OrderedDict(list(zip(symbols, dims)))
108+
cur_dyn_dim_cstrs = copy.deepcopy(dyn_dim_cstrs)
109+
cur_tensor_metas = copy.deepcopy(tensor_metas)
110+
cur_dyn_dim_cstrs.update_symbol2example_value(symbol2example_value)
111+
update_tensor_metas_by_dyn_dim_cstr(cur_tensor_metas, cur_dyn_dim_cstrs)
112+
yield symbol2example_value, cur_tensor_metas, cur_dyn_dim_cstrs
113+
114+
def _get_symbols_and_reified_dims(self, from_model_path, dyn_dim_cstrs):
115+
json_value = gn_json.read_json(str(from_model_path))
116+
reifier_name = json_value[gn_json.kSymbolicDimensionReifier]
117+
from graph_net.torch.sym_dim_reifiers.reifier_mgr import get_reifier
118+
119+
reifier_class = get_reifier(reifier_name)
120+
reifier_instance = reifier_class(str(from_model_path))
121+
assert reifier_instance.match
122+
symbols2reified_dims = reifier_instance.reify()
123+
assert len(symbols2reified_dims) == 1
124+
symbols, reified_dims = next(iter(symbols2reified_dims.items()))
125+
assert tuple(symbols) == tuple(dyn_dim_cstrs.symbols)
126+
assert all(len(symbols) == len(dims) for dims in reified_dims)
127+
return symbols, reified_dims
128+
129+
def _save_dyn_dim_cstrs(self, to_model_path, dyn_dim_cstrs):
130+
cstr_code = dyn_dim_cstrs.serialize_to_py_str()
131+
(to_model_path / "input_tensor_constraints.py").write_text(cstr_code)
132+
133+
def _save_tensor_metas_as_weight_meta(self, to_model_path, tensor_metas):
134+
weight_meta_code = "\n".join(
135+
tensor_meta.serialize_to_py_str() for tensor_meta in tensor_metas
136+
)
137+
(to_model_path / "weight_meta.py").write_text(weight_meta_code)
138+
139+
def _get_to_model_path(self, rel_model_path, symbol2example_value):
140+
sym_dim_str = "_".join(
141+
f"{sym_name}_{dim}"
142+
for symbol, dim in symbol2example_value.items()
143+
for sym_name in [symbol.name]
144+
)
145+
sub_module_name = f"{os.path.basename(rel_model_path)}__{sym_dim_str}"
146+
to_model_path = (
147+
Path(self.config["output_dir"]) / rel_model_path / sub_module_name
148+
)
149+
return to_model_path
150+
151+
def _copy_sample_model_path(self, from_model_path, to_model_path):
152+
to_model_path.mkdir(parents=True, exist_ok=True)
153+
shutil.copytree(Path(from_model_path), Path(to_model_path), dirs_exist_ok=True)
154+
155+
def _save_generalized_model_path(self, to_model_path, tmp_model_py_path):
156+
generalized_model_py_code = Path(tmp_model_py_path).read_text()
157+
(to_model_path / "model.py").write_text(generalized_model_py_code)
158+
file_hash = get_sha256_hash(generalized_model_py_code)
159+
(to_model_path / "graph_hash.txt").write_text(file_hash)
160+
161+
def _get_dim_axes_pairs(self, dyn_dim_cstrs):
162+
sym_input_shapes = dyn_dim_cstrs.get_sorted_symbolic_input_shapes()
163+
return [
164+
(dim, axes)
165+
for symbol in dyn_dim_cstrs.symbols
166+
for dim in [dyn_dim_cstrs.symbol2example_value[symbol]]
167+
for axes in [
168+
[
169+
axis
170+
for shape in sym_input_shapes
171+
for axis, sym_or_dim in enumerate(shape)
172+
if sym_or_dim == symbol
173+
]
174+
]
175+
]
176+
177+
def _get_dim_gen_pass_names(self, model_path):
178+
json_value = gn_json.read_json(model_path)
179+
return json_value.get(gn_json.kDimensionGeneralizationPasses, [])
180+
181+
def _check_num_handled_models(self):
182+
self.num_handled_models += 1
183+
limits = self.config["limits_handled_models"]
184+
if limits is None:
185+
return
186+
if self.num_handled_models < limits:
187+
return
188+
print("`num_handled_models` exceeds config `limits_handled_models`")
189+
sys.exit(0)
190+
191+
def _get_dimension_generalizer(self, dim_gen_pass_names):
192+
assert self.config["dimension_generalizer_filepath"] is not None
193+
decorator_cls = getattr(
194+
load_module(self.config["dimension_generalizer_filepath"]),
195+
self.config["dimension_generalizer_class_name"],
196+
)
197+
config = {"pass_names": dim_gen_pass_names}
198+
dim_generalizer = decorator_cls(config)
199+
return dim_generalizer
200+
201+
def _get_model(self, model_path):
202+
py_module = load_module(os.path.join(model_path, "model.py"))
203+
GraphModule = getattr(py_module, "GraphModule")
204+
GraphModule.__graph_net_file_path__ = py_module.__graph_net_file_path__
205+
return GraphModule()
206+
207+
@contextmanager
208+
def _get_generalized_model_py_file_path(
209+
self, dim_generalizer, dim_axes_pairs, model_path, inputs
210+
):
211+
model = self._get_model(model_path)
212+
dim_gen_pass = dim_generalizer(model, dim_axes_pairs)
213+
logging.warning("before need_rewrite")
214+
need_rewrite = dim_gen_pass.need_rewrite(inputs)
215+
logging.warning("after need_rewrite")
216+
if not need_rewrite:
217+
yield os.path.join(model_path, "model.py")
218+
return
219+
logging.warning("before rewrite")
220+
graph_module = dim_gen_pass.rewrite(inputs)
221+
logging.warning("after rewrite")
222+
with tempfile.TemporaryDirectory() as tmp_dir:
223+
shutil.copytree(Path(model_path), Path(tmp_dir), dirs_exist_ok=True)
224+
dim_gen_pass.save_graph_module(graph_module, tmp_dir)
225+
yield os.path.join(tmp_dir, "model.py")
226+
227+
def _get_tensor_metas(self, model_path):
228+
make = TensorMeta.unserialize_from_py_file
229+
return [
230+
*make(os.path.join(model_path, "input_meta.py")),
231+
*make(os.path.join(model_path, "weight_meta.py")),
232+
]
233+
234+
235+
def update_tensor_metas_by_dyn_dim_cstr(
236+
tensor_metas: list[TensorMeta], dyn_dim_cstr: DynamicDimConstraints
237+
):
238+
input_shapes = dyn_dim_cstr.get_reified_input_shapes()
239+
assert len(tensor_metas) == len(input_shapes)
240+
for i, tensor_meta in enumerate(tensor_metas):
241+
tensor_meta.shape = input_shapes[i]
242+
if tensor_meta.data is not None:
243+
assert isinstance(tensor_meta.data, (list, tuple))
244+
size = functools.reduce(lambda a, b: a * b, tensor_meta.shape, 1)
245+
doubled_data = [*tensor_meta.data, *tensor_meta.data]
246+
tensor_meta.data = doubled_data[:size]

graph_net/dynamic_dim_constraints.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@ class DynamicDimConstraints:
2323
input_shapes: list[(tuple[sympy.Expr | int], str)]
2424
kInputShapes = "dynamic_dim_constraint_input_shapes"
2525

26+
def serialize_symbolic_input_shapes_to_str(self):
27+
input_shapes = self.get_sorted_symbolic_input_shapes()
28+
input_shapes_str = str(input_shapes).replace(" ", "")
29+
return input_shapes_str
30+
31+
def get_sorted_symbolic_input_shapes(self):
32+
return sorted(
33+
[
34+
tuple(shape)
35+
for shape, name in self.input_shapes
36+
if any(isinstance(dim, sympy.Expr) for dim in shape)
37+
],
38+
key=str,
39+
)
40+
2641
@classmethod
2742
def make_by_named_inputs(cls, named_shapes):
2843
return cls(
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,17 @@
1+
from pathlib import Path
2+
import json
3+
14
kDimensionGeneralizationPasses = "dimension_generalization_passes"
5+
kSymbolicDimensionReifier = "symbolic_dimension_reifier"
6+
7+
8+
def read_json(model_path):
9+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
10+
return json.loads(graph_net_json_file_path.read_text())
11+
12+
13+
def update_json(model_path, field, value):
14+
graph_net_json_file_path = Path(f"{model_path}/graph_net.json")
15+
graph_net_json = json.loads(graph_net_json_file_path.read_text())
16+
graph_net_json[field] = value
17+
graph_net_json_file_path.write_text(json.dumps(graph_net_json, indent=4))

graph_net/hash_util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import hashlib
2+
3+
4+
def get_sha256_hash(content):
5+
m = hashlib.sha256()
6+
m.update(content.encode())
7+
return m.hexdigest()

0 commit comments

Comments
 (0)