Skip to content

Commit 9162e4d

Browse files
AE Porting Code and Cleanup (#10)
Co-authored-by: Saurav Maheshkar <[email protected]>
1 parent e9dbdf7 commit 9162e4d

16 files changed

+615
-477
lines changed

jflux/cli.py

+50-83
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
import jax
88
import jax.numpy as jnp
9+
from flax import nnx
910
from fire import Fire
1011
from jax.typing import DTypeLike
1112

13+
from PIL import Image
14+
15+
from einops import rearrange
1216
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
1317
from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5
1418

@@ -101,51 +105,35 @@ def main(
101105
"a photo of a forest with mist swirling around the tree trunks. The word "
102106
'"FLUX" is painted over it in big, red brush strokes with visible texture'
103107
),
104-
device: str = "gpu" if jax.device_get("gpu") else "cpu",
105108
num_steps: int | None = None,
106109
loop: bool = False,
107110
guidance: float = 3.5,
108-
# TODO: JAX variant of offloading to CPU
109111
offload: bool = False,
110112
output_dir: str = "output",
111-
dtype: DTypeLike = jax.dtypes.bfloat16,
112-
param_dtype: DTypeLike = None,
113-
) -> None:
113+
add_sampling_metadata: bool = True,
114+
):
114115
"""
115-
Sample the flux model.
116+
Sample the flux model. Either interactively (set `--loop`) or run for a
117+
single image.
116118
117119
Args:
118-
name(str): Name of the model to use. Choose from 'flux-schnell' or 'flux-dev'.
119-
width(int): Width of the generated image.
120-
height(int): Height of the generated image.
121-
seed(int, optional): Seed for the random number generator.
122-
prompt(str): Text prompt to generate the image from.
123-
device(str): Device to run the model on. Choose from 'cpu' or 'gpu'.
124-
num_steps(int, optional): Number of steps to run the model for.
125-
loop(bool): Whether to loop the sampling process.
126-
guidance(float, optional): Guidance for the model, defaults to 3.5.
127-
offload(bool, optional): Whether to offload the model to CPU, defaults to False.
128-
output_dir(str, optional): Directory to save the output images in, defaults to 'output'.
129-
dtype(DTypeLike, optional): Data type for the model, defaults to jax.dtypes.bfloat16.
130-
param_dtype(DTypeLike, optional): Data type for the model parameters, defaults to None.
120+
name: Name of the model to load
121+
height: height of the sample in pixels (should be a multiple of 16)
122+
width: width of the sample in pixels (should be a multiple of 16)
123+
seed: Set a seed for sampling
124+
output_name: where to save the output image, `{idx}` will be replaced
125+
by the index of the sample
126+
prompt: Prompt used for sampling
127+
device: Pytorch device
128+
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
129+
loop: start an interactive session and sample multiple times
130+
guidance: guidance value used for guidance distillation
131+
add_sampling_metadata: Add the prompt to the image Exif metadata
131132
"""
132-
133-
if param_dtype is None:
134-
param_dtype = dtype
135-
136133
if name not in configs:
137134
available = ", ".join(configs.keys())
138135
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
139136

140-
jax_device = jax.devices(device)
141-
if len(jax_device) == 1:
142-
jax_device = jax_device[0]
143-
else:
144-
# TODO (ariG23498)
145-
# this will be when there are more than
146-
# one devices to work on
147-
pass
148-
149137
if num_steps is None:
150138
num_steps = 4 if name == "flux-schnell" else 50
151139

@@ -169,26 +157,11 @@ def main(
169157
idx = 0
170158

171159
# init all components
172-
import sys
173-
174-
sys.exit(0)
175-
t5 = load_t5(max_length=256 if name == "flux-schnell" else 512)
160+
t5 = load_t5()
176161
clip = load_clip()
177-
model = load_flow_model(
178-
name,
179-
device="cpu" if offload else jax_device,
180-
dtype=dtype,
181-
param_dtype=param_dtype,
182-
)
183-
ae = load_ae(
184-
name,
185-
device="cpu" if offload else jax_device,
186-
dtype=dtype,
187-
param_dtype=param_dtype,
188-
)
162+
model = load_flow_model(name)
163+
ae = load_ae(name)
189164

190-
# TODO (ariG23498)
191-
# rngs = nnx.Rngs(0)
192165
opts = SamplingOptions(
193166
prompt=prompt,
194167
width=width,
@@ -200,57 +173,51 @@ def main(
200173

201174
while opts is not None:
202175
if opts.seed is None:
203-
# TODO (ariG23498)
204-
# set the rng seed
205-
# opts.seed = rng.seed()
206-
pass
176+
opts.seed = jax.random.PRNGKey(seed=42)
207177
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
208178
t0 = time.perf_counter()
209179

210180
# prepare input
211181
x = get_noise(
212-
1,
213-
opts.height,
214-
opts.width,
215-
device=jax_device,
182+
num_samples=1,
183+
height=opts.height,
184+
width=opts.width,
216185
dtype=jax.dtypes.bfloat16,
217-
seed=opts.seed, # type: ignore
186+
seed=opts.seed,
218187
)
219188
opts.seed = None
220-
# TODO: JAX variant of offloading to CPU
221-
# if offload:
222-
# ae = ae.cpu()
223-
# torch.cuda.empty_cache()
224-
# t5, clip = t5.to(torch_device), clip.to(torch_device)
225-
inp = prepare(t5, clip, img=x, prompt=opts.prompt, device=jax_device)
189+
190+
inp = prepare(t5=t5, clip=clip, img=x, prompt=opts.prompt)
226191
timesteps = get_schedule(
227-
opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")
192+
num_steps=opts.num_steps,
193+
image_seq_len=inp["img"].shape[1],
194+
shift=(name != "flux-schnell"),
228195
)
229196

230-
# offload TEs to CPU, load model to gpu
231-
# TODO: JAX variant of offloading to CPU
232-
# if offload:
233-
# t5, clip = t5.cpu(), clip.cpu()
234-
# torch.cuda.empty_cache()
235-
# model = model.to(torch_device)
236-
237197
# denoise initial noise
238-
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
239-
240-
# offload model, load autoencoder to gpu
241-
# TODO: JAX variant of offloading to CPU
242-
# if offload:
243-
# model.cpu()
244-
# torch.cuda.empty_cache()
245-
# ae.decoder.to(x.device)
198+
x = denoise(
199+
model=model,
200+
img=inp["img"],
201+
img_ids=inp["img_ids"],
202+
txt=inp["txt"],
203+
txt_ids=inp["txt_ids"],
204+
vec=inp["vec"],
205+
timesteps=timesteps,
206+
guidance=opts.guidance,
207+
)
246208

247209
# decode latents to pixel space
248-
x = unpack(x.astype(jnp.float32), opts.height, opts.width)
249-
x = ae.decode(x).astype(dtype=jax.dtypes.bfloat16) # noqa
210+
x = unpack(x=x.astype(jnp.float32), height=opts.height, width=opts.width)
211+
x = ae.decode(x)
250212
t1 = time.perf_counter()
251213

252214
fn = output_name.format(idx=idx)
253215
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
216+
# bring into PIL format and save
217+
x = x.clip(-1, 1)
218+
x = rearrange(x[0], "c h w -> h w c")
219+
220+
img = Image.fromarray((127.5 * (x + 1.0)))
254221

255222
if loop:
256223
print("-" * 80)

jflux/math.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import jax
21
from chex import Array
32
from einops import rearrange
43
from flax import nnx
@@ -16,7 +15,7 @@ def attention(q: Array, k: Array, v: Array, pe: Array) -> Array:
1615

1716
def rope(pos: Array, dim: int, theta: int) -> Array:
1817
assert dim % 2 == 0
19-
scale = jnp.arange(0, dim, 2, dtype=jnp.float64, device=pos.device) / dim
18+
scale = jnp.arange(0, dim, 2, dtype=jnp.float32) / dim
2019
omega = 1.0 / (theta**scale)
2120
out = jnp.einsum("...n,d->...nd", pos, omega)
2221
out = jnp.stack([jnp.cos(out), -jnp.sin(out), jnp.sin(out), jnp.cos(out)], axis=-1)

jflux/model.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import jax.numpy as jnp
44
from chex import Array
55
from flax import nnx
6-
from flux.modules.layers import (
6+
from jax.typing import DTypeLike
7+
8+
from jflux.modules.layers import (
79
DoubleStreamBlock,
810
EmbedND,
911
LastLayer,
1012
MLPEmbedder,
1113
SingleStreamBlock,
1214
timestep_embedding,
1315
)
14-
from jax.typing import DTypeLike
1516

1617

1718
@dataclass
@@ -67,8 +68,18 @@ def __init__(self, params: FluxParams):
6768
rngs=params.rngs,
6869
param_dtype=params.param_dtype,
6970
)
70-
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
71-
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
71+
self.time_in = MLPEmbedder(
72+
in_dim=256,
73+
hidden_dim=self.hidden_size,
74+
rngs=params.rngs,
75+
param_dtype=params.param_dtype,
76+
)
77+
self.vector_in = MLPEmbedder(
78+
params.vec_in_dim,
79+
self.hidden_size,
80+
rngs=params.rngs,
81+
param_dtype=params.param_dtype,
82+
)
7283
self.guidance_in = (
7384
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
7485
if params.guidance_embed
@@ -109,7 +120,13 @@ def __init__(self, params: FluxParams):
109120
]
110121
)
111122

112-
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
123+
self.final_layer = LastLayer(
124+
self.hidden_size,
125+
1,
126+
self.out_channels,
127+
rngs=params.rngs,
128+
param_dtype=params.param_dtype,
129+
)
113130

114131
def __call__(
115132
self,
@@ -136,14 +153,14 @@ def __call__(
136153
vec = vec + self.vector_in(y)
137154
txt = self.txt_in(txt)
138155

139-
ids = jnp.concatenate((txt_ids, img_ids), dim=1)
156+
ids = jnp.concatenate((txt_ids, img_ids), axis=1)
140157
pe = self.pe_embedder(ids)
141158

142-
for block in self.double_blocks:
159+
for block in self.double_blocks.layers:
143160
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
144161

145-
img = jnp.concatenate((txt, img), 1)
146-
for block in self.single_blocks:
162+
img = jnp.concatenate((txt, img), axis=1)
163+
for block in self.single_blocks.layers:
147164
img = block(img, vec=vec, pe=pe)
148165
img = img[:, txt.shape[1] :, ...]
149166

jflux/modules/autoencoder.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -524,14 +524,27 @@ def __init__(
524524
self.shift_factor = params.shift_factor
525525

526526
def encode(self, x: Array) -> Array:
527+
# rearrange for jax
528+
x = rearrange(x, "b c h w -> b h w c")
529+
527530
z = self.reg(self.encoder(x))
528531
z = self.scale_factor * (z - self.shift_factor)
532+
533+
# rearrange for jax
534+
z = rearrange(z, "b h w c -> b c h w")
529535
return z
530536

531537
def decode(self, z: Array) -> Array:
538+
# rearrange for jax
539+
z = rearrange(z, "b c h w -> b h w c")
540+
532541
z = z / self.scale_factor + self.shift_factor
533-
return self.decoder(z)
542+
z = self.decoder(z)
543+
544+
# rearrange for jax
545+
z = rearrange(z, "b h w c -> b c h w")
546+
return z
534547

535548
def __call__(self, x: Array) -> Array:
536-
# x -> (b, h, w, c)
549+
# x -> (b, c, h, w)
537550
return self.decode(self.encode(x))

jflux/modules/conditioner.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1+
# Note: This is a torch module not a Jax module
2+
from torch import nn
13
from chex import Array
2-
from flax import nnx
3-
from transformers import (
4-
CLIPTokenizer,
5-
FlaxCLIPTextModel,
6-
FlaxT5EncoderModel,
7-
T5Tokenizer,
8-
)
4+
import jax.numpy as jnp
5+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
96

107

11-
class HFEmbedder(nnx.Module):
12-
def __init__(self, version: str, max_length: int, **hf_kwargs) -> None:
8+
class HFEmbedder(nn.Module):
9+
def __init__(self, version: str, max_length: int, **hf_kwargs):
10+
super().__init__()
1311
self.is_clip = version.startswith("openai")
1412
self.max_length = max_length
1513
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
@@ -18,33 +16,36 @@ def __init__(self, version: str, max_length: int, **hf_kwargs) -> None:
1816
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
1917
version, max_length=max_length
2018
)
21-
self.hf_module: FlaxCLIPTextModel = FlaxCLIPTextModel.from_pretrained(
19+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(
2220
version, **hf_kwargs
2321
)
2422
else:
2523
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(
2624
version, max_length=max_length
2725
)
28-
self.hf_module: FlaxT5EncoderModel = FlaxT5EncoderModel.from_pretrained(
29-
version, from_pt=True, **hf_kwargs
26+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(
27+
version, **hf_kwargs
3028
)
3129

32-
self.hf_module = self.hf_module.eval().requires_grad_(False) # noqa: ignore
30+
self.hf_module = self.hf_module.eval().requires_grad_(False)
3331

34-
def __call__(self, text: list[str]) -> Array:
32+
def forward(self, text: list[str]) -> Array:
3533
batch_encoding = self.tokenizer(
3634
text,
3735
truncation=True,
3836
max_length=self.max_length,
3937
return_length=False,
4038
return_overflowing_tokens=False,
4139
padding="max_length",
42-
return_tensors="np",
40+
return_tensors="pt",
4341
)
4442

4543
outputs = self.hf_module(
4644
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
4745
attention_mask=None,
4846
output_hidden_states=False,
4947
)
50-
return outputs[self.output_key]
48+
torch_outputs = outputs[self.output_key]
49+
50+
jax_outputs = jnp.array(torch_outputs.cpu().float(), dtype=jnp.bfloat16)
51+
return jax_outputs

0 commit comments

Comments
 (0)