Skip to content

Commit 366ae3e

Browse files
committed
Add xPos embeddings
1 parent e52bdab commit 366ae3e

File tree

5 files changed

+136
-7
lines changed

5 files changed

+136
-7
lines changed

finetune_t0_non_causal_decoder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ def get_batch_pipe(data):
9494
segment_ids=segment_ids.long(),
9595
)
9696

97-
if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
97+
if args.position_embedding_type not in [
98+
PositionEmbeddingType.alibi,
99+
PositionEmbeddingType.rotary,
100+
PositionEmbeddingType.xpos,
101+
]:
98102
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")
99103

100104
return (tokens, position_ids, attention_mask), (labels, loss_mask)

megatron/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _add_network_size_args(parser):
398398
group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x],
399399
choices=list(PositionEmbeddingType),
400400
default=PositionEmbeddingType.absolute,
401-
help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.'
401+
help='Define position embedding type ("absolute" | "rotary" | "alibi" | "xpos"). "absolute" by default.'
402402
)
403403
group.add_argument('--glu-activation', type=str,
404404
choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(),

megatron/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ class PositionEmbeddingType(enum.Enum):
3333
rotary = 1
3434
absolute = 2
3535
alibi = 3
36+
xpos = 4

megatron/model/positional_embeddings.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,111 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
4848

4949
def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
5050
cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
51-
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
51+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
52+
53+
54+
# Original implementation adjusted from https://github.com/sunyt32/torchscale
55+
56+
def fixed_pos_embedding(x, base):
57+
seq_len, dim = x.shape
58+
inv_freq = 1.0 / (base ** (torch.arange(0, dim) / dim))
59+
sinusoid_inp = (
60+
torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
61+
)
62+
return torch.cos(sinusoid_inp), torch.sin(sinusoid_inp)
63+
64+
65+
class XPosEmbedding(torch.nn.Module):
66+
"""
67+
xPos positional embeddings from https://arxiv.org/abs/2212.10554.
68+
"""
69+
70+
def __init__(self, head_dim, freq_base=10000, scale_base=512, gamma=0.4, precision=torch.half):
71+
super().__init__()
72+
self.scale_base = scale_base
73+
self.register_buffer(
74+
"scale",
75+
(
76+
(torch.arange(0, head_dim, 2) + gamma * head_dim)
77+
/ ((1.0 + gamma) * head_dim)
78+
),
79+
)
80+
self.max_seq_len_cached = None
81+
self.precision = precision
82+
self.freq_base = freq_base
83+
84+
def forward(self, x, seq_dim=1, seq_len=None):
85+
if seq_len is None:
86+
seq_len = x.shape[seq_dim]
87+
if (
88+
self.max_seq_len_cached is None
89+
or (seq_len > self.max_seq_len_cached)
90+
):
91+
self.max_seq_len_cached = seq_len
92+
scale = (
93+
self.scale
94+
** (
95+
torch.arange(0, seq_len, 1) - seq_len // 2
96+
).to(self.scale).div(self.scale_base)[:, None]
97+
)
98+
cos, sin = fixed_pos_embedding(scale, self.freq_base)
99+
self.cos_cached = cos
100+
self.sin_cached = sin
101+
self.scale_cached = scale
102+
if self.precision == torch.bfloat16:
103+
self.cos_cached = self.cos_cached.bfloat16()
104+
self.sin_cached = self.sin_cached.bfloat16()
105+
return (
106+
self.cos_cached[:seq_len],
107+
self.sin_cached[:seq_len],
108+
self.scale_cached[:seq_len],
109+
)
110+
111+
112+
def rotate_every_two(x):
113+
x1 = x[:, :, ::2]
114+
x2 = x[:, :, 1::2]
115+
x = torch.stack((-x2, x1), dim=-1)
116+
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
117+
118+
119+
def duplicate_interleave(m):
120+
"""
121+
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
122+
"""
123+
dim0 = m.shape[0]
124+
m = m.view(-1, 1) # flatten the matrix
125+
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
126+
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
127+
return m.unsqueeze(1)
128+
129+
130+
def _apply_xpos_emb(x, cos, sin, scale):
131+
# x is assumed to be (seq_len, batch_size, dim) here.
132+
cos = duplicate_interleave(cos * scale)
133+
sin = duplicate_interleave(sin * scale)
134+
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
135+
return (x * cos) + (rotate_every_two(x) * sin)
136+
137+
138+
@torch.jit.script
139+
def apply_xpos_emb(q, k, cos, sin, scale, offset: int = 0):
140+
# q/k are assumed to be (seq_len, batch_size, dim) here.
141+
cos = cos[offset:q.shape[0] + offset]
142+
sin = sin[offset:q.shape[0] + offset]
143+
scale = scale[offset:q.shape[0] + offset]
144+
return (
145+
_apply_xpos_emb(q, cos, sin, scale),
146+
_apply_xpos_emb(q, cos, sin, 1.0 / scale),
147+
)
148+
149+
150+
def apply_xpos_emb_torch(q, k, cos, sin, scale, offset: int = 0):
151+
# q/k are assumed to be (seq_len, batch_size, dim) here.
152+
cos = cos[offset:q.shape[0] + offset]
153+
sin = sin[offset:q.shape[0] + offset]
154+
scale = scale[offset:q.shape[0] + offset]
155+
return (
156+
_apply_xpos_emb(q, cos, sin, scale),
157+
_apply_xpos_emb(q, cos, sin, 1.0 / scale),
158+
)

megatron/model/transformer.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,14 @@
3131
import deepspeed
3232

3333
from .glu_activations import GLU_ACTIVATIONS
34-
from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb
34+
from .positional_embeddings import (
35+
apply_rotary_pos_emb,
36+
apply_rotary_pos_emb_torch,
37+
apply_xpos_emb,
38+
apply_xpos_emb_torch,
39+
RotaryEmbedding,
40+
XPosEmbedding,
41+
)
3542

3643
# flags required to enable jit fusion kernels
3744
torch._C._jit_set_profiling_mode(False)
@@ -204,6 +211,8 @@ def __init__(self, init_method,
204211

205212
if self.position_embedding_type == PositionEmbeddingType.rotary:
206213
self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype)
214+
elif self.position_embedding_type == PositionEmbeddingType.xpos:
215+
self.xpos_emb = XPosEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype)
207216

208217
def forward(self, hidden_states, attention_mask, layer_past=None,
209218
get_key_value=False, encoder_output=None, alibi=None):
@@ -291,16 +300,24 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
291300
matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]]
292301

293302
# Rotary embeddings
294-
if self.position_embedding_type == PositionEmbeddingType.rotary:
295-
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
296-
303+
if self.position_embedding_type in [
304+
PositionEmbeddingType.rotary, PositionEmbeddingType.xpos]:
297305
seq_len = key_layer.shape[0]
298306
offset = 0
299307
if layer_past is not None and layer_past.numel() > 0:
300308
offset = layer_past[0].shape[0]
301309
seq_len += offset
310+
311+
if self.position_embedding_type == PositionEmbeddingType.rotary:
312+
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
302313
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
303314
query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
315+
elif self.position_embedding_type == PositionEmbeddingType.xpos:
316+
apply_xpos_fn = apply_xpos_emb_torch if self.bf16 else apply_xpos_emb
317+
cos, sin, scale = self.xpos_emb(value_layer, seq_len=seq_len)
318+
query_layer, key_layer = apply_xpos_fn(
319+
query_layer, key_layer, cos, sin, scale, offset=offset)
320+
304321

305322
# Raw attention scores. [b * np, sq, sk]
306323
if alibi is None:

0 commit comments

Comments
 (0)