Skip to content

Commit 0ec7428

Browse files
authored
remove dtensortype (#66)
* updating to new imex without DistTensorType * adding device support * temporarily disable imex patches to llvm
1 parent 8f36592 commit 0ec7428

21 files changed

+255
-272
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ jobs:
8484
git remote add origin https://github.com/llvm/llvm-project || exit 1
8585
git fetch origin ${{ env.LLVM_SHA }} || exit 1
8686
git reset --hard FETCH_HEAD || exit 1
87-
if [ -d "$GITHUB_WORKSPACE/third_party/imex/build_tools/patches" ]; then git apply $GITHUB_WORKSPACE/third_party/imex/build_tools/patches/*.patch; fi
87+
# FIXME if [ -d "$GITHUB_WORKSPACE/third_party/imex/build_tools/patches" ]; then git apply $GITHUB_WORKSPACE/third_party/imex/build_tools/patches/*.patch; fi
8888
cd -
8989
mkdir -p build/llvm-mlir || exit 1
9090
cd build/llvm-mlir || exit 1

ddptensor/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,27 @@ def to_numpy(a):
6969
FUNC = func.upper()
7070
if func == "full":
7171
exec(
72-
f"{func} = lambda shape, val, dtype, team=1: dtensor(_cdt.Creator.full(shape, val, dtype, team))"
72+
f"{func} = lambda shape, val, dtype=float64, device='', team=1: dtensor(_cdt.Creator.full(shape, val, dtype, device, team))"
7373
)
7474
elif func == "empty":
7575
exec(
76-
f"{func} = lambda shape, dtype, team=1: dtensor(_cdt.Creator.full(shape, None, dtype, team))"
76+
f"{func} = lambda shape, dtype=float64, device='', team=1: dtensor(_cdt.Creator.full(shape, None, dtype, device, team))"
7777
)
7878
elif func == "ones":
7979
exec(
80-
f"{func} = lambda shape, dtype, team=1: dtensor(_cdt.Creator.full(shape, 1, dtype, team))"
80+
f"{func} = lambda shape, dtype=float64, device='', team=1: dtensor(_cdt.Creator.full(shape, 1, dtype, device, team))"
8181
)
8282
elif func == "zeros":
8383
exec(
84-
f"{func} = lambda shape, dtype, team=1: dtensor(_cdt.Creator.full(shape, 0, dtype, team))"
84+
f"{func} = lambda shape, dtype=float64, device='', team=1: dtensor(_cdt.Creator.full(shape, 0, dtype, device, team))"
8585
)
8686
elif func == "arange":
8787
exec(
88-
f"{func} = lambda start, end, step, dtype, team=1: dtensor(_cdt.Creator.arange(start, end, step, dtype, team))"
88+
f"{func} = lambda start, end, step, dtype=int64, device='', team=1: dtensor(_cdt.Creator.arange(start, end, step, dtype, device, team))"
8989
)
9090
elif func == "linspace":
9191
exec(
92-
f"{func} = lambda start, end, step, endpoint, dtype, team=1: dtensor(_cdt.Creator.linspace(start, end, step, endpoint, dtype, team))"
92+
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: dtensor(_cdt.Creator.linspace(start, end, step, endpoint, dtype, device, team))"
9393
)
9494

9595
for func in api.api_categories["ReduceOp"]:

ddptensor/numpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .. import empty, float32
22

33

4-
def fromfunction(function, shape, *, dtype=float32, team=1):
5-
t = empty(shape, dtype, team)
4+
def fromfunction(function, shape, *, dtype=float32, device="", team=1):
5+
t = empty(shape, dtype=dtype, device=device, team=team)
66
t._t.map(function)
77
return t

imex_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
06523703fa96f9731b9c3c1c7f6fee0914ef9f26
1+
dd921a5893d2956ddee3d7fecb84612edf15fbbe

src/Creator.cpp

Lines changed: 50 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "ddptensor/TypeDispatch.hpp"
1111
#include "ddptensor/jit/mlir.hpp"
1212

13+
#include <imex/Dialect/Dist/IR/DistOps.h>
1314
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
1415
#include <imex/Utils/PassUtils.h>
1516

@@ -35,8 +36,8 @@ struct DeferredFull : public Deferred {
3536

3637
DeferredFull() = default;
3738
DeferredFull(const shape_type &shape, PyScalar val, DTypeId dtype,
38-
uint64_t team)
39-
: Deferred(dtype, shape, team, true), _val(val) {}
39+
const std::string &device, uint64_t team)
40+
: Deferred(dtype, shape, device, team), _val(val) {}
4041

4142
template <typename T> struct ValAndDType {
4243
static ::mlir::Value op(::mlir::OpBuilder &builder,
@@ -67,35 +68,27 @@ struct DeferredFull : public Deferred {
6768

6869
::imex::ptensor::DType dtyp;
6970
::mlir::Value val = dispatch<ValAndDType>(_dtype, builder, loc, _val, dtyp);
70-
71-
auto transceiver = getTransceiver();
72-
auto teamV = team() == 0
73-
? ::mlir::Value()
74-
: ::imex::createIndex(loc, builder,
75-
reinterpret_cast<uint64_t>(team()));
76-
77-
auto rTyp = ::imex::ptensor::PTensorType::get(
78-
shape(), imex::ptensor::toMLIR(builder, dtyp));
79-
80-
dm.addVal(this->guid(),
81-
builder.create<::imex::ptensor::CreateOp>(loc, rTyp, shp, dtyp,
82-
val, nullptr, teamV),
83-
[this](uint64_t rank, void *l_allocated, void *l_aligned,
84-
intptr_t l_offset, const intptr_t *l_sizes,
85-
const intptr_t *l_strides, void *o_allocated,
86-
void *o_aligned, intptr_t o_offset,
87-
const intptr_t *o_sizes, const intptr_t *o_strides,
88-
void *r_allocated, void *r_aligned, intptr_t r_offset,
89-
const intptr_t *r_sizes, const intptr_t *r_strides,
90-
uint64_t *lo_allocated, uint64_t *lo_aligned) {
91-
assert(rank == this->rank());
92-
this->set_value(std::move(mk_tnsr(
93-
reinterpret_cast<Transceiver *>(this->team()), _dtype,
94-
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
95-
l_strides, o_allocated, o_aligned, o_offset, o_sizes,
96-
o_strides, r_allocated, r_aligned, r_offset, r_sizes,
97-
r_strides, lo_allocated, lo_aligned)));
98-
});
71+
auto envs = jit::mkEnvs(builder, rank(), _device, team());
72+
73+
dm.addVal(
74+
this->guid(),
75+
builder.create<::imex::ptensor::CreateOp>(loc, shp, dtyp, val, envs),
76+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
77+
intptr_t l_offset, const intptr_t *l_sizes,
78+
const intptr_t *l_strides, void *o_allocated, void *o_aligned,
79+
intptr_t o_offset, const intptr_t *o_sizes,
80+
const intptr_t *o_strides, void *r_allocated, void *r_aligned,
81+
intptr_t r_offset, const intptr_t *r_sizes,
82+
const intptr_t *r_strides, uint64_t *lo_allocated,
83+
uint64_t *lo_aligned) {
84+
assert(rank == this->rank());
85+
this->set_value(std::move(
86+
mk_tnsr(reinterpret_cast<Transceiver *>(this->team()), _dtype,
87+
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
88+
l_strides, o_allocated, o_aligned, o_offset, o_sizes,
89+
o_strides, r_allocated, r_aligned, r_offset, r_sizes,
90+
r_strides, lo_allocated, lo_aligned)));
91+
});
9992
return false;
10093
}
10194

@@ -109,9 +102,11 @@ struct DeferredFull : public Deferred {
109102
};
110103

111104
ddptensor *Creator::full(const shape_type &shape, const py::object &val,
112-
DTypeId dtype, uint64_t team) {
105+
DTypeId dtype, const std::string &device,
106+
uint64_t team) {
113107
auto v = mk_scalar(val, dtype);
114-
return new ddptensor(defer<DeferredFull>(shape, v, dtype, mkTeam(team)));
108+
return new ddptensor(
109+
defer<DeferredFull>(shape, v, dtype, device, mkTeam(team)));
115110
}
116111

117112
// ***************************************************************************
@@ -121,33 +116,25 @@ struct DeferredArange : public Deferred {
121116

122117
DeferredArange() = default;
123118
DeferredArange(uint64_t start, uint64_t end, uint64_t step, DTypeId dtype,
124-
uint64_t team)
119+
const std::string &device, uint64_t team)
125120
: Deferred(dtype,
126121
{static_cast<shape_type::value_type>(
127122
(end - start + step + (step < 0 ? 1 : -1)) / step)},
128-
team, true),
123+
device, team),
129124
_start(start), _end(end), _step(step) {}
130125

131126
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
132127
jit::DepManager &dm) override {
133-
// ::mlir::Value
134-
auto transceiver = getTransceiver();
135-
auto teamV = team() == 0
136-
? ::mlir::Value()
137-
: ::imex::createIndex(loc, builder,
138-
reinterpret_cast<uint64_t>(team()));
139-
140128
auto _num = shape()[0];
141-
142129
auto start = ::imex::createFloat(loc, builder, _start);
143130
auto stop = ::imex::createFloat(loc, builder, _start + _num * _step);
144131
auto num = ::imex::createIndex(loc, builder, _num);
145-
auto rTyp = ::imex::ptensor::PTensorType::get(
146-
shape(), imex::ptensor::toMLIR(builder, jit::getPTDType(_dtype)));
132+
auto dtyp = jit::getPTDType(dtype());
133+
auto envs = jit::mkEnvs(builder, rank(), _device, team());
147134

148135
dm.addVal(this->guid(),
149-
builder.create<::imex::ptensor::LinSpaceOp>(
150-
loc, rTyp, start, stop, num, false, nullptr, teamV),
136+
builder.create<::imex::ptensor::LinSpaceOp>(loc, start, stop, num,
137+
false, dtyp, envs),
151138
[this](uint64_t rank, void *l_allocated, void *l_aligned,
152139
intptr_t l_offset, const intptr_t *l_sizes,
153140
const intptr_t *l_strides, void *o_allocated,
@@ -157,7 +144,7 @@ struct DeferredArange : public Deferred {
157144
const intptr_t *r_sizes, const intptr_t *r_strides,
158145
uint64_t *lo_allocated, uint64_t *lo_aligned) {
159146
assert(rank == 1);
160-
assert(l_strides[0] == 1);
147+
assert(o_strides[0] == 1);
161148
this->set_value(std::move(mk_tnsr(
162149
reinterpret_cast<Transceiver *>(this->team()), _dtype,
163150
this->shape(), l_allocated, l_aligned, l_offset, l_sizes,
@@ -178,9 +165,10 @@ struct DeferredArange : public Deferred {
178165
};
179166

180167
ddptensor *Creator::arange(uint64_t start, uint64_t end, uint64_t step,
181-
DTypeId dtype, uint64_t team) {
168+
DTypeId dtype, const std::string &device,
169+
uint64_t team) {
182170
return new ddptensor(
183-
defer<DeferredArange>(start, end, step, dtype, mkTeam(team)));
171+
defer<DeferredArange>(start, end, step, dtype, device, mkTeam(team)));
184172
}
185173

186174
// ***************************************************************************
@@ -192,27 +180,22 @@ struct DeferredLinspace : public Deferred {
192180

193181
DeferredLinspace() = default;
194182
DeferredLinspace(double start, double end, uint64_t num, bool endpoint,
195-
DTypeId dtype, uint64_t team)
196-
: Deferred(dtype, {static_cast<shape_type::value_type>(num)}, team, true),
183+
DTypeId dtype, const std::string &device, uint64_t team)
184+
: Deferred(dtype, {static_cast<shape_type::value_type>(num)}, device,
185+
team),
197186
_start(start), _end(end), _num(num), _endpoint(endpoint) {}
198187

199188
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
200189
jit::DepManager &dm) override {
201-
// ::mlir::Value
202-
auto teamV = team() == 0
203-
? ::mlir::Value()
204-
: ::imex::createIndex(loc, builder,
205-
reinterpret_cast<uint64_t>(team()));
206-
207190
auto start = ::imex::createFloat(loc, builder, _start);
208191
auto stop = ::imex::createFloat(loc, builder, _end);
209192
auto num = ::imex::createIndex(loc, builder, _num);
210-
auto rTyp = ::imex::ptensor::PTensorType::get(
211-
shape(), imex::ptensor::toMLIR(builder, jit::getPTDType(_dtype)));
193+
auto dtyp = jit::getPTDType(dtype());
194+
auto envs = jit::mkEnvs(builder, rank(), _device, team());
212195

213196
dm.addVal(this->guid(),
214197
builder.create<::imex::ptensor::LinSpaceOp>(
215-
loc, rTyp, start, stop, num, _endpoint, nullptr, teamV),
198+
loc, start, stop, num, _endpoint, dtyp, envs),
216199
[this](uint64_t rank, void *l_allocated, void *l_aligned,
217200
intptr_t l_offset, const intptr_t *l_sizes,
218201
const intptr_t *l_strides, void *o_allocated,
@@ -244,9 +227,10 @@ struct DeferredLinspace : public Deferred {
244227
};
245228

246229
ddptensor *Creator::linspace(double start, double end, uint64_t num,
247-
bool endpoint, DTypeId dtype, uint64_t team) {
248-
return new ddptensor(
249-
defer<DeferredLinspace>(start, end, num, endpoint, dtype, mkTeam(team)));
230+
bool endpoint, DTypeId dtype,
231+
const std::string &device, uint64_t team) {
232+
return new ddptensor(defer<DeferredLinspace>(start, end, num, endpoint, dtype,
233+
device, mkTeam(team)));
250234
}
251235

252236
// ***************************************************************************
@@ -255,11 +239,12 @@ extern DTypeId DEFAULT_FLOAT;
255239
extern DTypeId DEFAULT_INT;
256240

257241
std::pair<ddptensor *, bool> Creator::mk_future(const py::object &b,
242+
const std::string &device,
258243
uint64_t team, DTypeId dtype) {
259244
if (py::isinstance<ddptensor>(b)) {
260245
return {b.cast<ddptensor *>(), false};
261246
} else if (py::isinstance<py::float_>(b) || py::isinstance<py::int_>(b)) {
262-
return {Creator::full({}, b, dtype, team), true};
247+
return {Creator::full({}, b, dtype, device, team), true};
263248
}
264249
throw std::runtime_error(
265250
"Invalid right operand to elementwise binary operation");

src/DDPTensorImpl.cpp

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ DDPTensorImpl::DDPTensorImpl(
2222
uint64_t *lo_allocated, uint64_t *lo_aligned, rank_type owner)
2323
: _owner(owner), _transceiver(transceiver), _gShape(gShape),
2424
_lo_allocated(lo_allocated), _lo_aligned(lo_aligned),
25-
_lhsHalo(gShape.size(), l_allocated, l_aligned, l_offset, l_sizes,
26-
l_strides),
27-
_lData(gShape.size(), o_allocated, o_aligned, o_offset, o_sizes,
28-
o_strides),
29-
_rhsHalo(gShape.size(), r_allocated, r_aligned, r_offset, r_sizes,
30-
r_strides),
25+
_lhsHalo(l_allocated ? gShape.size() : 0, l_allocated, l_aligned,
26+
l_offset, l_sizes, l_strides),
27+
_lData(o_allocated ? gShape.size() : 0, o_allocated, o_aligned, o_offset,
28+
o_sizes, o_strides),
29+
_rhsHalo(r_allocated ? gShape.size() : 0, r_allocated, r_aligned,
30+
r_offset, r_sizes, r_strides),
3131
_dtype(dtype) {
3232
if (ndims() == 0) {
3333
_owner = REPLICATED;
@@ -183,7 +183,7 @@ int64_t DDPTensorImpl::__int__() const {
183183
void DDPTensorImpl::add_to_args(std::vector<void *> &args) {
184184
int ndims = this->ndims();
185185
auto storeMR = [ndims](DynMemRef &mr) -> intptr_t * {
186-
intptr_t *buff = new intptr_t[dtensor_sz(ndims)];
186+
intptr_t *buff = new intptr_t[memref_sz(ndims)];
187187
buff[0] = reinterpret_cast<intptr_t>(mr._allocated);
188188
buff[1] = reinterpret_cast<intptr_t>(mr._aligned);
189189
buff[2] = static_cast<intptr_t>(mr._offset);
@@ -192,29 +192,22 @@ void DDPTensorImpl::add_to_args(std::vector<void *> &args) {
192192
return buff;
193193
}; // FIXME memory leak?
194194

195-
if (_transceiver == nullptr) {
195+
if (_transceiver == nullptr || ndims == 0) {
196196
// no-dist-mode
197197
args.push_back(storeMR(_lData));
198198
} else {
199-
// transceiver/team first
200-
// args.push_back(_transceiver);
201-
// local tensor first
202-
if (ndims > 0) {
203-
args.push_back(storeMR(_lhsHalo));
204-
args.push_back(storeMR(_lData));
205-
args.push_back(storeMR(_rhsHalo));
206-
assert(5 == memref_sz(1));
207-
// local offsets last
208-
auto buff = new intptr_t[dtensor_sz(1)];
209-
buff[0] = reinterpret_cast<intptr_t>(_lo_allocated);
210-
buff[1] = reinterpret_cast<intptr_t>(_lo_aligned);
211-
buff[2] = 0;
212-
buff[3] = ndims;
213-
buff[4] = 1;
214-
args.push_back(buff);
215-
} else {
216-
args.push_back(storeMR(_lData));
217-
}
199+
args.push_back(storeMR(_lhsHalo));
200+
args.push_back(storeMR(_lData));
201+
args.push_back(storeMR(_rhsHalo));
202+
// local offsets last
203+
auto buff = new intptr_t[memref_sz(1)];
204+
assert(5 == memref_sz(1));
205+
buff[0] = reinterpret_cast<intptr_t>(_lo_allocated);
206+
buff[1] = reinterpret_cast<intptr_t>(_lo_aligned);
207+
buff[2] = 0;
208+
buff[3] = ndims;
209+
buff[4] = 1;
210+
args.push_back(buff);
218211
}
219212
}
220213

src/Deferred.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ Deferred::future_type Deferred::get_future() {
4343
_guid,
4444
_dtype,
4545
_shape,
46-
_team,
47-
_balanced};
46+
_device,
47+
_team};
4848
}
4949

5050
// defer a tensor-producing computation by adding it to the queue.

0 commit comments

Comments
 (0)