Skip to content

Commit 6a1ba3c

Browse files
authored
Update gemma2 examples with a note about sample generation (#1176)
SUMMARY: - Add a note advising users to either downgrade transformers from 4.49 or use vLLM for generation - We should revisit why this is only happening on generation with this new release but can be revisited down the road
1 parent 45f2b33 commit 6a1ba3c

File tree

3 files changed

+10
-0
lines changed

3 files changed

+10
-0
lines changed

examples/quantization_kv_cache/gemma2_fp8_kv_example.py

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def process_and_tokenize(example):
8686
"Please use vLLM for inference with the quantized kv_cache.",
8787
)
8888
# Confirm generations of the quantized model look sane.
89+
90+
# NOTE: transformers 4.49.0 results in a generation error with gemma2.
91+
# Consider either downgrading your transformers version to a previous version
92+
# or use vLLM for sample generation.
8993
print("\n\n")
9094
print("========== SAMPLE GENERATION ==============")
9195
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")

examples/quantization_w8a8_fp8/gemma2_example.py

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
)
3030

3131
# Confirm generations of the quantized model look sane.
32+
# NOTE: transformers 4.49.0 results in a generation error with gemma2.
33+
# Consider either downgrading your transformers version to a previous version
34+
# or use vLLM for sample generation.
3235
print("========== SAMPLE GENERATION ==============")
3336
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
3437
output = model.generate(input_ids, max_new_tokens=20)

examples/quantization_w8a8_int8/gemma2_example.py

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def tokenize(sample):
6868
)
6969

7070
# Confirm generations of the quantized model look sane.
71+
# NOTE: transformers 4.49.0 results in a generation error with gemma2.
72+
# Consider either downgrading your transformers version to a previous version
73+
# or use vLLM for sample generation.
7174
print("========== SAMPLE GENERATION ==============")
7275
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
7376
output = model.generate(input_ids, max_new_tokens=20)

0 commit comments

Comments
 (0)