Skip to content

Commit 99f9f10

Browse files
jiqing-fengSunMarc
andauthored
Fix torchao usage (#37034)
* fix load path Signed-off-by: jiqing-feng <[email protected]> * fix path Signed-off-by: jiqing-feng <[email protected]> * Fix torchao usage Signed-off-by: jiqing-feng <[email protected]> * fix tests Signed-off-by: jiqing-feng <[email protected]> * fix format Signed-off-by: jiqing-feng <[email protected]> * revert useless change Signed-off-by: jiqing-feng <[email protected]> * format Signed-off-by: jiqing-feng <[email protected]> * revert fp8 test Signed-off-by: jiqing-feng <[email protected]> * fix fp8 test Signed-off-by: jiqing-feng <[email protected]> * fix fp8 test Signed-off-by: jiqing-feng <[email protected]> * fix torch dtype Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]> Co-authored-by: Marc Sun <[email protected]>
1 parent 0fb8d49 commit 99f9f10

File tree

2 files changed

+50
-36
lines changed

2 files changed

+50
-36
lines changed

src/transformers/utils/quantization_config.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import importlib.metadata
2121
import json
2222
import os
23-
from dataclasses import dataclass
23+
from dataclasses import dataclass, is_dataclass
2424
from enum import Enum
2525
from inspect import Parameter, signature
2626
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -1627,6 +1627,7 @@ def get_apply_tensor_subclass(self):
16271627
and is_torchao_available()
16281628
and self.quant_type == "int4_weight_only"
16291629
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
1630+
and quant_type_kwargs.get("layout", None) is None
16301631
):
16311632
from torchao.dtypes import Int4CPULayout
16321633

@@ -1643,7 +1644,17 @@ def to_dict(self):
16431644
if isinstance(self.quant_type, str):
16441645
# Handle layout serialization if present
16451646
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
1646-
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"])
1647+
if is_dataclass(d["quant_type_kwargs"]["layout"]):
1648+
d["quant_type_kwargs"]["layout"] = [
1649+
d["quant_type_kwargs"]["layout"].__class__.__name__,
1650+
dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
1651+
]
1652+
if isinstance(d["quant_type_kwargs"]["layout"], list):
1653+
assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layour kwargs"
1654+
assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
1655+
assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
1656+
else:
1657+
raise ValueError("layout must be a list")
16471658
else:
16481659
# Handle AOBaseConfig serialization
16491660
from torchao.core.config import config_to_dict
@@ -1661,6 +1672,9 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
16611672
assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict"
16621673
config_dict = config_dict.copy()
16631674
quant_type = config_dict.pop("quant_type")
1675+
1676+
if isinstance(quant_type, str):
1677+
return cls(quant_type=quant_type, **config_dict)
16641678
# Check if we only have one key which is "default"
16651679
# In the future we may update this
16661680
assert len(quant_type) == 1 and "default" in quant_type, (

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ def test_json_serializable(self):
104104
"""
105105
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
106106
d = quantization_config.to_dict()
107-
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
108-
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
107+
self.assertIsInstance(d["quant_type_kwargs"]["layout"], list)
108+
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1])
109109
quantization_config.to_json_string(use_diff=False)
110110

111111

@@ -159,7 +159,7 @@ def test_int4wo_quant_bfloat16_conversion(self):
159159
# Note: we quantize the bfloat16 model on the fly to int4
160160
quantized_model = AutoModelForCausalLM.from_pretrained(
161161
self.model_name,
162-
torch_dtype=None,
162+
torch_dtype=torch.bfloat16,
163163
device_map=self.device,
164164
quantization_config=quant_config,
165165
)
@@ -282,7 +282,7 @@ def test_autoquant(self):
282282

283283
quantized_model = AutoModelForCausalLM.from_pretrained(
284284
self.model_name,
285-
torch_dtype=torch.bfloat16,
285+
torch_dtype="auto",
286286
device_map=self.device,
287287
quantization_config=quant_config,
288288
)
@@ -295,7 +295,7 @@ def test_autoquant(self):
295295

296296
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
297297

298-
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready'
298+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
299299
output = quantized_model.generate(
300300
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
301301
)
@@ -307,9 +307,7 @@ def test_autoquant(self):
307307
class TorchAoSerializationTest(unittest.TestCase):
308308
input_text = "What are we having for dinner?"
309309
max_new_tokens = 10
310-
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
311-
# TODO: investigate why we don't have the same output as the original model for this test
312-
SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
310+
EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
313311
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
314312
quant_scheme = "int4_weight_only"
315313
quant_scheme_kwargs = (
@@ -326,9 +324,10 @@ def setUpClass(cls):
326324

327325
def setUp(self):
328326
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
327+
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
329328
self.quantized_model = AutoModelForCausalLM.from_pretrained(
330329
self.model_name,
331-
torch_dtype=torch.bfloat16,
330+
torch_dtype=torch_dtype,
332331
device_map=self.device,
333332
quantization_config=self.quant_config,
334333
)
@@ -342,50 +341,49 @@ def test_original_model_expected_output(self):
342341
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device)
343342
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
344343

345-
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT)
344+
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
346345

347346
def check_serialization_expected_output(self, device, expected_output):
348347
"""
349348
Test if we can serialize and load/infer the model again on the same device
350349
"""
350+
torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
351351
with tempfile.TemporaryDirectory() as tmpdirname:
352352
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
353353
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(
354-
self.model_name, torch_dtype=torch.bfloat16, device_map=device
354+
tmpdirname, torch_dtype=torch_dtype, device_map=device
355355
)
356356
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
357357

358358
output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
359359
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
360360

361361
def test_serialization_expected_output(self):
362-
self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT)
362+
self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT)
363363

364364

365365
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
366366
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
367-
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
368-
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
367+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
369368

370369
@require_torch_gpu
371370
def test_serialization_expected_output_on_cuda(self):
372371
"""
373372
Test if we can serialize on device (cpu) and load/infer the model on cuda
374373
"""
375-
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
374+
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
376375

377376

378377
class TorchAoSerializationW8CPUTest(TorchAoSerializationTest):
379378
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
380-
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
381-
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
379+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
382380

383381
@require_torch_gpu
384382
def test_serialization_expected_output_on_cuda(self):
385383
"""
386384
Test if we can serialize on device (cpu) and load/infer the model on cuda
387385
"""
388-
self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT)
386+
self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT)
389387

390388

391389
@require_torch_gpu
@@ -397,53 +395,55 @@ class TorchAoSerializationGPTTest(TorchAoSerializationTest):
397395
@require_torch_gpu
398396
class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest):
399397
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
400-
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
401-
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
398+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
402399
device = "cuda:0"
403400

404401

405402
@require_torch_gpu
406403
class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
407404
quant_scheme, quant_scheme_kwargs = "int8_weight_only", {}
408-
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
409-
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
405+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
410406
device = "cuda:0"
411407

412408

413409
@require_torch_gpu
414410
@require_torchao_version_greater_or_equal("0.10.0")
415411
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
416-
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
417-
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
412+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
418413
device = "cuda:0"
419414

420-
def setUp(self):
415+
# called only once for all test in this class
416+
@classmethod
417+
def setUpClass(cls):
421418
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
422419
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
423420

424421
from torchao.quantization import Float8WeightOnlyConfig
425422

426-
self.quant_scheme = Float8WeightOnlyConfig()
427-
self.quant_scheme_kwargs = {}
428-
super().setUp()
423+
cls.quant_scheme = Float8WeightOnlyConfig()
424+
cls.quant_scheme_kwargs = {}
425+
426+
super().setUpClass()
429427

430428

431429
@require_torch_gpu
432430
@require_torchao_version_greater_or_equal("0.10.0")
433431
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
434-
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
435-
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
432+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
436433
device = "cuda:0"
437434

438-
def setUp(self):
435+
# called only once for all test in this class
436+
@classmethod
437+
def setUpClass(cls):
439438
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
440439
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
441440

442441
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
443442

444-
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
445-
self.quant_scheme_kwargs = {}
446-
super().setUp()
443+
cls.quant_scheme = Int8DynamicActivationInt4WeightConfig()
444+
cls.quant_scheme_kwargs = {}
445+
446+
super().setUpClass()
447447

448448

449449
if __name__ == "__main__":

0 commit comments

Comments
 (0)