@@ -29,6 +29,9 @@ inline SHARPY::id_type get_guid() { return ++_nguid; }
29
29
// Transceiver * theTransceiver = MPITransceiver();
30
30
31
31
template <typename T> T *mr_to_ptr (void *ptr, intptr_t offset) {
32
+ if (!ptr) {
33
+ throw std::runtime_error (" Fatal: cannot handle offset on nullptr" );
34
+ }
32
35
return reinterpret_cast <T *>(ptr) + offset;
33
36
}
34
37
@@ -91,7 +94,7 @@ uint64_t idtr_nprocs(SHARPY::Transceiver *tc) {
91
94
initMPIRuntime ();
92
95
tc = SHARPY::getTransceiver ();
93
96
#endif
94
- return tc->nranks ();
97
+ return tc ? tc ->nranks () : 1 ;
95
98
}
96
99
#pragma weak _idtr_nprocs = idtr_nprocs
97
100
#pragma weak _mlir_ciface__idtr_nprocs = idtr_nprocs
@@ -102,7 +105,7 @@ uint64_t idtr_prank(SHARPY::Transceiver *tc) {
102
105
initMPIRuntime ();
103
106
tc = SHARPY::getTransceiver ();
104
107
#endif
105
- return tc->rank ();
108
+ return tc ? tc ->rank () : 0 ;
106
109
}
107
110
#pragma weak _idtr_prank = idtr_prank
108
111
#pragma weak _mlir_ciface__idtr_prank = idtr_prank
@@ -226,6 +229,9 @@ mlir2sharpy(const ::imex::ndarray::DType dt) {
226
229
void bufferize (void *cptr, SHARPY::DTypeId dtype, const int64_t *sizes,
227
230
const int64_t *strides, const int64_t *tStarts,
228
231
const int64_t *tSizes, uint64_t nd, uint64_t N, void *out) {
232
+ if (!cptr || !sizes || !strides || !tStarts || !tSizes) {
233
+ return ;
234
+ }
229
235
dispatch (dtype, cptr,
230
236
[sizes, strides, tStarts, tSizes, nd, N, out](auto *ptr) {
231
237
auto buff = static_cast <decltype (ptr)>(out);
@@ -252,6 +258,9 @@ void bufferize(void *cptr, SHARPY::DTypeId dtype, const int64_t *sizes,
252
258
void unpack (void *in, SHARPY::DTypeId dtype, const int64_t *sizes,
253
259
const int64_t *strides, const int64_t *tStarts,
254
260
const int64_t *tSizes, uint64_t nd, uint64_t N, void *out) {
261
+ if (!in || !sizes || !strides || !tStarts || !tSizes || !out) {
262
+ return ;
263
+ }
255
264
dispatch (dtype, out, [sizes, strides, tStarts, tSizes, nd, N, in](auto *ptr) {
256
265
auto buff = static_cast <decltype (ptr)>(in);
257
266
@@ -276,6 +285,9 @@ template <typename T>
276
285
void copy_ (uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes,
277
286
const int64_t *strides, const uint64_t *chunks, uint64_t nd,
278
287
uint64_t start, uint64_t end, T *&out) {
288
+ if (!cptr || !sizes || !strides || !chunks || !out) {
289
+ return ;
290
+ }
279
291
auto stride = strides[d];
280
292
uint64_t sz = sizes[d];
281
293
uint64_t chunk = chunks[d];
@@ -311,6 +323,9 @@ void copy_(uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes,
311
323
void bufferizeN (void *cptr, SHARPY::DTypeId dtype, const int64_t *sizes,
312
324
const int64_t *strides, const int64_t *tStarts,
313
325
const int64_t *tEnds, uint64_t nd, uint64_t N, void *out) {
326
+ if (!cptr || !sizes || !strides || !tStarts || !tEnds || !out) {
327
+ return ;
328
+ }
314
329
std::vector<uint64_t > chunks (nd);
315
330
chunks[nd - 1 ] = 1 ;
316
331
for (uint64_t i = 1 ; i < nd; ++i) {
@@ -377,6 +392,10 @@ void _idtr_reshape(SHARPY::DTypeId sharpytype, int64_t lRank,
377
392
initMPIRuntime ();
378
393
tc = SHARPY::getTransceiver ();
379
394
#endif
395
+ if (!gShapePtr || !lDataPtr || !lShapePtr || !lStridesPtr || !lOffsPtr ||
396
+ !oGShapePtr || !oDataPtr || !oShapePtr || !oOffsPtr || !tc) {
397
+ throw std::runtime_error (" Fatal: received nullptr in reshape" );
398
+ }
380
399
381
400
assert (std::accumulate (&gShapePtr [0 ], &gShapePtr [lRank], 1 ,
382
401
std::multiplies<int64_t >()) ==
@@ -730,6 +749,13 @@ void *_idtr_update_halo(SHARPY::DTypeId sharpytype, int64_t ndims,
730
749
initMPIRuntime ();
731
750
tc = SHARPY::getTransceiver ();
732
751
#endif
752
+
753
+ if (!ownedOff || !ownedShape || !ownedStride || !bbOff || !bbShape ||
754
+ !ownedData || !leftHaloShape || !leftHaloStride || !leftHaloData ||
755
+ !rightHaloShape || !rightHaloStride || !rightHaloData || !tc) {
756
+ throw std::runtime_error (" Fatal error: received nullptr in update_halo." );
757
+ }
758
+
733
759
auto nworkers = tc->nranks ();
734
760
if (nworkers <= 1 || getenv (" SHARPY_SKIP_COMM" ))
735
761
return nullptr ;
@@ -824,6 +850,11 @@ void *_idtr_update_halo(SHARPY::Transceiver *tc, int64_t gShapeRank,
824
850
void *bbShapeDescr, int64_t lHaloRank, void *lHaloDescr,
825
851
int64_t rHaloRank, void *rHaloDescr, int64_t key) {
826
852
853
+ if (!gShapeDescr || !oOffDescr || !oDataDescr || !bbOffDescr ||
854
+ !bbShapeDescr || !lHaloDescr || !rHaloDescr) {
855
+ throw std::runtime_error (" Fatal error: received nullptr in update_halo." );
856
+ }
857
+
827
858
auto sharpytype = SHARPY::DTYPE<T>::value;
828
859
829
860
// Construct unranked memrefs for metadata and data
0 commit comments