Skip to content

Commit 6506320

Browse files
author
Youngmin Baek
committed
initial commit
0 parents  commit 6506320

11 files changed

+588
-0
lines changed

.gitignore

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
*.pyc
2+
*.swp
3+
*.pkl
4+
*.pth
5+
result*
6+
weights*

README.md

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
## CRAFT: Character-Region Awareness For Text detection
2+
Official Pytorch implementation of CRAFT text detector | [Paper](https://arxiv.org/abs/1904.01941) | [Pretrained Model](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ) | [Supplementary](https://youtu.be/HI8MzpY8KMI)
3+
4+
**[Youngmin Baek](mailto:[email protected]), Bado Lee, Dongyoon Han, Sangdoo Yun, Hwalsuk Lee.**
5+
6+
Clova AI Research, NAVER Corp.
7+
8+
### Sample Results
9+
10+
### Overview
11+
PyTorch implementation for CRAFT text detector that effectively detect text area by exploring each character region and affinity between characters. The bounding box of texts are obtained by simply finding minimum bounding rectangles on binary map after thresholding character region and affinity scores.
12+
13+
<img width="1000" alt="teaser" src="./figures/craft_example.gif">
14+
15+
## Updates
16+
**4 Jun, 2019**: Initial update
17+
18+
19+
## Getting started
20+
### Install dependencies
21+
#### Requirements
22+
- PyTorch>=0.4.1
23+
- torchvision>=0.2.1
24+
- opencv-python>=3.4.2
25+
- check requiremtns.txt
26+
```
27+
pip install -r requirements.txt
28+
```
29+
30+
### Training
31+
We are currently in the process of cleaning training code for disclosure.
32+
33+
### Test instruction using pretrained model
34+
- Download [Trained Model on IC13,IC17](https://drive.google.com/open?id=1Jk4eGD7crsqCCg9C9VjCLkMN3ze8kutZ)
35+
* Run with pretrained model
36+
``` (with python 3.7)
37+
python test.py --trained_model=[weightfile] --test_folder=[folder path to test images]
38+
```
39+
40+
The result image and socre maps will be saved to `./result` by default.
41+
42+
### Arguments
43+
* `--trained_model`: pretrained model
44+
* `--text_threshold`: text confidence threshold
45+
* `--low_text`: text low-bound score
46+
* `--link_threshold`: link confidence threshold
47+
* `--canvas_size`: max image size for inference
48+
* `--mag_ratio`: image magnification ratio
49+
* `--show_time`: show processing time
50+
* `--test_folder`: folder path to input images
51+
52+
## Citation
53+
```
54+
@article{baek2019character,
55+
title={Character Region Awareness for Text Detection},
56+
author={Baek, Youngmin and Lee, Bado and Han, Dongyoon and Yun, Sangdoo and Lee, Hwalsuk},
57+
journal={arXiv preprint arXiv:1904.01941},
58+
year={2019}
59+
}
60+
```
61+
62+
## License
63+
```
64+
Copyright (c) 2019-present NAVER Corp.
65+
66+
Permission is hereby granted, free of charge, to any person obtaining a copy
67+
of this software and associated documentation files (the "Software"), to deal
68+
in the Software without restriction, including without limitation the rights
69+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
70+
copies of the Software, and to permit persons to whom the Software is
71+
furnished to do so, subject to the following conditions:
72+
73+
The above copyright notice and this permission notice shall be included in
74+
all copies or substantial portions of the Software.
75+
76+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
77+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
78+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
79+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
80+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
81+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
82+
THE SOFTWARE.
83+
```

basenet/__init__.py

Whitespace-only changes.

basenet/vgg16_bn.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from collections import namedtuple
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.init as init
6+
from torchvision import models
7+
from torchvision.models.vgg import model_urls
8+
9+
def init_weights(modules):
10+
for m in modules:
11+
if isinstance(m, nn.Conv2d):
12+
init.xavier_uniform_(m.weight.data)
13+
if m.bias is not None:
14+
m.bias.data.zero_()
15+
elif isinstance(m, nn.BatchNorm2d):
16+
m.weight.data.fill_(1)
17+
m.bias.data.zero_()
18+
elif isinstance(m, nn.Linear):
19+
m.weight.data.normal_(0, 0.01)
20+
m.bias.data.zero_()
21+
22+
class vgg16_bn(torch.nn.Module):
23+
def __init__(self, pretrained=True, freeze=True):
24+
super(vgg16_bn, self).__init__()
25+
model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
26+
vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
27+
self.slice1 = torch.nn.Sequential()
28+
self.slice2 = torch.nn.Sequential()
29+
self.slice3 = torch.nn.Sequential()
30+
self.slice4 = torch.nn.Sequential()
31+
self.slice5 = torch.nn.Sequential()
32+
for x in range(12): # conv2_2
33+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
34+
for x in range(12, 19): # conv3_3
35+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
36+
for x in range(19, 29): # conv4_3
37+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
38+
for x in range(29, 39): # conv5_3
39+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
40+
41+
# fc6, fc7 without atrous conv
42+
self.slice5 = torch.nn.Sequential(
43+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
44+
nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
45+
nn.Conv2d(1024, 1024, kernel_size=1)
46+
)
47+
48+
if not pretrained:
49+
init_weights(self.slice1.modules())
50+
init_weights(self.slice2.modules())
51+
init_weights(self.slice3.modules())
52+
init_weights(self.slice4.modules())
53+
54+
init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
55+
56+
if freeze:
57+
for param in self.slice1.parameters(): # only first conv
58+
param.requires_grad= False
59+
60+
def forward(self, X):
61+
h = self.slice1(X)
62+
h_relu2_2 = h
63+
h = self.slice2(h)
64+
h_relu3_2 = h
65+
h = self.slice3(h)
66+
h_relu4_3 = h
67+
h = self.slice4(h)
68+
h_relu5_3 = h
69+
h = self.slice5(h)
70+
h_fc7 = h
71+
vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
72+
out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
73+
return out

craft.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# -*- coding: utf-8 -*-
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
from basenet.vgg16_bn import vgg16_bn, init_weights
7+
8+
class double_conv(nn.Module):
9+
def __init__(self, in_ch, mid_ch, out_ch):
10+
super(double_conv, self).__init__()
11+
self.conv = nn.Sequential(
12+
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
13+
nn.BatchNorm2d(mid_ch),
14+
nn.ReLU(inplace=True),
15+
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
16+
nn.BatchNorm2d(out_ch),
17+
nn.ReLU(inplace=True)
18+
)
19+
20+
def forward(self, x):
21+
x = self.conv(x)
22+
return x
23+
24+
25+
class CRAFT(nn.Module):
26+
def __init__(self, pretrained=False, freeze=False):
27+
super(CRAFT, self).__init__()
28+
29+
""" Base network """
30+
self.basenet = vgg16_bn(pretrained, freeze)
31+
32+
""" U network """
33+
self.upconv1 = double_conv(1024, 512, 256)
34+
self.upconv2 = double_conv(512, 256, 128)
35+
self.upconv3 = double_conv(256, 128, 64)
36+
self.upconv4 = double_conv(128, 64, 32)
37+
38+
num_class = 2
39+
self.conv_cls = nn.Sequential(
40+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
41+
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
42+
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
43+
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
44+
nn.Conv2d(16, num_class, kernel_size=1),
45+
)
46+
47+
init_weights(self.upconv1.modules())
48+
init_weights(self.upconv2.modules())
49+
init_weights(self.upconv3.modules())
50+
init_weights(self.upconv4.modules())
51+
init_weights(self.conv_cls.modules())
52+
53+
def forward(self, x):
54+
""" Base network """
55+
sources = self.basenet(x)
56+
57+
""" U network """
58+
y = torch.cat([sources[0], sources[1]], dim=1)
59+
y = self.upconv1(y)
60+
61+
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
62+
y = torch.cat([y, sources[2]], dim=1)
63+
y = self.upconv2(y)
64+
65+
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
66+
y = torch.cat([y, sources[3]], dim=1)
67+
y = self.upconv3(y)
68+
69+
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
70+
y = torch.cat([y, sources[4]], dim=1)
71+
feature = self.upconv4(y)
72+
73+
y = self.conv_cls(feature)
74+
75+
return y.permute(0,2,3,1), feature
76+
77+
if __name__ == '__main__':
78+
model = CRAFT(pretrained=True).cuda()
79+
output, _ = model(torch.randn(1, 3, 768, 768).cuda())
80+
print(output.shape)

craft_utils.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# -*- coding: utf-8 -*-
2+
import numpy as np
3+
import cv2
4+
import math
5+
6+
7+
def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
8+
# prepare data
9+
linkmap = linkmap.copy()
10+
textmap = textmap.copy()
11+
img_h, img_w = textmap.shape
12+
13+
""" labeling method """
14+
ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
15+
ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
16+
17+
text_score_comb = np.clip(text_score + link_score, 0, 1)
18+
nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
19+
20+
det = []
21+
mapper = []
22+
for k in range(1,nLabels):
23+
# size filtering
24+
size = stats[k, cv2.CC_STAT_AREA]
25+
if size < 10: continue
26+
27+
# thresholding
28+
if np.max(textmap[labels==k]) < text_threshold: continue
29+
30+
# make segmentation map
31+
segmap = np.zeros(textmap.shape, dtype=np.uint8)
32+
segmap[labels==k] = 255
33+
segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
34+
x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
35+
w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
36+
niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
37+
sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
38+
# boundary check
39+
if sx < 0 : sx = 0
40+
if sy < 0 : sy = 0
41+
if ex >= img_w: ex = img_w
42+
if ey >= img_h: ey = img_h
43+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
44+
segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)
45+
46+
# make box
47+
np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
48+
rectangle = cv2.minAreaRect(np_contours)
49+
box = cv2.boxPoints(rectangle)
50+
51+
# align diamond-shape
52+
w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
53+
box_ratio = max(w, h) / (min(w, h) + 1e-5)
54+
if abs(1 - box_ratio) <= 0.1:
55+
l, r = min(np_contours[:,0]), max(np_contours[:,0])
56+
t, b = min(np_contours[:,1]), max(np_contours[:,1])
57+
box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
58+
59+
# make clock-wise order
60+
startidx = box.sum(axis=1).argmin()
61+
box = np.roll(box, 4-startidx, 0)
62+
box = np.array(box)
63+
64+
det.append(box)
65+
mapper.append(k)
66+
67+
return det, labels, mapper
68+
69+
70+
def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text):
71+
boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
72+
73+
return boxes
74+
75+
def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
76+
if len(polys) > 0:
77+
polys = np.array(polys)
78+
for k in range(len(polys)):
79+
if polys[k] is not None:
80+
polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
81+
return polys

figures/craft_example.gif

869 KB
Loading

0 commit comments

Comments
 (0)