Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 239 additions & 14 deletions libopenage/util/fixed_point.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ class FixedPoint {
return FixedPoint::from_int(0);
}

/**
* FixedPoint value that is preinitialized to one.
*/
static constexpr FixedPoint one() {
return FixedPoint::from_int(1);
}

/**
* Math constants represented in FixedPoint
*/
Expand Down Expand Up @@ -366,6 +373,26 @@ class FixedPoint {
static_cast<FixedPoint::unsigned_int_type>(this->raw_value) & std::integral_constant<int_type, FixedPoint::fractional_part_bitmask()>::value);
}

/**
* Converter to retrieve the integral (pre-decimal) part of the number.
*/
constexpr FixedPoint get_integral_part() const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for this method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! I'll try to get those added soon!

// similar to the get_fractional_part() implementation, but the opposite bits.
return FixedPoint::from_raw_value(this->raw_value & ~std::integral_constant<int_type, FixedPoint::fractional_part_bitmask()>::value);
}

/**
* Round to the nearest integer.
*/
constexpr FixedPoint round() const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for this method?

if (this->get_fractional_part() < same_type_but_unsigned(0.5)) {
return this->get_integral_part();
}
Comment on lines +388 to +390
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably consider IEE 754 rounding rules here (see Wikipedia for details). The default one seems to be "round to nearest, ties to even". That would mean there needs to be an additional condition for when the fraction is exactly == 0.5. The rules are a bit counter-intuitive because rounding in real-life is done differently, but apparently it checks out when I test this with Python:

  • -1.5 => -2
  • -0.5 => 0
  • 0.5 => 0
  • 1.5 => 2
  • 2.5 => 2

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should also mention the rounding rule in the docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, I'll take a look soon!

else {
return this->get_integral_part() + one();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return this->get_integral_part() + one();
return this->get_integral_part() + FixedPoint::one();

}
}

// Comparison operators for comparison with other
constexpr auto operator<=>(const FixedPoint &o) const = default;

Expand Down Expand Up @@ -413,6 +440,18 @@ class FixedPoint {
return *this;
}

/**
* FixedPoint *= FixedPoint
*
* This shares the same caveats as FixedPint * FixedPoint.
*
* Use a larger intermediate type to avoid overflow.
*/
constexpr FixedPoint operator*=(const FixedPoint &rhs) {
*this = *this *rhs;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*this = *this *rhs;
*this = *this * rhs;

return *this;
}

/**
* FixedPoint /= N
*/
Expand Down Expand Up @@ -499,16 +538,111 @@ class FixedPoint {
return std::atan2(this->to_double(), n.to_double());
}

constexpr double sin() {
return std::sin(this->to_double());
/**
* Calculate pure FixedPoint sine using a power series approximation.
*
* This may lose absolute precision when the fractional part is small. Because trig functions are
* cyclic, we don't lose precision for large integer values. For best results, you would want
* intermediate_size to be twice the size of raw_type.
*/
constexpr FixedPoint sin() {
size_t order = this->approx_decimal_places;
FixedPoint x = *this;

// Sine is an odd function, so we can pull out the sign.
bool negative = x < 0;
if (negative) {
x = -x;
}

// Ensure we're in the interval (-pi, pi)
// Since this is a series expansion approximation around 0, this interval is where we are most accurate
if (x > FixedPoint::pi()) {
int_type n = (x / FixedPoint::tau()).round().to_int();
x -= FixedPoint::tau() * n;
}

FixedPoint pow_x = x;
FixedPoint sin_x = 0;
size_t factorial = 1;
bool term_sign = 0;
for (size_t i = 0; i < order; i++) {
sin_x += pow_x * (-2 * term_sign + 1) / factorial;
term_sign = !term_sign;
pow_x *= x * x;
factorial *= (2 * i + 2) * (2 * i + 3);
}

return negative ? -sin_x : sin_x;
}

constexpr double cos() {
return std::cos(this->to_double());
/**
* Calculate pure FixedPoint cosine using a power series approximation.
*
* This may lose absolute precision when the fractional part is small. Because trig functions are
* cyclic, we don't lose precision for large integer values. For best results, you would want
* intermediate_size to be twice the size of raw_type.
*/
constexpr FixedPoint cos() {
size_t order = this->approx_decimal_places;
FixedPoint x = *this;

// Cosine is an even function so we can drop the sign
if (x < 0) {
x = -x;
}

// Ensure we're in the interval (-pi, pi)
// Since this is a series expansion approximation around 0, this interval is where we are most accurate
if (x > FixedPoint::pi()) {
int_type n = (x / FixedPoint::tau()).round().to_int();
x -= FixedPoint::tau() * n;
}

FixedPoint pow_x = 1;
FixedPoint cos_x = 0;
size_t factorial = 1;
bool term_sign = 0;
for (size_t i = 0; i < order; i++) {
cos_x += pow_x * (-2 * term_sign + 1) / factorial;
term_sign = !term_sign;
pow_x *= x * x;
factorial *= (2 * i + 1) * (2 * i + 2);
}

return cos_x;
}

constexpr double tan() {
return std::tan(this->to_double());
/**
* Calculate pure FixedPoint tangent using sin() and cos().
*
* This may lose absolute precision when the fractional part is small. Because trig functions are
* cyclic, we don't lose precision for large integer values. For best results, you would want
* intermediate_size to be twice the size of raw_type.
* This is guaranteed to lose precision when approaching its asymptotes (k * pi/2 for odd k).
*/
constexpr FixedPoint tan() {
FixedPoint x = *this;

// Tangent is odd, we can pull out the sign
bool negative = x < 0;
if (negative) {
x = -x;
}

// Ensure we are in the interval (-pi/2, pi/2) for maximum accuracy
if (x > FixedPoint::pi_2()) {
int_type n = (x / FixedPoint::pi()).round().to_int();
x -= FixedPoint::pi() * n;
}

FixedPoint cos_x = x.cos();

// Raise an exception when too near an asymptote.
ENSURE(cos_x != std::clamp(cos_x.to_double(), -1e-7, 1e-7), "FixedPoint::tan() approaches +/- infinity for this value.");
FixedPoint tan_x = x.sin() / cos_x;

return negative ? -tan_x : tan_x;
}
};

Expand Down Expand Up @@ -575,10 +709,56 @@ typename std::enable_if<std::is_arithmetic<N>::value, FixedPoint<I, F, Inter>>::
*/
template <typename I, unsigned int F, typename Inter>
constexpr FixedPoint<I, F, Inter> operator*(const FixedPoint<I, F, Inter> lhs, const FixedPoint<I, F, Inter> rhs) {
Inter ret = static_cast<Inter>(lhs.get_raw_value()) * static_cast<Inter>(rhs.get_raw_value());
ret >>= F;
using uInter = typename std::make_unsigned<Inter>::type;

// An optimization that can prevent overflows.
// This only helps overflows from the actual operations happening here, but not an ordinary overflow of int_type
// This is essentially just Karatsuba
if constexpr (sizeof(I) == sizeof(Inter)) {
constexpr int hwidth = sizeof(Inter) * 4;
uInter lower_mask = ~static_cast<uInter>(0) >> hwidth;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type of unsafe casting should be avoided at all cost, please. This only leads to buggy code. Also, this line of code doesn't work as intended because the shift does nothing. What gets executed is:

uInter lower_mask = ~(static_cast<uInter>(0) >> hwidth);

You should either do

Suggested change
uInter lower_mask = ~static_cast<uInter>(0) >> hwidth;
constexpr uInter lower_mask = std::numeric_limits<uInter>::max() >> hwidth;

or a longer version

Suggested change
uInter lower_mask = ~static_cast<uInter>(0) >> hwidth;
uInter lower_mask = 0;
lower_mask = ~lower_mask;
lower_mask >>= hwidth;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also means we need more tests for the function :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ~ takes precedence over the >> I thought? But I do agree your first solution is cleaner - thanks! I'll add those tests soon!

Copy link
Member

@heinezen heinezen Apr 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it in the debugger and it did not take precendence (even though it should, I agree). I think the cast introduced some weird behavior that the compiler couldn't resolve properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. for uint16_t the value of lower_mask was 65535 even though it should have been 255.


// Store half of each value in each variable
bool l_pos = lhs > 0;
uInter lhs_lower = static_cast<uInter>(std::abs(lhs.get_raw_value()));
Inter lhs_upper = lhs_lower >> hwidth;
lhs_lower = lhs_lower & lower_mask;

bool r_pos = rhs > 0;
uInter rhs_lower = static_cast<uInter>(std::abs(rhs.get_raw_value()));
Inter rhs_upper = rhs_lower >> hwidth;
rhs_lower = rhs_lower & lower_mask;
Comment on lines +722 to +730
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No casting required.

Suggested change
bool l_pos = lhs > 0;
uInter lhs_lower = static_cast<uInter>(std::abs(lhs.get_raw_value()));
Inter lhs_upper = lhs_lower >> hwidth;
lhs_lower = lhs_lower & lower_mask;
bool r_pos = rhs > 0;
uInter rhs_lower = static_cast<uInter>(std::abs(rhs.get_raw_value()));
Inter rhs_upper = rhs_lower >> hwidth;
rhs_lower = rhs_lower & lower_mask;
bool l_pos = lhs > 0;
uInter lhs_lower = std::abs(lhs.get_raw_value());
Inter lhs_upper = lhs_lower >> hwidth;
lhs_lower = lhs_lower & lower_mask;
bool r_pos = rhs > 0;
uInter rhs_lower = std::abs(rhs.get_raw_value());
Inter rhs_upper = rhs_lower >> hwidth;
rhs_lower = rhs_lower & lower_mask;


// Calculate the multiplication piecewise
uInter result_lower = lhs_lower * rhs_lower;
Inter result_mid = lhs_lower * rhs_upper + lhs_upper * rhs_lower;
Inter result_upper = lhs_upper * rhs_upper;

// And recombine.
I result = result_lower >> F;
if constexpr (F > hwidth) {
result += result_mid >> (F - hwidth);
}
else {
result += result_mid << (hwidth - F);
}
if constexpr (F > 2 * hwidth) {
result += result_upper >> (F - 2 * hwidth);
}
else {
// These are the bits that would have been lost. We still may lose some, but there are some we save
result += result_upper << (2 * hwidth - F);
}
result = l_pos ^ r_pos ? -result : result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For readability, we should probably write this out.

Suggested change
result = l_pos ^ r_pos ? -result : result;
if (l_pos xor r_pos) {
result = -result;
}


return FixedPoint<I, F, Inter>::from_raw_value(result);
}
else {
Inter ret = static_cast<Inter>(lhs.get_raw_value()) * static_cast<Inter>(rhs.get_raw_value());
ret >>= F;

return FixedPoint<I, F, Inter>::from_raw_value(static_cast<I>(ret));
return FixedPoint<I, F, Inter>::from_raw_value(static_cast<I>(ret));
}
}


Expand All @@ -587,8 +767,53 @@ constexpr FixedPoint<I, F, Inter> operator*(const FixedPoint<I, F, Inter> lhs, c
*/
template <typename I, unsigned int F, typename Inter>
constexpr FixedPoint<I, F, Inter> operator/(const FixedPoint<I, F, Inter> lhs, const FixedPoint<I, F, Inter> rhs) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function should look similar to the multiplication function, so I'm going to make a few suggestions about that.

Inter ret = div((static_cast<Inter>(lhs.get_raw_value()) << F), static_cast<Inter>(rhs.get_raw_value()));
return FixedPoint<I, F, Inter>::from_raw_value(static_cast<I>(ret));
using uInter = typename std::make_unsigned<Inter>::type;
using FP = FixedPoint<I, F, Inter>;

// Implementation that doesn't lose bits using small intermediate values.
if constexpr (sizeof(I) == sizeof(Inter)) {
constexpr uInter lower_mask = ~static_cast<uInter>(0) >> (sizeof(Inter) * 8 - F);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
constexpr uInter lower_mask = ~static_cast<uInter>(0) >> (sizeof(Inter) * 8 - F);
constexpr int integral_width = sizeof(Inter) * 8 - F;
constexpr uInter lower_mask = std::numeric_limits<uInter>::max() >> integral_width;


// Store the integral and fractional parts in the upper and lower "halves"
bool l_pos = lhs > 0;
uInter lhs_lower = static_cast<uInter>(std::abs(lhs.get_raw_value()));
Inter lhs_upper = lhs_lower >> F;
lhs_lower = lhs_lower & lower_mask;

bool r_pos = rhs > 0;
uInter rhs_lower = static_cast<uInter>(std::abs(rhs.get_raw_value()));
Inter rhs_upper = rhs_lower >> F;
rhs_lower = rhs_lower & lower_mask;
Comment on lines +778 to +786
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool l_pos = lhs > 0;
uInter lhs_lower = static_cast<uInter>(std::abs(lhs.get_raw_value()));
Inter lhs_upper = lhs_lower >> F;
lhs_lower = lhs_lower & lower_mask;
bool r_pos = rhs > 0;
uInter rhs_lower = static_cast<uInter>(std::abs(rhs.get_raw_value()));
Inter rhs_upper = rhs_lower >> F;
rhs_lower = rhs_lower & lower_mask;
bool l_pos = lhs > 0;
uInter lhs_lower = std::abs(lhs.get_raw_value());
Inter lhs_upper = lhs_lower >> F;
lhs_lower = lhs_lower & lower_mask;
bool r_pos = rhs > 0;
uInter rhs_lower = std::abs(rhs.get_raw_value());
Inter rhs_upper = rhs_lower >> F;
rhs_lower = rhs_lower & lower_mask;


// special case when integeral part of rhs is 0
if (rhs_upper == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a test that covers this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is - just not explicitly. It took me ages to figure out why my prior tests weren't passing because of this 😅

I'll add an explicit set of tests that call this out though!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see it now being used by sin(..) for example. An explicit test would be nice though :)

// It's very likely this upper term is zero, but we should consider it regardless.
FP upper_term = FP::from_raw_value((lhs_upper << (2 * F)) / rhs_lower);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generates compiler warning because you are shifting by 2 * 32 sometimes when lhs_upper is a 64-Bit signed integer which causes undefined behavior. You should consider using safe_shiftleft.

FP lower_term = FP::from_raw_value((lhs_lower << F) / rhs_lower);
FP result = upper_term + lower_term;
return l_pos ^ r_pos ? -result : result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return l_pos ^ r_pos ? -result : result;
if (l_pos xor r_pos) {
result = -result;
}
return result;

}

// Calculate the multiplication piecewise
FP upper_term = FixedPoint<I, F, Inter>::from_raw_value((lhs_upper << F) / rhs_upper);
FP lower_term = FP::from_raw_value(((lhs_lower << F) / rhs_upper) >> F);
FP mixed_term = FixedPoint<I, F, Inter>::from_raw_value((rhs_lower) / rhs_upper);
Comment on lines +798 to +800
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
FP upper_term = FixedPoint<I, F, Inter>::from_raw_value((lhs_upper << F) / rhs_upper);
FP lower_term = FP::from_raw_value(((lhs_lower << F) / rhs_upper) >> F);
FP mixed_term = FixedPoint<I, F, Inter>::from_raw_value((rhs_lower) / rhs_upper);
FP upper_term = FP::from_raw_value((lhs_upper << F) / rhs_upper);
FP lower_term = FP::from_raw_value(((lhs_lower << F) / rhs_upper) >> F);
FP mixed_term = FP::from_raw_value((rhs_lower) / rhs_upper);


// Basically doing a power series expansion here for (lhs_upper + lhs_lower) / (rhs_upper + rhs_lower)
FP result = FP::zero();
FP term = FP::one();
for (size_t i = 0; i < 8; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this 8?

result += term * upper_term;
result += term * lower_term;
term *= -mixed_term;
}

return l_pos ^ r_pos ? -result : result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return l_pos ^ r_pos ? -result : result;
if (l_pos xor r_pos) {
result = -result;
}
return result;

}
else {
Inter ret = div((static_cast<Inter>(lhs.get_raw_value()) << F), static_cast<Inter>(rhs.get_raw_value()));
return FixedPoint<I, F, Inter>::from_raw_value(static_cast<I>(ret));
}
}


Expand Down Expand Up @@ -629,17 +854,17 @@ constexpr double atan2(openage::util::FixedPoint<I, F, Inter> x, openage::util::

template <typename I, unsigned F, typename Inter>
constexpr double sin(openage::util::FixedPoint<I, F, Inter> n) {
return n.sin();
return static_cast<double>(n.sin());
}

template <typename I, unsigned F, typename Inter>
constexpr double cos(openage::util::FixedPoint<I, F, Inter> n) {
return n.cos();
return static_cast<double>(n.cos());
}

template <typename I, unsigned F, typename Inter>
constexpr double tan(openage::util::FixedPoint<I, F, Inter> n) {
return n.tan();
return static_cast<double>(n.tan());
}

template <typename I, unsigned F, typename Inter>
Expand Down
Loading
Loading