1
1
2
+ from abc import ABC , abstractmethod
2
3
from matplotlib import pyplot as plt
3
4
import numpy as np
4
5
import torch
@@ -122,52 +123,29 @@ def forward(self, image):
122
123
# print(x.shape)
123
124
return x
124
125
125
- class UnetModel :
126
- """
127
- The Unet model takes the character density map image
128
- and returns the masks of the ID card number, first name,
129
- surname and date of birth regions on this image.
130
- The Unet model was trained with 3 different backbones,
131
- the most successful of which was obtained from the resnet34 backbone.
132
- """
133
-
134
- def __init__ (self , model_name , device ):
135
- self .device = device
136
- self .model_name = model_name
137
-
138
- print ("Loading {} model" .format ( self .model_name ))
139
-
140
- def predict (self ,input_img ):
141
-
142
- predicted_mask = None
126
+ class UnetBackBones (ABC ):
127
+ @abstractmethod
128
+ def load_model (self , device ):
129
+ pass
143
130
144
- if (self .model_name == "resnet34" ):
145
- predicted_mask = self .__load_resnet34_model (input_img )
146
-
147
- elif (self .model_name == "resnet50" ):
148
- predicted_mask = self .__load_resnet50_model (input_img )
149
-
150
- elif (self .model_name == "vgg13" ):
151
- predicted_mask = self .__load_vgg13_model (input_img )
152
-
153
- elif (self .model_name == "original" ):
154
- predicted_mask = self .__load_orig_model (input_img )
155
-
156
- else :
157
- print ("Select from resnet34, resnet50 or original" )
158
-
159
- return predicted_mask
131
+ @abstractmethod
132
+ def predict (self , model , img ):
133
+ pass
160
134
161
- def __load_resnet34_model (self , input_img ):
162
-
135
+ class Res34BackBone (UnetBackBones ):
136
+
137
+ def load_model (self , device ):
163
138
model = smp .Unet (encoder_name = "resnet34" , encoder_weights = "imagenet" , in_channels = 3 , classes = 1 )
164
- model .load_state_dict (torch .load ('model/resnet34/UNet_sig.pth' ,map_location = self .device ))
165
- model = model .to (self .device )
139
+ model .load_state_dict (torch .load ('model/resnet34/UNet_sig.pth' ,map_location = device ))
140
+ model = model .to (device )
141
+ return model
142
+
143
+ def predict (self , model , input_img , device ):
166
144
167
145
img = torch .tensor (input_img )
168
146
img = img .permute ((2 , 0 , 1 )).unsqueeze (0 ).float ()
169
147
170
- img = img .to (self . device )
148
+ img = img .to (device )
171
149
output = model (img )
172
150
output = output .squeeze (0 )
173
151
output [output > 0.0 ] = 1.0
@@ -177,16 +155,20 @@ def __load_resnet34_model(self, input_img):
177
155
predicted_mask = output .detach ().cpu ().numpy ()
178
156
179
157
return np .uint8 (predicted_mask )
180
-
181
158
182
- def __load_resnet50_model (self , input_img ):
159
+ class Res50BackBone (UnetBackBones ):
160
+
161
+ def load_model (self , device ):
183
162
184
163
model = smp .Unet (encoder_name = "resnet50" , encoder_weights = "imagenet" , in_channels = 3 , classes = 1 )
185
164
model .load_state_dict (torch .load ('model/resnet50/UNet.pth' ))
186
- model = model .to (self .device )
165
+ model = model .to (device )
166
+ return model
167
+
168
+ def predict (self , model , input_img , device ):
187
169
188
170
input_tensor = torch .tensor (input_img )
189
- input_tensor = input_tensor .permute ((2 , 0 , 1 )).unsqueeze (0 ).float ().to (self . device )
171
+ input_tensor = input_tensor .permute ((2 , 0 , 1 )).unsqueeze (0 ).float ().to (device )
190
172
191
173
output = model (input_tensor )
192
174
output = output .squeeze (0 )
@@ -197,15 +179,19 @@ def __load_resnet50_model(self, input_img):
197
179
predicted_mask = output .detach ().cpu ().numpy ()
198
180
199
181
return np .uint8 (predicted_mask )
182
+
183
+ class Vgg13BackBone (UnetBackBones ):
200
184
201
- def __load_vgg13_model (self , input_img ):
185
+ def load_model (self , device ):
202
186
203
187
model = smp .Unet (encoder_name = "vgg13" , encoder_weights = "imagenet" , in_channels = 3 , classes = 1 )
204
188
model .load_state_dict (torch .load ('model/vgg13/UNet.pth' ))
205
- model = model .to (self .device )
206
-
189
+ model = model .to (device )
190
+
191
+ def predict (self , model , input_img , device ):
192
+
207
193
input_tensor = torch .tensor (input_img )
208
- input_tensor = input_tensor .permute ((2 , 0 , 1 )).unsqueeze (0 ).float ().to (self . device )
194
+ input_tensor = input_tensor .permute ((2 , 0 , 1 )).unsqueeze (0 ).float ().to (device )
209
195
210
196
output = model (input_tensor )
211
197
output = output .squeeze (0 )
@@ -216,23 +202,47 @@ def __load_vgg13_model(self, input_img):
216
202
predicted_mask = output .detach ().cpu ().numpy ()
217
203
218
204
return np .uint8 (predicted_mask )
205
+
206
+ class NoBackBone (UnetBackBones ):
219
207
220
- def __load_orig_model (self , input_img ):
221
-
208
+ def load_model (self , device ):
222
209
model = UNET ()
223
210
model .load_state_dict (torch .load ('model/orig_unet/unetModel_20.pth' ))
224
- model = model .to (self .device )
225
-
211
+ model = model .to (device )
212
+
213
+ def predict (self , model , input_img , device ):
226
214
input_tensor = torch .tensor (input_img )
227
- input_tensor = input_tensor .permute ((2 , 0 , 1 )).unsqueeze (0 ).float ().to (self . device )
215
+ input_tensor = input_tensor .permute ((2 , 0 , 1 )).unsqueeze (0 ).float ().to (device )
228
216
229
217
output = model (input_tensor )
230
218
output = output .squeeze (0 )
231
219
output [output > 0.0 ] = 1.0
232
- output [output <= 0.0 ]= 0
220
+ output [output <= 0.0 ]= 0
233
221
output = output .squeeze (0 )
234
222
235
223
predicted_mask = output .detach ().cpu ().numpy ()
236
224
237
225
return np .uint8 (predicted_mask )
238
226
227
+ class UnetModel :
228
+ """
229
+ The Unet model takes the character density map image
230
+ and returns the masks of the ID card number, first name,
231
+ surname and date of birth regions on this image.
232
+ The Unet model was trained with 3 different backbones,
233
+ the most successful of which was obtained from the resnet34 backbone.
234
+ """
235
+
236
+ def __init__ (self , backbone :UnetBackBones = Res34BackBone (), device = "cuda" ):
237
+ self .device = device
238
+ self .backbone = backbone
239
+
240
+
241
+ def predict (self ,input_img ):
242
+
243
+ model = self .backbone .load_model (self .device )
244
+ predicted_mask = self .backbone .predict (model , input_img , self .device )
245
+
246
+ return predicted_mask
247
+
248
+
0 commit comments