31
31
32
32
33
33
def print_mean_std (name , v ):
34
- _ , scale = jsa .lax .get_data_scale (v )
34
+ data , scale = jsa .lax .get_data_scale (v )
35
35
# Always use np.float32, to avoid floating errors in descaling + stats.
36
- v = jsa .asarray (v , dtype = np .float32 )
36
+ v = jsa .asarray (data , dtype = np .float32 )
37
37
m , s = np .mean (v ), np .std (v )
38
38
print (f"{ name } : MEAN({ m :.4f} ) / STD({ s :.4f} ) / SCALE({ scale :.4f} )" )
39
39
@@ -45,19 +45,23 @@ def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
45
45
def predict (params , inputs ):
46
46
activations = inputs
47
47
for w , b in params [:- 1 ]:
48
- jsa .ops .debug_callback (partial (print_mean_std , "W" ), w )
49
- jsa .ops .debug_callback (partial (print_mean_std , "B" ), b )
50
- (w ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "WGrad" ), w )
48
+ # jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49
+ # jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
50
+ # (w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
51
51
52
52
# Matmul + relu
53
53
outputs = jnp .dot (activations , w ) + b
54
54
activations = jnp .maximum (outputs , 0 )
55
+ # activations = jsa.ops.dynamic_rescale_l2_grad(activations)
55
56
56
57
final_w , final_b = params [- 1 ]
57
58
logits = jnp .dot (activations , final_w ) + final_b
58
59
59
- jsa .ops .debug_callback (partial (print_mean_std , "Logits" ), logits )
60
- (logits ,) = jsa .ops .debug_callback_grad (partial (print_mean_std , "LogitsGrad" ), logits )
60
+ # jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits)
61
+ # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
62
+
63
+ logits = jsa .ops .dynamic_rescale_l1_grad (logits )
64
+ # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits)
61
65
62
66
return logits - logsumexp (logits , axis = 1 , keepdims = True )
63
67
@@ -81,7 +85,7 @@ def accuracy(params, batch):
81
85
step_size = 0.001
82
86
num_epochs = 10
83
87
batch_size = 128
84
- training_dtype = np .float32
88
+ training_dtype = np .float16
85
89
86
90
train_images , train_labels , test_images , test_labels = datasets .mnist ()
87
91
num_train = train_images .shape [0 ]
@@ -106,15 +110,19 @@ def data_stream():
106
110
@jsa .autoscale
107
111
def update (params , batch ):
108
112
grads = grad (loss )(params , batch )
109
- return [(w - step_size * dw , b - step_size * db ) for (w , b ), (dw , db ) in zip (params , grads )]
110
-
111
- num_batches = 4
112
- num_epochs = 2
113
+ # return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
114
+ return [
115
+ (jsa .ops .dynamic_rescale_l1 (w - step_size * dw ), jsa .ops .dynamic_rescale_l1 (b - step_size * db ))
116
+ for (w , b ), (dw , db ) in zip (params , grads )
117
+ ]
118
+
119
+ # num_batches = 4
120
+ # num_epochs = 2
113
121
for epoch in range (num_epochs ):
114
- print ("EPOCH:" , epoch )
122
+ # print("EPOCH:", epoch)
115
123
start_time = time .time ()
116
124
for _ in range (num_batches ):
117
- print ("BATCH..." )
125
+ # print("BATCH...")
118
126
batch = next (batches )
119
127
# Scaled micro-batch + training dtype cast.
120
128
batch = jsa .as_scaled_array (batch )
@@ -127,8 +135,8 @@ def update(params, batch):
127
135
128
136
# Evaluation in float32, for consistency.
129
137
raw_params = jsa .asarray (params , dtype = np .float32 )
130
- # train_acc = accuracy(raw_params, (train_images, train_labels))
131
- # test_acc = accuracy(raw_params, (test_images, test_labels))
132
- # print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
133
- # print(f"Training set accuracy {train_acc:0.5f}")
134
- # print(f"Test set accuracy {test_acc:0.5f}")
138
+ train_acc = accuracy (raw_params , (train_images , train_labels ))
139
+ test_acc = accuracy (raw_params , (test_images , test_labels ))
140
+ print (f"Epoch { epoch } in { epoch_time :0.2f} sec" )
141
+ print (f"Training set accuracy { train_acc :0.5f} " )
142
+ print (f"Test set accuracy { test_acc :0.5f} " )
0 commit comments