Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

defaults:
- COCO_TorchvisionFasterRCNN
- override /model/[email protected]: tuple_tensorizer_normalizer
- override /datamodule: armory_carla_over_objdet_perturbable_mask

task_name: "ArmoryCarlaOverObjDet_TorchvisionFasterRCNN"
Expand Down
9 changes: 0 additions & 9 deletions mart/configs/model/modules/tuple_normalizer.yaml

This file was deleted.

14 changes: 0 additions & 14 deletions mart/configs/model/modules/tuple_tensorizer_normalizer.yaml

This file was deleted.

8 changes: 7 additions & 1 deletion mart/configs/model/torchvision_object_detection.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# We simply wrap a torchvision object detection model for validation.
defaults:
- modular
- /model/[email protected]: tuple_normalizer

training_step_log:
loss: "loss"
Expand All @@ -13,6 +12,13 @@ test_sequence: ???
output_preds_key: "losses_and_detections.eval"

modules:
preprocessor:
_target_: mart.transforms.TupleTransforms
transforms:
_target_: torchvision.transforms.Normalize
mean: 0
std: 255

losses_and_detections:
# Return losses in the training mode and predictions in the eval mode in one pass.
_target_: mart.models.DualMode
Expand Down
17 changes: 0 additions & 17 deletions mart/datamodules/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,6 @@ def _load_target(self, id: int) -> List[Any]:

return {"image_id": id, "file_name": file_name, "annotations": annotations}

def __getitem__(self, index: int):
"""Override __getitem__() to dictionarize input for multi-modality datasets.

This runs after _load_image() and transforms(), while transforms() typically converts
images to tensors.
"""

image, target_dict = super().__getitem__(index)

# Convert multi-modal input to a dictionary.
if self.modalities is not None:
# We assume image is a multi-channel tensor, with each modality including 3 channels.
assert image.shape[0] == len(self.modalities) * 3
image = dict(zip(self.modalities, image.split(3)))

return image, target_dict


# Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203
def collate_fn(batch):
Expand Down
22 changes: 1 addition & 21 deletions mart/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@
import torch
from torchvision.transforms import transforms as T

__all__ = [
"Denormalize",
"Cat",
"Permute",
"Unsqueeze",
"Squeeze",
"Chunk",
"TupleTransforms",
"GetItems",
]
__all__ = ["Denormalize", "Cat", "Permute", "Unsqueeze", "Squeeze", "Chunk", "TupleTransforms"]


class Denormalize(T.Normalize):
Expand Down Expand Up @@ -90,14 +81,3 @@ def __init__(self, transforms):
def forward(self, x_tuple):
output_tuple = tuple(self.transforms(x) for x in x_tuple)
return output_tuple


class GetItems:
"""Get a list of values with a list of keys from a dictionary."""

def __init__(self, keys):
self.keys = keys

def __call__(self, x):
x_list = [x[key] for key in self.keys]
return x_list