Skip to content

Llama-3_1-Nemotron 51B support #726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Conversation

ymcki
Copy link

@ymcki ymcki commented Jan 28, 2025

Dear all,

I am the person who added Llama-3_1-Nemotron-51B support to llama.cpp.

ggml-org/llama.cpp#10669

I tried to add this support to exllamav2 and came up with a hack that
can convert and infer Llama-3_1-Nemotron-51B.

While the hack is working, I am not sure if it is the best way to implement it
as it changes quite a lot of code at exllamav2.

This is because the current exllamav2 codebase is not designed for the case
when different layers of an llm can have different number of key_value_heads and
different structures as in the case of DeciLMForCausalLM (this 51B model) and Apple's
OpenELMForCausalLM.

For this 51B model, there are three types of layers:

  1. Normal layer that is same as the llama3 model it is based on
  2. A linear attention layer that has no q_proj, k_proj, v_proj and o_proj but has
    a linear_attn which is simply matmuled with input.
  3. An attention-free layer that is simply a MLP layer.

Also, for this model, number of kv_heads and intermediate_size can be different
for different layers.

As a result, there are quite a lot of changes to the code in my fork. I also added
a file called linear_attn.py to define ExLlamaV2LinearAttention to handle the
linear attention layer.

While it can run without errors based on my testing so far, I am not sure if it
covers all situations. Maybe it will be better waiting for a rewrite that accomodates
these per layer variable models like DeciLMForCausalLM and OpenELMForCausalLM.

It would be great if this hack can serve as a starting point for such a rewrite and
allow me to add the support later for a cleaner contribution.

Thank you very much for your time.

@ymcki
Copy link
Author

ymcki commented Jan 29, 2025

Removed linear_attn.py by merging the ExLlamaV2LinearAttention class into ExLlamaV2Attention.

@turboderp
Copy link
Member

Sorry I've taken my time getting around to this. I'm not really in a position to test it thoroughly, but I can merge it into the dev branch if it's (sort of) working and you're pretty sure it isn't breaking anything else.

It will be very useful as a reference for ExLlamaV3 though, which takes a much more flexible approach (maybe trading off some performance, TBD). Based on this description supporting Nemotron (and similar architectures) should be trivial in V3.

@ymcki
Copy link
Author

ymcki commented Mar 12, 2025

Good to hear that there will be a exllamav3. I hope it will support architectures that have different config for different layers natively such that I can easily add support for this 51B model.

I don't think it is necessary to merge it to the dev branch. If people want to try it, they can download my fork.

Please let me know when exllamav3 is out.

@ymcki
Copy link
Author

ymcki commented Mar 31, 2025

A user of the nemotron model reported that he is seeing uneven distribution of VRAM usage when using multiple GPU with llama.cpp. I suggested this might be due to the first 40 layers has way more self attention layers than the second half. The user then suggested a way to improve the situation.

ggml-org/llama.cpp#12654

It would be great if this is take into account when exllamav3 is developed.

@turboderp
Copy link
Member

This is already how auto split works in ExLlamaV2, and V3 will have the same feature. It allocates layers (including cache tensors) on cuda:0 until cuda:0 has no more VRAM, then starts allocating on cuda:1, etc. (Note that parameter count is not enough to ensure an even split, though.)

The tensor-P split is simpler in principle, since every layer is distributed across all GPUs. The split can get a little complicated if the number of GPUs doesn't divide the number of keys, or if the GPUs differ in size, but dissimilar layers in the model wouldn't play a role in that.

Manual split I'm not sure about yet. It's trivial to just let the user select the number of layers to assign to each GPU, as a list. This provide a stable way to dial in the split, at least.

V2 took the approach of trying to estimate the total VRAM usage per layer (not including cache since cache is allocated separately from loading the model in that case), then letting the user select how much VRAM to use per device. This is very tricky to get right, though, since Torch's memory allocation behavior is unpredictable, subject to change with every new version, and not entirely accounted for in Torch's own API functions (since VRAM is also allocated by the CUDA runtime in response to what Torch is doing on top of that.)

@ymcki
Copy link
Author

ymcki commented Apr 1, 2025

Thanks a lot for your detailed reply.

Good to hear that exllamav2 already has a mechanism to deal with that. I suppose torch is something upstream that you don't have control, so probably don't need to think too much about it.

"V2 took the approach of trying to estimate the total VRAM usage per layer (not including cache since cache is allocated separately from loading the model in that case)" <= Why not including KV cache when you split layers? I think they can be easily estimated with formula for each layer. Ideally, layers should be split by taking into account of parameter size plus kv cache for each layer.

My hack doesn't change anything related to tensor parallel. Do you think it will work already? If so, I can ask that person who reported the problem to test it out.

@turboderp
Copy link
Member

It's not only that Torch is upstream, it's that it's a relatively high-level framework. That doesn't make it inefficient in every sense. I.e. as long as you can shape the workload so as to keep the GPU busy at all times, the abstractions don't necessarily slow anything down. But it does manage GPU resources dynamically, and trying to code around that everywhere kind of defeats the purpose. Especially if it's just to squeeze out that last 0.5% of VRAM only to crash anyway because the user resized a window at the wrong time.

The reason manual split doesn't account for the cache is that the cache size isn't known at the time the model is loaded. The cache is created separately, but you can't do that until the model is already split. Autosplit combines the two operations by first creating the cache with lazy initialization, then creating the actual K/V tensors while the model is loading. It works well for the "simple case" of a single model with a single, static cache, but it gets complicated if you have multiple models, multiple caches for the same model, models without cache, or whatever.

I suppose the regular split loader could also take a lazily initialized cache as an argument to account for it that way. I just never really built out the functionality because it's so unpredictable anyway, and autosplit seems to be perfectly sufficient most of the time. Keep in mind that autosplit can also reserve an arbitrary amount of VRAM on each device, so you could just reserve whatever the current max is minus what you want the split to be.

@ymcki
Copy link
Author

ymcki commented Apr 9, 2025

"The reason manual split doesn't account for the cache is that the cache size isn't known at the time the model is loaded. "

What do you mean by that? When the KV cache is created, user must supply the context length. With it, it is possible to calculate the KV cache size in advance. Here are the formulas for the 51B, 49B and 253B models

51B:
fp16_KV_cache_in_MB = (context_length)(28128442/(10241024)+2412812/(10241024)+2212862/(10241024)+2112832/(1024*1024))

49B:
fp16_KV_cache_in_MB = (context_length)(28128492/(10241024))

253B:
fp16_KV_cache_in_MB = (context_length)(28128642/(10241024))

@turboderp
Copy link
Member

I've added DeciLM support to ExLlamaV3 here

Much simpler with some better assumptions in the underlying framework, but also this PR helped a lot as a reference. I'm not sure the V3 implementation is 100% correct, but it seems to be working. I skipped over sliding window support, the attention sink stuff and linear attention, since I've only tested on some Llama-based models that don't seem to use any of those features. Also not sure what the attention sink business is exactly, would have to read up on that.

As for the cache stuff, it's not known at load time because the cache is a separate object from the model. So by the time you load the model (in exl2, with manual split) it can only split the model itself. But to do this correctly, it needs to take the cache size into account. But to allocate the cache tensors on the right devices you need to know how the model is going to be split beforehand.

I.e. you need to specify the cache size, then split/load the model, then allocate cache tensors. Lazy initialization does this in exl2, but only for autosplit mode. I never got around to adding that to the manual mode because the overall VRAM estimates are very shaky anyway. It has to also try to account for intermediates like attention scores etc., with all the parameters that go into that, and that's another chicken/egg situation all on its own because which layers end up getting included on a device also determines the need for intermediate storage.

Of course, autosplit doesn't need to care because it relies on catching an OoM exception to know when a device is actually full, and it runs inference while loading to make sure all the intermediate usage is accounted for as well.

For V3 I discovered the torch.cuda.set_per_process_memory_fraction function which allows manual split and autosplit to use the same process, so now everything is great. It also automatically manages references between the model and the cache, so now loading the model always creates tensors for however many caches are attached to it, and unloading the model destroys the cache. I'll see about backporting this to V2 at some point, maybe. Though I don't want to break the existing API, so idk.

@grimulkan
Copy link

If you ever do back-port it to V2 turbo, that'd maybe provide Tensor-P support for the Nemotron Llama's, at least until V3 gets around to that. So a vote in favor of that if it doesn't break things!

@Panchovix
Copy link

Panchovix commented Apr 12, 2025

Hi there, thanks for your work! When trying to convert Nemotron-Ultra-253B on this PR, I get this error

(venv_linux) pancho@fedora:/exllamav2$ python convert.py -i /models_llm/Llama-3_1-Nemotron-Ultra-253B-v1/ -o /Llama-3_1-Nemotron-Ultra-253B-v1_3.9bpw-exl2/ -b 3.9
 -- Beginning new job
 -- Input: /models_llm/Llama-3_1-Nemotron-Ultra-253B-v1/
 -- Output: /Llama-3_1-Nemotron-Ultra-253B-v1_3.9bpw-exl2/
 -- Using default calibration dataset
 -- Target bits per weight: 3.9 (decoder), 6 (head)
 -- Max shard size: 8192 MB

Traceback (most recent call last):
  File "//exllamav2/convert.py", line 1, in <module>
    import exllamav2.conversion.convert_exl2
  File "/exllamav2/exllamav2/conversion/convert_exl2.py", line 192, in <module>
    config.prepare()
  File "/exllamav2/exllamav2/config.py", line 328, in prepare
    intm_size = int(2 * ffn_mult * self.hidden_size / 3)
                    ~~^~~~~~~~~~
TypeError: unsupported operand type(s) for *: 'int' and 'NoneType'

When doing the same command but for exl3, it starts to convert (but it fails afterward with a different error)

@ymcki
Copy link
Author

ymcki commented Apr 13, 2025

Good to hear that you added the support yourself.

I don't think it is necessary to handle " sliding window support, the attention sink stuff" because all three 49B, 51B and 253B don't have them. I think "linear attention" is still necessary for the 51B model. But if you think it is an outdated model already, you can skip it. The newer 49B and 253B models don't have "linear_attention".

To also support the new 253B model, all you need is to handle the new dummy layers that have no self-attention and ffn. You can refer to the PR at llama.cpp for reference:
ggml-org/llama.cpp#12843

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants