Skip to content

Commit eff9472

Browse files
Gemma 3 minor fixes (#476)
CI enablement and other minor fixes for Gemma3 --------- Signed-off-by: Ann Kuruvilla <[email protected]>
1 parent 61b1445 commit eff9472

File tree

7 files changed

+39
-43
lines changed

7 files changed

+39
-43
lines changed

QEfficient/transformers/cache_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,6 @@ def from_legacy_cache(
288288
class QEffHybridCache(HybridCache):
289289
def __init__(self, config, batch_size, max_cache_len):
290290
super().__init__(config, batch_size, max_cache_len=max_cache_len)
291-
# breakpoint()
292291
self.key_cache: List[torch.Tensor] = []
293292
self.value_cache: List[torch.Tensor] = []
294293

@@ -327,7 +326,6 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
327326
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
328327
backward compatibility."""
329328
legacy_cache = ()
330-
# breakpoint()
331329
for layer_idx in range(len(self)):
332330
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
333331
return legacy_cache

QEfficient/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,9 @@ def forward(
238238
)
239239
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
240240
if self.is_sliding:
241-
cos, sin = self.rotary_emb_local(value_states, seq_len=constants.GEMMA3_MAX_POSITION_EMBEDDINGS)
241+
cos, sin = self.rotary_emb_local(value_states, seq_len=self.config.max_position_embeddings)
242242
else:
243-
cos, sin = self.rotary_emb(value_states, seq_len=constants.GEMMA3_MAX_POSITION_EMBEDDINGS)
243+
cos, sin = self.rotary_emb(value_states, seq_len=self.config.max_position_embeddings)
244244

245245
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
246246
if past_key_value is not None:
@@ -687,7 +687,6 @@ def get_specializations(
687687
"mm_tokens_per_image": mm_tokens_per_image,
688688
},
689689
]
690-
691690
specializations = {}
692691

693692
if kv_offload:

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
*Latest news* :fire: <br>
99
- [06/2025] Added support for Llama4 Multi-Model [meta-llama/Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
10+
- [06/2025] Added support for Gemma3 Multi-Modal-Model [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)
1011
- [06/2025] Added support of model `hpcai-tech/grok-1` [hpcai-tech/grok-1](https://huggingface.co/hpcai-tech/grok-1)
1112
- [04/2025] Added support of model `ibm-granite/granite-vision-3.2-2b`[ibm-granite/granite-vision-3.2-2b](https://huggingface.co/ibm-granite/granite-vision-3.2-2b)
1213
- [03/2025] Added support for swiftkv model [Snowflake/Llama-3.1-SwiftKV-8B-Instruct](https://huggingface.co/Snowflake/Llama-3.1-SwiftKV-8B-Instruct)

docs/source/validate.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
| **MllamaForConditionalGeneration** | Llama 3.2 | [meta-llama/Llama-3.2-11B-Vision Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct)<br>[meta-llama/Llama-3.2-90B-Vision](https://huggingface.co/meta-llama/Llama-3.2-90B-Vision) |
6464
|**LlavaNextForConditionalGeneration** | Granite Vision | [ibm-granite/granite-vision-3.2-2b](https://huggingface.co/ibm-granite/granite-vision-3.2-2b)
6565
|**Llama4ForConditionalGeneration** | Llama-4-Scout | [Llama-4-Scout-17B-16E-Instruct](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct)
66+
|**Gemma3ForConditionalGeneration** | Gemma3 | [google/gemma-3-4b-it](https://huggingface.co/google/gemma-3-4b-it)
67+
6668
### Audio Models
6769
(Automatic Speech Recognition) - Transcription Task
6870
**QEff Auto Class:** `QEFFAutoModelForSpeechSeq2Seq`

examples/gemma3_example/fp32_mm.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ FP32NodeInstanceNames:
370370
- /language_model/model/layers.4/self_attn/Mul_6_output_0
371371
- /language_model/model/layers.4/self_attn/Mul_7_output_0
372372
- /language_model/model/layers.4/self_attn/Mul_8_output_0
373-
- /language_model/model/layers.4/self_attn/Mul_9_output_0 [274/1312]
373+
- /language_model/model/layers.4/self_attn/Mul_9_output_0
374374
- /language_model/model/layers.5/self_attn/Mul_output_0
375375
- /language_model/model/layers.5/self_attn/Mul_1_output_0
376376
- /language_model/model/layers.5/self_attn/Mul_2_output_0
@@ -415,7 +415,7 @@ FP32NodeInstanceNames:
415415
- /language_model/model/layers.9/self_attn/Mul_1_output_0
416416
- /language_model/model/layers.9/self_attn/Mul_2_output_0
417417
- /language_model/model/layers.9/self_attn/Mul_3_output_0
418-
- /language_model/model/layers.9/self_attn/Mul_4_output_0 [229/1312]
418+
- /language_model/model/layers.9/self_attn/Mul_4_output_0
419419
- /language_model/model/layers.9/self_attn/Mul_5_output_0
420420
- /language_model/model/layers.9/self_attn/Mul_6_output_0
421421
- /language_model/model/layers.9/self_attn/Mul_7_output_0

examples/gemma3_example/gemma3_mm.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import torch
99
import transformers
10-
from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor, TextStreamer
10+
from transformers import AutoConfig, AutoProcessor
1111

1212
from QEfficient import QEFFAutoModelForImageTextToText
1313

@@ -16,12 +16,14 @@
1616
# For Testing Purpose Only
1717
config.text_config.num_hidden_layers = 1
1818
config.vision_config.num_hidden_layers = 2
19-
20-
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config)
21-
model.eval()
2219
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
2320
processor = AutoProcessor.from_pretrained(model_id)
24-
qeff_model = QEFFAutoModelForImageTextToText(model, kv_offload=True)
21+
22+
# pass HF_TOKEN if gated model
23+
# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ###
24+
qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
25+
model_id, config=config, attn_implementation="eager", kv_offload=True
26+
)
2527

2628
### use skip_vision=Ture, if want to run only text, or false ###
2729
skip_vision = True
@@ -59,9 +61,7 @@
5961
return_tensors="pt",
6062
)
6163

62-
streamer = TextStreamer(tokenizer)
63-
output = qeff_model.generate(inputs=inputs, device_ids=[0], generation_len=100)
64-
print(output.generated_ids)
64+
output = qeff_model.generate(inputs=inputs, generation_len=100)
6565
print(tokenizer.batch_decode(output.generated_ids))
6666
print(output)
6767

@@ -72,7 +72,7 @@
7272
ctx_len=3072,
7373
img_size=896,
7474
num_cores=16,
75-
num_devices=8,
75+
num_devices=1,
7676
mxfp6_matmul=False,
7777
mxint8_kv_cache=False,
7878
aic_enable_depth_first=True,
@@ -103,9 +103,6 @@
103103
return_tensors="pt",
104104
)
105105
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
106-
streamer = TextStreamer(tokenizer)
107-
output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100)
108-
print(output.generated_ids)
106+
output = qeff_model.generate(inputs=inputs, generation_len=100)
109107
print(tokenizer.batch_decode(output.generated_ids))
110108
print(output)
111-
print()

tests/transformers/models/test_image_text_to_text_models.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -88,29 +88,28 @@
8888
"What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud",
8989
4,
9090
),
91-
# FIX: Accuracy in AIC
92-
# (
93-
# "google/gemma-3-4b-it",
94-
# True,
95-
# 1,
96-
# 128,
97-
# 3072,
98-
# 896,
99-
# "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
100-
# "Can you describe the image in detail.",
101-
# 6,
102-
# ),
103-
# (
104-
# "google/gemma-3-4b-it",
105-
# False,
106-
# 1,
107-
# 128,
108-
# 3072,
109-
# 896,
110-
# "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
111-
# "Can you describe the image in detail.",
112-
# 6,
113-
# ),
91+
(
92+
"google/gemma-3-4b-it",
93+
True,
94+
1,
95+
128,
96+
3072,
97+
896,
98+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
99+
"Can you describe the image in detail.",
100+
1,
101+
),
102+
(
103+
"google/gemma-3-4b-it",
104+
False,
105+
1,
106+
128,
107+
3072,
108+
896,
109+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png",
110+
"Can you describe the image in detail.",
111+
1,
112+
),
114113
# (
115114
# "meta-llama/Llama-3.2-11B-Vision-Instruct",
116115
# True,

0 commit comments

Comments
 (0)