|
17 | 17 | The primary aim here is simplicity and minimal dependencies.
|
18 | 18 | """
|
19 | 19 | import time
|
| 20 | +from functools import partial |
20 | 21 |
|
21 | 22 | import datasets
|
22 | 23 | import jax
|
|
28 | 29 |
|
29 | 30 | import jax_scaled_arithmetics as jsa
|
30 | 31 |
|
31 |
| -# from functools import partial |
32 |
| - |
33 | 32 |
|
34 | 33 | def print_mean_std(name, v):
|
35 | 34 | data, scale = jsa.lax.get_data_scale(v)
|
@@ -58,19 +57,29 @@ def predict(params, inputs):
|
58 | 57 | final_w, final_b = params[-1]
|
59 | 58 | logits = jnp.dot(activations, final_w) + final_b
|
60 | 59 |
|
61 |
| - # jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) |
| 60 | + jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) |
62 | 61 | # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
|
63 | 62 |
|
64 | 63 | logits = jsa.ops.dynamic_rescale_l2_grad(logits)
|
65 | 64 | # logits = logits.astype(np.float32)
|
66 | 65 | # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
|
67 | 66 |
|
68 |
| - return logits - logsumexp(logits, axis=1, keepdims=True) |
| 67 | + logits = logits - logsumexp(logits, axis=1, keepdims=True) |
| 68 | + jsa.ops.debug_callback(partial(print_mean_std, "Logits2"), logits) |
| 69 | + (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) |
| 70 | + return logits |
69 | 71 |
|
70 | 72 |
|
71 | 73 | def loss(params, batch):
|
72 | 74 | inputs, targets = batch
|
73 | 75 | preds = predict(params, inputs)
|
| 76 | + jsa.ops.debug_callback(partial(print_mean_std, "Preds"), preds) |
| 77 | + loss = jnp.sum(preds * targets, axis=1) |
| 78 | + # loss = jsa.ops.dynamic_rescale_l2(loss) |
| 79 | + jsa.ops.debug_callback(partial(print_mean_std, "LOSS1"), loss) |
| 80 | + loss = -jnp.mean(loss) |
| 81 | + jsa.ops.debug_callback(partial(print_mean_std, "LOSS2"), loss) |
| 82 | + return loss |
74 | 83 | return -jnp.mean(jnp.sum(preds * targets, axis=1))
|
75 | 84 |
|
76 | 85 |
|
|
0 commit comments