Skip to content

Commit dbb1c1d

Browse files
committed
wip
1 parent 091d7cd commit dbb1c1d

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
The primary aim here is simplicity and minimal dependencies.
1818
"""
1919
import time
20+
from functools import partial
2021

2122
import datasets
2223
import jax
@@ -28,8 +29,6 @@
2829

2930
import jax_scaled_arithmetics as jsa
3031

31-
# from functools import partial
32-
3332

3433
def print_mean_std(name, v):
3534
data, scale = jsa.lax.get_data_scale(v)
@@ -58,19 +57,29 @@ def predict(params, inputs):
5857
final_w, final_b = params[-1]
5958
logits = jnp.dot(activations, final_w) + final_b
6059

61-
# jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
60+
jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
6261
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
6362

6463
logits = jsa.ops.dynamic_rescale_l2_grad(logits)
6564
# logits = logits.astype(np.float32)
6665
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
6766

68-
return logits - logsumexp(logits, axis=1, keepdims=True)
67+
logits = logits - logsumexp(logits, axis=1, keepdims=True)
68+
jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits)
69+
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
70+
return logits
6971

7072

7173
def loss(params, batch):
7274
inputs, targets = batch
7375
preds = predict(params, inputs)
76+
jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds)
77+
loss = jnp.sum(preds * targets, axis=1)
78+
# loss = jsa.ops.dynamic_rescale_l2(loss)
79+
jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss)
80+
loss = -jnp.mean(loss)
81+
jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss)
82+
return loss
7483
return -jnp.mean(jnp.sum(preds * targets, axis=1))
7584

7685

jax_scaled_arithmetics/ops/rescaling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from functools import partial
33

44
import jax
5-
import jax.numpy as jnp
5+
6+
# import jax.numpy as jnp
67
import numpy as np
78

89
from jax_scaled_arithmetics.core import ScaledArray, pow2_round

0 commit comments

Comments
 (0)