1
+ #!/usr/bin/evn python
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2017 - zihao.chen <[email protected] >
4
+ '''
5
+ Author: zihao.chen
6
+ Create Date: 2018-04-20
7
+ Modify Date: 2018-04-20
8
+ descirption: ""
9
+ '''
10
+
11
+ import torch .nn as nn
12
+ from torch .autograd import Variable
13
+ import torch
14
+
15
+
16
+ def weights_init (m ):
17
+ classname = m .__class__ .__name__
18
+ if classname .find ('Conv' ) != - 1 :
19
+ m .weight .data .normal_ (0.0 , 0.02 )
20
+ elif classname .find ('BatchNorm' ) != - 1 :
21
+ m .weight .data .normal_ (1.0 , 0.02 )
22
+ m .bias .data .fill_ (0 )
23
+
24
+
25
+ class CLSTM_cell (nn .Module ):
26
+ """Initialize a basic Conv LSTM cell.
27
+ Args:
28
+ shape: int tuple thats the height and width of the hidden states h and c()
29
+ filter_size: int that is the height and width of the filters
30
+ num_features: int thats the num of channels of the states, like hidden_size
31
+
32
+ """
33
+
34
+ def __init__ (self , shape , input_chans , filter_size , num_features ):
35
+ super (CLSTM_cell , self ).__init__ ()
36
+
37
+ self .shape = shape # H,W
38
+ self .input_chans = input_chans
39
+ self .filter_size = filter_size
40
+ self .num_features = num_features
41
+ self .dropout = nn .Dropout (p = 0.5 )
42
+ # self.batch_size=batch_size
43
+ self .padding = (filter_size - 1 ) / 2 # in this way the output has the same size
44
+ self .conv = nn .Conv2d (self .input_chans + self .num_features , 4 * self .num_features , self .filter_size , 1 ,
45
+ self .padding )
46
+
47
+ def forward (self , input , hidden_state ):
48
+ # print type(hidden_state)
49
+ hidden , c = hidden_state # hidden and c are images with several channels
50
+ # print 'hidden ',hidden.size()
51
+ # print 'input ',input.size()
52
+ combined = torch .cat ((input , hidden ), 1 ) # oncatenate in the channels
53
+ # print 'combined',combined.size()
54
+ # print type(combined.data)
55
+ A = self .conv (combined )
56
+ (ai , af , ao , ag ) = torch .split (A , self .num_features , dim = 1 ) # it should return 4 tensors
57
+ i = torch .sigmoid (ai )
58
+ i = self .dropout (i )
59
+ f = torch .sigmoid (af )
60
+ f = self .dropout (f )
61
+ o = torch .sigmoid (ao )
62
+ o = self .dropout (o )
63
+ g = torch .tanh (ag )
64
+ g = self .dropout (g )
65
+
66
+ next_c = f * c + i * g
67
+ next_h = o * torch .tanh (next_c )
68
+ next_h = self .dropout (next_h )
69
+ return next_h , (next_h ,next_c )
70
+
71
+ def init_hidden (self , batch_size ):
72
+ return (Variable (torch .zeros (batch_size , self .num_features , self .shape [0 ], self .shape [1 ])).cuda (),
73
+ Variable (torch .zeros (batch_size , self .num_features , self .shape [0 ], self .shape [1 ])).cuda ())
74
+
75
+
76
+ class MultiConvRNNCell (nn .Module ):
77
+ def __init__ (self ,cells ,state_is_tuple = True ):
78
+ super (MultiConvRNNCell , self ).__init__ ()
79
+ self ._cells = cells
80
+ self ._state_is_tuple = state_is_tuple
81
+
82
+ def init_hidden (self , batch_size ):
83
+ init_states = [] # this is a list of tuples
84
+ for i in xrange (len (self ._cells )):
85
+ init_states .append (self ._cells [i ].init_hidden (batch_size ))
86
+ return init_states
87
+
88
+ def forward (self , input , hidden_state ):
89
+ cur_inp = input
90
+ new_states = []
91
+ for i , cell in enumerate (self ._cells ):
92
+ cur_state = hidden_state [i ]
93
+ # print 'cur_inp size :', cur_inp.size()
94
+ # print 'cur_state size :', cur_state[0].size()
95
+ # print type(cur_inp.data),type(cur_state[0].data)
96
+ cur_inp , new_state = cell (cur_inp , cur_state )
97
+ # print 'cur_inp size :',cur_inp.size()
98
+ # print 'cur_state size :', cur_state[0].size()
99
+ new_states .append (new_state )
100
+
101
+ new_states = tuple (new_states )
102
+ return cur_inp ,new_states
103
+
104
+
105
+
106
+ class CLSTM (nn .Module ):
107
+ """Initialize a basic Conv LSTM cell.
108
+ Args:
109
+ shape: int tuple thats the height and width of the hidden states h and c()
110
+ filter_size: int that is the height and width of the filters
111
+ num_features: int thats the num of channels of the states, like hidden_size
112
+
113
+ """
114
+
115
+ def __init__ (self , shape , input_chans , filter_size , num_features , num_layers ):
116
+ super (CLSTM , self ).__init__ ()
117
+
118
+ self .shape = shape # H,W
119
+ self .input_chans = input_chans
120
+ self .filter_size = filter_size
121
+ self .num_features = num_features
122
+ self .num_layers = num_layers
123
+ cell_list = []
124
+ cell_list .append (
125
+ CLSTM_cell (self .shape , self .input_chans , self .filter_size , self .num_features ).cuda ()) # the first
126
+ # one has a different number of input channels
127
+
128
+ for idcell in xrange (1 , self .num_layers ):
129
+ cell_list .append (CLSTM_cell (self .shape , self .num_features , self .filter_size , self .num_features ).cuda ())
130
+ self .cell_list = nn .ModuleList (cell_list )
131
+
132
+ def forward (self , input , hidden_state ):
133
+ """
134
+ args:
135
+ hidden_state:list of tuples, one for every layer, each tuple should be hidden_layer_i,c_layer_i
136
+ input is the tensor of shape seq_len,Batch,Chans,H,W
137
+ """
138
+
139
+ # current_input = input.transpose(0, 1) # now is seq_len,B,C,H,W
140
+ current_input = input
141
+ next_hidden = [] # hidden states(h and c)
142
+ seq_len = current_input .size (0 )
143
+
144
+ for idlayer in xrange (self .num_layers ): # loop for every layer
145
+
146
+ hidden_c = hidden_state [idlayer ] # hidden and c are images with several channels
147
+ all_output = []
148
+ output_inner = []
149
+ for t in xrange (seq_len ): # loop for every step
150
+ hidden_c = self .cell_list [idlayer ](current_input [t , ...],
151
+ hidden_c ) # cell_list is a list with different conv_lstms 1 for every layer
152
+
153
+ output_inner .append (hidden_c [0 ])
154
+
155
+ next_hidden .append (hidden_c )
156
+ print output_inner [0 ].size ()
157
+ current_input = torch .cat (output_inner , 0 ).view (current_input .size (0 ),
158
+ * output_inner [0 ].size ()) # seq_len,B,chans,H,W
159
+ print current_input .size ()
160
+ return next_hidden , current_input
161
+
162
+ def init_hidden (self , batch_size ):
163
+ init_states = [] # this is a list of tuples
164
+ for i in xrange (self .num_layers ):
165
+ init_states .append (self .cell_list [i ].init_hidden (batch_size ))
166
+ return init_states
167
+
168
+ if __name__ == '__main__' :
169
+
170
+ ###########Usage#######################################
171
+ num_features = 64
172
+ filter_size = 5
173
+ batch_size = 8
174
+ shape = (120 , 120 ) # H,W
175
+ inp_chans = 3
176
+ nlayers = 2
177
+ seq_len = 10
178
+
179
+ # If using this format, then we need to transpose in CLSTM
180
+ input = Variable (torch .rand ( seq_len ,batch_size , inp_chans , shape [0 ], shape [1 ])).cuda ()
181
+
182
+ conv_lstm = CLSTM (shape , inp_chans , filter_size , num_features , nlayers )
183
+ conv_lstm .apply (weights_init )
184
+ conv_lstm .cuda ()
185
+
186
+ print 'convlstm module:' , conv_lstm
187
+
188
+ print 'params:'
189
+ params = conv_lstm .parameters ()
190
+ for p in params :
191
+ print 'param ' , p .size ()
192
+ print 'mean ' , torch .mean (p )
193
+
194
+ hidden_state = conv_lstm .init_hidden (batch_size )
195
+ print 'hidden_h shape ' , len (hidden_state )
196
+ print 'hidden_h shape ' , hidden_state [0 ][0 ].size ()
197
+ out = conv_lstm (input , hidden_state )
198
+ print 'out shape' , out [1 ].size ()
199
+ print 'len hidden ' , len (out [0 ])
200
+ print 'next hidden' , out [0 ][0 ][0 ].size ()
201
+ print 'convlstm dict' , conv_lstm .state_dict ().keys ()
202
+
203
+ # L = torch.sum(out[1])
204
+ # L.backward()
0 commit comments