@@ -299,13 +299,106 @@ pub fn sigmoid(x: f32) -> f32 {
299
299
1.0 / ( 1.0 + ( -x) . exp ( ) )
300
300
}
301
301
302
+ #[ cfg( target_arch = "x86_64" ) ]
303
+ #[ target_feature( enable = "sse" ) ]
304
+ #[ target_feature( enable = "fma" ) ]
305
+ unsafe fn dot_sse_fma ( xs : & [ f32 ] , ys : & [ f32 ] ) -> f32 {
306
+ use core:: arch:: x86_64:: _mm_add_ps;
307
+ use core:: arch:: x86_64:: _mm_add_ss;
308
+ use core:: arch:: x86_64:: _mm_cvtss_f32;
309
+ use core:: arch:: x86_64:: _mm_fmadd_ps;
310
+ use core:: arch:: x86_64:: _mm_loadu_ps;
311
+ use core:: arch:: x86_64:: _mm_movehdup_ps;
312
+ use core:: arch:: x86_64:: _mm_movehl_ps;
313
+ use core:: arch:: x86_64:: _mm_setzero_ps;
314
+
315
+ debug_assert_eq ! ( xs. len( ) , ys. len( ) ) ;
316
+
317
+ let xc = xs. chunks_exact ( 2 * 4 ) ;
318
+ let yc = ys. chunks_exact ( 2 * 4 ) ;
319
+
320
+ let sum_all = xc
321
+ . remainder ( )
322
+ . iter ( )
323
+ . zip ( yc. remainder ( ) . iter ( ) )
324
+ . map ( |( x, y) | x * y)
325
+ . sum :: < f32 > ( ) ;
326
+ let mut sum = _mm_setzero_ps ( ) ;
327
+
328
+ for ( x, y) in xc. zip ( yc) {
329
+ let xptr = x. as_ptr ( ) ;
330
+ let yptr = y. as_ptr ( ) ;
331
+
332
+ let xv = _mm_loadu_ps ( xptr) ;
333
+ let yv = _mm_loadu_ps ( yptr) ;
334
+ sum = _mm_fmadd_ps ( xv, yv, sum) ;
335
+
336
+ let xv = _mm_loadu_ps ( xptr. add ( 4 ) ) ;
337
+ let yv = _mm_loadu_ps ( yptr. add ( 4 ) ) ;
338
+ sum = _mm_fmadd_ps ( xv, yv, sum) ;
339
+ }
340
+
341
+ // Using hacks in
342
+ // https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction
343
+ let mut shuf = _mm_movehdup_ps ( sum) ;
344
+ let mut sums = _mm_add_ps ( sum, shuf) ;
345
+ shuf = _mm_movehl_ps ( shuf, sums) ;
346
+ sums = _mm_add_ss ( sums, shuf) ;
347
+ sum_all + _mm_cvtss_f32 ( sums)
348
+ }
349
+
350
+ unsafe fn dot_avx_fma ( xs : & [ f32 ] , ys : & [ f32 ] ) -> f32 {
351
+ use core:: arch:: x86_64:: {
352
+ _mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps,
353
+ _mm256_setzero_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_movehdup_ps, _mm_movehl_ps,
354
+ } ;
355
+ debug_assert_eq ! ( xs. len( ) , ys. len( ) ) ;
356
+
357
+ let xc = xs. chunks_exact ( 8 ) ;
358
+ let yc = ys. chunks_exact ( 8 ) ;
359
+
360
+ let sum_all = xc
361
+ . remainder ( )
362
+ . iter ( )
363
+ . zip ( yc. remainder ( ) . iter ( ) )
364
+ . map ( |( x, y) | x * y)
365
+ . sum :: < f32 > ( ) ;
366
+ let mut sum = _mm256_setzero_ps ( ) ;
367
+
368
+ for ( x, y) in xc. zip ( yc) {
369
+ let xptr = x. as_ptr ( ) ;
370
+ let yptr = y. as_ptr ( ) ;
371
+
372
+ let xv = _mm256_loadu_ps ( xptr) ;
373
+ let yv = _mm256_loadu_ps ( yptr) ;
374
+ sum = _mm256_fmadd_ps ( xv, yv, sum) ;
375
+ }
376
+
377
+ // Using hacks in
378
+ // https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction
379
+ let mut lo = _mm256_castps256_ps128 ( sum) ;
380
+ let hi = _mm256_extractf128_ps ( sum, 1 ) ;
381
+ lo = _mm_add_ps ( lo, hi) ;
382
+
383
+ let mut shuf = _mm_movehdup_ps ( lo) ;
384
+ let mut sums = _mm_add_ps ( lo, shuf) ;
385
+ shuf = _mm_movehl_ps ( shuf, sums) ;
386
+ sums = _mm_add_ss ( sums, shuf) ;
387
+ sum_all + _mm_cvtss_f32 ( sums)
388
+ }
389
+
302
390
/// Compute the dot product.
303
391
///
304
392
/// `xs` and `ys` must be the same length
305
393
///
306
394
/// (From ndarray 0.15.6)
307
395
fn unrolled_dot ( xs : & [ f32 ] , ys : & [ f32 ] ) -> f32 {
308
396
debug_assert_eq ! ( xs. len( ) , ys. len( ) ) ;
397
+ if std:: is_x86_feature_detected!( "avx" ) {
398
+ unsafe {
399
+ return dot_avx_fma ( xs, ys) ;
400
+ }
401
+ }
309
402
// eightfold unrolled so that floating point can be vectorized
310
403
// (even with strict floating point accuracy semantics)
311
404
let mut p = ( 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 0.0 ) ;
0 commit comments