Skip to content

Add Esm #2244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open

Add Esm #2244

wants to merge 26 commits into from

Conversation

pass-lin
Copy link
Contributor

@pass-lin pass-lin commented May 3, 2025

from #2177
Achieved a smaller error with hf.

import os
os.environ["KERAS_BACKEND"] = "torch"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

from keras import ops
from transformers.models.esm.modeling_esm import EsmAttention as hf_EsmSelfAttention
from transformers import EsmConfig
from esm2.esm2_layers import EsmSelfAttention
import numpy as np
import keras
from transformers.models.esm.modeling_esm import EsmModel
weights_path = "facebook/esm2_t6_8M_UR50D"
hf_model = EsmModel.from_pretrained(weights_path)
hf_model.cuda().eval()
hf_model.embeddings.token_dropout = False


from keras_hub.src.models.esm.esm_backbone import (
    ESMBackbone,
)


keras_model =  ESMBackbone.from_preset('hf://'+weights_path)
keras_model.summary()


x = ops.array([[1,2,3,4,5]])+1
hf_out = hf_model(x,ops.ones_like(x))[0]
keras_out = keras_model({'token_ids': x})

print(ops.all(ops.isclose(hf_out, keras_out,atol=1e-4)))

ESM Checkpoint Conversion and Numerics Verification Demo (across multiple backends): Notebook Link

Train Demo: Notebook Link

@pass-lin
Copy link
Contributor Author

pass-lin commented May 3, 2025

ruff.....................................................................Passed
ruff-format..............................................................Passed
Error: Process completed with exit code 1.

Please help me figure out how to solve this problem.

@mattdangerw
Copy link
Member

Probably an issue with generating the API symbols. Looks like you need to sync with the latest changes on master, then you could try running ./shell/api_gen.sh

@sachinprasadhs
Copy link
Collaborator

ruff.....................................................................Passed
ruff-format..............................................................Passed
Error: Process completed with exit code 1.

Please help me figure out how to solve this problem.

You can rebase it to latest master code
and then run - pre-commit run --all-files
pip install -u namex

@pass-lin
Copy link
Contributor Author

keras_hub/src/layers/modeling/reversible_embedding_test.py::ReversibleEmbeddingTest::test_quantize_dtype_argument_tie_weights - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/layers/modeling/reversible_embedding_test.py::ReversibleEmbeddingTest::test_quantize_dtype_argument_untie_weights - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/layers/modeling/reversible_embedding_test.py::ReversibleEmbeddingTest::test_quantize_int8_tie_weights - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/layers/modeling/reversible_embedding_test.py::ReversibleEmbeddingTest::test_quantize_int8_untie_weights - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/albert/albert_backbone_test.py::AlbertBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/bart/bart_backbone_test.py::BartBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/bert/bert_backbone_test.py::BertBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/bloom/bloom_backbone_test.py::BloomBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/clip/clip_backbone_test.py::CLIPBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/deberta_v3/deberta_v3_backbone_test.py::DebertaV3BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/distil_bert/distil_bert_backbone_test.py::DistilBertBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/electra/electra_backbone_test.py::ElectraBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/f_net/f_net_backbone_test.py::FNetBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/falcon/falcon_backbone_test.py::FalconBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/gemma/gemma_backbone_test.py::GemmaBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/gemma/gemma_backbone_test.py::Gemma2BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/gpt2/gpt2_backbone_test.py::GPT2BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/gpt_neo_x/gpt_neo_x_backbone_test.py::GPTNeoXBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/llama/llama_backbone_test.py::LlamaTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/mistral/mistral_backbone_test.py::MistralBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/opt/opt_backbone_test.py::OPTBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py::PaliGemmaBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py::PaliGemma2BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/phi3/phi3_backbone_test.py::Phi3Test::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/phi3/phi3_backbone_test.py::Phi3Test::test_backbone_basics_with_su_rotary - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/roberta/roberta_backbone_test.py::RobertaBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/siglip/siglip_backbone_test.py::SigLIPBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/siglip/siglip_backbone_test.py::SigLIP2BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/t5/t5_backbone_test.py::T5BackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/whisper/whisper_backbone_test.py::WhisperBackboneTest::test_backbone_basics - TypeError: _int8_build() takes 2 positional arguments but 3 were given
FAILED keras_hub/src/models/xlm_roberta/xlm_roberta_backbone_test.py

@mattdangerw @sachinprasadhs
Is it a problem with the test environment? Why are there so many errors that don't belong to me?

@sachinprasadhs
Copy link
Collaborator

It's not related to your code, looks like some issue with the JAX backend, we will look into it.

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks fro the PR, I have added my comments, also add checkpoints conversion under: keras-hub/tools/checkpoint_conversion

intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
layer_norm_eps:bool.Should we use ln after embedding?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't get the point here, are you asking our input or it's the arg detail, if it is the arg details, it needs to be repharsed, avoid question marks and the argument name is emb_layer_norm_before

layer_norm_eps discription needs to be updated.

@pass-lin
Copy link
Contributor Author

pass-lin commented May 17, 2025

@sachinprasadhs @mattdangerw
Can anybody review my code?

@pass-lin
Copy link
Contributor Author

pass-lin commented Jun 2, 2025

@mattdangerw @sachinprasadhs
Please check my code, thank you.

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added few more comments and few of the previous review comments still needs to be addressed

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind.
Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still activation and max_wavelength description is missing!

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind.
Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add arg description for pad_token_id as well

Comment on lines 45 to 46
position_embedding_type:esm1 use abs position embeding,esm2 use rope.
so this parameter is only except for absolute and rotary.
Copy link
Collaborator

@sachinprasadhs sachinprasadhs Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still needs to be changed to:

position_embedding_type: str. The position embedding type to use. One of "absolute" and
"rotary". Use "absolute" for ESM1. Use "rotary" for ESM2. Defaults to "rotary".



@keras_hub_export("keras_hub.models.ESMProteinClassifierPreprocessor")
class ESMProteinClassifierPreprocessor(BertTextClassifierPreprocessor):
Copy link
Collaborator

@sachinprasadhs sachinprasadhs Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pending change here which should be subclassed from TextClassifierPreprocessor instead of BertTextClassifierPreprocessor

max_sequence_length=1024,
max_wavelength=10000,
layer_norm_eps=1e-12,
emb_layer_norm_before=False,
Copy link
Collaborator

@sachinprasadhs sachinprasadhs Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pending change, instead emb_layer_norm_before --> use_pre_layer_norm



@keras_hub_export("keras_hub.models.ESMProteinClassifier")
class ESMProteinClassifier(RobertaTextClassifier):
Copy link
Collaborator

@sachinprasadhs sachinprasadhs Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pending change.
You can subclass TextClassifier and make the same changes as RobertaTextClassifier instead of subclassing from another model.

@sachinprasadhs
Copy link
Collaborator

Once you address all the comments, add end to end working colab along with the checkpoints conversion under: keras-hub/tools/checkpoint_conversion

@pass-lin
Copy link
Contributor Author

Thanks, few minor comments.

Also, need more details specific to Keras 3.6 older version issue.

Finally in the PR description, add the colab notebook to show end to end working of the model, numerics verification. you can follow the PR description template from the recent PR.

How to add a Colab notebook? Can you give me give a demo?

@sachinprasadhs
Copy link
Collaborator

Thanks, few minor comments.
Also, need more details specific to Keras 3.6 older version issue.
Finally in the PR description, add the colab notebook to show end to end working of the model, numerics verification. you can follow the PR description template from the recent PR.

How to add a Colab notebook? Can you give me give a demo?

Adding from one of the recent PR which got merged, you can do something like this

  • DeiT Checkpoint Conversion and Numerics Verification Demo (across multiple backends): Notebook Link

  • DeiT End-to-End Demo (zero-shot and finetuning): Notebook Link

  • Here are the converted DeiT presets from Hugging Face checkpoints for reference.

@pass-lin
Copy link
Contributor Author

Thanks, few minor comments.
Also, need more details specific to Keras 3.6 older version issue.
Finally in the PR description, add the colab notebook to show end to end working of the model, numerics verification. you can follow the PR description template from the recent PR.

How to add a Colab notebook? Can you give me give a demo?

Adding from one of the recent PR which got merged, you can do something like this

  • DeiT Checkpoint Conversion and Numerics Verification Demo (across multiple backends): Notebook Link
  • DeiT End-to-End Demo (zero-shot and finetuning): Notebook Link
  • Here are the converted DeiT presets from Hugging Face checkpoints for reference.

Hello, I've already added the Colab demo of tools/checkpoint_conversion/convert_esm_checkpoints.py in the PR description. I think this is enough, and we can refer to BERT for the rest.
Can we merge now?

@sachinprasadhs
Copy link
Collaborator

We don't have access to view the notebook, can you make it public. Thanks

@pass-lin
Copy link
Contributor Author

We don't have access to view the notebook, can you make it public. Thanks

OK,It has been enable sharing

@sachinprasadhs
Copy link
Collaborator

Hi, The intention of the notebook is to verify the correctness of the model including, backbone, tasks with the usage details and the expected outcome and to verify the numerics stablity after weights transfer to the Keras architecture, with wither forward pass.

@pass-lin
Copy link
Contributor Author

Hi, The intention of the notebook is to verify the correctness of the model including, backbone, tasks with the usage details and the expected outcome and to verify the numerics stablity after weights transfer to the Keras architecture, with wither forward pass.

Okay, I've added another notebook, which is a demo for predicting the suitable pH of enzymes using ESM.

@sachinprasadhs
Copy link
Collaborator

You can remove the esm2_t6_8M directory, that will be generated using the conversion script you have provided and will be uploaded to Kaggle.

The notebook which you have provided doesn't have predict method,
take any sample suitable input and display the output with predict.

Also in your conversion script, you have mentioned atol=1e-3, what would be the error percentage when the atol=1e-04 and we need following things in your notebook

  • Numerics verification, load the original ESM model and do forward pass, and do the same forward pass to Keras-Hub ESM implementation and compare the numerics layer by layer to show if numerics are matching(preferably to the 1e-4 precision)
  • Demonstrating usage of proprocessor, Tokenizer and other functionalities of ESM

I have provided the reference notebooks, please refer those.

You can keep only ESM changes in this PR, you can create a new PR for roformer which also needs checkpoint conversion script, so that we can maintain the latest weight in Kaggle by generating the new weights with the script with any future changes to Keras Hub model specific.

@pass-lin
Copy link
Contributor Author

You can remove the esm2_t6_8M directory, that will be generated using the conversion script you have provided and will be uploaded to Kaggle.

The notebook which you have provided doesn't have predict method, take any sample suitable input and display the output with predict.

Also in your conversion script, you have mentioned atol=1e-3, what would be the error percentage when the atol=1e-04 and we need following things in your notebook

  • Numerics verification, load the original ESM model and do forward pass, and do the same forward pass to Keras-Hub ESM implementation and compare the numerics layer by layer to show if numerics are matching(preferably to the 1e-4 precision)
  • Demonstrating usage of proprocessor, Tokenizer and other functionalities of ESM

I have provided the reference notebooks, please refer those.

You can keep only ESM changes in this PR, you can create a new PR for roformer which also needs checkpoint conversion script, so that we can maintain the latest weight in Kaggle by generating the new weights with the script with any future changes to Keras Hub model specific.

OK, I have modified the notebook, please check. In addition, roformerV2 does not need to convert scripts, it is a native keras model. I just modified the keras2 api

@pass-lin
Copy link
Contributor Author

@sachinprasadhs plz check my notebook

@sachinprasadhs sachinprasadhs added the kokoro:force-run Runs Tests on GPU label Jul 9, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jul 9, 2025
@pass-lin
Copy link
Contributor Author

Hi, Still your notebook does not demonstrate the actual use case example demonstrations like https://huggingface.co/docs/transformers/en/model_doc/esm#transformers.EsmForSequenceClassification.forward.example or https://huggingface.co/docs/transformers/en/model_doc/esm#transformers.EsmForProteinFolding.forward.example or https://huggingface.co/docs/transformers/en/model_doc/esm#transformers.EsmForTokenClassification.forward.example, please include it.

We've included a training demo for ESM. As for ESMFold, that's another brand new pr. So can you just click and tell me what demo to add? Sorry for the trouble.

@sachinprasadhs
Copy link
Collaborator

Any demo with the implementation you have which predicts the actual data or the sample input data and display the output in the existing colab, and remove the folder/directory named esm2_t6_8M in your code, rest all it looks good.
Thanks for all the work.

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for ESM models, including the backbone, classifier, and masked protein language modeling tasks, along with their corresponding preprocessors, tokenizers, and tests. I've identified several areas for improvement, including fixing a critical bug in an exception raise, correcting several documentation examples and descriptions that could mislead users, and addressing inconsistencies in model configuration and weight conversion. Addressing the feedback will improve the quality and robustness of the new ESM model support.

if self.use_rotary:
qw, kw = self.rotary_embedding_layer(qw, kw)
if version.parse(keras.__version__) < version.parse("3.6"):
raise ("Please make sure your Keras version is >=3.6.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Raising a string or a tuple does not work as intended in Python 3 and will result in a TypeError. You should raise an instance of an exception class, such as ValueError.

raise ValueError("Please make sure your Keras version is >=3.6.")

Comment on lines 69 to 72
"token_ids": np.ones(shape=(2, 12), dtype="int32"),
"segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2),
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The example for preprocessed integer data includes "segment_ids" and "padding_mask" in the input features. However, the ESMBackbone and ESMProteinClassifierPreprocessor only expect "token_ids". This example is misleading and will not work as written. Please update it to only include "token_ids".

Suggested change
"token_ids": np.ones(shape=(2, 12), dtype="int32"),
"segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2),
}
features = {
"token_ids": np.ones(shape=(2, 12), dtype="int32"),
}

Comment on lines +146 to +150
hf_weight_key=f"{hf_prefix}{i}.LayerNorm.weight",
)
loader.port_weight(
keras_variable=block.feedforward_norm.beta,
hf_weight_key=f"{hf_prefix}{i}.LayerNorm.bias",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Hugging Face weight key for the feedforward layer normalization seems to be incorrect. Based on the transformers implementation of EsmLayer, the LayerNorm is part of the output block. The key should likely be encoder.layer.{i}.output.LayerNorm.weight instead of encoder.layer.{i}.LayerNorm.weight.

Suggested change
hf_weight_key=f"{hf_prefix}{i}.LayerNorm.weight",
)
loader.port_weight(
keras_variable=block.feedforward_norm.beta,
hf_weight_key=f"{hf_prefix}{i}.LayerNorm.bias",
hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.weight",
)
loader.port_weight(
keras_variable=block.feedforward_norm.beta,
hf_weight_key=f"{hf_prefix}{i}.output.LayerNorm.bias",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an illusion, and it is recommended to ignore it

num_heads=4,
hidden_dim=256,
intermediate_dim=512,
head_size = 64,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ESMBackbone constructor does not accept a head_size argument. This parameter is calculated internally as hidden_dim // num_heads. Including it in the example will cause an error for users who copy-paste the code. Please remove this line from the example.

Comment on lines 52 to 56
keras_variable=backbone.get_layer("embeddings_layer_norm").gamma,
hf_weight_key="embeddings.layer_norm.weight",
)
loader.port_weight(
keras_variable=backbone.get_layer("embeddings_layer_norm").beta,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The LayerNormalization layer for embeddings in ESMBackbone is named "emb_layer_norm", but the conversion script is trying to access it with the name "embeddings_layer_norm". This will cause a ValueError when trying to get the layer.

Suggested change
keras_variable=backbone.get_layer("embeddings_layer_norm").gamma,
hf_weight_key="embeddings.layer_norm.weight",
)
loader.port_weight(
keras_variable=backbone.get_layer("embeddings_layer_norm").beta,
keras_variable=backbone.get_layer("emb_layer_norm").gamma,
hf_weight_key="embeddings.layer_norm.weight",
)
loader.port_weight(
keras_variable=backbone.get_layer("emb_layer_norm").beta,

self.dropout = dropout
self.max_wavelength = max_wavelength
self.head_size = head_size
self.dropout = dropout

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The self.dropout attribute is assigned twice in the __init__ method (lines 195 and 198). The second assignment is redundant and can be removed.

```python
# Unbatched input.
tokenizer = keras_hub.models.ESMTokenizer.from_preset(
"roformer_v2_base_zh",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example for ESMTokenizer uses a roformer_v2_base_zh preset, which is confusing. It would be better to use an actual ESM preset to avoid misleading users.

Suggested change
"roformer_v2_base_zh",
"hf://facebook/esm2_t6_8M_UR50D",

raise ValueError(
'`position_embedding_type` must be either `"rotary"`, or '
'`"absolute"`. Received '
"position_embedding_type={position_embedding_type}."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string in the ValueError is intended to be an f-string to include the value of position_embedding_type, but it's missing the f prefix. This will result in the literal string {position_embedding_type} being part of the error message.

Suggested change
"position_embedding_type={position_embedding_type}."
f"position_embedding_type={position_embedding_type}."

"""A ESM2 and ESM encoder network.

This class implements a bi-directional Transformer-based encoder as
described in ["Roformer"](https://github.com/facebookresearch/esm).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring mentions "Roformer" but links to the ESM repository. To avoid confusion, the link text should be updated to "ESM" to match the model being implemented.

Suggested change
described in ["Roformer"](https://github.com/facebookresearch/esm).
described in ["ESM"](https://github.com/facebookresearch/esm).

from keras_hub.src.models.esm.esm_masked_plm import (
ESMMaskedPLM as ESM2MaskedPLM,
)
from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import of ESMMaskedPLM is redundant. The name ESMMaskedPLM is already available from the import on lines 200-202. Removing this line will improve code clarity. Even though this file is autogenerated, it's good practice to address such issues in the source generator if possible.

@pass-lin
Copy link
Contributor Author

pass-lin commented Jul 11, 2025

Any demo with the implementation you have which predicts the actual data or the sample input data and display the output in the existing colab, and remove the folder/directory named esm2_t6_8M in your code, rest all it looks good. Thanks for all the work.

I’m not sure what you mean by “delete the esm2_t6_8M directory.”

Looking at the demo notebook, all it does is install the environment, change the OS, and then run:

python tools/checkpoint_conversion/convert_deit_checkpoints.py --preset deit-base-distilled-patch16-384

In my notebook I did exactly the same thing: installed the environment, changed the OS, and then ran

python tools/checkpoint_conversion/convert_esm_checkpoints.py --preset esm2_t6_8M

Could you give a more precise and detailed description of which notebook has the problem and what it is missing compared to the reference notebook?
In the reference notebook, what exactly shows that the esm2_t6_8M directory should be removed?

Further, in another notebook I explicitly provide demonstrations of predict, fit, and evaluate. What exactly is still missing?

image image

A clear description would be greatly appreciated—thank you for your help!
And sorry for the extra work caused by adding a detailed description to you.

@pass-lin
Copy link
Contributor Author

pass-lin commented Jul 11, 2025

/gemini review

Thanks, I fixed some error with reference to gemini's review.

@sachinprasadhs
Copy link
Collaborator

In your code commit, there is the files with checkpoint files generated, we don't keep these files in our github, we upload the checkpoints to kaggle/ Hugging face, and the files will be generated by running the conversion script you have provided, you don't need to provide the converted checkpoints here in your commit
image

@pass-lin
Copy link
Contributor Author

In your code commit, there is the files with checkpoint files generated, we don't keep these files in our github, we upload the checkpoints to kaggle/ Hugging face, and the files will be generated by running the conversion script you have provided, you don't need to provide the converted checkpoints here in your commit image

I'm so sorry, this is an intermediate product when running the test before. I didn't notice his existence. Thanks very much for your reminder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants