Skip to content

Commit 59eea0c

Browse files
committed
Merge pull request numpy#3616 from seberg/buffer-pybuf_simple
BUG: Fix PyBUF_SIMPLE flag to GetBuffer.
2 parents c831298 + 5c2dafb commit 59eea0c

File tree

3 files changed

+123
-9
lines changed

3 files changed

+123
-9
lines changed

numpy/core/src/multiarray/buffer.c

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -604,23 +604,22 @@ array_getbuffer(PyObject *obj, Py_buffer *view, int flags)
604604

605605
/* Check whether we can provide the wanted properties */
606606
if ((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS &&
607-
!PyArray_CHKFLAGS(self, NPY_ARRAY_C_CONTIGUOUS)) {
607+
!PyArray_CHKFLAGS(self, NPY_ARRAY_C_CONTIGUOUS)) {
608608
PyErr_SetString(PyExc_ValueError, "ndarray is not C-contiguous");
609609
goto fail;
610610
}
611611
if ((flags & PyBUF_F_CONTIGUOUS) == PyBUF_F_CONTIGUOUS &&
612-
!PyArray_CHKFLAGS(self, NPY_ARRAY_F_CONTIGUOUS)) {
612+
!PyArray_CHKFLAGS(self, NPY_ARRAY_F_CONTIGUOUS)) {
613613
PyErr_SetString(PyExc_ValueError, "ndarray is not Fortran contiguous");
614614
goto fail;
615615
}
616616
if ((flags & PyBUF_ANY_CONTIGUOUS) == PyBUF_ANY_CONTIGUOUS
617-
&& !PyArray_ISONESEGMENT(self)) {
617+
&& !PyArray_ISONESEGMENT(self)) {
618618
PyErr_SetString(PyExc_ValueError, "ndarray is not contiguous");
619619
goto fail;
620620
}
621621
if ((flags & PyBUF_STRIDES) != PyBUF_STRIDES &&
622-
(flags & PyBUF_ND) == PyBUF_ND &&
623-
!PyArray_CHKFLAGS(self, NPY_ARRAY_C_CONTIGUOUS)) {
622+
!PyArray_CHKFLAGS(self, NPY_ARRAY_C_CONTIGUOUS)) {
624623
/* Non-strided N-dim buffers must be C-contiguous */
625624
PyErr_SetString(PyExc_ValueError, "ndarray is not C-contiguous");
626625
goto fail;

numpy/core/src/multiarray/multiarray_tests.c.src

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,6 @@ inplace_increment(PyObject *dummy, PyObject *args)
522522
}
523523
type_number = PyArray_TYPE(a);
524524

525-
526-
527-
528525
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
529526
if (type_number == type_numbers[i]) {
530527
add_inplace = addition_funcs[i];
@@ -558,6 +555,7 @@ fail:
558555
return NULL;
559556
}
560557

558+
561559
#if !defined(NPY_PY3K)
562560
static PyObject *
563561
int_subclass(PyObject *dummy, PyObject *args)
@@ -581,6 +579,116 @@ int_subclass(PyObject *dummy, PyObject *args)
581579
}
582580
#endif
583581

582+
583+
/*
584+
* Create python string from a FLAG and or the corresponding PyBuf flag
585+
* for the use in get_buffer_info.
586+
*/
587+
#define GET_PYBUF_FLAG(FLAG) \
588+
buf_flag = PyUnicode_FromString(#FLAG); \
589+
flag_matches = PyObject_RichCompareBool(buf_flag, tmp, Py_EQ); \
590+
Py_DECREF(buf_flag); \
591+
if (flag_matches == 1) { \
592+
Py_DECREF(tmp); \
593+
flags |= PyBUF_##FLAG; \
594+
continue; \
595+
} \
596+
else if (flag_matches == -1) { \
597+
Py_DECREF(tmp); \
598+
return NULL; \
599+
}
600+
601+
602+
/*
603+
* Get information for a buffer through PyBuf_GetBuffer with the
604+
* corresponding flags or'ed. Note that the python caller has to
605+
* make sure that or'ing those flags actually makes sense.
606+
* More information should probably be returned for future tests.
607+
*/
608+
static PyObject *
609+
get_buffer_info(PyObject *NPY_UNUSED(self), PyObject *args)
610+
{
611+
PyObject *buffer_obj, *pyflags;
612+
PyObject *tmp, *buf_flag;
613+
Py_buffer buffer;
614+
PyObject *shape, *strides;
615+
Py_ssize_t i, n;
616+
int flag_matches;
617+
int flags = 0;
618+
619+
if (!PyArg_ParseTuple(args, "OO", &buffer_obj, &pyflags)) {
620+
return NULL;
621+
}
622+
623+
n = PySequence_Length(pyflags);
624+
if (n < 0) {
625+
return NULL;
626+
}
627+
628+
for (i=0; i < n; i++) {
629+
tmp = PySequence_GetItem(pyflags, i);
630+
if (tmp == NULL) {
631+
return NULL;
632+
}
633+
634+
GET_PYBUF_FLAG(SIMPLE);
635+
GET_PYBUF_FLAG(WRITABLE);
636+
GET_PYBUF_FLAG(STRIDES);
637+
GET_PYBUF_FLAG(ND);
638+
GET_PYBUF_FLAG(C_CONTIGUOUS);
639+
GET_PYBUF_FLAG(F_CONTIGUOUS);
640+
GET_PYBUF_FLAG(ANY_CONTIGUOUS);
641+
GET_PYBUF_FLAG(INDIRECT);
642+
GET_PYBUF_FLAG(FORMAT);
643+
GET_PYBUF_FLAG(STRIDED);
644+
GET_PYBUF_FLAG(STRIDED_RO);
645+
GET_PYBUF_FLAG(RECORDS);
646+
GET_PYBUF_FLAG(RECORDS_RO);
647+
GET_PYBUF_FLAG(FULL);
648+
GET_PYBUF_FLAG(FULL_RO);
649+
GET_PYBUF_FLAG(CONTIG);
650+
GET_PYBUF_FLAG(CONTIG_RO);
651+
652+
Py_DECREF(tmp);
653+
654+
/* One of the flags must match */
655+
PyErr_SetString(PyExc_ValueError, "invalid flag used.");
656+
return NULL;
657+
}
658+
659+
if (PyObject_GetBuffer(buffer_obj, &buffer, flags) < 0) {
660+
return NULL;
661+
}
662+
663+
if (buffer.shape == NULL) {
664+
Py_INCREF(Py_None);
665+
shape = Py_None;
666+
}
667+
else {
668+
shape = PyTuple_New(buffer.ndim);
669+
for (i=0; i < buffer.ndim; i++) {
670+
PyTuple_SET_ITEM(shape, i, PyLong_FromSsize_t(buffer.shape[i]));
671+
}
672+
}
673+
674+
if (buffer.strides == NULL) {
675+
Py_INCREF(Py_None);
676+
strides = Py_None;
677+
}
678+
else {
679+
strides = PyTuple_New(buffer.ndim);
680+
for (i=0; i < buffer.ndim; i++) {
681+
PyTuple_SET_ITEM(strides, i, PyLong_FromSsize_t(buffer.strides[i]));
682+
}
683+
}
684+
685+
PyBuffer_Release(&buffer);
686+
return Py_BuildValue("(NN)", shape, strides);
687+
}
688+
689+
#undef GET_PYBUF_FLAG
690+
691+
584692
static PyMethodDef Multiarray_TestsMethods[] = {
585693
{"test_neighborhood_iterator",
586694
test_neighborhood_iterator,
@@ -602,6 +710,9 @@ static PyMethodDef Multiarray_TestsMethods[] = {
602710
int_subclass,
603711
METH_VARARGS, NULL},
604712
#endif
713+
{"get_buffer_info",
714+
get_buffer_info,
715+
METH_VARARGS, NULL},
605716
{NULL, NULL, 0, NULL} /* Sentinel */
606717
};
607718

numpy/core/tests/test_multiarray.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from numpy.core.multiarray_tests import (
1414
test_neighborhood_iterator, test_neighborhood_iterator_oob,
1515
test_pydatamem_seteventhook_start, test_pydatamem_seteventhook_end,
16-
test_inplace_increment
16+
test_inplace_increment, get_buffer_info
1717
)
1818
from numpy.testing import (
1919
TestCase, run_module_suite, assert_, assert_raises,
@@ -3368,6 +3368,10 @@ def test_export_endian(self):
33683368
else:
33693369
assert_equal(y.format, '<i')
33703370

3371+
def test_export_flags(self):
3372+
# Check SIMPLE flag, see also gh-3613 (exception should be BufferError)
3373+
assert_raises(ValueError, get_buffer_info, np.arange(5)[::2], ('SIMPLE',))
3374+
33713375
def test_padding(self):
33723376
for j in range(8):
33733377
x = np.array([(1,),(2,)], dtype={'f0': (int, j)})

0 commit comments

Comments
 (0)