From 8610dc7689a1c31a996617cce728d1fe713dc24e Mon Sep 17 00:00:00 2001 From: Geolm Date: Thu, 25 Jan 2024 21:35:35 -0500 Subject: [PATCH] exp and exp2 compliance --- math_intrinsics.h | 18 +++++++++++++++--- tests/test.c | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/math_intrinsics.h b/math_intrinsics.h index d90a328..97130fe 100644 --- a/math_intrinsics.h +++ b/math_intrinsics.h @@ -462,6 +462,8 @@ static inline simd_vector simd_sign(simd_vector a) __m256 mm256_exp_ps(__m256 x) #endif { + simd_vector invalid_mask = simd_isnan(x); + simd_vector input_is_infinity = simd_cmp_eq(x, simd_splat_positive_infinity()); simd_vector tmp = simd_splat_zero(); simd_vector fx; simd_vector one = simd_splat(1.f); @@ -493,8 +495,11 @@ static inline simd_vector simd_sign(simd_vector a) emm0 = simd_shift_left_i(emm0, 23); simd_vector pow2n = simd_cast_from_int(emm0); - y = simd_mul(y, pow2n); - return y; + simd_vector result = simd_mul(y, pow2n); + result = simd_or(result, invalid_mask); + result = simd_select(result, simd_splat_positive_infinity(), input_is_infinity); // +inf arg will be +inf + + return result; } //---------------------------------------------------------------------------------------------------------------------- @@ -505,6 +510,9 @@ static inline simd_vector simd_sign(simd_vector a) __m256 mm256_exp2_ps(__m256 x) #endif { + simd_vector invalid_mask = simd_isnan(x); + simd_vector input_is_infinity = simd_cmp_eq(x, simd_splat_positive_infinity()); + // clamp values x = simd_clamp(x, simd_splat(-127.f), simd_splat(127.f)); simd_vector equal_to_zero = simd_cmp_eq(x, simd_splat_zero()); @@ -522,7 +530,11 @@ static inline simd_vector simd_sign(simd_vector a) px = simd_fmad(px, x, one); px = simd_ldexp(px, i0); - return simd_select(px, one, equal_to_zero); + simd_vector result = simd_select(px, one, equal_to_zero); + result = simd_or(result, invalid_mask); + result = simd_select(result, simd_splat_positive_infinity(), input_is_infinity); // +inf arg will be +inf + + return result; } //---------------------------------------------------------------------------------------------------------------------- diff --git a/tests/test.c b/tests/test.c index 7dcee33..60e2694 100644 --- a/tests/test.c +++ b/tests/test.c @@ -199,17 +199,35 @@ SUITE(infinity_nan_compliance) const float not_a_number = nanf(""); #ifdef __MATH__INTRINSICS__AVX__ + + // log RUN_TESTp(nan_expected, -1.f, mm256_log_ps); RUN_TESTp(nan_expected, not_a_number, mm256_log_ps); RUN_TESTp(value_expected, 1.f, 0.f, mm256_log_ps); RUN_TESTp(value_expected, 0.f, negative_inf, mm256_log_ps); RUN_TESTp(value_expected, positive_inf, positive_inf, mm256_log_ps); + // log2 RUN_TESTp(nan_expected, -1.f, mm256_log2_ps); RUN_TESTp(nan_expected, not_a_number, mm256_log2_ps); RUN_TESTp(value_expected, 1.f, 0.f, mm256_log2_ps); RUN_TESTp(value_expected, 0.f, negative_inf, mm256_log2_ps); RUN_TESTp(value_expected, positive_inf, positive_inf, mm256_log2_ps); + + // exp + RUN_TESTp(nan_expected, not_a_number, mm256_exp_ps); + RUN_TESTp(value_expected, 0.f, 1.f, mm256_exp_ps); + RUN_TESTp(value_expected,-0.f, 1.f, mm256_exp_ps); + RUN_TESTp(value_expected, positive_inf, positive_inf, mm256_exp_ps); + RUN_TESTp(value_expected, negative_inf, 0.f, mm256_exp_ps); + + // exp2 + RUN_TESTp(nan_expected, not_a_number, mm256_exp2_ps); + RUN_TESTp(value_expected, 0.f, 1.f, mm256_exp2_ps); + RUN_TESTp(value_expected,-0.f, 1.f, mm256_exp2_ps); + RUN_TESTp(value_expected, positive_inf, positive_inf, mm256_exp2_ps); + RUN_TESTp(value_expected, negative_inf, 0.f, mm256_exp2_ps); + #else RUN_TESTp(nan_expected, -1.f, vlogq_f32); RUN_TESTp(nan_expected, not_a_number, vlogq_f32); @@ -222,6 +240,20 @@ SUITE(infinity_nan_compliance) RUN_TESTp(value_expected, 1.f, 0.f, vlog2q_f32); RUN_TESTp(value_expected, 0.f, negative_inf, vlog2q_f32); RUN_TESTp(value_expected, positive_inf, positive_inf, vlog2q_f32); + + // exp + RUN_TESTp(nan_expected, not_a_number, vexpq_f32); + RUN_TESTp(value_expected, 0.f, 1.f, vexpq_f32); + RUN_TESTp(value_expected,-0.f, 1.f, vexpq_f32); + RUN_TESTp(value_expected, positive_inf, positive_inf, vexpq_f32); + RUN_TESTp(value_expected, negative_inf, 0.f, vexpq_f32); + + // exp2 + RUN_TESTp(nan_expected, not_a_number, vexp2q_f32); + RUN_TESTp(value_expected, 0.f, 1.f, vexp2q_f32); + RUN_TESTp(value_expected,-0.f, 1.f, vexp2q_f32); + RUN_TESTp(value_expected, positive_inf, positive_inf, vexp2q_f32); + RUN_TESTp(value_expected, negative_inf, 0.f, vexp2q_f32); #endif }