3
3
*/
4
4
5
5
#include " sharpy/Creator.hpp"
6
- #include " sharpy/NDArray.hpp"
7
6
#include " sharpy/Deferred.hpp"
8
7
#include " sharpy/Factory.hpp"
8
+ #include " sharpy/NDArray.hpp"
9
9
#include " sharpy/Transceiver.hpp"
10
10
#include " sharpy/TypeDispatch.hpp"
11
11
#include " sharpy/jit/mlir.hpp"
@@ -82,12 +82,11 @@ struct DeferredFull : public Deferred {
82
82
const intptr_t *r_strides, uint64_t *lo_allocated,
83
83
uint64_t *lo_aligned) {
84
84
assert (rank == this ->rank ());
85
- this ->set_value (std::move (
86
- mk_tnsr (reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
87
- this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
88
- l_strides, o_allocated, o_aligned, o_offset, o_sizes,
89
- o_strides, r_allocated, r_aligned, r_offset, r_sizes,
90
- r_strides, lo_allocated, lo_aligned)));
85
+ this ->set_value (std::move (mk_tnsr (
86
+ this ->guid (), _dtype, this ->shape (), this ->device (), this ->team (),
87
+ l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated,
88
+ o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned,
89
+ r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
91
90
});
92
91
return false ;
93
92
}
@@ -102,8 +101,8 @@ struct DeferredFull : public Deferred {
102
101
};
103
102
104
103
FutureArray *Creator::full (const shape_type &shape, const py::object &val,
105
- DTypeId dtype, const std::string &device,
106
- uint64_t team) {
104
+ DTypeId dtype, const std::string &device,
105
+ uint64_t team) {
107
106
auto v = mk_scalar (val, dtype);
108
107
return new FutureArray (
109
108
defer<DeferredFull>(shape, v, dtype, device, mkTeam (team)));
@@ -132,26 +131,26 @@ struct DeferredArange : public Deferred {
132
131
auto dtyp = jit::getPTDType (dtype ());
133
132
auto envs = jit::mkEnvs (builder, rank (), _device, team ());
134
133
135
- dm.addVal (this -> guid (),
136
- builder. create <::imex::ndarray::LinSpaceOp>(loc, start, stop, num ,
137
- false , dtyp, envs) ,
138
- [ this ]( uint64_t rank, void *l_allocated, void *l_aligned ,
139
- intptr_t l_offset, const intptr_t *l_sizes ,
140
- const intptr_t *l_strides, void *o_allocated ,
141
- void *o_aligned, intptr_t o_offset ,
142
- const intptr_t *o_sizes , const intptr_t *o_strides ,
143
- void *r_allocated, void *r_aligned, intptr_t r_offset ,
144
- const intptr_t *r_sizes , const intptr_t *r_strides ,
145
- uint64_t *lo_allocated , uint64_t *lo_aligned) {
146
- assert (rank == 1 );
147
- assert (o_strides[ 0 ] == 1 );
148
- this -> set_value ( std::move ( mk_tnsr (
149
- reinterpret_cast <Transceiver *>( this ->team ()), _dtype,
150
- this ->shape (), l_allocated, l_aligned, l_offset, l_sizes ,
151
- l_strides, o_allocated, o_aligned, o_offset, o_sizes ,
152
- o_strides, r_allocated, r_aligned, r_offset, r_sizes ,
153
- r_strides, lo_allocated, lo_aligned)));
154
- });
134
+ dm.addVal (
135
+ this -> guid () ,
136
+ builder. create <::imex::ndarray::LinSpaceOp>(loc, start, stop, num ,
137
+ false , dtyp, envs) ,
138
+ [ this ]( uint64_t rank, void *l_allocated, void *l_aligned ,
139
+ intptr_t l_offset, const intptr_t *l_sizes ,
140
+ const intptr_t *l_strides, void *o_allocated, void *o_aligned ,
141
+ intptr_t o_offset , const intptr_t *o_sizes ,
142
+ const intptr_t *o_strides, void *r_allocated, void *r_aligned,
143
+ intptr_t r_offset , const intptr_t *r_sizes ,
144
+ const intptr_t *r_strides , uint64_t *lo_allocated,
145
+ uint64_t *lo_aligned) {
146
+ assert (rank == 1 );
147
+ assert (o_strides[ 0 ] == 1 );
148
+ this ->set_value ( std::move ( mk_tnsr (
149
+ this ->guid (), _dtype, this -> shape (), this -> device (), this -> team () ,
150
+ l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated ,
151
+ o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned,
152
+ r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
153
+ });
155
154
return false ;
156
155
}
157
156
@@ -165,8 +164,8 @@ struct DeferredArange : public Deferred {
165
164
};
166
165
167
166
FutureArray *Creator::arange (uint64_t start, uint64_t end, uint64_t step,
168
- DTypeId dtype, const std::string &device,
169
- uint64_t team) {
167
+ DTypeId dtype, const std::string &device,
168
+ uint64_t team) {
170
169
return new FutureArray (
171
170
defer<DeferredArange>(start, end, step, dtype, device, mkTeam (team)));
172
171
}
@@ -193,26 +192,26 @@ struct DeferredLinspace : public Deferred {
193
192
auto dtyp = jit::getPTDType (dtype ());
194
193
auto envs = jit::mkEnvs (builder, rank (), _device, team ());
195
194
196
- dm.addVal (this -> guid (),
197
- builder. create <::imex::ndarray::LinSpaceOp>(
198
- loc, start, stop, num, _endpoint, dtyp, envs) ,
199
- [ this ]( uint64_t rank, void *l_allocated, void *l_aligned ,
200
- intptr_t l_offset, const intptr_t *l_sizes ,
201
- const intptr_t *l_strides, void *o_allocated ,
202
- void *o_aligned, intptr_t o_offset ,
203
- const intptr_t *o_sizes , const intptr_t *o_strides ,
204
- void *r_allocated, void *r_aligned, intptr_t r_offset ,
205
- const intptr_t *r_sizes , const intptr_t *r_strides ,
206
- uint64_t *lo_allocated , uint64_t *lo_aligned) {
207
- assert (rank == 1 );
208
- assert (l_strides[ 0 ] == 1 );
209
- this -> set_value ( std::move ( mk_tnsr (
210
- reinterpret_cast <Transceiver *>( this ->team ()), _dtype,
211
- this ->shape (), l_allocated, l_aligned, l_offset, l_sizes ,
212
- l_strides, o_allocated, o_aligned, o_offset, o_sizes ,
213
- o_strides, r_allocated, r_aligned, r_offset, r_sizes ,
214
- r_strides, lo_allocated, lo_aligned)));
215
- });
195
+ dm.addVal (
196
+ this -> guid (),
197
+ builder. create <::imex::ndarray::LinSpaceOp>( loc, start, stop, num,
198
+ _endpoint, dtyp, envs) ,
199
+ [ this ]( uint64_t rank, void *l_allocated, void *l_aligned ,
200
+ intptr_t l_offset, const intptr_t *l_sizes ,
201
+ const intptr_t *l_strides, void *o_allocated, void *o_aligned ,
202
+ intptr_t o_offset , const intptr_t *o_sizes ,
203
+ const intptr_t *o_strides, void *r_allocated, void *r_aligned,
204
+ intptr_t r_offset , const intptr_t *r_sizes ,
205
+ const intptr_t *r_strides , uint64_t *lo_allocated,
206
+ uint64_t *lo_aligned) {
207
+ assert (rank == 1 );
208
+ assert (l_strides[ 0 ] == 1 );
209
+ this ->set_value ( std::move ( mk_tnsr (
210
+ this ->guid (), _dtype, this -> shape (), this -> device (), this -> team () ,
211
+ l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated ,
212
+ o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned,
213
+ r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
214
+ });
216
215
return false ;
217
216
}
218
217
@@ -227,10 +226,10 @@ struct DeferredLinspace : public Deferred {
227
226
};
228
227
229
228
FutureArray *Creator::linspace (double start, double end, uint64_t num,
230
- bool endpoint, DTypeId dtype,
231
- const std::string &device, uint64_t team) {
232
- return new FutureArray (defer<DeferredLinspace>(start, end, num, endpoint, dtype,
233
- device, mkTeam (team)));
229
+ bool endpoint, DTypeId dtype,
230
+ const std::string &device, uint64_t team) {
231
+ return new FutureArray (defer<DeferredLinspace>(start, end, num, endpoint,
232
+ dtype, device, mkTeam (team)));
234
233
}
235
234
236
235
// ***************************************************************************
@@ -239,8 +238,9 @@ extern DTypeId DEFAULT_FLOAT;
239
238
extern DTypeId DEFAULT_INT;
240
239
241
240
std::pair<FutureArray *, bool > Creator::mk_future (const py::object &b,
242
- const std::string &device,
243
- uint64_t team, DTypeId dtype) {
241
+ const std::string &device,
242
+ uint64_t team,
243
+ DTypeId dtype) {
244
244
if (py::isinstance<FutureArray>(b)) {
245
245
return {b.cast <FutureArray *>(), false };
246
246
} else if (py::isinstance<py::float_>(b) || py::isinstance<py::int_>(b)) {
0 commit comments