Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from typing import Any, Iterable

import torch
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
from torchvision.transforms.functional import InterpolationMode


class Composer(abc.ABC):
def __call__(
class Composer(torch.nn.Module):
def forward(
self,
perturbation: torch.Tensor | Iterable[torch.Tensor],
*,
Expand All @@ -24,6 +28,17 @@ def __call__(
if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor):
return self.compose(perturbation, input=input, target=target)

elif (
isinstance(perturbation, torch.Tensor)
and isinstance(input, Iterable) # noqa: W503
and isinstance(target, Iterable) # noqa: W503
):
# FIXME: replace tuple with whatever input's type is
return tuple(
self.compose(perturbation, input=input_i, target=target_i)
for input_i, target_i in zip(input, target)
)

elif (
isinstance(perturbation, Iterable)
and isinstance(input, Iterable) # noqa: W503
Expand Down Expand Up @@ -56,8 +71,13 @@ def compose(self, perturbation, *, input, target):
return input + perturbation


class Overlay(Composer):
"""We assume an adversary overlays a patch to the input."""
class Composite(Composer):
"""We assume an adversary underlays a patch to the input."""

def __init__(self, premultiplied_alpha=False):
super().__init__()

self.premultiplied_alpha = premultiplied_alpha

def compose(self, perturbation, *, input, target):
# True is mutable, False is immutable.
Expand All @@ -67,14 +87,7 @@ def compose(self, perturbation, *, input, target):
# because some data modules (e.g. Armory) gives binary mask.
mask = mask.to(input)

return input * (1 - mask) + perturbation * mask


class MaskAdditive(Composer):
"""We assume an adversary adds masked perturbation to the input."""

def compose(self, perturbation, *, input, target):
mask = target["perturbable_mask"]
masked_perturbation = perturbation * mask
if not self.premultiplied_alpha:
perturbation = perturbation * mask

return input + masked_perturbation
return input * (1 - mask) + perturbation
1 change: 1 addition & 0 deletions mart/configs/attack/composer/composite.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: mart.attack.composer.Composite
1 change: 0 additions & 1 deletion mart/configs/attack/composer/mask_additive.yaml

This file was deleted.

1 change: 0 additions & 1 deletion mart/configs/attack/composer/overlay.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion mart/configs/attack/object_detection_mask_adversary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defaults:
- perturber: default
- perturber/initializer: constant
- perturber/projector: mask_range
- composer: overlay
- composer: composite
- /optimizer@optimizer: sgd
- gain: rcnn_training_loss
- gradient_modifier: sign
Expand Down
17 changes: 3 additions & 14 deletions tests/test_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch

from mart.attack.composer import Additive, MaskAdditive, Overlay
from mart.attack.composer import Additive, Composite


def test_additive_composer_forward(input_data, target_data, perturbation):
Expand All @@ -17,22 +17,11 @@ def test_additive_composer_forward(input_data, target_data, perturbation):
torch.testing.assert_close(output, expected_output, equal_nan=True)


def test_overlay_composer_forward(input_data, target_data, perturbation):
composer = Overlay()
def test_composite_composer_forward(input_data, target_data, perturbation):
composer = Composite()

output = composer(perturbation, input=input_data, target=target_data)
mask = target_data["perturbable_mask"]
mask = mask.to(input_data)
expected_output = input_data * (1 - mask) + perturbation
torch.testing.assert_close(output, expected_output, equal_nan=True)


def test_mask_additive_composer_forward():
input = torch.zeros((2, 2))
perturbation = torch.ones((2, 2))
target = {"perturbable_mask": torch.eye(2)}
expected_output = torch.eye(2)

composer = MaskAdditive()
output = composer(perturbation, input=input, target=target)
torch.testing.assert_close(output, expected_output, equal_nan=True)