Skip to content

Commit 5e8cef6

Browse files
committed
Task 3 Initial Commit
0 parents  commit 5e8cef6

16 files changed

+2108
-0
lines changed

README.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Drone_Task3_OCR
2+
3+
1. Download Pretrained model from CRAFT and WIW github.
4+
2. python task3.py --craft_weight ./craft_mlt_25k.pth --wiw_weight ./recognition_model.pth
5+
6+
7+
# [STAGE 1] Detection: CRAFT
8+
- Official Code: https://github.com/clovaai/CRAFT-pytorch
9+
- pretrained model: https://drive.google.com/file/d/1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ/view
10+
11+
Baek, Youngmin, et al. "Character region awareness for text detection." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
12+
13+
# [STAGE 2] Recognition: WIW
14+
- Official Code: https://github.com/clovaai/deep-text-recognition-benchmark
15+
- pretrained model: https://drive.google.com/file/d/1b59rXuGGmKne1AuHnkgDzoYgKeETNMv9/view?usp=share_link
16+
17+
Baek, Jeonghun, et al. "What is wrong with scene text recognition model comparisons? dataset and model analysis." Proceedings of the IEEE/CVF international conference on computer vision. 2019.

__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .task3 import Task3

basenet/__init__.py

Whitespace-only changes.

basenet/vgg16_bn.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from collections import namedtuple
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.init as init
6+
from torchvision import models
7+
from torchvision.models.vgg import model_urls
8+
9+
def init_weights(modules):
10+
for m in modules:
11+
if isinstance(m, nn.Conv2d):
12+
init.xavier_uniform_(m.weight.data)
13+
if m.bias is not None:
14+
m.bias.data.zero_()
15+
elif isinstance(m, nn.BatchNorm2d):
16+
m.weight.data.fill_(1)
17+
m.bias.data.zero_()
18+
elif isinstance(m, nn.Linear):
19+
m.weight.data.normal_(0, 0.01)
20+
m.bias.data.zero_()
21+
22+
class vgg16_bn(torch.nn.Module):
23+
def __init__(self, pretrained=True, freeze=True):
24+
super(vgg16_bn, self).__init__()
25+
model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
26+
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
27+
self.slice1 = torch.nn.Sequential()
28+
self.slice2 = torch.nn.Sequential()
29+
self.slice3 = torch.nn.Sequential()
30+
self.slice4 = torch.nn.Sequential()
31+
self.slice5 = torch.nn.Sequential()
32+
for x in range(12): # conv2_2
33+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
34+
for x in range(12, 19): # conv3_3
35+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
36+
for x in range(19, 29): # conv4_3
37+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
38+
for x in range(29, 39): # conv5_3
39+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
40+
41+
# fc6, fc7 without atrous conv
42+
self.slice5 = torch.nn.Sequential(
43+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
44+
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
45+
nn.Conv2d(1024, 1024, kernel_size=1)
46+
)
47+
48+
if not pretrained:
49+
init_weights(self.slice1.modules())
50+
init_weights(self.slice2.modules())
51+
init_weights(self.slice3.modules())
52+
init_weights(self.slice4.modules())
53+
54+
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
55+
56+
if freeze:
57+
for param in self.slice1.parameters(): # only first conv
58+
param.requires_grad= False
59+
60+
def forward(self, X):
61+
h = self.slice1(X)
62+
h_relu2_2 = h
63+
h = self.slice2(h)
64+
h_relu3_2 = h
65+
h = self.slice3(h)
66+
h_relu4_3 = h
67+
h = self.slice4(h)
68+
h_relu5_3 = h
69+
h = self.slice5(h)
70+
h_fc7 = h
71+
vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
72+
out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
73+
return out

craft.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Copyright (c) 2019-present NAVER Corp.
3+
MIT License
4+
"""
5+
6+
# -*- coding: utf-8 -*-
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
11+
from .basenet.vgg16_bn import vgg16_bn, init_weights
12+
# from basenet.vgg16_bn import vgg16_bn, init_weights
13+
14+
class double_conv(nn.Module):
15+
def __init__(self, in_ch, mid_ch, out_ch):
16+
super(double_conv, self).__init__()
17+
self.conv = nn.Sequential(
18+
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
19+
nn.BatchNorm2d(mid_ch),
20+
nn.ReLU(inplace=True),
21+
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
22+
nn.BatchNorm2d(out_ch),
23+
nn.ReLU(inplace=True)
24+
)
25+
26+
def forward(self, x):
27+
x = self.conv(x)
28+
return x
29+
30+
31+
class CRAFT(nn.Module):
32+
def __init__(self, pretrained=False, freeze=False):
33+
super(CRAFT, self).__init__()
34+
35+
""" Base network """
36+
self.basenet = vgg16_bn(pretrained, freeze)
37+
38+
""" U network """
39+
self.upconv1 = double_conv(1024, 512, 256)
40+
self.upconv2 = double_conv(512, 256, 128)
41+
self.upconv3 = double_conv(256, 128, 64)
42+
self.upconv4 = double_conv(128, 64, 32)
43+
44+
num_class = 2
45+
self.conv_cls = nn.Sequential(
46+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
47+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
48+
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
49+
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
50+
nn.Conv2d(16, num_class, kernel_size=1),
51+
)
52+
53+
init_weights(self.upconv1.modules())
54+
init_weights(self.upconv2.modules())
55+
init_weights(self.upconv3.modules())
56+
init_weights(self.upconv4.modules())
57+
init_weights(self.conv_cls.modules())
58+
59+
def forward(self, x):
60+
""" Base network """
61+
sources = self.basenet(x)
62+
63+
""" U network """
64+
y = torch.cat([sources[0], sources[1]], dim=1)
65+
y = self.upconv1(y)
66+
67+
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
68+
y = torch.cat([y, sources[2]], dim=1)
69+
y = self.upconv2(y)
70+
71+
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
72+
y = torch.cat([y, sources[3]], dim=1)
73+
y = self.upconv3(y)
74+
75+
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
76+
y = torch.cat([y, sources[4]], dim=1)
77+
feature = self.upconv4(y)
78+
79+
y = self.conv_cls(feature)
80+
81+
return y.permute(0,2,3,1), feature
82+
83+
if __name__ == '__main__':
84+
model = CRAFT(pretrained=True).cuda()
85+
output, _ = model(torch.randn(1, 3, 768, 768).cuda())
86+
print(output.shape)

0 commit comments

Comments
 (0)