Skip to content

Commit 79c09c6

Browse files
introduce mps_op_lib
1 parent 8a0897d commit 79c09c6

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

torchchat/utils/quantize.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ def quantize_model(
137137
group_size = q_kwargs["groupsize"]
138138
bit_width = q_kwargs["bitwidth"]
139139
has_weight_zeros = q_kwargs["has_weight_zeros"]
140-
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
140+
granularity = PerRow() if group_size == -1 else PerGroup(group_size)
141141
weight_dtype = getattr(torch, f"int{bit_width}")
142142

143143
try:
144144
quantize_(
145-
model,
145+
model,
146146
int8_dynamic_activation_intx_weight(
147147
weight_dtype=weight_dtype,
148148
granularity=granularity,
@@ -154,7 +154,7 @@ def quantize_model(
154154
print("Encountered error during quantization: {e}")
155155
print("Trying with PlainLayout")
156156
quantize_(
157-
model,
157+
model,
158158
int8_dynamic_activation_intx_weight(
159159
weight_dtype=weight_dtype,
160160
granularity=granularity,
@@ -979,5 +979,19 @@ def quantized_model(self) -> nn.Module:
979979
except Exception as e:
980980
print("Unable to load torchao mps ops library.")
981981

982+
torchao_experimental_mps_op_lib_spec = importlib.util.spec_from_file_location(
983+
"torchao_experimental_mps_op_lib",
984+
f"{torchao_build_path}/src/ao/torchao/experimental/ops/mps/mps_op_lib.py",
985+
)
986+
torchao_experimental_mps_op_lib = importlib.util.module_from_spec(
987+
torchao_experimental_mps_op_lib_spec
988+
)
989+
sys.modules["torchao_experimental_mps_op_lib"] = torchao_experimental_mps_op_lib
990+
torchao_experimental_mps_op_lib_spec.loader.exec_module(
991+
torchao_experimental_mps_op_lib
992+
)
993+
from torchao_experimental_mps_op_lib import *
994+
995+
982996
except Exception as e:
983997
print("Unable to import torchao experimental quant_api with error: ", e)

0 commit comments

Comments
 (0)