@@ -48,4 +48,111 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
48
48
49
49
def apply_rotary_pos_emb_torch (q , k , cos , sin , offset : int = 0 ): # jitting fails with bf16
50
50
cos , sin = cos [offset :q .shape [0 ] + offset , ...], sin [offset :q .shape [0 ] + offset , ...]
51
- return (q * cos ) + (rotate_half (q ) * sin ), (k * cos ) + (rotate_half (k ) * sin )
51
+ return (q * cos ) + (rotate_half (q ) * sin ), (k * cos ) + (rotate_half (k ) * sin )
52
+
53
+
54
+ # Original implementation adjusted from https://github.com/sunyt32/torchscale
55
+
56
+ def fixed_pos_embedding (x , base ):
57
+ seq_len , dim = x .shape
58
+ inv_freq = 1.0 / (base ** (torch .arange (0 , dim ) / dim ))
59
+ sinusoid_inp = (
60
+ torch .einsum ("i , j -> i j" , torch .arange (0 , seq_len , dtype = torch .float ), inv_freq ).to (x )
61
+ )
62
+ return torch .cos (sinusoid_inp ), torch .sin (sinusoid_inp )
63
+
64
+
65
+ class XPosEmbedding (torch .nn .Module ):
66
+ """
67
+ xPos positional embeddings from https://arxiv.org/abs/2212.10554.
68
+ """
69
+
70
+ def __init__ (self , head_dim , freq_base = 10000 , scale_base = 512 , gamma = 0.4 , precision = torch .half ):
71
+ super ().__init__ ()
72
+ self .scale_base = scale_base
73
+ self .register_buffer (
74
+ "scale" ,
75
+ (
76
+ (torch .arange (0 , head_dim , 2 ) + gamma * head_dim )
77
+ / ((1.0 + gamma ) * head_dim )
78
+ ),
79
+ )
80
+ self .max_seq_len_cached = None
81
+ self .precision = precision
82
+ self .freq_base = freq_base
83
+
84
+ def forward (self , x , seq_dim = 1 , seq_len = None ):
85
+ if seq_len is None :
86
+ seq_len = x .shape [seq_dim ]
87
+ if (
88
+ self .max_seq_len_cached is None
89
+ or (seq_len > self .max_seq_len_cached )
90
+ ):
91
+ self .max_seq_len_cached = seq_len
92
+ scale = (
93
+ self .scale
94
+ ** (
95
+ torch .arange (0 , seq_len , 1 ) - seq_len // 2
96
+ ).to (self .scale ).div (self .scale_base )[:, None ]
97
+ )
98
+ cos , sin = fixed_pos_embedding (scale , self .freq_base )
99
+ self .cos_cached = cos
100
+ self .sin_cached = sin
101
+ self .scale_cached = scale
102
+ if self .precision == torch .bfloat16 :
103
+ self .cos_cached = self .cos_cached .bfloat16 ()
104
+ self .sin_cached = self .sin_cached .bfloat16 ()
105
+ return (
106
+ self .cos_cached [:seq_len ],
107
+ self .sin_cached [:seq_len ],
108
+ self .scale_cached [:seq_len ],
109
+ )
110
+
111
+
112
+ def rotate_every_two (x ):
113
+ x1 = x [:, :, ::2 ]
114
+ x2 = x [:, :, 1 ::2 ]
115
+ x = torch .stack ((- x2 , x1 ), dim = - 1 )
116
+ return x .flatten (- 2 ) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
117
+
118
+
119
+ def duplicate_interleave (m ):
120
+ """
121
+ A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
122
+ """
123
+ dim0 = m .shape [0 ]
124
+ m = m .view (- 1 , 1 ) # flatten the matrix
125
+ m = m .repeat (1 , 2 ) # repeat all elements into the 2nd dimension
126
+ m = m .view (dim0 , - 1 ) # reshape into a matrix, interleaving the copy
127
+ return m .unsqueeze (1 )
128
+
129
+
130
+ def _apply_xpos_emb (x , cos , sin , scale ):
131
+ # x is assumed to be (seq_len, batch_size, dim) here.
132
+ cos = duplicate_interleave (cos * scale )
133
+ sin = duplicate_interleave (sin * scale )
134
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
135
+ return (x * cos ) + (rotate_every_two (x ) * sin )
136
+
137
+
138
+ @torch .jit .script
139
+ def apply_xpos_emb (q , k , cos , sin , scale , offset : int = 0 ):
140
+ # q/k are assumed to be (seq_len, batch_size, dim) here.
141
+ cos = cos [offset :q .shape [0 ] + offset ]
142
+ sin = sin [offset :q .shape [0 ] + offset ]
143
+ scale = scale [offset :q .shape [0 ] + offset ]
144
+ return (
145
+ _apply_xpos_emb (q , cos , sin , scale ),
146
+ _apply_xpos_emb (q , cos , sin , 1.0 / scale ),
147
+ )
148
+
149
+
150
+ def apply_xpos_emb_torch (q , k , cos , sin , scale , offset : int = 0 ):
151
+ # q/k are assumed to be (seq_len, batch_size, dim) here.
152
+ cos = cos [offset :q .shape [0 ] + offset ]
153
+ sin = sin [offset :q .shape [0 ] + offset ]
154
+ scale = scale [offset :q .shape [0 ] + offset ]
155
+ return (
156
+ _apply_xpos_emb (q , cos , sin , scale ),
157
+ _apply_xpos_emb (q , cos , sin , 1.0 / scale ),
158
+ )
0 commit comments