26
26
import ai_edge_torch .generative .examples .stable_diffusion .diffusion as diffusion
27
27
from ai_edge_torch .generative .examples .stable_diffusion .encoder import Encoder
28
28
import ai_edge_torch .generative .examples .stable_diffusion .util as util
29
+ from ai_edge_torch .generative .quantize import quant_recipes
29
30
import ai_edge_torch .generative .utilities .stable_diffusion_loader as stable_diffusion_loader
30
31
31
32
arg_parser = argparse .ArgumentParser ()
@@ -60,6 +61,7 @@ def convert_stable_diffusion_to_tflite(
60
61
decoder_ckpt_path : str ,
61
62
image_height : int = 512 ,
62
63
image_width : int = 512 ,
64
+ quantize : bool = True ,
63
65
):
64
66
65
67
clip_model = clip .CLIP (clip .get_model_config ())
@@ -105,15 +107,17 @@ def convert_stable_diffusion_to_tflite(
105
107
if not os .path .exists (output_dir ):
106
108
Path (output_dir ).mkdir (parents = True , exist_ok = True )
107
109
110
+ quant_config = quant_recipes .full_int8_weight_only_recipe () if quantize else None
111
+
108
112
# TODO(yichunk): convert to multi signature tflite model.
109
113
# CLIP text encoder
110
- ai_edge_torch .signature ('encode' , clip_model , (prompt_tokens ,)).convert (). export (
111
- f' { output_dir } /clip.tflite'
112
- )
114
+ ai_edge_torch .signature ('encode' , clip_model , (prompt_tokens ,)).convert (
115
+ quant_config = quant_config
116
+ ). export ( f' { output_dir } /clip.tflite' )
113
117
114
118
# TODO(yichunk): enable image encoder conversion
115
119
# Image encoder
116
- # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
120
+ # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert(quant_config=quant_config ).export(
117
121
# f'{output_dir}/encoder.tflite'
118
122
# )
119
123
@@ -122,12 +126,12 @@ def convert_stable_diffusion_to_tflite(
122
126
'diffusion' ,
123
127
diffusion_model ,
124
128
(torch .repeat_interleave (input_latents , 2 , 0 ), context , time_embedding ),
125
- ).convert ().export (f'{ output_dir } /diffusion.tflite' )
129
+ ).convert (quant_config = quant_config ).export (f'{ output_dir } /diffusion.tflite' )
126
130
127
131
# Image decoder
128
- ai_edge_torch .signature ('decode' , decoder_model , (input_latents ,)).convert (). export (
129
- f' { output_dir } /decoder.tflite'
130
- )
132
+ ai_edge_torch .signature ('decode' , decoder_model , (input_latents ,)).convert (
133
+ quant_config = quant_config
134
+ ). export ( f' { output_dir } /decoder.tflite' )
131
135
132
136
133
137
if __name__ == '__main__' :
@@ -139,4 +143,5 @@ def convert_stable_diffusion_to_tflite(
139
143
decoder_ckpt_path = args .decoder_ckpt ,
140
144
image_height = 512 ,
141
145
image_width = 512 ,
146
+ quantize = True ,
142
147
)
0 commit comments