Skip to content

Commit af29c94

Browse files
committed
initial commit
1 parent c970851 commit af29c94

File tree

4 files changed

+84
-100
lines changed

4 files changed

+84
-100
lines changed

.idea/vcs.xml

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

+54-15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

model.py

+16-50
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,12 @@ def __init__(self,
136136
# layers
137137
layer_list = OrderedDict()
138138
for l in range(len(self.num_features_list)):
139-
if tt.arg.inter_deactivate:
140-
layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features * 2,
141-
out_channels=self.num_features_list[l],
142-
kernel_size=1,
143-
bias=False)
144-
else:
145-
layer_list['conv{}'.format(l)] = nn.Conv2d(
146-
in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
147-
out_channels=self.num_features_list[l],
148-
kernel_size=1,
149-
bias=False)
139+
140+
layer_list['conv{}'.format(l)] = nn.Conv2d(
141+
in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
142+
out_channels=self.num_features_list[l],
143+
kernel_size=1,
144+
bias=False)
150145
layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
151146
)
152147
layer_list['relu{}'.format(l)] = nn.LeakyReLU()
@@ -170,10 +165,7 @@ def forward(self, node_feat, edge_feat):
170165
# compute attention and aggregate
171166
aggr_feat = torch.bmm(torch.cat(torch.split(edge_feat, 1, 1), 2).squeeze(1), node_feat)
172167

173-
if tt.arg.inter_deactivate:
174-
node_feat = torch.cat([node_feat, aggr_feat.split(num_data, 1)[0]], -1).transpose(1, 2)
175-
else:
176-
node_feat = torch.cat([node_feat, torch.cat(aggr_feat.split(num_data, 1), -1)], -1).transpose(1, 2)
168+
node_feat = torch.cat([node_feat, torch.cat(aggr_feat.split(num_data, 1), -1)], -1).transpose(1, 2)
177169

178170
# non-linear transform
179171
node_feat = self.network(node_feat.unsqueeze(-1)).transpose(1, 2).squeeze(-1)
@@ -251,41 +243,15 @@ def forward(self, node_feat, edge_feat):
251243
dsim_val = 1.0 - sim_val
252244

253245

254-
if tt.arg.inter_deactivte:
255-
diag_mask = 1.0 - torch.eye(node_feat.size(1)).unsqueeze(0).unsqueeze(0).repeat(node_feat.size(0), 1, 1,
256-
1).to(tt.arg.device)
257-
edge_feat = edge_feat * diag_mask.detach()
258-
259-
edge_feat_1 = edge_feat
260-
edge_feat_2 = 1 - edge_feat
261-
262-
merge_sum_1 = torch.sum(edge_feat_1, -1, True)
263-
edge_feat_1 = F.normalize(sim_val * edge_feat_1, p=1, dim=-1) * merge_sum_1
264-
265-
merge_sum_2 = torch.sum(edge_feat_2, -1, True)
266-
edge_feat_2 = F.normalize(dsim_val * edge_feat_2, p=1, dim=-1) * merge_sum_2
267-
268-
edge_feat = torch.cat([edge_feat_1, edge_feat_2], 1)
269-
270-
force_edge_feat = torch.cat((torch.eye(node_feat.size(1)).unsqueeze(0),
271-
torch.zeros(node_feat.size(1), node_feat.size(1)).unsqueeze(0)),
272-
0).unsqueeze(0).repeat(node_feat.size(0), 1, 1, 1).to(tt.arg.device)
273-
274-
edge_feat = edge_feat + force_edge_feat.detach()
275-
276-
edge_feat = edge_feat + 1e-6
277-
edge_feat = edge_feat / torch.sum(edge_feat, dim=1).unsqueeze(1).repeat(1, 2, 1, 1)
278-
279-
else:
280-
diag_mask = 1.0 - torch.eye(node_feat.size(1)).unsqueeze(0).unsqueeze(0).repeat(node_feat.size(0), 2, 1, 1).to(tt.arg.device)
281-
edge_feat = edge_feat * diag_mask
282-
merge_sum = torch.sum(edge_feat, -1, True)
283-
# set diagonal as zero and normalize
284-
edge_feat = F.normalize(torch.cat([sim_val, dsim_val], 1) * edge_feat, p=1, dim=-1) * merge_sum
285-
force_edge_feat = torch.cat((torch.eye(node_feat.size(1)).unsqueeze(0), torch.zeros(node_feat.size(1), node_feat.size(1)).unsqueeze(0)), 0).unsqueeze(0).repeat(node_feat.size(0), 1, 1, 1).to(tt.arg.device)
286-
edge_feat = edge_feat + force_edge_feat
287-
edge_feat = edge_feat + 1e-6
288-
edge_feat = edge_feat / torch.sum(edge_feat, dim=1).unsqueeze(1).repeat(1, 2, 1, 1)
246+
diag_mask = 1.0 - torch.eye(node_feat.size(1)).unsqueeze(0).unsqueeze(0).repeat(node_feat.size(0), 2, 1, 1).to(tt.arg.device)
247+
edge_feat = edge_feat * diag_mask
248+
merge_sum = torch.sum(edge_feat, -1, True)
249+
# set diagonal as zero and normalize
250+
edge_feat = F.normalize(torch.cat([sim_val, dsim_val], 1) * edge_feat, p=1, dim=-1) * merge_sum
251+
force_edge_feat = torch.cat((torch.eye(node_feat.size(1)).unsqueeze(0), torch.zeros(node_feat.size(1), node_feat.size(1)).unsqueeze(0)), 0).unsqueeze(0).repeat(node_feat.size(0), 1, 1, 1).to(tt.arg.device)
252+
edge_feat = edge_feat + force_edge_feat
253+
edge_feat = edge_feat + 1e-6
254+
edge_feat = edge_feat / torch.sum(edge_feat, dim=1).unsqueeze(1).repeat(1, 2, 1, 1)
289255

290256
return edge_feat
291257

0 commit comments

Comments
 (0)