Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions mart/attack/adversary_in_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,17 @@ def convert_input_art_to_mart(self, x: numpy.ndarray):
x (np.ndarray): NHWC, [0, 1]

Returns:
tuple: a tuple of tensors in CHW, [0, 255].
list[torch.Tensor]: a list of tensors in CHW, [0, 255].
"""
input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255
input = tuple(inp_ for inp_ in input)
input = [inp_ for inp_ in input]
return input

def convert_input_mart_to_art(self, input: tuple):
def convert_input_mart_to_art(self, input: list[torch.Tensor]):
"""Convert MART input to the ART's format.

Args:
input (tuple): a tuple of tensors in CHW, [0, 255].
input (list[torch.Tensor]): a list of tensors in CHW, [0, 255].

Returns:
np.ndarray: NHWC, [0, 1]
Expand All @@ -112,7 +112,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List):
y_patch_metadata (_type_): _description_

Returns:
tuple: a tuple of target dictionaies.
list: a list of target dictionaies.
"""
# Copy y to target, and convert ndarray to pytorch tensors accordingly.
target = []
Expand All @@ -132,6 +132,4 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List):
target_i["file_name"] = f"{yi['image_id'][0]}.jpg"
target.append(target_i)

target = tuple(target)

return target
4 changes: 2 additions & 2 deletions mart/attack/adversary_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(

def forward(
self,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
model: torch.nn.Module | None = None,
**kwargs,
):
Expand Down
24 changes: 12 additions & 12 deletions mart/attack/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def on_run_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -35,8 +35,8 @@ def on_examine_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -46,8 +46,8 @@ def on_examine_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -57,8 +57,8 @@ def on_advance_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -68,8 +68,8 @@ def on_advance_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -79,8 +79,8 @@ def on_run_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
model: torch.nn.Module,
**kwargs,
):
Expand Down
14 changes: 7 additions & 7 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@
class Composer(abc.ABC):
def __call__(
self,
perturbation: torch.Tensor | tuple,
perturbation: torch.Tensor | list[torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
**kwargs,
) -> torch.Tensor | tuple:
if isinstance(perturbation, tuple):
input_adv = tuple(
) -> torch.Tensor | list[torch.Tensor]:
if isinstance(perturbation, list):
input_adv = [
self.compose(perturbation_i, input=input_i, target=target_i)
for perturbation_i, input_i, target_i in zip(perturbation, input, target)
)
]
else:
input_adv = self.compose(perturbation, input=input, target=target)

Expand Down
14 changes: 7 additions & 7 deletions mart/attack/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ class ConstraintViolated(Exception):
class Constraint(abc.ABC):
def __call__(
self,
input_adv: torch.Tensor | tuple,
input_adv: torch.Tensor | list[torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
) -> None:
if isinstance(input_adv, tuple):
if isinstance(input_adv, list):
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self.verify(input_adv_i, input=input_i, target=target_i)
else:
Expand Down Expand Up @@ -103,10 +103,10 @@ def __init__(self, constraints: dict[str, Constraint] | None = None) -> None:
@torch.no_grad()
def __call__(
self,
input_adv: torch.Tensor | tuple,
input_adv: torch.Tensor | list[torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
**kwargs,
) -> None:
for constraint in self.constraints.values():
Expand Down
10 changes: 5 additions & 5 deletions mart/attack/perturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def __init__(

self.perturbation = None

def configure_perturbation(self, input: torch.Tensor | tuple):
def configure_perturbation(self, input: torch.Tensor | list[torch.Tensor]):
def create_and_initialize(inp):
pert = torch.empty_like(inp, dtype=torch.float, requires_grad=True)
self.initializer(pert)
return pert

if isinstance(input, tuple):
self.perturbation = tuple(create_and_initialize(inp) for inp in input)
if isinstance(input, list):
self.perturbation = [create_and_initialize(inp) for inp in input]
elif isinstance(input, dict):
raise NotImplementedError
else:
Expand All @@ -81,9 +81,9 @@ def configure_optimizers(self):
)

params = self.perturbation
if not isinstance(params, tuple):
if not isinstance(params, list):
# FIXME: Should we treat the batch dimension as independent parameters?
params = (params,)
params = [params]

return self.optimizer_fn(params)

Expand Down
20 changes: 10 additions & 10 deletions mart/attack/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@ class Projector:
@torch.no_grad()
def __call__(
self,
perturbation: torch.Tensor | tuple,
perturbation: torch.Tensor | list[torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
**kwargs,
) -> None:
if isinstance(perturbation, tuple):
if isinstance(perturbation, list):
for perturbation_i, input_i, target_i in zip(perturbation, input, target):
self.project(perturbation_i, input=input_i, target=target_i)
else:
self.project(perturbation, input=input, target=target)

def project(
self,
perturbation: torch.Tensor | tuple,
perturbation: torch.Tensor | list[torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
) -> None:
pass

Expand All @@ -48,10 +48,10 @@ def __init__(self, projectors: list[Projector]):
@torch.no_grad()
def __call__(
self,
perturbation: torch.Tensor | tuple,
perturbation: torch.Tensor | list[torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | list[torch.Tensor],
target: torch.Tensor | dict[str, Any] | list[Any],
**kwargs,
) -> None:
for projector in self.projectors:
Expand Down
31 changes: 29 additions & 2 deletions mart/datamodules/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from typing import Any, Callable, List, Optional

import numpy as np
import torch
from torch.utils.data._utils.collate import ( # WHY ARE THESE PRIVATE?!
collate,
collate_tensor_fn,
default_collate_fn_map,
)
from torchvision.datasets.coco import CocoDetection as CocoDetection_
from torchvision.datasets.folder import default_loader

Expand Down Expand Up @@ -86,6 +92,27 @@ def __getitem__(self, index: int):
return image, target_dict


# Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203
def _collate_tensor_fn(batch, *, collate_fn_map=None):
"""Handle the case when all elements in list are not the same shape.

Instead of throwing an exception, we just leave them as a list of Tensors.
"""

if not all([x.shape == batch[0].shape for x in batch]):
return list(batch)

return collate_tensor_fn(batch, collate_fn_map=collate_fn_map)


def collate_fn(batch):
return tuple(zip(*batch))
collate_fn_map = default_collate_fn_map.copy()
collate_fn_map[torch.Tensor] = _collate_tensor_fn

images, targets = collate(batch, collate_fn_map=collate_fn_map)

# dict of lists to list of dicts for backwards compatibility
if isinstance(targets, dict):
targets = [dict(zip(targets.keys(), values)) for values in zip(*targets.values())]

# FIXME: Ideally we would just return a dict with {"input": images, **targets}
return images, targets