Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions quaddtype/numpy_quaddtype/src/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ common_instance(QuadPrecDTypeObject *dtype1, QuadPrecDTypeObject *dtype2)
static PyArray_DTypeMeta *
common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
{
// Handle Python abstract dtypes (PyLongDType, PyFloatDType)
// These have type_num = -1
if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType) {
Py_INCREF(cls);
return cls;
}

// Promote integer and floating-point types to QuadPrecDType
if (other->type_num >= 0 &&
(PyTypeNum_ISINTEGER(other->type_num) || PyTypeNum_ISFLOAT(other->type_num))) {
Expand All @@ -116,14 +123,21 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
static PyArray_Descr *
quadprec_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls), PyObject *obj)
{
if (Py_TYPE(obj) != &QuadPrecision_Type) {
PyErr_SetString(PyExc_TypeError, "Can only store QuadPrecision in a QuadPrecDType array.");
return NULL;
if (Py_TYPE(obj) == &QuadPrecision_Type) {
/* QuadPrecision scalar: use its backend */
QuadPrecisionObject *quad_obj = (QuadPrecisionObject *)obj;
return (PyArray_Descr *)new_quaddtype_instance(quad_obj->backend);
}

QuadPrecisionObject *quad_obj = (QuadPrecisionObject *)obj;

return (PyArray_Descr *)new_quaddtype_instance(quad_obj->backend);

/* For Python int/float/other numeric types: return default descriptor */
/* The casting machinery will handle conversion to QuadPrecision */
if (PyLong_Check(obj) || PyFloat_Check(obj)) {
return (PyArray_Descr *)new_quaddtype_instance(BACKEND_SLEEF);
}

/* Unknown type - ERROR */
PyErr_SetString(PyExc_TypeError, "Can only store QuadPrecision, int, or float in a QuadPrecDType array.");
return NULL;
}

static int
Expand Down Expand Up @@ -261,6 +275,50 @@ quadprec_get_constant(PyArray_Descr *descr, int constant_id, void *ptr)
return 1;
}

/*
* Fill function.
* The buffer already has the first two elements set:
* buffer[0] = start
* buffer[1] = start + step
* We need to fill buffer[2..length-1] with the arithmetic progression.
*/
static int
quadprec_fill(void *buffer, npy_intp length, void *arr_)
{
PyArrayObject *arr = (PyArrayObject *)arr_;
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(arr);
QuadBackendType backend = descr->backend;
npy_intp i;

if (length < 2) {
return 0; // Nothing to fill
}

if (backend == BACKEND_SLEEF) {
Sleef_quad *buf = (Sleef_quad *)buffer;
Sleef_quad start = buf[0];
Sleef_quad delta = Sleef_subq1_u05(buf[1], start); // delta = buf[1] - start

for (i = 2; i < length; ++i) {
// buf[i] = start + i * delta
Sleef_quad i_quad = Sleef_cast_from_doubleq1(i);
Sleef_quad i_delta = Sleef_mulq1_u05(i_quad, delta);
buf[i] = Sleef_addq1_u05(start, i_delta);
}
}
else {
long double *buf = (long double *)buffer;
long double start = buf[0];
long double delta = buf[1] - start;

for (i = 2; i < length; ++i) {
buf[i] = start + i * delta;
}
}

return 0;
}

static PyType_Slot QuadPrecDType_Slots[] = {
{NPY_DT_ensure_canonical, &ensure_canonical},
{NPY_DT_common_instance, &common_instance},
Expand All @@ -270,6 +328,7 @@ static PyType_Slot QuadPrecDType_Slots[] = {
{NPY_DT_getitem, &quadprec_getitem},
{NPY_DT_default_descr, &quadprec_default_descr},
{NPY_DT_get_constant, &quadprec_get_constant},
{NPY_DT_PyArray_ArrFuncs_fill, &quadprec_fill},
{0, NULL}};

static PyObject *
Expand Down
71 changes: 71 additions & 0 deletions quaddtype/tests/test_quaddtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,74 @@ def test_hyperbolic_functions(op, val):
if float_result == 0.0:
assert np.signbit(float_result) == np.signbit(
quad_result), f"Zero sign mismatch for {op}({val})"


class TestTypePomotionWithPythonAbstractTypes:
"""Tests for common_dtype handling of Python abstract dtypes (PyLongDType, PyFloatDType)"""

def test_promotion_with_python_int(self):
"""Test that Python int promotes to QuadPrecDType"""
# Create array from Python int
arr = np.array([1, 2, 3], dtype=QuadPrecDType)
assert arr.dtype.name == "QuadPrecDType128"
assert len(arr) == 3
assert float(arr[0]) == 1.0
assert float(arr[1]) == 2.0
assert float(arr[2]) == 3.0

def test_promotion_with_python_float(self):
"""Test that Python float promotes to QuadPrecDType"""
# Create array from Python float
arr = np.array([1.5, 2.7, 3.14], dtype=QuadPrecDType)
assert arr.dtype.name == "QuadPrecDType128"
assert len(arr) == 3
np.testing.assert_allclose(float(arr[0]), 1.5, rtol=1e-15)
np.testing.assert_allclose(float(arr[1]), 2.7, rtol=1e-15)
np.testing.assert_allclose(float(arr[2]), 3.14, rtol=1e-15)

def test_result_dtype_binary_ops_with_python_types(self):
"""Test that binary operations between QuadPrecDType and Python scalars return QuadPrecDType"""
quad_arr = np.array([QuadPrecision("1.0"), QuadPrecision("2.0")])

# Addition with Python int
result = quad_arr + 5
assert result.dtype.name == "QuadPrecDType128"
assert float(result[0]) == 6.0
assert float(result[1]) == 7.0

# Multiplication with Python float
result = quad_arr * 2.5
assert result.dtype.name == "QuadPrecDType128"
np.testing.assert_allclose(float(result[0]), 2.5, rtol=1e-15)
np.testing.assert_allclose(float(result[1]), 5.0, rtol=1e-15)

def test_concatenate_with_python_types(self):
"""Test concatenation handles Python numeric types correctly"""
quad_arr = np.array([QuadPrecision("1.0")])
# This should work if promotion is correct
int_arr = np.array([2], dtype=np.int64)

# The result dtype should be QuadPrecDType
result = np.concatenate([quad_arr, int_arr.astype(QuadPrecDType)])
assert result.dtype.name == "QuadPrecDType128"
assert len(result) == 2


@pytest.mark.parametrize("func,args,expected", [
# arange tests
(np.arange, (0, 10), list(range(10))),
(np.arange, (0, 10, 2), [0, 2, 4, 6, 8]),
(np.arange, (0.0, 5.0, 0.5), [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]),
(np.arange, (10, 0, -1), [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),
(np.arange, (-5, 5), list(range(-5, 5))),
# linspace tests
(np.linspace, (0, 10, 11), list(range(11))),
(np.linspace, (0, 1, 5), [0.0, 0.25, 0.5, 0.75, 1.0]),
])
def test_fill_function(func, args, expected):
"""Test quadprec_fill function with arange and linspace"""
arr = func(*args, dtype=QuadPrecDType())
assert arr.dtype.name == "QuadPrecDType128"
assert len(arr) == len(expected)
for i, exp_val in enumerate(expected):
np.testing.assert_allclose(float(arr[i]), float(exp_val), rtol=1e-15, atol=1e-15)
Loading