Skip to content

Commit c8067a4

Browse files
committed
making it serializable
1 parent d9f2183 commit c8067a4

File tree

4 files changed

+36
-27
lines changed

4 files changed

+36
-27
lines changed

break_captcha

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def main():
4141
else:
4242
with open(img_path) as f_image:
4343
image = Image.open(f_image).convert('L')
44-
print model.decode_image(image)
44+
print model.predict(image)
4545
except EOFError:
4646
pass
4747

features.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,29 @@ def __call__(self, arg):
3838
extractor(image, image_features)
3939
return image_features
4040

41-
def _border_detection(digit):
42-
digit.image = digit.image.filter(ImageFilter.FIND_EDGES)
43-
digit.pix = digit.image.load()
44-
return digit
41+
class border(object):
42+
def __init__(self, callback):
43+
self.callback = callback
4544

46-
def border(callback):
47-
return lambda digit,features: callback(_border_detection(digit), features, prefix='border-')
45+
def __border_detection(self, digit):
46+
digit.image = digit.image.filter(ImageFilter.FIND_EDGES)
47+
digit.pix = digit.image.load()
48+
return digit
4849

49-
def _scale_down(digit):
50-
digit.image = digit.image.resize((16,16), Image.BICUBIC)
51-
digit.pix = digit.image.load()
52-
return digit
50+
def __call__(self, digit, features):
51+
return callback(self.__border_detection(digit), features, prefix='border-')
5352

54-
def scale_image_down(callback):
55-
return lambda digit,features: callback(_scale_down(digit), features, prefix='scaled-')
53+
class scale_image_down(object):
54+
def __init__(self, callback):
55+
self.callback = callback
56+
57+
def __scale_down(self, digit):
58+
digit.image = digit.image.resize((16,16), Image.BICUBIC)
59+
digit.pix = digit.image.load()
60+
return digit
61+
62+
def __call__(self, digit, features):
63+
return self.callback(self.__scale_down(digit), features, prefix='scaled-')
5664

5765
def is_white(color):
5866
return color > 230

models.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from sklearn import ensemble
2828
from sklearn.feature_extraction import DictVectorizer
2929

30-
class ModelUnavailable(Exception):
31-
pass
3230

3331
ALL_EXTRACTORS = [
3432
x_histogram,
@@ -69,23 +67,26 @@ def fit(self, x, y):
6967
train_array = self.vectorizer.fit_transform(digits).toarray()
7068
self.engine.fit(train_array, labels)
7169

70+
def __make_prediction(self, image):
71+
separator = DigitSeparator(image)
72+
features = map(self.feature_extractor, separator.get_digits())
73+
digits = self.vectorizer.transform(features).toarray()
74+
labels = self.engine.predict(digits)
75+
return ''.join(map(lambda x: '%d'%x, labels))
76+
7277
def predict(self, x):
73-
prediction = []
74-
for image in x:
75-
separator = DigitSeparator(image)
76-
features = map(self.feature_extractor, separator.get_digits())
77-
digits = self.vectorizer.transform(features).toarray()
78-
labels = self.engine.predict(digits)
79-
prediction.append(''.join(map(lambda x: '%d'%x, labels)))
80-
return prediction
78+
if not hasattr(x, '__iter__'):
79+
return self.__make_prediction(x)
80+
else:
81+
prediction = []
82+
for image in x:
83+
prediction.append(self.__make_prediction(image))
84+
return prediction
8185

8286
def score(self, data, labels):
8387
pred_labels = self.predict(data)
8488
matches = sum(map(lambda (x,y): x==y, zip(labels, pred_labels)))
8589
return float(matches)/len(labels)
8690

87-
def decode_image(self, image):
88-
return self.predict([image])[0]
89-
9091
def get_params(self, *args, **kwargs):
9192
return self.engine.get_params(*args, **kwargs)

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def main():
3939
joblib.dump(model, sys.argv[2])
4040
else:
4141
t0 = time.time()
42-
scores = cross_validation.cross_val_score(model, numpy.array(dataset[0], dtype=object), dataset[1], cv=50)
42+
scores = cross_validation.cross_val_score(model, numpy.array(dataset[0], dtype=object), dataset[1], cv=5)
4343
print 'Accuracy: %0.2f (+/- %0.2f)' % (scores.mean(), scores.std()/2)
4444
print 'Validation time:', time.time() - t0
4545

0 commit comments

Comments
 (0)