Skip to content

Commit b2a820e

Browse files
committed
feat: add AgentUnittestGenerator class for run_model.py in both paddle and torch.
1 parent 7e02a5d commit b2a820e

File tree

3 files changed

+317
-28
lines changed

3 files changed

+317
-28
lines changed

graph_net/paddle/run_model.py

Lines changed: 128 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
import argparse
2+
import base64
3+
import importlib.util
4+
import json
15
import os
26
import sys
3-
import json
4-
import base64
5-
import argparse
6-
from typing import Type
77

88
os.environ["FLAGS_logging_pir_py_code_dir"] = "/tmp/dump"
99

1010
import paddle
1111
from graph_net import imp_util
1212
from graph_net.paddle import utils
13+
from jinja2 import Template
1314

1415

1516
def load_class_from_file(file_path: str, class_name: str):
@@ -41,16 +42,125 @@ def _convert_to_dict(config_str):
4142
return config
4243

4344

44-
def _get_decorator(args):
45-
if args.decorator_config is None:
46-
return lambda model: model
47-
decorator_config = _convert_to_dict(args.decorator_config)
48-
if "decorator_path" not in decorator_config:
45+
def _get_decorator(arg):
46+
"""兼容旧接口:既接受 argparse.Namespace,也接受已解析的 dict。"""
47+
if arg is None:
4948
return lambda model: model
50-
decorator_class = load_class_from_file(
51-
decorator_config["decorator_path"], class_name="RunModelDecorator"
49+
50+
decorator_config = (
51+
_convert_to_dict(arg.decorator_config)
52+
if hasattr(arg, "decorator_config")
53+
else arg
5254
)
53-
return decorator_class(decorator_config.get("decorator_config", {}))
55+
if not decorator_config:
56+
return lambda model: model
57+
58+
class_name = decorator_config.get("decorator_class_name", "RunModelDecorator")
59+
decorator_kwargs = decorator_config.get("decorator_config", {})
60+
61+
if "decorator_path" in decorator_config:
62+
decorator_class = load_class_from_file(
63+
decorator_config["decorator_path"], class_name=class_name
64+
)
65+
return decorator_class(decorator_kwargs)
66+
67+
if hasattr(sys.modules[__name__], class_name):
68+
decorator_class = getattr(sys.modules[__name__], class_name)
69+
return decorator_class(decorator_kwargs)
70+
71+
return lambda model: model
72+
73+
74+
class AgentUnittestGenerator:
75+
"""生成 Paddle 子图的独立 unittest 脚本,验证前向可运行。"""
76+
77+
def __init__(self, config):
78+
defaults = {
79+
"model_path": None,
80+
"output_path": None,
81+
"force_device": "auto", # auto / cpu / gpu
82+
"use_numpy": True,
83+
}
84+
merged = {**defaults, **(config or {})}
85+
if merged["model_path"] is None:
86+
raise ValueError("AgentUnittestGenerator requires 'model_path' in config")
87+
self.model_path = merged["model_path"]
88+
self.output_path = merged["output_path"] or self._default_output_path()
89+
self.force_device = merged["force_device"]
90+
self.use_numpy = merged["use_numpy"]
91+
92+
def __call__(self, model):
93+
self._generate_unittest_file()
94+
return model
95+
96+
def _default_output_path(self):
97+
base = os.path.basename(os.path.normpath(self.model_path))
98+
return os.path.join(self.model_path, f"{base}_test.py")
99+
100+
def _choose_device(self):
101+
if self.force_device == "cpu":
102+
return "cpu"
103+
if self.force_device == "gpu":
104+
return "gpu"
105+
return "gpu" if paddle.device.is_compiled_with_cuda() else "cpu"
106+
107+
def _generate_unittest_file(self):
108+
target_device = self._choose_device()
109+
template_str = """
110+
import importlib.util
111+
import os
112+
import unittest
113+
114+
import paddle
115+
from graph_net.paddle import utils
116+
117+
118+
def _load_graph_module(model_path: str):
119+
source_path = os.path.join(model_path, "model.py")
120+
spec = importlib.util.spec_from_file_location("agent_graph_module", source_path)
121+
module = importlib.util.module_from_spec(spec)
122+
spec.loader.exec_module(module)
123+
return module.GraphModule
124+
125+
126+
class AgentGraphTest(unittest.TestCase):
127+
def setUp(self):
128+
self.model_path = os.path.dirname(__file__)
129+
self.target_device = "{{ target_device }}"
130+
paddle.set_device(self.target_device)
131+
self.GraphModule = _load_graph_module(self.model_path)
132+
self.meta = utils.load_converted_from_text(self.model_path)
133+
self.use_numpy = {{ use_numpy_flag }}
134+
135+
def _with_device(self, info):
136+
cloned = {"info": dict(info["info"]), "data": info.get("data")}
137+
cloned["info"]["device"] = self.target_device
138+
return cloned
139+
140+
def _build_tensor(self, meta):
141+
return utils.replay_tensor(self._with_device(meta), use_numpy=self.use_numpy)
142+
143+
def test_forward_runs(self):
144+
model = self.GraphModule()
145+
inputs = {k: self._build_tensor(v) for k, v in self.meta["input_info"].items()}
146+
params = {k: self._build_tensor(v) for k, v in self.meta["weight_info"].items()}
147+
model.__graph_net_file_path__ = self.model_path
148+
output = model(**params, **inputs)
149+
self.assertIsNotNone(output)
150+
151+
152+
if __name__ == "__main__":
153+
unittest.main()
154+
"""
155+
156+
rendered = Template(template_str).render(
157+
target_device=target_device, use_numpy_flag=self.use_numpy
158+
)
159+
160+
os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
161+
with open(self.output_path, "w", encoding="utf-8") as f:
162+
f.write(rendered)
163+
print(f"[Agent] unittest 已生成: {self.output_path} (device={target_device})")
54164

55165

56166
def main(args):
@@ -61,9 +171,14 @@ def main(args):
61171
assert model_class is not None
62172
model = model_class()
63173
print(f"{model_path=}")
174+
decorator_config = _convert_to_dict(args.decorator_config)
175+
if decorator_config:
176+
decorator_config.setdefault("decorator_config", {})
177+
decorator_config["decorator_config"].setdefault("model_path", model_path)
178+
decorator_config["decorator_config"].setdefault("use_numpy", True)
64179

180+
model = _get_decorator(decorator_config)(model)
65181
input_dict = get_input_dict(args.model_path)
66-
model = _get_decorator(args)(model)
67182
model(**input_dict)
68183

69184

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env bash
2+
set -euo pipefail
3+
4+
# Smoke tests for AgentUnittestGenerator on one CV and one NLP sample (Torch side).
5+
# It runs run_model with the decorator, which will drop a *_test.py under each sample directory.
6+
7+
ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
8+
TORCH_RUN="python -m graph_net.torch.run_model"
9+
10+
CV_SAMPLE="$ROOT_DIR/samples/torchvision/resnet18"
11+
NLP_SAMPLE="$ROOT_DIR/samples/transformers-auto-model/albert-base-v2"
12+
13+
encode_cfg() {
14+
MODEL_PATH="$1" python - <<'PY'
15+
import base64, json, os
16+
cfg = {
17+
"decorator_class_name": "AgentUnittestGenerator",
18+
"decorator_config": {
19+
"model_path": os.environ["MODEL_PATH"],
20+
"force_device": "auto",
21+
"output_path": None,
22+
"use_dummy_inputs": False,
23+
},
24+
}
25+
print(base64.b64encode(json.dumps(cfg).encode()).decode())
26+
PY
27+
}
28+
29+
run_case() {
30+
local sample_path="$1"
31+
local name="$2"
32+
echo "[AgentTest] running $name sample at $sample_path"
33+
cfg_b64="$(encode_cfg "$sample_path")"
34+
$TORCH_RUN --model-path "$sample_path" --decorator-config "$cfg_b64"
35+
}
36+
37+
run_case "$CV_SAMPLE" "CV (torchvision/resnet18)"
38+
run_case "$NLP_SAMPLE" "NLP (transformers-auto-model/albert-base-v2)"
39+
40+
echo "[AgentTest] done. Generated *_test.py files should now exist beside the samples."

0 commit comments

Comments
 (0)