Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] support log-bmm to context-free grammars #6

Open
sustcsonglin opened this issue Aug 9, 2020 · 22 comments
Open

[feature request] support log-bmm to context-free grammars #6

sustcsonglin opened this issue Aug 9, 2020 · 22 comments

Comments

@sustcsonglin
Copy link

I found log-bmm very useful for linear-chain CRF to save memory and speed up, while in context-free grammars, A->BC requires amounts of GPU memories, which is more serious. So it is difficult to increase the number of non-terminals or terminals under single graphical-card situation.

@srush
Copy link
Contributor

srush commented Aug 10, 2020

It should just work? Is there a bug?

@sustcsonglin
Copy link
Author

I am afraid that it does not work in CFGs. logbmm can only pass two tensors, and the dimension of tensors is 3.

For example, in compound PCFG, we have rule A->BC for each sentence, the size is (B, NT, NT, NT) (we only consider non-terminal here), the shape for B and C is (B, n-w, w, NT) (we are using linear-scan here).

If we want to get the final A, which has the shape (B, n-w, NT), It seems that we have to create a temporary tensor (B, n-w, w, NT, NT) to combine B and C, then reshape it to (B * n-w * w, NT*NT, 1). the grammar A->BC need to expand to (B,*n-w w, NT, NTNT) to apply logbmm function., in this case we still need memory(B, n-w *w, NT, NT, NT) since expanding A->BC not works because logbmm needs contiguous tensor. both n-w * w and NT^3 can be very huge.

I believe there is an inherent difference between CFGs and linear-chain models. The most ideal situation is that we can directly combine B, C and A->BC to A. (B, n-w, w, NT) + (B, n-w, w, NT) + (B, NT, NT, NT) -> (B, n-w, w, NT) -> (B, n-w, NT) without any bigger intermediate tensors than (B, n-w, w, NT)

@srush
Copy link
Contributor

srush commented Aug 10, 2020

I see what you mean. So you are suggesting writing another intermediate operator that directly does both combinations without storing intermediates.

Let's do this together. Maybe you can give a minimal suggestion of what that operator would need to look like?

One idea would be to support the operators in einsum https://pypi.org/project/opt-einsum/ directly?

@srush
Copy link
Contributor

srush commented Aug 10, 2020

Or perhaps you are just suggesting that genbmm should support broadcasting along the first dimension? Would that work? (A->BC could be size (1, NT, NT * NT) and still be contiguous without ever explicitly expanding right? )

@srush
Copy link
Contributor

srush commented Aug 10, 2020

I like the second solution better. If you are motivated to give it a try, here's how to do it.

  1. edit this line so you check both the size of a.size(0) and b.size(0) :
    https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L373

  2. set the block size based on the max of the two https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L385

  3. Pass the batch sizes into this function
    https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L398

  4. Instead of doing n here make an n_a and n_b variables that are always 0 if size is 1
    https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L33

5-9) Do the same for the backward version of the function
https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L437
https://github.com/harvardnlp/genbmm/blob/master/matmul_cuda_kernel.cu#L129

Test that it works here:
https://github.com/harvardnlp/genbmm/blob/master/genbmm/test_cuda.py#L12

@sustcsonglin
Copy link
Author

I found a slightly better way to reduce O(batch, n-w, w, A, B, C) to O(batch, n-w, A, B, C)

Instead of combing B and C first, we can combine the grammar rule A->BC and C first,
(batch, AB, C) + (batch, C, n-ww) -> (batch, AB, n-ww) -> (batch * n-w, AB, w) + (batch * n-w, w, C) -> (batch n-w, AB, C) -> (batch, n-w, A, BC) -> (batch, n-w, A)

But it still suffered from O(NT)^3.

I think previous is attracting. if logbmm can take three arguments: grammars (batch, A, B, C), left (batch, n-w, w, B), right (batch n-w, w, C). and design a kernel to support final[:, :, :, k] = logsumexp left[:, :, :, i] + right[:, :, :, j] + grammars[:, k, i, j] will be great, we only need o(batch, n-w, w, A) memory in this situation.

I found a similar library “Keops" for lazy reduction and supporting logsumexp, but i did not try yet.

@srush
Copy link
Contributor

srush commented Aug 10, 2020

Cool, yeah I played with keops a bit but it didn't performs as well as I would have liked. (see https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/semirings/keops.py ) . But I think maybe that was because I was trying to only do binary reductions. If you think the key here is a triple reduction, you should definitely try it out.

@srush
Copy link
Contributor

srush commented Aug 10, 2020

Btw, does this same issue appear for dependency parsing? It would be nice to have a kernel that wasn't so CFG specific.

@sustcsonglin
Copy link
Author

sustcsonglin commented Aug 10, 2020

no, it is not an issue for dependency parsing since dependency parsing does not have "non-terminals". Dependency parsing can be regarded as lexicalized CFGs with non-terminals is Null for dependency parsing and valence number of dependency model with valence (DMV).
do you suggest that I should modify your binary reduction kernel to triplet? I have no cuda programming experience, is it very difficult?

@srush
Copy link
Contributor

srush commented Aug 10, 2020

Oh no, you should definitely not try to do triplets in CUDA that would be really messy.

I think the right way to do this is to remove this expansion function https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/semirings/fast_semirings.py#L19 and instead implement binary broadcasting in cuda correctly. That way you just never create the bigger (A->BC) tensor.

However I think keops is worth exploring as well. You might consider just copying my code and trying out keops directly without any of the semiring stuff.

@sustcsonglin
Copy link
Author

Thank you, i'll have a try

@sustcsonglin
Copy link
Author

btw, i found the autograd of pytorch uses amounts of gpu memories to calculate gradient. if I use linear-scan to explicitly implement the outside algorithm and use inside-outside algorithm to compute the gradient, it saves 10x gpu memories and 1.5x faster, but it is annoying to implement outside-algorithm manually for each algorithms. Do you have any ideas for combining the advantage of both of them?

@srush
Copy link
Contributor

srush commented Aug 10, 2020

Which algorithm are you talking about particularly? Also what do you mean by linear scan here? I don't use linear-scan for any of the tree approaches.

I started by implementing backward manually it didn't actually make things much faster and it was difficult to use different semirings.

@sustcsonglin
Copy link
Author

eisner, zero-order cky and pcfgs and so on

I refer linear-scan as O(n) implementation here (considering all spans with the same width at the same time).
I am trying to combine the genbmm.logbmm function to outside algorithm now, it seems to reduce around 30x memories compares to your CKY_CRF implementation, which does not use logbmm and use autograd to calculate gradients.

@srush
Copy link
Contributor

srush commented Aug 10, 2020

Hmm, would be curious to know how CKY_CRF with logbmm compares to manual backward. Not sure where it is storing so much extra memory.

@sustcsonglin
Copy link
Author

In cky_crf
batch=10, length=50, NT=25, T=25,
logbmm + autograd saves 10x memory
logbmm + inside-outside saves 80x memory,
i have to re-compute many terms to save the space, it is a trade-off between speed and space, but it is not too slow, only 1.5x slower than the original implementation. While in zero-order cky, there is no need to re-compute anything, so I can reach 1.5x faster if I use inside-outside algorithm.

@srush
Copy link
Contributor

srush commented Aug 12, 2020

One really nice trick to save memory (without more code) is by recomputing is to use Checkpointing. It basically just automatically reruns forward for you.

Here is an example of that:
https://github.com/harvardnlp/pytorch-struct/blob/master/torch_struct/semirings/checkpoint.py

@srush
Copy link
Contributor

srush commented Aug 14, 2020 via email

@sustcsonglin
Copy link
Author

me too, I am going to do as follows:
https://stackoverflow.com/a/52916131

@srush
Copy link
Contributor

srush commented Aug 14, 2020

Yeah unfortunately I tried a bunch of these tricks and ended up realizing it is hard to beat the CUDA version. But I love that you are looking into this! Would be fantastic if there was a good trick here.

@sustcsonglin
Copy link
Author

sustcsonglin commented Aug 14, 2020

i found pieces of codes from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/hmm.py, and replace it with logbmm, it saves 10% memories, but it is around 2x slower than CUDA version

`class _SafeLog(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.log()

    @staticmethod
    def backward(ctx, grad):
        x, = ctx.saved_tensors
        return grad / x.clamp(min=torch.finfo(x.dtype).eps)


def safe_log(x):
    """
    Like :func:`torch.log` but avoids infinite gradients at log(0)
    by clamping them to at most ``1 / finfo.eps``.
    """
    return _SafeLog.apply(x)


def _logmatmulexp(x, y):
    """
    Numerically stable version of ``(x.log() @ y.log()).exp()``.
    """
    finfo = torch.finfo(x.dtype)  # avoid nan due to -inf - -inf
    x_shift = x.max(-1, keepdim=True).values.clamp(min=finfo.min)
    y_shift = y.max(-2, keepdim=True).values.clamp(min=finfo.min)
    xy = (torch.matmul((x - x_shift).exp(), (y - y_shift).exp())).log()
    return xy + x_shift + y_shift

@srush
Copy link
Contributor

srush commented Aug 14, 2020 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants