Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .granite import GraniteGPTQ
from .granitemoe import GraniteMoeGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"granite",
"gemma",
"dbrx_converted",
"granitemoe",
]

EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .gpt_bigcode import GPTBigCodeGPTQ
from .gpt_neox import GPTNeoXGPTQ
from .granite import GraniteGPTQ
from .granitemoe import GraniteMoeGPTQ
from .llama import LlamaGPTQ
from .mistral import MistralGPTQ
from .mixtral import MixtralGPTQ
Expand All @@ -43,6 +44,7 @@
"granite": GraniteGPTQ,
"dbrx": DbrxGPTQ,
"dbrx_converted": DbrxConvertedGPTQ,
"granitemoe": GraniteMoeGPTQ,
}

at_least_one_cuda_v6 = any(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
###############################################################################
# Standard
from os.path import isfile, join
from typing import Dict, List, Optional, Union
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
import copy
import json
import logging
Expand Down Expand Up @@ -74,6 +75,7 @@
move_to,
nested_move_to,
pack_model,
replace_3d_parameters_with_module_list,
simple_dispatch_model,
verify_model_hash,
verify_sharded_model_hashes,
Expand All @@ -94,6 +96,12 @@ class BaseGPTQModel(nn.Module):
# does not include the node which holds all the repeating layers
base_modules: List[str] = None

# If 3D Parameters to be converted
convert3dparameters: bool = False

# User provided forward pass to replace the existing forward pass
update_forwards: List[Tuple[str, Callable]] = None

# name of lm_head
lm_head: str = "lm_head"

Expand Down Expand Up @@ -128,6 +136,13 @@ def __init__(
super().__init__()

self.model = model
if self.convert3dparameters:
replace_3d_parameters_with_module_list(model)
for mod in model.modules():
forward = self.update_forwards.get(mod.__class__.__name__)
if forward is not None:
mod.forward = MethodType(forward, mod)

self.model_type = self.model.config.model_type
self._quantized = quantized
self.quantize_config = quantize_config
Expand Down Expand Up @@ -561,7 +576,7 @@ def save_quantized(
self.quantize_config.meta_set_versionable(
key=META_FIELD_QUANTIZER,
value=META_QUANTIZER_GPTQMODEL,
version=__version__,
version="1.0.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be changed?

)

# The config, quantize_config and model may be edited in place in save_quantized.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Third Party
import torch

# Local
from .base import BaseGPTQModel


def new_forward(self, inputs, expert_size):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was testing by defining this new_forward method here. But as you said if we are expecting it from users then does user pass this function ? Is it in GPTQModel.pretrained() function here and then from there passing it forward to BaseGPTQModel.init() ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. the user is the one who writes the new model, in this case, GraniteMoeGPTQ,

  • so while user is writing GraniteMoeGPTQ, user will then configure all the forwards to be overwritten in the update_forwards class member
    So how you have done it now is jsut fine

"""
Forward pass of the GraniteMoeParallelExperts module.
Args:
inputs (Tensor):
Input tensor.
expert_size:
Expert size information.
Returns:
Tensor: Output tensor.
"""
input_list = inputs.split(expert_size, dim=0)
output_list = []
for i in range(self.num_experts):
# the key is we need to use call the module
output_list.append(self.weight[i](input_list[i]))
results = torch.cat(output_list, dim=0)
return results


class GraniteMoeGPTQ(BaseGPTQModel):
base_modules = ["model.embed_tokens", "model.norm"]
convert3dparameters = True
update_forwards = {"GraniteMoeParallelExperts": new_forward}

layers_node = "model.layers"
layer_type = "GraniteMoeDecoderLayer"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest you add some simple key to inform the format of input_linear and output_linear, that these are 3D tensors.

Also in the granitemoe case, another compilation is that input_linear fuses w1 and w3. it might be ok for a first cut just to leave them as fused.

Copy link
Contributor

@fabianlim fabianlim Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so basically the simple key needs to know what do look for to convert it to 3D tensor, and then when you write layer_modules you write it as though they have been converrted

class GraniteMoeGPTQ(BaseGPTQModel):
    
    convert3dToModuleList = ["block_sparse_moe.input_linear", "block_sparse_moe.output_linear"]

    layer_modules = [

        [
             "block_sparse_moe.input_linear.0.weight",
              "block_sparse_moe.input_linear.1.weight",
              ...
        ], [
             "block_sparse_moe.output_linear.0.weight",
              "block_sparse_moe.output_linear.1.weight",
              ...
        ]
    ]

layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
[f"block_sparse_moe.input_linear.weight.{i}" for i in range(40)],
[f"block_sparse_moe.output_linear.weight.{i}" for i in range(40)],
]
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,33 @@ def get_moe_layer_modules(layer_modules: List, num_experts: int) -> List:
new_inside_layer_modules[-1].append(n)

return new_inside_layer_modules


def replace_3d_parameters_with_module_list(
model: torch.nn.Module,
):

for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
if len(param.shape) == 3:
device = param.device
dtype = param.dtype
num, in_features, out_features = param.shape

module_list = []
for i in range(num):
linear = torch.nn.Linear(
in_features=in_features,
out_features=out_features,
device=device,
dtype=dtype,
bias=None, # FIXME: how to support bias?
)
linear.weight.data = param.data[i]
module_list.append(linear)

module_list = torch.nn.ModuleList(module_list)

# replace
delattr(module, param_name)
setattr(module, param_name, module_list)