Skip to content

Commit 6417cd8

Browse files
authored
Merge pull request #181 from SwayamInSync/heaviside
2 parents e73f34f + e71a66a commit 6417cd8

File tree

4 files changed

+146
-1
lines changed

4 files changed

+146
-1
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,26 @@ quad_logaddexp2(const Sleef_quad *x, const Sleef_quad *y)
752752
return Sleef_addq1_u05(max_val, log2_term);
753753
}
754754

755+
static inline Sleef_quad
756+
quad_heaviside(const Sleef_quad *x1, const Sleef_quad *x2)
757+
{
758+
// heaviside(x1, x2) = 0 if x1 < 0, x2 if x1 == 0, 1 if x1 > 0
759+
// NaN propagation: only propagate NaN from x1, not from x2 (unless x1 == 0)
760+
if (Sleef_iunordq1(*x1, *x1)) {
761+
return *x1; // x1 is NaN, return NaN
762+
}
763+
764+
if (Sleef_icmpltq1(*x1, QUAD_ZERO)) {
765+
return QUAD_ZERO;
766+
}
767+
else if (Sleef_icmpeqq1(*x1, QUAD_ZERO)) {
768+
return *x2; // When x1 == 0, return x2 (even if x2 is NaN)
769+
}
770+
else {
771+
return QUAD_ONE;
772+
}
773+
}
774+
755775
// Binary long double operations
756776
typedef long double (*binary_op_longdouble_def)(const long double *, const long double *);
757777
// Binary long double operations with 2 outputs (for divmod, modf, frexp)
@@ -1002,6 +1022,26 @@ ld_logaddexp2(const long double *x, const long double *y)
10021022
return max_val + log2l(1.0L + exp2l(-abs_diff));
10031023
}
10041024

1025+
static inline long double
1026+
ld_heaviside(const long double *x1, const long double *x2)
1027+
{
1028+
// heaviside(x1, x2) = 0 if x1 < 0, x2 if x1 == 0, 1 if x1 > 0
1029+
// NaN propagation: only propagate NaN from x1, not from x2 (unless x1 == 0)
1030+
if (isnan(*x1)) {
1031+
return *x1; // x1 is NaN, return NaN
1032+
}
1033+
1034+
if (*x1 < 0.0L) {
1035+
return 0.0L;
1036+
}
1037+
else if (*x1 == 0.0L) {
1038+
return *x2; // When x1 == 0, return x2 (even if x2 is NaN)
1039+
}
1040+
else {
1041+
return 1.0L;
1042+
}
1043+
}
1044+
10051045
// comparison quad functions
10061046
typedef npy_bool (*cmp_quad_def)(const Sleef_quad *, const Sleef_quad *);
10071047

quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,9 @@ init_quad_binary_ops(PyObject *numpy)
454454
if (create_quad_binary_ufunc<quad_logaddexp2, ld_logaddexp2>(numpy, "logaddexp2") < 0) {
455455
return -1;
456456
}
457+
if (create_quad_binary_ufunc<quad_heaviside, ld_heaviside>(numpy, "heaviside") < 0) {
458+
return -1;
459+
}
457460
if (create_quad_binary_2out_ufunc<quad_divmod, ld_divmod>(numpy, "divmod") < 0) {
458461
return -1;
459462
}

quaddtype/release_tracker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
| fabs |||
2727
| rint |||
2828
| sign |||
29-
| heaviside | | |
29+
| heaviside | | |
3030
| conj | | |
3131
| conjugate | | |
3232
| exp |||

quaddtype/tests/test_quaddtype.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1591,3 +1591,105 @@ def test_fabs(val):
15911591
if float_result == 0.0:
15921592
assert not np.signbit(quad_result), f"fabs({val}) should not have negative sign"
15931593
assert not np.signbit(quad_arr_result[0]), f"fabs({val}) should not have negative sign"
1594+
1595+
1596+
@pytest.mark.parametrize("x1,x2", [
1597+
# Basic cases: x1 < 0 -> 0
1598+
("-1.0", "0.5"), ("-5.0", "0.5"), ("-100.0", "0.5"),
1599+
("-1e10", "0.5"), ("-0.1", "0.5"),
1600+
1601+
# Basic cases: x1 == 0 -> x2
1602+
("0.0", "0.5"), ("0.0", "0.0"), ("0.0", "1.0"),
1603+
("-0.0", "0.5"), ("-0.0", "0.0"), ("-0.0", "1.0"),
1604+
1605+
# Basic cases: x1 > 0 -> 1
1606+
("1.0", "0.5"), ("5.0", "0.5"), ("100.0", "0.5"),
1607+
("1e10", "0.5"), ("0.1", "0.5"),
1608+
1609+
# Edge cases with different x2 values
1610+
("0.0", "-1.0"), ("0.0", "2.0"), ("0.0", "100.0"),
1611+
1612+
# Special values: infinity
1613+
("inf", "0.5"), ("-inf", "0.5"),
1614+
("inf", "0.0"), ("-inf", "0.0"),
1615+
1616+
# Special values: NaN (should propagate)
1617+
("nan", "0.5"), ("0.5", "nan"), ("nan", "nan"),
1618+
("-nan", "0.5"), ("0.5", "-nan"),
1619+
1620+
# Edge case: zero x1 with special x2
1621+
("0.0", "inf"), ("0.0", "-inf"), ("0.0", "nan"),
1622+
("-0.0", "inf"), ("-0.0", "-inf"), ("-0.0", "nan"),
1623+
])
1624+
def test_heaviside(x1, x2):
1625+
"""
1626+
Test np.heaviside ufunc for QuadPrecision dtype.
1627+
1628+
heaviside(x1, x2) = 0 if x1 < 0
1629+
x2 if x1 == 0
1630+
1 if x1 > 0
1631+
1632+
This is the Heaviside step function where x2 determines the value at x1=0.
1633+
"""
1634+
quad_x1 = QuadPrecision(x1)
1635+
quad_x2 = QuadPrecision(x2)
1636+
float_x1 = float(x1)
1637+
float_x2 = float(x2)
1638+
1639+
# Test scalar inputs
1640+
quad_result = np.heaviside(quad_x1, quad_x2)
1641+
float_result = np.heaviside(float_x1, float_x2)
1642+
1643+
# Test array inputs
1644+
quad_arr_x1 = np.array([quad_x1], dtype=QuadPrecDType())
1645+
quad_arr_x2 = np.array([quad_x2], dtype=QuadPrecDType())
1646+
quad_arr_result = np.heaviside(quad_arr_x1, quad_arr_x2)
1647+
1648+
# Check results match
1649+
np.testing.assert_array_equal(
1650+
np.array(quad_result).astype(float),
1651+
float_result,
1652+
err_msg=f"Scalar heaviside({x1}, {x2}) mismatch"
1653+
)
1654+
1655+
np.testing.assert_array_equal(
1656+
quad_arr_result.astype(float)[0],
1657+
float_result,
1658+
err_msg=f"Array heaviside({x1}, {x2}) mismatch"
1659+
)
1660+
1661+
# Additional checks for non-NaN results
1662+
if not np.isnan(float_result):
1663+
# Verify the expected value based on x1
1664+
if float_x1 < 0:
1665+
assert float(quad_result) == 0.0, f"Expected 0 for heaviside({x1}, {x2})"
1666+
elif float_x1 == 0.0:
1667+
np.testing.assert_array_equal(
1668+
float(quad_result), float_x2,
1669+
err_msg=f"Expected {x2} for heaviside(0, {x2})"
1670+
)
1671+
else: # float_x1 > 0
1672+
assert float(quad_result) == 1.0, f"Expected 1 for heaviside({x1}, {x2})"
1673+
1674+
1675+
def test_heaviside_broadcast():
1676+
"""Test that heaviside works with broadcasting"""
1677+
x1 = np.array([-1.0, 0.0, 1.0], dtype=QuadPrecDType())
1678+
x2 = QuadPrecision("0.5")
1679+
1680+
result = np.heaviside(x1, x2)
1681+
expected = np.array([0.0, 0.5, 1.0], dtype=np.float64)
1682+
1683+
assert result.dtype.name == "QuadPrecDType128"
1684+
np.testing.assert_array_equal(result.astype(float), expected)
1685+
1686+
# Test with array for both arguments
1687+
x1_arr = np.array([-2.0, -0.0, 0.0, 5.0], dtype=QuadPrecDType())
1688+
x2_arr = np.array([0.5, 0.5, 1.0, 0.5], dtype=QuadPrecDType())
1689+
1690+
result = np.heaviside(x1_arr, x2_arr)
1691+
expected = np.array([0.0, 0.5, 1.0, 1.0], dtype=np.float64)
1692+
1693+
assert result.dtype.name == "QuadPrecDType128"
1694+
np.testing.assert_array_equal(result.astype(float), expected)
1695+

0 commit comments

Comments
 (0)