Skip to content

Add Phi-4 Backbone #2272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

Add Phi-4 Backbone #2272

wants to merge 3 commits into from

Conversation

yrahul3910
Copy link

Description of the change

This is the first PR in contributing the Phi-4 model to KerasHub, and includes the backbone and its test file.

Reference

Colab Notebook

I've had some trouble getting this part to work, so I need some help. This is my Colab notebook, but the HF model has been pretty annoying to run. On CPU machines, it seems to constantly allocate all available memory (I gave up after giving it 280GB), and on an H200 on Modal, I couldn't get an output after 15 minutes. In the notebook, this line:

hf_output = pt_model(**hf_sample_input)

at the bottom is the one I have trouble with.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

@yrahul3910
Copy link
Author

I uploaded the output from the KerasHub model, could someone upload a HF version that I can compare and add to the Colab?

@yrahul3910
Copy link
Author

P.S. A lot of this code is based on the existing code for Phi-3 (the technical report states it mostly follows the Phi-3-Medium architecture; I simply made the changes from the report and the reference implementation). Should I refactor it to inherit from Phi3Backbone?

Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! added my review comments.

Comment on lines +148 to +151
'`rope_scaling_type` must be `None` or `"su"`.'
"if `None` is choosed, `RotaryEmbedding` will be used."
'if `"su"` is choosed, `Phi4SuScaledRotaryEmbedding` will be '
"used."
Copy link
Collaborator

@sachinprasadhs sachinprasadhs May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be change this to --> "rope_scaling_type must be None or su. If None, RotaryEmbedding will be used. If su, Phi4SuScaledRotaryEmbedding will be used."
Add backtick wherever it is necessary.

Comment on lines +30 to +31
vocabulary_size (int): The size of the token vocabulary. Defaults to
`100_352`.
Copy link
Collaborator

@sachinprasadhs sachinprasadhs May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this to --> vocabulary_size: int. The size of the token vocabulary. Defaults to `100_352`.

Follow the above arg pattern for others as well, i know this follows same as phi3, but this will be consistent with majority of our models.

Comment on lines +85 to +92
@pytest.mark.extra_large
def test_all_presets(self):
for preset in Phi4Backbone.presets:
self.run_preset_test(
cls=Phi4Backbone,
preset=preset,
input_data=self.input_data,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually how big these models will be and how many presets are we testing here?

Comment on lines +20 to +30
hidden_dim=5120,
intermediate_dim=17_920,
num_query_heads=40,
num_key_value_heads=10,
activation="silu",
layer_norm_epsilon=1e-5,
kernel_initializer="glorot_uniform",
dropout=0,
max_sequence_length=16_384,
pretraining_sequence_length=16_384,
rope_max_wavelength=250_000,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these default values are mostly common for all the presets in phi-4, if not may be we can remove default values?

Comment on lines +5 to +7
# TODO: Deprecate this in favor of
# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is
# removed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have keras 2 support now, either update the code or remove/update this comment.

@sachinprasadhs
Copy link
Collaborator

P.S. A lot of this code is based on the existing code for Phi-3 (the technical report states it mostly follows the Phi-3-Medium architecture; I simply made the changes from the report and the reference implementation). Should I refactor it to inherit from Phi3Backbone?

How much similar is Phi-4 compared to Phi-3? What is the approx percentage of code we can reuse?

@sachinprasadhs
Copy link
Collaborator

I guess still Tokenizer and CausalLM and preset file with necessary test files still needs to be added?

@yrahul3910
Copy link
Author

Actually, I think we might get away with directly subclassing Phi3Backbone and changing the defaults. If you do

diff src/models/phi3/phi3_backbone.py src/models/phi4/phi4_backbone.py

The only differences are the model name and the defaults; initially I did this copy anticipating architectural changes, but it seems the only ones are in the attention. From the paper's Section 3:

The architecture closely follows phi-3-medium, except that we now use the tiktoken tokenizer (for better
multilingual support) with a padded vocabulary size of 100,352 (including unused tokens) and we use
full attention over the 4K context length, rather than a 2K sliding window used in phi-3-medium.

I could not find this sliding window attention in the code, however, so that also remained unchanged, and the tokenizer would be part of the third PR (based on the contributing guidelines). Do you think it's better if I just did that instead?

@mattdangerw
Copy link
Member

Yeah this is an interesting question if the only thing that is really changing is the tokenizer. Would it work to just subclass all the phi3 classes as stubs that only update the classes? Like this

@keras_hub_export("keras_hub.models.Phi4CausalLM")
class Phi4CausalLM(Phi3CausalLM):
    backbone_cls = Phi4Backbone
    preprocessor_cls = Phi4CausalLMPreprocessor

And then define a new tokenizer? I don't think we need to worry about switchin the defaults, that we can just reflect in the preset configs we upload to kaggle.

Might be worth trying all of that on a single PR and trying to convert a model to see if everything works. Since we'd mostly be dealing with stub classes it shouldn't be too much code.

@yrahul3910
Copy link
Author

Yes, that makes sense to me. I'll try getting it out over the weekend.

@abheesht17
Copy link
Collaborator

@yrahul3910 - are you still working on this?

@yrahul3910
Copy link
Author

Yes, so sorry for the delay! I'll push some commits soon.

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Phi-4 backbone and its associated components. The implementation is solid and well-structured. I've identified a few minor areas for improvement, mainly related to typos and redundant code. Addressing these will improve the clarity and maintainability of the new model.

rope_scaling_short_factor List[float]: List of factors used to adjust
rope frequencies when the `rope_scaling_type` is `"su"`. List must
be of length `hidden_dim//num_query_heads//2`. It is used when
`sequence_length` is smaller than `original_max_sequence_length`.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the docstring for rope_scaling_short_factor, original_max_sequence_length is used. This should be pretraining_sequence_length to align with the parameter name and its usage in the rotary embedding implementation. Apply this same correction to the docstring for rope_scaling_long_factor on line 65.

Suggested change
`sequence_length` is smaller than `original_max_sequence_length`.
`sequence_length` is smaller than `pretraining_sequence_length`.

self.max_sequence_length = max_sequence_length
self.pretraining_sequence_length = pretraining_sequence_length
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_type = rope_scaling_type

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The attribute self.rope_scaling_type is assigned twice in the __init__ method. This line is a duplicate of the assignment on line 187 and can be removed.

Comment on lines +139 to +140
inverese_freq_short_factor=self.rope_scaling_short_factor,
inverese_freq_long_factor=self.rope_scaling_long_factor,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo in the keyword arguments inverese_freq_short_factor and inverese_freq_long_factor. They should be inverse_freq_short_factor and inverse_freq_long_factor respectively. This change is dependent on correcting the corresponding typo in the Phi4SuScaledRotaryEmbedding class definition first.

Suggested change
inverese_freq_short_factor=self.rope_scaling_short_factor,
inverese_freq_long_factor=self.rope_scaling_long_factor,
inverse_freq_short_factor=self.rope_scaling_short_factor,
inverse_freq_long_factor=self.rope_scaling_long_factor,

inverese_freq_short_factor List[float]: List of factors used to adjust
rope frequencies when the `rope_scaling_type` is `"su"`. List must
be of length `hidden_dim//num_query_heads//2`. It is used when
`sequence_length` is smaller than `original_max_sequence_length`.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the docstring, original_max_sequence_length is mentioned. This should be pretraining_sequence_length to match the class's parameter and its usage. This correction should also be applied on line 19.

Suggested change
`sequence_length` is smaller than `original_max_sequence_length`.
`sequence_length` is smaller than `pretraining_sequence_length`.

Comment on lines +41 to +42
inverese_freq_short_factor,
inverese_freq_long_factor,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Typo in the parameter names: inverese should be inverse. This affects inverese_freq_short_factor and inverese_freq_long_factor.

Correct this here and propagate throughout the class in the following places:

  • The __init__ method body (lines 64-65)
  • The _compute_cos_sin_embedding method (lines 78 and 83)
  • The get_config method (lines 120-121)
  • The docstring (lines 12 and 16)

Fixing this will improve API clarity and code maintainability.

Suggested change
inverese_freq_short_factor,
inverese_freq_long_factor,
inverse_freq_short_factor,
inverse_freq_long_factor,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

5 participants