Skip to content

Commit 58e7e93

Browse files
committed
autoencoder works
1 parent bf2d84c commit 58e7e93

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

jflux/modules/autoencoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def attention(self, h_: Array) -> Array:
8585
v = rearrange(v, "b h w c-> b (h w) 1 c")
8686

8787
# Calculate Attention
88-
h_ = nnx.dot_product_attention(q, k, v)
88+
h_ = jax.nn.dot_product_attention(q, k, v)
8989

9090
return rearrange(h_, "b (h w) 1 c -> b h w c", h=h, w=w, c=c, b=b)
9191

jflux/util.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch # need for t5 and clip
55
from flax import nnx
66
from huggingface_hub import hf_hub_download
7+
import jax
78
from jax import numpy as jnp
89
from safetensors import safe_open
910

@@ -12,6 +13,11 @@
1213
from jflux.modules.conditioner import HFEmbedder
1314
from jflux.port import port_autoencoder, port_flux
1415

16+
def torch2jax(torch_tensor):
17+
intermediate_tensor = torch_tensor.to(torch.float32)
18+
jax_tensor = jnp.array(intermediate_tensor, dtype=jnp.bfloat16)
19+
return jax_tensor
20+
1521

1622
@dataclass
1723
class ModelSpec:
@@ -127,9 +133,10 @@ def load_flow_model(name: str, hf_download: bool = True) -> Flux:
127133

128134
if ckpt_path is not None:
129135
tensors = {}
130-
with safe_open(ckpt_path, framework="flax") as f:
136+
with safe_open(ckpt_path, framework="pt") as f:
131137
for k in f.keys():
132-
tensors[k] = f.get_tensor(k)
138+
with jax.default_device(jax.devices("cpu")[0]):
139+
tensors[k] = torch2jax(f.get_tensor(k))
133140

134141
model = port_flux(flux=model, tensors=tensors)
135142
del tensors
@@ -166,9 +173,10 @@ def load_ae(name: str, hf_download: bool = True) -> AutoEncoder:
166173

167174
if ckpt_path is not None:
168175
tensors = {}
169-
with safe_open(ckpt_path, framework="flax") as f:
176+
with safe_open(ckpt_path, framework="pt") as f:
170177
for k in f.keys():
171-
tensors[k] = f.get_tensor(k)
178+
with jax.default_device(jax.devices("cpu")[0]):
179+
tensors[k] = torch2jax(f.get_tensor(k))
172180

173181
ae = port_autoencoder(autoencoder=ae, tensors=tensors)
174182
del tensors

0 commit comments

Comments
 (0)