Skip to content

Conversation

vulkomilev
Copy link

Description of the change

I have made some temporary changes to mistal backbone and I need help from the team because I am still encountering problems but at least the model is loading in the mistral backbone.I am creating this PR in order to collab with other contributors.

Reference

[keras-hub-2333]

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @vulkomilev, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces temporary adjustments to the Mistral model's attention mechanism and tokenizer handling within the Keras Hub project. The primary goal is to facilitate debugging and collaboration with the team, as the author is currently facing issues with model loading despite these changes.

Highlights

  • Mistral Attention Layer Adjustments: The build method for EinsumDense layers (_query_dense, _key_dense, _value_dense, _output_dense) in MistralAttention has been modified to use hardcoded input shapes instead of dynamically derived inputs_shape. This includes (None, None, 4096) for query, key, and value dense layers, and (None, None, self._num_query_heads, 128) for the output dense layer.
  • Tokenizer File Switch: The convert_tokenizer function now attempts to load a tokenizer from "tekken.json" instead of the standard "tokenizer.model".
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@vulkomilev
Copy link
Author

@sachinprasadhs hi please look at the changes. These are temporary and in my final commit I will move them to more appropriate place.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request appears to introduce temporary changes aimed at debugging the Mistral backbone and tokenizer integration. My review highlights several debugging artifacts, such as print statements, and hardcoded values that should be addressed. The feedback focuses on removing these temporary elements and restoring dynamic values to ensure the code is robust, maintainable, and ready for collaboration or merging. Key areas for improvement include removing hardcoded dimensions in mistral_attention.py and reverting the tokenizer file path in convert_mistral.py.

self._query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self._num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self._query_dense.build(inputs_shape)
self._query_dense.build((None,None,4096))#inputs_shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input shape for _query_dense.build is hardcoded. This makes the layer less flexible and seems to be a temporary change for debugging, as indicated by the commented-out code. Please consider restoring the original dynamic inputs_shape.

Suggested change
self._query_dense.build((None,None,4096))#inputs_shape
self._query_dense.build(inputs_shape)

@@ -77,7 +79,7 @@ def build(self, inputs_shape):
dtype=self.dtype_policy,
name="key",
)
self._key_dense.build(inputs_shape)
self._key_dense.build((None,None,4096))#input_shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the query layer, the input shape for _key_dense.build is hardcoded. This should be reverted to use inputs_shape to maintain the model's flexibility.

Suggested change
self._key_dense.build((None,None,4096))#input_shape
self._key_dense.build(inputs_shape)

@@ -90,7 +92,7 @@ def build(self, inputs_shape):
dtype=self.dtype_policy,
name="value",
)
self._value_dense.build(inputs_shape)
self._value_dense.build((None,None,4096))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input shape for _value_dense.build is also hardcoded. This should be reverted to use inputs_shape to maintain the model's flexibility.

Suggested change
self._value_dense.build((None,None,4096))
self._value_dense.build(inputs_shape)

@@ -111,7 +113,7 @@ def build(self, inputs_shape):
name="attention_output",
)
self._output_dense.build(
(None, None, self._num_query_heads, self._head_dim)
(None, None, self._num_query_heads, 128)#self._head_dim)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The head dimension for _output_dense.build is hardcoded to 128. This should use the calculated self._head_dim to ensure the model works correctly with different configurations.

Suggested change
(None, None, self._num_query_heads, 128)#self._head_dim)
(None, None, self._num_query_heads, self._head_dim)

@@ -113,4 +115,4 @@ def convert_weights(backbone, loader, transformers_config):


def convert_tokenizer(cls, preset, **kwargs):
return cls(get_file(preset, "tokenizer.model"), **kwargs)
return cls(get_file(preset, "tekken.json"),**kwargs)#)"tokenizer.model"), **kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tokenizer file is hardcoded to "tekken.json", which seems like a temporary value for testing. This should be reverted to the original "tokenizer.model" to ensure presets can be loaded correctly.

Suggested change
return cls(get_file(preset, "tekken.json"),**kwargs)#)"tokenizer.model"), **kwargs)
return cls(get_file(preset, "tokenizer.model"), **kwargs)

@@ -45,6 +45,7 @@ def __init__(
self._rope_scaling_factor = rope_scaling_factor

def build(self, inputs_shape):
print("inputs_shape",inputs_shape)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This appears to be a debugging print statement. Please remove it before this change is merged.

@@ -54,17 +55,18 @@ def build(self, inputs_shape):
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
print("self._hidden_dim // self._num_query_heads",self._hidden_dim , self._num_query_heads)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This appears to be a debugging print statement. Please remove it before this change is merged.

self._head_dim = self._hidden_dim // self._num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)

print("(None, self._num_query_heads, self._head_dim)",(None, self._num_query_heads, self._head_dim))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This appears to be a debugging print statement. Please remove it before this change is merged.

@@ -50,7 +50,7 @@ def convert_weights(backbone, loader, transformers_config):
hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight",
hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16),
)

print("decoder_layer._self_attention_layer._query_dense.kernel",decoder_layer._self_attention_layer._query_dense.kernel,index)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This appears to be a debugging print statement. Please remove it before this change is merged.

Comment on lines 62 to 63
print("decoder_layer._self_attention_layer._key_dense.kernel",decoder_layer._self_attention_layer._key_dense.kernel,index)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This appears to be a debugging print statement and an extra newline. Please remove them before this change is merged.

@sachinprasadhs
Copy link
Collaborator

@sachinprasadhs hi please look at the changes. These are temporary and in my final commit I will move them to more appropriate place.

Hi, could you remove all the print statement and make the final changes which will be ready to review.

@vulkomilev
Copy link
Author

@sachinprasadhs done. I may have some other format issues to fix and I will do them as I see the failing test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants