10
10
#include " ddptensor/TypeDispatch.hpp"
11
11
#include " ddptensor/jit/mlir.hpp"
12
12
13
+ #include < imex/Dialect/Dist/IR/DistOps.h>
13
14
#include < imex/Dialect/PTensor/IR/PTensorOps.h>
14
15
#include < imex/Utils/PassUtils.h>
15
16
@@ -35,8 +36,8 @@ struct DeferredFull : public Deferred {
35
36
36
37
DeferredFull () = default ;
37
38
DeferredFull (const shape_type &shape, PyScalar val, DTypeId dtype,
38
- uint64_t team)
39
- : Deferred(dtype, shape, team, true ), _val(val) {}
39
+ const std::string &device, uint64_t team)
40
+ : Deferred(dtype, shape, device, team ), _val(val) {}
40
41
41
42
template <typename T> struct ValAndDType {
42
43
static ::mlir::Value op (::mlir::OpBuilder &builder,
@@ -67,35 +68,27 @@ struct DeferredFull : public Deferred {
67
68
68
69
::imex::ptensor::DType dtyp;
69
70
::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
70
-
71
- auto transceiver = getTransceiver ();
72
- auto teamV = team () == 0
73
- ? ::mlir::Value ()
74
- : ::imex::createIndex (loc, builder,
75
- reinterpret_cast <uint64_t >(team ()));
76
-
77
- auto rTyp = ::imex::ptensor::PTensorType::get (
78
- shape (), imex::ptensor::toMLIR (builder, dtyp));
79
-
80
- dm.addVal (this ->guid (),
81
- builder.create <::imex::ptensor::CreateOp>(loc, rTyp, shp, dtyp,
82
- val, nullptr , teamV),
83
- [this ](uint64_t rank, void *l_allocated, void *l_aligned,
84
- intptr_t l_offset, const intptr_t *l_sizes,
85
- const intptr_t *l_strides, void *o_allocated,
86
- void *o_aligned, intptr_t o_offset,
87
- const intptr_t *o_sizes, const intptr_t *o_strides,
88
- void *r_allocated, void *r_aligned, intptr_t r_offset,
89
- const intptr_t *r_sizes, const intptr_t *r_strides,
90
- uint64_t *lo_allocated, uint64_t *lo_aligned) {
91
- assert (rank == this ->rank ());
92
- this ->set_value (std::move (mk_tnsr (
93
- reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
94
- this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
95
- l_strides, o_allocated, o_aligned, o_offset, o_sizes,
96
- o_strides, r_allocated, r_aligned, r_offset, r_sizes,
97
- r_strides, lo_allocated, lo_aligned)));
98
- });
71
+ auto envs = jit::mkEnvs (builder, rank (), _device, team ());
72
+
73
+ dm.addVal (
74
+ this ->guid (),
75
+ builder.create <::imex::ptensor::CreateOp>(loc, shp, dtyp, val, envs),
76
+ [this ](uint64_t rank, void *l_allocated, void *l_aligned,
77
+ intptr_t l_offset, const intptr_t *l_sizes,
78
+ const intptr_t *l_strides, void *o_allocated, void *o_aligned,
79
+ intptr_t o_offset, const intptr_t *o_sizes,
80
+ const intptr_t *o_strides, void *r_allocated, void *r_aligned,
81
+ intptr_t r_offset, const intptr_t *r_sizes,
82
+ const intptr_t *r_strides, uint64_t *lo_allocated,
83
+ uint64_t *lo_aligned) {
84
+ assert (rank == this ->rank ());
85
+ this ->set_value (std::move (
86
+ mk_tnsr (reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
87
+ this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
88
+ l_strides, o_allocated, o_aligned, o_offset, o_sizes,
89
+ o_strides, r_allocated, r_aligned, r_offset, r_sizes,
90
+ r_strides, lo_allocated, lo_aligned)));
91
+ });
99
92
return false ;
100
93
}
101
94
@@ -109,9 +102,11 @@ struct DeferredFull : public Deferred {
109
102
};
110
103
111
104
ddptensor *Creator::full (const shape_type &shape, const py::object &val,
112
- DTypeId dtype, uint64_t team) {
105
+ DTypeId dtype, const std::string &device,
106
+ uint64_t team) {
113
107
auto v = mk_scalar (val, dtype);
114
- return new ddptensor (defer<DeferredFull>(shape, v, dtype, mkTeam (team)));
108
+ return new ddptensor (
109
+ defer<DeferredFull>(shape, v, dtype, device, mkTeam (team)));
115
110
}
116
111
117
112
// ***************************************************************************
@@ -121,33 +116,25 @@ struct DeferredArange : public Deferred {
121
116
122
117
DeferredArange () = default ;
123
118
DeferredArange (uint64_t start, uint64_t end, uint64_t step, DTypeId dtype,
124
- uint64_t team)
119
+ const std::string &device, uint64_t team)
125
120
: Deferred(dtype,
126
121
{static_cast <shape_type::value_type>(
127
122
(end - start + step + (step < 0 ? 1 : -1 )) / step)},
128
- team, true ),
123
+ device, team ),
129
124
_start (start), _end(end), _step(step) {}
130
125
131
126
bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location &loc,
132
127
jit::DepManager &dm) override {
133
- // ::mlir::Value
134
- auto transceiver = getTransceiver ();
135
- auto teamV = team () == 0
136
- ? ::mlir::Value ()
137
- : ::imex::createIndex (loc, builder,
138
- reinterpret_cast <uint64_t >(team ()));
139
-
140
128
auto _num = shape ()[0 ];
141
-
142
129
auto start = ::imex::createFloat (loc, builder, _start);
143
130
auto stop = ::imex::createFloat (loc, builder, _start + _num * _step);
144
131
auto num = ::imex::createIndex (loc, builder, _num);
145
- auto rTyp = :: imex::ptensor::PTensorType::get (
146
- shape (), imex::ptensor::toMLIR (builder, jit::getPTDType (_dtype) ));
132
+ auto dtyp = jit::getPTDType ( dtype ());
133
+ auto envs = jit::mkEnvs (builder, rank (), _device, team ( ));
147
134
148
135
dm.addVal (this ->guid (),
149
- builder.create <::imex::ptensor::LinSpaceOp>(
150
- loc, rTyp, start, stop, num, false , nullptr , teamV ),
136
+ builder.create <::imex::ptensor::LinSpaceOp>(loc, start, stop, num,
137
+ false , dtyp, envs ),
151
138
[this ](uint64_t rank, void *l_allocated, void *l_aligned,
152
139
intptr_t l_offset, const intptr_t *l_sizes,
153
140
const intptr_t *l_strides, void *o_allocated,
@@ -157,7 +144,7 @@ struct DeferredArange : public Deferred {
157
144
const intptr_t *r_sizes, const intptr_t *r_strides,
158
145
uint64_t *lo_allocated, uint64_t *lo_aligned) {
159
146
assert (rank == 1 );
160
- assert (l_strides [0 ] == 1 );
147
+ assert (o_strides [0 ] == 1 );
161
148
this ->set_value (std::move (mk_tnsr (
162
149
reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
163
150
this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
@@ -178,9 +165,10 @@ struct DeferredArange : public Deferred {
178
165
};
179
166
180
167
ddptensor *Creator::arange (uint64_t start, uint64_t end, uint64_t step,
181
- DTypeId dtype, uint64_t team) {
168
+ DTypeId dtype, const std::string &device,
169
+ uint64_t team) {
182
170
return new ddptensor (
183
- defer<DeferredArange>(start, end, step, dtype, mkTeam (team)));
171
+ defer<DeferredArange>(start, end, step, dtype, device, mkTeam (team)));
184
172
}
185
173
186
174
// ***************************************************************************
@@ -192,27 +180,22 @@ struct DeferredLinspace : public Deferred {
192
180
193
181
DeferredLinspace () = default ;
194
182
DeferredLinspace (double start, double end, uint64_t num, bool endpoint,
195
- DTypeId dtype, uint64_t team)
196
- : Deferred(dtype, {static_cast <shape_type::value_type>(num)}, team, true ),
183
+ DTypeId dtype, const std::string &device, uint64_t team)
184
+ : Deferred(dtype, {static_cast <shape_type::value_type>(num)}, device,
185
+ team),
197
186
_start (start), _end(end), _num(num), _endpoint(endpoint) {}
198
187
199
188
bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location &loc,
200
189
jit::DepManager &dm) override {
201
- // ::mlir::Value
202
- auto teamV = team () == 0
203
- ? ::mlir::Value ()
204
- : ::imex::createIndex (loc, builder,
205
- reinterpret_cast <uint64_t >(team ()));
206
-
207
190
auto start = ::imex::createFloat (loc, builder, _start);
208
191
auto stop = ::imex::createFloat (loc, builder, _end);
209
192
auto num = ::imex::createIndex (loc, builder, _num);
210
- auto rTyp = :: imex::ptensor::PTensorType::get (
211
- shape (), imex::ptensor::toMLIR (builder, jit::getPTDType (_dtype) ));
193
+ auto dtyp = jit::getPTDType ( dtype ());
194
+ auto envs = jit::mkEnvs (builder, rank (), _device, team ( ));
212
195
213
196
dm.addVal (this ->guid (),
214
197
builder.create <::imex::ptensor::LinSpaceOp>(
215
- loc, rTyp, start, stop, num, _endpoint, nullptr , teamV ),
198
+ loc, start, stop, num, _endpoint, dtyp, envs ),
216
199
[this ](uint64_t rank, void *l_allocated, void *l_aligned,
217
200
intptr_t l_offset, const intptr_t *l_sizes,
218
201
const intptr_t *l_strides, void *o_allocated,
@@ -244,9 +227,10 @@ struct DeferredLinspace : public Deferred {
244
227
};
245
228
246
229
ddptensor *Creator::linspace (double start, double end, uint64_t num,
247
- bool endpoint, DTypeId dtype, uint64_t team) {
248
- return new ddptensor (
249
- defer<DeferredLinspace>(start, end, num, endpoint, dtype, mkTeam (team)));
230
+ bool endpoint, DTypeId dtype,
231
+ const std::string &device, uint64_t team) {
232
+ return new ddptensor (defer<DeferredLinspace>(start, end, num, endpoint, dtype,
233
+ device, mkTeam (team)));
250
234
}
251
235
252
236
// ***************************************************************************
@@ -255,11 +239,12 @@ extern DTypeId DEFAULT_FLOAT;
255
239
extern DTypeId DEFAULT_INT;
256
240
257
241
std::pair<ddptensor *, bool > Creator::mk_future (const py::object &b,
242
+ const std::string &device,
258
243
uint64_t team, DTypeId dtype) {
259
244
if (py::isinstance<ddptensor>(b)) {
260
245
return {b.cast <ddptensor *>(), false };
261
246
} else if (py::isinstance<py::float_>(b) || py::isinstance<py::int_>(b)) {
262
- return {Creator::full ({}, b, dtype, team), true };
247
+ return {Creator::full ({}, b, dtype, device, team), true };
263
248
}
264
249
throw std::runtime_error (
265
250
" Invalid right operand to elementwise binary operation" );
0 commit comments