@@ -130,12 +130,20 @@ def forward(self, x):
130
130
131
131
class QEffFP8Config (QuantizationConfigMixin ):
132
132
def __init__ (
133
- self , quant_method : str , activation_scheme : str , ignored_layers : List [str ] = None , kv_cache_scheme : str = None
133
+ self ,
134
+ quant_method : str ,
135
+ activation_scheme : str ,
136
+ ignored_layers : List [str ] = None ,
137
+ kv_cache_scheme : str = None ,
138
+ run_compressed : bool = True ,
134
139
):
135
140
self .quant_method = quant_method
136
141
self .activation_scheme = activation_scheme
137
142
self .ignored_layers = ignored_layers
138
143
self .kv_cache_scheme = kv_cache_scheme
144
+ self .run_compressed = run_compressed
145
+ self .quantization_config = None
146
+ self .sparsity_config = None
139
147
if kv_cache_scheme :
140
148
logger .warning (
141
149
f"kv_cache_scheme={ kv_cache_scheme } will be ignored please use `mxint8_kv_cache=True` during compile call if you want to keep kv cache in int8 at runtime on Cloud AI 100"
@@ -156,7 +164,7 @@ def __init__(self, quantization_config, **kwargs):
156
164
raise TypeError (f"Only { QEffFP8Config } is supported for initialization got { type (quantization_config )} " )
157
165
158
166
self .quantization_config = quantization_config
159
-
167
+ self . run_compressed = quantization_config . run_compressed
160
168
# -- Handle extra kwargs below --
161
169
self .modules_to_not_convert = kwargs .pop ("modules_to_not_convert" , [])
162
170
self .modules_to_not_convert = list (
@@ -216,6 +224,7 @@ def __init__(
216
224
ignore = None ,
217
225
sparsity_config = None ,
218
226
quant_method = "compressed-tensors" ,
227
+ run_compressed : bool = True ,
219
228
** kwargs ,
220
229
):
221
230
self .config_groups = config_groups
@@ -226,6 +235,10 @@ def __init__(
226
235
self .global_compression_ratio = global_compression_ratio
227
236
self .ignore = ignore
228
237
238
+ self .quantization_config = None
239
+ self .sparsity_config = None
240
+
241
+ self .run_compressed = run_compressed
229
242
# Validate configuration
230
243
if len (self .config_groups ) != 1 :
231
244
raise NotImplementedError (
@@ -318,7 +331,7 @@ def __init__(self, quantization_config, **kwargs):
318
331
raise TypeError (
319
332
f"Only { QEffCompressedTensorsConfig } is supported for initialization got { type (quantization_config )} "
320
333
)
321
-
334
+ self . run_compressed = quantization_config . run_compressed
322
335
self .quantization_config = quantization_config
323
336
324
337
# -- Handle extra kwargs below --
0 commit comments