-
Notifications
You must be signed in to change notification settings - Fork 291
Safetensors conversion #2290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Safetensors conversion #2290
Conversation
Thanks for the PR, will take a look in a bit :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just left some initial comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a unit test that calls this util and tries loading the result with transformers and seeing if it works. OK to add transformers to our ci environment here https://github.com/keras-team/keras-hub/blob/master/requirements-common.txt
… into safetensors_conversion merge updated branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Please address the changes from the earlier PR as well
keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, nice work!
return hf_config | ||
|
||
|
||
def export_to_hf(keras_model, path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add the API export decorator here, similar to this: https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/bloom/bloom_backbone.py#L15-L16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API.
So, this is how the directory keras_hub/src/utils/transformers/convert_to_safetensor/
will look like:
export.py
: this will have the common code. We will expose this as the API. This will also check if we support safetensor conversion for a given passed model yet.gemma.py
: this will just have a way to create the weight dictionary for Gemma. Insideexport.py
, we will call the the weight conversion function specific to a specified model.
Pinging @mattdangerw to confirm if we should do this now or at a later point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like model.save_to_preset()
or a function like some_export(model)
. Any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think structuring the export logic with a utility function (export_to_hf) and model-specific mappings (gemma.py) will enhance scalability and maintainability. New models can be added by creating a new file, while existing tests only need an import update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to Abheesht's comment we need an API instead of a script for Gemma, we already have that
https://github.com/keras-team/keras-hub/blob/master/tools/gemma/export_gemma_to_hf.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving comments since I don't see the changes we discussed last week.
keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py
Outdated
Show resolved
Hide resolved
keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, reviewed. Let's fix the tests!
GemmaCausalLMPreprocessor, | ||
) | ||
from keras_hub.src.models.gemma.gemma_tokenizer import ( | ||
GemmaTokenizer as KerasGemmaTokenizer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just GemmaTokenizer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HF also imports the tokenizer as GemmaTokenizer
… so just used KerasGemmaTokenizer
to avoid confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, okay. Then it's better to call the HF import as HFGemmaTokenizer and the Keras one as GemmaTokenizer, I suppose
keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py
Outdated
Show resolved
Hide resolved
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The code changes introduce a utility function to export Keras Gemma models to Hugging Face format, saving the configuration, weights, and tokenizer assets. The review focuses on improving the robustness of weight mapping, adding checks for empty weight dictionaries, and enhancing the warning message for missing vocabulary files.
@@ -0,0 +1,161 @@ | |||
import json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets make this a model-agnostic export utility. Rename file to safetensor_exporter.py
add a dict to maintain the mapping
MODEL_EXPORTERS = {
"GemmaBackbone": gemma_exporter.get_gemma_weights_map,
"LlamaBackbone": llama_exporter.get_llama_weights_map, # Future
}
and a user facing API function for the export
def export_to_safetensors(keras_model):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement exporter mapping for each model - for this PR's scope just the Gemma model that can serve as a prototype for other models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's land this PR first and do this in a separate PR: #2290 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, awesome work! Made a few cosmetic changes
keras_hub/src/utils/transformers/convert_to_safetensor/export.py
Outdated
Show resolved
Hide resolved
keras_hub/src/utils/transformers/convert_to_safetensor/export.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets land the generic API in a different PR! Thanks for the great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a few changes.
keras_hub/src/utils/transformers/convert_to_safetensor/export.py
Outdated
Show resolved
Hide resolved
keras_hub/src/utils/transformers/convert_to_safetensor/export.py
Outdated
Show resolved
Hide resolved
keras_hub/src/utils/transformers/convert_to_safetensor/export.py
Outdated
Show resolved
Hide resolved
"This is a test.", | ||
] | ||
proto_prefix = os.path.join(self.get_temp_dir(), "dummy_vocab") | ||
SentencePieceTrainer.train( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a follow up (not on this PR), let's consider using keras_hub/src/tests/test_data/gemma_test_vocab.spm
instead of retraining a new vocab here. Will be faster. Maybe leave a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried using it, and hit some issues. I didn't really dive deeper though, because I was lazy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually let me switch to approval so I won't block merging this. But let's address these comment before merge!
ty! |
Description of the change
Reference
Colab Notebook
https://colab.research.google.com/drive/1naqf0sO2J40skndWbVMeQismjL7MuEjd?usp=sharingChecklist