-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexample_eval.py
59 lines (53 loc) · 2.06 KB
/
example_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import kornia.augmentation
from torch.utils.data import DataLoader
import tyc_dataset
DATASET_PATH: str = "/Users/christoph/Desktop/tyc/labeled_set/val"
def main() -> None:
# Init augmentations
transforms = kornia.augmentation.AugmentationSequential(
kornia.augmentation.RandomHorizontalFlip(p=1.0),
data_keys=["input", "mask"],
same_on_batch=False,
)
# Init dataset
dataset = tyc_dataset.data.TYCDataset(path=DATASET_PATH, augmentations=transforms)
# Make data loader
data_loader = DataLoader(
dataset=dataset,
num_workers=8,
batch_size=1, # Only works with a batch size of 1 due to the different image shapes
drop_last=True,
collate_fn=tyc_dataset.data.collate_function_tyc_dataset,
)
# Init PQ
pq = tyc_dataset.eval.PanopticQuality()
cell_iou = tyc_dataset.eval.CellIoU()
# Loop over data loader
for index, (images, instances, class_labels) in enumerate(data_loader):
# Get semantic classes form one-hot vector
semantic_classes = [c.argmax(dim=-1) for c in class_labels]
# Shift the label just a little to simulate a near perfect prediction
instances_pred = [i.clone().roll(shifts=(5, 5), dims=(1, 2)) for i in instances]
# Copy list
semantic_classes_pred = semantic_classes.copy()
# Simulate case where no instance was detected
if index == 0:
instances_pred[0] = None
semantic_classes_pred[0] = None
# Compute metrics
pq.update(
instances_pred=instances_pred,
classes_pred=semantic_classes_pred,
instances_target=instances,
classes_target=semantic_classes,
)
cell_iou.update(
instances_pred=instances_pred,
classes_pred=semantic_classes_pred,
instances_target=instances,
classes_target=semantic_classes,
)
print(f"Panoptic Quality: {pq.compute().item()}")
print(f"Cell class IoU: {cell_iou.compute().item()}")
if __name__ == "__main__":
main()