Skip to content

Commit 63f7da2

Browse files
andrewor14sekyondaMetaAlannaBurkesvekars
authored
Update "GPU Quantization with TorchAO" (#3439)
Refresh oudated links and APIs. Co-authored-by: sekyondaMeta <[email protected]> Co-authored-by: Alanna Burke <[email protected]> Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 6d169f6 commit 63f7da2

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

prototype_source/gpu_quantization_torchao_tutorial.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
# > conda create -n myenv python=3.10
3232
# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
3333
# > pip install git+https://github.com/facebookresearch/segment-anything.git
34-
# > pip install git+https://github.com/pytorch-labs/ao.git
34+
# > pip install git+https://github.com/pytorch/ao.git
3535
#
3636
# Segment Anything Model checkpoint setup:
3737
#
@@ -44,7 +44,7 @@
4444
#
4545

4646
import torch
47-
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
47+
from torchao.quantization.quant_api import quantize_, Int8DynamicActivationInt8WeightConfig
4848
from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5
4949
from segment_anything import sam_model_registry
5050
from torch.utils.benchmark import Timer
@@ -143,7 +143,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
143143
# for improvements.
144144
#
145145
# Next, let's apply quantization. Quantization for GPUs comes in three main forms
146-
# in `torchao <https://github.com/pytorch-labs/ao>`_ which is just native
146+
# in `torchao <https://github.com/pytorch/ao>`_ which is just native
147147
# pytorch+python code. This includes:
148148
#
149149
# * int8 dynamic quantization
@@ -157,9 +157,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
157157
# in memory bound situations where the benefit comes from loading less
158158
# weight data, rather than doing less computation. The torchao APIs:
159159
#
160-
# ``int8_dynamic_activation_int8_weight()``,
161-
# ``int8_weight_only()`` or
162-
# ``int4_weight_only()``
160+
# ``Int8DynamicActivationInt8WeightConfig()``,
161+
# ``Int8WeightOnlyConfig()`` or
162+
# ``Int4WeightOnlyConfig()``
163163
#
164164
# can be used to easily apply the desired quantization technique and then
165165
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
@@ -171,7 +171,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
171171
# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
172172
# above (no replacement for int4).
173173
#
174-
# The difference between the two APIs is that ``int8_dynamic_activation`` API
174+
# The difference between the two APIs is that the ``Int8DynamicActivationInt8WeightConfig`` API
175175
# alters the weight tensor of the linear module so instead of doing a
176176
# normal linear, it does a quantized operation. This is helpful when you
177177
# have non-standard linear ops that do more than one thing. The ``apply``
@@ -186,7 +186,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
186186
model, image = get_sam_model(only_one_block, batchsize)
187187
model = model.to(torch.bfloat16)
188188
image = image.to(torch.bfloat16)
189-
quantize_(model, int8_dynamic_activation_int8_weight())
189+
quantize_(model, Int8DynamicActivationInt8WeightConfig())
190190
if not TORCH_VERSION_AT_LEAST_2_5:
191191
# needed for subclass + compile to work on older versions of pytorch
192192
unwrap_tensor_subclass(model)
@@ -224,7 +224,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
224224
model = model.to(torch.bfloat16)
225225
image = image.to(torch.bfloat16)
226226
torch._inductor.config.force_fuse_int_mm_with_mul = True
227-
quantize_(model, int8_dynamic_activation_int8_weight())
227+
quantize_(model, Int8DynamicActivationInt8WeightConfig())
228228
if not TORCH_VERSION_AT_LEAST_2_5:
229229
# needed for subclass + compile to work on older versions of pytorch
230230
unwrap_tensor_subclass(model)
@@ -258,7 +258,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
258258
torch._inductor.config.coordinate_descent_tuning = True
259259
torch._inductor.config.coordinate_descent_check_all_directions = True
260260
torch._inductor.config.force_fuse_int_mm_with_mul = True
261-
quantize_(model, int8_dynamic_activation_int8_weight())
261+
quantize_(model, Int8DynamicActivationInt8WeightConfig())
262262
if not TORCH_VERSION_AT_LEAST_2_5:
263263
# needed for subclass + compile to work on older versions of pytorch
264264
unwrap_tensor_subclass(model)
@@ -290,7 +290,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
290290
model, image = get_sam_model(False, batchsize)
291291
model = model.to(torch.bfloat16)
292292
image = image.to(torch.bfloat16)
293-
quantize_(model, int8_dynamic_activation_int8_weight())
293+
quantize_(model, Int8DynamicActivationInt8WeightConfig())
294294
if not TORCH_VERSION_AT_LEAST_2_5:
295295
# needed for subclass + compile to work on older versions of pytorch
296296
unwrap_tensor_subclass(model)
@@ -315,6 +315,6 @@ def get_sam_model(only_one_block=False, batchsize=1):
315315
# the model. For example, this can be done with some form of flash attention.
316316
#
317317
# For more information visit
318-
# `torchao <https://github.com/pytorch-labs/ao>`_ and try it on your own
318+
# `torchao <https://github.com/pytorch/ao>`_ and try it on your own
319319
# models.
320320
#

0 commit comments

Comments
 (0)