@@ -81,6 +81,15 @@ def __init__(self, dim):
81
81
def forward (self , x ):
82
82
return F .layer_norm (x , x .shape [- 1 :], self .gamma , self .beta )
83
83
84
+ class MultiHeadedRMSNorm (nn .Module ):
85
+ def __init__ (self , dim , heads = 1 ):
86
+ super ().__init__ ()
87
+ self .scale = dim ** 0.5
88
+ self .gamma = nn .Parameter (torch .ones (heads , 1 , dim ))
89
+
90
+ def forward (self , x ):
91
+ return F .normalize (x , dim = - 1 ) * self .scale * self .gamma
92
+
84
93
# positional embeds
85
94
86
95
class LearnedSinusoidalPosEmb (nn .Module ):
@@ -104,6 +113,7 @@ def __init__(
104
113
heads = 4 ,
105
114
dim_head = 32 ,
106
115
norm = False ,
116
+ qk_norm = False ,
107
117
time_cond_dim = None
108
118
):
109
119
super ().__init__ ()
@@ -127,6 +137,11 @@ def __init__(
127
137
128
138
self .to_qkv = nn .Linear (dim , hidden_dim * 3 , bias = False )
129
139
140
+ self .qk_norm = qk_norm
141
+ if qk_norm :
142
+ self .q_norm = MultiHeadedRMSNorm (dim_head , heads )
143
+ self .k_norm = MultiHeadedRMSNorm (dim_head , heads )
144
+
130
145
self .to_out = nn .Sequential (
131
146
nn .Linear (hidden_dim , dim , bias = False ),
132
147
LayerNorm (dim )
@@ -148,6 +163,10 @@ def forward(
148
163
qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
149
164
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
150
165
166
+ if self .qk_norm :
167
+ q = self .q_norm (q )
168
+ k = self .k_norm (k )
169
+
151
170
q = q .softmax (dim = - 1 )
152
171
k = k .softmax (dim = - 2 )
153
172
@@ -169,7 +188,8 @@ def __init__(
169
188
norm = False ,
170
189
norm_context = False ,
171
190
time_cond_dim = None ,
172
- flash = False
191
+ flash = False ,
192
+ qk_norm = False
173
193
):
174
194
super ().__init__ ()
175
195
hidden_dim = dim_head * heads
@@ -197,6 +217,11 @@ def __init__(
197
217
self .to_kv = nn .Linear (dim_context , hidden_dim * 2 , bias = False )
198
218
self .to_out = nn .Linear (hidden_dim , dim , bias = False )
199
219
220
+ self .qk_norm = qk_norm
221
+ if qk_norm :
222
+ self .q_norm = MultiHeadedRMSNorm (dim_head , heads )
223
+ self .k_norm = MultiHeadedRMSNorm (dim_head , heads )
224
+
200
225
self .attend = Attend (flash = flash )
201
226
202
227
def forward (
@@ -222,6 +247,10 @@ def forward(
222
247
qkv = (self .to_q (x ), * self .to_kv (context ).chunk (2 , dim = - 1 ))
223
248
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
224
249
250
+ if self .qk_norm :
251
+ q = self .q_norm (q )
252
+ k = self .k_norm (k )
253
+
225
254
out = self .attend (q , k , v )
226
255
227
256
out = rearrange (out , 'b h n d -> b n (h d)' )
0 commit comments