@@ -67,10 +67,7 @@ def __getitem__(self, index):
67
67
batch_y = []
68
68
for f in fs :
69
69
img , path = f .result ()
70
- if self .augmentation :
71
- img = self .transform (image = img )['image' ]
72
- img = cv2 .resize (img , (self .input_shape [1 ], self .input_shape [0 ]))
73
- x = np .asarray (img ).reshape (self .input_shape ).astype ('float32' ) / 255.0
70
+ x = self .preprocess (img , aug = self .augmentation )
74
71
batch_x .append (x )
75
72
76
73
dir_name = path .replace (self .root_path , '' ).split ('/' )[1 ]
@@ -99,8 +96,16 @@ def random_blur(self, img):
99
96
img = cv2 .GaussianBlur (img , (kernel_size , kernel_size ), 0 )
100
97
return img
101
98
102
- def load_img (self , path ):
103
- img = cv2 .imdecode (np .fromfile (path , dtype = np .uint8 ), cv2 .IMREAD_GRAYSCALE if self .input_shape [2 ] == 1 else cv2 .IMREAD_COLOR )
99
+ def preprocess (self , img , aug = False ):
100
+ img = cv2 .resize (img , (self .input_shape [1 ], self .input_shape [0 ]))
101
+ if aug :
102
+ img = self .transform (image = img )['image' ]
104
103
if self .input_shape [- 1 ] == 3 :
105
104
img = cv2 .cvtColor (img , cv2 .COLOR_BGR2RGB ) # swap rb
105
+ x = np .asarray (img ).reshape (self .input_shape ).astype ('float32' ) / 255.0
106
+ return x
107
+
108
+ def load_img (self , path ):
109
+ img = cv2 .imdecode (np .fromfile (path , dtype = np .uint8 ), cv2 .IMREAD_GRAYSCALE if self .input_shape [- 1 ] == 1 else cv2 .IMREAD_COLOR )
106
110
return img , path
111
+
0 commit comments