-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature request] support log-bmm to context-free grammars #6
Comments
It should just work? Is there a bug? |
I am afraid that it does not work in CFGs. logbmm can only pass two tensors, and the dimension of tensors is 3. For example, in compound PCFG, we have rule A->BC for each sentence, the size is (B, NT, NT, NT) (we only consider non-terminal here), the shape for B and C is (B, n-w, w, NT) (we are using linear-scan here). If we want to get the final A, which has the shape (B, n-w, NT), It seems that we have to create a temporary tensor (B, n-w, w, NT, NT) to combine B and C, then reshape it to (B * n-w * w, NT*NT, 1). the grammar A->BC need to expand to (B,*n-w w, NT, NTNT) to apply logbmm function., in this case we still need memory(B, n-w *w, NT, NT, NT) since expanding A->BC not works because logbmm needs contiguous tensor. both n-w * w and NT^3 can be very huge. I believe there is an inherent difference between CFGs and linear-chain models. The most ideal situation is that we can directly combine B, C and A->BC to A. (B, n-w, w, NT) + (B, n-w, w, NT) + (B, NT, NT, NT) -> (B, n-w, w, NT) -> (B, n-w, NT) without any bigger intermediate tensors than (B, n-w, w, NT) |
I see what you mean. So you are suggesting writing another intermediate operator that directly does both combinations without storing intermediates. Let's do this together. Maybe you can give a minimal suggestion of what that operator would need to look like? One idea would be to support the operators in einsum https://pypi.org/project/opt-einsum/ directly? |
Or perhaps you are just suggesting that genbmm should support broadcasting along the first dimension? Would that work? (A->BC could be size (1, NT, NT * NT) and still be contiguous without ever explicitly expanding right? ) |
I like the second solution better. If you are motivated to give it a try, here's how to do it.
5-9) Do the same for the backward version of the function Test that it works here: |
I found a slightly better way to reduce O(batch, n-w, w, A, B, C) to O(batch, n-w, A, B, C) Instead of combing B and C first, we can combine the grammar rule A->BC and C first, But it still suffered from O(NT)^3. I think previous is attracting. if logbmm can take three arguments: grammars (batch, A, B, C), left (batch, n-w, w, B), right (batch n-w, w, C). and design a kernel to support final[:, :, :, k] = logsumexp left[:, :, :, i] + right[:, :, :, j] + grammars[:, k, i, j] will be great, we only need o(batch, n-w, w, A) memory in this situation. I found a similar library “Keops" for lazy reduction and supporting logsumexp, but i did not try yet. |
Cool, yeah I played with keops a bit but it didn't performs as well as I would have liked. (see https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/semirings/keops.py ) . But I think maybe that was because I was trying to only do binary reductions. If you think the key here is a triple reduction, you should definitely try it out. |
Btw, does this same issue appear for dependency parsing? It would be nice to have a kernel that wasn't so CFG specific. |
no, it is not an issue for dependency parsing since dependency parsing does not have "non-terminals". Dependency parsing can be regarded as lexicalized CFGs with non-terminals is Null for dependency parsing and valence number of dependency model with valence (DMV). |
Oh no, you should definitely not try to do triplets in CUDA that would be really messy. I think the right way to do this is to remove this expansion function https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/semirings/fast_semirings.py#L19 and instead implement binary broadcasting in cuda correctly. That way you just never create the bigger (A->BC) tensor. However I think keops is worth exploring as well. You might consider just copying my code and trying out keops directly without any of the semiring stuff. |
Thank you, i'll have a try |
btw, i found the autograd of pytorch uses amounts of gpu memories to calculate gradient. if I use linear-scan to explicitly implement the outside algorithm and use inside-outside algorithm to compute the gradient, it saves 10x gpu memories and 1.5x faster, but it is annoying to implement outside-algorithm manually for each algorithms. Do you have any ideas for combining the advantage of both of them? |
Which algorithm are you talking about particularly? Also what do you mean by linear scan here? I don't use linear-scan for any of the tree approaches. I started by implementing backward manually it didn't actually make things much faster and it was difficult to use different semirings. |
eisner, zero-order cky and pcfgs and so on I refer linear-scan as O(n) implementation here (considering all spans with the same width at the same time). |
Hmm, would be curious to know how CKY_CRF with logbmm compares to manual backward. Not sure where it is storing so much extra memory. |
In cky_crf |
One really nice trick to save memory (without more code) is by recomputing is to use Checkpointing. It basically just automatically reruns forward for you. Here is an example of that: |
Interesting, I did try this originally, but it is not as precise. I got
some underflow errors.
…On Fri, Aug 14, 2020 at 8:06 AM sustcsonglin ***@***.***> wrote:
i found that we can make use of highly-optimized matrix multiplications in
exp space (instead of log space)
i can simply write an alternative to logbmm:
def logsumexp(a,b):
m = torch.max(torch.max(a), torch.max(b))
a1 = (a-m).exp()
b1 = (b-m).exp()
c = torch.matmul(a1, b1)
return c.log().add_(2*m)
it uses less memory than "logbmm", we can make use of the highly optimized
operation "torch.matmul" here.
Similarly, we can make use of library "opt_einsum" to handle the triplet
situation,
def logsumexp_V2(a, b, c):
# shape of a (b, n, w ,Y) left span
# shape of b (b, n, w, Z) right span
# shape of c (b, X, Y, Z, ) grammar rules.
ma = torch.max(a)
mb = torch.max(b)
mc = torch.max(c)
m = torch.max(torch.max(ma, mb), mc)
a1 = (a-m).exp()
b1 = (b-m).exp()
c1 = (c-m).exp()
res = contract("bnwy, bnwz, bxyz -> bx", a1,b1,c1 , backend='torch')
return res.log().add_(3*m)
it saves amounts of memory, and my issues have been solved.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#6 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAIYKQTZJRVGJ6ZKY3IU53SAUSEJANCNFSM4PY6W5KQ>
.
|
me too, I am going to do as follows: |
Yeah unfortunately I tried a bunch of these tricks and ended up realizing it is hard to beat the CUDA version. But I love that you are looking into this! Would be fantastic if there was a good trick here. |
i found pieces of codes from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/hmm.py, and replace it with logbmm, it saves 10% memories, but it is around 2x slower than CUDA version
|
Isn't this code the same as from the stackoverflow above?
…On Fri, Aug 14, 2020 at 1:38 PM sustcsonglin ***@***.***> wrote:
i found pieces of codes from
https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/hmm.py, and
replace it with logbmm, it saves 10% memories, but it is around 2x slower
than CUDA version
`class _SafeLog(torch.autograd.Function):
@staticmethod <https://github.com/staticmethod>
def forward(ctx, x):
ctx.save_for_backward(x)
return x.log()
@staticmethod
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad / x.clamp(min=torch.finfo(x.dtype).eps)
def safe_log(x):
"""
Like :func:torch.log but avoids infinite gradients at log(0)
by clamping them to at most 1 / finfo.eps.
"""
return _SafeLog.apply(x)
def _logmatmulexp(x, y):
"""
Numerically stable version of (x.log() @ y.log()).exp().
"""
finfo = torch.finfo(x.dtype) # avoid nan due to -inf - -inf
x_shift = x.max(-1, keepdim=True).values.clamp(min=finfo.min)
y_shift = y.max(-2, keepdim=True).values.clamp(min=finfo.min)
xy = (torch.matmul((x - x_shift).exp(), (y - y_shift).exp())).log()
return xy + x_shift + y_shift
`
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#6 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAIYKTS4IZU523FLZB5IVLSAVZBFANCNFSM4PY6W5KQ>
.
|
I found log-bmm very useful for linear-chain CRF to save memory and speed up, while in context-free grammars, A->BC requires amounts of GPU memories, which is more serious. So it is difficult to increase the number of non-terminals or terminals under single graphical-card situation.
The text was updated successfully, but these errors were encountered: