Skip to content

Commit 071a59d

Browse files
committed
Use automatic downloads of pretrained weights
PyTorch will automatically download pretrained weights into `os.environ['TORCH_MODEL_ZOO']` using the mechanism described here: (http://pytorch.org/docs/master/model_zoo.html) Removed hardcoded paths
1 parent 4a1721f commit 071a59d

File tree

8 files changed

+20
-43
lines changed

8 files changed

+20
-43
lines changed

models/config.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
import os
22

3-
# here (https://github.com/pytorch/vision/tree/master/torchvision/models) to find the download link of pretrained models
4-
5-
root = '/media/b3-542/LIBRARY/ZijunDeng/PyTorch Pretrained'
6-
res101_path = os.path.join(root, 'ResNet', 'resnet101-5d3b4d8f.pth')
7-
res152_path = os.path.join(root, 'ResNet', 'resnet152-b121ed2d.pth')
8-
inception_v3_path = os.path.join(root, 'Inception', 'inception_v3_google-1a9a5a14.pth')
9-
vgg19_bn_path = os.path.join(root, 'VggNet', 'vgg19_bn-c79401a0.pth')
10-
vgg16_path = os.path.join(root, 'VggNet', 'vgg16-397923af.pth')
11-
dense201_path = os.path.join(root, 'DenseNet', 'densenet201-4c113574.pth')
3+
# PyTorch will automatically download pretrained weights into `os.environ['TORCH_MODEL_ZOO']`
4+
# using the mechanism described here: (http://pytorch.org/docs/master/model_zoo.html)
5+
# Download links used are also listed here: (https://github.com/pytorch/vision/tree/master/torchvision/models)
126

137
'''
148
vgg16 trained using caffe
159
visit this (https://github.com/jcjohnson/pytorch-vgg) to download the converted vgg16
1610
'''
17-
vgg16_caffe_path = os.path.join(root, 'VggNet', 'vgg16-caffe.pth')
11+
vgg16_caffe_path = os.path.join(os.environ.get('TORCH_MODEL_ZOO', '.'), 'vgg16-caffe.pth')

models/duc_hdc.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
from torch import nn
33
from torchvision import models
44

5-
from .config import res152_path
6-
7-
85
class _DenseUpsamplingConvModule(nn.Module):
96
def __init__(self, down_factor, in_dim, num_classes):
107
super(_DenseUpsamplingConvModule, self).__init__()
@@ -26,9 +23,7 @@ class ResNetDUC(nn.Module):
2623
# the size of image should be multiple of 8
2724
def __init__(self, num_classes, pretrained=True):
2825
super(ResNetDUC, self).__init__()
29-
resnet = models.resnet152()
30-
if pretrained:
31-
resnet.load_state_dict(torch.load(res152_path))
26+
resnet = models.resnet152(pretrained=pretrained)
3227
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
3328
self.layer1 = resnet.layer1
3429
self.layer2 = resnet.layer2

models/fcn16s.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
class FCN16VGG(nn.Module):
1010
def __init__(self, num_classes, pretrained=True):
1111
super(FCN16VGG, self).__init__()
12-
vgg = models.vgg16()
13-
if pretrained:
14-
vgg.load_state_dict(torch.load(vgg16_caffe_path))
12+
vgg = models.vgg16(pretrained=pretrained)
1513
features, classifier = list(vgg.features.children()), list(vgg.classifier.children())
1614

1715
features[0].padding = (100, 100)

models/fcn32s.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,12 @@
33
from torchvision import models
44

55
from ..utils import get_upsampling_weight
6-
from .config import vgg16_caffe_path
76

87

98
class FCN32VGG(nn.Module):
109
def __init__(self, num_classes, pretrained=True):
1110
super(FCN32VGG, self).__init__()
12-
vgg = models.vgg16()
13-
if pretrained:
14-
vgg.load_state_dict(torch.load(vgg16_caffe_path))
11+
vgg = models.vgg16(pretrained=pretrained)
1512
features, classifier = list(vgg.features.children()), list(vgg.classifier.children())
1613

1714
features[0].padding = (100, 100)

models/fcn8s.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,22 @@
33
from torchvision import models
44

55
from ..utils import get_upsampling_weight
6-
from .config import vgg16_path, vgg16_caffe_path
6+
from .config import vgg16_caffe_path
77

88

99
# This is implemented in full accordance with the original one (https://github.com/shelhamer/fcn.berkeleyvision.org)
1010
class FCN8s(nn.Module):
1111
def __init__(self, num_classes, pretrained=True, caffe=False):
1212
super(FCN8s, self).__init__()
13-
vgg = models.vgg16()
14-
if pretrained:
15-
if caffe:
16-
# load the pretrained vgg16 used by the paper's author
17-
vgg.load_state_dict(torch.load(vgg16_caffe_path))
18-
else:
19-
vgg.load_state_dict(torch.load(vgg16_path))
13+
14+
if pretrained and caffe:
15+
vgg = models.vgg16()
16+
# load the pretrained vgg16 used by the paper's author
17+
vgg.load_state_dict(torch.load(vgg16_caffe_path))
18+
else:
19+
# if pretrained, load the weights from PyTorch model zoo
20+
vgg = models.vgg16(pretrained=pretrained)
21+
2022
features, classifier = list(vgg.features.children()), list(vgg.classifier.children())
2123

2224
'''

models/gcn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from torchvision import models
55

66
from ..utils import initialize_weights
7-
from .config import res152_path
87

98

109
# many are borrowed from https://github.com/ycszen/pytorch-ss/blob/master/gcn.py
@@ -52,9 +51,7 @@ class GCN(nn.Module):
5251
def __init__(self, num_classes, input_size, pretrained=True):
5352
super(GCN, self).__init__()
5453
self.input_size = input_size
55-
resnet = models.resnet152()
56-
if pretrained:
57-
resnet.load_state_dict(torch.load(res152_path))
54+
resnet = models.resnet152(pretrained=pretrained)
5855
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
5956
self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1)
6057
self.layer2 = resnet.layer2

models/psp_net.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from ..utils import initialize_weights
77
from ..utils.misc import Conv2dDeformable
8-
from .config import res101_path
98

109

1110
class _PyramidPoolingModule(nn.Module):
@@ -34,9 +33,7 @@ class PSPNet(nn.Module):
3433
def __init__(self, num_classes, pretrained=True, use_aux=True):
3534
super(PSPNet, self).__init__()
3635
self.use_aux = use_aux
37-
resnet = models.resnet101()
38-
if pretrained:
39-
resnet.load_state_dict(torch.load(res101_path))
36+
resnet = models.resnet101(pretrained=pretrained)
4037
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
4138
self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
4239

models/seg_net.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from torchvision import models
44

55
from ..utils import initialize_weights
6-
from .config import vgg19_bn_path
76

87

98
class _DecoderBlock(nn.Module):
@@ -35,9 +34,7 @@ def forward(self, x):
3534
class SegNet(nn.Module):
3635
def __init__(self, num_classes, pretrained=True):
3736
super(SegNet, self).__init__()
38-
vgg = models.vgg19_bn()
39-
if pretrained:
40-
vgg.load_state_dict(torch.load(vgg19_bn_path))
37+
vgg = models.vgg19_bn(pretrained=pretrained)
4138
features = list(vgg.features.children())
4239
self.enc1 = nn.Sequential(*features[0:7])
4340
self.enc2 = nn.Sequential(*features[7:14])

0 commit comments

Comments
 (0)