@@ -136,17 +136,12 @@ def __init__(self,
136
136
# layers
137
137
layer_list = OrderedDict ()
138
138
for l in range (len (self .num_features_list )):
139
- if tt .arg .inter_deactivate :
140
- layer_list ['conv{}' .format (l )] = nn .Conv2d (in_channels = self .num_features_list [l - 1 ] if l > 0 else self .in_features * 2 ,
141
- out_channels = self .num_features_list [l ],
142
- kernel_size = 1 ,
143
- bias = False )
144
- else :
145
- layer_list ['conv{}' .format (l )] = nn .Conv2d (
146
- in_channels = self .num_features_list [l - 1 ] if l > 0 else self .in_features * 3 ,
147
- out_channels = self .num_features_list [l ],
148
- kernel_size = 1 ,
149
- bias = False )
139
+
140
+ layer_list ['conv{}' .format (l )] = nn .Conv2d (
141
+ in_channels = self .num_features_list [l - 1 ] if l > 0 else self .in_features * 3 ,
142
+ out_channels = self .num_features_list [l ],
143
+ kernel_size = 1 ,
144
+ bias = False )
150
145
layer_list ['norm{}' .format (l )] = nn .BatchNorm2d (num_features = self .num_features_list [l ],
151
146
)
152
147
layer_list ['relu{}' .format (l )] = nn .LeakyReLU ()
@@ -170,10 +165,7 @@ def forward(self, node_feat, edge_feat):
170
165
# compute attention and aggregate
171
166
aggr_feat = torch .bmm (torch .cat (torch .split (edge_feat , 1 , 1 ), 2 ).squeeze (1 ), node_feat )
172
167
173
- if tt .arg .inter_deactivate :
174
- node_feat = torch .cat ([node_feat , aggr_feat .split (num_data , 1 )[0 ]], - 1 ).transpose (1 , 2 )
175
- else :
176
- node_feat = torch .cat ([node_feat , torch .cat (aggr_feat .split (num_data , 1 ), - 1 )], - 1 ).transpose (1 , 2 )
168
+ node_feat = torch .cat ([node_feat , torch .cat (aggr_feat .split (num_data , 1 ), - 1 )], - 1 ).transpose (1 , 2 )
177
169
178
170
# non-linear transform
179
171
node_feat = self .network (node_feat .unsqueeze (- 1 )).transpose (1 , 2 ).squeeze (- 1 )
@@ -251,41 +243,15 @@ def forward(self, node_feat, edge_feat):
251
243
dsim_val = 1.0 - sim_val
252
244
253
245
254
- if tt .arg .inter_deactivte :
255
- diag_mask = 1.0 - torch .eye (node_feat .size (1 )).unsqueeze (0 ).unsqueeze (0 ).repeat (node_feat .size (0 ), 1 , 1 ,
256
- 1 ).to (tt .arg .device )
257
- edge_feat = edge_feat * diag_mask .detach ()
258
-
259
- edge_feat_1 = edge_feat
260
- edge_feat_2 = 1 - edge_feat
261
-
262
- merge_sum_1 = torch .sum (edge_feat_1 , - 1 , True )
263
- edge_feat_1 = F .normalize (sim_val * edge_feat_1 , p = 1 , dim = - 1 ) * merge_sum_1
264
-
265
- merge_sum_2 = torch .sum (edge_feat_2 , - 1 , True )
266
- edge_feat_2 = F .normalize (dsim_val * edge_feat_2 , p = 1 , dim = - 1 ) * merge_sum_2
267
-
268
- edge_feat = torch .cat ([edge_feat_1 , edge_feat_2 ], 1 )
269
-
270
- force_edge_feat = torch .cat ((torch .eye (node_feat .size (1 )).unsqueeze (0 ),
271
- torch .zeros (node_feat .size (1 ), node_feat .size (1 )).unsqueeze (0 )),
272
- 0 ).unsqueeze (0 ).repeat (node_feat .size (0 ), 1 , 1 , 1 ).to (tt .arg .device )
273
-
274
- edge_feat = edge_feat + force_edge_feat .detach ()
275
-
276
- edge_feat = edge_feat + 1e-6
277
- edge_feat = edge_feat / torch .sum (edge_feat , dim = 1 ).unsqueeze (1 ).repeat (1 , 2 , 1 , 1 )
278
-
279
- else :
280
- diag_mask = 1.0 - torch .eye (node_feat .size (1 )).unsqueeze (0 ).unsqueeze (0 ).repeat (node_feat .size (0 ), 2 , 1 , 1 ).to (tt .arg .device )
281
- edge_feat = edge_feat * diag_mask
282
- merge_sum = torch .sum (edge_feat , - 1 , True )
283
- # set diagonal as zero and normalize
284
- edge_feat = F .normalize (torch .cat ([sim_val , dsim_val ], 1 ) * edge_feat , p = 1 , dim = - 1 ) * merge_sum
285
- force_edge_feat = torch .cat ((torch .eye (node_feat .size (1 )).unsqueeze (0 ), torch .zeros (node_feat .size (1 ), node_feat .size (1 )).unsqueeze (0 )), 0 ).unsqueeze (0 ).repeat (node_feat .size (0 ), 1 , 1 , 1 ).to (tt .arg .device )
286
- edge_feat = edge_feat + force_edge_feat
287
- edge_feat = edge_feat + 1e-6
288
- edge_feat = edge_feat / torch .sum (edge_feat , dim = 1 ).unsqueeze (1 ).repeat (1 , 2 , 1 , 1 )
246
+ diag_mask = 1.0 - torch .eye (node_feat .size (1 )).unsqueeze (0 ).unsqueeze (0 ).repeat (node_feat .size (0 ), 2 , 1 , 1 ).to (tt .arg .device )
247
+ edge_feat = edge_feat * diag_mask
248
+ merge_sum = torch .sum (edge_feat , - 1 , True )
249
+ # set diagonal as zero and normalize
250
+ edge_feat = F .normalize (torch .cat ([sim_val , dsim_val ], 1 ) * edge_feat , p = 1 , dim = - 1 ) * merge_sum
251
+ force_edge_feat = torch .cat ((torch .eye (node_feat .size (1 )).unsqueeze (0 ), torch .zeros (node_feat .size (1 ), node_feat .size (1 )).unsqueeze (0 )), 0 ).unsqueeze (0 ).repeat (node_feat .size (0 ), 1 , 1 , 1 ).to (tt .arg .device )
252
+ edge_feat = edge_feat + force_edge_feat
253
+ edge_feat = edge_feat + 1e-6
254
+ edge_feat = edge_feat / torch .sum (edge_feat , dim = 1 ).unsqueeze (1 ).repeat (1 , 2 , 1 , 1 )
289
255
290
256
return edge_feat
291
257
0 commit comments