@@ -27,6 +27,28 @@ def get_pass_name(pass_id):
2727 return f"pass_{ pass_id } "
2828
2929
30+ def get_ranged_incorrect_models (tolerance_args : List [int ], log_path : str ) -> set :
31+ if not os .path .exists (log_path ):
32+ return set ()
33+
34+ t_start = tolerance_args [0 ]
35+ models_start = set (get_incorrect_models (t_start , log_path ))
36+
37+ if len (tolerance_args ) == 1 :
38+ return models_start
39+
40+ t_end = tolerance_args [1 ]
41+ models_end = set (get_incorrect_models (t_end , log_path ))
42+
43+ print (f"[Filter] Tolerance Range: { t_start } -> { t_end } " )
44+ print (
45+ f"[Filter] Fail({ t_start } ): { len (models_start )} , Fail({ t_end } ): { len (models_end )} "
46+ )
47+
48+ diff_set = models_start - models_end
49+ return diff_set
50+
51+
3052class TaskController :
3153 def __init__ (self , args ):
3254 self .root_output_dir = os .path .abspath (args .output_dir )
@@ -198,10 +220,10 @@ def run_decomposer_for_multi_models(
198220 )
199221 for model_name , task_info in tasks_map .items ():
200222 original_path = task_info ["original_path" ]
201- split_positions = calculate_split_positions_for_subgraph (
202- task_info ["subgraph_size" ], max_subgraph_size
203- )
204- task_info [ " split_positions" ] = split_positions
223+
224+ split_positions = task_info ["split_positions" ]
225+ if isinstance ( split_positions , set ):
226+ split_positions = sorted ( list ( split_positions ))
205227
206228 rectified_model_path = get_rectfied_model_path (original_path )
207229 assert os .path .exists (
@@ -262,35 +284,39 @@ def reconstruct_subgraph_size(split_positions: List[int]) -> List[list]:
262284 return subgraph_size
263285
264286
265- def calculate_split_positions_for_subgraph (subgraph_size , max_subgraph_size ):
266- assert isinstance (subgraph_size , (list , tuple )) and len (subgraph_size ) == 2
287+ def calculate_split_positions_for_subgraph (subgraph_range , max_subgraph_size ):
288+ assert isinstance (subgraph_range , (list , tuple )) and len (subgraph_range ) == 2
267289
268290 # subgraph_size: the start and end position in original model.
269- start_pos , end_pos = subgraph_size
291+ start_pos , end_pos = subgraph_range
270292 end_pos = kMaxGraphSize if end_pos == float ("inf" ) else end_pos
271293
272- split_positions = list (range (start_pos , end_pos + 1 , max_subgraph_size ))
273- deduplicated_splits = list (dict . fromkeys (split_positions ))
294+ split_positions = set (range (start_pos , end_pos + 1 , max_subgraph_size ))
295+ deduplicated_splits = list (sorted (split_positions ))
274296 return deduplicated_splits
275297
276298
277299def generate_initial_tasks (args ):
278300 """Generates tasks for Pass 0 based on the initial log file."""
279301 print (f"[Init] Pass 0: Reading from log file: { args .log_file } " )
280- initial_failures = get_incorrect_models (args .tolerance , args .log_file )
281- t1_incorrect_models = get_incorrect_models (1 , args .log_file )
282- initial_failures = initial_failures - t1_incorrect_models
302+ initial_failures = get_ranged_incorrect_models (args .tolerance , args .log_file )
283303
284304 tasks_map = {}
305+ max_subgraph_size = args .max_subgraph_size
306+
285307 for model_path in initial_failures :
286308 model_name = get_model_name_with_subgraph_tag (model_path )
309+
310+ initial_range = [0 , kMaxGraphSize ]
311+ initial_splits = calculate_split_positions_for_subgraph (
312+ initial_range , max_subgraph_size
313+ )
314+
287315 tasks_map [model_name ] = {
288316 "original_path" : model_path ,
289- "subgraph_size" : [0 , kMaxGraphSize ],
290- "split_positions" : set (),
317+ "split_positions" : list (sorted (initial_splits )),
291318 }
292319
293- max_subgraph_size = args .max_subgraph_size
294320 running_states = {
295321 "pass_0" : {
296322 "num_incorrect_models" : len (initial_failures ),
@@ -322,19 +348,28 @@ def generate_refined_tasks(base_output_dir, current_pass_id):
322348 assert model_name in prev_tasks_map
323349 pre_task_for_model = prev_tasks_map [model_name ]
324350
325- # Reconstruct previous subgraph size to locate the failing segment
326351 prev_split_positions = pre_task_for_model .get ("split_positions" , [])
327- subgraph_size = reconstruct_subgraph_size (prev_split_positions )
352+ subgraph_ranges = reconstruct_subgraph_size (prev_split_positions )
353+
328354 assert subgraph_idx < len (
329- subgraph_size
355+ subgraph_ranges
330356 ), f"subgraph_idx { subgraph_idx } is out of bounds for { model_name } (previous split_positions: { prev_split_positions } )"
331357
358+ split_positions = calculate_split_positions_for_subgraph (
359+ subgraph_ranges [subgraph_idx ], max_subgraph_size
360+ )
332361 if model_name not in tasks_map :
333362 tasks_map [model_name ] = {
334363 "original_path" : pre_task_for_model ["original_path" ],
335- "subgraph_size" : subgraph_size [subgraph_idx ],
336- "split_positions" : set (),
364+ "split_positions" : list (sorted (split_positions )),
337365 }
366+ else :
367+ merged_split_positions = (
368+ tasks_map [model_name ]["split_positions" ] + split_positions
369+ )
370+ tasks_map [model_name ]["split_positions" ] = list (
371+ sorted (set (merged_split_positions ))
372+ )
338373
339374 return tasks_map , max_subgraph_size , prev_config .running_states
340375
@@ -399,11 +434,23 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
399434 need_decompose = True
400435 shutil .rmtree (decomposed_samples_dir )
401436 os .makedirs (decomposed_samples_dir , exist_ok = True )
437+ max_subgraph_size = max (1 , max_subgraph_size // 2 )
402438 for model_name , task_info in tasks_map .items ():
403- task_info ["subgraph_size" ][1 ] = (
404- task_info ["subgraph_size" ][0 ] + max_subgraph_size
439+ splits = task_info ["split_positions" ]
440+ if not splits or len (splits ) < 2 :
441+ continue
442+ if isinstance (splits , set ):
443+ splits = sorted (list (splits ))
444+ start_pos = splits [0 ]
445+ first_segment_end = splits [1 ]
446+ new_splits = list (
447+ range (start_pos , first_segment_end + 1 , max_subgraph_size )
405448 )
406- max_subgraph_size = max (1 , max_subgraph_size // 2 )
449+
450+ if new_splits [- 1 ] != first_segment_end :
451+ new_splits .append (first_segment_end )
452+
453+ task_info ["split_positions" ] = sorted (list (set (new_splits )))
407454 else :
408455 need_decompose = False
409456 print ()
@@ -458,6 +505,7 @@ def main(args):
458505 "failed_decomposition_models"
459506 ] = list (failed_decomposition )
460507 else :
508+ print ("\n --- Phase 1: Decomposition (skipped) ---" , flush = True )
461509 config = DecomposeConfig .load (pass_work_dir )
462510 max_subgraph_size = config .max_subgraph_size
463511 tasks_map = config .tasks_map
@@ -466,19 +514,26 @@ def main(args):
466514 # --- Step 3: Evaluation ---
467515 pass_log_path = os .path .join (pass_work_dir , "batch_test_result.log" )
468516 if task_controller .task_scheduler ["run_evaluation" ]:
469- print ("\n --- Phase 2: Evaluation ---" )
517+ print (f "\n --- Phase 2: Evaluation ( { task_controller . test_module_name } ) ---" )
470518 run_evaluation (args .framework , args .test_config , pass_work_dir , pass_log_path )
471519
472520 # --- Step 4: Analysis ---
473521 next_round_models = set ()
474522 if task_controller .task_scheduler ["post_analysis" ]:
475- print ("\n --- Phase 3: Analysis ---" )
476- next_round_models = sorted (get_incorrect_models (args .tolerance , pass_log_path ))
477- print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
523+ tolerance = (
524+ args .tolerance [0 ] if isinstance (args .tolerance , list ) else args .tolerance
525+ )
526+ print (f"\n --- Phase 3: Analysis (torlance={ tolerance } ) ---" )
527+ next_round_models = sorted (get_incorrect_models (tolerance , pass_log_path ))
478528 running_states [f"pass_{ current_pass_id + 1 } " ] = {
479529 "num_incorrect_models" : len (next_round_models ),
480530 "incorrect_models" : list (next_round_models ),
481531 }
532+
533+ print (f"[Analysis] Found { len (next_round_models )} incorrect subgraphs.\n " )
534+ for idx , model_path in enumerate (next_round_models ):
535+ print (f"- [{ idx } ] { model_path } " )
536+
482537 print_summary_and_suggestion (next_round_models , max_subgraph_size )
483538
484539 # --- Step 5: Save States ---
@@ -500,7 +555,11 @@ def main(args):
500555 "--test-config" , type = str , required = True , help = "Base64 encoded test config"
501556 )
502557 parser .add_argument (
503- "--tolerance" , type = int , required = True , help = "Tolerance level range [-10, 5)"
558+ "--tolerance" ,
559+ type = int ,
560+ nargs = "+" ,
561+ required = True ,
562+ help = "Tolerance level range [-10, 5)" ,
504563 )
505564 parser .add_argument ("--max-subgraph-size" , type = int , default = 4096 )
506565 args = parser .parse_args ()
0 commit comments