Skip to content

Commit b95eb82

Browse files
Fix: Cross Entropy loss function in logistict regression from Scratch
Changes to be committed: modified: LogisticRegressionGluon.ipynb modified: LogisticRegressionScratch.ipynb
1 parent ea033f2 commit b95eb82

File tree

2 files changed

+37
-67
lines changed

2 files changed

+37
-67
lines changed

LogisticRegressionGluon.ipynb

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
},
4646
{
4747
"cell_type": "code",
48-
"execution_count": 13,
48+
"execution_count": 19,
4949
"metadata": {},
5050
"outputs": [],
5151
"source": [
@@ -56,13 +56,13 @@
5656
" y_hat = net(X)\n",
5757
" y_hat = softmax(y_hat)\n",
5858
" accumulator += (y_hat.argmax(axis=1)==y.astype('float32')).sum()\n",
59-
" size = len(y)\n",
59+
" size += len(y)\n",
6060
" return accumulator / size"
6161
]
6262
},
6363
{
6464
"cell_type": "code",
65-
"execution_count": 15,
65+
"execution_count": 20,
6666
"metadata": {},
6767
"outputs": [],
6868
"source": [
@@ -73,7 +73,7 @@
7373
},
7474
{
7575
"cell_type": "code",
76-
"execution_count": 16,
76+
"execution_count": 21,
7777
"metadata": {},
7878
"outputs": [],
7979
"source": [
@@ -84,7 +84,7 @@
8484
},
8585
{
8686
"cell_type": "code",
87-
"execution_count": 17,
87+
"execution_count": 22,
8888
"metadata": {},
8989
"outputs": [],
9090
"source": [
@@ -94,22 +94,23 @@
9494
},
9595
{
9696
"cell_type": "code",
97-
"execution_count": null,
97+
"execution_count": 23,
9898
"metadata": {},
9999
"outputs": [
100100
{
101101
"name": "stdout",
102102
"output_type": "stream",
103103
"text": [
104-
"Epoch 0, acc: 501.333344\n",
105-
"Epoch 1, acc: 514.375000\n",
106-
"Epoch 2, acc: 518.010437\n",
107-
"Epoch 3, acc: 521.375000\n"
104+
"Epoch 0, acc: 0.805600\n",
105+
"Epoch 1, acc: 0.822417\n",
106+
"Epoch 2, acc: 0.826817\n",
107+
"Epoch 3, acc: 0.831850\n",
108+
"Epoch 4, acc: 0.838033\n"
108109
]
109110
}
110111
],
111112
"source": [
112-
"epochs = 10\n",
113+
"epochs = 5\n",
113114
"for epoch in range(epochs):\n",
114115
" for X, y in train_iter:\n",
115116
" with autograd.record():\n",

LogisticRegressionScratch.ipynb

Lines changed: 25 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 95,
5+
"execution_count": 126,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -14,7 +14,7 @@
1414
},
1515
{
1616
"cell_type": "code",
17-
"execution_count": 96,
17+
"execution_count": 127,
1818
"metadata": {},
1919
"outputs": [],
2020
"source": [
@@ -33,7 +33,7 @@
3333
},
3434
{
3535
"cell_type": "code",
36-
"execution_count": 97,
36+
"execution_count": 128,
3737
"metadata": {},
3838
"outputs": [],
3939
"source": [
@@ -45,17 +45,17 @@
4545
},
4646
{
4747
"cell_type": "code",
48-
"execution_count": 98,
48+
"execution_count": 168,
4949
"metadata": {},
5050
"outputs": [],
5151
"source": [
5252
"def cross_entropy(y_hat, y):\n",
53-
" return -nd.pick(y_hat, y, axis=1)"
53+
" return -nd.pick(y_hat, y).log()"
5454
]
5555
},
5656
{
5757
"cell_type": "code",
58-
"execution_count": 99,
58+
"execution_count": 169,
5959
"metadata": {},
6060
"outputs": [],
6161
"source": [
@@ -66,7 +66,7 @@
6666
},
6767
{
6868
"cell_type": "code",
69-
"execution_count": 100,
69+
"execution_count": 170,
7070
"metadata": {},
7171
"outputs": [],
7272
"source": [
@@ -76,7 +76,7 @@
7676
},
7777
{
7878
"cell_type": "code",
79-
"execution_count": 101,
79+
"execution_count": 171,
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
@@ -87,13 +87,12 @@
8787
" y_hat = net(X, W, b, num_features)\n",
8888
" accumulator += accuracy(y_hat, y)\n",
8989
" size += len(y)\n",
90-
" print(accumulator)\n",
9190
" return accumulator / size"
9291
]
9392
},
9493
{
9594
"cell_type": "code",
96-
"execution_count": 102,
95+
"execution_count": 172,
9796
"metadata": {},
9897
"outputs": [],
9998
"source": [
@@ -103,7 +102,7 @@
103102
},
104103
{
105104
"cell_type": "code",
106-
"execution_count": 103,
105+
"execution_count": 173,
107106
"metadata": {},
108107
"outputs": [],
109108
"source": [
@@ -114,11 +113,11 @@
114113
},
115114
{
116115
"cell_type": "code",
117-
"execution_count": 104,
116+
"execution_count": 174,
118117
"metadata": {},
119118
"outputs": [],
120119
"source": [
121-
"num_inputs = 784\n",
120+
"num_inputs = 28 * 28\n",
122121
"num_outputs = 10\n",
123122
"W = nd.random.normal(scale=0.01, shape=(num_inputs, num_outputs))\n",
124123
"b = nd.zeros(num_outputs)\n",
@@ -128,53 +127,23 @@
128127
},
129128
{
130129
"cell_type": "code",
131-
"execution_count": 124,
130+
"execution_count": 175,
132131
"metadata": {},
133132
"outputs": [
134133
{
135134
"name": "stdout",
136135
"output_type": "stream",
137136
"text": [
138-
"\n",
139-
"[40504.]\n",
140-
"<NDArray 1 @cpu(0)>\n",
141-
"Epoch 0, acc: 0.675067\n",
142-
"\n",
143-
"[41122.]\n",
144-
"<NDArray 1 @cpu(0)>\n",
145-
"Epoch 1, acc: 0.685367\n",
146-
"\n",
147-
"[42817.]\n",
148-
"<NDArray 1 @cpu(0)>\n",
149-
"Epoch 2, acc: 0.713617\n",
150-
"\n",
151-
"[44973.]\n",
152-
"<NDArray 1 @cpu(0)>\n",
153-
"Epoch 3, acc: 0.749550\n",
154-
"\n",
155-
"[45543.]\n",
156-
"<NDArray 1 @cpu(0)>\n",
157-
"Epoch 4, acc: 0.759050\n",
158-
"\n",
159-
"[45997.]\n",
160-
"<NDArray 1 @cpu(0)>\n",
161-
"Epoch 5, acc: 0.766617\n",
162-
"\n",
163-
"[46272.]\n",
164-
"<NDArray 1 @cpu(0)>\n",
165-
"Epoch 6, acc: 0.771200\n",
166-
"\n",
167-
"[46486.]\n",
168-
"<NDArray 1 @cpu(0)>\n",
169-
"Epoch 7, acc: 0.774767\n",
170-
"\n",
171-
"[46629.]\n",
172-
"<NDArray 1 @cpu(0)>\n",
173-
"Epoch 8, acc: 0.777150\n",
174-
"\n",
175-
"[46824.]\n",
176-
"<NDArray 1 @cpu(0)>\n",
177-
"Epoch 9, acc: 0.780400\n"
137+
"Epoch 0, acc: 0.805417\n",
138+
"Epoch 1, acc: 0.820217\n",
139+
"Epoch 2, acc: 0.829133\n",
140+
"Epoch 3, acc: 0.834200\n",
141+
"Epoch 4, acc: 0.839000\n",
142+
"Epoch 5, acc: 0.841550\n",
143+
"Epoch 6, acc: 0.844400\n",
144+
"Epoch 7, acc: 0.845233\n",
145+
"Epoch 8, acc: 0.846100\n",
146+
"Epoch 9, acc: 0.846767\n"
178147
]
179148
}
180149
],
@@ -197,7 +166,7 @@
197166
},
198167
{
199168
"cell_type": "code",
200-
"execution_count": 123,
169+
"execution_count": 154,
201170
"metadata": {},
202171
"outputs": [
203172
{
@@ -208,7 +177,7 @@
208177
"(256,)\n",
209178
"(256,)\n",
210179
"\n",
211-
"[0.08203125]\n",
180+
"[0.11328125]\n",
212181
"<NDArray 1 @cpu(0)>\n"
213182
]
214183
}

0 commit comments

Comments
 (0)