Skip to content

Commit e4547a8

Browse files
committed
merge resize, augment, normalize functions to preprocess
1 parent 78c629f commit e4547a8

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

generator.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@ def __getitem__(self, index):
6767
batch_y = []
6868
for f in fs:
6969
img, path = f.result()
70-
if self.augmentation:
71-
img = self.transform(image=img)['image']
72-
img = cv2.resize(img, (self.input_shape[1], self.input_shape[0]))
73-
x = np.asarray(img).reshape(self.input_shape).astype('float32') / 255.0
70+
x = self.preprocess(img, aug=self.augmentation)
7471
batch_x.append(x)
7572

7673
dir_name = path.replace(self.root_path, '').split('/')[1]
@@ -99,8 +96,16 @@ def random_blur(self, img):
9996
img = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
10097
return img
10198

102-
def load_img(self, path):
103-
img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_GRAYSCALE if self.input_shape[2] == 1 else cv2.IMREAD_COLOR)
99+
def preprocess(self, img, aug=False):
100+
img = cv2.resize(img, (self.input_shape[1], self.input_shape[0]))
101+
if aug:
102+
img = self.transform(image=img)['image']
104103
if self.input_shape[-1] == 3:
105104
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # swap rb
105+
x = np.asarray(img).reshape(self.input_shape).astype('float32') / 255.0
106+
return x
107+
108+
def load_img(self, path):
109+
img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_GRAYSCALE if self.input_shape[-1] == 1 else cv2.IMREAD_COLOR)
106110
return img, path
111+

0 commit comments

Comments
 (0)