Skip to content

Support for KV caching and batched inference #1934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mseeger
Copy link
Contributor

@mseeger mseeger commented Feb 6, 2025

Adds abstraction for key-value caches, implements batched inference.

I am also adding two baseline KV caches, the default one from before (all KV are stored) and a last-recent one.

The abstraction contains methods not used by these baselines, but they are required to implement more advanced KV caches such as Heavy Hitter Oracle (H2O).

I have implemented some of these, but I may not be allowed to contribute them here (working for a company). I'll see what I can do.

@mseeger mseeger requested review from lantiga and t-vi as code owners February 6, 2025 09:37
@t-vi
Copy link
Collaborator

t-vi commented Feb 6, 2025

Hey, great work @mseeger .

Can we decouple things a lot, though?

Some initial thoughts:

  • I would prefer if we kept the KVCache initialization as in the current version (i.e. that you initialize the model, then potentially adjust the max seq len and then initialize the KVCache) in this PR. Adding this to the init parameters seems orthogonal to the other changes.
  • We do have batched generation today. Can we please split changes to batched generation from the KVCache improvements. We probably don't want to do batching via lists of tensors. I'm currently looking at passing in "packed" input/input_pos sequences but these changes. Changing the existing tests should be a bit of a red flag, as it will screw existing users to change the API (we can do this if we need to, but TBH I am not convinced this is the case).
  • In general, can we be very conservative with adding arguments? For optional arguments, we should look into making them keyword-only unless there is a good reason not to.
  • We do keep control flow simple. self._default_kv_cache = False is not a good idea.
  • I'm not sure I understand the both_in_parallel. Maybe the right time to add it and the associated refactors is when they are used?
  • I'm generally a bit weary of the amount of data structures and cases that are being passed around here, those add a lot of complexity. To my mind, this likely means that the right abstraction has not yet been found. Maybe integrating KVCache and SDPA more could be a thing, but I am not sure.
  • In general, we do not want to do the cache setup during the forward. Please keep the initialization separate. I think we are rather seeing movement towards less of a distinction between pre-fill and next token, so this seems a bit in the wrong direction.

Again, super good stuff in the PR! I think there are a few things to split out and consider individually and then maybe we can have a video call about the core KVCache things, wdyt?

Thanks for the initiative for better KVCacheing!

@mseeger
Copy link
Contributor Author

mseeger commented Feb 6, 2025

Hello, sure we can have a call, I am in the central Europe (Germany) time zone.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 6, 2025

My impression was that batched generation is not really there. But if it is, I don't ask to change it.

One thing is important through. KV caches really work by filling positions sequentially. So, you filled positions 0:(T-1), you need to continue with T, or with T:(T+k). The current API of just passing some position indexes is really not going to work.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 6, 2025

Also, the implementation right now allows you to send in KV cache objects from the start. If you do not do that, it will create them by default. This is done by set_kv_cache. If you also do not do that, it is done in the first forward with for_prefill=True.

Note that prefill here means that I can do a single pass, and the cache can take it all, without having to evict anything. It does not mean that this will encode even the shortest prompt in the batch. If prompts are longer than the max prefill length, you need to do it sequentially in chunks.

Maybe there is an easier way, we can discuss.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 6, 2025

It is annoying I cannot show you the KV cache code I have. But in a talk, I could explain why a few things are the way they are. Of course, I am not on top of other constraints you guys have.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 6, 2025

You may ask why KVCache.prefill? The main reason is that you want to use SDPA whenever you can, but SDPA cannot return the attention weights, which some KV cache algorithms (H2O) need in order to decide what to evict next.

We can do things so the very first call to the model, with input_pos=0, is doing this. So, instead of

model(x, for_prefill=True)

you'd call

model(x, input_pos=0)

This I could do. That would indeed be a little simpler.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 6, 2025

@t-vi Let me know what the next steps here should be. If I understand correctly, I could:

  • Get rid of for_prefill parameter, and use input_pos=0 instead
  • Don't create default KV cache in forward and rather fail the call if input_pos is used, s.t. user needs to call set_kv_cache
  • You don't seem to approve of passing the KV caches at construction (if user does not want to use default ones). Would you rather use set_kv_cache for that?

@t-vi
Copy link
Collaborator

t-vi commented Feb 9, 2025

Hi, so I think we should try to break things down.

We could either start with the core caching itself and try to see how to integrate it with minimal changes or see what is the deal with batching and prefill first.
I sent to your gmail address to find a good time to discuss.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 21, 2025

Hello @t-vi , let me try to break things down. Changes are these:

  1. KVCache and its implementations. This replaces the default cache, which just stores everything. No behavior changes.
  2. Caches for each layer can be passed when model is created. Before, there is set_kvcache, which creates the default
    caches. If nothing is done at all, default caches are created when first needed. This is a change. Before, it would create
    an exception.
  3. Refactoring of generation code: This works for batch generation now, and single sequence generation is a special case.
    Inside, this properly supports large prompts by splitting generation into prefill (as large as caches allow), and then
    aequential blocks of desired length.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 21, 2025

If I understand you correctly, you complain about 2., especially the automatic creation of default cache when nothing is done, and the change of __init__ of GPT. This, I can work on. I could to the following:

  • Allow passing KV caches per layer in set_kvcache (or have another method?)
  • Create default KV caches by calling set_kvcache. If this is not done, calling forward for inference fails, so no cache is created automatically

Would that be what you prefer?

@mseeger
Copy link
Contributor Author

mseeger commented Feb 21, 2025

As for 1. and 3., in the end, they go together, but I can try split it into two. I'd first do 1., keeping the generation code in place, which would however not work for batches and not support the sequential processing of prompts properly.

First doing 3. is not really sensible, because it requires things from 1.

What do you think?

@mseeger
Copy link
Contributor Author

mseeger commented Feb 21, 2025

Note that with DeepSeek (I am involved trying to bring this to Hugging Face), there is a lot of movement now not to ignore KV caching in the future. They even released a paper now how they can train with large contexts.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 24, 2025

OK, I did 2) AFAI understand. I'd work on 1) once I find time.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 24, 2025

No idea why all these tests are failing. Tests work for me locally.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 26, 2025

@t-vi Maybe I can change your mind about first keeping the current generation code in place, and only contribute the KV cache support?

This is quite a bit of extra work for me, and new code of mine has a number of improvements. in particular, the current code does not really do batch generation, it is marked with several TODO and is not used.

If we could have a chat, I'd appreciate that.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 27, 2025

Your CI system seems to be broken still.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 27, 2025

Out of curiosity: Why do you object to batch prompts being a list of tensors? In the end, they can have wildly different lengths, and there is not much you can do against that (sure, if you get lots of requests, you can maybe cluster them, but doing this too much delays requests, so increases latency).

Also, you really don't want to push PAD tokens into models just because a prompt in a batch happens to be shorter than others. The model, not being trained on this, would certainly get confused. And since you need to start token-by-token forward for generation, you really gain nothing by padding prompts.

I always thought if this as some kind of TensorFlow artefact when all tensors had to be allocated up front, etc. But I thought we have overcome this with PyTorch.

@t-vi
Copy link
Collaborator

t-vi commented Feb 27, 2025

Hey, sorry, I am totally swamped, still want to have a video call to chat.

Out of curiosity: Why do you object to batch prompts being a list of tensors? In the end, they can have wildly different lengths, and there is not much you can do against that (sure, if you get lots of requests, you can maybe cluster them, but doing this too much delays requests, so increases latency).

Because lists are a lot less nice to work with in various setups passing to kernels, cudagraphs etc.

For somewhat homogeneous seq lengths, padding works fine. We are using it in production, so I'm doubting claims that it does not work. It does have limitations with the inhomogenous sequence lengths, which we want to support.

But the proper way to support this is packed sequences, i.e. pass in flat input_tokens, input_pos (i.e. 1d shape, no batch index) and then batch_seq_lens of shape batch_size.
batch_seq_lens gives lengths for each batch item (and might even be 0).

This is hugely more flexible. It needs FlexAttention or somesuch https://pytorch.org/docs/stable/nn.attention.flex_attention.html to make it work efficiently in stock PyTorch.

@mseeger
Copy link
Contributor Author

mseeger commented Feb 27, 2025

Let me know when is a good time. I am in Europe time zone

@mseeger
Copy link
Contributor Author

mseeger commented Mar 5, 2025

After our call, I think I understand more what you mean. Something like an abstraction in multi-head attention, where the input are keys, values, query for the current input chunk, all the same size, but then this is bundled:

  • Take in keys, values and replace with KV-cached ones, so now keys, values are larger
  • Do the SDPA computation
  • Feed attention weights back to KV cache if needed
  • Return MHA outputs before final linear mapping

This makes a lot of sense, and is quite elegant.

@Borda Borda added the enhancement New feature or request label Mar 12, 2025
@Borda
Copy link
Member

Borda commented Mar 18, 2025

Your CI system seems to be broken still.

@mseeger shall be fixed now, thank you for your patience :)

@mseeger
Copy link
Contributor Author

mseeger commented Mar 19, 2025

As discussed with @t-vi , I'll refactor this as stated in the comment above. Makes total sense

@mseeger mseeger force-pushed the kvcache3 branch 3 times, most recently from 8c220a3 to 4122f2e Compare March 27, 2025 16:31
@mseeger
Copy link
Contributor Author

mseeger commented Mar 27, 2025

@t-vi , is this what you had in mind? KVCache is now very simple, and so the code in CausalSelfAttention is not polluted by details of KV caching. I added DefaultKVCache which most KV caches will use.

@mseeger
Copy link
Contributor Author

mseeger commented Mar 27, 2025

I'd be OK to take out my batched inference part here, but this means there won't be any. Do you have plans to add the batched inference code you talked about any time soon?

@mseeger mseeger force-pushed the kvcache3 branch 2 times, most recently from a1e724a to b5a63e4 Compare April 4, 2025 19:32
@mseeger
Copy link
Contributor Author

mseeger commented Apr 4, 2025

OK, I've taken out the batched inference code. Still working on fixing the tests (and need to refactor speculative decoding), but this is essentially it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants