Skip to content

Add Esm #2244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open

Add Esm #2244

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,22 @@
from keras_hub.src.models.electra.electra_tokenizer import (
ElectraTokenizer as ElectraTokenizer,
)
from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESM2Backbone
from keras_hub.src.models.esm.esm_backbone import ESMBackbone as ESMBackbone

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import of ESMBackbone is redundant. The name ESMBackbone is already available in the module's namespace from the import on the previous line. Removing this line will make the code cleaner. While I see the file is autogenerated, this redundancy should ideally be fixed in the generation script.

from keras_hub.src.models.esm.esm_classifier import (
ESMProteinClassifier as ESMProteinClassifier,
)
from keras_hub.src.models.esm.esm_classifier_preprocessor import (
ESMProteinClassifierPreprocessor as ESMProteinClassifierPreprocessor,
)
from keras_hub.src.models.esm.esm_masked_plm import (
ESMMaskedPLM as ESM2MaskedPLM,
)
from keras_hub.src.models.esm.esm_masked_plm import ESMMaskedPLM as ESMMaskedPLM

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import of ESMMaskedPLM is redundant. The name ESMMaskedPLM is already available from the import on lines 200-202. Removing this line will improve code clarity. Even though this file is autogenerated, it's good practice to address such issues in the source generator if possible.

from keras_hub.src.models.esm.esm_masked_plm_preprocessor import (
ESMMaskedPLMPreprocessor as ESMMaskedPLMPreprocessor,
)
from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer
from keras_hub.src.models.f_net.f_net_backbone import (
FNetBackbone as FNetBackbone,
)
Expand Down
1 change: 1 addition & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from keras_hub.src.models.electra.electra_tokenizer import (
ElectraTokenizer as ElectraTokenizer,
)
from keras_hub.src.models.esm.esm_tokenizer import ESMTokenizer as ESMTokenizer
from keras_hub.src.models.f_net.f_net_tokenizer import (
FNetTokenizer as FNetTokenizer,
)
Expand Down
Empty file.
95 changes: 95 additions & 0 deletions keras_hub/src/models/esm/esm_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import keras
from keras import ops
from packaging import version

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.models.roformer_v2.roformer_v2_attention import (
RoformerAttention,
)


class ESMRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_embedding(self, x, position=1):
dim = x.shape[-1]
inv_freq = self.scaling_factor / (
self.max_wavelength ** (ops.arange(0, dim, 2, dtype=x.dtype) / dim)
)
t = ops.arange(x.shape[position], dtype=x.dtype)
freqs = ops.outer(t, inv_freq)
emb = ops.concatenate((freqs, freqs), axis=-1)

cos_emb = ops.cos(emb)[None, :, None, :]
sin_emb = ops.sin(emb)[None, :, None, :]
return cos_emb, sin_emb

def call(self, q, k, position=1):
cos_emb, sin_emb = self._compute_cos_sin_embedding(q, position)

return (
self.apply_rotary_pos_emb(q, cos_emb, sin_emb),
self.apply_rotary_pos_emb(k, cos_emb, sin_emb),
)

def rotate_half(self, x):
x1, x2 = ops.split(x, 2, -1)
return ops.concatenate((-x2, x1), axis=-1)

def apply_rotary_pos_emb(self, x, cos, sin):
cos = cos[:, : x.shape[1], :, :]
sin = sin[:, : x.shape[1], :, :]

return (x * cos) + (self.rotate_half(x) * sin)


class EsmSelfAttention(RoformerAttention):
"""MultiHeadAttention by ESM2

Referred to the implementation of HuggingFace.
In fact, this part of the calculation is exactly the same as RoFormer.
Only the calculation of the rotary part is different.
"""

def __init__(self, use_rotary=True, **kwargs):
super().__init__(**kwargs)
self.use_rotary = use_rotary

def build(self, input_shape):
super().build(input_shape)
if self.use_rotary:
self.rotary_embedding_layer = ESMRotaryEmbedding(
max_wavelength=self.max_wavelength, dtype=self.dtype_policy
)
self.rotary_embedding_layer.build([])

def call(self, x, attention_mask=None):
qw = self.q_dense(x)
kw = self.k_dense(x)
vw = self.v_dense(x)

b, s = ops.shape(qw)[:2]
qw = ops.reshape(qw, (b, s, self.heads, self.head_size))
kw = ops.reshape(kw, (b, s, self.heads, self.head_size))
vw = ops.reshape(vw, (b, s, self.heads, self.head_size))

if self.use_rotary:
qw, kw = self.rotary_embedding_layer(qw, kw)
if version.parse(keras.__version__) < version.parse("3.6"):
raise ValueError("Please make sure your Keras version is >=3.6.")
flash_attention = keras.config.is_flash_attention_enabled()
attention_mask = ops.reshape(attention_mask, [b, 1, s, 1])
if keras.config.backend() == "torch":
attention_mask = ops.repeat(attention_mask, s, -1)
attention_mask = ops.transpose(attention_mask, [0, 1, 3, 2])
o = ops.dot_product_attention(
qw, kw, vw, mask=attention_mask, flash_attention=flash_attention
)
return self.o_dense(ops.reshape(o, [b, s, -1]))

def get_config(self):
config = super().get_config()
config.update(
{
"use_rotary": self.use_rotary,
}
)
return config
227 changes: 227 additions & 0 deletions keras_hub/src/models/esm/esm_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
import keras
from keras import activations

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.esm.esm_encoder import ESMEncoder


def esm2_kernel_initializer(stddev=0.02):
return keras.initializers.TruncatedNormal(stddev=stddev)


@keras_hub_export(
["keras_hub.models.ESM2Backbone", "keras_hub.models.ESMBackbone"]
)
class ESMBackbone(Backbone):
"""A ESM2 and ESM encoder network.

This class implements a bi-directional Transformer-based encoder as
described in ["ESM"](https://github.com/facebookresearch/esm).

The default constructor gives a fully customizable, randomly initialized
ESM2 encoder with any number of layers, heads, and embed dim.To
load preset architectures and weights, use the `from_preset()` constructor.


Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Defaults to in the arg description wherever you're using default values.
max_wavelength arg detail is missing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still activation and max_wavelength description is missing!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add arg description for pad_token_id as well

vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_heads: int. The number of attention heads for each transformer.
The hidden size must be divisible by the number of attention heads.
hidden_dim: int. The size of the transformer encoding and pooler layers.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
dropout: float. Dropout probability for the Transformer encoder.
Defaults to 0.1
use_pre_layer_norm:bool.If true, then layer norm will be used before
entering the transformer block.
Since it's pre-norm, the default is false.
max_sequence_length: int. The maximum sequence length that this encoder
can consume. If None, `max_sequence_length` uses the value from
sequence length. This determines the variable shape for positional
embeddings.
position_embedding_type: str. The position embedding type to use.
One of "absolute" and "rotary".
Use "absolute" for ESM1. Use "rotary" for ESM2. Defaults to "rotary"
max_wavelength : int. The maximum angular wavelength of
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
activation :string or keras.activations. The activation to
use for the transformer. Defaults to `"gelu"`.
pad_token_id: int.padding token id. Normally 0,
but is set to 1 in the esm2 model
dtype: None or str or keras.mixed_precision.DTypePolicy. The dtype to
use for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.

Examples:
```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
}

# Pretrained ESM2 encoder.
model = keras_hub.models.ESM2Backbone.from_preset('hf://facebook/esm2_t6_8M_UR50D')
model(input_data)

# Randomly initialized ESM2 encoder with a custom config.
model = keras_hub.models.ESM2Backbone(
vocabulary_size=30552,
num_layers=4,
num_heads=4,
hidden_dim=256,
intermediate_dim=512,
head_size = 64,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ESMBackbone constructor does not accept a head_size argument. This parameter is calculated internally as hidden_dim // num_heads. Including it in the example will cause an error for users who copy-paste the code. Please remove this line from the example.

)
model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
num_layers,
num_heads,
hidden_dim,
intermediate_dim,
use_bias=True,
activation="gelu",
dropout=0.1,
dtype=None,
max_sequence_length=1024,
max_wavelength=10000,
layer_norm_eps=1e-12,
use_pre_layer_norm=False,
position_embedding_type="rotary",
pad_token_id=0,
**kwargs,
):
if position_embedding_type not in (
"rotary",
"absolute",
):
raise ValueError(
'`position_embedding_type` must be either `"rotary"`, or '
'`"absolute"`. Received '
f"position_embedding_type={position_embedding_type}."
)
head_size = hidden_dim // num_heads
# === Layers ===
self.token_embedding = keras.layers.Embedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
embeddings_initializer=esm2_kernel_initializer(),
dtype=dtype,
name="token_embedding",
)
if position_embedding_type == "absolute":
self.position_embedding = PositionEmbedding(
initializer=esm2_kernel_initializer(),
sequence_length=max_sequence_length,
dtype=dtype,
name="position_embedding",
)
self.embeddings_add = keras.layers.Add(
dtype=dtype,
name="embeddings_add",
)

self.output_layer_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_eps,
dtype=dtype,
name="output_layer_norm",
)
if use_pre_layer_norm:
self.emb_layer_norm = keras.layers.LayerNormalization(
epsilon=layer_norm_eps,
dtype=dtype,
name="emb_layer_norm",
)
self.transformer_layers = []
for i in range(num_layers):
layer = ESMEncoder(
heads=num_heads,
head_size=head_size,
intermediate_size=intermediate_dim,
use_bias=use_bias,
max_wavelength=max_wavelength,
dropout=dropout,
activation=activation,
kernel_initializer=esm2_kernel_initializer(),
layer_norm_eps=layer_norm_eps,
dtype=dtype,
use_rotary=position_embedding_type == "rotary",
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)

# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)

attention_mask = keras.ops.not_equal(token_id_input, pad_token_id)

token_vector = self.token_embedding(token_id_input)
if position_embedding_type == "absolute":
position_vector = self.position_embedding(
token_vector, start_index=pad_token_id
)
x = self.embeddings_add([token_vector, position_vector])
else:
x = token_vector
if use_pre_layer_norm:
x = self.emb_layer_norm(x)
for transformer_layer in self.transformer_layers:
x = transformer_layer(x, attention_mask=attention_mask)
output = self.output_layer_norm(x)
super().__init__(
inputs={
"token_ids": token_id_input,
},
outputs=output,
dtype=dtype,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout = dropout
self.max_wavelength = max_wavelength
self.head_size = head_size
self.activation = activations.get(activation)
self.use_bias = use_bias
self.start_token_index = 0
self.layer_norm_eps = layer_norm_eps
self.max_sequence_length = max_sequence_length
self.use_pre_layer_norm = use_pre_layer_norm
self.position_embedding_type = position_embedding_type
self.pad_token_id = pad_token_id

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.dropout,
"max_wavelength": self.max_wavelength,
"use_bias": self.use_bias,
"activation": activations.serialize(self.activation),
"layer_norm_eps": self.layer_norm_eps,
"use_pre_layer_norm": self.use_pre_layer_norm,
"position_embedding_type": self.position_embedding_type,
"max_sequence_length": self.max_sequence_length,
"pad_token_id": self.pad_token_id,
}
)
return config
39 changes: 39 additions & 0 deletions keras_hub/src/models/esm/esm_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import keras
import pytest
from keras import ops
from packaging import version

from keras_hub.src.models.esm.esm_backbone import ESMBackbone
from keras_hub.src.tests.test_case import TestCase


class ESMBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 10,
"num_layers": 2,
"num_heads": 1,
"hidden_dim": 2,
"intermediate_dim": 4,
}
self.input_data = {
"token_ids": ops.ones((2, 5), dtype="int32"),
}

def test_backbone_basics(self):
if version.parse(keras.__version__) < version.parse("3.6"):
self.skipTest("Failing on keras lower version")
self.run_backbone_test(
cls=ESMBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=ESMBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading
Loading