Skip to content

Commit 7e02a5d

Browse files
authored
1) Init sym dim reifier in graph net json; 2) refactor model_path_handler (#420)
* init 'symbolic_dimension_reifier' field in graph_net.json * refactor model_path_handler
1 parent e6a16e2 commit 7e02a5d

15 files changed

+434
-530
lines changed

graph_net/config/empty_cstr_torch_samples_list.txt

Lines changed: 151 additions & 487 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
samples/transformers-auto-model/microsoft_xclip-base-patch32-16-frames

graph_net/model_path_handler.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import traceback
21
import argparse
32
from graph_net.imp_util import load_module
43
import logging
54
import sys
65
import json
76
import base64
7+
import subprocess
88

99
logging.basicConfig(
1010
level=logging.WARNING, format="%(asctime)s [%(levelname)s] %(message)s"
@@ -37,33 +37,49 @@ def _get_handler(args):
3737

3838
def main(args):
3939
handler = _get_handler(args)
40-
for model_path in _get_model_paths(args):
41-
print(f"{model_path=}")
40+
if args.model_path is not None:
41+
handle_model_path(handler, args.model_path)
42+
elif args.use_subprocess:
43+
handle_model_path_list_in_subprocess(args)
44+
else:
45+
handle_model_path_list_in_current_process(handler, args)
46+
47+
48+
def handle_model_path_list_in_current_process(handler, args):
49+
for model_path in _get_model_path_list(args):
4250
try:
43-
handler(model_path)
51+
handle_model_path(handler, model_path)
4452
except KeyboardInterrupt:
45-
sys.exit(-1)
46-
except Exception as e:
47-
print("--- Concise Error Message ---")
48-
print(e)
53+
print("KeyboardInterrupt")
54+
return
4955

50-
print("\n--- Full Traceback ---")
51-
traceback.print_exc()
5256

57+
def handle_model_path_list_in_subprocess(args):
58+
for model_path in _get_model_path_list(args):
59+
cmd = f"{sys.executable} -m graph_net.model_path_handler --model-path {model_path} --handler-config {args.handler_config}"
60+
try:
61+
subprocess.Popen(cmd, shell=True).wait()
62+
except KeyboardInterrupt:
63+
print("KeyboardInterrupt")
64+
return
5365

54-
def _get_model_paths(args):
55-
assert args.model_path is not None or args.model_path_list is not None
56-
if args.model_path is not None:
57-
yield args.model_path
58-
if args.model_path_list is not None:
59-
with open(args.model_path_list) as f:
60-
yield from (
61-
clean_line
62-
for line in f
63-
for clean_line in [line.strip()]
64-
if len(clean_line) > 0
65-
if not clean_line.startswith("#")
66-
)
66+
67+
def handle_model_path(handler, model_path):
68+
print(f"{model_path=}", flush=True)
69+
handler(model_path)
70+
71+
72+
def _get_model_path_list(args):
73+
assert args.model_path is None
74+
assert args.model_path_list is not None
75+
with open(args.model_path_list) as f:
76+
yield from (
77+
clean_line
78+
for line in f
79+
for clean_line in [line.strip()]
80+
if len(clean_line) > 0
81+
if not clean_line.startswith("#")
82+
)
6783

6884

6985
if __name__ == "__main__":
@@ -89,5 +105,11 @@ def _get_model_paths(args):
89105
default=None,
90106
help="handler configuration string",
91107
)
108+
parser.add_argument(
109+
"--use-subprocess",
110+
action="store_true",
111+
default=False,
112+
help="use subprocess",
113+
)
92114
args = parser.parse_args()
93115
main(args=args)

graph_net/test/decomposer_validator_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ echo "Results saved in: $FILE_PATH/ES_result.png"
4848
echo ""
4949
echo "IMPORTANT: Please verify if the curve in ES_result.png is a straight line"
5050
echo "If the curve is NOT a straight line, please check the log file: $FILE_PATH/log.log"
51-
echo "=================================================="
51+
echo "=================================================="

graph_net/test/naive_graph_decomposer_test.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ config_json_str=$(cat <<EOF
1313
"handler_config": {
1414
"output_dir": "/tmp/naive_decompose_workspace",
1515
"split_positions": [8, 16, 32],
16-
"group_head_and_tail": true,
17-
"filter_path":"$GRAPH_NET_ROOT/torch/naive_subgraph_filter.py",
18-
"filter_config": {}
16+
"chain_style": true,
17+
"group_head_and_tail": true
1918
}
2019
}
2120
EOF

graph_net/tools/_get_in_tensor_symbolic_shapes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sympy
12
from pathlib import Path
23
from graph_net.dynamic_dim_constraints import DynamicDimConstraints
34
import graph_net.graph_net_json_file_util as gn_json
@@ -27,6 +28,10 @@ def __call__(self, model_path):
2728
dyn_dim_cstrs = DynamicDimConstraints.unserialize_from_py_file(
2829
str(input_tensor_cstr_filepath)
2930
)
31+
for shape, name in dyn_dim_cstrs.input_shapes:
32+
if not any(isinstance(dim, sympy.Expr) for dim in shape):
33+
continue
34+
print(f"{shape=} {name=}")
3035
input_shapes_str = str(dyn_dim_cstrs.serialize_symbolic_input_shapes_to_str())
3136
print(f"get-in-tensor-symbolic-shapes {input_shapes_str} {model_path}")
3237

graph_net/tools/batch_init_input_tensor_constraints.sh

100644100755
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ config_json_str=$(cat <<EOF
3434
"non_batch_call_function_arange_plus_one_pass"
3535
]
3636
},
37-
"limits_handled_models": 1,
37+
"limits_handled_models": 999999,
3838
"last_model_log_file": "/tmp/a.py"
3939
}
4040
}
4141
EOF
4242
)
4343
CONFIG=$(echo $config_json_str | base64 -w 0)
4444

45-
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/empty_cstr_torch_samples_list.txt --handler-config=$CONFIG
45+
python3 -m graph_net.model_path_handler --model-path-list $GRAPH_NET_ROOT/config/empty_cstr_torch_samples_list.txt --handler-config=$CONFIG --use-subprocess

graph_net/tools/statisticize_in_tensor_symbolic_shapes.sh

100644100755
File mode changed.

graph_net/tools/update_sym_dim_reifier.sh

100644100755
File mode changed.

graph_net/torch/dim_gen_passes/batch_call_method_view_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def get_new_tuple_args(input_tensor_node, view_args):
7676
input_tensor_node = node.args[0]
7777
# Get the target shape arguments for view (e.g., 1, -1, 6, 64)
7878
view_args = node.args[1:]
79-
print(f"{view_args=}")
8079
new_view_args = get_new_tuple_args(input_tensor_node, view_args)
8180

8281
# --- Rebuild the view node ---

0 commit comments

Comments
 (0)