Skip to content

Commit 85e71b8

Browse files
author
zihao.chen_tp
committed
add ConvLstm
1 parent e14aa13 commit 85e71b8

File tree

2 files changed

+257
-7
lines changed

2 files changed

+257
-7
lines changed

ConvLSTM.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#!/usr/bin/evn python
2+
# -*- coding: utf-8 -*-
3+
# Copyright (c) 2017 - zihao.chen <[email protected]>
4+
'''
5+
Author: zihao.chen
6+
Create Date: 2018-04-20
7+
Modify Date: 2018-04-20
8+
descirption: ""
9+
'''
10+
11+
import torch.nn as nn
12+
from torch.autograd import Variable
13+
import torch
14+
15+
16+
def weights_init(m):
17+
classname = m.__class__.__name__
18+
if classname.find('Conv') != -1:
19+
m.weight.data.normal_(0.0, 0.02)
20+
elif classname.find('BatchNorm') != -1:
21+
m.weight.data.normal_(1.0, 0.02)
22+
m.bias.data.fill_(0)
23+
24+
25+
class CLSTM_cell(nn.Module):
26+
"""Initialize a basic Conv LSTM cell.
27+
Args:
28+
shape: int tuple thats the height and width of the hidden states h and c()
29+
filter_size: int that is the height and width of the filters
30+
num_features: int thats the num of channels of the states, like hidden_size
31+
32+
"""
33+
34+
def __init__(self, shape, input_chans, filter_size, num_features):
35+
super(CLSTM_cell, self).__init__()
36+
37+
self.shape = shape # H,W
38+
self.input_chans = input_chans
39+
self.filter_size = filter_size
40+
self.num_features = num_features
41+
self.dropout = nn.Dropout(p=0.5)
42+
# self.batch_size=batch_size
43+
self.padding = (filter_size - 1) / 2 # in this way the output has the same size
44+
self.conv = nn.Conv2d(self.input_chans + self.num_features, 4 * self.num_features, self.filter_size, 1,
45+
self.padding)
46+
47+
def forward(self, input, hidden_state):
48+
# print type(hidden_state)
49+
hidden, c = hidden_state # hidden and c are images with several channels
50+
# print 'hidden ',hidden.size()
51+
# print 'input ',input.size()
52+
combined = torch.cat((input, hidden), 1) # oncatenate in the channels
53+
# print 'combined',combined.size()
54+
# print type(combined.data)
55+
A = self.conv(combined)
56+
(ai, af, ao, ag) = torch.split(A, self.num_features, dim=1) # it should return 4 tensors
57+
i = torch.sigmoid(ai)
58+
i = self.dropout(i)
59+
f = torch.sigmoid(af)
60+
f = self.dropout(f)
61+
o = torch.sigmoid(ao)
62+
o = self.dropout(o)
63+
g = torch.tanh(ag)
64+
g = self.dropout(g)
65+
66+
next_c = f * c + i * g
67+
next_h = o * torch.tanh(next_c)
68+
next_h = self.dropout(next_h)
69+
return next_h, (next_h,next_c)
70+
71+
def init_hidden(self, batch_size):
72+
return (Variable(torch.zeros(batch_size, self.num_features, self.shape[0], self.shape[1])).cuda(),
73+
Variable(torch.zeros(batch_size, self.num_features, self.shape[0], self.shape[1])).cuda())
74+
75+
76+
class MultiConvRNNCell(nn.Module):
77+
def __init__(self,cells,state_is_tuple=True):
78+
super(MultiConvRNNCell, self).__init__()
79+
self._cells = cells
80+
self._state_is_tuple = state_is_tuple
81+
82+
def init_hidden(self, batch_size):
83+
init_states = [] # this is a list of tuples
84+
for i in xrange(len(self._cells)):
85+
init_states.append(self._cells[i].init_hidden(batch_size))
86+
return init_states
87+
88+
def forward(self, input, hidden_state):
89+
cur_inp = input
90+
new_states = []
91+
for i, cell in enumerate(self._cells):
92+
cur_state = hidden_state[i]
93+
# print 'cur_inp size :', cur_inp.size()
94+
# print 'cur_state size :', cur_state[0].size()
95+
# print type(cur_inp.data),type(cur_state[0].data)
96+
cur_inp, new_state = cell(cur_inp, cur_state)
97+
# print 'cur_inp size :',cur_inp.size()
98+
# print 'cur_state size :', cur_state[0].size()
99+
new_states.append(new_state)
100+
101+
new_states = tuple(new_states)
102+
return cur_inp,new_states
103+
104+
105+
106+
class CLSTM(nn.Module):
107+
"""Initialize a basic Conv LSTM cell.
108+
Args:
109+
shape: int tuple thats the height and width of the hidden states h and c()
110+
filter_size: int that is the height and width of the filters
111+
num_features: int thats the num of channels of the states, like hidden_size
112+
113+
"""
114+
115+
def __init__(self, shape, input_chans, filter_size, num_features, num_layers):
116+
super(CLSTM, self).__init__()
117+
118+
self.shape = shape # H,W
119+
self.input_chans = input_chans
120+
self.filter_size = filter_size
121+
self.num_features = num_features
122+
self.num_layers = num_layers
123+
cell_list = []
124+
cell_list.append(
125+
CLSTM_cell(self.shape, self.input_chans, self.filter_size, self.num_features).cuda()) # the first
126+
# one has a different number of input channels
127+
128+
for idcell in xrange(1, self.num_layers):
129+
cell_list.append(CLSTM_cell(self.shape, self.num_features, self.filter_size, self.num_features).cuda())
130+
self.cell_list = nn.ModuleList(cell_list)
131+
132+
def forward(self, input, hidden_state):
133+
"""
134+
args:
135+
hidden_state:list of tuples, one for every layer, each tuple should be hidden_layer_i,c_layer_i
136+
input is the tensor of shape seq_len,Batch,Chans,H,W
137+
"""
138+
139+
# current_input = input.transpose(0, 1) # now is seq_len,B,C,H,W
140+
current_input=input
141+
next_hidden = [] # hidden states(h and c)
142+
seq_len = current_input.size(0)
143+
144+
for idlayer in xrange(self.num_layers): # loop for every layer
145+
146+
hidden_c = hidden_state[idlayer] # hidden and c are images with several channels
147+
all_output = []
148+
output_inner = []
149+
for t in xrange(seq_len): # loop for every step
150+
hidden_c = self.cell_list[idlayer](current_input[t, ...],
151+
hidden_c) # cell_list is a list with different conv_lstms 1 for every layer
152+
153+
output_inner.append(hidden_c[0])
154+
155+
next_hidden.append(hidden_c)
156+
print output_inner[0].size()
157+
current_input = torch.cat(output_inner, 0).view(current_input.size(0),
158+
*output_inner[0].size()) # seq_len,B,chans,H,W
159+
print current_input.size()
160+
return next_hidden, current_input
161+
162+
def init_hidden(self, batch_size):
163+
init_states = [] # this is a list of tuples
164+
for i in xrange(self.num_layers):
165+
init_states.append(self.cell_list[i].init_hidden(batch_size))
166+
return init_states
167+
168+
if __name__ == '__main__':
169+
170+
###########Usage#######################################
171+
num_features = 64
172+
filter_size = 5
173+
batch_size = 8
174+
shape = (120, 120) # H,W
175+
inp_chans = 3
176+
nlayers = 2
177+
seq_len = 10
178+
179+
# If using this format, then we need to transpose in CLSTM
180+
input = Variable(torch.rand( seq_len,batch_size, inp_chans, shape[0], shape[1])).cuda()
181+
182+
conv_lstm = CLSTM(shape, inp_chans, filter_size, num_features, nlayers)
183+
conv_lstm.apply(weights_init)
184+
conv_lstm.cuda()
185+
186+
print 'convlstm module:', conv_lstm
187+
188+
print 'params:'
189+
params = conv_lstm.parameters()
190+
for p in params:
191+
print 'param ', p.size()
192+
print 'mean ', torch.mean(p)
193+
194+
hidden_state = conv_lstm.init_hidden(batch_size)
195+
print 'hidden_h shape ', len(hidden_state)
196+
print 'hidden_h shape ', hidden_state[0][0].size()
197+
out = conv_lstm(input, hidden_state)
198+
print 'out shape', out[1].size()
199+
print 'len hidden ', len(out[0])
200+
print 'next hidden', out[0][0][0].size()
201+
print 'convlstm dict', conv_lstm.state_dict().keys()
202+
203+
# L = torch.sum(out[1])
204+
# L.backward()

RNN.py

+53-7
Original file line numberDiff line numberDiff line change
@@ -54,26 +54,72 @@ def forward(self,data):
5454
return data
5555

5656

57+
class RNNConvLSTM(nn.Module):
58+
def __init__(self, inplanes, input_num_seqs, output_num_seqs, shape):
59+
super(RNNConvLSTM, self).__init__()
60+
self.inplanes = inplanes
61+
self.input_num_seqs = input_num_seqs
62+
self.output_num_seqs = output_num_seqs
63+
self.shape = (shape, shape)
64+
num_filter = 84
65+
kernel_size = 7
66+
67+
self.cell1 = CLSTM_cell(self.shape, self.inplanes, kernel_size, num_filter)
68+
self.cell2 = CLSTM_cell(self.shape, num_filter, kernel_size, num_filter)
69+
70+
71+
self.stacked_lstm = MultiConvRNNCell([self.cell1,self.cell2])
72+
73+
self.deconv1 = nn.ConvTranspose2d(num_filter, out_channels=1, kernel_size=3, stride=1, padding=1,
74+
bias=True)
75+
76+
def forward(self, data):
77+
new_state = self.stacked_lstm.init_hidden(data.size()[1])
78+
# print new_state[0][0].size()
79+
# new_state = [(Variable(torch.zeros(8, 70, 120, 120).cuda()), Variable(torch.zeros(8, 70, 120, 120).cuda())),
80+
# (Variable(torch.zeros(8, 70, 120, 120).cuda()), Variable(torch.zeros(8, 70, 120, 120).cuda()))]
81+
x_unwrap = []
82+
for i in xrange(self.input_num_seqs + self.output_num_seqs):
83+
# print i
84+
if i < self.input_num_seqs:
85+
y_1, new_state = self.stacked_lstm(data[i], new_state)
86+
else:
87+
y_1, new_state = self.stacked_lstm(x_1, new_state)
88+
# print y_1.size()
89+
x_1 = self.deconv1(y_1)
90+
# print x_1.size()
91+
if i >= self.input_num_seqs:
92+
x_unwrap.append(x_1)
93+
94+
return x_unwrap
95+
96+
5797
def test(num_seqs, channels_img, size_image, max_epoch, model, cuda_test):
58-
input_image = torch.rand(num_seqs, 8, channels_img, size_image, size_image)
98+
input_image = torch.rand(num_seqs, 4, channels_img, size_image, size_image)
5999
input_gru = Variable(input_image.cuda())
60100
MSE_criterion = nn.MSELoss()
61-
for time in xrange(num_seqs):
62-
h_next = model(input_gru[time])
101+
model = model.cuda()
102+
model.train()
103+
# new_state = model.stacked_lstm.init_hidden(8)
104+
# new_state = [(Variable(torch.zeros(8, 70, 120, 120).cuda()), Variable(torch.zeros(8, 70, 120, 120).cuda())),
105+
# (Variable(torch.zeros(8, 70, 120, 120).cuda()), Variable(torch.zeros(8, 70, 120, 120).cuda()))]
106+
model(input_gru)
107+
# for time in xrange(num_seqs):
108+
# h_next = model(input_gru[time])
63109

64110

65111
if __name__ == '__main__':
66112
num_seqs = 10
67113
hidden_size = 3
68-
channels_img = 2
114+
channels_img = 1
69115
size_image = 120
70116
max_epoch = 10
71117
cuda_flag = False
72118
kernel_size = 3
73-
rcg = RNNCovnGRU(inplanes=2, input_num_seqs=10,output_num_seqs=10)
119+
rcg = RNNConvLSTM(inplanes=1, input_num_seqs=10, output_num_seqs=10, shape=size_image)
74120
print(rcg)
75-
rcg = rcg.cuda()
76-
test(num_seqs,channels_img,size_image,max_epoch,rcg,cuda_flag)
121+
# rcg = rcg.cuda()
122+
test(num_seqs, channels_img, size_image, max_epoch, rcg, cuda_flag)
77123

78124

79125

0 commit comments

Comments
 (0)