Skip to content

Commit 0c9fb04

Browse files
ariG23498SauravMaheshkar
andauthoredOct 9, 2024··
Adding FLUX porting code (#11)
Co-authored-by: Saurav Maheshkar <[email protected]>
1 parent 9162e4d commit 0c9fb04

File tree

8 files changed

+400
-146
lines changed

8 files changed

+400
-146
lines changed
 

‎.github/CODEOWNERS

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
* @SauravMaheshkar
1+
* @SauravMaheshkar @ariG23498

‎jflux/cli.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66

77
import jax
88
import jax.numpy as jnp
9-
from flax import nnx
9+
from einops import rearrange
1010
from fire import Fire
11+
from flax import nnx
1112
from jax.typing import DTypeLike
12-
1313
from PIL import Image
1414

15-
from einops import rearrange
1615
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
1716
from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5
1817

‎jflux/modules/conditioner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Note: This is a torch module not a Jax module
2-
from torch import nn
3-
from chex import Array
42
import jax.numpy as jnp
3+
from chex import Array
4+
from torch import nn
55
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
66

77

‎jflux/modules/layers.py

+6
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def __init__(
214214
self.img_norm1 = nnx.LayerNorm(
215215
num_features=hidden_size,
216216
use_scale=False,
217+
use_bias=False,
217218
epsilon=1e-6,
218219
rngs=rngs,
219220
param_dtype=param_dtype,
@@ -229,6 +230,7 @@ def __init__(
229230
self.img_norm2 = nnx.LayerNorm(
230231
num_features=hidden_size,
231232
use_scale=False,
233+
use_bias=False,
232234
epsilon=1e-6,
233235
rngs=rngs,
234236
param_dtype=param_dtype,
@@ -257,6 +259,7 @@ def __init__(
257259
self.txt_norm1 = nnx.LayerNorm(
258260
num_features=hidden_size,
259261
use_scale=False,
262+
use_bias=False,
260263
epsilon=1e-6,
261264
rngs=rngs,
262265
param_dtype=param_dtype,
@@ -272,6 +275,7 @@ def __init__(
272275
self.txt_norm2 = nnx.LayerNorm(
273276
num_features=hidden_size,
274277
use_scale=False,
278+
use_bias=False,
275279
epsilon=1e-6,
276280
rngs=rngs,
277281
param_dtype=param_dtype,
@@ -382,6 +386,7 @@ def __init__(
382386
self.pre_norm = nnx.LayerNorm(
383387
num_features=hidden_size,
384388
use_scale=False,
389+
use_bias=False,
385390
epsilon=1e-6,
386391
rngs=rngs,
387392
param_dtype=param_dtype,
@@ -419,6 +424,7 @@ def __init__(
419424
self.norm_final = nnx.LayerNorm(
420425
num_features=hidden_size,
421426
use_scale=False,
427+
use_bias=False,
422428
epsilon=1e-6,
423429
rngs=rngs,
424430
param_dtype=param_dtype,

‎jflux/port.py

+369-118
Large diffs are not rendered by default.

‎jflux/util.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22
from dataclasses import dataclass
33

4-
import torch # need for t5 and clip
54
import jax
5+
import torch # need for t5 and clip
66
from flax import nnx
77
from huggingface_hub import hf_hub_download
88
from jax import numpy as jnp
@@ -12,8 +12,7 @@
1212
from jflux.model import Flux, FluxParams
1313
from jflux.modules.autoencoder import AutoEncoder, AutoEncoderParams
1414
from jflux.modules.conditioner import HFEmbedder
15-
16-
from port import port_autoencoder
15+
from jflux.port import port_autoencoder, port_flux
1716

1817

1918
@dataclass
@@ -128,13 +127,14 @@ def load_flow_model(name: str, hf_download: bool = True) -> Flux:
128127

129128
model = Flux(params=configs[name].params)
130129

131-
# TODO (ariG23498): Port the flux model
132130
if ckpt_path is not None:
133-
print("Loading checkpoint")
134-
# load_sft doesn't support torch.device
135-
sd = load_sft(ckpt_path)
136-
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
137-
print_load_warning(missing, unexpected)
131+
tensors = {}
132+
with safe_open(ckpt_path, framework="flax") as f:
133+
for k in f.keys():
134+
tensors[k] = f.get_tensor(k)
135+
136+
model = port_flux(flux=model, tensors=tensors)
137+
del tensors
138138
return model
139139

140140

@@ -166,12 +166,12 @@ def load_ae(name: str, hf_download: bool = True) -> AutoEncoder:
166166
print("Init AE")
167167
ae = AutoEncoder(params=configs[name].ae_params)
168168

169-
# TODO (ariG23498): Port the flux model
170169
if ckpt_path is not None:
171170
tensors = {}
172171
with safe_open(ckpt_path, framework="flax") as f:
173172
for k in f.keys():
174173
tensors[k] = f.get_tensor(k)
175174

176175
ae = port_autoencoder(autoencoder=ae, tensors=tensors)
176+
del tensors
177177
return ae

‎tests/modules/test_layers.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
1+
import jax
12
import jax.numpy as jnp
23
import numpy as np
3-
import jax
44
import torch
55
from einops import rearrange, repeat
66
from flax import nnx
77
from flux.modules.layers import DoubleStreamBlock as TorchDoubleStreamBlock
8+
from flux.modules.layers import EmbedND as TorchEmbedND
89
from flux.modules.layers import MLPEmbedder as TorchMLPEmbedder
910
from flux.modules.layers import Modulation as TorchModulation
1011
from flux.modules.layers import QKNorm as TorchQKNorm
1112
from flux.modules.layers import RMSNorm as TorchRMSNorm
1213
from flux.modules.layers import SelfAttention as TorchSelfAttention
1314
from flux.modules.layers import timestep_embedding as torch_timesetp_embedding
14-
from flux.modules.layers import EmbedND as TorchEmbedND
1515

1616
from jflux.modules.layers import DoubleStreamBlock as JaxDoubleStreamBlock
17+
from jflux.modules.layers import EmbedND as JaxEmbedND
1718
from jflux.modules.layers import MLPEmbedder as JaxMLPEmbedder
1819
from jflux.modules.layers import Modulation as JaxModulation
1920
from jflux.modules.layers import QKNorm as JaxQKNorm
2021
from jflux.modules.layers import RMSNorm as JaxRMSNorm
2122
from jflux.modules.layers import SelfAttention as JaxSelfAttention
2223
from jflux.modules.layers import timestep_embedding as jax_timestep_embedding
23-
from jflux.modules.layers import EmbedND as JaxEmbedND
24-
2524
from tests.utils import torch2jax
2625

2726

‎tests/test_sampling.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import numpy as np
1+
import chex
22
import jax
3+
import numpy as np
34
import torch
4-
import chex
5-
from jflux.sampling import get_noise as jax_get_noise
6-
75
from flux.sampling import get_noise as torch_get_noise
86

7+
from jflux.sampling import get_noise as jax_get_noise
8+
99

1010
class SamplingTestCase(chex.TestCase):
1111
def test_get_noise(self):
@@ -22,16 +22,15 @@ def test_get_noise(self):
2222
num_samples=1,
2323
height=height,
2424
width=width,
25-
dtype=jax.dtypes.bfloat16,
25+
dtype=jax.numpy.float32,
2626
seed=jax.random.PRNGKey(seed=42),
2727
)
2828
x_torch = torch_get_noise(
2929
num_samples=1,
3030
height=height,
3131
width=width,
32-
dtype=torch.bfloat16,
32+
dtype=torch.float32,
3333
seed=42,
34-
device="cuda",
34+
device="cuda" if torch.cuda.is_available() else "cpu",
3535
)
36-
print(x_jax.shape)
3736
chex.assert_equal_shape([x_jax, x_torch])

0 commit comments

Comments
 (0)
Please sign in to comment.