Skip to content
Open
1 change: 1 addition & 0 deletions mart/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .eval_mode import *
from .fiftyone import *
from .gradients import *
from .no_grad_mode import *
from .progress_bar import *
Expand Down
102 changes: 102 additions & 0 deletions mart/callbacks/fiftyone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#
# Copyright (C) 2022 Intel Corporation
#
# SPDX-License-Identifier: BSD-3-Clause
#

import logging
from typing import List

from lightning.pytorch.callbacks import BasePredictionWriter, Callback

from ..datamodules import FiftyOneDataset

logger = logging.getLogger(__name__)
try:
import fiftyone as fo
import fiftyone.brain as fob
except ImportError:
logger.debug("fiftyone module is not installed!")

__all__ = ["FiftyOneEvaluateDetections", "FiftyOneMistakenness", "FiftyOnePredictionAdder"]


class FiftyOneEvaluateDetections(Callback):
def __init__(self, run_id: str, gt_field: str = "ground_truth_detections") -> None:
self.run_id = run_id
self.gt_field = gt_field

def on_predict_end(self, trainer, pl_module):
predict_dataset = trainer.datamodule.predict_dataset
assert isinstance(predict_dataset, FiftyOneDataset)

eval_key = f"eval_{self.run_id}".replace("-", "")
eval_key = eval_key.replace("_", "")
results = predict_dataset.filtered_dataset.evaluate_detections(
f"prediction_{self.run_id}",
gt_field=self.gt_field,
eval_key=eval_key,
compute_mAP=True,
)

logger.info(f"Prediction mAP={results.mAP()}")

# Get the 10 most common classes in the dataset
counts = predict_dataset.filtered_dataset.count_values(f"{self.gt_field}.detections.label")
classes_top10 = sorted(counts, key=counts.get, reverse=True)[:10]

# Print a classification report for the top-10 classes
results.print_report(classes=classes_top10)


class FiftyOneMistakenness(Callback):
def __init__(self, run_id: str, gt_field: str = "ground_truth_detections") -> None:
self.prediction_field = f"prediction_{run_id}"
self.gt_field = gt_field

def on_predict_start(self, trainer, pl_module):
self.predict_dataset = trainer.datamodule.predict_dataset
assert isinstance(self.predict_dataset, FiftyOneDataset)

# reset mistakenness fields
if self.predict_dataset.dataset.has_brain_run("mistakenness"):
self.predict_dataset.dataset.delete_brain_run("mistakenness")

def on_predict_end(self, trainer, pl_module):
fob.compute_mistakenness(
self.predict_dataset.filtered_dataset, self.prediction_field, label_field=self.gt_field
)


class FiftyOnePredictionAdder(BasePredictionWriter):
def __init__(self, output_dir: str, write_interval: List[str]) -> None:
super().__init__(write_interval)
self.run_id = f"prediction_{output_dir}"

def _write_predictions(self, predictions, groundtruth_preds, dataset):
for pred, gt_pred in zip(predictions, groundtruth_preds):
filename = gt_pred["file_name"]
dataset.add_predictions(filename, pred, self.run_id)

def write_on_batch_end(
self, trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx
):
predict_dataset = trainer.datamodule.predict_dataset
assert isinstance(predict_dataset, FiftyOneDataset)

self._write_predictions(
prediction[pl_module.output_preds_key],
prediction[pl_module.output_target_key],
predict_dataset,
)

def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
predict_dataset = trainer.datamodule.predict_dataset
assert isinstance(predict_dataset, FiftyOneDataset)

for output in predictions:
self._write_predictions(
output[pl_module.output_preds_key],
output[pl_module.output_target_key],
predict_dataset,
)
3 changes: 3 additions & 0 deletions mart/configs/callbacks/fiftyone_evaluate_detections.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fiftyone_evaluate_detections:
_target_: mart.callbacks.FiftyOneEvaluateDetections
run_id: ${now:%Y-%m-%d}_${now:%H-%M-%S}
3 changes: 3 additions & 0 deletions mart/configs/callbacks/fiftyone_mistakenness.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fiftyone_mistakenness:
_target_: mart.callbacks.FiftyOneMistakenness
run_id: ${now:%Y-%m-%d}_${now:%H-%M-%S}
4 changes: 4 additions & 0 deletions mart/configs/callbacks/fiftyone_prediction_adder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
fiftyone_prediction_adder:
_target_: mart.callbacks.FiftyOnePredictionAdder
output_dir: ${now:%Y-%m-%d}_${now:%H-%M-%S}
write_interval: "epoch"
22 changes: 22 additions & 0 deletions mart/configs/datamodule/fiftyone.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ test_dataset:
quant_min: 0
quant_max: 255

predict_dataset:
_target_: mart.datamodules.fiftyone.FiftyOneDataset
dataset_name: ???
gt_field: ${..train_dataset.gt_field}
sample_tags: []
label_tags: []
transforms:
_target_: mart.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255

num_workers: 2
collate_fn:
_target_: hydra.utils.get_method
Expand Down
24 changes: 24 additions & 0 deletions mart/configs/datamodule/fiftyone_perturbable_mask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,27 @@ test_dataset:
zero_point: 0
quant_min: 0
quant_max: 255

predict_dataset:
_target_: mart.datamodules.fiftyone.FiftyOneDataset
dataset_name: ???
gt_field: ${..train_dataset.gt_field}
sample_tags: []
label_tags: []
transforms:
_target_: mart.transforms.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
# ConvertCocoPolysToMask must be prior to ConvertInstanceSegmentationToPerturbable.
- _target_: mart.transforms.ConvertCocoPolysToMask
- _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable
- _target_: mart.transforms.Denormalize
center: 0
scale: 255
- _target_: torch.fake_quantize_per_tensor_affine
_partial_: true
# (x/1+0).round().clamp(0, 255) * 1
scale: 1
zero_point: 0
quant_min: 0
quant_max: 255
23 changes: 23 additions & 0 deletions mart/configs/experiment/FiftyOne_TorchvisionFasterRCNN.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

defaults:
- COCO_TorchvisionFasterRCNN
- override /datamodule: fiftyone
- override /callbacks:
[
model_checkpoint,
lr_monitor,
fiftyone_prediction_adder,
fiftyone_evaluate_detections,
fiftyone_mistakenness,
]

task_name: "FiftyOne_TorchvisionFasterRCNN"

model:
predict_sequence:
seq010:
preprocessor: ["input"]

seq020:
losses_and_detections: ["preprocessor", "target"]
3 changes: 3 additions & 0 deletions mart/configs/lightning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ fit: True
# lightning chooses best model based on metric specified in checkpoint callback
test: True

# run inference on the predict set.
predict: False

# Whether to resume training using configuration and checkpoint in specified directory
resume: null

Expand Down
33 changes: 33 additions & 0 deletions mart/datamodules/fiftyone.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,36 @@ def __getitem__(self, index: int) -> Any:

def __len__(self) -> int:
return len(self.filtered_dataset)

def add_predictions(self, sample_identifier: Any, preds: List[dict], field_name: str) -> None:
# get the sample that the detections will be added
sample = self.filtered_dataset[sample_identifier]
w = sample.metadata.width
h = sample.metadata.height

# get the dataset classes
classes = self.filtered_dataset.default_classes

# extract prediction values
labels = preds["labels"]
scores = preds["scores"]
boxes = preds["boxes"]

# convert detections to FiftyOne format
detections = []
for label, score, box in zip(labels, scores, boxes):
if label >= len(classes):
continue

# Convert to [top-left-x, top-left-y, width, height]
# in relative coordinates in [0, 1] x [0, 1]
x1, y1, x2, y2 = box
rel_box = [x1 / w, y1 / h, (x2 - x1) / w, (y2 - y1) / h]

detections.append(
fo.Detection(label=classes[label], bounding_box=rel_box, confidence=score)
)

# save detections to dataset
sample[field_name] = fo.Detections(detections=detections)
sample.save()
27 changes: 27 additions & 0 deletions mart/datamodules/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ def __init__(
train_dataset,
val_dataset,
test_dataset=None,
predict_dataset=None,
train_sampler=None,
val_sampler=None,
test_sampler=None,
predict_sampler=None,
num_workers=0,
collate_fn=None,
ims_per_batch=1,
Expand All @@ -46,6 +48,9 @@ def __init__(
self.test_dataset = test_dataset
self.test_sampler = test_sampler

self.predict_dataset = predict_dataset
self.predict_sampler = predict_sampler

self.num_workers = num_workers
self.collate_fn = collate_fn
if not callable(self.collate_fn):
Expand Down Expand Up @@ -78,6 +83,10 @@ def setup(self, stage=None):
if not isinstance(self.test_dataset, (Dataset, type(None))):
self.test_dataset = instantiate(self.test_dataset)

if stage == "predict" or stage is None:
if not isinstance(self.predict_dataset, (Dataset, type(None))):
self.predict_dataset = instantiate(self.predict_dataset)

def train_dataloader(self):
batch_sampler = self.train_sampler
if not isinstance(batch_sampler, (Sampler, type(None))):
Expand Down Expand Up @@ -132,3 +141,21 @@ def test_dataloader(self):
collate_fn=self.collate_fn,
**kwargs,
)

def predict_dataloader(self):
batch_sampler = self.predict_sampler
if not isinstance(batch_sampler, (Sampler, type(None))):
batch_sampler = instantiate(batch_sampler, self.predict_dataset)

kwargs = {"batch_sampler": batch_sampler, "pin_memory": self.pin_memory}

if batch_sampler is None:
kwargs["batch_size"] = self.batch_size
kwargs["shuffle"] = False

return DataLoader(
self.predict_dataset,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
**kwargs,
)
22 changes: 22 additions & 0 deletions mart/models/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(
test_sequence=None,
test_step_log=None,
test_metrics=None,
predict_sequence=None,
predict_step_log=None,
load_state_dict=None,
output_loss_key="loss",
output_preds_key="preds",
Expand All @@ -54,6 +56,8 @@ def __init__(
validation_sequence = [validation_sequence[key] for key in sorted(validation_sequence)]
if isinstance(test_sequence, dict):
test_sequence = [test_sequence[key] for key in sorted(test_sequence)]
if isinstance(predict_sequence, dict):
predict_sequence = [predict_sequence[key] for key in sorted(predict_sequence)]

# *_step() functions make some assumptions about the type of Module it can call.
# That is, injecting a nn.Module generally won't work, so better to hardcode SequentialDict.
Expand All @@ -62,6 +66,7 @@ def __init__(
"training": training_sequence,
"validation": validation_sequence,
"test": test_sequence,
"predict": predict_sequence,
}
self.model = SequentialDict(modules, sequences)

Expand Down Expand Up @@ -90,6 +95,11 @@ def __init__(
self.test_step_log = test_step_log or {}
self.test_metrics = test_metrics

# Be backwards compatible by turning list into dict where each item is its own key-value
if isinstance(predict_step_log, (list, tuple)):
predict_step_log = {item: item for item in predict_step_log}
self.predict_step_log = predict_step_log or {}

# Load state dict for specified modules. We flatten it because Hydra
# commandlines converts dotted paths to nested dictionaries.
if isinstance(load_state_dict, str):
Expand Down Expand Up @@ -204,6 +214,18 @@ def on_test_epoch_end(self):

self.log_metrics(metrics, prefix="test_metrics")

#
# Predict
#
def predict_step(self, batch, batch_idx):
input, target = batch
pred = self(input=input, target=target, model=self.model, step="predict")

for log_name, output_key in self.predict_step_log.items():
self.log(f"predict/{log_name}", pred[output_key])

return pred

#
# Utilities
#
Expand Down
6 changes: 6 additions & 0 deletions mart/tasks/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def lightning(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

test_metrics = trainer.callback_metrics

if cfg.get("predict"):
log.info("Starting predictions!")
trainer.predict(
model=model, datamodule=datamodule, ckpt_path=ckpt_path, return_predictions=False
)

# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}

Expand Down