@@ -248,6 +248,7 @@ def save(
248248 self ,
249249 data_loader ,
250250 save_path = None ,
251+ save_every_iter = False ,
251252 verbose = True ,
252253 return_verbose = False ,
253254 save_predictions = False ,
@@ -260,6 +261,7 @@ def save(
260261 Arguments:
261262 save_path (str): save_path.
262263 data_loader (torch.utils.data.DataLoader): data loader.
264+ save_every_iter (bool): True for save every results every iter. (Default: False)
263265 verbose (bool): True for displaying detailed information. (Default: True)
264266 return_verbose (bool): True for returning detailed information. (Default: False)
265267 save_predictions (bool): True for saving predicted labels (Default: False)
@@ -319,45 +321,33 @@ def save(
319321 if save_path is not None :
320322 adv_input_list .append (adv_inputs .detach ().cpu ())
321323 label_list .append (labels .detach ().cpu ())
322-
323- adv_input_list_cat = torch .cat (adv_input_list , 0 )
324- label_list_cat = torch .cat (label_list , 0 )
325-
326- save_dict = {
327- "adv_inputs" : adv_input_list_cat ,
328- "labels" : label_list_cat ,
329- } # nopep8
330-
331324 if save_predictions :
332325 pred_list .append (pred .detach ().cpu ())
333- pred_list_cat = torch .cat (pred_list , 0 )
334- save_dict ["preds" ] = pred_list_cat
335-
336326 if save_clean_inputs :
337327 input_list .append (inputs .detach ().cpu ())
338- input_list_cat = torch . cat ( input_list , 0 )
339- save_dict [ "clean_inputs" ] = input_list_cat
340-
341- if self . normalization_used is not None :
342- save_dict [ "adv_inputs" ] = self . inverse_normalize (
343- save_dict [ "adv_inputs" ]
344- ) # nopep8
345- if save_clean_inputs :
346- save_dict [ "clean_inputs" ] = self . inverse_normalize (
347- save_dict [ "clean_inputs" ]
348- ) # nopep8
349-
350- if save_type == "int" :
351- save_dict [ "adv_inputs" ] = self .to_type (
352- save_dict [ "adv_inputs" ], "int"
353- ) # nopep8
354- if save_clean_inputs :
355- save_dict [ "clean_inputs" ] = self . to_type (
356- save_dict [ "clean_inputs" ], "int"
357- ) # nopep8
358-
359- save_dict [ "save_type" ] = save_type
360- torch . save ( save_dict , save_path )
328+ if save_every_iter :
329+ self . _save_adv_examples (
330+ save_type ,
331+ save_path ,
332+ adv_input_list ,
333+ label_list ,
334+ save_predictions = save_predictions ,
335+ pred_list = pred_list if save_predictions else None ,
336+ save_clean_inputs = save_clean_inputs ,
337+ input_list = input_list if save_clean_inputs else None ,
338+ )
339+
340+ if save_path is not None and not save_every_iter :
341+ self ._save_adv_examples (
342+ save_type ,
343+ save_path ,
344+ adv_input_list ,
345+ label_list ,
346+ save_predictions = save_predictions ,
347+ pred_list = pred_list if save_predictions else None ,
348+ save_clean_inputs = save_clean_inputs ,
349+ input_list = input_list if save_clean_inputs else None ,
350+ )
361351
362352 # To avoid erasing the printed information.
363353 if verbose :
@@ -388,6 +378,57 @@ def to_type(inputs, type):
388378 raise ValueError (type + " is not a valid type. [Options: float, int]" )
389379 return inputs
390380
381+
382+ def _save_adv_examples (
383+ self ,
384+ save_type ,
385+ save_path ,
386+ adv_input_list ,
387+ label_list ,
388+ save_predictions = False ,
389+ pred_list = [],
390+ save_clean_inputs = False ,
391+ input_list = [],
392+ ):
393+
394+
395+ adv_input_list_cat = torch .cat (adv_input_list , 0 )
396+ label_list_cat = torch .cat (label_list , 0 )
397+
398+ save_dict = {
399+ "adv_inputs" : adv_input_list_cat ,
400+ "labels" : label_list_cat ,
401+ }
402+
403+ if save_predictions :
404+ pred_list_cat = torch .cat (pred_list , 0 )
405+ save_dict ["preds" ] = pred_list_cat
406+
407+ if save_clean_inputs :
408+ input_list_cat = torch .cat (input_list , 0 )
409+ save_dict ["clean_inputs" ] = input_list_cat
410+
411+ if self .normalization_used is not None :
412+ save_dict ["adv_inputs" ] = self .inverse_normalize (
413+ save_dict ["adv_inputs" ]
414+ ) # nopep8
415+ if save_clean_inputs :
416+ save_dict ["clean_inputs" ] = self .inverse_normalize (
417+ save_dict ["clean_inputs" ]
418+ ) # nopep8
419+
420+ if save_type == "int" :
421+ save_dict ["adv_inputs" ] = self .to_type (
422+ save_dict ["adv_inputs" ], "int"
423+ ) # nopep8
424+ if save_clean_inputs :
425+ save_dict ["clean_inputs" ] = self .to_type (
426+ save_dict ["clean_inputs" ], "int"
427+ ) # nopep8
428+
429+ save_dict ["save_type" ] = save_type
430+ torch .save (save_dict , save_path )
431+
391432 @staticmethod
392433 def _save_print (progress , rob_acc , l2 , elapsed_time , end ):
393434 print (
0 commit comments