Skip to content

Commit e5635ad

Browse files
committed
save model
1 parent fc783fc commit e5635ad

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

pretrained/mnist.pt

1.17 MB
Binary file not shown.

torch_mnist.py

+8
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def forward(self, x):
2626
epochs = 1000
2727
m = MLP()
2828
optim = adam.Adam(m.parameters(), lr=3e-4)
29+
loss = None
2930
for i in range(epochs):
3031
data, labels = next(iter(train_dataloader))
3132
pred = m(data)
@@ -42,3 +43,10 @@ def forward(self, x):
4243
correct = (predicted == labels).sum().item() # Count correct predictions
4344
accuracy = correct / labels.size(0) # Compute accuracy as a fraction of total samples
4445
print(f"loss: {loss}, accuracy: {accuracy}")
46+
47+
torch.save({
48+
'epoch': epochs,
49+
'model_state_dict': m.state_dict(),
50+
'optimizer_state_dict': optim.state_dict(),
51+
'loss': loss
52+
}, "./pretrained/mnist.pt")

0 commit comments

Comments
 (0)