Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix of bug of kv_channels in TransformerLayer and added Gemma tutorial #731

Closed
wants to merge 10 commits into from

Conversation

pggPL
Copy link
Collaborator

@pggPL pggPL commented Mar 22, 2024

I'm working on tutorial similar to tutorial with Llama, but with Gemma model. Most parts are similar, but I have encountered few differences:

  1. Official weights are in Safetensor format, not as in the case of Llama - torch checkpoints. I modified the loading function.

  2. There is geglu activation function instead of swiglu, but this is also supported by the TE. It was one simple change in config.

  3. Gemma hidden dimension is different than attention dimension - attention projections change dimension from hidden(3072) to key/query/value(4096). There is parameter kv_channels of the TransformerLayer in TE, which seems to be responsible for exactly that. Nevertheless, changing it to value different than hidden_dim / num_heads causes Assertion Error and is not changing number of parameters.

I propose a change of argument kv_channels to the attention_hidden_dims. I changed the description and number of neurons in attention projection layers in transforemer.py and attention.py files.

The results of Gemma with bf16, bf16 and TE, fp8 and TE are proportional to the results of the Llama.

root and others added 6 commits March 22, 2024 22:55
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Mar 22, 2024

I was thinking about adding two parameters attention_hidden_dims_v and attention_hidden_dims_kq, because it is mathematically possible that hidden dimension of key vector and query vector is not equal to dimension of value vectors.

Nevertheless, FlashAttention assumes that this numbers are equal, so I gave up this idea.

@ptrendx
Copy link
Member

ptrendx commented Mar 25, 2024

Hmm, can we get the same effect without changing the API? That would be a breaking change and we would like to avoid that if possible.

])
if self.input_layernorm:
self.layernorm_qkv = LayerNormLinear(
hidden_size,
hidden_size + 2 * self.hidden_size_kv,
3 * attention_hidden_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
3 * attention_hidden_size,
self.hidden_size_per_attention_head * num_attention_heads + 2 * self.hidden_size_kv

@pggPL, we'd still need the separate hidden size (total) for k and v heads to handle group query attention

Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's what we could do alternately:

# We still keep this line
self.hidden_size_per_attention_head = kv_channels

# Use `kv_channels` to calculate the `attention_hidden_size`
self.attention_hidden_size = self.hidden_size_per_attention_head * num_attention_heads

# Use `self.attention_hidden_size` to calculate `hidden_size_kv` (which is used in case of GQA)
self.hidden_size_kv = int(self.attention_hidden_size* self.num_gqa_groups // num_attention_heads)

...

LayerNormLinear(
                    hidden_size,
                    self.attention_hidden_size+ 2 * self.hidden_size_kv,
                    ...
)

We could probably also rename attention_hidden_size to hidden_size_q because essentially that's what it's used for in the layer/weight creation (in the call to LayerNormLinear)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point, but I'm not sure whether using kv_channels is best option. Currently kv_channels argument does exactly nothing if it is set to correct value. If it is set to incorrect value, the error is raised. Changing the behavior of this argument is an option, but it will be unrelated to the name - we change also q_channels.

I propose to add optional argument attn_hidden_size - legacy code will work correctly without it. We can leave kv_channels in the code - but for example only print some "this argument is deprecated" warning. Or even simply do nothing. I believe (maybe wrongly) that this will not lead to less problems with legacy code, but the argument's names will be more accurate. Let me know what do you think.

pggPL and others added 4 commits March 27, 2024 20:51
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL pggPL mentioned this pull request May 3, 2024
11 tasks
@pggPL
Copy link
Collaborator Author

pggPL commented May 3, 2024

I split this PR into 2:
#833 (kv_channels part)
#829 (gemma part)
Thus I close it.

@pggPL pggPL closed this May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants