Skip to content

Commit 3e70ce5

Browse files
committed
Add casting from f64 to f32 that preserves values and avoids UB
1 parent 5a22bb9 commit 3e70ce5

File tree

3 files changed

+85
-17
lines changed

3 files changed

+85
-17
lines changed

subspace/num/__private/intrinsics.h

+38-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include <fenv.h>
1718
#include <stddef.h>
1819
#include <stdint.h>
1920

@@ -618,7 +619,7 @@ sus_pure_const inline constexpr OverflowOut<T> add_with_overflow(T x,
618619
};
619620
}
620621

621-
template <class T, class U = decltype(to_signed(std::declval<T>()))>
622+
template <class T, class U = decltype(into_signed(std::declval<T>()))>
622623
requires(std::is_integral_v<T> && !std::is_signed_v<T> &&
623624
::sus::mem::size_of<T>() <= 8 &&
624625
::sus::mem::size_of<T>() == ::sus::mem::size_of<U>())
@@ -631,7 +632,7 @@ sus_pure_const inline constexpr OverflowOut<T> add_with_overflow_signed(
631632
};
632633
}
633634

634-
template <class T, class U = decltype(to_unsigned(std::declval<T>()))>
635+
template <class T, class U = decltype(into_unsigned(std::declval<T>()))>
635636
requires(std::is_integral_v<T> && std::is_signed_v<T> &&
636637
::sus::mem::size_of<T>() <= 8 &&
637638
::sus::mem::size_of<T>() == ::sus::mem::size_of<U>())
@@ -668,7 +669,7 @@ sus_pure_const inline constexpr OverflowOut<T> sub_with_overflow(T x,
668669
};
669670
}
670671

671-
template <class T, class U = decltype(to_unsigned(std::declval<T>()))>
672+
template <class T, class U = decltype(into_unsigned(std::declval<T>()))>
672673
requires(std::is_integral_v<T> && std::is_signed_v<T> &&
673674
::sus::mem::size_of<T>() <= 8 &&
674675
::sus::mem::size_of<T>() == ::sus::mem::size_of<U>())
@@ -1258,12 +1259,14 @@ sus_pure_const sus_always_inline constexpr int32_t exponent_bits(
12581259
unchecked_shr(into_unsigned_integer(x) & mask, 52));
12591260
}
12601261

1261-
sus_pure_const sus_always_inline constexpr int32_t exponent_value(
1262+
/// This function requires that `x` is a normal value to produce a value result.
1263+
sus_pure_const sus_always_inline constexpr int32_t float_normal_exponent_value(
12621264
float x) noexcept {
12631265
return exponent_bits(x) - int32_t{127};
12641266
}
12651267

1266-
sus_pure_const sus_always_inline constexpr int32_t exponent_value(
1268+
/// This function requires that `x` is a normal value to produce a value result.
1269+
sus_pure_const sus_always_inline constexpr int32_t float_normal_exponent_value(
12671270
double x) noexcept {
12681271
return exponent_bits(x) - int32_t{1023};
12691272
}
@@ -1399,8 +1402,10 @@ sus_pure_const inline constexpr T truncate_float(T x) noexcept {
13991402
: uint32_t{52};
14001403

14011404
if (float_is_inf_or_nan(x) || float_is_zero(x)) return x;
1405+
if (float_nonzero_is_subnormal(x)) [[unlikely]]
1406+
return T{0};
14021407

1403-
const int32_t exponent = exponent_value(x);
1408+
const int32_t exponent = float_normal_exponent_value(x);
14041409

14051410
// If the exponent is greater than the most negative mantissa
14061411
// exponent, then x is already an integer.
@@ -1521,4 +1526,31 @@ sus_pure_const inline T next_toward(T from, T to) {
15211526
return std::nexttoward(from, to);
15221527
}
15231528

1529+
#pragma warning(push)
1530+
// MSVC claims that "overflow in constant arithmetic" occurs on the static_cast
1531+
// in `into_smaller_float()` but we check for overflow first, the conversion is
1532+
// in range.
1533+
#pragma warning(disable : 4756)
1534+
1535+
// Not constexpr as rounding is always toward zero in a constexpr context.
1536+
template <class Out, class T>
1537+
requires(std::is_floating_point_v<T> && ::sus::mem::size_of<T>() == 8 &&
1538+
::sus::mem::size_of<Out>() == 4)
1539+
sus_pure_const inline Out into_smaller_float(T x) noexcept {
1540+
if (x <= T{max_value<Out>()} && x >= T{min_value<Out>()}) [[likely]] {
1541+
// SAFETY: Because the value `x` is at or between two valid values of type
1542+
// `Out`, the static_cast does not cause UB.
1543+
return static_cast<Out>(x); // Handles values in range.
1544+
}
1545+
if (x > T{max_value<Out>()}) {
1546+
return infinity<Out>(); // Handles large values and INFINITY.
1547+
}
1548+
if (x < T{min_value<Out>()}) {
1549+
return negative_infinity<Out>(); // Handles small values and NEG_INFINITY.
1550+
}
1551+
return nan<Out>(); // All that's left are NaNs.
1552+
}
1553+
1554+
#pragma warning(pop)
1555+
15241556
} // namespace sus::num::__private

subspace/num/convert.h

+14-10
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,21 @@
1818
#include <type_traits>
1919

2020
#include "subspace/construct/to_bits.h"
21+
#include "subspace/num/__private/intrinsics.h"
2122
#include "subspace/num/float.h"
2223
#include "subspace/num/signed_integer.h"
2324
#include "subspace/num/unsigned_integer.h"
2425

25-
/// Casting from a float to an integer will round the float towards zero,
26-
/// except:
27-
/// * NaN will return 0.
28-
/// * Values larger than the maximum integer value, including `INFINITY`, will
26+
/// * Casting from a float to an integer will round the float towards zero,
27+
/// except:
28+
/// * NaN will return 0.
29+
/// * Values larger than the maximum integer value, including `INFINITY`, will
2930
/// saturate to the maximum value of the integer type.
30-
/// * Values smaller than the minimum integer value, including `NEG_INFINITY`,
31-
/// will saturate to the minimum value of the integer type.
31+
/// * Values smaller than the minimum integer value, including `NEG_INFINITY`,
32+
/// will saturate to the minimum value of the integer type.
33+
/// * Casting from an f32 to an f64 preserves the value unchanged.
34+
/// * Casting f64 to f32...
35+
3236

3337
// # ================ From signed integers. ============================
3438

@@ -335,13 +339,13 @@ struct sus::construct::ToBitsImpl<T, F> {
335339
if constexpr (::sus::mem::size_of<F>() == 4u) {
336340
return from;
337341
} else {
338-
return std::bit_cast<float>(
339-
static_cast<uint32_t>(std::bit_cast<uint64_t>(from)));
342+
return ::sus::num::__private::into_smaller_float<float>(from);
340343
}
341344
} else {
342345
if constexpr (::sus::mem::size_of<F>() == 4u) {
343-
return std::bit_cast<double>(
344-
static_cast<uint64_t>(std::bit_cast<uint32_t>(from)));
346+
// C++20 Section 7.3.7: A prvalue of type float can be converted to a
347+
// prvalue of type double. The value is unchanged.
348+
return T{from};
345349
} else {
346350
return from;
347351
}

subspace/num/convert_unittest.cc

+33-1
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,37 @@ TEST(ConvertToBits, isize) {
354354
}
355355
}
356356

357+
TEST(ConvertToBits, LosslessFloatConversion) {
358+
EXPECT_EQ(sus::to_bits<f64>(-1.8949651689383756e-14_f32),
359+
-1.8949651689383756e-14_f64);
360+
EXPECT_EQ(sus::to_bits<f32>(-1.8949651689383756e-14_f32),
361+
-1.8949651689383756e-14_f32);
362+
EXPECT_EQ(sus::to_bits<f64>(-4.59218127443847370761468605771e-102_f64),
363+
-4.59218127443847370761468605771e-102_f64);
364+
}
365+
366+
TEST(ConvertToBits, f64tof32) {
367+
EXPECT_EQ(sus::to_bits<f32>(f64::NAN).is_nan(), true);
368+
EXPECT_EQ(sus::to_bits<f32>(f64::INFINITY), f32::INFINITY);
369+
EXPECT_EQ(sus::to_bits<f32>(f64::NEG_INFINITY), f32::NEG_INFINITY);
370+
EXPECT_EQ(sus::to_bits<f32>(f64::MAX), f32::INFINITY);
371+
EXPECT_EQ(sus::to_bits<f32>(f64::MIN), f32::NEG_INFINITY);
372+
373+
// Just past the valid range of values for f32 in either direciton. A
374+
// static_cast<float>(double) for these values would cause UB.
375+
EXPECT_EQ(sus::to_bits<f32>(
376+
sus::to_bits<f64>(f32::MIN).next_toward(f64::NEG_INFINITY)),
377+
f32::NEG_INFINITY);
378+
EXPECT_EQ(
379+
sus::to_bits<f32>(sus::to_bits<f64>(f32::MAX).next_toward(f64::INFINITY)),
380+
f32::INFINITY);
381+
382+
// This is a value with bits set throughout the exponent and mantissa. Its
383+
// exponent is <= 127 and >= -126 so it's possible to represent it in f32.
384+
EXPECT_EQ(sus::to_bits<f32>(-4.59218127443847370761468605771e-102_f64),
385+
-4.59218127443847370761468605771e-102_f32);
386+
}
387+
357388
TEST(ConvertToBits, f32) {
358389
static_assert(std::same_as<decltype(sus::to_bits<u16>(0_f32)), u16>);
359390

@@ -546,7 +577,8 @@ TEST(ConvertToBits, f64) {
546577
EXPECT_EQ(sus::to_bits<i64>(0.51_f64), 0_i64);
547578
EXPECT_EQ(sus::to_bits<i64>(0.9999_f64), 0_i64);
548579
EXPECT_EQ(sus::to_bits<i64>(1_f64), 1_i64);
549-
EXPECT_LT(sus::to_bits<i64>((9223372036854775807_f64).next_toward(0_f64)), i64::MAX);
580+
EXPECT_LT(sus::to_bits<i64>((9223372036854775807_f64).next_toward(0_f64)),
581+
i64::MAX);
550582
EXPECT_EQ(sus::to_bits<i64>(9223372036854775807_f64), i64::MAX);
551583
EXPECT_EQ(sus::to_bits<i64>(9223372036854775807.00001_f64), i64::MAX);
552584
EXPECT_EQ(sus::to_bits<i64>(9223372036854775807_f64 * 2_f64), i64::MAX);

0 commit comments

Comments
 (0)