Skip to content

Commit 1a517c3

Browse files
committed
unet model backbones regularized
1 parent beb40a4 commit 1a517c3

File tree

2 files changed

+68
-57
lines changed

2 files changed

+68
-57
lines changed

main.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from find_nearest_box import NearestBox
88
from pytorch_unet.unet_predict import UnetModel
9+
from pytorch_unet.unet_predict import Res34BackBone
910
from extract_words import OcrFactory
1011
import extract_words
1112
import os
@@ -136,20 +137,20 @@ def getBoxRegions(regions):
136137
parser.add_argument('--neighbor_box_distance', default = 50, type = float, help='Nearest box distance threshold')
137138
parser.add_argument('--face_recognition', default = "ssd", type = str, help='face detection algorithm')
138139
parser.add_argument('--ocr_method', default = "EasyOcr", type = str, help='Type of ocr method for converting images to text')
139-
parser.add_argument('--rotation_interval', default = 15, type = int, help='Face search interval for rotation matrix')
140+
parser.add_argument('--rotation_interval', default = 30, type = int, help='Face search interval for rotation matrix')
140141
args = parser.parse_args()
141142

142143
Folder = args.folder_name # identity card images folder
143144
ORI_THRESH = 3 # Orientation angle threshold for skew correction
144145

145146
use_cuda = "cuda" if torch.cuda.is_available() else "cpu"
146147

147-
model = UnetModel("resnet34", use_cuda)
148+
model = UnetModel(Res34BackBone(), use_cuda)
148149
nearestBox = NearestBox(distance_thresh = args.neighbor_box_distance, draw_line=False)
149150
face_detector = detect_face.face_factory(face_model = args.face_recognition)
150151
findFaceID = face_detector.get_face_detector()
151-
Image2Text = extract_words.ocr_factory(ocr_method = args.ocr_method, border_thresh=3, denoise = False)
152-
152+
#Image2Text = extract_words.ocr_factory(ocr_method = args.ocr_method, border_thresh=3, denoise = False)
153+
Image2Text = OcrFactory().select_ocr_method(ocr_method = args.ocr_method, border_thresh=3, denoise = False)
153154

154155
start = time.time()
155156
end = 0

pytorch_unet/unet_predict.py

+63-53
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
from abc import ABC, abstractmethod
23
from matplotlib import pyplot as plt
34
import numpy as np
45
import torch
@@ -122,52 +123,29 @@ def forward(self, image):
122123
# print(x.shape)
123124
return x
124125

125-
class UnetModel:
126-
"""
127-
The Unet model takes the character density map image
128-
and returns the masks of the ID card number, first name,
129-
surname and date of birth regions on this image.
130-
The Unet model was trained with 3 different backbones,
131-
the most successful of which was obtained from the resnet34 backbone.
132-
"""
133-
134-
def __init__(self, model_name, device):
135-
self.device = device
136-
self.model_name = model_name
137-
138-
print("Loading {} model".format( self.model_name))
139-
140-
def predict(self,input_img):
141-
142-
predicted_mask = None
126+
class UnetBackBones(ABC):
127+
@abstractmethod
128+
def load_model(self, device):
129+
pass
143130

144-
if (self.model_name == "resnet34"):
145-
predicted_mask = self.__load_resnet34_model(input_img)
146-
147-
elif (self.model_name == "resnet50"):
148-
predicted_mask = self.__load_resnet50_model(input_img)
149-
150-
elif (self.model_name == "vgg13"):
151-
predicted_mask = self.__load_vgg13_model(input_img)
152-
153-
elif (self.model_name == "original"):
154-
predicted_mask = self.__load_orig_model(input_img)
155-
156-
else:
157-
print("Select from resnet34, resnet50 or original")
158-
159-
return predicted_mask
131+
@abstractmethod
132+
def predict(self, model, img):
133+
pass
160134

161-
def __load_resnet34_model(self, input_img):
162-
135+
class Res34BackBone(UnetBackBones):
136+
137+
def load_model(self, device):
163138
model = smp.Unet(encoder_name="resnet34" , encoder_weights="imagenet", in_channels=3, classes = 1)
164-
model.load_state_dict(torch.load('model/resnet34/UNet_sig.pth',map_location=self.device))
165-
model = model.to(self.device)
139+
model.load_state_dict(torch.load('model/resnet34/UNet_sig.pth',map_location = device))
140+
model = model.to(device)
141+
return model
142+
143+
def predict(self, model, input_img, device):
166144

167145
img = torch.tensor(input_img)
168146
img = img.permute((2, 0, 1)).unsqueeze(0).float()
169147

170-
img = img.to(self.device)
148+
img = img.to(device)
171149
output = model(img)
172150
output= output.squeeze(0)
173151
output[output>0.0] = 1.0
@@ -177,16 +155,20 @@ def __load_resnet34_model(self, input_img):
177155
predicted_mask = output.detach().cpu().numpy()
178156

179157
return np.uint8(predicted_mask)
180-
181158

182-
def __load_resnet50_model(self, input_img):
159+
class Res50BackBone(UnetBackBones):
160+
161+
def load_model(self, device):
183162

184163
model = smp.Unet(encoder_name="resnet50", encoder_weights="imagenet", in_channels=3, classes = 1)
185164
model.load_state_dict(torch.load('model/resnet50/UNet.pth'))
186-
model = model.to(self.device)
165+
model = model.to(device)
166+
return model
167+
168+
def predict(self, model, input_img, device):
187169

188170
input_tensor = torch.tensor(input_img)
189-
input_tensor = input_tensor.permute((2, 0, 1)).unsqueeze(0).float().to(self.device)
171+
input_tensor = input_tensor.permute((2, 0, 1)).unsqueeze(0).float().to(device)
190172

191173
output = model(input_tensor)
192174
output= output.squeeze(0)
@@ -197,15 +179,19 @@ def __load_resnet50_model(self, input_img):
197179
predicted_mask = output.detach().cpu().numpy()
198180

199181
return np.uint8(predicted_mask)
182+
183+
class Vgg13BackBone(UnetBackBones):
200184

201-
def __load_vgg13_model(self, input_img):
185+
def load_model(self, device):
202186

203187
model = smp.Unet(encoder_name="vgg13", encoder_weights="imagenet", in_channels=3, classes = 1)
204188
model.load_state_dict(torch.load('model/vgg13/UNet.pth'))
205-
model = model.to(self.device)
206-
189+
model = model.to(device)
190+
191+
def predict(self, model, input_img, device):
192+
207193
input_tensor = torch.tensor(input_img)
208-
input_tensor = input_tensor.permute((2, 0, 1)).unsqueeze(0).float().to(self.device)
194+
input_tensor = input_tensor.permute((2, 0, 1)).unsqueeze(0).float().to(device)
209195

210196
output = model(input_tensor)
211197
output= output.squeeze(0)
@@ -216,23 +202,47 @@ def __load_vgg13_model(self, input_img):
216202
predicted_mask = output.detach().cpu().numpy()
217203

218204
return np.uint8(predicted_mask)
205+
206+
class NoBackBone(UnetBackBones):
219207

220-
def __load_orig_model(self, input_img):
221-
208+
def load_model(self, device):
222209
model = UNET()
223210
model.load_state_dict(torch.load('model/orig_unet/unetModel_20.pth'))
224-
model = model.to(self.device)
225-
211+
model = model.to(device)
212+
213+
def predict(self, model, input_img, device):
226214
input_tensor = torch.tensor(input_img)
227-
input_tensor = input_tensor.permute((2, 0, 1)).unsqueeze(0).float().to(self.device)
215+
input_tensor = input_tensor.permute((2, 0, 1)).unsqueeze(0).float().to(device)
228216

229217
output = model(input_tensor)
230218
output= output.squeeze(0)
231219
output[output>0.0] = 1.0
232-
output[output<=0.0]=0
220+
output[output<=0.0]= 0
233221
output = output.squeeze(0)
234222

235223
predicted_mask = output.detach().cpu().numpy()
236224

237225
return np.uint8(predicted_mask)
238226

227+
class UnetModel:
228+
"""
229+
The Unet model takes the character density map image
230+
and returns the masks of the ID card number, first name,
231+
surname and date of birth regions on this image.
232+
The Unet model was trained with 3 different backbones,
233+
the most successful of which was obtained from the resnet34 backbone.
234+
"""
235+
236+
def __init__(self, backbone:UnetBackBones = Res34BackBone(), device = "cuda"):
237+
self.device = device
238+
self.backbone = backbone
239+
240+
241+
def predict(self,input_img):
242+
243+
model = self.backbone.load_model(self.device)
244+
predicted_mask = self.backbone.predict(model, input_img, self.device)
245+
246+
return predicted_mask
247+
248+

0 commit comments

Comments
 (0)