Skip to content

Commit 359db61

Browse files
mps + compile + quant: introduce mps_op_lib (#1530)
* introduce mps_op_lib * update torchao pin * update cpu quantizers
1 parent 8a0897d commit 359db61

File tree

3 files changed

+60
-45
lines changed

3 files changed

+60
-45
lines changed

.github/workflows/pull.yml

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -950,27 +950,11 @@ jobs:
950950
run: |
951951
export TORCHCHAT_ROOT=${PWD}
952952
echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV"
953-
- name: Load or install ET
954-
id: install-et
955-
uses: actions/cache@v4
956-
with:
957-
path: |
958-
./et-build
959-
./torchchat/utils/scripts
960-
key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh', '**/build_native.sh') }}
961-
- if: ${{ steps.install-et.outputs.cache-hit != 'true' }}
962-
continue-on-error: true
953+
- name: Install ExecuTorch
963954
run: |
964955
echo "Installing ExecuTorch"
956+
export TORCHCHAT_ROOT=${PWD}
965957
bash torchchat/utils/scripts/install_et.sh
966-
- name: Install ExecuTorch python
967-
run: |
968-
echo "Install ExecuTorch python"
969-
export TORCHCHAT_ROOT=$PWD
970-
export ET_BUILD_DIR="et-build"
971-
ENABLE_ET_PYBIND="${1:-true}"
972-
source "torchchat/utils/scripts/install_utils.sh"
973-
install_executorch_python_libs $ENABLE_ET_PYBIND
974958
- name: Install runner
975959
run: |
976960
echo "Installing runner"

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
711fa0809f06fc97febd0c3fe72563c3fe227e51
1+
7513042f39515af4c643bc1f9399952ad7f4f904

torchchat/utils/quantize.py

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,15 @@
3434

3535
# AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group'
3636
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
37+
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
38+
from torchao.experimental.quant_api import EmbeddingQuantizer
39+
from torchao.quantization.granularity import PerAxis, PerGroup
3740
from torchao.quantization.quant_api import (
3841
int4_weight_only,
3942
Int4WeightOnlyQuantizer,
4043
Int8DynActInt4WeightQuantizer,
44+
Int8DynamicActivationIntxWeightConfig,
45+
MappingType,
4146
quantize_,
4247
)
4348
from torchao.utils import unwrap_tensor_subclass
@@ -50,18 +55,6 @@
5055
state_dict_device,
5156
use_et_backend,
5257
)
53-
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
54-
PackedLinearInt8DynamicActivationIntxWeightLayout,
55-
)
56-
from torchao.experimental.quant_api import (
57-
int8_dynamic_activation_intx_weight,
58-
IntxWeightEmbeddingQuantizer,
59-
)
60-
from torchao.quantization.granularity import (
61-
PerGroup,
62-
PerRow,
63-
)
64-
from torchao.dtypes import PlainLayout
6558

6659

6760
# Flag for whether the a8wxdq quantizer is available.
@@ -87,7 +80,7 @@ def get_named_parameters(func: Callable) -> List[str]:
8780
return named_params
8881

8982
def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]:
90-
for key in q_kwargs.keys():
83+
for key in list(q_kwargs.keys()):
9184
if key not in named_params:
9285
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
9386
del q_kwargs[key]
@@ -137,29 +130,34 @@ def quantize_model(
137130
group_size = q_kwargs["groupsize"]
138131
bit_width = q_kwargs["bitwidth"]
139132
has_weight_zeros = q_kwargs["has_weight_zeros"]
140-
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
133+
granularity = PerAxis() if group_size == -1 else PerGroup(group_size)
141134
weight_dtype = getattr(torch, f"int{bit_width}")
135+
weight_mapping_type = (
136+
MappingType.ASYMMETRIC
137+
if has_weight_zeros
138+
else MappingType.SYMMETRIC
139+
)
142140

143141
try:
144142
quantize_(
145-
model,
146-
int8_dynamic_activation_intx_weight(
143+
model,
144+
Int8DynamicActivationIntxWeightConfig(
147145
weight_dtype=weight_dtype,
148-
granularity=granularity,
149-
has_weight_zeros=has_weight_zeros,
146+
weight_granularity=granularity,
147+
weight_mapping_type=weight_mapping_type,
150148
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
151149
),
152150
)
153151
except Exception as e:
154152
print("Encountered error during quantization: {e}")
155-
print("Trying with PlainLayout")
153+
print("Trying with QDQLayout")
156154
quantize_(
157-
model,
158-
int8_dynamic_activation_intx_weight(
155+
model,
156+
Int8DynamicActivationIntxWeightConfig(
159157
weight_dtype=weight_dtype,
160-
granularity=granularity,
161-
has_weight_zeros=has_weight_zeros,
162-
layout=PlainLayout(),
158+
weight_granularity=granularity,
159+
weight_mapping_type=weight_mapping_type,
160+
layout=QDQLayout(),
163161
),
164162
)
165163

@@ -174,6 +172,22 @@ def quantize_model(
174172
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
175173
set_precision(torch.float32)
176174

175+
group_size = q_kwargs["groupsize"]
176+
bit_width = q_kwargs["bitwidth"]
177+
has_weight_zeros = q_kwargs.get("has_weight_zeros", True)
178+
q_kwargs["granularity"] = (
179+
PerAxis() if group_size == -1 else PerGroup(group_size)
180+
)
181+
q_kwargs["weight_dtype"] = getattr(torch, f"int{bit_width}")
182+
q_kwargs["mapping_type"] = (
183+
MappingType.ASYMMETRIC
184+
if has_weight_zeros
185+
else MappingType.SYMMETRIC
186+
)
187+
q_kwargs["use_fallback"] = False
188+
del q_kwargs["groupsize"]
189+
del q_kwargs["bitwidth"]
190+
177191
if quantizer == "linear:afpwx" and device != "mps":
178192
raise RuntimeError("linear:afpwx quantization can only run on mps device!")
179193

@@ -188,7 +202,10 @@ def quantize_model(
188202
# Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs
189203
if "tokenizer" in named_params:
190204
q_kwargs["tokenizer"] = tokenizer
191-
quant_handler = q(device=device, precision=precision, **q_kwargs)
205+
if quantizer == "embedding:wx":
206+
quant_handler = q(**q_kwargs)
207+
else:
208+
quant_handler = q(device=device, precision=precision, **q_kwargs)
192209

193210
# quantize model
194211
model = quant_handler.quantize(model)
@@ -939,7 +956,7 @@ def quantized_model(self) -> nn.Module:
939956
# class references
940957
quantizer_class_dict = {
941958
"embedding": EmbeddingOnlyQuantHandler,
942-
"embedding:wx": IntxWeightEmbeddingQuantizer,
959+
"embedding:wx": EmbeddingQuantizer,
943960
"linear:int8": WeightOnlyInt8QuantHandler,
944961
"precision": PrecisionHandler,
945962
"executor": ExecutorHandler,
@@ -979,5 +996,19 @@ def quantized_model(self) -> nn.Module:
979996
except Exception as e:
980997
print("Unable to load torchao mps ops library.")
981998

999+
torchao_experimental_mps_op_lib_spec = importlib.util.spec_from_file_location(
1000+
"torchao_experimental_mps_op_lib",
1001+
f"{torchao_build_path}/src/ao/torchao/experimental/ops/mps/mps_op_lib.py",
1002+
)
1003+
torchao_experimental_mps_op_lib = importlib.util.module_from_spec(
1004+
torchao_experimental_mps_op_lib_spec
1005+
)
1006+
sys.modules["torchao_experimental_mps_op_lib"] = torchao_experimental_mps_op_lib
1007+
torchao_experimental_mps_op_lib_spec.loader.exec_module(
1008+
torchao_experimental_mps_op_lib
1009+
)
1010+
from torchao_experimental_mps_op_lib import *
1011+
1012+
9821013
except Exception as e:
9831014
print("Unable to import torchao experimental quant_api with error: ", e)

0 commit comments

Comments
 (0)