Skip to content

Commit b29b43e

Browse files
committed
add dropout and Data enhancement
1 parent a572538 commit b29b43e

File tree

2 files changed

+73
-78
lines changed

2 files changed

+73
-78
lines changed

classify_cifar10.py

+62-74
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import torch.optim as optim
55
import pickle
66
from torchvision import datasets, transforms
7-
from resnet18 import ResNet,ResidualBlock
87
from vgg16 import VGG16
8+
import torchvision
99

1010

1111
BATCH_SIZE=100 # 批次大小
12-
EPOCHS=90 # 总共训练批次
12+
EPOCHS=200 # 总共训练批次
1313
LR = 0.01 #学习率
1414
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1515

@@ -71,20 +71,11 @@ def print_lr(optimizer):
7171
for param_group in optimizer.param_groups:
7272
print(param_group['lr'])
7373

74-
dict = [[],[],[],[],[],[]]
75-
76-
dict[0] = unpickle('Cifar/cifar-10-batches-py/data_batch_1')
77-
dict[1] = unpickle('Cifar/cifar-10-batches-py/data_batch_2')
78-
dict[2] = unpickle('Cifar/cifar-10-batches-py/data_batch_3')
79-
dict[3] = unpickle('Cifar/cifar-10-batches-py/data_batch_4')
80-
dict[4] = unpickle('Cifar/cifar-10-batches-py/data_batch_5')
81-
dict[5] = unpickle('Cifar/cifar-10-batches-py/test_batch')
82-
8374
#model = Net().to(DEVICE)
84-
model = ResNet(ResidualBlock).to(DEVICE)
85-
#model = VGG16().to(DEVICE)
75+
#model = ResNet(ResidualBlock).to(DEVICE)
76+
model = VGG16().to(DEVICE)
77+
#optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4)
8678
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
87-
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5*30*(int(10000/BATCH_SIZE)), gamma=0.1) #每调用一次step(),step_size就会+1,学习率的衰减策略,要根据实际的迭代次数进行设置
8879

8980
criterion = nn.CrossEntropyLoss()
9081
model.train()
@@ -104,78 +95,75 @@ def process_train(num):
10495
target.append(batch_target)
10596
return data,target
10697

107-
data,target = process_train(0)
108-
data1,target1 = process_train(1)
109-
data2,target2 = process_train(2)
110-
data3,target3 = process_train(3)
111-
data4,target4 = process_train(4)
98+
train_transform = transforms.Compose(
99+
[
100+
transforms.RandomHorizontalFlip(),
101+
transforms.RandomGrayscale(),
102+
transforms.ToTensor(),
103+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
112104

113-
test_data,test_target = process_train(5)
105+
train_dataset=torchvision.datasets.CIFAR10(root='data/',train=True,transform=train_transform,download=False)
114106

115107

116-
def train(inputdata,inputtarget):
117-
for i in range((int(10000/BATCH_SIZE))):
118-
batch_data = inputdata[i]
119-
batch_target = inputtarget[i]
120-
batch_target = torch.tensor(batch_target).to(DEVICE)
121-
target_t = torch.tensor(batch_target)
122-
data_t = torch.tensor(batch_data)
123-
# print(target_t)
124-
# print(data_t.size())
125-
# data_t = data_t.reshape(1,3,32,32)
126-
data_t = data_t.to(DEVICE)
127-
data_t = data_t.float()
128-
data_t = data_t/255.0
129-
data_t = (data_t - 0.5)/0.5
108+
test_transform = transforms.Compose(
109+
[
110+
transforms.ToTensor(),
111+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
112+
113+
114+
test_dataset=torchvision.datasets.CIFAR10(root='data/',train=False,transform=test_transform,download=False)
115+
116+
117+
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE,
118+
shuffle=True, num_workers=2)
119+
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE,
120+
shuffle=True, num_workers=2)
121+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50*len(trainloader), gamma=0.1)
122+
123+
print("trainloader len: ", len(train_dataset))
124+
print("testloader len: :", len(test_dataset))
125+
126+
127+
128+
def train(trainloader):
129+
model.train()
130+
for i,data in enumerate(trainloader, 0):
131+
batch_data, batch_target = data
132+
batch_data, batch_target =batch_data.to(DEVICE), batch_target.to(DEVICE)
133+
130134
optimizer.zero_grad()
131-
out = model(data_t)
132-
loss = criterion(out, target_t)
135+
out = model(batch_data)
136+
loss = criterion(out, batch_target)
133137
loss.backward()
134-
optimizer.step()
138+
optimizer.step()
135139
scheduler.step()
136140
if i%((int(10000/BATCH_SIZE))) ==0:
137-
print('loss %.4f ' %loss.item())
138-
print("lr: ")
139-
print_lr(optimizer)
140-
141+
print('loss %.4f ' %loss.item(),"lr:", optimizer.param_groups[0]['lr'])
141142

143+
def test(testloader):
144+
model.eval()
145+
corrent_nums = 0
146+
total = 0
147+
for i,data in enumerate(testloader, 0):
148+
images, labels = data
149+
images, labels = images.to(DEVICE), labels.to(DEVICE)
150+
outputs = model(images)
151+
_, predicted = torch.max(outputs.data, 1)
152+
total += labels.size(0)
153+
corrent_nums += (predicted == labels).sum().item()
154+
print(corrent_nums, total)
142155

143-
for i in range(EPOCHS):
144-
print(str(i)+'/'+str(EPOCHS))
145-
train(data,target)
146-
train(data1,target1)
147-
train(data2,target2)
148-
train(data3,target3)
149-
train(data4,target4)
150156

151157

152158

153-
model.eval()
154-
corrent_nums = 0
155-
for i in range(int(10000/BATCH_SIZE)):
156-
batch_data = test_data[i]
157-
batch_target = test_target[i]
158-
# print(batch_data)
159-
160-
batch_target = torch.tensor(batch_target).to(DEVICE)
161-
target_t = torch.tensor(batch_target)
162-
data_t = torch.tensor(batch_data)
163-
data_t = data_t.to(DEVICE)
164-
data_t = data_t.float()
165-
data_t = data_t/255.0
166-
data_t = (data_t - 0.5)/0.5
167-
output = model(data_t)
168-
_, predicted = torch.max(output.data, 1)
169-
pred = output.max(1, keepdim=True)[1]
170-
171-
# print(predicted)
172-
# print(target_t)
173-
corrent_nums += (predicted == target_t).sum().item()
174-
# for j in range(4):
175-
# if predicted[j] == target_t[j]:
176-
# corrent_nums = corrent_nums + 1
177-
178-
print(corrent_nums)
159+
for i in range(EPOCHS):
160+
print(str(i)+'/'+str(EPOCHS))
161+
train(trainloader)
162+
if i%10==0 and i!=0:
163+
test(testloader)
164+
model_name = "./weights/vgg11_dropout_" + str(i) + ".pth"
165+
model.save_weights(model_name)
166+
179167

180168

181169
'''

vgg16.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,15 @@ def __init__(self):
4343
# view
4444

4545
self.fc1 = nn.Linear(512 * 6 * 6, 360)
46-
# self.fc2 = nn.Linear(4096, 4096)
46+
self.drop1 = nn.Dropout2d()
47+
self.fc2 = nn.Linear(360, 360)
48+
self.drop2 = nn.Dropout2d()
4749
self.fc3 = nn.Linear(360, 10)
4850
# softmax 1 * 1 * 1000
49-
51+
52+
def save_weights(self, path):
53+
torch.save(self.state_dict(), path)
54+
5055
def forward(self, x):
5156

5257
# x.size(0)即为batch_size
@@ -100,8 +105,10 @@ def forward(self, x):
100105

101106
out = self.fc1(out)
102107
out = F.relu(out)
103-
#out = self.fc2(out)
104-
#out = F.relu(out)
108+
out = self.drop1(out)
109+
out = self.fc2(out)
110+
out = F.relu(out)
111+
out = self.drop2(out)
105112
out = self.fc3(out)
106113
# out = F.log_softmax(out, dim=1)
107114
return out

0 commit comments

Comments
 (0)