Skip to content

Commit d28a3ed

Browse files
committed
sse fmadd
1 parent c7567d4 commit d28a3ed

File tree

3 files changed

+95
-5
lines changed

3 files changed

+95
-5
lines changed

experimental/segmenter/src/lib.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,8 @@
7878
clippy::panic,
7979
clippy::exhaustive_structs,
8080
clippy::exhaustive_enums,
81-
missing_debug_implementations,
8281
)
8382
)]
84-
#![warn(missing_docs)]
8583

8684
extern crate alloc;
8785

@@ -104,9 +102,9 @@ pub mod provider;
104102
pub mod symbols;
105103

106104
#[cfg(feature = "lstm")]
107-
mod lstm;
105+
pub mod lstm;
108106
#[cfg(feature = "lstm")]
109-
mod lstm_bies;
107+
pub mod lstm_bies;
110108
#[cfg(feature = "lstm")]
111109
mod lstm_error;
112110
#[cfg(feature = "lstm")]

experimental/segmenter/src/line.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ impl LineSegmenter {
236236
Self::try_new_lstm_with_options_unstable(provider, Default::default())
237237
}
238238

239-
#[cfg(feature = "lstm")]
240239
icu_provider::gen_any_buffer_constructors!(
241240
locale: skip,
242241
options: skip,

experimental/segmenter/src/math_helper.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,106 @@ pub fn sigmoid(x: f32) -> f32 {
299299
1.0 / (1.0 + (-x).exp())
300300
}
301301

302+
#[cfg(target_arch = "x86_64")]
303+
#[target_feature(enable = "sse")]
304+
#[target_feature(enable = "fma")]
305+
unsafe fn dot_sse_fma(xs: &[f32], ys: &[f32]) -> f32 {
306+
use core::arch::x86_64::_mm_add_ps;
307+
use core::arch::x86_64::_mm_add_ss;
308+
use core::arch::x86_64::_mm_cvtss_f32;
309+
use core::arch::x86_64::_mm_fmadd_ps;
310+
use core::arch::x86_64::_mm_loadu_ps;
311+
use core::arch::x86_64::_mm_movehdup_ps;
312+
use core::arch::x86_64::_mm_movehl_ps;
313+
use core::arch::x86_64::_mm_setzero_ps;
314+
315+
debug_assert_eq!(xs.len(), ys.len());
316+
317+
let xc = xs.chunks_exact(2 * 4);
318+
let yc = ys.chunks_exact(2 * 4);
319+
320+
let sum_all = xc
321+
.remainder()
322+
.iter()
323+
.zip(yc.remainder().iter())
324+
.map(|(x, y)| x * y)
325+
.sum::<f32>();
326+
let mut sum = _mm_setzero_ps();
327+
328+
for (x, y) in xc.zip(yc) {
329+
let xptr = x.as_ptr();
330+
let yptr = y.as_ptr();
331+
332+
let xv = _mm_loadu_ps(xptr);
333+
let yv = _mm_loadu_ps(yptr);
334+
sum = _mm_fmadd_ps(xv, yv, sum);
335+
336+
let xv = _mm_loadu_ps(xptr.add(4));
337+
let yv = _mm_loadu_ps(yptr.add(4));
338+
sum = _mm_fmadd_ps(xv, yv, sum);
339+
}
340+
341+
// Using hacks in
342+
// https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction
343+
let mut shuf = _mm_movehdup_ps(sum);
344+
let mut sums = _mm_add_ps(sum, shuf);
345+
shuf = _mm_movehl_ps(shuf, sums);
346+
sums = _mm_add_ss(sums, shuf);
347+
sum_all + _mm_cvtss_f32(sums)
348+
}
349+
350+
unsafe fn dot_avx_fma(xs: &[f32], ys: &[f32]) -> f32 {
351+
use core::arch::x86_64::{
352+
_mm256_castps256_ps128, _mm256_extractf128_ps, _mm256_fmadd_ps, _mm256_loadu_ps,
353+
_mm256_setzero_ps, _mm_add_ps, _mm_add_ss, _mm_cvtss_f32, _mm_movehdup_ps, _mm_movehl_ps,
354+
};
355+
debug_assert_eq!(xs.len(), ys.len());
356+
357+
let xc = xs.chunks_exact(8);
358+
let yc = ys.chunks_exact(8);
359+
360+
let sum_all = xc
361+
.remainder()
362+
.iter()
363+
.zip(yc.remainder().iter())
364+
.map(|(x, y)| x * y)
365+
.sum::<f32>();
366+
let mut sum = _mm256_setzero_ps();
367+
368+
for (x, y) in xc.zip(yc) {
369+
let xptr = x.as_ptr();
370+
let yptr = y.as_ptr();
371+
372+
let xv = _mm256_loadu_ps(xptr);
373+
let yv = _mm256_loadu_ps(yptr);
374+
sum = _mm256_fmadd_ps(xv, yv, sum);
375+
}
376+
377+
// Using hacks in
378+
// https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction
379+
let mut lo = _mm256_castps256_ps128(sum);
380+
let hi = _mm256_extractf128_ps(sum, 1);
381+
lo = _mm_add_ps(lo, hi);
382+
383+
let mut shuf = _mm_movehdup_ps(lo);
384+
let mut sums = _mm_add_ps(lo, shuf);
385+
shuf = _mm_movehl_ps(shuf, sums);
386+
sums = _mm_add_ss(sums, shuf);
387+
sum_all + _mm_cvtss_f32(sums)
388+
}
389+
302390
/// Compute the dot product.
303391
///
304392
/// `xs` and `ys` must be the same length
305393
///
306394
/// (From ndarray 0.15.6)
307395
fn unrolled_dot(xs: &[f32], ys: &[f32]) -> f32 {
308396
debug_assert_eq!(xs.len(), ys.len());
397+
if std::is_x86_feature_detected!("avx") {
398+
unsafe {
399+
return dot_avx_fma(xs, ys);
400+
}
401+
}
309402
// eightfold unrolled so that floating point can be vectorized
310403
// (even with strict floating point accuracy semantics)
311404
let mut p = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0);

0 commit comments

Comments
 (0)