Skip to content

Commit 11f0270

Browse files
committed
wip
1 parent fa546aa commit 11f0270

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131

3232

3333
def print_mean_std(name, v):
34-
_, scale = jsa.lax.get_data_scale(v)
34+
data, scale = jsa.lax.get_data_scale(v)
3535
# Always use np.float32, to avoid floating errors in descaling + stats.
36-
v = jsa.asarray(v, dtype=np.float32)
36+
v = jsa.asarray(data, dtype=np.float32)
3737
m, s = np.mean(v), np.std(v)
3838
print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})")
3939

@@ -45,19 +45,23 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
4545
def predict(params, inputs):
4646
activations = inputs
4747
for w, b in params[:-1]:
48-
jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49-
jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
50-
(w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
48+
# jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49+
# jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
50+
# (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
5151

5252
# Matmul + relu
5353
outputs = jnp.dot(activations, w) + b
5454
activations = jnp.maximum(outputs, 0)
55+
# activations = jsa.ops.dynamic_rescale_l2_grad(activations)
5556

5657
final_w, final_b = params[-1]
5758
logits = jnp.dot(activations, final_w) + final_b
5859

59-
jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
60-
(logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
60+
# jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
61+
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
62+
63+
logits = jsa.ops.dynamic_rescale_l1_grad(logits)
64+
# (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
6165

6266
return logits - logsumexp(logits, axis=1, keepdims=True)
6367

@@ -81,7 +85,7 @@ def accuracy(params, batch):
8185
step_size = 0.001
8286
num_epochs = 10
8387
batch_size = 128
84-
training_dtype = np.float32
88+
training_dtype = np.float16
8589

8690
train_images, train_labels, test_images, test_labels = datasets.mnist()
8791
num_train = train_images.shape[0]
@@ -106,15 +110,19 @@ def data_stream():
106110
@jsa.autoscale
107111
def update(params, batch):
108112
grads = grad(loss)(params, batch)
109-
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
110-
111-
num_batches = 4
112-
num_epochs = 2
113+
# return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
114+
return [
115+
(jsa.ops.dynamic_rescale_l1(w - step_size * dw), jsa.ops.dynamic_rescale_l1(b - step_size * db))
116+
for (w, b), (dw, db) in zip(params, grads)
117+
]
118+
119+
# num_batches = 4
120+
# num_epochs = 2
113121
for epoch in range(num_epochs):
114-
print("EPOCH:", epoch)
122+
# print("EPOCH:", epoch)
115123
start_time = time.time()
116124
for _ in range(num_batches):
117-
print("BATCH...")
125+
# print("BATCH...")
118126
batch = next(batches)
119127
# Scaled micro-batch + training dtype cast.
120128
batch = jsa.as_scaled_array(batch)
@@ -127,8 +135,8 @@ def update(params, batch):
127135

128136
# Evaluation in float32, for consistency.
129137
raw_params = jsa.asarray(params, dtype=np.float32)
130-
# train_acc = accuracy(raw_params, (train_images, train_labels))
131-
# test_acc = accuracy(raw_params, (test_images, test_labels))
132-
# print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
133-
# print(f"Training set accuracy {train_acc:0.5f}")
134-
# print(f"Test set accuracy {test_acc:0.5f}")
138+
train_acc = accuracy(raw_params, (train_images, train_labels))
139+
test_acc = accuracy(raw_params, (test_images, test_labels))
140+
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
141+
print(f"Training set accuracy {train_acc:0.5f}")
142+
print(f"Test set accuracy {test_acc:0.5f}")

jax_scaled_arithmetics/ops/rescaling.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import partial
33

44
import jax
5+
import jax.numpy as jnp
56
import numpy as np
67

78
from jax_scaled_arithmetics.core import ScaledArray, pow2_round
@@ -48,7 +49,7 @@ def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray:
4849
data_sq = jax.lax.abs(data)
4950
axes = tuple(range(data.ndim))
5051
# Get MAX norm + pow2 rounding.
51-
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes)
52+
norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) + np.float32(1e-3)
5253
norm = pow2_round(norm.astype(scale.dtype))
5354
# Rebalancing based on norm.
5455
return rebalance(arr, norm)
@@ -63,7 +64,7 @@ def dynamic_rescale_l1_base(arr: ScaledArray) -> ScaledArray:
6364
data_sq = jax.lax.abs(data.astype(np.float32))
6465
axes = tuple(range(data.ndim))
6566
# Get L1 norm + pow2 rounding.
66-
norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size
67+
norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size + np.float32(1e-3)
6768
norm = pow2_round(norm.astype(scale.dtype))
6869
# Rebalancing based on norm.
6970
return rebalance(arr, norm)
@@ -78,7 +79,8 @@ def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray:
7879
data_sq = jax.lax.integer_pow(data.astype(np.float32), 2)
7980
axes = tuple(range(data.ndim))
8081
# Get L2 norm + pow2 rounding.
81-
norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes)) / data.size
82+
norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / np.float32(data.size)) + np.float32(1e-3)
83+
# jax.debug.print("{} // {} // {}", jnp.mean(data.astype(np.float32)), jnp.std(data.astype(np.float32)), norm)
8284
norm = pow2_round(norm.astype(scale.dtype))
8385
# Rebalancing based on norm.
8486
return rebalance(arr, norm)

0 commit comments

Comments
 (0)