Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions graph_net/torch/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
mut_graph_codes=None,
placeholder_auto_rename=False,
workspace_path=None,
param_buffer_ids=None,
):
self.subgraph_counter = 0
self.name = name
Expand All @@ -64,6 +65,7 @@ def __init__(
raise EnvironmentError(
"Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set."
)
self.param_buffer_ids = param_buffer_ids or set()

def move_files(self, source_dir, target_dir):
os.makedirs(target_dir, exist_ok=True)
Expand Down Expand Up @@ -150,7 +152,16 @@ def try_rename_placeholder(node):
self.mut_graph_codes.append(base_code)

# 4. Save tensor metadata
converted = utils.convert_state_and_inputs(params, [])
# Separate model weights (parameters + buffers) from real inputs in params
weights = {}
example_inputs = {}
for name, value in params.items():
if id(value) in self.param_buffer_ids:
weights[name] = value
else:
example_inputs[name] = value

converted = utils.convert_state_and_inputs(weights, example_inputs)
utils.save_converted_to_text(converted, file_path=subgraph_path)
utils.save_constraints_text(
converted,
Expand Down Expand Up @@ -280,9 +291,24 @@ def wrapper(model: torch.nn.Module):
model_path = None
if hasattr(model, "__graph_net_file_path__"):
model_path = os.path.dirname(model.__graph_net_file_path__)
extractor = get_graph_extractor_maker(model_path)(
name, dynamic, mut_graph_codes, placeholder_auto_rename
)

# Collect parameter and buffer ids from the original model for distinguishing weights and inputs in __call__
param_buffer_ids = set()
for _, p in model.named_parameters():
param_buffer_ids.add(id(p))
for _, b in model.named_buffers():
param_buffer_ids.add(id(b))

maker = get_graph_extractor_maker(model_path)
if maker is GraphExtractor:
extractor = maker(
name, dynamic, mut_graph_codes, placeholder_auto_rename,
param_buffer_ids=param_buffer_ids,
)
else:
extractor = maker(
name, dynamic, mut_graph_codes, placeholder_auto_rename
)
# return torch.compile(backend=extractor, dynamic=dynamic)
compiled_model = torch.compile(model, backend=extractor, dynamic=dynamic)
return compiled_model
Expand Down
27 changes: 22 additions & 5 deletions graph_net/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def process_tensor(tensor):
processed_inputs = process_tensor(example_inputs)
elif isinstance(example_inputs, (list, tuple)):
processed_inputs = [process_tensor(t) for t in example_inputs]
elif isinstance(example_inputs, dict):
processed_inputs = {k: process_tensor(v) for k, v in example_inputs.items()}
else:
processed_inputs = {"type": "unknown", "value": example_inputs}

Expand Down Expand Up @@ -181,13 +183,28 @@ def process_tensor_info(tensor_info, name_prefix="example_input"):
return lines

input_infos = converted["input_info"]
if isinstance(input_infos, dict):
input_infos = [input_infos]

input_lines = []
for idx, input_info in enumerate(input_infos):
input_info["name"] = f"input_{idx}"
input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input"))
if isinstance(input_infos, dict) and input_infos:
# Check if it's a dict of named tensor infos (e.g., placeholder inputs)
first_val = next(iter(input_infos.values()))
if isinstance(first_val, dict) and "type" in first_val:
# Named inputs: {name: tensor_info}
for name, input_info in input_infos.items():
input_info["name"] = name
input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input"))
else:
# Single input info dict (e.g., a single tensor's info)
input_infos = [input_infos]
for idx, input_info in enumerate(input_infos):
input_info["name"] = f"input_{idx}"
input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input"))
else:
if isinstance(input_infos, dict):
input_infos = [input_infos]
for idx, input_info in enumerate(input_infos):
input_info["name"] = f"input_{idx}"
input_lines.extend(process_tensor_info(input_info, name_prefix="Program_input"))

with open(f"{file_path}/input_meta.py", "w") as f:
f.write("\n".join(input_lines))
Expand Down