6
6
7
7
import jax
8
8
import jax .numpy as jnp
9
+ from flax import nnx
9
10
from fire import Fire
10
11
from jax .typing import DTypeLike
11
12
13
+ from PIL import Image
14
+
15
+ from einops import rearrange
12
16
from jflux .sampling import denoise , get_noise , get_schedule , prepare , unpack
13
17
from jflux .util import configs , load_ae , load_clip , load_flow_model , load_t5
14
18
@@ -101,51 +105,35 @@ def main(
101
105
"a photo of a forest with mist swirling around the tree trunks. The word "
102
106
'"FLUX" is painted over it in big, red brush strokes with visible texture'
103
107
),
104
- device : str = "gpu" if jax .device_get ("gpu" ) else "cpu" ,
105
108
num_steps : int | None = None ,
106
109
loop : bool = False ,
107
110
guidance : float = 3.5 ,
108
- # TODO: JAX variant of offloading to CPU
109
111
offload : bool = False ,
110
112
output_dir : str = "output" ,
111
- dtype : DTypeLike = jax .dtypes .bfloat16 ,
112
- param_dtype : DTypeLike = None ,
113
- ) -> None :
113
+ add_sampling_metadata : bool = True ,
114
+ ):
114
115
"""
115
- Sample the flux model.
116
+ Sample the flux model. Either interactively (set `--loop`) or run for a
117
+ single image.
116
118
117
119
Args:
118
- name(str): Name of the model to use. Choose from 'flux-schnell' or 'flux-dev'.
119
- width(int): Width of the generated image.
120
- height(int): Height of the generated image.
121
- seed(int, optional): Seed for the random number generator.
122
- prompt(str): Text prompt to generate the image from.
123
- device(str): Device to run the model on. Choose from 'cpu' or 'gpu'.
124
- num_steps(int, optional): Number of steps to run the model for.
125
- loop(bool): Whether to loop the sampling process.
126
- guidance(float, optional): Guidance for the model, defaults to 3.5.
127
- offload(bool, optional): Whether to offload the model to CPU, defaults to False.
128
- output_dir(str, optional): Directory to save the output images in, defaults to 'output'.
129
- dtype(DTypeLike, optional): Data type for the model, defaults to jax.dtypes.bfloat16.
130
- param_dtype(DTypeLike, optional): Data type for the model parameters, defaults to None.
120
+ name: Name of the model to load
121
+ height: height of the sample in pixels (should be a multiple of 16)
122
+ width: width of the sample in pixels (should be a multiple of 16)
123
+ seed: Set a seed for sampling
124
+ output_name: where to save the output image, `{idx}` will be replaced
125
+ by the index of the sample
126
+ prompt: Prompt used for sampling
127
+ device: Pytorch device
128
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
129
+ loop: start an interactive session and sample multiple times
130
+ guidance: guidance value used for guidance distillation
131
+ add_sampling_metadata: Add the prompt to the image Exif metadata
131
132
"""
132
-
133
- if param_dtype is None :
134
- param_dtype = dtype
135
-
136
133
if name not in configs :
137
134
available = ", " .join (configs .keys ())
138
135
raise ValueError (f"Got unknown model name: { name } , chose from { available } " )
139
136
140
- jax_device = jax .devices (device )
141
- if len (jax_device ) == 1 :
142
- jax_device = jax_device [0 ]
143
- else :
144
- # TODO (ariG23498)
145
- # this will be when there are more than
146
- # one devices to work on
147
- pass
148
-
149
137
if num_steps is None :
150
138
num_steps = 4 if name == "flux-schnell" else 50
151
139
@@ -169,26 +157,11 @@ def main(
169
157
idx = 0
170
158
171
159
# init all components
172
- import sys
173
-
174
- sys .exit (0 )
175
- t5 = load_t5 (max_length = 256 if name == "flux-schnell" else 512 )
160
+ t5 = load_t5 ()
176
161
clip = load_clip ()
177
- model = load_flow_model (
178
- name ,
179
- device = "cpu" if offload else jax_device ,
180
- dtype = dtype ,
181
- param_dtype = param_dtype ,
182
- )
183
- ae = load_ae (
184
- name ,
185
- device = "cpu" if offload else jax_device ,
186
- dtype = dtype ,
187
- param_dtype = param_dtype ,
188
- )
162
+ model = load_flow_model (name )
163
+ ae = load_ae (name )
189
164
190
- # TODO (ariG23498)
191
- # rngs = nnx.Rngs(0)
192
165
opts = SamplingOptions (
193
166
prompt = prompt ,
194
167
width = width ,
@@ -200,57 +173,51 @@ def main(
200
173
201
174
while opts is not None :
202
175
if opts .seed is None :
203
- # TODO (ariG23498)
204
- # set the rng seed
205
- # opts.seed = rng.seed()
206
- pass
176
+ opts .seed = jax .random .PRNGKey (seed = 42 )
207
177
print (f"Generating with seed { opts .seed } :\n { opts .prompt } " )
208
178
t0 = time .perf_counter ()
209
179
210
180
# prepare input
211
181
x = get_noise (
212
- 1 ,
213
- opts .height ,
214
- opts .width ,
215
- device = jax_device ,
182
+ num_samples = 1 ,
183
+ height = opts .height ,
184
+ width = opts .width ,
216
185
dtype = jax .dtypes .bfloat16 ,
217
- seed = opts .seed , # type: ignore
186
+ seed = opts .seed ,
218
187
)
219
188
opts .seed = None
220
- # TODO: JAX variant of offloading to CPU
221
- # if offload:
222
- # ae = ae.cpu()
223
- # torch.cuda.empty_cache()
224
- # t5, clip = t5.to(torch_device), clip.to(torch_device)
225
- inp = prepare (t5 , clip , img = x , prompt = opts .prompt , device = jax_device )
189
+
190
+ inp = prepare (t5 = t5 , clip = clip , img = x , prompt = opts .prompt )
226
191
timesteps = get_schedule (
227
- opts .num_steps , inp ["img" ].shape [1 ], shift = (name != "flux-schnell" )
192
+ num_steps = opts .num_steps ,
193
+ image_seq_len = inp ["img" ].shape [1 ],
194
+ shift = (name != "flux-schnell" ),
228
195
)
229
196
230
- # offload TEs to CPU, load model to gpu
231
- # TODO: JAX variant of offloading to CPU
232
- # if offload:
233
- # t5, clip = t5.cpu(), clip.cpu()
234
- # torch.cuda.empty_cache()
235
- # model = model.to(torch_device)
236
-
237
197
# denoise initial noise
238
- x = denoise (model , ** inp , timesteps = timesteps , guidance = opts .guidance )
239
-
240
- # offload model, load autoencoder to gpu
241
- # TODO: JAX variant of offloading to CPU
242
- # if offload:
243
- # model.cpu()
244
- # torch.cuda.empty_cache()
245
- # ae.decoder.to(x.device)
198
+ x = denoise (
199
+ model = model ,
200
+ img = inp ["img" ],
201
+ img_ids = inp ["img_ids" ],
202
+ txt = inp ["txt" ],
203
+ txt_ids = inp ["txt_ids" ],
204
+ vec = inp ["vec" ],
205
+ timesteps = timesteps ,
206
+ guidance = opts .guidance ,
207
+ )
246
208
247
209
# decode latents to pixel space
248
- x = unpack (x .astype (jnp .float32 ), opts .height , opts .width )
249
- x = ae .decode (x ). astype ( dtype = jax . dtypes . bfloat16 ) # noqa
210
+ x = unpack (x = x .astype (jnp .float32 ), height = opts .height , width = opts .width )
211
+ x = ae .decode (x )
250
212
t1 = time .perf_counter ()
251
213
252
214
fn = output_name .format (idx = idx )
253
215
print (f"Done in { t1 - t0 :.1f} s. Saving { fn } " )
216
+ # bring into PIL format and save
217
+ x = x .clip (- 1 , 1 )
218
+ x = rearrange (x [0 ], "c h w -> h w c" )
219
+
220
+ img = Image .fromarray ((127.5 * (x + 1.0 )))
254
221
255
222
if loop :
256
223
print ("-" * 80 )
0 commit comments