@@ -27,29 +27,27 @@ def _reshape_activation_tensor(
27
27
28
28
@staticmethod
29
29
def silu_and_mul (out : torch .Tensor , x : torch .Tensor ) -> None :
30
- x1 , x2 = ipex_ops ._reshape_activation_tensor (x )
31
- ipex .llm .functional .silu_mul (x1 , x2 , out )
30
+ ipex .llm .functional .silu_and_mul (x , out )
32
31
33
32
@staticmethod
34
33
def gelu_and_mul (out : torch .Tensor , x : torch .Tensor ) -> None :
35
- x1 , x2 = ipex_ops ._reshape_activation_tensor (x )
36
- ipex .llm .functional .gelu_mul (x1 , x2 , out , "none" )
34
+ ipex .llm .functional .gelu_and_mul (x , out )
37
35
38
36
@staticmethod
39
37
def gelu_tanh_and_mul (out : torch .Tensor , x : torch .Tensor ) -> None :
40
- x1 , x2 = ipex_ops ._reshape_activation_tensor (x )
41
- ipex .llm .functional .gelu_mul (x1 , x2 , out , "tanh" )
38
+ ipex .llm .functional .gelu_and_mul (x , out )
42
39
43
40
@staticmethod
44
- def gelu_fast (out : torch . Tensor , x : torch .Tensor ) -> None :
45
- out . copy_ ( torch .nn .functional .gelu (x ) )
41
+ def gelu_fast (x : torch .Tensor ) -> torch . Tensor :
42
+ return torch .nn .functional .gelu (x )
46
43
47
44
@staticmethod
48
- def gelu_new (out : torch . Tensor , x : torch .Tensor ) -> None :
49
- out . copy_ ( torch .nn .functional .gelu (x ) )
45
+ def gelu_new (x : torch .Tensor ) -> torch . Tensor :
46
+ return torch .nn .functional .gelu (x )
50
47
51
- # TODO add implementation of gelu_quick here
52
- # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
48
+ @staticmethod
49
+ def gelu_quick (out : torch .Tensor , x : torch .Tensor ) -> None :
50
+ ipex .llm .functional .gelu_quick (x , out )
53
51
54
52
@staticmethod
55
53
def paged_attention_v1 (
@@ -160,67 +158,26 @@ def rotary_embedding(
160
158
cos_sin_cache : torch .Tensor , # [cos_sin_dim, rot_dim]
161
159
is_neox : bool ,
162
160
) -> None :
163
- if positions .dim () == 1 :
164
- positions = positions .unsqueeze (0 )
165
- query = query .unsqueeze (0 )
166
- key = key .unsqueeze (0 )
167
-
168
- rotary_dim = cos_sin_cache .size (1 )
169
- query = query .view (* query .shape [:- 1 ], - 1 , head_size )
170
- key = key .view (* key .shape [:- 1 ], - 1 , head_size )
171
-
172
- query_rot = query [..., :rotary_dim ]
173
- key_rot = key [..., :rotary_dim ]
174
-
175
- cos_sin = cos_sin_cache [positions .long ()]
176
- cos , sin = cos_sin .chunk (2 , dim = - 1 )
177
-
178
- if is_neox :
179
- cos = cos .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
180
- sin = sin .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
181
- else :
182
- cos = cos .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
183
- sin = sin .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
184
- ipex .llm .functional .rotary_embedding (query_rot , key_rot , sin , cos ,
185
- rotary_dim , is_neox , positions )
161
+ rot_dim = cos_sin_cache .size (1 )
162
+ ipex .llm .functional .rotary_embedding_batched (positions , query , key ,
163
+ head_size , cos_sin_cache ,
164
+ is_neox , rot_dim )
186
165
187
166
@staticmethod
188
167
def batched_rotary_embedding (positions : torch .Tensor , query : torch .Tensor ,
189
168
key : torch .Tensor , head_size : int ,
190
169
cos_sin_cache : torch .Tensor , is_neox : bool ,
191
170
rot_dim : int ,
192
171
cos_sin_cache_offsets : torch .Tensor ) -> None :
193
- if positions .dim () == 1 :
194
- positions = positions .unsqueeze (0 )
195
- query = query .unsqueeze (0 )
196
- key = key .unsqueeze (0 )
197
- cos_sin_cache_offsets = cos_sin_cache_offsets .view_as (positions )
198
- rotary_dim = cos_sin_cache .size (1 )
199
- query = query .view (* query .shape [:- 1 ], - 1 , head_size )
200
- key = key .view (* key .shape [:- 1 ], - 1 , head_size )
201
-
202
- query_rot = query [..., :rotary_dim ]
203
- key_rot = key [..., :rotary_dim ]
204
-
205
- cos_sin = cos_sin_cache [torch .add (positions ,
206
- cos_sin_cache_offsets ).long ()]
207
- cos , sin = cos_sin .chunk (2 , dim = - 1 )
208
-
209
- if is_neox :
210
- cos = cos .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
211
- sin = sin .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
212
- else :
213
- cos = cos .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
214
- sin = sin .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
215
-
216
- ipex .llm .functional .rotary_embedding (query_rot , key_rot , sin , cos ,
217
- rotary_dim , is_neox , positions )
172
+ ipex .llm .functional .rotary_embedding_batched (positions , query , key ,
173
+ head_size , cos_sin_cache ,
174
+ is_neox , rot_dim ,
175
+ cos_sin_cache_offsets )
218
176
219
177
@staticmethod
220
- def rms_norm (out : torch .Tensor , input : torch .Tensor , weight : torch .Tensor ,
221
- epsilon : float ) -> None :
222
- tmp = ipex .llm .functional .rms_norm (input , weight , epsilon )
223
- out .copy_ (tmp )
178
+ def rms_norm (input : torch .Tensor , weight : torch .Tensor ,
179
+ epsilon : float ) -> torch .Tensor :
180
+ return ipex .llm .functional .rms_norm (input , weight , epsilon )
224
181
225
182
@staticmethod
226
183
def fused_add_rms_norm (input : torch .Tensor , residual : torch .Tensor ,
@@ -246,11 +203,14 @@ def varlen_attention(
246
203
return_softmax : bool ,
247
204
gen_ : torch .Generator ,
248
205
) -> None :
249
- ipex .llm .functional .varlen_attention (query , key , value , out , seqlen_q ,
250
- seqlen_k , max_seqlen_q ,
251
- max_seqlen_k , pdropout ,
252
- softmax_scale , zero_tensors ,
253
- is_causal , return_softmax , gen_ )
206
+ ipex .llm .functional .varlen_attention (query .contiguous (),
207
+ key .contiguous (),
208
+ value .contiguous (), out ,
209
+ seqlen_q .int (), seqlen_k .int (),
210
+ max_seqlen_q , max_seqlen_k ,
211
+ pdropout , softmax_scale ,
212
+ zero_tensors , is_causal ,
213
+ return_softmax , gen_ )
254
214
255
215
@staticmethod
256
216
def reshape_and_cache (
0 commit comments