Skip to content

Commit e2fac8c

Browse files
gitignore added
1 parent 10fe2f3 commit e2fac8c

40 files changed

+977
-13
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
checkpoint/
2+
data/
3+
checkpoint_/

.idea/PyTorch-BayesianCNN.iml

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

+8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

+921
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
1.26 KB
Binary file not shown.

__pycache__/config.cpython-36.pyc

1.16 KB
Binary file not shown.

bayesian_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
optim_type = 'Adam'
88
lr = 0.001
99
weight_decay = 0.0005
10-
num_samples = 100
10+
num_samples = 25
1111
beta_type = "Blundell"
1212

1313

config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
start_epoch = 1
55
num_epochs = 100
6-
batch_size = 128
6+
batch_size = 256
77
optim_type = 'Adam'
88

99
mean = {

main_Bayes.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torchvision
1313
import torchvision.transforms as transforms
14+
from utils.autoaugment import CIFAR10Policy
1415

1516
import torch
1617
import torch.utils.data as data
@@ -31,22 +32,22 @@
3132

3233
parser = argparse.ArgumentParser(description='PyTorch Bayesian Model Training')
3334
#parser.add_argument('--lr', default=0.001, type=float, help='learning_rate')
34-
parser.add_argument('--net_type', default='lenet', type=str, help='model')
35+
parser.add_argument('--net_type', default='alexnet', type=str, help='model')
3536
#parser.add_argument('--depth', default=28, type=int, help='depth of model')
3637
#parser.add_argument('--widen_factor', default=10, type=int, help='width of model')
3738
#parser.add_argument('--num_samples', default=10, type=int, help='Number of samples')
3839
#parser.add_argument('--beta_type', default="Blundell", type=str, help='Beta type')
3940
#parser.add_argument('--p_logvar_init', default=0, type=int, help='p_logvar_init')
4041
#parser.add_argument('--q_logvar_init', default=-10, type=int, help='q_logvar_init')
4142
#parser.add_argument('--weight_decay', default=0.0005, type=float, help='weight_decay')
42-
parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [mnist/cifar10/cifar100/fashionmnist/stl10]')
43+
parser.add_argument('--dataset', default='stl10', type=str, help='dataset = [mnist/cifar10/cifar100/fashionmnist/stl10]')
4344
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
4445
parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model')
4546
args = parser.parse_args()
4647

4748
# Hyper Parameter settings
4849
use_cuda = torch.cuda.is_available()
49-
torch.cuda.set_device(0)
50+
torch.cuda.set_device(1)
5051
best_acc = 0
5152
resize=32
5253

@@ -56,15 +57,17 @@
5657
transform_train = transforms.Compose([
5758
transforms.Resize((resize, resize)),
5859
transforms.RandomCrop(32, padding=4),
59-
transforms.RandomHorizontalFlip(),
60+
#transforms.RandomHorizontalFlip(),
61+
#CIFAR10Policy(),
6062
transforms.ToTensor(),
6163
transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
6264
]) # meanstd transformation
6365

6466
transform_test = transforms.Compose([
6567
transforms.Resize((resize, resize)),
6668
transforms.RandomCrop(32, padding=4),
67-
transforms.RandomHorizontalFlip(),
69+
#transforms.RandomHorizontalFlip(),
70+
#CIFAR10Policy(),
6871
transforms.ToTensor(),
6972
transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
7073
])
@@ -141,7 +144,7 @@ def getNetwork(args):
141144
print('| Resuming from checkpoint...')
142145
assert os.path.isdir('checkpoint'), 'Error: No checkpoint directory found!'
143146
_, file_name = getNetwork(args)
144-
checkpoint = torch.load('./checkpoint/'+args.dataset+os.sep+file_name+'.t7')
147+
checkpoint = torch.load('./checkpoint/'+args.dataset+os.sep+file_name+str(cf.num_samples)+'.t7')
145148
net = checkpoint['net']
146149
best_acc = checkpoint['acc']
147150
cf.start_epoch = checkpoint['epoch']
@@ -154,7 +157,7 @@ def getNetwork(args):
154157

155158
vi = GaussianVariationalInference(torch.nn.CrossEntropyLoss())
156159

157-
logfile = os.path.join('diagnostics_Bayes{}_{}.txt'.format(args.net_type, args.dataset))
160+
logfile = os.path.join('diagnostics_Bayes{}_{}_{}.txt'.format(args.net_type, args.dataset, cf.num_samples))
158161

159162
# Training
160163
def train(epoch):
@@ -255,7 +258,7 @@ def test(epoch):
255258
save_point = './checkpoint/'+args.dataset+os.sep
256259
if not os.path.isdir(save_point):
257260
os.mkdir(save_point)
258-
torch.save(state, save_point+file_name+'.t7')
261+
torch.save(state, save_point+file_name+str(cf.num_samples)+'.t7')
259262
best_acc = acc
260263

261264
print('\n[Phase 3] : Training model')

main_nonBayes.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch
44
import torch.nn as nn
55
import torch.optim as optim
6-
import torch.nn.functional as F
76
import torch.backends.cudnn as cudnn
87
import config as cf
98

@@ -14,7 +13,6 @@
1413
import sys
1514
import time
1615
import argparse
17-
import datetime
1816

1917
from torch.autograd import Variable
2018

@@ -25,14 +23,15 @@
2523
from utils.NonBayesianModels.SqueezeNet import SqueezeNet
2624
from utils.NonBayesianModels.wide_resnet import Wide_ResNet
2725
from utils.NonBayesianModels.ThreeConvThreeFC import ThreeConvThreeFC
26+
from utils.autoaugment import CIFAR10Policy
2827

2928
parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training')
3029
parser.add_argument('--lr', default=0.001, type=float, help='learning_rate')
3130
parser.add_argument('--net_type', default='alexnet', type=str, help='model')
3231
parser.add_argument('--depth', default=28, type=int, help='depth of model')
3332
parser.add_argument('--widen_factor', default=10, type=int, help='width of model')
3433
parser.add_argument('--dropout', default=0.3, type=float, help='dropout_rate')
35-
parser.add_argument('--dataset', default='cifar10', type=str, help='dataset = [mnist/cifar10/cifar100/fashionmnist/stl10]')
34+
parser.add_argument('--dataset', default='stl10', type=str, help='dataset = [mnist/cifar10/cifar100/fashionmnist/stl10]')
3635
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
3736
parser.add_argument('--testOnly', '-t', action='store_true', help='Test mode with the saved model')
3837
args = parser.parse_args()
@@ -48,12 +47,18 @@
4847

4948
transform_train = transforms.Compose([
5049
transforms.Resize((resize, resize)),
50+
transforms.RandomCrop(32, padding=4),
51+
#transforms.RandomHorizontalFlip(),
52+
#CIFAR10Policy(),
5153
transforms.ToTensor(),
5254
transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
5355
]) # meanstd transformation
5456

5557
transform_test = transforms.Compose([
5658
transforms.Resize((resize, resize)),
59+
transforms.RandomCrop(32, padding=4),
60+
#transforms.RandomHorizontalFlip(),
61+
#CIFAR10Policy(),
5762
transforms.ToTensor(),
5863
transforms.Normalize(cf.mean[args.dataset], cf.std[args.dataset]),
5964
])
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
7.11 KB
Binary file not shown.
178 Bytes
Binary file not shown.
8.67 KB
Binary file not shown.

weights/weights_BBBLeNet_CIFAR100.pkl

4.77 MB
Binary file not shown.

weights/weights_LeNet_CIFAR-100.pkl

2.64 MB
Binary file not shown.

0 commit comments

Comments
 (0)