-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
61 lines (50 loc) · 1.75 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from mobilenet_v3 import mobilenet_v3_large
# Define the transformations for the train and test sets
transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
# Load the CIFAR-10 datasets
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=256, shuffle=True, num_workers=12
)
# Instantiate the network, loss function, and optimizer
device = torch.device("cuda:2")
model = mobilenet_v3_large()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(100):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward + backward + optimize
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Print statistics
running_loss += loss.item()
if i % 100 == 99: # Print every 100 mini-batches
print(
f"[Epoch {epoch + 1}, Batch {i +
1}] loss: {running_loss / 100:.3f}"
)
running_loss = 0.0
print("Finished Training")
h_net = torch.nn.Sequential(*list(model.features.children())[:8])
torch.save(h_net.state_dict(), 'h_net.pth')
print(h_net)