Skip to content

Commit c8bd563

Browse files
committed
Perf attack save
1 parent 936e86d commit c8bd563

File tree

1 file changed

+76
-35
lines changed

1 file changed

+76
-35
lines changed

torchattacks/attack.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def save(
248248
self,
249249
data_loader,
250250
save_path=None,
251+
save_every_iter=False,
251252
verbose=True,
252253
return_verbose=False,
253254
save_predictions=False,
@@ -260,6 +261,7 @@ def save(
260261
Arguments:
261262
save_path (str): save_path.
262263
data_loader (torch.utils.data.DataLoader): data loader.
264+
save_every_iter (bool): True for save every results every iter. (Default: False)
263265
verbose (bool): True for displaying detailed information. (Default: True)
264266
return_verbose (bool): True for returning detailed information. (Default: False)
265267
save_predictions (bool): True for saving predicted labels (Default: False)
@@ -319,45 +321,33 @@ def save(
319321
if save_path is not None:
320322
adv_input_list.append(adv_inputs.detach().cpu())
321323
label_list.append(labels.detach().cpu())
322-
323-
adv_input_list_cat = torch.cat(adv_input_list, 0)
324-
label_list_cat = torch.cat(label_list, 0)
325-
326-
save_dict = {
327-
"adv_inputs": adv_input_list_cat,
328-
"labels": label_list_cat,
329-
} # nopep8
330-
331324
if save_predictions:
332325
pred_list.append(pred.detach().cpu())
333-
pred_list_cat = torch.cat(pred_list, 0)
334-
save_dict["preds"] = pred_list_cat
335-
336326
if save_clean_inputs:
337327
input_list.append(inputs.detach().cpu())
338-
input_list_cat = torch.cat(input_list, 0)
339-
save_dict["clean_inputs"] = input_list_cat
340-
341-
if self.normalization_used is not None:
342-
save_dict["adv_inputs"] = self.inverse_normalize(
343-
save_dict["adv_inputs"]
344-
) # nopep8
345-
if save_clean_inputs:
346-
save_dict["clean_inputs"] = self.inverse_normalize(
347-
save_dict["clean_inputs"]
348-
) # nopep8
349-
350-
if save_type == "int":
351-
save_dict["adv_inputs"] = self.to_type(
352-
save_dict["adv_inputs"], "int"
353-
) # nopep8
354-
if save_clean_inputs:
355-
save_dict["clean_inputs"] = self.to_type(
356-
save_dict["clean_inputs"], "int"
357-
) # nopep8
358-
359-
save_dict["save_type"] = save_type
360-
torch.save(save_dict, save_path)
328+
if save_every_iter:
329+
self._save_adv_examples(
330+
save_type,
331+
save_path,
332+
adv_input_list,
333+
label_list,
334+
save_predictions = save_predictions,
335+
pred_list = pred_list if save_predictions else None,
336+
save_clean_inputs = save_clean_inputs,
337+
input_list = input_list if save_clean_inputs else None,
338+
)
339+
340+
if save_path is not None and not save_every_iter:
341+
self._save_adv_examples(
342+
save_type,
343+
save_path,
344+
adv_input_list,
345+
label_list,
346+
save_predictions = save_predictions,
347+
pred_list = pred_list if save_predictions else None,
348+
save_clean_inputs = save_clean_inputs,
349+
input_list = input_list if save_clean_inputs else None,
350+
)
361351

362352
# To avoid erasing the printed information.
363353
if verbose:
@@ -388,6 +378,57 @@ def to_type(inputs, type):
388378
raise ValueError(type + " is not a valid type. [Options: float, int]")
389379
return inputs
390380

381+
382+
def _save_adv_examples(
383+
self,
384+
save_type,
385+
save_path,
386+
adv_input_list,
387+
label_list,
388+
save_predictions = False,
389+
pred_list = [],
390+
save_clean_inputs = False,
391+
input_list = [],
392+
):
393+
394+
395+
adv_input_list_cat = torch.cat(adv_input_list, 0)
396+
label_list_cat = torch.cat(label_list, 0)
397+
398+
save_dict = {
399+
"adv_inputs": adv_input_list_cat,
400+
"labels": label_list_cat,
401+
}
402+
403+
if save_predictions:
404+
pred_list_cat = torch.cat(pred_list, 0)
405+
save_dict["preds"] = pred_list_cat
406+
407+
if save_clean_inputs:
408+
input_list_cat = torch.cat(input_list, 0)
409+
save_dict["clean_inputs"] = input_list_cat
410+
411+
if self.normalization_used is not None:
412+
save_dict["adv_inputs"] = self.inverse_normalize(
413+
save_dict["adv_inputs"]
414+
) # nopep8
415+
if save_clean_inputs:
416+
save_dict["clean_inputs"] = self.inverse_normalize(
417+
save_dict["clean_inputs"]
418+
) # nopep8
419+
420+
if save_type == "int":
421+
save_dict["adv_inputs"] = self.to_type(
422+
save_dict["adv_inputs"], "int"
423+
) # nopep8
424+
if save_clean_inputs:
425+
save_dict["clean_inputs"] = self.to_type(
426+
save_dict["clean_inputs"], "int"
427+
) # nopep8
428+
429+
save_dict["save_type"] = save_type
430+
torch.save(save_dict, save_path)
431+
391432
@staticmethod
392433
def _save_print(progress, rob_acc, l2, elapsed_time, end):
393434
print(

0 commit comments

Comments
 (0)