Skip to content

Llama3.1 #2132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
a2e9207
Add Llama 3.1
pctablet505 Mar 5, 2025
c31f222
Update llama3_rotary_embedding.py
pctablet505 Mar 5, 2025
68afeb8
Update llama3_rotary_embedding.py
pctablet505 Mar 5, 2025
6dc38b9
Update llama31_attention.py
pctablet505 Mar 5, 2025
39b29e7
Update __init__.py
pctablet505 Mar 5, 2025
ea667a5
Update llama31_attention.py
pctablet505 Mar 5, 2025
021b553
Update llama31_causal_lm.py
pctablet505 Mar 5, 2025
9b38b25
Update llama31_attention.py
pctablet505 Mar 5, 2025
7989b74
Update llama3_rotary_embedding.py
pctablet505 Mar 5, 2025
510523f
Update llama31_decoder.py
pctablet505 Mar 5, 2025
593f867
Update llama31_attention.py
pctablet505 Mar 5, 2025
11a3fb7
Code fix
pctablet505 Mar 6, 2025
1e651c5
code fix
pctablet505 Mar 7, 2025
dba14d0
code refactoring
pctablet505 Mar 7, 2025
b9b9ce7
Merge branch 'keras-team:master' into llama3.1
pctablet505 Mar 7, 2025
e896b03
replaced bitwise_and to logical_and. bitwise_and is not supported for…
pctablet505 Mar 10, 2025
4919575
deleted files added by mistake.
pctablet505 Mar 10, 2025
6ef7abe
Added changes in llama3 to support llama3.1\n Modified rotary embeddi…
pctablet505 Mar 17, 2025
203ade4
removed llama31 from model, and added support for llama3.1 using same…
pctablet505 Mar 17, 2025
8c8a710
Update rotary_embedding.py
pctablet505 Mar 17, 2025
1d67423
Update rotary_embedding.py
pctablet505 Mar 17, 2025
f7c8a58
Added argument details, and optimized operations
pctablet505 Mar 18, 2025
f2c980f
optimized operations
pctablet505 Mar 18, 2025
f3dae63
optimized operations
pctablet505 Mar 18, 2025
70b18a4
Added argument details in docstring
pctablet505 Mar 18, 2025
4eea6f3
Improved the code to adapt llama3.1 and llama3.2
pctablet505 Mar 24, 2025
4b7893d
Update llama_attention.py
pctablet505 Mar 24, 2025
d212621
Update llama_rotary_embedding.py
pctablet505 Mar 24, 2025
f0c2d41
changed variable name
pctablet505 Mar 26, 2025
fd63dc4
Code refactoring
pctablet505 Mar 26, 2025
a11c021
Typo Fix
pctablet505 Mar 26, 2025
7d8d239
code fix
pctablet505 Mar 26, 2025
8e3f2af
Delete convert_llama3_checkpoints.py
pctablet505 Mar 26, 2025
7b172d1
Delete keras-hub
pctablet505 Mar 26, 2025
d53020b
Update llama_backbone.py
pctablet505 Mar 28, 2025
840d181
Update convert_llama3.py
pctablet505 Apr 1, 2025
b603c5b
Update convert_llama3.py
pctablet505 Apr 1, 2025
1655a63
moved llama_rotary_embedding from keras_hub.models to llama directory
pctablet505 Apr 4, 2025
e4d523b
Update llama_rotary_embedding.py
pctablet505 Apr 4, 2025
a08a403
Update llama_rotary_embedding.py
pctablet505 Apr 4, 2025
dc924c0
Update convert_llama3.py
pctablet505 Apr 4, 2025
184a808
Merge branch 'keras-team:master' into llama3.1
pctablet505 Apr 4, 2025
be205a8
docstring changes
pctablet505 Apr 7, 2025
b7988b4
Update llama_backbone.py
pctablet505 Apr 7, 2025
4e4fa14
Update llama_rotary_embedding.py
pctablet505 Apr 7, 2025
93b6d2b
Update llama_rotary_embedding.py
pctablet505 Apr 7, 2025
d453b1b
Update llama_rotary_embedding.py
pctablet505 Apr 7, 2025
672a6f2
typo fix
pctablet505 Apr 8, 2025
1de4e7d
Update convert_llama3.py
pctablet505 Apr 9, 2025
c034ed4
Added presets for llama3.1 and llama3.2
pctablet505 Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions keras_hub/src/models/llama/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import keras
from keras import ops

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.models.llama.llama_rotary_embedding import (
LlamaRotaryEmbedding,
)
from keras_hub.src.utils.keras_utils import clone_initializer
from keras_hub.src.utils.keras_utils import fused_attention_op_available

Expand All @@ -16,7 +18,11 @@ def __init__(
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
rope_position_scaling_factor=1.0,
rope_frequency_adjustment_factor=None,
rope_low_freq_factor=None,
rope_high_freq_factor=None,
rope_pretraining_sequence_length=None,
kernel_initializer="glorot_uniform",
dropout=0,
**kwargs,
Expand All @@ -28,13 +34,16 @@ def __init__(

self.num_key_value_groups = num_query_heads // num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength
self.rope_position_scaling_factor = rope_position_scaling_factor
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
self.rope_low_freq_factor = rope_low_freq_factor
self.rope_high_freq_factor = rope_high_freq_factor
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length

self.kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

self.rope_scaling_factor = rope_scaling_factor

def build(self, inputs_shape):
# Einsum variables:
# b = batch size
Expand Down Expand Up @@ -103,9 +112,13 @@ def build(self, inputs_shape):
)
self._output_dense.build((None, None, self.num_query_heads, head_dim))

self.rotary_embedding_layer = RotaryEmbedding(
self.rotary_embedding_layer = LlamaRotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
position_scaling_factor=self.rope_position_scaling_factor,
frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
low_freq_factor=self.rope_low_freq_factor,
high_freq_factor=self.rope_high_freq_factor,
pretraining_sequence_length=self.rope_pretraining_sequence_length,
dtype=self.dtype_policy,
)

Expand Down Expand Up @@ -224,6 +237,11 @@ def get_config(self):
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"rope_low_freq_factor": self.rope_low_freq_factor,
"rope_high_freq_factor": self.rope_high_freq_factor,
"rope_pretraining_sequence_length": (
self.rope_pretraining_sequence_length
),
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
Expand Down
66 changes: 50 additions & 16 deletions keras_hub/src/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,30 @@ class LlamaBackbone(Backbone):
constructor.

Args:
vocabulary_size (int): The size of the token vocabulary.
num_layers (int): The number of transformer layers.
num_query_heads (int): The number of query attention heads for
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_query_heads : int. The number of query attention heads for
each transformer.
hidden_dim (int): The size of the transformer encoding and pooling
hidden_dim : int. The size of the transformer encoding and pooling
layers.
intermediate_dim (int): The output dimension of the first Dense layer in
intermediate_dim : int. The output dimension of the first Dense layer in
a three-layer feedforward network for each transformer.
num_key_value_heads (int): The number of key and value attention heads
num_key_value_heads : int. The number of key and value attention heads
for each transformer.
rope_max_wavelength (int, optional): The maximum angular wavelength of
rope_max_wavelength : int. The maximum angular wavelength of
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
rope_scaling_factor (float, optional): The scaling factor for
calculation of roatary embedding. Defaults to `1.0`.
layer_norm_epsilon (float, optional): Epsilon for the layer
normalization layers in the transformer decoder. Defaults to `1e-6`.
rope_position_scaling_factor: float. The scaling factor for
calculation of rotary embedding. Defaults to `1.0`
rope_frequency_adjustment_factor: flaot. The scaling factor
used to scale the inverse frequencies. Defaults to `None`.
rope_low_freq_factor: flaot. The low frequency scaling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flaot --> float. Here, and other places too.

factor. Defaults to `None`.
rope_high_freq_factor: flaot. Used for Llama3.1+. The high
frequency scaling factor. Defaults to `None`.
rope_pretraining_sequence_length: int. Used for Llama3.1+.
Defaults to `None`.
layer_norm_epsilon : flaot. Epsilon for the layer normalization layers
in the transformer decoder. Defaults to `1e-6`.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
Expand Down Expand Up @@ -87,7 +95,11 @@ def __init__(
intermediate_dim,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
rope_position_scaling_factor=1.0,
rope_frequency_adjustment_factor=None,
rope_low_freq_factor=None,
rope_high_freq_factor=None,
rope_pretraining_sequence_length=None,
layer_norm_epsilon=1e-6,
dropout=0,
dtype=None,
Expand All @@ -110,7 +122,15 @@ def __init__(
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
rope_max_wavelength=rope_max_wavelength,
rope_scaling_factor=rope_scaling_factor,
rope_position_scaling_factor=rope_position_scaling_factor,
rope_frequency_adjustment_factor=(
rope_frequency_adjustment_factor
),
rope_low_freq_factor=rope_low_freq_factor,
rope_high_freq_factor=rope_high_freq_factor,
rope_pretraining_sequence_length=(
rope_pretraining_sequence_length
),
layer_norm_epsilon=layer_norm_epsilon,
activation=ops.silu,
kernel_initializer=_llama_kernel_initializer(stddev=0.02),
Expand Down Expand Up @@ -152,9 +172,13 @@ def __init__(
self.num_query_heads = num_query_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.rope_max_wavelength = rope_max_wavelength
self.num_key_value_heads = num_key_value_heads
self.rope_scaling_factor = rope_scaling_factor
self.rope_max_wavelength = rope_max_wavelength
self.rope_position_scaling_factor = rope_position_scaling_factor
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
self.rope_low_freq_factor = rope_low_freq_factor
self.rope_high_freq_factor = rope_high_freq_factor
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.tie_word_embeddings = tie_word_embeddings
Expand All @@ -169,7 +193,17 @@ def get_config(self):
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"rope_position_scaling_factor": (
self.rope_position_scaling_factor
),
"rope_frequency_adjustment_factor": (
self.rope_frequency_adjustment_factor
),
"rope_low_freq_factor": self.rope_low_freq_factor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all the original Llama UTs passing?

"rope_high_freq_factor": self.rope_high_freq_factor,
"rope_pretraining_sequence_length": (
self.rope_pretraining_sequence_length
),
"num_key_value_heads": self.num_key_value_heads,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
Expand Down
23 changes: 20 additions & 3 deletions keras_hub/src/models/llama/llama_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ def __init__(
num_query_heads,
num_key_value_heads,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
rope_position_scaling_factor=1.0,
rope_frequency_adjustment_factor=None,
rope_low_freq_factor=None,
rope_high_freq_factor=None,
rope_pretraining_sequence_length=None,
activation="silu",
layer_norm_epsilon=1e-5,
kernel_initializer="glorot_uniform",
Expand All @@ -34,7 +38,11 @@ def __init__(
self.num_key_value_heads = num_key_value_heads

self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.rope_position_scaling_factor = rope_position_scaling_factor
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
self.rope_low_freq_factor = rope_low_freq_factor
self.rope_high_freq_factor = rope_high_freq_factor
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length

self.dropout = dropout

Expand All @@ -53,7 +61,11 @@ def build(self, decoder_sequence_shape):
num_query_heads=self.num_query_heads,
num_key_value_heads=self.num_key_value_heads,
rope_max_wavelength=self.rope_max_wavelength,
rope_scaling_factor=self.rope_scaling_factor,
rope_position_scaling_factor=self.rope_position_scaling_factor,
rope_frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
rope_low_freq_factor=self.rope_low_freq_factor,
rope_high_freq_factor=self.rope_high_freq_factor,
rope_pretraining_sequence_length=self.rope_pretraining_sequence_length,
kernel_initializer=clone_initializer(self.kernel_initializer),
dropout=self.dropout,
dtype=self.dtype_policy,
Expand Down Expand Up @@ -221,6 +233,11 @@ def get_config(self):
"num_query_heads": self.num_query_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"rope_low_freq_factor": self.rope_low_freq_factor,
"rope_high_freq_factor": self.rope_high_freq_factor,
"rope_pretraining_sequence_length": (
self.rope_pretraining_sequence_length
),
"num_key_value_heads": self.num_key_value_heads,
"activation": keras.activations.serialize(self.activation),
"layer_norm_epsilon": self.layer_norm_epsilon,
Expand Down
Loading
Loading