1- use super :: { args:: Meta , fill_pos, Args , Rope , Seq , SinCosTable } ;
1+ use super :: { args:: Meta , args :: RopeType as R , fill_pos, Args , Rope , Seq , SinCosTable } ;
22use crate :: {
33 common_cpu:: Cpu , get_static, strides_not_support, ByteOf , LaunchError , QueueAlloc , SchemeError ,
44 Unsigned ,
55} ;
66use digit_layout:: { types as ty, DigitLayout } ;
77use half:: f16;
8+ use std:: ptr:: null;
9+ #[ derive( Copy , Clone ) ]
10+ enum NtkPartsType {
11+ None ,
12+ Yarn ,
13+ }
814
15+ #[ derive( Copy , Clone ) ]
16+ enum SchemeType {
17+ Rope {
18+ s : f32 ,
19+ } ,
20+ Long {
21+ long : * const u8 ,
22+ short : * const u8 ,
23+ s : f32 ,
24+ origin_pos : u32 ,
25+ } ,
26+ NtkParts {
27+ alpha : f32 ,
28+ beta : f32 ,
29+ l0 : f32 ,
30+ s : f32 ,
31+ ntktype : NtkPartsType ,
32+ } ,
33+ }
934pub struct Operator ;
1035
1136impl Rope < Cpu > for Operator {
@@ -78,6 +103,7 @@ impl crate::Operator for Operator {
78103 p_layout,
79104 p_base,
80105 theta,
106+ rope_type,
81107 ..
82108 } = args;
83109 let & [ _, nh, dh] = t_layout. shape ( ) else {
@@ -99,6 +125,50 @@ impl crate::Operator for Operator {
99125 return Err ( strides_not_support ( "" ) . into ( ) ) ;
100126 }
101127
128+ let ( theta, scheme_type) = match rope_type {
129+ R :: Rope | R :: Dyn { .. } | R :: Ntk { .. } | R :: Pi { .. } => {
130+ let ( theta, s) = match rope_type {
131+ R :: Rope => ( * theta, 1. ) ,
132+ R :: Dyn { s, a } => ( theta * ( a * s - a + 1. ) , 1. ) ,
133+ R :: Ntk { s } => ( theta * s, 1. ) ,
134+ R :: Pi { s } => ( * theta, * s) ,
135+ _ => unreachable ! ( ) ,
136+ } ;
137+ ( theta, SchemeType :: Rope { s } )
138+ }
139+ R :: Long {
140+ long,
141+ short,
142+ max_pos,
143+ origin_pos,
144+ } => {
145+ let s = 1.0
146+ + ( ( * max_pos as f32 / * origin_pos as f32 ) . ln ( ) / ( * origin_pos as f32 ) . ln ( ) )
147+ . sqrt ( ) ;
148+ let scheme_type = SchemeType :: Long {
149+ long : long. cast ( ) ,
150+ short : short. cast ( ) ,
151+ s,
152+ origin_pos : * origin_pos,
153+ } ;
154+ ( * theta, scheme_type)
155+ }
156+ R :: Yarn { alpha, beta, l0, s } | R :: NtkParts { alpha, beta, l0, s } => {
157+ let ntktype = match rope_type {
158+ R :: NtkParts { .. } => NtkPartsType :: None ,
159+ R :: Yarn { .. } => NtkPartsType :: Yarn ,
160+ _ => unreachable ! ( ) ,
161+ } ;
162+ let scheme_type = SchemeType :: NtkParts {
163+ alpha : * alpha,
164+ beta : * beta,
165+ l0 : * l0,
166+ s : * s,
167+ ntktype,
168+ } ;
169+ ( * theta, scheme_type)
170+ }
171+ } ;
102172 macro_rules! calculate {
103173 ( $t: ty, $p: ty) => {
104174 Scheme :: <$t, $p> {
@@ -108,9 +178,10 @@ impl crate::Operator for Operator {
108178 st,
109179 sh,
110180 sp,
111- theta: * theta ,
181+ theta,
112182 t_base: t_base. cast( ) ,
113183 p_base: p_base. cast( ) ,
184+ scheme_type,
114185 }
115186 . calculate( )
116187 } ;
@@ -142,15 +213,15 @@ struct Scheme<A, P> {
142213 theta : f32 ,
143214 t_base : * mut A ,
144215 p_base : * const P ,
216+ scheme_type : SchemeType ,
145217}
146218
147219unsafe impl < A , P > Send for Scheme < A , P > { }
148220unsafe impl < A , P > Sync for Scheme < A , P > { }
149-
150221/// 激活值。
151222trait Activation : Sized {
152223 /// 激活值类型决定计算类型。
153- type Calculation ;
224+ type Calculation : Copy ;
154225 /// 计算流程。
155226 fn calculate ( pair : & mut [ Self ; 2 ] , sin : Self :: Calculation , cos : Self :: Calculation ) ;
156227}
@@ -187,15 +258,69 @@ impl Activation for f64 {
187258}
188259
189260trait Position < Calculation > {
190- fn freq_sin_cos ( self , k : isize , dh : isize , theta : f32 ) -> ( Calculation , Calculation ) ;
261+ fn freq_sin_cos_rope (
262+ self ,
263+ k : isize ,
264+ dh : isize ,
265+ theta : f32 ,
266+ s : f32 ,
267+ ) -> ( Calculation , Calculation ) ;
268+ fn freq_sin_cos_long (
269+ self ,
270+ k : isize ,
271+ dh : isize ,
272+ t : f32 ,
273+ f : Calculation ,
274+ s : f32 ,
275+ ) -> ( Calculation , Calculation ) ;
276+ #[ allow( clippy:: too_many_arguments) ]
277+ fn freq_sin_cos_ntk_part (
278+ self ,
279+ k : isize ,
280+ dh : isize ,
281+ theta : f32 ,
282+ alpha : f32 ,
283+ beta : f32 ,
284+ l0 : f32 ,
285+ s : f32 ,
286+ ntktype : NtkPartsType ,
287+ ) -> ( Calculation , Calculation ) ;
191288}
192289
193290macro_rules! impl_position {
194291 ( $a: ty) => {
195292 impl <T : Unsigned > Position <$a> for T {
196293 #[ inline]
197- fn freq_sin_cos( self , k: isize , dh: isize , theta: f32 ) -> ( $a, $a) {
198- ( self . val( ) as $a / ( theta as $a) . powf( k as $a / dh as $a) ) . sin_cos( )
294+ fn freq_sin_cos_rope( self , k: isize , dh: isize , theta: f32 , s: f32 ) -> ( $a, $a) {
295+ ( self . val( ) as $a * s as $a * ( theta as $a) . powf( k as $a / dh as $a) . recip( ) )
296+ . sin_cos( )
297+ }
298+ #[ inline]
299+ fn freq_sin_cos_long( self , k: isize , dh: isize , t: f32 , f: $a, s: f32 ) -> ( $a, $a) {
300+ let ( sin, cos) =
301+ ( self . val( ) as $a * ( t as $a) . powf( k as $a / dh as $a) . recip( ) * f) . sin_cos( ) ;
302+ ( sin * s as $a, cos * s as $a)
303+ }
304+ #[ inline]
305+ fn freq_sin_cos_ntk_part(
306+ self ,
307+ k: isize ,
308+ dh: isize ,
309+ theta: f32 ,
310+ alpha: f32 ,
311+ beta: f32 ,
312+ l0: f32 ,
313+ s: f32 ,
314+ ntktype: NtkPartsType ,
315+ ) -> ( $a, $a) {
316+ use std:: f32 :: consts:: PI ;
317+ let pos = match ntktype {
318+ NtkPartsType :: None => self . val( ) as $a,
319+ NtkPartsType :: Yarn => self . val( ) as $a * ( 0.1 * s. ln( ) + 1. ) as $a,
320+ } ;
321+ let theta = theta. powf( k as f32 / dh as f32 ) . recip( ) ;
322+ let r = ( ( l0 / ( 2. * PI / theta) - alpha) / ( beta - alpha) ) . clamp( 0. , 1. ) ;
323+ ( pos * ( ( 1. - r) / s + r) as $a * theta as $a) . sin_cos( )
199324 }
200325 }
201326 } ;
@@ -206,8 +331,8 @@ impl_position!(f64);
206331
207332impl < A , P > Scheme < A , P >
208333where
209- A : Activation ,
210- P : Position < A :: Calculation > + Sync + Copy ,
334+ A : Activation + Copy ,
335+ P : Position < A :: Calculation > + Sync + Copy + Unsigned ,
211336{
212337 fn calculate ( & self ) {
213338 let & Self {
@@ -220,6 +345,7 @@ where
220345 theta,
221346 t_base,
222347 p_base,
348+ scheme_type,
223349 } = self ;
224350 let nt = nt as isize ;
225351 let nh = nh as isize ;
@@ -229,10 +355,38 @@ where
229355 for i in 0 ..nt {
230356 let t = unsafe { t_base. byte_offset ( i * st) . cast :: < [ A ; 2 ] > ( ) } ;
231357 let p = unsafe { * p_base. byte_offset ( i * sp) } ;
358+ let factor = match scheme_type {
359+ SchemeType :: Long {
360+ long,
361+ short,
362+ origin_pos,
363+ ..
364+ } => unsafe {
365+ if p. val ( ) < origin_pos as usize {
366+ ( short as * const P ) . byte_offset ( i * st) . cast ( )
367+ } else {
368+ ( long as * const P ) . byte_offset ( i * st) . cast ( )
369+ }
370+ } ,
371+ _ => null ( ) ,
372+ } ;
232373 for j in 0 ..nh {
233374 for k in 0 ..dh {
234375 let pair = unsafe { & mut * t. byte_offset ( j * sh + k * sd) } ;
235- let ( sin, cos) = p. freq_sin_cos ( k, dh, theta) ;
376+ let ( sin, cos) = match scheme_type {
377+ SchemeType :: Rope { s } => p. freq_sin_cos_rope ( k, dh, theta, s) ,
378+ SchemeType :: Long { s, .. } => {
379+ let factor = unsafe { * factor } ;
380+ p. freq_sin_cos_long ( k, dh, theta, factor, s)
381+ }
382+ SchemeType :: NtkParts {
383+ alpha,
384+ beta,
385+ l0,
386+ s,
387+ ntktype,
388+ } => p. freq_sin_cos_ntk_part ( k, dh, theta, alpha, beta, l0, s, ntktype) ,
389+ } ;
236390 A :: calculate ( pair, sin, cos)
237391 }
238392 }
0 commit comments