Skip to content

Commit 73bd2ed

Browse files
Feat: Implement Naive Bayes Classifier for MNIST digits.
On branch master Changes to be committed: new file: naive_bayes.py
0 parents  commit 73bd2ed

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

naive_bayes.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from mxnet import nd, gluon
2+
import matplotlib.pyplot as plt
3+
import random
4+
5+
def show_images(imgs, rows, cols):
6+
_, axes = plt.subplots(rows, cols)
7+
axes = axes.flatten()
8+
for i, (ax, img) in enumerate(zip(axes, imgs)):
9+
ax.imshow(img)
10+
ax.get_xaxis().set_visible(False)
11+
ax.get_yaxis().set_visible(False)
12+
plt.show()
13+
return axes
14+
15+
def transform(data, label):
16+
return (data/128).astype('float32').squeeze(axis=-1), label
17+
18+
def train(train, n_classes):
19+
X, Y = train[:]
20+
n_y = nd.zeros(n_classes)
21+
for y in range(n_classes):
22+
n_y[y] = (Y==y).sum()
23+
P_y = n_y / n_y.sum()
24+
25+
n_x = nd.zeros((n_classes, 28, 28))
26+
for y in range(n_classes):
27+
n_x[y] = nd.array(X.asnumpy()[Y==y].sum(axis=0))
28+
P_xy = (n_x+1) / (n_y+1).reshape((10, 1, 1))
29+
show_images(P_xy.asnumpy(), 2, 5)
30+
31+
return P_xy, P_y
32+
33+
def predict(img, P_xy, P_y):
34+
img = img.expand_dims(axis=0)
35+
log_P_xy = nd.log(P_xy)
36+
neg_log_P_xy = nd.log(1-P_xy)
37+
pxy = log_P_xy * img + neg_log_P_xy * (1-img)
38+
pxy = pxy.reshape((10, -1)).sum(axis=1)
39+
40+
probs = pxy+nd.log(P_y)
41+
42+
return probs.argmax(axis=0).asscalar()
43+
44+
def test(test, n_classes, P_xy, P_y):
45+
X, Y = test[:]
46+
correct = 0
47+
for i, img in enumerate(X):
48+
result = predict(img, P_xy, P_y)
49+
if result == Y[i]:
50+
correct += 1
51+
acc = (correct/X.shape[0]) * 100
52+
print("Accuracy {}%".format(acc))
53+
return acc
54+
55+
56+
def main():
57+
n_classes = 10
58+
59+
mnist_train = gluon.data.vision.datasets.MNIST(train=True, transform=transform)
60+
mnist_test = gluon.data.vision.datasets.MNIST(train=False, transform=transform)
61+
62+
P_xy, P_y = train(mnist_train, n_classes)
63+
64+
acc = test(mnist_test, n_classes, P_xy, P_y)
65+
66+
test_X, test_Y = mnist_test[:]
67+
index = random.randint(0, test_X.shape[0])
68+
result = predict(test_X[index], P_xy, P_y)
69+
70+
print("Predicted value {}".format(result))
71+
plt.imshow(test_X[index].asnumpy())
72+
plt.show()
73+
74+
if __name__ == "__main__":
75+
main()

0 commit comments

Comments
 (0)