Skip to content

Commit 8e795c5

Browse files
committed
add vgg16,resnet18
1 parent 87c472c commit 8e795c5

File tree

3 files changed

+181
-7
lines changed

3 files changed

+181
-7
lines changed

classify_cifar10.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import torch.optim as optim
55
import pickle
66
from torchvision import datasets, transforms
7+
from resnet18 import ResNet,ResidualBlock
8+
from vgg16 import VGG16
79

8-
BATCH_SIZE=4 # 批次大小
10+
11+
BATCH_SIZE=10 # 批次大小
912
EPOCHS=200 # 总共训练批次
1013
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1114

@@ -65,15 +68,17 @@ def forward(self,x):
6568
dict[4] = unpickle('Cifar/cifar-10-batches-py/data_batch_5')
6669
dict[5] = unpickle('Cifar/cifar-10-batches-py/test_batch')
6770

68-
model = Net().to(DEVICE)
71+
#model = Net().to(DEVICE)
72+
model = ResNet(ResidualBlock).to(DEVICE)
73+
#model = VGG16().to(DEVICE)
6974
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
7075
criterion = nn.CrossEntropyLoss()
7176
model.train()
7277

7378
def process_train(num):
7479
data = []
7580
target = []
76-
for j in range(2500):
81+
for j in range(1000):
7782
batch_data = []
7883
batch_target = []
7984
for i in range(BATCH_SIZE):
@@ -94,7 +99,7 @@ def process_train(num):
9499
test_data,test_target = process_train(5)
95100

96101
def train(inputdata,inputtarget):
97-
for i in range(2500):
102+
for i in range(1000):
98103
batch_data = inputdata[i]
99104
batch_target = inputtarget[i]
100105
batch_target = torch.tensor(batch_target).to(DEVICE)
@@ -112,12 +117,12 @@ def train(inputdata,inputtarget):
112117
loss = criterion(out, target_t)
113118
loss.backward()
114119
optimizer.step()
115-
if i%500 ==0:
120+
if i%1000 ==0:
116121
print(loss.item())
117122

118123

119124

120-
for i in range(2):
125+
for i in range(20):
121126
train(data,target)
122127
train(data1,target1)
123128
train(data2,target2)
@@ -128,7 +133,7 @@ def train(inputdata,inputtarget):
128133

129134
model.eval()
130135
corrent_nums = 0
131-
for i in range(2500):
136+
for i in range(1000):
132137
batch_data = test_data[i]
133138
batch_target = test_target[i]
134139
# print(batch_data)

resnet18.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class ResidualBlock(nn.Module):
6+
def __init__(self, inchannel, outchannel, stride=1):
7+
super(ResidualBlock, self).__init__()
8+
self.left = nn.Sequential(
9+
nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
10+
nn.BatchNorm2d(outchannel),
11+
nn.ReLU(inplace=True),
12+
nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
13+
nn.BatchNorm2d(outchannel)
14+
)
15+
self.shortcut = nn.Sequential()
16+
if stride != 1 or inchannel != outchannel:
17+
self.shortcut = nn.Sequential(
18+
nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
19+
nn.BatchNorm2d(outchannel)
20+
)
21+
22+
def forward(self, x):
23+
out = self.left(x)
24+
out += self.shortcut(x)
25+
out = F.relu(out)
26+
return out
27+
28+
class ResNet(nn.Module):
29+
def __init__(self, ResidualBlock, num_classes=10):
30+
super(ResNet, self).__init__()
31+
self.inchannel = 64
32+
self.conv1 = nn.Sequential(
33+
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
34+
nn.BatchNorm2d(64),
35+
nn.ReLU(),
36+
)
37+
self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
38+
self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
39+
self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
40+
self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
41+
self.fc = nn.Linear(512, num_classes)
42+
43+
def make_layer(self, block, channels, num_blocks, stride):
44+
strides = [stride] + [1] * (num_blocks - 1) #strides=[1,1]
45+
layers = []
46+
for stride in strides:
47+
layers.append(block(self.inchannel, channels, stride))
48+
self.inchannel = channels
49+
return nn.Sequential(*layers)
50+
51+
def forward(self, x):
52+
out = self.conv1(x)
53+
out = self.layer1(out)
54+
out = self.layer2(out)
55+
out = self.layer3(out)
56+
out = self.layer4(out)
57+
out = F.avg_pool2d(out, 4)
58+
out = out.view(out.size(0), -1)
59+
out = self.fc(out)
60+
return out
61+
62+

vgg16.py

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class VGG16(nn.Module):
7+
def __init__(self):
8+
super(VGG16, self).__init__()
9+
10+
# 3 * 32 * 32
11+
self.conv1_1 = nn.Conv2d(3, 64, 3) # 64 * 30 * 30
12+
self.bn1_1 = nn.BatchNorm2d(64)
13+
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # 64 * 30* 30
14+
self.bn1_2 = nn.BatchNorm2d(64)
15+
# self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 64 * 112 * 112
16+
17+
self.conv2_1 = nn.Conv2d(64, 128, 3) # 128 * 28 * 28
18+
self.bn2_1 = nn.BatchNorm2d(128)
19+
# self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1)) # 128 * 28 * 28
20+
# self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 128 * 56 * 56
21+
22+
self.conv3_1 = nn.Conv2d(128, 256, 3) # 256 * 26 * 26
23+
self.bn3_1 = nn.BatchNorm2d(256)
24+
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 26 * 26
25+
self.bn3_2 = nn.BatchNorm2d(256)
26+
# self.conv3_3 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 26 * 26
27+
# self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 256 * 28 * 28
28+
29+
self.conv4_1 = nn.Conv2d(256, 512, 3) # 512 * 24 * 24
30+
self.bn4_1 = nn.BatchNorm2d(512)
31+
# self.conv4_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 24 * 24
32+
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 24 * 24
33+
self.bn4_3 = nn.BatchNorm2d(512)
34+
self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 512 * 13 * 13
35+
36+
self.conv5_1 = nn.Conv2d(512, 512, 3) # 512 * 11 * 11
37+
self.bn5_1 = nn.BatchNorm2d(512)
38+
# self.conv5_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 11 * 11
39+
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 11 * 11
40+
self.bn5_3 = nn.BatchNorm2d(512)
41+
self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 512 * 6 * 6
42+
43+
# view
44+
45+
self.fc1 = nn.Linear(512 * 6 * 6, 360)
46+
# self.fc2 = nn.Linear(4096, 4096)
47+
self.fc3 = nn.Linear(360, 10)
48+
# softmax 1 * 1 * 1000
49+
50+
def forward(self, x):
51+
52+
# x.size(0)即为batch_size
53+
in_size = x.size(0)
54+
55+
out = self.conv1_1(x) # 222
56+
out = self.bn1_1(out)
57+
out = F.relu(out)
58+
out = self.conv1_2(out) # 222
59+
out = self.bn1_2(out)
60+
out = F.relu(out)
61+
# out = self.maxpool1(out) # 112
62+
63+
out = self.conv2_1(out) # 110
64+
out = self.bn2_1(out)
65+
out = F.relu(out)
66+
#out = self.conv2_2(out) # 110
67+
#out = F.relu(out)
68+
# out = self.maxpool2(out) # 56
69+
70+
out = self.conv3_1(out) # 54
71+
out = self.bn3_1(out)
72+
out = F.relu(out)
73+
out = self.conv3_2(out) # 54
74+
out = self.bn3_2(out)
75+
out = F.relu(out)
76+
#out = self.conv3_3(out) # 54
77+
#out = F.relu(out)
78+
# out = self.maxpool3(out) # 28
79+
80+
out = self.conv4_1(out) # 26
81+
out = self.bn4_1(out)
82+
out = F.relu(out)
83+
# out = self.conv4_2(out) # 26
84+
# out = F.relu(out)
85+
out = self.conv4_3(out) # 26
86+
out = self.bn4_3(out)
87+
out = F.relu(out)
88+
out = self.maxpool4(out) # 14
89+
90+
out = self.conv5_1(out) # 12
91+
out = F.relu(out)
92+
#out = self.conv5_2(out) # 12
93+
#out = F.relu(out)
94+
out = self.conv5_3(out) # 12
95+
out = F.relu(out)
96+
out = self.maxpool5(out) # 7
97+
98+
# 展平
99+
out = out.view(in_size, -1)
100+
101+
out = self.fc1(out)
102+
out = F.relu(out)
103+
#out = self.fc2(out)
104+
#out = F.relu(out)
105+
out = self.fc3(out)
106+
# out = F.log_softmax(out, dim=1)
107+
return out

0 commit comments

Comments
 (0)