-
Notifications
You must be signed in to change notification settings - Fork 488
Add GLM4_MOE model support #952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vvvdwbvvv
wants to merge
18
commits into
linkedin:main
Choose a base branch
from
vvvdwbvvv:add-glm4moe
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
26487c2
[GLM4MOE] Add support for Liger kernel patches in GLM-4MOE models
vvvdwbvvv dac15e9
[GLM4MOE] Formatting functions
vvvdwbvvv 14cfb90
Rename function for GLM-4MOE kernel application and update model type…
vvvdwbvvv 973e418
Refactor lce_forward function: update return type and remove deprecat…
vvvdwbvvv 39e7d18
Fix import path for Glm4MoeConfig in test_apply_liger_kernel_to_insta…
vvvdwbvvv ca27242
fix tests
vvvdwbvvv 5af9d16
modify to adapt to new API
vvvdwbvvv 72c7ec6
Merge branch 'main' into add-glm4moe
lancerts 2f320f6
Merge branch 'main' into add-glm4moe
lancerts 00c0d13
Merge branch 'main' into add-glm4moe
lancerts 0be4cc5
Merge branch 'main' into add-glm4moe
vvvdwbvvv 7f0ebf4
Enhance GLM4-MoE support by adding MLP handling in monkey patching
vvvdwbvvv 5cf027e
fix typo
vvvdwbvvv 876d075
fix typo
vvvdwbvvv c3eb0fc
Merge remote-tracking branch 'origin/main' into add-glm4moe
vvvdwbvvv cf9f212
fix: update Glm4Moe import and clean up MOE layer patching in monkey_…
vvvdwbvvv ccc308a
fix: update rotary position embedding assignment in apply_liger_kerne…
vvvdwbvvv 375903c
feat: add MOE configuration parameters for GLM4_MOE in test models
vvvdwbvvv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| from typing import List | ||
| from typing import Optional | ||
| from typing import Tuple | ||
| from typing import Union | ||
|
|
||
| import torch | ||
|
|
||
| from transformers.utils.deprecation import deprecate_kwarg | ||
|
|
||
| from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss | ||
| from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result | ||
| from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast | ||
|
|
||
|
|
||
| @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") | ||
| def lce_forward( | ||
| self, | ||
| input_ids: torch.LongTensor = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| use_cache: Optional[bool] = None, | ||
| output_attentions: Optional[bool] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| logits_to_keep: Union[int, torch.Tensor] = 0, | ||
| skip_logits: Optional[bool] = None, | ||
| **kwargs, | ||
| ) -> Union[Tuple, LigerCausalLMOutputWithPast]: | ||
| r""" | ||
| Args: | ||
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | ||
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | ||
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | ||
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | ||
| image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): | ||
| The temporal, height and width of feature shape of each image in LLM. | ||
| video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): | ||
| The temporal, height and width of feature shape of each video in LLM. | ||
| rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): | ||
| The rope index difference between sequence length and multimodal rope. | ||
|
|
||
|
|
||
| logits_to_keep (`int` or `torch.Tensor`, *optional*): | ||
| If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all | ||
| `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that | ||
| token can save memory, which becomes pretty significant for long sequences or large vocabulary size. | ||
| If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. | ||
| This is useful when using packed tensor format (single dimension for batch and sequence length). | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| >>> from transformers import AutoProcessor, Glm4MoeForCausalLM | ||
| >>> import torch | ||
|
|
||
| >>> MODEL_PATH = "meta-glm4_moe/Glm4Moe-2-7b-hf" | ||
| >>> messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| { | ||
| "type": "image", | ||
| "url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png" | ||
| }, | ||
| { | ||
| "type": "text", | ||
| "text": "describe this image" | ||
| } | ||
| ], | ||
| } | ||
| ] | ||
| >>> processor = AutoProcessor.from_pretrained(MODEL_PATH) | ||
| >>> model = Glm4MoeForCausalLM.from_pretrained( | ||
| pretrained_model_name_or_path=MODEL_PATH, | ||
| dtype="auto", | ||
| device_map="auto", | ||
| ) | ||
| >>> inputs = processor.apply_chat_template( | ||
| messages, | ||
| tokenize=True, | ||
| add_generation_prompt=True, | ||
| return_dict=True, | ||
| return_tensors="pt" | ||
| ).to(model.device) | ||
| >>> inputs.pop("token_type_ids", None) | ||
| >>> generated_ids = model.generate(**inputs, max_new_tokens=8192) | ||
| >>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) | ||
| ``` | ||
| """ | ||
|
|
||
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
| output_hidden_states = ( | ||
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
| ) | ||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
|
||
| # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
| outputs = self.model( | ||
| input_ids=input_ids, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| inputs_embeds=inputs_embeds, | ||
| use_cache=use_cache, | ||
| output_attentions=output_attentions, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| cache_position=cache_position, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| hidden_states = outputs[0] | ||
| # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | ||
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | ||
| kept_hidden_states = hidden_states[:, slice_indices, :] | ||
|
|
||
| shift_labels = kwargs.pop("shift_labels", None) | ||
| logits = None | ||
| loss = None | ||
| token_accuracy = None | ||
|
|
||
| if skip_logits and labels is None and shift_labels is None: | ||
| raise ValueError("skip_logits is True, but labels and shift_labels are None") | ||
|
|
||
| if skip_logits is None: | ||
| # By default, if in training mode, don't materialize logits | ||
| skip_logits = self.training and (labels is not None or shift_labels is not None) | ||
|
|
||
| if skip_logits: | ||
| result = LigerForCausalLMLoss( | ||
| hidden_states=kept_hidden_states, | ||
| lm_head_weight=self.lm_head.weight, | ||
| labels=labels, | ||
| shift_labels=shift_labels, | ||
| hidden_size=self.config.hidden_size, | ||
| **kwargs, | ||
| ) | ||
| loss, _, token_accuracy = unpack_cross_entropy_result(result) | ||
|
|
||
| else: | ||
| logits = self.lm_head(kept_hidden_states) | ||
| if labels is not None or shift_labels is not None: | ||
| loss = self.loss_function( | ||
| logits=logits, | ||
| labels=labels, | ||
| shift_labels=shift_labels, | ||
| vocab_size=self.config.vocab_size, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| if not return_dict: | ||
| output = (logits,) + outputs[1:] | ||
| output = ((loss,) + output) if loss is not None else output | ||
| output = output + (token_accuracy,) if token_accuracy is not None else output | ||
| return output | ||
|
|
||
| # Return custom output class with accuracy field | ||
| return LigerCausalLMOutputWithPast( | ||
| loss=loss, | ||
| logits=logits, | ||
| past_key_values=outputs.past_key_values, | ||
| hidden_states=outputs.hidden_states, | ||
| attentions=outputs.attentions, | ||
| token_accuracy=token_accuracy, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set smaller experts related numbers as well