|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | + |
| 5 | + |
| 6 | +class VGG16(nn.Module): |
| 7 | + def __init__(self): |
| 8 | + super(VGG16, self).__init__() |
| 9 | + |
| 10 | + # 3 * 32 * 32 |
| 11 | + self.conv1_1 = nn.Conv2d(3, 64, 3) # 64 * 30 * 30 |
| 12 | + self.bn1_1 = nn.BatchNorm2d(64) |
| 13 | + self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # 64 * 30* 30 |
| 14 | + self.bn1_2 = nn.BatchNorm2d(64) |
| 15 | +# self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 64 * 112 * 112 |
| 16 | + |
| 17 | + self.conv2_1 = nn.Conv2d(64, 128, 3) # 128 * 28 * 28 |
| 18 | + self.bn2_1 = nn.BatchNorm2d(128) |
| 19 | +# self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1)) # 128 * 28 * 28 |
| 20 | +# self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 128 * 56 * 56 |
| 21 | + |
| 22 | + self.conv3_1 = nn.Conv2d(128, 256, 3) # 256 * 26 * 26 |
| 23 | + self.bn3_1 = nn.BatchNorm2d(256) |
| 24 | + self.conv3_2 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 26 * 26 |
| 25 | + self.bn3_2 = nn.BatchNorm2d(256) |
| 26 | + # self.conv3_3 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 26 * 26 |
| 27 | +# self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 256 * 28 * 28 |
| 28 | + |
| 29 | + self.conv4_1 = nn.Conv2d(256, 512, 3) # 512 * 24 * 24 |
| 30 | + self.bn4_1 = nn.BatchNorm2d(512) |
| 31 | + # self.conv4_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 24 * 24 |
| 32 | + self.conv4_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 24 * 24 |
| 33 | + self.bn4_3 = nn.BatchNorm2d(512) |
| 34 | + self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 512 * 13 * 13 |
| 35 | + |
| 36 | + self.conv5_1 = nn.Conv2d(512, 512, 3) # 512 * 11 * 11 |
| 37 | + self.bn5_1 = nn.BatchNorm2d(512) |
| 38 | + # self.conv5_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 11 * 11 |
| 39 | + self.conv5_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 11 * 11 |
| 40 | + self.bn5_3 = nn.BatchNorm2d(512) |
| 41 | + self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1)) # pooling 512 * 6 * 6 |
| 42 | + |
| 43 | + # view |
| 44 | + |
| 45 | + self.fc1 = nn.Linear(512 * 6 * 6, 360) |
| 46 | +# self.fc2 = nn.Linear(4096, 4096) |
| 47 | + self.fc3 = nn.Linear(360, 10) |
| 48 | + # softmax 1 * 1 * 1000 |
| 49 | + |
| 50 | + def forward(self, x): |
| 51 | + |
| 52 | + # x.size(0)即为batch_size |
| 53 | + in_size = x.size(0) |
| 54 | + |
| 55 | + out = self.conv1_1(x) # 222 |
| 56 | + out = self.bn1_1(out) |
| 57 | + out = F.relu(out) |
| 58 | + out = self.conv1_2(out) # 222 |
| 59 | + out = self.bn1_2(out) |
| 60 | + out = F.relu(out) |
| 61 | +# out = self.maxpool1(out) # 112 |
| 62 | + |
| 63 | + out = self.conv2_1(out) # 110 |
| 64 | + out = self.bn2_1(out) |
| 65 | + out = F.relu(out) |
| 66 | + #out = self.conv2_2(out) # 110 |
| 67 | + #out = F.relu(out) |
| 68 | + # out = self.maxpool2(out) # 56 |
| 69 | + |
| 70 | + out = self.conv3_1(out) # 54 |
| 71 | + out = self.bn3_1(out) |
| 72 | + out = F.relu(out) |
| 73 | + out = self.conv3_2(out) # 54 |
| 74 | + out = self.bn3_2(out) |
| 75 | + out = F.relu(out) |
| 76 | + #out = self.conv3_3(out) # 54 |
| 77 | + #out = F.relu(out) |
| 78 | +# out = self.maxpool3(out) # 28 |
| 79 | + |
| 80 | + out = self.conv4_1(out) # 26 |
| 81 | + out = self.bn4_1(out) |
| 82 | + out = F.relu(out) |
| 83 | + # out = self.conv4_2(out) # 26 |
| 84 | + # out = F.relu(out) |
| 85 | + out = self.conv4_3(out) # 26 |
| 86 | + out = self.bn4_3(out) |
| 87 | + out = F.relu(out) |
| 88 | + out = self.maxpool4(out) # 14 |
| 89 | + |
| 90 | + out = self.conv5_1(out) # 12 |
| 91 | + out = F.relu(out) |
| 92 | + #out = self.conv5_2(out) # 12 |
| 93 | + #out = F.relu(out) |
| 94 | + out = self.conv5_3(out) # 12 |
| 95 | + out = F.relu(out) |
| 96 | + out = self.maxpool5(out) # 7 |
| 97 | + |
| 98 | + # 展平 |
| 99 | + out = out.view(in_size, -1) |
| 100 | + |
| 101 | + out = self.fc1(out) |
| 102 | + out = F.relu(out) |
| 103 | + #out = self.fc2(out) |
| 104 | + #out = F.relu(out) |
| 105 | + out = self.fc3(out) |
| 106 | +# out = F.log_softmax(out, dim=1) |
| 107 | + return out |
0 commit comments