1+ import argparse
2+ import base64
3+ import importlib .util
4+ import json
15import os
26import sys
3- import json
4- import base64
5- import argparse
6- from typing import Type
77
88os .environ ["FLAGS_logging_pir_py_code_dir" ] = "/tmp/dump"
99
1010import paddle
1111from graph_net import imp_util
1212from graph_net .paddle import utils
13+ from jinja2 import Template
1314
1415
1516def load_class_from_file (file_path : str , class_name : str ):
@@ -41,16 +42,125 @@ def _convert_to_dict(config_str):
4142 return config
4243
4344
44- def _get_decorator (args ):
45- if args .decorator_config is None :
46- return lambda model : model
47- decorator_config = _convert_to_dict (args .decorator_config )
48- if "decorator_path" not in decorator_config :
45+ def _get_decorator (arg ):
46+ """兼容旧接口:既接受 argparse.Namespace,也接受已解析的 dict。"""
47+ if arg is None :
4948 return lambda model : model
50- decorator_class = load_class_from_file (
51- decorator_config ["decorator_path" ], class_name = "RunModelDecorator"
49+
50+ decorator_config = (
51+ _convert_to_dict (arg .decorator_config )
52+ if hasattr (arg , "decorator_config" )
53+ else arg
5254 )
53- return decorator_class (decorator_config .get ("decorator_config" , {}))
55+ if not decorator_config :
56+ return lambda model : model
57+
58+ class_name = decorator_config .get ("decorator_class_name" , "RunModelDecorator" )
59+ decorator_kwargs = decorator_config .get ("decorator_config" , {})
60+
61+ if "decorator_path" in decorator_config :
62+ decorator_class = load_class_from_file (
63+ decorator_config ["decorator_path" ], class_name = class_name
64+ )
65+ return decorator_class (decorator_kwargs )
66+
67+ if hasattr (sys .modules [__name__ ], class_name ):
68+ decorator_class = getattr (sys .modules [__name__ ], class_name )
69+ return decorator_class (decorator_kwargs )
70+
71+ return lambda model : model
72+
73+
74+ class AgentUnittestGenerator :
75+ """生成 Paddle 子图的独立 unittest 脚本,验证前向可运行。"""
76+
77+ def __init__ (self , config ):
78+ defaults = {
79+ "model_path" : None ,
80+ "output_path" : None ,
81+ "force_device" : "auto" , # auto / cpu / gpu
82+ "use_numpy" : True ,
83+ }
84+ merged = {** defaults , ** (config or {})}
85+ if merged ["model_path" ] is None :
86+ raise ValueError ("AgentUnittestGenerator requires 'model_path' in config" )
87+ self .model_path = merged ["model_path" ]
88+ self .output_path = merged ["output_path" ] or self ._default_output_path ()
89+ self .force_device = merged ["force_device" ]
90+ self .use_numpy = merged ["use_numpy" ]
91+
92+ def __call__ (self , model ):
93+ self ._generate_unittest_file ()
94+ return model
95+
96+ def _default_output_path (self ):
97+ base = os .path .basename (os .path .normpath (self .model_path ))
98+ return os .path .join (self .model_path , f"{ base } _test.py" )
99+
100+ def _choose_device (self ):
101+ if self .force_device == "cpu" :
102+ return "cpu"
103+ if self .force_device == "gpu" :
104+ return "gpu"
105+ return "gpu" if paddle .device .is_compiled_with_cuda () else "cpu"
106+
107+ def _generate_unittest_file (self ):
108+ target_device = self ._choose_device ()
109+ template_str = """
110+ import importlib.util
111+ import os
112+ import unittest
113+
114+ import paddle
115+ from graph_net.paddle import utils
116+
117+
118+ def _load_graph_module(model_path: str):
119+ source_path = os.path.join(model_path, "model.py")
120+ spec = importlib.util.spec_from_file_location("agent_graph_module", source_path)
121+ module = importlib.util.module_from_spec(spec)
122+ spec.loader.exec_module(module)
123+ return module.GraphModule
124+
125+
126+ class AgentGraphTest(unittest.TestCase):
127+ def setUp(self):
128+ self.model_path = os.path.dirname(__file__)
129+ self.target_device = "{{ target_device }}"
130+ paddle.set_device(self.target_device)
131+ self.GraphModule = _load_graph_module(self.model_path)
132+ self.meta = utils.load_converted_from_text(self.model_path)
133+ self.use_numpy = {{ use_numpy_flag }}
134+
135+ def _with_device(self, info):
136+ cloned = {"info": dict(info["info"]), "data": info.get("data")}
137+ cloned["info"]["device"] = self.target_device
138+ return cloned
139+
140+ def _build_tensor(self, meta):
141+ return utils.replay_tensor(self._with_device(meta), use_numpy=self.use_numpy)
142+
143+ def test_forward_runs(self):
144+ model = self.GraphModule()
145+ inputs = {k: self._build_tensor(v) for k, v in self.meta["input_info"].items()}
146+ params = {k: self._build_tensor(v) for k, v in self.meta["weight_info"].items()}
147+ model.__graph_net_file_path__ = self.model_path
148+ output = model(**params, **inputs)
149+ self.assertIsNotNone(output)
150+
151+
152+ if __name__ == "__main__":
153+ unittest.main()
154+ """
155+
156+ rendered = Template (template_str ).render (
157+ target_device = target_device , use_numpy_flag = self .use_numpy
158+ )
159+
160+ os .makedirs (os .path .dirname (self .output_path ), exist_ok = True )
161+ with open (self .output_path , "w" , encoding = "utf-8" ) as f :
162+ f .write (rendered )
163+ print (f"[Agent] unittest 已生成: { self .output_path } (device={ target_device } )" )
54164
55165
56166def main (args ):
@@ -61,9 +171,14 @@ def main(args):
61171 assert model_class is not None
62172 model = model_class ()
63173 print (f"{ model_path = } " )
174+ decorator_config = _convert_to_dict (args .decorator_config )
175+ if decorator_config :
176+ decorator_config .setdefault ("decorator_config" , {})
177+ decorator_config ["decorator_config" ].setdefault ("model_path" , model_path )
178+ decorator_config ["decorator_config" ].setdefault ("use_numpy" , True )
64179
180+ model = _get_decorator (decorator_config )(model )
65181 input_dict = get_input_dict (args .model_path )
66- model = _get_decorator (args )(model )
67182 model (** input_dict )
68183
69184
0 commit comments