-
Notifications
You must be signed in to change notification settings - Fork 278
Llama3.1 #2132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Llama3.1 #2132
Changes from 47 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
a2e9207
Add Llama 3.1
pctablet505 c31f222
Update llama3_rotary_embedding.py
pctablet505 68afeb8
Update llama3_rotary_embedding.py
pctablet505 6dc38b9
Update llama31_attention.py
pctablet505 39b29e7
Update __init__.py
pctablet505 ea667a5
Update llama31_attention.py
pctablet505 021b553
Update llama31_causal_lm.py
pctablet505 9b38b25
Update llama31_attention.py
pctablet505 7989b74
Update llama3_rotary_embedding.py
pctablet505 510523f
Update llama31_decoder.py
pctablet505 593f867
Update llama31_attention.py
pctablet505 11a3fb7
Code fix
pctablet505 1e651c5
code fix
pctablet505 dba14d0
code refactoring
pctablet505 b9b9ce7
Merge branch 'keras-team:master' into llama3.1
pctablet505 e896b03
replaced bitwise_and to logical_and. bitwise_and is not supported for…
pctablet505 4919575
deleted files added by mistake.
pctablet505 6ef7abe
Added changes in llama3 to support llama3.1\n Modified rotary embeddi…
pctablet505 203ade4
removed llama31 from model, and added support for llama3.1 using same…
pctablet505 8c8a710
Update rotary_embedding.py
pctablet505 1d67423
Update rotary_embedding.py
pctablet505 f7c8a58
Added argument details, and optimized operations
pctablet505 f2c980f
optimized operations
pctablet505 f3dae63
optimized operations
pctablet505 70b18a4
Added argument details in docstring
pctablet505 4eea6f3
Improved the code to adapt llama3.1 and llama3.2
pctablet505 4b7893d
Update llama_attention.py
pctablet505 d212621
Update llama_rotary_embedding.py
pctablet505 f0c2d41
changed variable name
pctablet505 fd63dc4
Code refactoring
pctablet505 a11c021
Typo Fix
pctablet505 7d8d239
code fix
pctablet505 8e3f2af
Delete convert_llama3_checkpoints.py
pctablet505 7b172d1
Delete keras-hub
pctablet505 d53020b
Update llama_backbone.py
pctablet505 840d181
Update convert_llama3.py
pctablet505 b603c5b
Update convert_llama3.py
pctablet505 1655a63
moved llama_rotary_embedding from keras_hub.models to llama directory
pctablet505 e4d523b
Update llama_rotary_embedding.py
pctablet505 a08a403
Update llama_rotary_embedding.py
pctablet505 dc924c0
Update convert_llama3.py
pctablet505 184a808
Merge branch 'keras-team:master' into llama3.1
pctablet505 be205a8
docstring changes
pctablet505 b7988b4
Update llama_backbone.py
pctablet505 4e4fa14
Update llama_rotary_embedding.py
pctablet505 93b6d2b
Update llama_rotary_embedding.py
pctablet505 d453b1b
Update llama_rotary_embedding.py
pctablet505 672a6f2
typo fix
pctablet505 1de4e7d
Update convert_llama3.py
pctablet505 c034ed4
Added presets for llama3.1 and llama3.2
pctablet505 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,22 +30,30 @@ class LlamaBackbone(Backbone): | |
constructor. | ||
|
||
Args: | ||
vocabulary_size (int): The size of the token vocabulary. | ||
num_layers (int): The number of transformer layers. | ||
num_query_heads (int): The number of query attention heads for | ||
vocabulary_size: int. The size of the token vocabulary. | ||
num_layers: int. The number of transformer layers. | ||
num_query_heads : int. The number of query attention heads for | ||
each transformer. | ||
hidden_dim (int): The size of the transformer encoding and pooling | ||
hidden_dim : int. The size of the transformer encoding and pooling | ||
layers. | ||
intermediate_dim (int): The output dimension of the first Dense layer in | ||
intermediate_dim : int. The output dimension of the first Dense layer in | ||
a three-layer feedforward network for each transformer. | ||
num_key_value_heads (int): The number of key and value attention heads | ||
num_key_value_heads : int. The number of key and value attention heads | ||
for each transformer. | ||
rope_max_wavelength (int, optional): The maximum angular wavelength of | ||
rope_max_wavelength : int. The maximum angular wavelength of | ||
the sine/cosine curves, for rotary embeddings. Defaults to `10000`. | ||
rope_scaling_factor (float, optional): The scaling factor for | ||
calculation of roatary embedding. Defaults to `1.0`. | ||
layer_norm_epsilon (float, optional): Epsilon for the layer | ||
normalization layers in the transformer decoder. Defaults to `1e-6`. | ||
rope_position_scaling_factor: float. The scaling factor for | ||
calculation of rotary embedding. Defaults to `1.0` | ||
rope_frequency_adjustment_factor: flaot. The scaling factor | ||
used to scale the inverse frequencies. Defaults to `None`. | ||
rope_low_freq_factor: flaot. The low frequency scaling | ||
factor. Defaults to `None`. | ||
rope_high_freq_factor: flaot. Used for Llama3.1+. The high | ||
frequency scaling factor. Defaults to `None`. | ||
rope_pretraining_sequence_length: int. Used for Llama3.1+. | ||
Defaults to `None`. | ||
layer_norm_epsilon : flaot. Epsilon for the layer normalization layers | ||
in the transformer decoder. Defaults to `1e-6`. | ||
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | ||
for model computations and weights. Note that some computations, | ||
such as softmax and layer normalization, will always be done at | ||
|
@@ -87,7 +95,11 @@ def __init__( | |
intermediate_dim, | ||
num_key_value_heads, | ||
rope_max_wavelength=10000, | ||
rope_scaling_factor=1.0, | ||
rope_position_scaling_factor=1.0, | ||
rope_frequency_adjustment_factor=None, | ||
rope_low_freq_factor=None, | ||
rope_high_freq_factor=None, | ||
rope_pretraining_sequence_length=None, | ||
layer_norm_epsilon=1e-6, | ||
dropout=0, | ||
dtype=None, | ||
|
@@ -110,7 +122,15 @@ def __init__( | |
num_query_heads=num_query_heads, | ||
num_key_value_heads=num_key_value_heads, | ||
rope_max_wavelength=rope_max_wavelength, | ||
rope_scaling_factor=rope_scaling_factor, | ||
rope_position_scaling_factor=rope_position_scaling_factor, | ||
rope_frequency_adjustment_factor=( | ||
rope_frequency_adjustment_factor | ||
), | ||
rope_low_freq_factor=rope_low_freq_factor, | ||
rope_high_freq_factor=rope_high_freq_factor, | ||
rope_pretraining_sequence_length=( | ||
rope_pretraining_sequence_length | ||
), | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
activation=ops.silu, | ||
kernel_initializer=_llama_kernel_initializer(stddev=0.02), | ||
|
@@ -152,9 +172,13 @@ def __init__( | |
self.num_query_heads = num_query_heads | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.rope_max_wavelength = rope_max_wavelength | ||
self.num_key_value_heads = num_key_value_heads | ||
self.rope_scaling_factor = rope_scaling_factor | ||
self.rope_max_wavelength = rope_max_wavelength | ||
self.rope_position_scaling_factor = rope_position_scaling_factor | ||
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor | ||
self.rope_low_freq_factor = rope_low_freq_factor | ||
self.rope_high_freq_factor = rope_high_freq_factor | ||
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
self.dropout = dropout | ||
self.tie_word_embeddings = tie_word_embeddings | ||
|
@@ -169,7 +193,17 @@ def get_config(self): | |
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"rope_max_wavelength": self.rope_max_wavelength, | ||
"rope_scaling_factor": self.rope_scaling_factor, | ||
"rope_position_scaling_factor": ( | ||
self.rope_position_scaling_factor | ||
), | ||
"rope_frequency_adjustment_factor": ( | ||
self.rope_frequency_adjustment_factor | ||
), | ||
"rope_low_freq_factor": self.rope_low_freq_factor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are all the original Llama UTs passing? |
||
"rope_high_freq_factor": self.rope_high_freq_factor, | ||
"rope_pretraining_sequence_length": ( | ||
self.rope_pretraining_sequence_length | ||
), | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
"dropout": self.dropout, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flaot --> float. Here, and other places too.