-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathours_transforms.py
138 lines (113 loc) · 4.73 KB
/
ours_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import numpy as np
from numpy import random
class RandomCrop(object):
"""Random crop the image & seg.
Args:
crop_size (tuple): Expected size after cropping, (h, w).
cat_max_ratio (float): The maximum ratio that single category could
occupy.
"""
def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
self.cat_max_ratio = cat_max_ratio
self.ignore_index = ignore_index
def get_crop_bbox(self, img):
"""Randomly get a crop bounding box."""
margin_h = max(img.shape[0] - self.crop_size[0], 0)
margin_w = max(img.shape[1] - self.crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
return crop_y1, crop_y2, crop_x1, crop_x2
def crop(self, img, crop_bbox):
"""Crop from ``img``"""
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
return img
def __call__(self, results):
"""Call function to randomly crop images, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
img = results['img']
if 'crop_bbox' in results:
crop_bbox = results['crop_bbox']
else:
crop_bbox = self.get_crop_bbox(img)
best_score = -1
best_crop_bbox = None
# Repeat 10 times
for _ in range(10):
if best_score >= 0:
crop_bbox = self.get_crop_bbox(img)
seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
labels, cnt = torch.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
score = 0
if len(cnt) > 1 and torch.max(cnt).item() / torch.sum(cnt).item() < self.cat_max_ratio:
cnt_valid = cnt[cnt > 1]
score = cnt_valid.float().log().sum().item()
if score > best_score:
best_score = score
best_crop_bbox = crop_bbox
crop_bbox = best_crop_bbox
# crop the image
img = self.crop(img, crop_bbox)
img_shape = img.shape
results['img'] = img
results['img_shape'] = img_shape
results['crop_bbox'] = crop_bbox
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = self.crop(results[key], crop_bbox)
return results
def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
class RandomCropNoProd(RandomCrop):
"""Random crop the image & seg.
Args:
crop_size (tuple): Expected size after cropping, (h, w).
cat_max_ratio (float): The maximum ratio that single category could
occupy.
"""
def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
super().__init__(crop_size, cat_max_ratio, ignore_index)
def __call__(self, results):
"""Call function to randomly crop images, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
img = results['img']
if 'crop_bbox' in results:
crop_bbox = results['crop_bbox']
else:
crop_bbox = self.get_crop_bbox(img)
if self.cat_max_ratio < 1.:
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
labels, cnt = torch.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and torch.max(cnt).item() / torch.sum(
cnt).item() < self.cat_max_ratio:
break
crop_bbox = self.get_crop_bbox(img)
# crop the image
img = self.crop(img, crop_bbox)
img_shape = img.shape
results['img'] = img
results['img_shape'] = img_shape
results['crop_bbox'] = crop_bbox
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = self.crop(results[key], crop_bbox)
return results