|
| 1 | +from mxnet import nd, gluon |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import random |
| 4 | + |
| 5 | +def show_images(imgs, rows, cols): |
| 6 | + _, axes = plt.subplots(rows, cols) |
| 7 | + axes = axes.flatten() |
| 8 | + for i, (ax, img) in enumerate(zip(axes, imgs)): |
| 9 | + ax.imshow(img) |
| 10 | + ax.get_xaxis().set_visible(False) |
| 11 | + ax.get_yaxis().set_visible(False) |
| 12 | + plt.show() |
| 13 | + return axes |
| 14 | + |
| 15 | +def transform(data, label): |
| 16 | + return (data/128).astype('float32').squeeze(axis=-1), label |
| 17 | + |
| 18 | +def train(train, n_classes): |
| 19 | + X, Y = train[:] |
| 20 | + n_y = nd.zeros(n_classes) |
| 21 | + for y in range(n_classes): |
| 22 | + n_y[y] = (Y==y).sum() |
| 23 | + P_y = n_y / n_y.sum() |
| 24 | + |
| 25 | + n_x = nd.zeros((n_classes, 28, 28)) |
| 26 | + for y in range(n_classes): |
| 27 | + n_x[y] = nd.array(X.asnumpy()[Y==y].sum(axis=0)) |
| 28 | + P_xy = (n_x+1) / (n_y+1).reshape((10, 1, 1)) |
| 29 | + show_images(P_xy.asnumpy(), 2, 5) |
| 30 | + |
| 31 | + return P_xy, P_y |
| 32 | + |
| 33 | +def predict(img, P_xy, P_y): |
| 34 | + img = img.expand_dims(axis=0) |
| 35 | + log_P_xy = nd.log(P_xy) |
| 36 | + neg_log_P_xy = nd.log(1-P_xy) |
| 37 | + pxy = log_P_xy * img + neg_log_P_xy * (1-img) |
| 38 | + pxy = pxy.reshape((10, -1)).sum(axis=1) |
| 39 | + |
| 40 | + probs = pxy+nd.log(P_y) |
| 41 | + |
| 42 | + return probs.argmax(axis=0).asscalar() |
| 43 | + |
| 44 | +def test(test, n_classes, P_xy, P_y): |
| 45 | + X, Y = test[:] |
| 46 | + correct = 0 |
| 47 | + for i, img in enumerate(X): |
| 48 | + result = predict(img, P_xy, P_y) |
| 49 | + if result == Y[i]: |
| 50 | + correct += 1 |
| 51 | + acc = (correct/X.shape[0]) * 100 |
| 52 | + print("Accuracy {}%".format(acc)) |
| 53 | + return acc |
| 54 | + |
| 55 | + |
| 56 | +def main(): |
| 57 | + n_classes = 10 |
| 58 | + |
| 59 | + mnist_train = gluon.data.vision.datasets.MNIST(train=True, transform=transform) |
| 60 | + mnist_test = gluon.data.vision.datasets.MNIST(train=False, transform=transform) |
| 61 | + |
| 62 | + P_xy, P_y = train(mnist_train, n_classes) |
| 63 | + |
| 64 | + acc = test(mnist_test, n_classes, P_xy, P_y) |
| 65 | + |
| 66 | + test_X, test_Y = mnist_test[:] |
| 67 | + index = random.randint(0, test_X.shape[0]) |
| 68 | + result = predict(test_X[index], P_xy, P_y) |
| 69 | + |
| 70 | + print("Predicted value {}".format(result)) |
| 71 | + plt.imshow(test_X[index].asnumpy()) |
| 72 | + plt.show() |
| 73 | + |
| 74 | +if __name__ == "__main__": |
| 75 | + main() |
0 commit comments