We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fc783fc commit e5635adCopy full SHA for e5635ad
pretrained/mnist.pt
1.17 MB
torch_mnist.py
@@ -26,6 +26,7 @@ def forward(self, x):
26
epochs = 1000
27
m = MLP()
28
optim = adam.Adam(m.parameters(), lr=3e-4)
29
+loss = None
30
for i in range(epochs):
31
data, labels = next(iter(train_dataloader))
32
pred = m(data)
@@ -42,3 +43,10 @@ def forward(self, x):
42
43
correct = (predicted == labels).sum().item() # Count correct predictions
44
accuracy = correct / labels.size(0) # Compute accuracy as a fraction of total samples
45
print(f"loss: {loss}, accuracy: {accuracy}")
46
+
47
+torch.save({
48
+ 'epoch': epochs,
49
+ 'model_state_dict': m.state_dict(),
50
+ 'optimizer_state_dict': optim.state_dict(),
51
+ 'loss': loss
52
+}, "./pretrained/mnist.pt")
0 commit comments