-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdigit_recognizer.py
78 lines (76 loc) · 2 KB
/
digit_recognizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import matplotlib.pyplot as plot
import numpy
import os.path
import pandas
import pickle
import sklearn.externals as external
import sklearn.svm as svm
PICKLE_SVM = 'digit_recognization_svm.pickle'
CLFSVM = None
data = pandas.read_csv('./data/train.csv')
test = pandas.read_csv('./data/test.csv')
#print('%s row, %s column\n' % data.shape)
#print(data.head())
def show(i, data):
"""
show the image represented by the ith row.
param:
i: integer
data: data from train.csv
return:
none.
"""
width = height = pow(data.shape[1] - 1, 0.5)
image = data.iloc[i, 1 : ].reshape(width, height)
plot.imshow(image)
plot.show()
# show the first 8 pics
##for i in range(8): show(i, data)
def scale(data):
"""
scale pixel values to fit SVM: [0, 255] → [0, 1].
param:
data: the train matrix
return:
scaled train data.
"""
r = None
r = numpy.divide(data.iloc[ : , 1 : ], 255.0)
r = numpy.concatenate((data.iloc[ : , : 1], r), axis = 1)
return r
def train_svm():
"""
this function trains a SVM model with default the panelty and kernel.
TODO use CV to select the parameters.
param:
None
return:
a SVM classifier.
"""
classifier = svm.SVC()
X = data.values[ : , 1 : ]
y = data.values[ : , 0]
classifier.fit(X, y)
CLFSVM = classifier
joblib.dump(classifier, PICKLE_SVM)
return classifier
def create_svm_classifier():
"""
create a SVM classifier.
param:
None
return:
a SVM classifier.
"""
r = joblib.load(PICKLE_SVM) if os.path.isfile(PICKLE_SVM) else train_svm()
CLFSVM = r
return r
def recognize_svm(image):
"""
recognize a digit image by using the trained SVM model.
param:
image: a 28 * 28 matrix representing a pixels pic, each element of the matrix is an integer in [0, 255]
return:
the digit of the image.
"""
return CLFSVM.predict(image).values[0, 0] if CLFSVM else create_svm_classifier()