|
17 | 17 | The primary aim here is simplicity and minimal dependencies.
|
18 | 18 | """
|
19 | 19 | import time
|
20 |
| -from functools import partial |
21 | 20 |
|
22 | 21 | import datasets
|
23 | 22 | import jax
|
|
29 | 28 |
|
30 | 29 | import jax_scaled_arithmetics as jsa
|
31 | 30 |
|
| 31 | +# from functools import partial |
| 32 | + |
32 | 33 |
|
33 | 34 | def print_mean_std(name, v):
|
34 | 35 | data, scale = jsa.lax.get_data_scale(v)
|
@@ -60,7 +61,8 @@ def predict(params, inputs):
|
60 | 61 | # jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
|
61 | 62 | # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
|
62 | 63 |
|
63 |
| - logits = jsa.ops.dynamic_rescale_l1_grad(logits) |
| 64 | + logits = jsa.ops.dynamic_rescale_l2_grad(logits) |
| 65 | + # logits = logits.astype(np.float32) |
64 | 66 | # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
|
65 | 67 |
|
66 | 68 | return logits - logsumexp(logits, axis=1, keepdims=True)
|
@@ -110,7 +112,7 @@ def data_stream():
|
110 | 112 | @jsa.autoscale
|
111 | 113 | def update(params, batch):
|
112 | 114 | grads = grad(loss)(params, batch)
|
113 |
| - # return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] |
| 115 | + return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] |
114 | 116 | return [
|
115 | 117 | (jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db))
|
116 | 118 | for (w, b), (dw, db) in zip(params, grads)
|
|
0 commit comments