Skip to content

Commit fa546aa

Browse files
committed
wip
1 parent 4fc71b2 commit fa546aa

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

experiments/mnist/mnist_classifier_from_scratch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def predict(params, inputs):
4646
activations = inputs
4747
for w, b in params[:-1]:
4848
jsa.ops.debug_callback(partial(print_mean_std, "W"), w)
49+
jsa.ops.debug_callback(partial(print_mean_std, "B"), b)
4950
(w,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "WGrad"), w)
5051

5152
# Matmul + relu
@@ -107,7 +108,7 @@ def update(params, batch):
107108
grads = grad(loss)(params, batch)
108109
return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]
109110

110-
num_batches = 2
111+
num_batches = 4
111112
num_epochs = 2
112113
for epoch in range(num_epochs):
113114
print("EPOCH:", epoch)

0 commit comments

Comments
 (0)