11import config
22import problem
3- import run
43import random
54import generate
5+ import shutil
66import signal
7+ import sys
78import time
89import threading
10+ from colorama import Style
911from pathlib import Path
10- from typing import Any , Optional
12+ from typing import Any , Optional , TextIO
1113
1214import parallel
1315from util import *
16+ from run import Run , Submission
1417from testcase import Testcase
1518from validate import OutputValidator , Mode
1619from verdicts import Verdict
2730class GeneratorTask :
2831 def __init__ (self , fuzz : "Fuzz" , t : generate .TestcaseRule , i : int , tmp_id : int ):
2932 self .fuzz = fuzz
33+ self .rule = t
3034 generator = t .generator
3135 assert generator is not None
3236 self .generator = generator
@@ -52,6 +56,7 @@ def _run(self, bar: ProgressBar) -> bool:
5256 # GENERATE THE TEST DATA
5357 dir = Path ("fuzz" ) / f"tmp_id_{ str (self .tmp_id )} "
5458 cwd = self .fuzz .problem .tmpdir / "tool_runs" / dir
59+ shutil .rmtree (cwd , ignore_errors = True )
5560 cwd .mkdir (parents = True , exist_ok = True )
5661 name = "testcase"
5762 infile = cwd / (name + ".in" )
@@ -63,51 +68,64 @@ def _run(self, bar: ProgressBar) -> bool:
6368
6469 localbar = bar .start (f"{ self .i } : generate" )
6570 result = self .generator .run (localbar , cwd , name , self .seed )
71+ self .fuzz .queue .ensure_alive ()
6672 if not result .status :
6773 return False # No need to call bar.done() in this case, because the Generator calls bar.error()
74+ if ".ans" in self .rule .hardcoded :
75+ ansfile .write_text (self .rule .hardcoded [".ans" ])
6876 localbar .done ()
6977
7078 testcase = Testcase (self .fuzz .problem , infile , short_path = dir / (name + ".in" ))
7179
7280 # Validate the generated .in.
7381 localbar = bar .start (f"{ self .i } : validate input" )
7482 if not testcase .validate_format (Mode .INPUT , bar = localbar , constraints = None ):
83+ self .fuzz .queue .ensure_alive ()
7584 localbar .done (False )
7685 return False
86+ self .fuzz .queue .ensure_alive ()
7787 localbar .done ()
7888
7989 # Generate .ans.
80- if not self .fuzz .problem .interactive and not self .fuzz .problem .multi_pass :
81- if self .solution and not testcase .ans_path .is_file ():
82- if testcase .ans_path .is_file ():
83- testcase .ans_path .unlink ()
84- # Run the solution and validate the generated .ans.
85- localbar = bar .start (f"{ self .i } : generate ans" )
86- if not self .solution .run (bar , cwd ).status :
90+ if not ansfile .is_file ():
91+ if self .fuzz .problem .settings .ans_is_output :
92+ if self .solution :
93+ # Run the solution and validate the generated .ans.
94+ localbar = bar .start (f"{ self .i } : generate ans" )
95+ if not self .solution .run (bar , cwd ).status :
96+ self .fuzz .queue .ensure_alive ()
97+ localbar .done ()
98+ return False
99+ self .fuzz .queue .ensure_alive ()
87100 localbar .done ()
88- return False
89- localbar .done ()
90-
91- if ansfile .is_file ():
92- localbar = bar .start (f"{ self .i } : validate output" )
93- if not testcase .validate_format (Mode .ANSWER , bar = localbar ):
94- localbar .done (False )
95- return False
96- localbar .done ()
97- else :
98- bar .error (f"{ self .i } : { ansfile .name } was not generated." )
101+ elif self .fuzz .problem .interactive or self .fuzz .problem .multi_pass :
102+ ansfile .write_text ("" )
103+
104+ if ansfile .is_file ():
105+ localbar = bar .start (f"{ self .i } : validate output" )
106+ if not testcase .validate_format (Mode .ANSWER , bar = localbar ):
107+ self .fuzz .queue .ensure_alive ()
108+ localbar .done (False )
99109 return False
110+ self .fuzz .queue .ensure_alive ()
111+ localbar .done ()
100112 else :
101- if not testcase . ans_path . is_file ():
102- testcase . ans_path . write_text ( "" )
113+ bar . error ( f" { self . i } : { ansfile . name } was not generated." )
114+ return False
103115
104116 # Run all submissions against the testcase.
105117 with self .fuzz .queue :
106118 for submission in self .fuzz .submissions :
107119 self .fuzz .queue .put (SubmissionTask (self , submission , testcase , self .tmp_id ))
108120 return True
109121
110- def save_test (self , bar : ProgressBar ) -> None :
122+ def get_command (self ) -> dict [str , str ] | str :
123+ if not self .fuzz .problem .settings .ans_is_output and ".ans" in self .rule .hardcoded :
124+ return {"generate" : self .command , "ans" : self .rule .hardcoded [".ans" ]}
125+ else :
126+ return self .command
127+
128+ def save_test (self , bar : ProgressBar , submission : Submission , verdict : Verdict ) -> None :
111129 if self .saved :
112130 return
113131 save = False
@@ -116,19 +134,21 @@ def save_test(self, bar: ProgressBar) -> None:
116134 if not self .saved :
117135 self .saved = True
118136 save = True
137+ self .fuzz .queue .ensure_alive ()
119138 # only save rule if we set self.saved to True
120- if save and not self . fuzz . queue . aborted :
139+ if save :
121140 localbar = bar .start (f"{ self .i } : { self .command } " )
122141 localbar .log ("Saving testcase in generators.yaml." )
142+ self .fuzz .save_test (self .get_command (), submission , verdict )
143+ self .fuzz .queue .ensure_alive ()
123144 localbar .done ()
124- self .fuzz .save_test (self .command )
125145
126146
127147class SubmissionTask :
128148 def __init__ (
129149 self ,
130150 generator_task : GeneratorTask ,
131- submission : run . Submission ,
151+ submission : Submission ,
132152 testcase : Testcase ,
133153 tmp_id : int ,
134154 ):
@@ -142,38 +162,67 @@ def run(self, bar: ProgressBar) -> None:
142162 self .generator_task .fuzz .finish_task (self .tmp_id )
143163
144164 def _run (self , bar : ProgressBar ) -> None :
145- r = run . Run (self .generator_task .fuzz .problem , self .submission , self .testcase )
165+ r = Run (self .generator_task .fuzz .problem , self .submission , self .testcase )
146166 localbar = bar .start (f"{ self .generator_task .i } : { self .submission .name } " )
147167 result = r .run (localbar )
168+ self .generator_task .fuzz .queue .ensure_alive ()
148169 if result .verdict != Verdict .ACCEPTED :
149- self .generator_task .save_test (bar )
170+ self .generator_task .save_test (bar , self . submission , result . verdict )
150171 localbar .done (False , f"{ result .verdict } !" )
151172 else :
152173 localbar .done ()
153174
154175
176+ class FuzzProgressBar (ProgressBar ):
177+ def __init__ (self , queue : parallel .AbstractQueue , prefix : str , max_len : int ):
178+ super ().__init__ (prefix , max_len )
179+ self .queue = queue
180+
181+ def _print (
182+ self ,
183+ * objects ,
184+ sep : str = "" ,
185+ end : str = "\n " ,
186+ file : TextIO = sys .stderr ,
187+ flush : bool = True ,
188+ ):
189+ self .queue .ensure_alive ()
190+ super ()._print (* objects , sep = sep , end = end , file = file , flush = flush )
191+
192+
155193class Fuzz :
156194 def __init__ (self , problem : problem .Problem ):
157195 self .generators_yaml_mutex = threading .Lock ()
158196 self .problem = problem
197+ self .summary : dict [Submission , set [Verdict ]] = {}
198+ self .added = 0
159199
160200 # GENERATOR INVOCATIONS
161201 generator_config = generate .GeneratorConfig (self .problem , config .args .testcases )
162202 self .testcase_rules : list [generate .TestcaseRule ] = []
163203
164204 # Filter to only keep valid rules depending on seed without duplicates from count
165- added_testcase_rules = set ()
205+ added_testcase_rule_data = set ()
166206
167207 def add_testcase (t : generate .TestcaseRule ) -> None :
168208 if (
169- t .in_is_generated
170- and t .parse_error is None
171- and t .generator is not None
172- and t .generator . uses_seed
173- and t .generator .command_string . strip () not in added_testcase_rules
209+ not t .in_is_generated
210+ or t .root in config . INVALID_CASE_DIRECTORIES
211+ or t .parse_error is not None
212+ or t .generator is None
213+ or not t .generator .uses_seed
174214 ):
175- self .testcase_rules .append (t )
176- added_testcase_rules .add (t .generator .command_string .strip ())
215+ return
216+
217+ testcase_rule_data = t .generator .command_string .strip ()
218+ if not problem .settings .ans_is_output and ".ans" in t .hardcoded :
219+ testcase_rule_data += t .hardcoded [".ans" ]
220+
221+ if testcase_rule_data in added_testcase_rule_data :
222+ return
223+
224+ self .testcase_rules .append (t )
225+ added_testcase_rule_data .add (testcase_rule_data )
177226
178227 generator_config .root_dir .walk (add_testcase , dir_f = None )
179228 if len (self .testcase_rules ) == 0 :
@@ -205,33 +254,42 @@ def run(self) -> bool:
205254 def runner (task : GeneratorTask | SubmissionTask ) -> None :
206255 task .run (bar )
207256
208- # config.args.no_bar = True
209- # max(len(s.name) for s in self.submissions)
210- bar = ProgressBar ("Fuzz" , max_len = 60 )
211257 self .start_time = time .monotonic ()
212258 self .iteration = 0
213259 self .tasks = 0
214260 self .queue = parallel .new_queue (runner , pin = True )
215261
262+ # pool of ids used for generators
263+ self .tmp_ids = 2 * max (1 , self .queue .num_threads ) + 1
264+ self .free_tmp_id = {* range (self .tmp_ids )}
265+ self .tmp_id_count = [0 ] * self .tmp_ids
266+
267+ max_len = max (
268+ 25 ,
269+ * [len (s .name ) for s in self .submissions ],
270+ * [
271+ len (t .generator .cache_command (seed = 2 ** 32 ))
272+ for t in self .testcase_rules
273+ if t .generator is not None
274+ ],
275+ )
276+ max_len += len (f"{ self .tmp_ids } : " )
277+ bar = FuzzProgressBar (self .queue , "Fuzz" , max_len = max_len )
278+
216279 def soft_exit (sig : Any , frame : Any ) -> None :
217280 if self .queue .aborted :
218281 fatal ("Running interrupted" , force = True )
219282 else :
220283 self .queue .abort ()
221284 with bar :
222- bar .clearline ( )
285+ print ( bar .carriage_return , file = sys . stderr )
223286 message (
224287 "Running interrupted (waiting on remaining tasks)\n " ,
225288 "\n Fuzz" ,
226289 color_type = MessageType .ERROR ,
227290 )
228291
229- signal .signal (signal .SIGINT , soft_exit )
230-
231- # pool of ids used for generators
232- self .tmp_ids = 2 * max (1 , self .queue .num_threads ) + 1
233- self .free_tmp_id = {* range (self .tmp_ids )}
234- self .tmp_id_count = [0 ] * self .tmp_ids
292+ old_handler = signal .signal (signal .SIGINT , soft_exit )
235293
236294 # add first generator task
237295 self .finish_task ()
@@ -241,11 +299,19 @@ def soft_exit(sig: Any, frame: Any) -> None:
241299 # At this point, no new tasks may be started anymore.
242300 self .queue .done ()
243301
302+ signal .signal (signal .SIGINT , old_handler )
303+
304+ for submission , verdicts in self .summary .items ():
305+ msg = ", " .join (f"{ v .color ()} { v .short ()} { Style .RESET_ALL } " for v in sorted (verdicts ))
306+ message (msg , "Fuzz" , submission .name )
307+ message (f"Found { self .added } testcases in total." , "Fuzz" )
308+
244309 if self .queue .aborted :
245- fatal ("Running interrupted" , force = True )
310+ fatal ("Running interrupted" )
246311
247312 bar .done ()
248313 bar .finalize ()
314+
249315 return True
250316
251317 # finish task from generator with tmp_id
@@ -280,7 +346,9 @@ def finish_task(self, tmp_id: Optional[int] = None, count: int = 1) -> None:
280346
281347 # Write new rule to yaml
282348 # lock between read and write to ensure that no rule gets lost
283- def save_test (self , command : str ) -> None :
349+ def save_test (
350+ self , command : dict [str , str ] | str , submission : Submission , verdict : Verdict
351+ ) -> None :
284352 with self .generators_yaml_mutex :
285353 generators_yaml = self .problem .path / "generators/generators.yaml"
286354 data = None
@@ -298,3 +366,6 @@ def save_test(self, command: str) -> None:
298366
299367 # Overwrite generators.yaml.
300368 write_yaml (data , generators_yaml )
369+
370+ self .summary .setdefault (submission , set ()).add (verdict )
371+ self .added += 1
0 commit comments