|
4 | 4 | import torch # need for t5 and clip
|
5 | 5 | from flax import nnx
|
6 | 6 | from huggingface_hub import hf_hub_download
|
| 7 | +import jax |
7 | 8 | from jax import numpy as jnp
|
8 | 9 | from safetensors import safe_open
|
9 | 10 |
|
|
12 | 13 | from jflux.modules.conditioner import HFEmbedder
|
13 | 14 | from jflux.port import port_autoencoder, port_flux
|
14 | 15 |
|
| 16 | +def torch2jax(torch_tensor): |
| 17 | + intermediate_tensor = torch_tensor.to(torch.float32) |
| 18 | + jax_tensor = jnp.array(intermediate_tensor, dtype=jnp.bfloat16) |
| 19 | + return jax_tensor |
| 20 | + |
15 | 21 |
|
16 | 22 | @dataclass
|
17 | 23 | class ModelSpec:
|
@@ -127,9 +133,10 @@ def load_flow_model(name: str, hf_download: bool = True) -> Flux:
|
127 | 133 |
|
128 | 134 | if ckpt_path is not None:
|
129 | 135 | tensors = {}
|
130 |
| - with safe_open(ckpt_path, framework="flax") as f: |
| 136 | + with safe_open(ckpt_path, framework="pt") as f: |
131 | 137 | for k in f.keys():
|
132 |
| - tensors[k] = f.get_tensor(k) |
| 138 | + with jax.default_device(jax.devices("cpu")[0]): |
| 139 | + tensors[k] = torch2jax(f.get_tensor(k)) |
133 | 140 |
|
134 | 141 | model = port_flux(flux=model, tensors=tensors)
|
135 | 142 | del tensors
|
@@ -166,9 +173,10 @@ def load_ae(name: str, hf_download: bool = True) -> AutoEncoder:
|
166 | 173 |
|
167 | 174 | if ckpt_path is not None:
|
168 | 175 | tensors = {}
|
169 |
| - with safe_open(ckpt_path, framework="flax") as f: |
| 176 | + with safe_open(ckpt_path, framework="pt") as f: |
170 | 177 | for k in f.keys():
|
171 |
| - tensors[k] = f.get_tensor(k) |
| 178 | + with jax.default_device(jax.devices("cpu")[0]): |
| 179 | + tensors[k] = torch2jax(f.get_tensor(k)) |
172 | 180 |
|
173 | 181 | ae = port_autoencoder(autoencoder=ae, tensors=tensors)
|
174 | 182 | del tensors
|
|
0 commit comments