Skip to content

Commit fa95b25

Browse files
suquarkmingyuliutw
authored andcommitted
Add a converter which shows how to convert old torch models into pytorch models.
1 parent 6aa8f90 commit fa95b25

File tree

1 file changed

+130
-0
lines changed

1 file changed

+130
-0
lines changed

converter.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import os
2+
3+
import torch
4+
import torch.nn as nn
5+
from torch.utils.serialization import load_lua
6+
7+
from models import VGGEncoder, VGGDecoder
8+
from photo_wct import PhotoWCT
9+
10+
11+
def weight_assign(lua, pth, maps):
12+
for k, v in maps.items():
13+
getattr(pth, k).weight = nn.Parameter(lua.get(v).weight.float())
14+
getattr(pth, k).bias = nn.Parameter(lua.get(v).bias.float())
15+
16+
17+
def photo_wct_loader(p_wct):
18+
p_wct.e1.load_state_dict(torch.load('pth_models/vgg_normalised_conv1.pth'))
19+
p_wct.d1.load_state_dict(torch.load('pth_models/feature_invertor_conv1.pth'))
20+
p_wct.e2.load_state_dict(torch.load('pth_models/vgg_normalised_conv2.pth'))
21+
p_wct.d2.load_state_dict(torch.load('pth_models/feature_invertor_conv2.pth'))
22+
p_wct.e3.load_state_dict(torch.load('pth_models/vgg_normalised_conv3.pth'))
23+
p_wct.d3.load_state_dict(torch.load('pth_models/feature_invertor_conv3.pth'))
24+
p_wct.e4.load_state_dict(torch.load('pth_models/vgg_normalised_conv4.pth'))
25+
p_wct.d4.load_state_dict(torch.load('pth_models/feature_invertor_conv4.pth'))
26+
27+
28+
if __name__ == '__main__':
29+
if not os.path.exists('pth_models'):
30+
os.mkdir('pth_models')
31+
32+
## VGGEncoder1
33+
vgg1 = load_lua('models/vgg_normalised_conv1_1_mask.t7')
34+
e1 = VGGEncoder(1)
35+
weight_assign(vgg1, e1, {
36+
'conv0': 0,
37+
'conv1_1': 2,
38+
})
39+
torch.save(e1.state_dict(), 'pth_models/vgg_normalised_conv1.pth')
40+
41+
## VGGDecoder1
42+
inv1 = load_lua('models/feature_invertor_conv1_1_mask.t7')
43+
d1 = VGGDecoder(1)
44+
weight_assign(inv1, d1, {
45+
'conv1_1': 1,
46+
})
47+
torch.save(d1.state_dict(), 'pth_models/feature_invertor_conv1.pth')
48+
49+
## VGGEncoder2
50+
vgg2 = load_lua('models/vgg_normalised_conv2_1_mask.t7')
51+
e2 = VGGEncoder(2)
52+
weight_assign(vgg2, e2, {
53+
'conv0': 0,
54+
'conv1_1': 2,
55+
'conv1_2': 5,
56+
'conv2_1': 9,
57+
})
58+
torch.save(e2.state_dict(), 'pth_models/vgg_normalised_conv2.pth')
59+
60+
## VGGDecoder2
61+
inv2 = load_lua('models/feature_invertor_conv2_1_mask.t7')
62+
d2 = VGGDecoder(2)
63+
weight_assign(inv2, d2, {
64+
'conv2_1': 1,
65+
'conv1_2': 5,
66+
'conv1_1': 8,
67+
})
68+
torch.save(d2.state_dict(), 'pth_models/feature_invertor_conv2.pth')
69+
70+
## VGGEncoder3
71+
vgg3 = load_lua('models/vgg_normalised_conv3_1_mask.t7')
72+
e3 = VGGEncoder(3)
73+
weight_assign(vgg3, e3, {
74+
'conv0': 0,
75+
'conv1_1': 2,
76+
'conv1_2': 5,
77+
'conv2_1': 9,
78+
'conv2_2': 12,
79+
'conv3_1': 16,
80+
})
81+
torch.save(e3.state_dict(), 'pth_models/vgg_normalised_conv3.pth')
82+
83+
## VGGDecoder3
84+
inv3 = load_lua('models/feature_invertor_conv3_1_mask.t7')
85+
d3 = VGGDecoder(3)
86+
weight_assign(inv3, d3, {
87+
'conv3_1': 1,
88+
'conv2_2': 5,
89+
'conv2_1': 8,
90+
'conv1_2': 12,
91+
'conv1_1': 15,
92+
})
93+
torch.save(d3.state_dict(), 'pth_models/feature_invertor_conv3.pth')
94+
95+
## VGGEncoder4
96+
vgg4 = load_lua('models/vgg_normalised_conv4_1_mask.t7')
97+
e4 = VGGEncoder(4)
98+
weight_assign(vgg4, e4, {
99+
'conv0': 0,
100+
'conv1_1': 2,
101+
'conv1_2': 5,
102+
'conv2_1': 9,
103+
'conv2_2': 12,
104+
'conv3_1': 16,
105+
'conv3_2': 19,
106+
'conv3_3': 22,
107+
'conv3_4': 25,
108+
'conv4_1': 29,
109+
})
110+
torch.save(e4.state_dict(), 'pth_models/vgg_normalised_conv4.pth')
111+
112+
## VGGDecoder4
113+
inv4 = load_lua('models/feature_invertor_conv4_1_mask.t7')
114+
d4 = VGGDecoder(4)
115+
weight_assign(inv4, d4, {
116+
'conv4_1': 1,
117+
'conv3_4': 5,
118+
'conv3_3': 8,
119+
'conv3_2': 11,
120+
'conv3_1': 14,
121+
'conv2_2': 18,
122+
'conv2_1': 21,
123+
'conv1_2': 25,
124+
'conv1_1': 28,
125+
})
126+
torch.save(d4.state_dict(), 'pth_models/feature_invertor_conv4.pth')
127+
128+
p_wct = PhotoWCT()
129+
photo_wct_loader(p_wct)
130+
torch.save(p_wct.state_dict(), 'PhotoWCTModels/photo_wct.pth')

0 commit comments

Comments
 (0)