Skip to content

Commit bf2d84c

Browse files
feat: chex-ify testsuite (#12)
* drop `.pre-commit-config.yaml`. * fix `README` instructions for running the cli. * add `.vscode/settings.json` with a simple vscode config. * delete redundant `jflux/__main__.py`. * enforce ruff style rules `E`, `F` and `W`. * `chex`-ify testsuite.
1 parent 0c9fb04 commit bf2d84c

17 files changed

+284
-240
lines changed

.github/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ $ uv sync
1212
## Running
1313

1414
```shell
15-
$ uv jflux
15+
$ uv run jflux
1616
```
1717

1818
## References

.pre-commit-config.yaml

-20
This file was deleted.

.vscode/settings.json

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"python.testing.pytestArgs": [
3+
"tests"
4+
],
5+
"python.testing.unittestEnabled": false,
6+
"python.testing.pytestEnabled": true,
7+
"[python]": {
8+
"editor.formatOnSave": true,
9+
"editor.codeActionsOnSave": {
10+
"source.fixAll": "explicit",
11+
"source.organizeImports": "explicit"
12+
},
13+
"editor.defaultFormatter": "charliermarsh.ruff",
14+
},
15+
}

jflux/__main__.py

-4
This file was deleted.

jflux/cli.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
import jax
88
import jax.numpy as jnp
9+
import numpy as np
910
from einops import rearrange
1011
from fire import Fire
11-
from flax import nnx
12-
from jax.typing import DTypeLike
1312
from PIL import Image
1413

1514
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
@@ -124,7 +123,8 @@ def main(
124123
by the index of the sample
125124
prompt: Prompt used for sampling
126125
device: Pytorch device
127-
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
126+
num_steps: number of sampling steps
127+
(default 4 for schnell, 50 for guidance distilled)
128128
loop: start an interactive session and sample multiple times
129129
guidance: guidance value used for guidance distillation
130130
add_sampling_metadata: Add the prompt to the image Exif metadata
@@ -216,7 +216,12 @@ def main(
216216
x = x.clip(-1, 1)
217217
x = rearrange(x[0], "c h w -> h w c")
218218

219-
img = Image.fromarray((127.5 * (x + 1.0)))
219+
x = 127.5 * (x + 1.0)
220+
x_numpy = np.array(x.astype(jnp.uint8))
221+
img = Image.fromarray(x_numpy)
222+
223+
img.save(fn, quality=95, subsampling=0)
224+
idx += 1
220225

221226
if loop:
222227
print("-" * 80)

jflux/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self, params: FluxParams):
4949
self.out_channels = self.in_channels
5050
if params.hidden_size % params.num_heads != 0:
5151
raise ValueError(
52-
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
52+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" # noqa: E501
5353
)
5454
pe_dim = params.hidden_size // params.num_heads
5555
if sum(params.axes_dim) != pe_dim:

jflux/port.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from einops import rearrange
22

3+
34
##############################################################################################
45
# AUTOENCODER MODEL PORTING
56
##############################################################################################
@@ -481,3 +482,5 @@ def port_flux(flux, tensors):
481482
tensors=tensors,
482483
prefix="final_layer",
483484
)
485+
486+
return flux

jflux/util.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import os
22
from dataclasses import dataclass
33

4-
import jax
54
import torch # need for t5 and clip
65
from flax import nnx
76
from huggingface_hub import hf_hub_download
87
from jax import numpy as jnp
9-
from jax.typing import DTypeLike
108
from safetensors import safe_open
119

1210
from jflux.model import Flux, FluxParams

pyproject.toml

+5-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ dependencies = [
99
"einops>=0.8.0",
1010
"fire>=0.6.0",
1111
"flax>=0.9.0",
12-
"jflux",
1312
# FIXME: Allow for local installation without GPUs as well `jax[cuda12]`
1413
"jax>=0.4.31",
1514
"mypy>=1.11.2",
@@ -22,6 +21,7 @@ dependencies = [
2221
jflux = "jflux.cli:app"
2322

2423
[tool.uv]
24+
package = true
2525
dev-dependencies = [
2626
"flux",
2727
"pytest>=8.3.3",
@@ -32,7 +32,10 @@ jflux = { workspace = true }
3232
flux = { git = "https://github.com/black-forest-labs/flux.git" }
3333

3434
[tool.ruff.lint]
35-
select = ["I001"]
35+
select = ["E", "F", "I001", "W"]
36+
37+
[tool.ruff.lint.isort]
38+
lines-after-imports = 2
3639

3740
[tool.ruff.lint.pydocstyle]
3841
convention = "google"

tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)