Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with implementation #2

Closed
felipemello1 opened this issue Feb 17, 2022 · 4 comments
Closed

Issues with implementation #2

felipemello1 opened this issue Feb 17, 2022 · 4 comments

Comments

@felipemello1
Copy link

felipemello1 commented Feb 17, 2022

Hi all, I tried to implement the info loss in my own GNN. I am using a custom convolution in a custom dataset that might have leakage, so this might be the source of error. But I am trying to understand why the model would behave the way it is behaving. I would appreciate any ideas/feedback.

My model is for link prediction on small subgraphs, where each for each edge I wanna predict, I sample a subgraph around it.

I am implementing the info_loss just like in your code:
info_loss = (edge_att * torch.log(edge_att/r + 1e-6) + (1-edge_att) * torch.log((1-edge_att)/(1-r+1e-6) + 1e-6)).mean()

If I don't use any sort of info loss, when I train my model, my edge attention looks like this:
image

If I use l1 loss (just minimizing edge_att.mean()), my edge attention looks like this:
image

If I use l1 loss, but multiply by 1e-3, it looks like this:
image

However, if I use the info loss proposed in your paper, my edge attention agglutinates in values close to r. For example, for r = 0.3, I get the attention distribution below. If I use r=0.5, then the dense part of the historgram moves to the middle, and if I use something like r=0.7 or r=0.9, then all my attention weights are closer to 1.
image

I tried to understand the intuition behind it by plotting the curve att x info_loss for different values of r
image

image

image

So basically the info_loss is approximately zero when closer to r, and positive everywhere else. This is forcing my model to try to have the attention always close to r (which I am not sure if I understand why), and apparently this is exactly what my model is doing. What confuses me is that in your paper, r is recommended to be between 0.5 and 0.9. However, in my current setting, this forces the majority of my edge attention to be > 0.5, instead of making them sparse.

I wonder if I am doing something wrong, if info_loss should have a smaller weight, or if my concrete_sampler should have a higher temperature to force a bernoulli-like distribution, or if maybe my model simply doesnt really need the edges, and it is ok with using any edge_attention value, hacking a way to get the same solution just based on node embedding, for example, without message passing. Maybe I have excess of dropout during training? (I do both node and edge dropout).

Please let me know if you have any ideas. Thanks in advance!

@siqim
Copy link
Member

siqim commented Feb 17, 2022

Hi Felipe, thanks for your questions! They're actually very good questions and I'll try to resolve them:

  1. We do not recommend using L1 norm loss on p, as this will encourage too much sparsity and can cause the model easy to collapse. In figure 7 of our paper, we observe similar things as you mentioned. If we replace the info_loss with the L1 loss, the model will be a lot more sensitive to the penalty coefficient. When the coefficient is not well-tuned, the model will collapse (learning p to be all zeros) at the early stage of training.

  2. Yes, the info_loss will push attention weights (denoted as p in figure 1 of our paper) around r, as info_loss is the KL divergence between them and KL loss will be zero only when p = r. However, as p represents the probability of keeping the corresponding edge during training, to make good predictions the other loss term (cross-entropy) will push to learn a larger p for edges that are critical for model predictions; otherwise those critical edges may have too large probabilities to be dropped and this may result in a large cross-entropy loss and hurt model predictions. So, it is as expected that many of the learned p will be close to r (as these edges may not affect model predictions), but those critical edges should have p to be greater than r to make sure the cross-entropy loss is small. Then, the ranking of p can represent the importance of edges.

  3. GSAT doesn't encourage generating sparse subgraphs. We find r = 0.7 can generally work very well for all datasets in our experiments, which means during training roughly 70% of edges will be kept (kind of still dense). This is because GSAT doesn't try to provide interpretability by finding a small/sparse subgraph of the original input graph (this is what previous literature does). Instead, it provides interpretability by pushing the critical edges to have relatively lower stochasticity during training.

  4. When r = 0.7, attention weights for non-critical edges should be around 0.7, and those for critical edges should be relatively larger. If in this case, all your attention weights are close to 1, then I guess probably (1) all your edges are important and dropping any of them would result in a large cross-entropy loss; or (2) your current cross-entropy loss is so large that the info_loss does not impose enough penalty and you may then need a smaller r, say 0.5. But there might be other reasons depending on your specific settings. You could tune your r from {0.5, 0.7, 0.9} based on your validation prediction performance.

  5. If you were referring sparsity as edge attention being 0 or 1 (Bernoulli-like distribution you mentioned), then yes the majority of edge attention will not be discrete during the testing phase. This is because the sampling procedure (Gumbel-softmax) is only done during training. So, during training, you should see sampled attention (denoted as \alpha in figure 1) to be (roughly) discrete; but when testing, \alpha = p, p is close to r and r > 0.5, hence the majority of p or \alpha should be > 0.5.

  6. If you don't use info_loss, this corresponds to the case \beta = 0 in table 4 of our paper, which means we don't control the values of the learned p at all. GSAT may still work and provide interpretability and generalizability in this case because the general idea of GSAT works. But from our ablation study, we find that this may suffer from initialization problems and cause the results to have high variance across different random seeds, as it may result in trivial solutions easily with bad network initializations when there is no regularization on p, e.g. learning p being all ones or learning large p for non-critical edges.

  7. We find the penalty coefficient (denoted as \beta in our paper) of info_loss does not need to be tuned extensively to get good results. So, we use 1/|E| for all datasets in our experiments (the .mean() is doing this), which can provide good enough results. Nonetheless, we indeed observe that tuning \beta can result in even better performance, but this may need a lot of extra effort. We did not try to tune the temperature in the Gumbel-softmax trick. So, it's possible to get better results by tuning it.

  8. We did not try to use node/edge dropout in our experiments. My intuition is that GSAT is already dropping node/edge for us, so GSAT may conflict with such regularization methods.

I hope these could resolve your questions. But if you have any other questions, please feel free to let us know!

@felipemello1
Copy link
Author

felipemello1 commented Feb 17, 2022

Thanks for such helpful detailed response! The algorithm gets more beautiful the more you think about it. I guess I was trying to think about GSAT in terms of increasing graph sparsity, which is obviously not the point.

If you don't mind me asking, if I had to extract a graph (nodes,edges), with the smallest possible number of relevant edges, what would you recommend? I can see two simple options:

  1. Threshold, for example edge_att > 0.9;
  2. Topk, for example, highest N scores;

But I am afraid that these heuristics might be a bit of arbitrary. Please let me if you recommend any particular approach, and, again, thank you so much!

@siqim
Copy link
Member

siqim commented Feb 17, 2022

As the learned p will be around r, what we do is to normalize p in [0, 1] and use the normalized p to represent the transparency of edges to plot a graph. See Line 96 and line 126 here.

If you had to extract a subgraph, I think the options you mentioned make sense and I would do the same thing. Currently, we don't have a very good way to directly output a critical subgraph, as GSAT tells us the relative importance of edges. You could try:

  1. First normalize p in [0, 1], then a good threshold can be easier to find. But probably you would need to try several thresholds, say some thresholds from [0.8, 1.0], and see which one would produce the most plausible subgraph. This approach needs some prior knowledge about the dataset though;
  2. If you don't have prior knowledge about the dataset, another way could be trying different thresholds to extract subgraphs and training new models only on these extracted subgraphs. Then, based on the results of those new models, choose the threshold that gives you the best validation performance. This approach sounds intuitive to me, but we haven't tried this and it may need a lot more computational costs;
  3. If you have a budget, top-k selection is perfect;

@siqim
Copy link
Member

siqim commented Feb 21, 2022

I will be closing this issue now. But if you have any other questions, please feel free to let us know!

@siqim siqim closed this as completed Feb 21, 2022
@siqim siqim pinned this issue Jul 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants