Skip to content

Commit

Permalink
minor edti: attnwalk
Browse files Browse the repository at this point in the history
- disable l1 emb reg by default (not in the orig paper)
- use to_scipy_sparse_array instead of adjacency_matrix to suppress nx warnning
- improve organization os loss computation
  • Loading branch information
RemyLau committed Aug 6, 2023
1 parent 3ecd545 commit 2588bc4
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/obnb/ext/attnwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def attnwalk_embed(
walk_length: int = 80,
window_size: int = 5,
beta: float = 0.5,
gamma: float = 0.5,
gamma: float = 0.0,
epochs: int = 200,
lr: float = 0.01,
verbose: bool = False,
Expand All @@ -50,7 +50,7 @@ def attnwalk_embed(
walk_length: Random walk length.
window_size: Wiindow size.
beta: Attention l2 regularization parameter.
gamma: Embedding l2 regularization parameter.
gamma: Embedding l1 regularization parameter.
epochs: Training epochs.
lr: Learning rate.
device: Compute device.
Expand Down Expand Up @@ -153,7 +153,7 @@ def prepare_features(self):
where the first dimension is the powering number.
"""
adj_mat = nx.adjacency_matrix(self.g, dtype=np.float32).toarray()
adj_mat = nx.to_scipy_sparse_array(self.g, dtype=np.float32).toarray()

inv_sqrt_degs = 1 / np.sqrt(adj_mat.sum(0, keepdims=True))
norm_adj_mat = adj_mat * inv_sqrt_degs * inv_sqrt_degs.T
Expand Down Expand Up @@ -192,14 +192,17 @@ def forward(self):
target_mat = (target_tensor * attn).sum(0)

pred = torch.mm(self.left, self.right).sigmoid().clamp(EPS, 1 - EPS)
pos_loss = -(target_mat * torch.log(pred))
neg_loss = -(adj_opposite * torch.log(1 - pred))
nlgl = (self.walk_length * target_mat.shape[0] * pos_loss + neg_loss).mean()
pos_loss = -(target_mat * torch.log(pred)).mean()
neg_loss = -(adj_opposite * torch.log(1 - pred)).mean()
nlgl = self.walk_length * target_mat.shape[0] * pos_loss + neg_loss

attn_reg = self.beta * self.attn_weights.norm(2).pow(2)
emb_reg = self.gamma * (self.left.abs().mean() + self.right.abs().mean())
attn_reg = self.attn_weights.norm(2).pow(2)

loss = nlgl + attn_reg + emb_reg
if self.gamma > 0:
loss = nlgl + self.beta * attn_reg
else:
emb_reg = self.left.abs().mean() + self.right.abs().mean()
loss = nlgl + self.beta * attn_reg + self.gamma * emb_reg

return loss

Expand Down

0 comments on commit 2588bc4

Please sign in to comment.