@@ -137,12 +137,12 @@ def quantize_model(
137
137
group_size = q_kwargs ["groupsize" ]
138
138
bit_width = q_kwargs ["bitwidth" ]
139
139
has_weight_zeros = q_kwargs ["has_weight_zeros" ]
140
- granularity = PerRow () if group_size == - 1 else PerGroup (group_size )
140
+ granularity = PerRow () if group_size == - 1 else PerGroup (group_size )
141
141
weight_dtype = getattr (torch , f"int{ bit_width } " )
142
142
143
143
try :
144
144
quantize_ (
145
- model ,
145
+ model ,
146
146
int8_dynamic_activation_intx_weight (
147
147
weight_dtype = weight_dtype ,
148
148
granularity = granularity ,
@@ -154,7 +154,7 @@ def quantize_model(
154
154
print ("Encountered error during quantization: {e}" )
155
155
print ("Trying with PlainLayout" )
156
156
quantize_ (
157
- model ,
157
+ model ,
158
158
int8_dynamic_activation_intx_weight (
159
159
weight_dtype = weight_dtype ,
160
160
granularity = granularity ,
@@ -979,5 +979,19 @@ def quantized_model(self) -> nn.Module:
979
979
except Exception as e :
980
980
print ("Unable to load torchao mps ops library." )
981
981
982
+ torchao_experimental_mps_op_lib_spec = importlib .util .spec_from_file_location (
983
+ "torchao_experimental_mps_op_lib" ,
984
+ f"{ torchao_build_path } /src/ao/torchao/experimental/ops/mps/mps_op_lib.py" ,
985
+ )
986
+ torchao_experimental_mps_op_lib = importlib .util .module_from_spec (
987
+ torchao_experimental_mps_op_lib_spec
988
+ )
989
+ sys .modules ["torchao_experimental_mps_op_lib" ] = torchao_experimental_mps_op_lib
990
+ torchao_experimental_mps_op_lib_spec .loader .exec_module (
991
+ torchao_experimental_mps_op_lib
992
+ )
993
+ from torchao_experimental_mps_op_lib import *
994
+
995
+
982
996
except Exception as e :
983
997
print ("Unable to import torchao experimental quant_api with error: " , e )
0 commit comments