Skip to content

Commit 0a93be7

Browse files
authored
Add pt2e tutorials to torchao doc page (#2384)
Summary: att, after we migrate pt2e quant code from pytorch to torchao, now we also want to migrate the docs as well Test Plan: check generated docs Reviewers: Subscribers: Tasks: Tags:
1 parent 7e7ea92 commit 0a93be7

File tree

9 files changed

+2369
-10
lines changed

9 files changed

+2369
-10
lines changed

docs/source/index.rst

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ for an overall introduction to the library and recent highlight and updates.
1212
:caption: Getting Started
1313

1414
quick_start
15+
pt2e_quant
1516

1617
.. toctree::
1718
:glob:
@@ -35,11 +36,23 @@ for an overall introduction to the library and recent highlight and updates.
3536
.. toctree::
3637
:glob:
3738
:maxdepth: 1
38-
:caption: Tutorials
39+
:caption: Eager Quantization Tutorials
3940

4041
serialization
4142
subclass_basic
4243
subclass_advanced
4344
static_quantization
4445
pretraining
4546
torchao_vllm_integration
47+
48+
.. toctree::
49+
:glob:
50+
:maxdepth: 1
51+
:caption: PT2E Quantization Tutorials
52+
53+
tutorials_source/pt2e_quant_ptq
54+
tutorials_source/pt2e_quant_qat
55+
tutorials_source/pt2e_quant_x86_inductor
56+
tutorials_source/pt2e_quant_xpu_inductor
57+
tutorials_source/pt2e_quantizer
58+
tutorials_source/openvino_quantizer

docs/source/quick_start.rst

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,20 @@ First, let's set up our toy model:
2929
3030
import copy
3131
import torch
32-
32+
3333
class ToyLinearModel(torch.nn.Module):
3434
def __init__(self, m: int, n: int, k: int):
3535
super().__init__()
3636
self.linear1 = torch.nn.Linear(m, n, bias=False)
3737
self.linear2 = torch.nn.Linear(n, k, bias=False)
38-
38+
3939
def forward(self, x):
4040
x = self.linear1(x)
4141
x = self.linear2(x)
4242
return x
43-
43+
4444
model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
45-
45+
4646
# Optional: compile model for faster inference and generation
4747
model = torch.compile(model, mode="max-autotune", fullgraph=True)
4848
model_bf16 = copy.deepcopy(model)
@@ -99,18 +99,18 @@ it is also much faster!
9999
benchmark_model,
100100
unwrap_tensor_subclass,
101101
)
102-
102+
103103
# Temporary workaround for tensor subclass + torch.compile
104104
# Only needed for torch version < 2.5
105105
if not TORCH_VERSION_AT_LEAST_2_5:
106106
unwrap_tensor_subclass(model)
107-
107+
108108
num_runs = 100
109109
torch._dynamo.reset()
110110
example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),)
111111
bf16_time = benchmark_model(model_bf16, num_runs, example_inputs)
112112
int4_time = benchmark_model(model, num_runs, example_inputs)
113-
113+
114114
print("bf16 mean time: %0.3f ms" % bf16_time)
115115
print("int4 mean time: %0.3f ms" % int4_time)
116116
print("speedup: %0.1fx" % (bf16_time / int4_time))
@@ -121,6 +121,87 @@ On a single A100 GPU with 80GB memory, this prints::
121121
int4 mean time: 4.410 ms
122122
speedup: 6.9x
123123

124+
PyTorch 2 Export Quantization
125+
=============================
126+
PyTorch 2 Export Quantization is a full graph quantization workflow mostly for static quantization. It targets hardwares that requires both input and output activation and weight to be quantized and relies of recognizing an operator pattern to make quantization decisions (such as linear - relu). PT2E quantization produces a pattern with quantize and dequantize ops inserted around the operators and during lowering quantized operator patterns will be fused into real quantized ops. Currently there are two typical lowering paths, 1. torch.compile through inductor lowering 2. ExecuTorch through delegation
127+
128+
Here we show an example with X86InductorQuantizer
129+
130+
API Example::
131+
132+
import torch
133+
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e
134+
from torch.export import export
135+
from torchao.quantization.pt2e.quantizer.x86_inductor_quantizer import (
136+
X86InductorQuantizer,
137+
get_default_x86_inductor_quantization_config,
138+
)
139+
140+
class M(torch.nn.Module):
141+
def __init__(self):
142+
super().__init__()
143+
self.linear = torch.nn.Linear(5, 10)
144+
145+
def forward(self, x):
146+
return self.linear(x)
147+
148+
# initialize a floating point model
149+
float_model = M().eval()
150+
151+
# define calibration function
152+
def calibrate(model, data_loader):
153+
model.eval()
154+
with torch.no_grad():
155+
for image, target in data_loader:
156+
model(image)
157+
158+
# Step 1. program capture
159+
m = export(m, *example_inputs).module()
160+
# we get a model with aten ops
161+
162+
# Step 2. quantization
163+
# backend developer will write their own Quantizer and expose methods to allow
164+
# users to express how they
165+
# want the model to be quantized
166+
quantizer = X86InductorQuantizer()
167+
quantizer.set_global(xiq.get_default_x86_inductor_quantization_config())
168+
169+
# or prepare_qat_pt2e for Quantization Aware Training
170+
m = prepare_pt2e(m, quantizer)
171+
172+
# run calibration
173+
# calibrate(m, sample_inference_data)
174+
m = convert_pt2e(m)
175+
176+
# Step 3. lowering
177+
# lower to target backend
178+
179+
# Optional: using the C++ wrapper instead of default Python wrapper
180+
import torch._inductor.config as config
181+
config.cpp_wrapper = True
182+
183+
with torch.no_grad():
184+
optimized_model = torch.compile(converted_model)
185+
186+
# Running some benchmark
187+
optimized_model(*example_inputs)
188+
189+
190+
Please follow these tutorials to get started on PyTorch 2 Export Quantization:
191+
192+
Modeling Users:
193+
194+
- `PyTorch 2 Export Post Training Quantization <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_ptq.html>`_
195+
- `PyTorch 2 Export Quantization Aware Training <ttps://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_qat.html>`_
196+
- `PyTorch 2 Export Post Training Quantization with X86 Backend through Inductor <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_x86_inductor.html>`_
197+
- `PyTorch 2 Export Post Training Quantization with XPU Backend through Inductor <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_xpu_inductor.html>`_
198+
- `PyTorch 2 Export Quantization for OpenVINO torch.compile Backend <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quant_openvino.html>`_
199+
200+
201+
Backend Developers (please check out all Modeling Users docs as well):
202+
203+
- `How to Write a Quantizer for PyTorch 2 Export Quantization <https://docs.pytorch.org/ao/stable/tutorial_source/pt2e_quantizer.html>`_
204+
124205

125206
Next Steps
126207
==========

0 commit comments

Comments
 (0)