4
4
import torch .optim as optim
5
5
import pickle
6
6
from torchvision import datasets , transforms
7
- from resnet18 import ResNet ,ResidualBlock
8
7
from vgg16 import VGG16
8
+ import torchvision
9
9
10
10
11
11
BATCH_SIZE = 100 # 批次大小
12
- EPOCHS = 90 # 总共训练批次
12
+ EPOCHS = 200 # 总共训练批次
13
13
LR = 0.01 #学习率
14
14
DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
15
15
@@ -71,20 +71,11 @@ def print_lr(optimizer):
71
71
for param_group in optimizer .param_groups :
72
72
print (param_group ['lr' ])
73
73
74
- dict = [[],[],[],[],[],[]]
75
-
76
- dict [0 ] = unpickle ('Cifar/cifar-10-batches-py/data_batch_1' )
77
- dict [1 ] = unpickle ('Cifar/cifar-10-batches-py/data_batch_2' )
78
- dict [2 ] = unpickle ('Cifar/cifar-10-batches-py/data_batch_3' )
79
- dict [3 ] = unpickle ('Cifar/cifar-10-batches-py/data_batch_4' )
80
- dict [4 ] = unpickle ('Cifar/cifar-10-batches-py/data_batch_5' )
81
- dict [5 ] = unpickle ('Cifar/cifar-10-batches-py/test_batch' )
82
-
83
74
#model = Net().to(DEVICE)
84
- model = ResNet (ResidualBlock ).to (DEVICE )
85
- #model = VGG16().to(DEVICE)
75
+ #model = ResNet(ResidualBlock).to(DEVICE)
76
+ model = VGG16 ().to (DEVICE )
77
+ #optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4)
86
78
optimizer = optim .SGD (model .parameters (), lr = LR , momentum = 0.9 , weight_decay = 5e-4 )
87
- scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 5 * 30 * (int (10000 / BATCH_SIZE )), gamma = 0.1 ) #每调用一次step(),step_size就会+1,学习率的衰减策略,要根据实际的迭代次数进行设置
88
79
89
80
criterion = nn .CrossEntropyLoss ()
90
81
model .train ()
@@ -104,78 +95,75 @@ def process_train(num):
104
95
target .append (batch_target )
105
96
return data ,target
106
97
107
- data ,target = process_train (0 )
108
- data1 ,target1 = process_train (1 )
109
- data2 ,target2 = process_train (2 )
110
- data3 ,target3 = process_train (3 )
111
- data4 ,target4 = process_train (4 )
98
+ train_transform = transforms .Compose (
99
+ [
100
+ transforms .RandomHorizontalFlip (),
101
+ transforms .RandomGrayscale (),
102
+ transforms .ToTensor (),
103
+ transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))])
112
104
113
- test_data , test_target = process_train ( 5 )
105
+ train_dataset = torchvision . datasets . CIFAR10 ( root = 'data/' , train = True , transform = train_transform , download = False )
114
106
115
107
116
- def train (inputdata ,inputtarget ):
117
- for i in range ((int (10000 / BATCH_SIZE ))):
118
- batch_data = inputdata [i ]
119
- batch_target = inputtarget [i ]
120
- batch_target = torch .tensor (batch_target ).to (DEVICE )
121
- target_t = torch .tensor (batch_target )
122
- data_t = torch .tensor (batch_data )
123
- # print(target_t)
124
- # print(data_t.size())
125
- # data_t = data_t.reshape(1,3,32,32)
126
- data_t = data_t .to (DEVICE )
127
- data_t = data_t .float ()
128
- data_t = data_t / 255.0
129
- data_t = (data_t - 0.5 )/ 0.5
108
+ test_transform = transforms .Compose (
109
+ [
110
+ transforms .ToTensor (),
111
+ transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))])
112
+
113
+
114
+ test_dataset = torchvision .datasets .CIFAR10 (root = 'data/' ,train = False ,transform = test_transform ,download = False )
115
+
116
+
117
+ trainloader = torch .utils .data .DataLoader (train_dataset , batch_size = BATCH_SIZE ,
118
+ shuffle = True , num_workers = 2 )
119
+ testloader = torch .utils .data .DataLoader (test_dataset , batch_size = BATCH_SIZE ,
120
+ shuffle = True , num_workers = 2 )
121
+ scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 50 * len (trainloader ), gamma = 0.1 )
122
+
123
+ print ("trainloader len: " , len (train_dataset ))
124
+ print ("testloader len: :" , len (test_dataset ))
125
+
126
+
127
+
128
+ def train (trainloader ):
129
+ model .train ()
130
+ for i ,data in enumerate (trainloader , 0 ):
131
+ batch_data , batch_target = data
132
+ batch_data , batch_target = batch_data .to (DEVICE ), batch_target .to (DEVICE )
133
+
130
134
optimizer .zero_grad ()
131
- out = model (data_t )
132
- loss = criterion (out , target_t )
135
+ out = model (batch_data )
136
+ loss = criterion (out , batch_target )
133
137
loss .backward ()
134
- optimizer .step ()
138
+ optimizer .step ()
135
139
scheduler .step ()
136
140
if i % ((int (10000 / BATCH_SIZE ))) == 0 :
137
- print ('loss %.4f ' % loss .item ())
138
- print ("lr: " )
139
- print_lr (optimizer )
140
-
141
+ print ('loss %.4f ' % loss .item (),"lr:" , optimizer .param_groups [0 ]['lr' ])
141
142
143
+ def test (testloader ):
144
+ model .eval ()
145
+ corrent_nums = 0
146
+ total = 0
147
+ for i ,data in enumerate (testloader , 0 ):
148
+ images , labels = data
149
+ images , labels = images .to (DEVICE ), labels .to (DEVICE )
150
+ outputs = model (images )
151
+ _ , predicted = torch .max (outputs .data , 1 )
152
+ total += labels .size (0 )
153
+ corrent_nums += (predicted == labels ).sum ().item ()
154
+ print (corrent_nums , total )
142
155
143
- for i in range (EPOCHS ):
144
- print (str (i )+ '/' + str (EPOCHS ))
145
- train (data ,target )
146
- train (data1 ,target1 )
147
- train (data2 ,target2 )
148
- train (data3 ,target3 )
149
- train (data4 ,target4 )
150
156
151
157
152
158
153
- model .eval ()
154
- corrent_nums = 0
155
- for i in range (int (10000 / BATCH_SIZE )):
156
- batch_data = test_data [i ]
157
- batch_target = test_target [i ]
158
- # print(batch_data)
159
-
160
- batch_target = torch .tensor (batch_target ).to (DEVICE )
161
- target_t = torch .tensor (batch_target )
162
- data_t = torch .tensor (batch_data )
163
- data_t = data_t .to (DEVICE )
164
- data_t = data_t .float ()
165
- data_t = data_t / 255.0
166
- data_t = (data_t - 0.5 )/ 0.5
167
- output = model (data_t )
168
- _ , predicted = torch .max (output .data , 1 )
169
- pred = output .max (1 , keepdim = True )[1 ]
170
-
171
- # print(predicted)
172
- # print(target_t)
173
- corrent_nums += (predicted == target_t ).sum ().item ()
174
- # for j in range(4):
175
- # if predicted[j] == target_t[j]:
176
- # corrent_nums = corrent_nums + 1
177
-
178
- print (corrent_nums )
159
+ for i in range (EPOCHS ):
160
+ print (str (i )+ '/' + str (EPOCHS ))
161
+ train (trainloader )
162
+ if i % 10 == 0 and i != 0 :
163
+ test (testloader )
164
+ model_name = "./weights/vgg11_dropout_" + str (i ) + ".pth"
165
+ model .save_weights (model_name )
166
+
179
167
180
168
181
169
'''
0 commit comments