Skip to content

Commit e79fdc0

Browse files
authored
Add quantization for StableDiffusion (google-ai-edge#116)
BUG=b/355505876 Verified with pipeline.py
1 parent 5d3367c commit e79fdc0

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
2727
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
2828
import ai_edge_torch.generative.examples.stable_diffusion.util as util
29+
from ai_edge_torch.generative.quantize import quant_recipes
2930
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
3031

3132
arg_parser = argparse.ArgumentParser()
@@ -60,6 +61,7 @@ def convert_stable_diffusion_to_tflite(
6061
decoder_ckpt_path: str,
6162
image_height: int = 512,
6263
image_width: int = 512,
64+
quantize: bool = True,
6365
):
6466

6567
clip_model = clip.CLIP(clip.get_model_config())
@@ -105,15 +107,17 @@ def convert_stable_diffusion_to_tflite(
105107
if not os.path.exists(output_dir):
106108
Path(output_dir).mkdir(parents=True, exist_ok=True)
107109

110+
quant_config = quant_recipes.full_int8_weight_only_recipe() if quantize else None
111+
108112
# TODO(yichunk): convert to multi signature tflite model.
109113
# CLIP text encoder
110-
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
111-
f'{output_dir}/clip.tflite'
112-
)
114+
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert(
115+
quant_config=quant_config
116+
).export(f'{output_dir}/clip.tflite')
113117

114118
# TODO(yichunk): enable image encoder conversion
115119
# Image encoder
116-
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
120+
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert(quant_config=quant_config).export(
117121
# f'{output_dir}/encoder.tflite'
118122
# )
119123

@@ -122,12 +126,12 @@ def convert_stable_diffusion_to_tflite(
122126
'diffusion',
123127
diffusion_model,
124128
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
125-
).convert().export(f'{output_dir}/diffusion.tflite')
129+
).convert(quant_config=quant_config).export(f'{output_dir}/diffusion.tflite')
126130

127131
# Image decoder
128-
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
129-
f'{output_dir}/decoder.tflite'
130-
)
132+
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert(
133+
quant_config=quant_config
134+
).export(f'{output_dir}/decoder.tflite')
131135

132136

133137
if __name__ == '__main__':
@@ -139,4 +143,5 @@ def convert_stable_diffusion_to_tflite(
139143
decoder_ckpt_path=args.decoder_ckpt,
140144
image_height=512,
141145
image_width=512,
146+
quantize=True,
142147
)

ai_edge_torch/generative/examples/stable_diffusion/pipeline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@
6565
choices=['k_euler', 'k_euler_ancestral', 'k_lms'],
6666
help='A sampler to be used to denoise the encoded image latents. Can be one of `k_lms, `k_euler`, or `k_euler_ancestral`.',
6767
)
68+
arg_parser.add_argument(
69+
'--seed',
70+
default=None,
71+
type=int,
72+
help='A seed to make generation deterministic. A random number is used if unspecified.',
73+
)
6874

6975

7076
class StableDiffusion:
@@ -219,4 +225,5 @@ def run_tflite_pipeline(
219225
output_path=args.output_path,
220226
sampler=args.sampler,
221227
n_inference_steps=args.n_inference_steps,
228+
seed=args.seed,
222229
)

ai_edge_torch/generative/quantize/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,3 @@ def custom_selective_quantization_recipe() -> quant_config.QuantConfig:
4343
```
4444

4545
For example, this recipe specifies that the embedding table, attention, and feedforward layers should be quantized to INT8. Specifically, for attention layers the computation should be in FP32. All other ops should be quantized to the default scheme which is specified as FP16.
46-

0 commit comments

Comments
 (0)