|
| 1 | +import os |
| 2 | +from graph_net import path_utils |
| 3 | +from graph_net.paddle import utils |
| 4 | + |
| 5 | + |
| 6 | +class GraphMetaRestorer: |
| 7 | + def __init__(self, config, parent_model_path): |
| 8 | + self.config = config |
| 9 | + self.parent_model_path = parent_model_path |
| 10 | + print(f"parent_model_path: {self.parent_model_path}") |
| 11 | + |
| 12 | + assert path_utils.is_single_model_dir( |
| 13 | + parent_model_path |
| 14 | + ), f"{parent_model_path=} is not a graphnet sample." |
| 15 | + ( |
| 16 | + parent_weight_meta_classes, |
| 17 | + parent_input_meta_classes, |
| 18 | + ) = self._load_weight_and_input_meta_classes(parent_model_path) |
| 19 | + self.original_name2parent_weight_meta_class = self._convert_to_dict( |
| 20 | + parent_weight_meta_classes |
| 21 | + ) |
| 22 | + self.original_name2parent_input_meta_class = self._convert_to_dict( |
| 23 | + parent_input_meta_classes |
| 24 | + ) |
| 25 | + |
| 26 | + def __call__(self, model_path): |
| 27 | + assert path_utils.is_single_model_dir( |
| 28 | + model_path |
| 29 | + ), f"{model_path=} is not a graphnet sample." |
| 30 | + ( |
| 31 | + weight_meta_classes, |
| 32 | + input_meta_classes, |
| 33 | + ) = self._load_weight_and_input_meta_classes(model_path) |
| 34 | + |
| 35 | + assert self.config["update_inplace"] |
| 36 | + is_weight_meta_fully_updated = self._update_by_original_name( |
| 37 | + weight_meta_classes, self.original_name2parent_weight_meta_class |
| 38 | + ) |
| 39 | + if is_weight_meta_fully_updated: |
| 40 | + new_weight_meta_codes = [] |
| 41 | + for meta_class in weight_meta_classes: |
| 42 | + new_weight_meta_codes.append( |
| 43 | + self._generate_py_code_from_meta_class(meta_class) |
| 44 | + ) |
| 45 | + |
| 46 | + weight_meta_file_path = os.path.join(model_path, "weight_meta.py") |
| 47 | + if self.config["update_inplace"]: |
| 48 | + print(f"[GraphMetaRestorer] Update {weight_meta_file_path}") |
| 49 | + with open(weight_meta_file_path, "w") as f: |
| 50 | + f.write("\n\n".join(new_weight_meta_codes)) |
| 51 | + |
| 52 | + is_input_meta_fully_updated = self._update_by_tensor_spec( |
| 53 | + input_meta_classes, self.original_name2parent_input_meta_class |
| 54 | + ) |
| 55 | + if is_input_meta_fully_updated: |
| 56 | + new_input_meta_codes = [] |
| 57 | + for meta_class in input_meta_classes: |
| 58 | + new_input_meta_codes.append( |
| 59 | + self._generate_py_code_from_meta_class(meta_class) |
| 60 | + ) |
| 61 | + |
| 62 | + input_meta_file_path = os.path.join(model_path, "input_meta.py") |
| 63 | + if self.config["update_inplace"]: |
| 64 | + print(f"[GraphMetaRestorer] Update {input_meta_file_path}") |
| 65 | + with open(input_meta_file_path, "w") as f: |
| 66 | + f.write("\n\n".join(new_input_meta_codes)) |
| 67 | + |
| 68 | + def _load_weight_and_input_meta_classes(self, model_path): |
| 69 | + weight_meta_file_path = os.path.join(model_path, "weight_meta.py") |
| 70 | + weight_meta_classes = [ |
| 71 | + meta_class |
| 72 | + for (name, meta_class) in utils.get_meta_classes(weight_meta_file_path) |
| 73 | + ] |
| 74 | + |
| 75 | + input_meta_file_path = os.path.join(model_path, "input_meta.py") |
| 76 | + input_meta_classes = [ |
| 77 | + meta_class |
| 78 | + for (name, meta_class) in utils.get_meta_classes(input_meta_file_path) |
| 79 | + ] |
| 80 | + |
| 81 | + return weight_meta_classes, input_meta_classes |
| 82 | + |
| 83 | + def _convert_to_dict(self, meta_classes): |
| 84 | + original_name2meta_class = {} |
| 85 | + for meta_class in meta_classes: |
| 86 | + assert meta_class.original_name not in original_name2meta_class.keys() |
| 87 | + original_name2meta_class[meta_class.original_name] = meta_class |
| 88 | + return original_name2meta_class |
| 89 | + |
| 90 | + def _update_tensor_meta(self, meta_class, parent_meta_class): |
| 91 | + if ( |
| 92 | + parent_meta_class |
| 93 | + and meta_class.dtype == parent_meta_class.dtype |
| 94 | + and meta_class.shape == parent_meta_class.shape |
| 95 | + ): |
| 96 | + for attr_name in ["max_val", "min_val", "mean", "std", "data"]: |
| 97 | + if hasattr(meta_class, attr_name) or hasattr( |
| 98 | + parent_meta_class, attr_name |
| 99 | + ): |
| 100 | + attr_value = getattr(parent_meta_class, attr_name, None) |
| 101 | + setattr(meta_class, attr_name, attr_value) |
| 102 | + return True |
| 103 | + return False |
| 104 | + |
| 105 | + def _update_by_original_name(self, meta_classes, original_name2parent_meta_class): |
| 106 | + updated_class_names = set() |
| 107 | + for meta_class in meta_classes: |
| 108 | + if not meta_class.original_name: |
| 109 | + continue |
| 110 | + |
| 111 | + parent_meta_class = original_name2parent_meta_class.get( |
| 112 | + meta_class.original_name, None |
| 113 | + ) |
| 114 | + if self._update_tensor_meta(meta_class, parent_meta_class): |
| 115 | + updated_class_names.add(meta_class.name) |
| 116 | + |
| 117 | + print( |
| 118 | + f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes are updated." |
| 119 | + ) |
| 120 | + return len(meta_classes) == len(updated_class_names) |
| 121 | + |
| 122 | + def _update_by_tensor_spec(self, meta_classes, original_name2parent_meta_class): |
| 123 | + updated_class_names = set() |
| 124 | + for meta_class in meta_classes: |
| 125 | + matched_parent_meta_class = [ |
| 126 | + parent_meta_class |
| 127 | + for parent_meta_class in original_name2parent_meta_class.values() |
| 128 | + if meta_class.dtype == parent_meta_class.dtype |
| 129 | + and meta_class.shape == parent_meta_class.shape |
| 130 | + ] |
| 131 | + if len(matched_parent_meta_class) == 1: |
| 132 | + self._update_tensor_meta(meta_class, matched_parent_meta_class[0]) |
| 133 | + updated_class_names.add(meta_class.name) |
| 134 | + |
| 135 | + print( |
| 136 | + f"[GraphMetaRestorer] {len(updated_class_names)}/{len(meta_classes)} classes are updated." |
| 137 | + ) |
| 138 | + return len(meta_classes) == len(updated_class_names) |
| 139 | + |
| 140 | + def _generate_py_code_from_meta_class(self, meta_class): |
| 141 | + lines = [f"class {meta_class.__name__}:"] |
| 142 | + members = vars(meta_class) |
| 143 | + members = {k: v for k, v in members.items() if not k.startswith("__")} |
| 144 | + |
| 145 | + if not members: |
| 146 | + return lines[0] + "\n pass" |
| 147 | + |
| 148 | + for name, value in members.items(): |
| 149 | + value_str = ( |
| 150 | + f"float('{repr(value)}')" if isinstance(value, float) else repr(value) |
| 151 | + ) |
| 152 | + lines.append(f" {name} = {value_str}") |
| 153 | + return "\n".join(lines) |
0 commit comments