Skip to content

Commit 6fcc92b

Browse files
committed
Merge branch 'fix-test-conseq-oneshot' of github.com:vllm-project/llm-compressor into fix-test-conseq-oneshot
2 parents 14d5c49 + 11fd1f0 commit 6fcc92b

File tree

15 files changed

+1451
-35
lines changed

15 files changed

+1451
-35
lines changed

.github/workflows/test-check-transformers.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ on:
33
pull_request:
44
branches: main
55
types: [ labeled, synchronize ]
6+
push:
7+
branches: main
68

79
env:
810
CADENCE: "commit"
@@ -15,7 +17,7 @@ env:
1517
jobs:
1618
transformers-tests:
1719
runs-on: gcp-k8s-vllm-l4-solo
18-
if: contains(github.event.pull_request.labels.*.name, 'ready')
20+
if: contains(github.event.pull_request.labels.*.name, 'ready') || github.event_name == 'push'
1921
steps:
2022
- uses: actions/setup-python@v5
2123
with:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoProcessor
3+
4+
from llmcompressor.modifiers.quantization import GPTQModifier
5+
from llmcompressor.transformers import oneshot
6+
from llmcompressor.transformers.utils.data_collator import phi3_vision_data_collator
7+
8+
# Load model.
9+
model_id = "microsoft/Phi-3-vision-128k-instruct"
10+
model = AutoModelForCausalLM.from_pretrained(
11+
model_id,
12+
device_map="auto",
13+
torch_dtype="auto",
14+
trust_remote_code=True,
15+
_attn_implementation="eager",
16+
)
17+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
18+
processor.chat_template = processor.tokenizer.chat_template
19+
20+
# Oneshot arguments
21+
DATASET_ID = "lmms-lab/flickr30k"
22+
DATASET_SPLIT = "test[:512]"
23+
NUM_CALIBRATION_SAMPLES = 512
24+
MAX_SEQUENCE_LENGTH = 2048
25+
26+
# Load dataset and preprocess.
27+
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
28+
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
29+
30+
31+
# Apply chat template
32+
def preprocess(example):
33+
messages = [{"role": "user", "content": "<|image_1|>\nWhat does the image show?"}]
34+
return {
35+
"text": processor.apply_chat_template(
36+
messages,
37+
add_generation_prompt=True,
38+
),
39+
"images": example["image"],
40+
}
41+
42+
43+
ds = ds.map(preprocess)
44+
45+
46+
# # Tokenize inputs.
47+
def tokenize(sample):
48+
return processor(
49+
text=sample["text"],
50+
images=sample["images"],
51+
padding=False,
52+
max_length=MAX_SEQUENCE_LENGTH,
53+
truncation=True,
54+
)
55+
56+
57+
# long data lengths produced by the phi3_vision processor
58+
# can lead to integer overflows when mapping, avoid with writer_batch_size
59+
ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names)
60+
61+
62+
# Recipe
63+
recipe = [
64+
GPTQModifier(
65+
targets="Linear",
66+
scheme="W4A16",
67+
sequential_targets=["Phi3DecoderLayer"],
68+
ignore=["lm_head", "re:model.vision_embed_tokens.*"],
69+
),
70+
]
71+
72+
# Perform oneshot
73+
oneshot(
74+
model=model,
75+
dataset=ds,
76+
recipe=recipe,
77+
max_seq_length=MAX_SEQUENCE_LENGTH,
78+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
79+
trust_remote_code_model=True,
80+
data_collator=phi3_vision_data_collator,
81+
)
82+
83+
# Confirm generations of the quantized model look sane.
84+
print("========== SAMPLE GENERATION ==============")
85+
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
86+
output = model.generate(input_ids, max_new_tokens=20)
87+
print(processor.decode(output[0]))
88+
print("==========================================")
89+
90+
# Save to disk compressed.
91+
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
92+
model.save_pretrained(SAVE_DIR, save_compressed=True)
93+
processor.save_pretrained(SAVE_DIR)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import base64
2+
from io import BytesIO
3+
4+
from datasets import load_dataset
5+
from qwen_vl_utils import process_vision_info
6+
from transformers import AutoProcessor
7+
8+
from llmcompressor.modifiers.quantization import GPTQModifier
9+
from llmcompressor.transformers import oneshot
10+
from llmcompressor.transformers.tracing import TraceableQwen2VLForConditionalGeneration
11+
from llmcompressor.transformers.utils.data_collator import qwen2_vl_data_collator
12+
13+
# Load model.
14+
model_id = "Qwen/Qwen2-VL-2B-Instruct"
15+
model = TraceableQwen2VLForConditionalGeneration.from_pretrained(
16+
model_id,
17+
device_map="auto",
18+
torch_dtype="auto",
19+
)
20+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
21+
22+
# Oneshot arguments
23+
DATASET_ID = "lmms-lab/flickr30k"
24+
DATASET_SPLIT = {"calibration": "test[:512]"}
25+
NUM_CALIBRATION_SAMPLES = 512
26+
MAX_SEQUENCE_LENGTH = 2048
27+
28+
# Load dataset and preprocess.
29+
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
30+
ds = ds.shuffle(seed=42)
31+
32+
33+
# Apply chat template and tokenize inputs.
34+
def preprocess_and_tokenize(example):
35+
# preprocess
36+
buffered = BytesIO()
37+
example["image"].save(buffered, format="PNG")
38+
encoded_image = base64.b64encode(buffered.getvalue())
39+
encoded_image_text = encoded_image.decode("utf-8")
40+
base64_qwen = f"data:image;base64,{encoded_image_text}"
41+
messages = [
42+
{
43+
"role": "user",
44+
"content": [
45+
{"type": "image", "image": base64_qwen},
46+
{"type": "text", "text": "What does the image show?"},
47+
],
48+
}
49+
]
50+
text = processor.apply_chat_template(
51+
messages, tokenize=False, add_generation_prompt=True
52+
)
53+
image_inputs, video_inputs = process_vision_info(messages)
54+
55+
# tokenize
56+
return processor(
57+
text=[text],
58+
images=image_inputs,
59+
videos=video_inputs,
60+
padding=False,
61+
max_length=MAX_SEQUENCE_LENGTH,
62+
truncation=True,
63+
)
64+
65+
66+
ds = ds.map(preprocess_and_tokenize, remove_columns=ds["calibration"].column_names)
67+
68+
# Recipe
69+
recipe = [
70+
GPTQModifier(
71+
targets="Linear",
72+
scheme="W4A16",
73+
sequential_targets=["Qwen2VLDecoderLayer"],
74+
ignore=["lm_head", "re:visual.*"],
75+
),
76+
]
77+
78+
# Perform oneshot
79+
oneshot(
80+
model=model,
81+
tokenizer=model_id,
82+
dataset=ds,
83+
recipe=recipe,
84+
max_seq_length=MAX_SEQUENCE_LENGTH,
85+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
86+
trust_remote_code_model=True,
87+
data_collator=qwen2_vl_data_collator,
88+
)
89+
90+
# Confirm generations of the quantized model look sane.
91+
print("========== SAMPLE GENERATION ==============")
92+
messages = [
93+
{
94+
"role": "user",
95+
"content": [
96+
{
97+
"type": "image",
98+
"image": "http://images.cocodataset.org/train2017/000000231895.jpg",
99+
},
100+
{"type": "text", "text": "Please describe the animal in this image\n"},
101+
],
102+
}
103+
]
104+
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
105+
image_inputs, video_inputs = process_vision_info(messages)
106+
inputs = processor(
107+
text=[prompt],
108+
images=image_inputs,
109+
videos=video_inputs,
110+
padding=False,
111+
max_length=MAX_SEQUENCE_LENGTH,
112+
truncation=True,
113+
return_tensors="pt",
114+
).to("cuda")
115+
output = model.generate(**inputs, max_new_tokens=100)
116+
print(processor.decode(output[0], skip_special_tokens=True))
117+
print("==========================================")
118+
119+
120+
# Save to disk compressed.
121+
SAVE_DIR = model_id.split("/")[1] + "-W4A16-G128"
122+
model.save_pretrained(SAVE_DIR, save_compressed=True)
123+
processor.save_pretrained(SAVE_DIR)

examples/sparse_2of4_quantization_fp8/README.md

+12-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ oneshot(
9393
)
9494
```
9595

96-
3. **Save the Compressed Model**
96+
### Saving the Compressed Model
9797

9898
The compressed model and tokenizer are saved to the output directory:
9999

@@ -106,6 +106,17 @@ Output Directories:
106106
- Without FP8: `Meta-Llama-3-8B-Instruct-2of4-sparse`
107107
- With FP8: `Meta-Llama-3-8B-Instruct-2of4-W8A8-FP8-Dynamic-Per-Token`
108108

109+
#### Saving Without Sparse Compression
110+
111+
To save the model on disk without sparse compression:
112+
113+
```python
114+
model.save_pretrained(save_dir, save_compressed=True, disable_sparse_compression=True)
115+
tokenizer.save_pretrained(save_dir)
116+
```
117+
118+
> **Note:** Saving a model with both the `save_compressed` and `disable_sparse_compression` options will compress the model using the quantization compressor; however, instead of using the more disk-efficient sparsity compressor(s), the dense sparsity compressor will be used. The `dense` sparsity compressor saves model params as is, and does not leverage sparsity for disk-efficient storage. These options only affect how the model(s) are saved on disk and do not impact the actual pruning or quantization processes.
119+
109120
### Validation
110121

111122
After compression, the script validates the model by generating a sample output:

src/llmcompressor/transformers/compression/quantization_format.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22

33
from compressed_tensors import CompressionFormat
4-
from compressed_tensors.config import SparsityCompressionConfig
4+
from compressed_tensors.config import SparsityStructure
55
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
66
from compressed_tensors.quantization.utils import (
77
is_model_quantized,
@@ -16,10 +16,30 @@ def infer_quantization_format(
1616
model,
1717
quantization_format: Optional[str] = None,
1818
save_compressed: bool = False,
19-
sparsity_config: Optional[SparsityCompressionConfig] = None,
19+
sparsity_structure: Optional[str] = None,
2020
) -> str:
2121
"""
22-
Infers a quantization format based on model state and compression args
22+
Infers the quantization format for a model based on its state and provided
23+
compression arguments.
24+
25+
The following table outlines the possible quantization and sparsity formats
26+
along with their corresponding compressor formats:
27+
28+
+---------------+----------+----------------------+---------------------+
29+
| Quantization | Sparsity | Quant Compressor | Sparsity Compressor |
30+
| | | Format | Format |
31+
+---------------+----------+----------------------+---------------------+
32+
| W8A8 - int | None | int_quantized | Dense |
33+
| W8A8 - float | None | float_quantized | Dense |
34+
| W4A16 - int | None | pack_quantized | Dense |
35+
| W8A16 - int | None | pack_quantized | Dense |
36+
| W8A16 - float | None | naive_quantized | Dense |
37+
| W8A8 - int | 2:4 | int_quantized | Sparse24 |
38+
| W8A8 - float | 2:4 | float_quantized | Sparse24 |
39+
| W4A16 - int | 2:4 | marlin_24 | Dense |
40+
| W8A16 - int | 2:4 | marlin_24 | Dense |
41+
| W8A16 - float | 2:4 | naive_quantized | Dense |
42+
+---------------+----------+----------------------+---------------------+
2343
2444
:param model: model to check for quantization, if the model is not quantized no
2545
quantization format is returned
@@ -37,7 +57,7 @@ def infer_quantization_format(
3757
if save_compressed:
3858
weight_args, input_args = _get_unique_quant_args(model)
3959
is_24_structure = (
40-
sparsity_config and sparsity_config.sparsity_structure == "2:4"
60+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
4161
)
4262
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
4363

0 commit comments

Comments
 (0)