Skip to content

Commit e43da22

Browse files
Adding layers and math (#7)
Co-authored-by: Saurav Maheshkar <[email protected]>
1 parent 55abdee commit e43da22

14 files changed

+864
-396
lines changed

jflux/cli.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,7 @@
1010
from jax.typing import DTypeLike
1111

1212
from jflux.sampling import denoise, get_noise, get_schedule, prepare, unpack
13-
from jflux.util import (
14-
configs,
15-
load_ae,
16-
load_clip,
17-
load_flow_model,
18-
load_t5,
19-
)
13+
from jflux.util import configs, load_ae, load_clip, load_flow_model, load_t5
2014

2115

2216
@dataclass

jflux/math.py

+3-39
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,21 @@
1-
import typing
2-
31
import jax
42
from chex import Array
53
from einops import rearrange
4+
from flax import nnx
65
from jax import numpy as jnp
76

87

9-
@typing.no_type_check
108
def attention(q: Array, k: Array, v: Array, pe: Array) -> Array:
11-
# TODO (ariG23498): Change all usage of attention to use this function
129
q, k = apply_rope(q, k, pe)
1310

14-
# jax expects this shape
15-
x = rearrange(x, "B H L D -> B L H D") # noqa
16-
x = jax.nn.dot_product_attention(q, k, v)
17-
x = rearrange(x, "B L H D -> B L (H D)") # reshape again
11+
x = nnx.dot_product_attention(q, k, v)
12+
x = rearrange(x, "B H L D -> B L (H D)")
1813

1914
return x
2015

2116

2217
def rope(pos: Array, dim: int, theta: int) -> Array:
23-
"""
24-
Generate Rotary Position Embedding (RoPE) for positional encoding.
25-
26-
Args:
27-
pos (Array): Positional values, typically a sequence of positions in an array format.
28-
dim (int): The embedding dimension, which must be an even number.
29-
theta (int): A scaling parameter for RoPE that controls the frequency range of rotations.
30-
31-
Returns:
32-
Array: Rotary embeddings with cosine and sine components for each position and dimension.
33-
"""
34-
35-
# Embedding dimension must be an even number
3618
assert dim % 2 == 0
37-
38-
# Generate the RoPE embeddings
3919
scale = jnp.arange(0, dim, 2, dtype=jnp.float64, device=pos.device) / dim
4020
omega = 1.0 / (theta**scale)
4121
out = jnp.einsum("...n,d->...nd", pos, omega)
@@ -45,26 +25,10 @@ def rope(pos: Array, dim: int, theta: int) -> Array:
4525

4626

4727
def apply_rope(xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
48-
"""
49-
Apply RoPE to the input query and key tensors.
50-
51-
Args:
52-
xq (Array): Query tensor.
53-
xk (Array): Key tensor.
54-
freqs_cis (Array): RoPE frequencies.
55-
56-
Returns:
57-
tuple[Array, Array]: Query and key tensors with RoPE applied.
58-
"""
59-
# Reshape and typecast the input tensors
6028
xq_ = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 1, 2)
6129
xk_ = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 1, 2)
62-
63-
# Apply RoPE to the input tensors
6430
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
6531
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
66-
67-
# Reshape and typecast the output tensors
6832
return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(
6933
xk.dtype
7034
)

jflux/model.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,8 @@
66
from jax import numpy as jnp
77
from jax.typing import DTypeLike
88

9-
from jflux.modules.layers import (
10-
AdaLayerNorm,
11-
Embed,
12-
Identity,
13-
timestep_embedding,
14-
)
159
from jflux.modules import DoubleStreamBlock, MLPEmbedder, SingleStreamBlock
10+
from jflux.modules.layers import AdaLayerNorm, Embed, Identity, timestep_embedding
1611

1712

1813
@dataclass

0 commit comments

Comments
 (0)