Skip to content

Commit 3b6c041

Browse files
committed
Support use original tensor_meta to to recover the re-extracted samples.
1 parent 4f954ce commit 3b6c041

File tree

5 files changed

+215
-40
lines changed

5 files changed

+215
-40
lines changed

graph_net/imp_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import importlib.util as imp
23

34

@@ -6,5 +7,5 @@ def load_module(path, name="unnamed"):
67
module = imp.module_from_spec(spec)
78
module.__file__ = path
89
spec.loader.exec_module(module)
9-
module.__graph_net_file_path__ = path
10+
module.__graph_net_file_path__ = os.path.normpath(path)
1011
return module
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import os
2+
from graph_net import path_utils
3+
from graph_net.paddle import utils
4+
5+
6+
class GraphMetaRestorer:
7+
def __init__(self, config, parent_model_path):
8+
self.config = config
9+
self.parent_model_path = parent_model_path
10+
print(f"parent_model_path: {self.parent_model_path}")
11+
12+
assert path_utils.is_single_model_dir(
13+
parent_model_path
14+
), f"{parent_model_path=} is not a graphnet sample."
15+
(
16+
parent_weight_meta_classes,
17+
parent_input_meta_classes,
18+
) = self._load_weight_and_input_meta_classes(parent_model_path)
19+
self.original_name2parent_weight_meta_class = self._convert_to_dict(
20+
parent_weight_meta_classes
21+
)
22+
self.original_name2parent_input_meta_class = self._convert_to_dict(
23+
parent_input_meta_classes
24+
)
25+
26+
def __call__(self, model_path):
27+
assert path_utils.is_single_model_dir(
28+
model_path
29+
), f"{model_path=} is not a graphnet sample."
30+
(
31+
weight_meta_classes,
32+
input_meta_classes,
33+
) = self._load_weight_and_input_meta_classes(model_path)
34+
35+
assert self.config["update_inplace"]
36+
is_weight_meta_fully_updated = self._update_by_original_name(
37+
weight_meta_classes, self.original_name2parent_weight_meta_class
38+
)
39+
if is_weight_meta_fully_updated:
40+
new_weight_meta_codes = []
41+
for meta_class in weight_meta_classes:
42+
new_weight_meta_codes.append(
43+
self._generate_py_code_from_meta_class(meta_class)
44+
)
45+
46+
weight_meta_file_path = os.path.join(model_path, "weight_meta.py")
47+
if self.config["update_inplace"]:
48+
print(f"[GraphMetaRestorer] Update {weight_meta_file_path}")
49+
with open(weight_meta_file_path, "w") as f:
50+
f.write("\n\n".join(new_weight_meta_codes))
51+
52+
is_input_meta_fully_updated = self._update_by_tensor_spec(
53+
input_meta_classes, self.original_name2parent_input_meta_class
54+
)
55+
if is_input_meta_fully_updated:
56+
new_input_meta_codes = []
57+
for meta_class in input_meta_classes:
58+
new_input_meta_codes.append(
59+
self._generate_py_code_from_meta_class(meta_class)
60+
)
61+
62+
input_meta_file_path = os.path.join(model_path, "input_meta.py")
63+
if self.config["update_inplace"]:
64+
print(f"[GraphMetaRestorer] Update {input_meta_file_path}")
65+
with open(input_meta_file_path, "w") as f:
66+
f.write("\n\n".join(new_input_meta_codes))
67+
68+
def _load_weight_and_input_meta_classes(self, model_path):
69+
weight_meta_file_path = os.path.join(model_path, "weight_meta.py")
70+
weight_meta_classes = [
71+
meta_class
72+
for (name, meta_class) in utils.get_meta_classes(weight_meta_file_path)
73+
]
74+
75+
input_meta_file_path = os.path.join(model_path, "input_meta.py")
76+
input_meta_classes = [
77+
meta_class
78+
for (name, meta_class) in utils.get_meta_classes(input_meta_file_path)
79+
]
80+
81+
return weight_meta_classes, input_meta_classes
82+
83+
def _convert_to_dict(self, meta_classes):
84+
original_name2meta_class = {}
85+
for meta_class in meta_classes:
86+
assert meta_class.original_name not in original_name2meta_class.keys()
87+
original_name2meta_class[meta_class.original_name] = meta_class
88+
return original_name2meta_class
89+
90+
def _update_tensor_meta(self, meta_class, parent_meta_class):
91+
if (
92+
parent_meta_class
93+
and meta_class.dtype == parent_meta_class.dtype
94+
and meta_class.shape == parent_meta_class.shape
95+
):
96+
for attr_name in ["max_val", "min_val", "mean", "std", "data"]:
97+
if hasattr(meta_class, attr_name) or hasattr(
98+
parent_meta_class, attr_name
99+
):
100+
attr_value = getattr(parent_meta_class, attr_name, None)
101+
setattr(meta_class, attr_name, attr_value)
102+
return True
103+
return False
104+
105+
def _update_by_original_name(self, meta_classes, original_name2parent_meta_class):
106+
updated_class_names = set()
107+
for meta_class in meta_classes:
108+
if not meta_class.original_name:
109+
continue
110+
111+
parent_meta_class = original_name2parent_meta_class.get(
112+
meta_class.original_name, None
113+
)
114+
if self._update_tensor_meta(meta_class, parent_meta_class):
115+
updated_class_names.add(meta_class.name)
116+
117+
print(
118+
f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes are updated."
119+
)
120+
return len(meta_classes) == len(updated_class_names)
121+
122+
def _update_by_tensor_spec(self, meta_classes, original_name2parent_meta_class):
123+
updated_class_names = set()
124+
for meta_class in meta_classes:
125+
matched_parent_meta_class = [
126+
parent_meta_class
127+
for parent_meta_class in original_name2parent_meta_class.values()
128+
if meta_class.dtype == parent_meta_class.dtype
129+
and meta_class.shape == parent_meta_class.shape
130+
]
131+
if len(matched_parent_meta_class) == 1:
132+
self._update_tensor_meta(meta_class, matched_parent_meta_class[0])
133+
updated_class_names.add(meta_class.name)
134+
135+
print(
136+
f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes are updated."
137+
)
138+
return len(meta_classes) == len(updated_class_names)
139+
140+
def _generate_py_code_from_meta_class(self, meta_class):
141+
lines = [f"class {meta_class.__name__}:"]
142+
members = vars(meta_class)
143+
members = {k: v for k, v in members.items() if not k.startswith("__")}
144+
145+
if not members:
146+
return lines[0] + "\n pass"
147+
148+
for name, value in members.items():
149+
value_str = (
150+
f"float('{repr(value)}')" if isinstance(value, float) else repr(value)
151+
)
152+
lines.append(f" {name} = {value_str}")
153+
return "\n".join(lines)

graph_net/paddle/naive_graph_decomposer.py

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os
2+
from typing import List
3+
import paddle
4+
from graph_net import imp_util
25
from graph_net.paddle.extractor import GraphExtractor as BuiltinGraphExtractor
36

47

@@ -19,47 +22,67 @@ def __init__(
1922

2023
def make_config(
2124
self,
22-
split_positions=(),
25+
split_positions=None,
2326
group_head_and_tail=False,
2427
chain_style=False,
2528
output_dir="./tmp/naive_decomposer_dir",
29+
post_extract_process_path=None,
30+
post_extract_process_class_name=None,
31+
post_extract_process_config=None,
2632
):
27-
for pos in split_positions:
33+
assert not chain_style, "chain_style=True is not supported now."
34+
if split_positions is not None:
2835
assert isinstance(
29-
pos, int
30-
), f"split_positions should be list of int, {split_positions=}"
36+
split_positions, (tuple, list)
37+
), f"split_positions is expected to be tuple or list, but recived {split_positions=}"
38+
for pos in split_positions:
39+
assert isinstance(
40+
pos, int
41+
), f"split_positions is expected to be tuple or list of int, but recived {split_positions=}"
3142
return {
3243
"split_positions": split_positions,
3344
"group_head_and_tail": group_head_and_tail,
3445
"chain_style": chain_style,
3546
"output_dir": output_dir,
47+
"post_extract_process_path": post_extract_process_path,
48+
"post_extract_process_class_name": post_extract_process_class_name,
49+
"post_extract_process_config": post_extract_process_config,
3650
}
3751

3852
def __call__(self, **input_dict):
3953
extracted_model = self.get_naive_decomposer_extractor()(**input_dict)
4054
return extracted_model
4155

4256
def get_naive_decomposer_extractor(self):
43-
return NaiveDecomposerExtractor(self)
57+
return NaiveDecomposerExtractor(
58+
config=self.config,
59+
parent_model=self.model,
60+
parent_model_name=self.name,
61+
parent_input_spec=self.input_spec,
62+
)
4463

4564

4665
class NaiveDecomposerExtractor:
47-
def __init__(self, parent_graph_extractor):
48-
super().__init__()
49-
self.parent_graph_extractor = parent_graph_extractor
66+
def __init__(
67+
self,
68+
config: dict,
69+
parent_model: paddle.nn.Layer,
70+
parent_model_name: str,
71+
parent_input_spec: List[paddle.static.InputSpec],
72+
):
73+
self.config = config
5074
self.extracted = False
75+
self.parent_model_path = os.path.dirname(parent_model.__graph_net_file_path__)
5176
self.builtin_extractor = BuiltinGraphExtractor(
52-
model=parent_graph_extractor.model,
53-
name=parent_graph_extractor.name,
54-
dynamic=parent_graph_extractor.dynamic,
55-
input_spec=parent_graph_extractor.input_spec,
56-
workspace_path=self.parent_graph_extractor.config["output_dir"],
77+
model=parent_model,
78+
name=parent_model_name,
79+
dynamic=False,
80+
input_spec=parent_input_spec,
81+
workspace_path=self.config["output_dir"],
5782
)
58-
self.split_positions = self.parent_graph_extractor.config["split_positions"]
59-
self.group_head_and_tail = self.parent_graph_extractor.config[
60-
"group_head_and_tail"
61-
]
62-
self.post_process = self.make_post_process(self.parent_graph_extractor.config)
83+
self.split_positions = self.config["split_positions"]
84+
self.group_head_and_tail = self.config["group_head_and_tail"]
85+
self.post_extract_process = self.make_post_extract_process(self.config)
6386

6487
def do_extract(self, **input_dict):
6588
# 1. Run the model to dump pir programs
@@ -97,14 +120,17 @@ def __call__(self, **input_dict):
97120
if not self.extracted:
98121
extracted_model = self.do_extract(**input_dict)
99122
self.extracted = True
100-
# if self.extracted:
101-
# for subgraph_path in self.subgraph_path_list:
102-
# self.post_process(subgraph_path)
123+
124+
for subgraph_path in self.subgraph_path_list:
125+
self._post_extract_process(subgraph_path)
103126
return extracted_model
104127

105-
def make_post_process(self, config):
106-
return None
107-
# if config["post_process_path"] is None:
108-
# return None
109-
# module = imp_util.load_module(config["post_process_path"])
110-
# return module.PostExtractProcess(config["post_process_config"])
128+
def _post_extract_process(self, subgraph_path):
129+
return self.post_extract_process(subgraph_path)
130+
131+
def make_post_extract_process(self, config):
132+
if config.get("post_extract_process_path") is None:
133+
return lambda *args, **kwargs: None
134+
module = imp_util.load_module(config["post_extract_process_path"])
135+
cls = getattr(module, config["post_extract_process_class_name"])
136+
return cls(config["post_extract_process_config"], self.parent_model_path)

graph_net/paddle/run_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import os
2-
import sys
32
import json
43
import base64
54
import argparse
6-
from typing import Type
75

86
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
97

@@ -16,6 +14,7 @@ def load_class_from_file(file_path: str, class_name: str):
1614
print(f"Load {class_name} from {file_path}")
1715
module = imp_util.load_module(file_path, "unnamed")
1816
model_class = getattr(module, class_name, None)
17+
setattr(model_class, "__graph_net_file_path__", os.path.normpath(file_path))
1918
return model_class
2019

2120

@@ -26,7 +25,8 @@ def get_input_dict(model_path):
2625

2726
state_dict = {}
2827
for k, v in params.items():
29-
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=k)
28+
name = v["original_name"] if v.get("original_name", None) else k
29+
state_dict[k] = paddle.nn.parameter.Parameter(utils.replay_tensor(v), name=name)
3030
for k, v in inputs.items():
3131
state_dict[k] = utils.replay_tensor(v)
3232
return state_dict
@@ -83,4 +83,5 @@ def main(args):
8383
help="decorator configuration string",
8484
)
8585
args = parser.parse_args()
86+
print(args)
8687
main(args=args)

graph_net/paddle/utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
1-
import re
2-
from collections import OrderedDict
3-
import uuid
4-
import json
5-
import os
6-
import argparse
71
import importlib
8-
import inspect
92
import ast
103
import math
114
import numpy as np
@@ -141,7 +134,7 @@ def convert_to_valid_number(data_type, value):
141134

142135
def convert_meta_classes_to_tensors(file_path):
143136
current_device = paddle.device.get_device()
144-
for name, cls in _get_classes(file_path):
137+
for name, cls in get_meta_classes(file_path):
145138
attrs = {
146139
k: v
147140
for k, v in cls.__dict__.items()
@@ -169,10 +162,11 @@ def convert_meta_classes_to_tensors(file_path):
169162
},
170163
"data": data_value,
171164
"name": attrs.get("name"),
165+
"original_name": attrs.get("original_name", None),
172166
}
173167

174168

175-
def _get_classes(file_path):
169+
def get_meta_classes(file_path):
176170
with open(file_path, "r", encoding="utf-8") as f:
177171
tree = ast.parse(f.read(), filename=file_path)
178172

0 commit comments

Comments
 (0)