Skip to content

Commit 205f1d7

Browse files
quic-meetMeet Doshi
andauthored
Gemma, Gemma2 and CodeGemma support (#123)
* Support for Gemma models Signed-off-by: quic-meet <[email protected]> * Updated modeling file: batch generation still buggy Signed-off-by: quic-meet <[email protected]> * lint fixes Signed-off-by: quic-meet <[email protected]> * Test case fix for sequence length mismatch of LM head outputs Signed-off-by: quic-meet <[email protected]> * Update rope calculations Signed-off-by: quic-meet <[email protected]> * Gemma, Gemma2, CodeGemma support Signed-off-by: quic-meet <[email protected]> * Gemma2 test case fix, RMS norm weight update Signed-off-by: quic-meet <[email protected]> * PR Fix Signed-off-by: Meet Doshi <[email protected]> * Fixed test bugs Signed-off-by: Meet Doshi <[email protected]> * Removed gated models from test Signed-off-by: Meet Doshi <[email protected]> * revert to opset13 Signed-off-by: Meet Doshi <[email protected]> * ruff format Signed-off-by: Meet Doshi <[email protected]> --------- Signed-off-by: quic-meet <[email protected]> Signed-off-by: Meet Doshi <[email protected]> Signed-off-by: Meet Doshi <[email protected]> Co-authored-by: Meet Doshi <[email protected]>
1 parent 4778d42 commit 205f1d7

File tree

14 files changed

+1372
-18
lines changed

14 files changed

+1372
-18
lines changed

QEfficient/customop/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
CtxScatterFuncCB,
1313
CtxScatterFuncCB3D,
1414
)
15-
from QEfficient.customop.rms_norm import CustomRMSNormAIC
15+
from QEfficient.customop.rms_norm import CustomRMSNormAIC, GemmaCustomRMSNormAIC
1616

1717
__all__ = [
1818
"CtxGatherFunc",
1919
"CtxScatterFunc",
2020
"CtxGatherFunc3D",
2121
"CtxScatterFunc3D",
2222
"CustomRMSNormAIC",
23+
"GemmaCustomRMSNormAIC",
2324
"CtxGatherFuncCB",
2425
"CtxScatterFuncCB",
2526
"CtxGatherFuncCB3D",

QEfficient/customop/rms_norm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,20 @@ class CustomRMSNormAIC(nn.Module):
4545
def __init__(self, hidden_size, eps=1e-05):
4646
super(CustomRMSNormAIC, self).__init__()
4747
self.variance_epsilon = eps
48+
self.eps = eps # Added to support GemmaRMSNorm
4849
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
4950

5051
def forward(self, hidden_states):
51-
return CustomRMSNormFunc.apply(hidden_states, self.weight, self.variance_epsilon)
52+
return CustomRMSNormFunc.apply(
53+
hidden_states, self.weight, self.variance_epsilon if hasattr(self, "variance_epsilon") else self.eps
54+
)
55+
56+
57+
class GemmaCustomRMSNormAIC(CustomRMSNormAIC):
58+
"""
59+
Modify the init function to add +1 to the weights
60+
"""
61+
62+
def __qeff_init__(self):
63+
with torch.no_grad():
64+
self.weight.copy_(self.weight + 1.0)

QEfficient/exporter/export_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def export_onnx(
103103
except Exception as e:
104104
raise RuntimeError("Exporting to ONNX failed. {}".format(e))
105105

106-
onnx.checker.check_model(f"{gen_models_path}_tmp/{model_base_name}.onnx")
106+
onnx.checker.check_model(f"{gen_models_path}_tmp/{model_base_name}.onnx", full_check=True)
107107
loaded_model = onnx.load(f"{gen_models_path}_tmp/{model_base_name}.onnx")
108108
shutil.rmtree(f"{gen_models_path}_tmp")
109109
os.makedirs(f"{gen_models_path}", exist_ok=True)
@@ -123,7 +123,7 @@ def export_onnx(
123123
size_threshold=1024,
124124
convert_attribute=False,
125125
)
126-
onnx.checker.check_model(os.path.join(gen_models_path, f"{model_base_name}.onnx"))
126+
onnx.checker.check_model(os.path.join(gen_models_path, f"{model_base_name}.onnx"), full_check=True)
127127

128128
# Run shape inference in intial model itself
129129
onnx.shape_inference.infer_shapes_path(

QEfficient/transformers/modeling_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@
2020
FalconForCausalLM,
2121
FalconModel,
2222
)
23+
from transformers.models.gemma.modeling_gemma import (
24+
GemmaAttention,
25+
GemmaDecoderLayer,
26+
GemmaForCausalLM,
27+
GemmaModel,
28+
GemmaRMSNorm,
29+
)
30+
from transformers.models.gemma2.modeling_gemma2 import (
31+
Gemma2Attention,
32+
Gemma2DecoderLayer,
33+
Gemma2ForCausalLM,
34+
Gemma2Model,
35+
Gemma2RMSNorm,
36+
)
2337
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
2438
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
2539
GPTBigCodeAttention,
@@ -74,6 +88,13 @@
7488
QEffFalconForCausalLM,
7589
QEffFalconModel,
7690
)
91+
from .models.gemma.modeling_gemma import QEffGemmaAttention, QEffGemmaDecoderLayer, QEffGemmaForCausalLM, QEffGemmaModel
92+
from .models.gemma2.modeling_gemma2 import (
93+
QEffGemma2Attention,
94+
QEffGemma2DecoderLayer,
95+
QEffGemma2ForCausalLM,
96+
QEffGemma2Model,
97+
)
7798
from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model
7899
from .models.gpt_bigcode.modeling_gpt_bigcode import (
79100
QEffGPTBigCodeAttention,
@@ -119,6 +140,8 @@
119140
get_lists_of_cb_qeff_models = ModelArchitectures(
120141
[
121142
LlamaForCausalLM.__name__,
143+
GemmaForCausalLM.__name__,
144+
Gemma2ForCausalLM.__name__,
122145
MistralForCausalLM.__name__,
123146
MixtralForCausalLM.__name__,
124147
Starcoder2ForCausalLM.__name__,
@@ -141,6 +164,8 @@
141164
MptForCausalLM.__name__,
142165
CodeGenForCausalLM.__name__,
143166
LlamaForCausalLM.__name__,
167+
GemmaForCausalLM.__name__,
168+
Gemma2ForCausalLM.__name__,
144169
MistralForCausalLM.__name__,
145170
MixtralForCausalLM.__name__,
146171
Phi3ForCausalLM.__name__,
@@ -170,6 +195,18 @@
170195
LlamaForCausalLM: QEffLlamaForCausalLM,
171196
LlamaDecoderLayer: QEffLlamaDecoderLayer,
172197
LlamaRMSNorm: CustomRMSNormAIC,
198+
# Gemma model layers
199+
GemmaModel: QEffGemmaModel,
200+
GemmaAttention: QEffGemmaAttention,
201+
GemmaForCausalLM: QEffGemmaForCausalLM,
202+
GemmaDecoderLayer: QEffGemmaDecoderLayer,
203+
GemmaRMSNorm: CustomRMSNormAIC,
204+
# Gemma2 model layers
205+
Gemma2Model: QEffGemma2Model,
206+
Gemma2Attention: QEffGemma2Attention,
207+
Gemma2ForCausalLM: QEffGemma2ForCausalLM,
208+
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
209+
Gemma2RMSNorm: CustomRMSNormAIC,
173210
# MPT model layers
174211
MptAttention: QEffMptAttention,
175212
MptBlock: QEffMptBlock,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)