1
- import typing
2
-
3
1
import jax
4
2
from chex import Array
5
3
from einops import rearrange
4
+ from flax import nnx
6
5
from jax import numpy as jnp
7
6
8
7
9
- @typing .no_type_check
10
8
def attention (q : Array , k : Array , v : Array , pe : Array ) -> Array :
11
- # TODO (ariG23498): Change all usage of attention to use this function
12
9
q , k = apply_rope (q , k , pe )
13
10
14
- # jax expects this shape
15
- x = rearrange (x , "B H L D -> B L H D" ) # noqa
16
- x = jax .nn .dot_product_attention (q , k , v )
17
- x = rearrange (x , "B L H D -> B L (H D)" ) # reshape again
11
+ x = nnx .dot_product_attention (q , k , v )
12
+ x = rearrange (x , "B H L D -> B L (H D)" )
18
13
19
14
return x
20
15
21
16
22
17
def rope (pos : Array , dim : int , theta : int ) -> Array :
23
- """
24
- Generate Rotary Position Embedding (RoPE) for positional encoding.
25
-
26
- Args:
27
- pos (Array): Positional values, typically a sequence of positions in an array format.
28
- dim (int): The embedding dimension, which must be an even number.
29
- theta (int): A scaling parameter for RoPE that controls the frequency range of rotations.
30
-
31
- Returns:
32
- Array: Rotary embeddings with cosine and sine components for each position and dimension.
33
- """
34
-
35
- # Embedding dimension must be an even number
36
18
assert dim % 2 == 0
37
-
38
- # Generate the RoPE embeddings
39
19
scale = jnp .arange (0 , dim , 2 , dtype = jnp .float64 , device = pos .device ) / dim
40
20
omega = 1.0 / (theta ** scale )
41
21
out = jnp .einsum ("...n,d->...nd" , pos , omega )
@@ -45,26 +25,10 @@ def rope(pos: Array, dim: int, theta: int) -> Array:
45
25
46
26
47
27
def apply_rope (xq : Array , xk : Array , freqs_cis : Array ) -> tuple [Array , Array ]:
48
- """
49
- Apply RoPE to the input query and key tensors.
50
-
51
- Args:
52
- xq (Array): Query tensor.
53
- xk (Array): Key tensor.
54
- freqs_cis (Array): RoPE frequencies.
55
-
56
- Returns:
57
- tuple[Array, Array]: Query and key tensors with RoPE applied.
58
- """
59
- # Reshape and typecast the input tensors
60
28
xq_ = xq .astype (jnp .float32 ).reshape (* xq .shape [:- 1 ], - 1 , 1 , 2 )
61
29
xk_ = xk .astype (jnp .float32 ).reshape (* xk .shape [:- 1 ], - 1 , 1 , 2 )
62
-
63
- # Apply RoPE to the input tensors
64
30
xq_out = freqs_cis [..., 0 ] * xq_ [..., 0 ] + freqs_cis [..., 1 ] * xq_ [..., 1 ]
65
31
xk_out = freqs_cis [..., 0 ] * xk_ [..., 0 ] + freqs_cis [..., 1 ] * xk_ [..., 1 ]
66
-
67
- # Reshape and typecast the output tensors
68
32
return xq_out .reshape (* xq .shape ).astype (xq .dtype ), xk_out .reshape (* xk .shape ).astype (
69
33
xk .dtype
70
34
)
0 commit comments