-
Notifications
You must be signed in to change notification settings - Fork 19
[WIP] Enable GPTQModel to handle GraniteMoeParallelExperts #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cba3127
a103f5d
3fe665e
775e9fb
56dfbef
c23816b
11aae96
cff9a59
3cd53eb
12b206a
b0f9272
58d2722
f49fdc5
74d18b7
f98e85d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| "granite", | ||
| "gemma", | ||
| "dbrx_converted", | ||
| "granitemoe", | ||
| ] | ||
|
|
||
| EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| ############################################################################### | ||
| # Adapted from https://github.com/ModelCloud/GPTQModel | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| ############################################################################### | ||
| # Third Party | ||
| import torch | ||
|
|
||
| # Local | ||
| from .base import BaseGPTQModel | ||
|
|
||
|
|
||
| def new_forward(self, inputs, expert_size): | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was testing by defining this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. the user is the one who writes the new model, in this case,
|
||
| """ | ||
| Forward pass of the GraniteMoeParallelExperts module. | ||
| Args: | ||
| inputs (Tensor): | ||
| Input tensor. | ||
| expert_size: | ||
| Expert size information. | ||
| Returns: | ||
| Tensor: Output tensor. | ||
| """ | ||
| input_list = inputs.split(expert_size, dim=0) | ||
| output_list = [] | ||
| for i in range(self.num_experts): | ||
| # the key is we need to use call the module | ||
| output_list.append(self.weight[i](input_list[i])) | ||
| results = torch.cat(output_list, dim=0) | ||
| return results | ||
|
|
||
|
|
||
| class GraniteMoeGPTQ(BaseGPTQModel): | ||
| base_modules = ["model.embed_tokens", "model.norm"] | ||
| convert3dparameters = True | ||
| update_forwards = {"GraniteMoeParallelExperts": new_forward} | ||
|
|
||
| layers_node = "model.layers" | ||
| layer_type = "GraniteMoeDecoderLayer" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest you add some simple key to inform the format of Also in the granitemoe case, another compilation is that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so basically the simple key needs to know what do look for to convert it to 3D tensor, and then when you write class GraniteMoeGPTQ(BaseGPTQModel):
convert3dToModuleList = ["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]
layer_modules = [
[
"block_sparse_moe.input_linear.0.weight",
"block_sparse_moe.input_linear.1.weight",
...
], [
"block_sparse_moe.output_linear.0.weight",
"block_sparse_moe.output_linear.1.weight",
...
]
] |
||
| layer_modules = [ | ||
| ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], | ||
| ["self_attn.o_proj"], | ||
| [f"block_sparse_moe.input_linear.weight.{i}" for i in range(40)], | ||
| [f"block_sparse_moe.output_linear.weight.{i}" for i in range(40)], | ||
| ] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this need to be changed?