Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a84f7d8
Make PerturbedImageVisualizer more generic
dxoigmn Jun 5, 2023
25259c3
Disable test
dxoigmn Jun 5, 2023
10accf5
Use attrgetter
dxoigmn Jun 9, 2023
d8fe8a0
Restore image_visualizer config
dxoigmn Jun 9, 2023
48e3be9
Make *_step_log dicts where the key is the logging name and value is …
dxoigmn Jun 12, 2023
01a2066
Fix configs
dxoigmn Jun 13, 2023
df1d0b2
remove sync_dist
dxoigmn Jun 13, 2023
14f4d1f
backwards compatibility
dxoigmn Jun 13, 2023
2e30587
Revert "Fix configs"
dxoigmn Jun 13, 2023
ca17006
Merge branch 'main' into better_litmodular
dxoigmn Jun 13, 2023
6fef148
style
dxoigmn Jun 13, 2023
c4e0d78
Make metric logging keys configurable
dxoigmn Jun 12, 2023
508798c
cleanup
dxoigmn Jun 13, 2023
fc770e8
Remove *_step_end
dxoigmn Jun 14, 2023
5050163
Merge branch 'better_litmodular2' into better_sequentialdict
dxoigmn Jun 14, 2023
4414822
Merge branch 'better_litmodular3' into better_sequentialdict
dxoigmn Jun 14, 2023
c31f4de
Don't require output module with SequentialDict
dxoigmn Jun 12, 2023
549f705
fix configs and tests
dxoigmn Jun 14, 2023
5e73817
Generalize attack objectives
dxoigmn Jun 14, 2023
62216f2
Merge branch 'better_litmodular2' into better_sequentialdict
dxoigmn Jun 14, 2023
a113f7e
Merge branch 'main' into general_visualizer
dxoigmn Jun 14, 2023
94b949b
Merge branch 'better_sequentialdict' into general_visualizer
dxoigmn Jun 14, 2023
1f4c3f9
Merge branch 'main' into general_visualizer
dxoigmn Jun 22, 2023
494a2db
Merge branch 'main' into general_visualizer
dxoigmn Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions mart/callbacks/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,37 @@
# SPDX-License-Identifier: BSD-3-Clause
#

import os
from operator import attrgetter

from pytorch_lightning.callbacks import Callback
from torchvision.transforms import ToPILImage

__all__ = ["PerturbedImageVisualizer"]
__all__ = ["ImageVisualizer"]


class PerturbedImageVisualizer(Callback):
"""Save adversarial images as files."""
class ImageVisualizer(Callback):
def __init__(self, frequency: int = 100, **tag_paths):
self.frequency = frequency
self.tag_paths = tag_paths

def __init__(self, folder):
super().__init__()
def log_image(self, trainer, tag, image):
# Add image to each logger
for logger in trainer.loggers:
# FIXME: Should we just use isinstance(logger.experiment, SummaryWriter)?
if not hasattr(logger.experiment, "add_image"):
continue

# FIXME: This should use the Trainer's logging directory.
self.folder = folder
self.convert = ToPILImage()
logger.experiment.add_image(tag, image, global_step=trainer.global_step)

if not os.path.isdir(self.folder):
os.makedirs(self.folder)
def log_images(self, trainer, pl_module):
for tag, path in self.tag_paths.items():
image = attrgetter(path)(pl_module)
self.log_image(trainer, tag, image)

def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx):
# Save input and target for on_train_end
self.input = batch["input"]
self.target = batch["target"]
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if batch_idx % self.frequency != 0:
return

def on_train_end(self, trainer, model):
# FIXME: We should really just save this to outputs instead of recomputing adv_input
adv_input = model(input=self.input, target=self.target)
self.log_images(trainer, pl_module)

for img, tgt in zip(adv_input, self.target):
fname = tgt["file_name"]
fpath = os.path.join(self.folder, fname)
im = self.convert(img / 255)
im.save(fpath)
def on_train_end(self, trainer, pl_module):
self.log_images(trainer, pl_module)
4 changes: 4 additions & 0 deletions mart/configs/callbacks/perturbation_visualizer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
perturbation_visualizer:
_target_: mart.callbacks.ImageVisualizer
frequency: 100
perturbation: ???
69 changes: 35 additions & 34 deletions tests/test_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,41 @@
from torchvision.transforms import ToPILImage

from mart.attack import Adversary
from mart.callbacks import PerturbedImageVisualizer

# from mart.callbacks import PerturbedImageVisualizer

def test_visualizer_run_end(input_data, target_data, perturbation, tmp_path):
folder = tmp_path / "test"
input_list = [input_data]
target_list = [target_data]

# simulate an addition perturbation
def perturb(input):
result = [sample + perturbation for sample in input]
return result

trainer = Mock()
model = Mock(return_value=perturb(input_list))
outputs = Mock()
batch = {"input": input_list, "target": target_list}
adversary = Mock(spec=Adversary, side_effect=perturb)

visualizer = PerturbedImageVisualizer(folder)
visualizer.on_train_batch_end(trainer, model, outputs, batch, 0)
visualizer.on_train_end(trainer, model)

# verify that the visualizer created the JPG file
expected_output_path = folder / target_data["file_name"]
assert expected_output_path.exists()

# verify image file content
perturbed_img = input_data + perturbation
converter = ToPILImage()
expected_img = converter(perturbed_img / 255)
expected_img.save(folder / "test_expected.jpg")

stored_img = Image.open(expected_output_path)
expected_stored_img = Image.open(folder / "test_expected.jpg")
diff = ImageChops.difference(expected_stored_img, stored_img)
assert not diff.getbbox()
# def test_visualizer_run_end(input_data, target_data, perturbation, tmp_path):
# folder = tmp_path / "test"
# input_list = [input_data]
# target_list = [target_data]
#
# # simulate an addition perturbation
# def perturb(input):
# result = [sample + perturbation for sample in input]
# return result
#
# trainer = Mock()
# model = Mock(return_value=perturb(input_list))
# outputs = Mock()
# batch = {"input": input_list, "target": target_list}
# adversary = Mock(spec=Adversary, side_effect=perturb)
#
# visualizer = PerturbedImageVisualizer(folder)
# visualizer.on_train_batch_end(trainer, model, outputs, batch, 0)
# visualizer.on_train_end(trainer, model)
#
# # verify that the visualizer created the JPG file
# expected_output_path = folder / target_data["file_name"]
# assert expected_output_path.exists()
#
# # verify image file content
# perturbed_img = input_data + perturbation
# converter = ToPILImage()
# expected_img = converter(perturbed_img / 255)
# expected_img.save(folder / "test_expected.jpg")
#
# stored_img = Image.open(expected_output_path)
# expected_stored_img = Image.open(folder / "test_expected.jpg")
# diff = ImageChops.difference(expected_stored_img, stored_img)
# assert not diff.getbbox()