-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issues with implementation #2
Comments
Hi Felipe, thanks for your questions! They're actually very good questions and I'll try to resolve them:
I hope these could resolve your questions. But if you have any other questions, please feel free to let us know! |
Thanks for such helpful detailed response! The algorithm gets more beautiful the more you think about it. I guess I was trying to think about GSAT in terms of increasing graph sparsity, which is obviously not the point. If you don't mind me asking, if I had to extract a graph (nodes,edges), with the smallest possible number of relevant edges, what would you recommend? I can see two simple options:
But I am afraid that these heuristics might be a bit of arbitrary. Please let me if you recommend any particular approach, and, again, thank you so much! |
As the learned If you had to extract a subgraph, I think the options you mentioned make sense and I would do the same thing. Currently, we don't have a very good way to directly output a critical subgraph, as GSAT tells us the relative importance of edges. You could try:
|
I will be closing this issue now. But if you have any other questions, please feel free to let us know! |
Hi all, I tried to implement the info loss in my own GNN. I am using a custom convolution in a custom dataset that might have leakage, so this might be the source of error. But I am trying to understand why the model would behave the way it is behaving. I would appreciate any ideas/feedback.
My model is for link prediction on small subgraphs, where each for each edge I wanna predict, I sample a subgraph around it.
I am implementing the info_loss just like in your code:
info_loss = (edge_att * torch.log(edge_att/r + 1e-6) + (1-edge_att) * torch.log((1-edge_att)/(1-r+1e-6) + 1e-6)).mean()
If I don't use any sort of info loss, when I train my model, my edge attention looks like this:
If I use l1 loss (just minimizing edge_att.mean()), my edge attention looks like this:
If I use l1 loss, but multiply by 1e-3, it looks like this:
However, if I use the info loss proposed in your paper, my edge attention agglutinates in values close to r. For example, for r = 0.3, I get the attention distribution below. If I use r=0.5, then the dense part of the historgram moves to the middle, and if I use something like r=0.7 or r=0.9, then all my attention weights are closer to 1.
I tried to understand the intuition behind it by plotting the curve att x info_loss for different values of r
So basically the info_loss is approximately zero when closer to r, and positive everywhere else. This is forcing my model to try to have the attention always close to r (which I am not sure if I understand why), and apparently this is exactly what my model is doing. What confuses me is that in your paper, r is recommended to be between 0.5 and 0.9. However, in my current setting, this forces the majority of my edge attention to be > 0.5, instead of making them sparse.
I wonder if I am doing something wrong, if info_loss should have a smaller weight, or if my concrete_sampler should have a higher temperature to force a bernoulli-like distribution, or if maybe my model simply doesnt really need the edges, and it is ok with using any edge_attention value, hacking a way to get the same solution just based on node embedding, for example, without message passing. Maybe I have excess of dropout during training? (I do both node and edge dropout).
Please let me know if you have any ideas. Thanks in advance!
The text was updated successfully, but these errors were encountered: