You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Can this be fused, really? Maybe at some point in the future we write Triton for everything
There are a lot of things that we can fuse.
For example, xr = hidden_states + xx * self.x_r, and xw, xk, xv, xa, xg...
We can use torch.addcmul. It will use FP32 (instead of BF16) and is faster than the original implementation (even faster than Triton in most situations).
I opened a PR to implement the fused kernel.
There are many other things that can be fused and implemented with Triton, but we must be careful because in many cases, an unoptimized fused PyTorch kernel is faster.
Proposal
e.g. https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/rwkv7.py#L147C1-L158C1
Either write a fused triton kernel or use torch.jit or whatever to reduce the I/O cost
Rationale
these elementwise ops are super slow.
The text was updated successfully, but these errors were encountered: