Skip to content

Commit 4a46bd8

Browse files
committed
refactor
1 parent a83864e commit 4a46bd8

File tree

6 files changed

+108
-127
lines changed

6 files changed

+108
-127
lines changed

break_captcha

+2-8
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,10 @@
2222

2323
import sys
2424
import Image
25-
from sklearn.externals import joblib
26-
from image_processing import DigitSeparator
2725
import urllib2
2826
import StringIO
2927

30-
def detect_number(model, image):
31-
digit_separator = DigitSeparator(image)
32-
digits = digit_separator.get_digits()
33-
labels = model.predict(digits)
34-
print ''.join(map(lambda x: str(int(x)), labels))
28+
from sklearn.externals import joblib
3529

3630
def main():
3731
model = joblib.load(sys.argv[1])
@@ -47,7 +41,7 @@ def main():
4741
else:
4842
with open(img_path) as f_image:
4943
image = Image.open(f_image).convert('L')
50-
detect_number(model, image)
44+
print model.decode_image(image)
5145
except EOFError:
5246
pass
5347

dataset.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (C) 2012 Rafael Cunha de Almeida <[email protected]>
2+
#
3+
# Permission is hereby granted, free of charge, to any person obtaining
4+
# a copy of this software and associated documentation files (the
5+
# "Software"), to deal in the Software without restriction, including
6+
# without limitation the rights to use, copy, modify, merge, publish,
7+
# distribute, sublicense, and/or sell copies of the Software, and to
8+
# permit persons to whom the Software is furnished to do so, subject to
9+
# the following conditions:
10+
#
11+
# The above copyright notice and this permission notice shall be
12+
# included in all copies or substantial portions of the Software.
13+
#
14+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15+
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16+
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
17+
# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
19+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
20+
# OTHER DEALINGS IN THE SOFTWARE.
21+
22+
import os
23+
import re
24+
import Image
25+
26+
def _get_files(base_dir):
27+
return map(lambda x: os.path.join(base_dir, x), os.listdir(base_dir))
28+
29+
def load_captcha_dataset(base_dir):
30+
files = _get_files(base_dir)
31+
dataset = []
32+
for file_path in files:
33+
file_name = os.path.basename(file_path)
34+
label = re.findall(r'^([0-9]+)-[0-9]+\..*$', file_name)[0]
35+
with open(file_path) as f:
36+
captcha = Image.open(f).convert('L')
37+
dataset.append((captcha, label))
38+
return zip(*dataset) # unzip

features.py

-19
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import ImageOps
2222
import ImageFilter
2323
import numpy
24-
from sklearn.feature_extraction import DictVectorizer
2524

2625
class compose_extractors(object):
2726
def __init__(self, extractors):
@@ -39,24 +38,6 @@ def __call__(self, arg):
3938
extractor(image, image_features)
4039
return image_features
4140

42-
class FeatureHandler(object):
43-
def __init__(self, extractor, dataset):
44-
self.extractor = extractor
45-
self.vectorizer = DictVectorizer()
46-
digits = self.__extract_features(dataset[0])
47-
self.train_digits = self.vectorizer.fit_transform(digits).toarray()
48-
self.labels = dataset[1]
49-
50-
def __extract_features(self, values):
51-
return map(self.extractor, values)
52-
53-
def sklearn_format_train(self):
54-
return self.train_digits,self.labels
55-
56-
def sklearn_format_test(self, items):
57-
features = self.__extract_features(items)
58-
return self.vectorizer.transform(features).toarray()
59-
6041
def border_detection(digit):
6142
digit.image = digit.image.filter(ImageFilter.FIND_EDGES)
6243
digit.pix = digit.image.load()

models.py

+49-31
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,23 @@
1212
#
1313
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1414
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE X
16-
# CONSORTIUM BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
15+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
# AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
1717
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
1818
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919

2020
from features import *
21+
from image_processing import DigitSeparator
2122
from functools import partial
22-
from sklearn import naive_bayes
23-
from sklearn import tree
24-
from sklearn import linear_model
25-
from sklearn import svm
26-
from sklearn import ensemble
27-
from sklearn.neighbors.nearest_centroid import NearestCentroid
28-
from sklearn import decomposition
23+
2924
import time
3025

31-
class ScikitWrapper(object):
32-
def __init__(self, engine, extractors, dataset):
33-
self.feature_handler = FeatureHandler(
34-
compose_extractors(extractors),
35-
dataset)
36-
self.engine = engine
37-
vector, labels = self.feature_handler.sklearn_format_train()
38-
self.engine.fit(vector, labels)
26+
from sklearn import svm
27+
from sklearn import ensemble
28+
from sklearn.feature_extraction import DictVectorizer
3929

40-
def predict(self, items):
41-
return self.engine.predict(self.feature_handler.sklearn_format_test(items))
30+
class ModelUnavailable(Exception):
31+
pass
4232

4333
ALL_EXTRACTORS = [
4434
x_histogram,
@@ -55,20 +45,48 @@ def predict(self, items):
5545
horizontal_symmetry,
5646
]
5747

58-
def NaiveBayes(dataset):
59-
return ScikitWrapper(naive_bayes.MultinomialNB(), [positions], dataset)
48+
SVM_EXTRACTORS = [positions]
49+
def svm_engine():
50+
return svm.SVC(kernel='poly', degree=2)
51+
52+
FOREST_EXTRACTORS = ALL_EXTRACTORS
53+
def forest_engine():
54+
return ensemble.RandomForestClassifier(n_estimators=50, n_jobs=2)
55+
56+
class CaptchaDecoder(object):
57+
def __init__(self, x, y):
58+
self.engine = svm_engine()
59+
self.feature_extractor = compose_extractors(SVM_EXTRACTORS)
60+
self.fit(x,y)
6061

61-
def DecisionTree(dataset):
62-
return ScikitWrapper(tree.DecisionTreeRegressor(), [positions, reversed_horizontal_silhouette, horizontal_silhouette], dataset)
62+
def fit(self, x, y):
63+
digits = []
64+
labels = []
65+
for image,param_labels in zip(x,y):
66+
separator = DigitSeparator(image)
67+
digits.extend(map(self.feature_extractor, separator.get_digits()))
68+
labels.extend(param_labels)
69+
self.vectorizer = DictVectorizer()
70+
train_array = self.vectorizer.fit_transform(digits).toarray()
71+
self.engine.fit(train_array, labels)
6372

64-
def SGD(dataset):
65-
return ScikitWrapper(linear_model.SGDClassifier(loss="hinge", penalty="l2"), [positions, reversed_horizontal_silhouette, horizontal_silhouette], dataset)
73+
def predict(self, x):
74+
prediction = []
75+
for image in x:
76+
separator = DigitSeparator(image)
77+
features = map(self.feature_extractor, separator.get_digits())
78+
digits = self.vectorizer.transform(features).toarray()
79+
labels = self.engine.predict(digits)
80+
prediction.append(''.join(map(lambda x: '%d'%x, labels)))
81+
return prediction
6682

67-
def SVM(dataset):
68-
return ScikitWrapper(svm.SVC(kernel='poly', degree=2), [positions], dataset)
83+
def score(self, data, labels):
84+
pred_labels = self.predict(data)
85+
matches = sum(map(lambda (x,y): x==y, zip(labels, pred_labels)))
86+
return float(matches)/len(labels)
6987

70-
def NN(dataset):
71-
return ScikitWrapper(NearestCentroid(), [positions, reversed_horizontal_silhouette, horizontal_silhouette], dataset)
88+
def decode_image(self, image):
89+
return self.predict([image])[0]
7290

73-
def RandomForest(dataset):
74-
return ScikitWrapper(ensemble.RandomForestClassifier(n_estimators=50, n_jobs=2), ALL_EXTRACTORS, dataset)
91+
def get_params(self, *args, **kwargs):
92+
return self.engine(*args, **kwargs)

profiler

+6-24
Original file line numberDiff line numberDiff line change
@@ -20,39 +20,21 @@
2020
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
2121
# OTHER DEALINGS IN THE SOFTWARE.
2222

23-
import os
2423
import sys
25-
import re
26-
import Image
2724
import time
28-
from sklearn.externals import joblib
29-
from image_processing import DigitSeparator
3025
import cProfile
3126

32-
model = joblib.load(sys.argv[1])
27+
from sklearn.externals import joblib
3328

34-
def make_test_dataset(files):
35-
dataset = []
36-
for file_path in files:
37-
file_name = os.path.basename(file_path)
38-
label = re.findall(r'^([0-9]+)-[0-9]+\..*$', file_name)[0]
39-
with open(file_path) as f:
40-
digits = DigitSeparator(Image.open(f).convert("L")).get_digits()
41-
dataset.append((label, digits))
42-
return dataset
29+
from dataset import load_captcha_dataset
30+
31+
model = joblib.load(sys.argv[1])
4332

4433
def prof():
45-
base_dir = sys.argv[2]
46-
files = map(lambda x: os.path.join(base_dir, x), os.listdir(base_dir))
47-
dataset = make_test_dataset(files)
34+
dataset = load_captcha_dataset(sys.argv[2])
4835
t0 = time.time()
49-
matches = 0
50-
for labels,digits in dataset:
51-
pred_labels = model.predict(digits)
52-
if labels == ''.join(map(lambda x: str(int(x)), pred_labels)):
53-
matches += 1
36+
print 'Matches:', model.score(dataset[0], dataset[1])
5437
spent_time = time.time() - t0
5538
print 'Spent time:', spent_time, 'avg per predict:', spent_time/len(dataset)
56-
print 'Matches:', float(matches)/len(dataset)
5739

5840
cProfile.run('prof()', filename='/tmp/profile')

train.py

+13-45
Original file line numberDiff line numberDiff line change
@@ -18,73 +18,41 @@
1818
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1919

2020
import sys
21-
import re
2221
import random
23-
import os
24-
import pickle
22+
2523
from models import *
2624
from sklearn.externals import joblib
2725
from image_processing import DigitSeparator
2826

29-
def make_train_dataset(files):
30-
dataset = []
31-
for file_path in files:
32-
file_name = os.path.basename(file_path)
33-
labels = re.findall(r'^([0-9]+)-[0-9]+\..*$', file_name)[0]
34-
with open(file_path) as f:
35-
digits = DigitSeparator(Image.open(f).convert("L")).get_digits()
36-
for i,digit in enumerate(digits):
37-
dataset.append((digit, labels[i]))
38-
return zip(*dataset) # unzip
39-
40-
def make_test_dataset(files):
41-
dataset = []
42-
for file_path in files:
43-
file_name = os.path.basename(file_path)
44-
label = re.findall(r'^([0-9]+)-[0-9]+\..*$', file_name)[0]
45-
with open(file_path) as f:
46-
digits = DigitSeparator(Image.open(f).convert("L")).get_digits()
47-
dataset.append((label, digits))
48-
return dataset
49-
50-
def get_files(base_dir):
51-
return map(lambda x: os.path.join(base_dir, x), os.listdir(base_dir))
27+
from dataset import load_captcha_dataset
5228

5329
def generate_datasets(base_dir):
54-
files = get_files(base_dir)
55-
random.shuffle(files)
56-
train_size = int(0.4*len(files))
57-
train = files[:train_size]
58-
test = files[train_size:]
59-
print "Number of trains:", len(train), "Number of tests:", len(test)
60-
train_dataset = make_train_dataset(train)
61-
test_dataset = make_test_dataset(test)
30+
dataset = load_captcha_dataset(base_dir)
31+
ziped_dataset = zip(*dataset)
32+
random.shuffle(ziped_dataset)
33+
dataset = zip(*ziped_dataset)
34+
train_size = int(0.4*len(ziped_dataset))
35+
train_dataset = (dataset[0][:train_size], dataset[1][:train_size])
36+
test_dataset = (dataset[0][train_size:], dataset[1][train_size:])
37+
print "Number of trains:", len(train_dataset), "Number of tests:", len(test_dataset)
6238
return train_dataset, test_dataset
6339

6440
def largest_label_size(dataset):
6541
return max(map(len, dataset.values()))
6642

67-
def test(model, test_dataset):
68-
matches = 0
69-
for labels,digits in test_dataset:
70-
pred_labels = model.predict(digits)
71-
if labels == ''.join(map(lambda x: str(int(x)), pred_labels)):
72-
matches += 1
73-
print 'Matches:', float(matches)/len(test_dataset)
74-
7543
def main():
7644
if len(sys.argv) > 2:
77-
train_dataset = make_train_dataset(get_files(sys.argv[1]))
45+
train_dataset = load_captcha_dataset(sys.argv[1])
7846
else:
7947
train_dataset, test_dataset = generate_datasets(sys.argv[1])
8048
t0 = time.time()
81-
model = SVM(train_dataset)
49+
model = CaptchaDecoder(train_dataset[0], train_dataset[1])
8250
print 'Train time:', time.time() - t0
8351
if len(sys.argv) > 2:
8452
joblib.dump(model, sys.argv[2])
8553
else:
8654
t0 = time.time()
87-
test(model, test_dataset)
55+
print 'Matches:', model.score(test_dataset[0], test_dataset[1])
8856
print 'Test time:', time.time() - t0
8957

9058
if __name__ == '__main__':

0 commit comments

Comments
 (0)