Skip to content

Commit fb052a8

Browse files
authored
Check array bounds in set/getitem (#96)
1 parent ad440dc commit fb052a8

File tree

9 files changed

+134
-18
lines changed

9 files changed

+134
-18
lines changed

src/SetGetItem.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,8 @@ struct DeferredGather
127127
void *outPtr = nullptr;
128128
py::handle res;
129129
if (!sendonly || !trscvr) {
130-
auto tmp = a_ptr->shape();
131-
std::vector<ssize_t> tmpv(tmp, &tmp[a_ptr->ndims()]);
132-
// numpy treats 0d arrays as empty arrays, not as a scalar as we do
133-
if (tmpv.empty()) {
134-
tmpv.emplace_back(1);
135-
}
136-
res = dispatch<mk_array>(a_ptr->dtype(), std::move(tmpv), outPtr);
130+
std::vector<ssize_t> shp(a_ptr->shape());
131+
res = dispatch<mk_array>(a_ptr->dtype(), std::move(shp), outPtr);
137132
}
138133

139134
gather_array(a_ptr, _root, outPtr);
@@ -309,9 +304,37 @@ struct DeferredGetItem : public Deferred {
309304

310305
// ***************************************************************************
311306

307+
// extract "start", "stop", "step" int attrs from py::slice
308+
std::optional<int> getSliceAttr(const py::slice &slice, const char *name) {
309+
auto obj = getattr(slice, name);
310+
if (py::isinstance<py::none>(obj)) {
311+
return std::nullopt;
312+
} else if (py::isinstance<py::int_>(obj)) {
313+
return std::optional<int>{obj.cast<int>()};
314+
} else {
315+
throw std::invalid_argument("Invalid indices");
316+
}
317+
};
318+
319+
// check that multi-dimensional slice start does not exceed given shape
320+
void validateSlice(const shape_type &shape,
321+
const std::vector<py::slice> &slices) {
322+
auto dim = shape.size();
323+
for (std::size_t i = 0; i < dim; i++) {
324+
auto start = getSliceAttr(slices[i], "start");
325+
if (start && start.value() >= shape[i]) {
326+
std::stringstream msg;
327+
msg << "index " << start.value() << " is out of bounds for axis " << i
328+
<< " with size " << shape[i] << "\n";
329+
throw std::out_of_range(msg.str());
330+
}
331+
}
332+
}
333+
312334
FutureArray *GetItem::__getitem__(const FutureArray &a,
313335
const std::vector<py::slice> &v) {
314336
auto afut = a.get();
337+
validateSlice(afut.shape(), v);
315338
NDSlice slc(v, afut.shape());
316339
return new FutureArray(defer<DeferredGetItem>(afut, std::move(slc)));
317340
}
@@ -328,9 +351,10 @@ GetItem::py_future_type GetItem::gather(const FutureArray &a, rank_type root) {
328351
FutureArray *SetItem::__setitem__(FutureArray &a,
329352
const std::vector<py::slice> &v,
330353
const py::object &b) {
331-
auto bb =
332-
Creator::mk_future(b, a.get().device(), a.get().team(), a.get().dtype());
333-
a.put(defer<DeferredSetItem>(a.get(), bb.first->get(), v));
354+
auto afut = a.get();
355+
validateSlice(afut.shape(), v);
356+
auto bb = Creator::mk_future(b, afut.device(), afut.team(), afut.dtype());
357+
a.put(defer<DeferredSetItem>(afut, bb.first->get(), v));
334358
if (bb.second)
335359
delete bb.first;
336360
return &a;

src/include/sharpy/NDArray.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ class NDArray : public array_i, protected ArrayMeta {
9797
virtual DTypeId dtype() const override { return ArrayMeta::dtype(); }
9898

9999
/// @return array's shape
100-
virtual const int64_t *shape() const override {
101-
return ArrayMeta::shape().data();
100+
virtual const shape_type &shape() const override {
101+
return ArrayMeta::shape();
102102
}
103103

104104
/// @returnnumber of dimensions of array

src/include/sharpy/array_i.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class array_i {
106106
/// @return array's element type
107107
virtual DTypeId dtype() const = 0;
108108
/// @return array's shape
109-
virtual const int64_t *shape() const = 0;
109+
virtual const shape_type &shape() const = 0;
110110
/// @return number of dimensions of array
111111
virtual int ndims() const = 0;
112112
/// @return global number of elements in array

src/jit/mlir.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ ::mlir::Value DepManager::addDependent(::mlir::OpBuilder &builder,
314314
auto typ = getMRType(ndims, impl->owned_data(), elType);
315315
_func.insertArgument(idx, typ, {}, loc);
316316
_inputs.push_back(storeMR(impl->owned_data()));
317-
auto arTyp = getTType(builder, impl->dtype(), impl->device(),
318-
{impl->shape(), ndims});
317+
auto arTyp =
318+
getTType(builder, impl->dtype(), impl->device(), impl->shape());
319319
val = _builder.create<::imex::ndarray::FromMemRefOp>(
320320
loc, arTyp, _func.getArgument(idx));
321321
_lastIn += 1;
@@ -355,7 +355,7 @@ ::mlir::Value DepManager::addDependent(::mlir::OpBuilder &builder,
355355

356356
auto darTyp =
357357
getTType(builder, impl->dtype(), impl->device(), {ownShape, ndims},
358-
impl->team(), {impl->shape(), ndims}, {impl->local_offsets()},
358+
impl->team(), impl->shape(), {impl->local_offsets()},
359359
{lhShape, ndims}, {rhShape, ndims});
360360

361361
val = _builder.create<::imex::dist::InitDistArrayOp>(

test/test_create.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy
2+
import pytest
3+
from utils import device, dtypeIsInt, mpi_dtypes
4+
5+
import sharpy as sp
6+
7+
8+
@pytest.fixture(params=mpi_dtypes)
9+
def datatype(request):
10+
return request.param
11+
12+
13+
@pytest.fixture(params=[(), (6,), (6, 5), (6, 5, 4)])
14+
def shape(request):
15+
return request.param
16+
17+
18+
@pytest.fixture(
19+
params=[
20+
(sp.ones, 1.0),
21+
(sp.zeros, 0.0),
22+
],
23+
)
24+
def creator(request):
25+
return request.param[0], request.param[1]
26+
27+
28+
def test_create_datatypes(creator, datatype):
29+
shape = (6, 4)
30+
func, expected_value = creator
31+
a = func(shape, dtype=datatype, device=device)
32+
assert tuple(a.shape) == shape
33+
assert numpy.allclose(sp.to_numpy(a), [expected_value])
34+
35+
36+
def test_create_shapes(creator, shape):
37+
datatype = sp.int32
38+
func, expected_value = creator
39+
a = func(shape, dtype=datatype, device=device)
40+
assert tuple(a.shape) == shape
41+
assert numpy.allclose(sp.to_numpy(a), [expected_value])
42+
43+
44+
@pytest.mark.parametrize("expected_value", [5.0])
45+
def test_full_shapes(expected_value, shape):
46+
datatype = sp.int32
47+
value = int(expected_value) if dtypeIsInt(datatype) else expected_value
48+
a = sp.full(shape, value, dtype=datatype, device=device)
49+
assert tuple(a.shape) == shape
50+
assert numpy.allclose(sp.to_numpy(a), [expected_value])

test/test_setget.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,27 @@ def test_assign_bcast_scalar(self):
179179
a[:, :] = b
180180
a2 = sp.to_numpy(a)
181181
assert numpy.all(a2 == 2)
182+
183+
@pytest.fixture(
184+
params=[
185+
((6,), (slice(6, None))),
186+
((6,), (slice(7, 10))),
187+
((6, 5), (slice(7, None), slice(None, None))),
188+
((6, 5), (slice(None, None), slice(6, None))),
189+
((6, 5, 4), (slice(None, None), slice(None, None), slice(6, None))),
190+
],
191+
)
192+
def shape_and_slices(self, request):
193+
return request.param[0], request.param[1]
194+
195+
def test_get_invalid_bounds(self, shape_and_slices):
196+
shape, slices = shape_and_slices
197+
with pytest.raises(IndexError):
198+
a = sp.ones(shape, dtype=sp.float64)
199+
b = a[slices] # noqa: F841
200+
201+
def test_set_invalid_bounds(self, shape_and_slices):
202+
shape, slices = shape_and_slices
203+
with pytest.raises(IndexError):
204+
a = sp.ones(shape, dtype=sp.float64)
205+
a[slices] = 1.0

test/test_spmd.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import numpy as np
22
import pytest
3-
from mpi4py import MPI
43
from utils import device
54

65
import sharpy as sp
6+
from mpi4py import MPI
77
from sharpy import _sharpy_cw
88

99

@@ -75,6 +75,12 @@ def test_gather2(self):
7575
assert float(c) == v
7676
MPI.COMM_WORLD.barrier()
7777

78+
def test_gather_0d(self):
79+
a = sp.full((), 5, dtype=sp.int32, device=device)
80+
b = sp.spmd.gather(a)
81+
assert float(b) == 5
82+
MPI.COMM_WORLD.barrier()
83+
7884
@pytest.mark.skip(reason="FIXME reshape")
7985
def test_gather_strided1(self):
8086
a = sp.reshape(

test/test_ui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def sharpy_script(tmp_path):
1717
import os
1818
device = os.getenv("SHARPY_DEVICE", "")
1919
sp.init(False)
20-
a = a = sp.ones((4,), device=device)
20+
a = sp.ones((4,), device=device)
2121
assert a.size == 4
2222
print("SUCCESS")
2323
sp.fini()"""

test/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,15 @@ def runAndCompare(func, do_gather=True):
3333
sharpy.int8,
3434
sharpy.uint8,
3535
]
36+
37+
38+
def dtypeIsInt(dtype):
39+
mpi_int_types = [
40+
sharpy.int8,
41+
sharpy.int32,
42+
sharpy.int64,
43+
sharpy.uint8,
44+
sharpy.uint32,
45+
sharpy.uint64,
46+
]
47+
return dtype in mpi_int_types

0 commit comments

Comments
 (0)