Skip to content

Commit 091d7cd

Browse files
committed
wip
1 parent 11f0270 commit 091d7cd

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
The primary aim here is simplicity and minimal dependencies.
1818
"""
1919
import time
20-
from functools import partial
2120

2221
import datasets
2322
import jax
@@ -29,6 +28,8 @@
2928

3029
import jax_scaled_arithmetics as jsa
3130

31+
# from functools import partial
32+
3233

3334
def print_mean_std(name, v):
3435
data, scale = jsa.lax.get_data_scale(v)
@@ -60,7 +61,8 @@ def predict(params, inputs):
6061
# jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
6162
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
6263

63-
logits = jsa.ops.dynamic_rescale_l1_grad(logits)
64+
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
65+
# logits = logits.astype(np.float32)
6466
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
6567

6668
return logits - logsumexp(logits, axis=1, keepdims=True)
@@ -110,7 +112,7 @@ def data_stream():
110112
@jsa.autoscale
111113
def update(params, batch):
112114
grads = grad(loss)(params, batch)
113-
# return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
115+
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
114116
return [
115117
(jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db))
116118
for (w, b), (dw, db) in zip(params, grads)

0 commit comments

Comments
 (0)