Skip to content

Commit 7c98509

Browse files
authored
Fixes for GPU (#57)
Initial support for running on GPUS * register all MLIR passes and all dialects * isolate ddpt and mlir features; hide all non-ddpt stuff in libs; move all ddpt internals into namespace DDPT * use 32bit in examples and tests * scalars become 0d-tensors of equivalent type in ewbinops
1 parent 2d13221 commit 7c98509

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+639
-395
lines changed

CMakeLists.txt

Lines changed: 27 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
6868
#include(AddIMEX)
6969

7070
# macro for mlir root directory
71-
add_compile_definitions(CMAKE_MLIR_ROOT="${MLIR_ROOT}")
71+
add_compile_definitions(CMAKE_MLIR_ROOT="${MLIR_ROOT}" CMAKE_IMEX_ROOT="${IMEX_ROOT}")
7272

7373
#find_package(OpenMP)
7474

@@ -106,17 +106,18 @@ set(DDPTSrcs
106106
${PROJECT_SOURCE_DIR}/src/Random.cpp
107107
${PROJECT_SOURCE_DIR}/src/ReduceOp.cpp
108108
${PROJECT_SOURCE_DIR}/src/SetGetItem.cpp
109+
${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp
110+
${PROJECT_SOURCE_DIR}/src/Service.cpp
111+
${PROJECT_SOURCE_DIR}/src/Deferred.cpp
109112
)
110113
set(RTSrcs
114+
${PROJECT_SOURCE_DIR}/src/Mediator.cpp
115+
${PROJECT_SOURCE_DIR}/src/MPIMediator.cpp
111116
${PROJECT_SOURCE_DIR}/src/CollComm.cpp
112117
${PROJECT_SOURCE_DIR}/src/DDPTensorImpl.cpp
113-
${PROJECT_SOURCE_DIR}/src/Deferred.cpp
114118
${PROJECT_SOURCE_DIR}/src/Factory.cpp
115-
${PROJECT_SOURCE_DIR}/src/Mediator.cpp
116-
${PROJECT_SOURCE_DIR}/src/MPIMediator.cpp
117119
${PROJECT_SOURCE_DIR}/src/Registry.cpp
118-
${PROJECT_SOURCE_DIR}/src/Service.cpp
119-
${PROJECT_SOURCE_DIR}/src/jit/mlir.cpp
120+
${PROJECT_SOURCE_DIR}/src/_deferred.cpp
120121
)
121122
set(IDTRSrcs
122123
${PROJECT_SOURCE_DIR}/src/idtr.cpp
@@ -143,13 +144,17 @@ include_directories(
143144

144145
if (CMAKE_SYSTEM_NAME STREQUAL Linux)
145146
target_link_options(_ddptensor PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/export.txt")
147+
target_link_options(_ddpt_rt PRIVATE "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/export-ddpt_rt.txt")
148+
# target_link_options(idtr PRIVATE "LINKER:-fvisibility=hidden" "LINKER:--exclude-libs,All")
146149
endif()
147150

148151
#compile_options(_ddptensor PRIVATE -fopenmp)
149-
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
150-
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
151-
get_property(mlir_all_libs GLOBAL PROPERTY MLIR_ALL_LIBS)
152-
get_property(imex_all_libs GLOBAL PROPERTY IMEX_ALL_LIBS)
152+
get_property(mlir_dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
153+
get_property(mlir_conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
154+
get_property(mlir_extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
155+
get_property(mlir_translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
156+
get_property(imex_dialect_libs GLOBAL PROPERTY IMEX_DIALECT_LIBS)
157+
get_property(imex_conversion_libs GLOBAL PROPERTY IMEX_CONVERSION_LIBS)
153158

154159
#llvm_update_compile_flags(_ddpttensor)
155160
target_link_directories(_ddptensor PRIVATE ${CONDA_PREFIX}/lib)
@@ -159,6 +164,18 @@ target_link_directories(idtr PRIVATE ${CONDA_PREFIX}/lib)
159164
target_link_libraries(_ddptensor PRIVATE
160165
# ${MKL_LIBRARIES}
161166
# tbb
167+
${mlir_dialect_libs}
168+
${mlir_conversion_libs}
169+
${mlir_extension_libs}
170+
${mlir_translation_libs}
171+
MLIROptLib
172+
MLIRExecutionEngine
173+
${imex_dialect_libs}
174+
${imex_conversion_libs}
175+
IMEXTransforms
176+
IMEXUtil
177+
LLVMX86CodeGen
178+
LLVMX86AsmParser
162179
_ddpt_rt
163180
idtr
164181
)
@@ -171,71 +188,4 @@ target_link_libraries(_ddpt_rt PRIVATE
171188
${MPI_C_LIBRARIES}
172189
# ${MKL_LIBRARIES}
173190
tbb
174-
IMEXPTensorDialect
175-
IMEXPTensorTransforms
176-
IMEXPTensorToLinalg
177-
IMEXDistDialect
178-
IMEXDistTransforms
179-
IMEXDistToStandard
180-
IMEXDistRuntimeDialect
181-
IMEXDistRuntimeTransforms
182-
IMEXUtil
183-
IMEXTransforms
184-
MLIROptLib
185-
MLIRExecutionEngine
186-
MLIRIR
187-
MLIRAffineDialect
188-
MLIRAffineToStandard
189-
MLIRAffineTransforms
190-
MLIRFuncDialect
191-
MLIRFuncToLLVM
192-
MLIRFuncTransforms
193-
MLIRLinalgDialect
194-
MLIRLinalgTransforms
195-
MLIRLLVMDialect
196-
MLIRMathDialect
197-
MLIRMathToFuncs
198-
MLIRMathToLibm
199-
MLIRMathToLLVM
200-
MLIRMathTransforms
201-
MLIRMemRefDialect
202-
MLIRMemRefToLLVM
203-
MLIRMemRefTransforms
204-
MLIROpenMPDialect
205-
MLIROpenMPToLLVM
206-
MLIROpenMPToLLVMIRTranslation
207-
MLIRReconcileUnrealizedCasts
208-
MLIRSCFDialect
209-
MLIRSCFToOpenMP
210-
MLIRSCFToControlFlow
211-
MLIRSCFTransforms
212-
MLIRShapeDialect
213-
MLIRShapeOpsTransforms
214-
MLIRShapeToStandard
215-
MLIRTosaDialect
216-
MLIRTosaToLinalg
217-
MLIRTosaToTensor
218-
MLIRTensorTransforms
219191
)
220-
# LLVM${LLVM_NATIVE_ARCH}CodeGen
221-
# LLVM${LLVM_NATIVE_ARCH}Desc
222-
# LLVMTarget
223-
# MLIRAnalysis
224-
# MLIRCallInterfaces
225-
# MLIRCastInterfaces
226-
# MLIRGPUToGPURuntimeTransforms
227-
# MLIRGPUToSPIRV
228-
# MLIRLLVMCommonConversion
229-
# MLIRLLVMToLLVMIRTranslation
230-
# MLIRLinalgTransforms
231-
# MLIRMathToLibm
232-
# MLIRMemRef
233-
# MLIRParser
234-
# MLIRPass
235-
# MLIRReconcileUnrealizedCasts
236-
# MLIRSCFToGPU
237-
# MLIRSPIRVSerialization
238-
# MLIRSPIRVTransforms
239-
# MLIRSideEffectInterfaces
240-
# MLIRTargetLLVMIRExport
241-
# MLIRTransforms

examples/stencil-2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def main():
120120
# there is certainly a more Pythonic way to initialize W,
121121
# but it will have no impact on performance.
122122
t0 = timer()
123-
W = np.zeros(((2 * r + 1), (2 * r + 1)), dtype=np.float64)
124-
B = np.zeros((n, n), dtype=np.float64)
123+
W = np.zeros(((2 * r + 1), (2 * r + 1)), dtype=np.float32)
124+
B = np.zeros((n, n), dtype=np.float32)
125125

126126
if pattern == "star":
127127
stencil_size = 4 * r + 1
@@ -143,7 +143,7 @@ def main():
143143
W[r + j, r + j] = +1.0 / (4 * j * r)
144144
W[r - j, r - j] = -1.0 / (4 * j * r)
145145

146-
A = np.numpy.fromfunction(lambda i, j: i + j, (n, n), dtype=np.float64)
146+
A = np.numpy.fromfunction(lambda i, j: i + j, (n, n), dtype=np.float32)
147147

148148
for k in range(iterations + 1):
149149
# start timer after a warmup iteration

export-ddpt_rt.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
global:
3+
*DDPT*;
4+
local: *;
5+
};

scripts/code_gen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#include <pybind11/stl.h>
2020
namespace py = pybind11;
2121
#endif
22+
23+
namespace DDPT {
2224
"""
2325
)
2426

@@ -43,7 +45,7 @@
4345
print(f' .value("{x.upper()}", {x.upper()})')
4446
print(" .export_values();\n")
4547

46-
print("}\n#endif\n")
48+
print("}\n#endif\n} // namespace DDPT")
4749

4850
# Close the file
4951
sys.stdout.close()

src/CollComm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include "ddptensor/CollComm.hpp"
44

5+
namespace DDPT {
6+
57
void bufferize(DDPTensorImpl::ptr_type a_ptr, void *outPtr) {
68
dispatch(a_ptr->dtype(), a_ptr->data(), [&a_ptr, outPtr](auto *ptr) {
79
auto buff = static_cast<decltype(ptr)>(outPtr);
@@ -153,3 +155,4 @@ std::vector<std::vector<int>> CollComm::map(const PVSlice &n_slc,
153155
#endif // if 0
154156
return {};
155157
}
158+
} // namespace DDPT

src/Creator.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "ddptensor/Factory.hpp"
99
#include "ddptensor/Transceiver.hpp"
1010
#include "ddptensor/TypeDispatch.hpp"
11+
#include "ddptensor/jit/mlir.hpp"
1112

1213
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
1314
#include <imex/Utils/PassUtils.h>
@@ -18,6 +19,8 @@
1819
#include <mlir/Dialect/Tensor/IR/Tensor.h>
1920
#include <mlir/IR/Builders.h>
2021

22+
namespace DDPT {
23+
2124
static const char *FORCE_DIST = getenv("DDPT_FORCE_DIST");
2225

2326
inline uint64_t mkTeam(uint64_t team) {
@@ -36,8 +39,9 @@ struct DeferredFull : public Deferred {
3639
: Deferred(dtype, shape, team, true), _val(val) {}
3740

3841
template <typename T> struct ValAndDType {
39-
static ::mlir::Value op(::mlir::OpBuilder &builder, ::mlir::Location loc,
40-
const PyScalar &val, ::imex::ptensor::DType &dtyp) {
42+
static ::mlir::Value op(::mlir::OpBuilder &builder,
43+
const ::mlir::Location &loc, const PyScalar &val,
44+
::imex::ptensor::DType &dtyp) {
4145
dtyp = jit::PT_DTYPE<T>::value;
4246

4347
if (is_none(val)) {
@@ -54,7 +58,7 @@ struct DeferredFull : public Deferred {
5458
};
5559
};
5660

57-
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
61+
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
5862
jit::DepManager &dm) override {
5963
::mlir::SmallVector<::mlir::Value> shp(rank());
6064
for (auto i = 0; i < rank(); ++i) {
@@ -124,7 +128,7 @@ struct DeferredArange : public Deferred {
124128
team, true),
125129
_start(start), _end(end), _step(step) {}
126130

127-
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
131+
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
128132
jit::DepManager &dm) override {
129133
// ::mlir::Value
130134
auto transceiver = getTransceiver();
@@ -192,7 +196,7 @@ struct DeferredLinspace : public Deferred {
192196
: Deferred(dtype, {static_cast<shape_type::value_type>(num)}, team, true),
193197
_start(start), _end(end), _num(num), _endpoint(endpoint) {}
194198

195-
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
199+
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
196200
jit::DepManager &dm) override {
197201
// ::mlir::Value
198202
auto teamV = team() == 0
@@ -247,14 +251,15 @@ ddptensor *Creator::linspace(double start, double end, uint64_t num,
247251

248252
// ***************************************************************************
249253

254+
extern DTypeId DEFAULT_FLOAT;
255+
extern DTypeId DEFAULT_INT;
256+
250257
std::pair<ddptensor *, bool> Creator::mk_future(const py::object &b,
251-
uint64_t team) {
258+
uint64_t team, DTypeId dtype) {
252259
if (py::isinstance<ddptensor>(b)) {
253260
return {b.cast<ddptensor *>(), false};
254-
} else if (py::isinstance<py::float_>(b)) {
255-
return {Creator::full({}, b, FLOAT64, team), true};
256-
} else if (py::isinstance<py::int_>(b)) {
257-
return {Creator::full({}, b, INT64, team), true};
261+
} else if (py::isinstance<py::float_>(b) || py::isinstance<py::int_>(b)) {
262+
return {Creator::full({}, b, dtype, team), true};
258263
}
259264
throw std::runtime_error(
260265
"Invalid right operand to elementwise binary operation");
@@ -263,3 +268,4 @@ std::pair<ddptensor *, bool> Creator::mk_future(const py::object &b,
263268
FACTORY_INIT(DeferredFull, F_FULL);
264269
FACTORY_INIT(DeferredArange, F_ARANGE);
265270
FACTORY_INIT(DeferredLinspace, F_LINSPACE);
271+
} // namespace DDPT

src/DDPTensorImpl.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <algorithm>
1111
#include <iostream>
1212

13+
namespace DDPT {
14+
1315
DDPTensorImpl::DDPTensorImpl(
1416
Transceiver *transceiver, DTypeId dtype, shape_type gShape,
1517
void *l_allocated, void *l_aligned, intptr_t l_offset,
@@ -242,3 +244,4 @@ void DDPTensorImpl::replicate() {
242244
});
243245
set_owner(REPLICATED);
244246
}
247+
} // namespace DDPT

src/Deferred.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "include/ddptensor/Service.hpp"
1414
#include "include/ddptensor/Transceiver.hpp"
1515
#include "include/ddptensor/itac.hpp"
16+
#include "include/ddptensor/jit/mlir.hpp"
1617

1718
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
1819
#include <mlir/Dialect/Func/IR/FuncOps.h>
@@ -24,11 +25,10 @@ namespace py = pybind11;
2425

2526
#include <iostream>
2627

27-
// thread-safe FIFO queue holding deferred objects
28-
static tbb::concurrent_bounded_queue<Runable::ptr_type> _deferred;
28+
namespace DDPT {
2929

30-
// add a deferred object to the queue
31-
void push_runable(Runable::ptr_type &&r) { _deferred.push(std::move(r)); }
30+
// thread-safe FIFO queue holding deferred objects
31+
extern tbb::concurrent_bounded_queue<Runable::ptr_type> _deferred;
3232

3333
// if needed, object/promise is broadcasted to worker processes
3434
// (for controller/worker mode)
@@ -180,3 +180,4 @@ void process_promises() {
180180
}
181181
} while (!done);
182182
}
183+
} // namespace DDPT

src/EWBinOp.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,21 @@
88
#include "ddptensor/Broadcast.hpp"
99
#include "ddptensor/Creator.hpp"
1010
#include "ddptensor/DDPTensorImpl.hpp"
11+
#include "ddptensor/Deferred.hpp"
1112
#include "ddptensor/Factory.hpp"
1213
#include "ddptensor/LinAlgOp.hpp"
1314
#include "ddptensor/Registry.hpp"
1415
#include "ddptensor/TypeDispatch.hpp"
1516
#include "ddptensor/TypePromotion.hpp"
17+
#include "ddptensor/jit/mlir.hpp"
1618

1719
#include <imex/Dialect/Dist/IR/DistOps.h>
1820
#include <imex/Dialect/PTensor/IR/PTensorOps.h>
1921
#include <mlir/Dialect/Shape/IR/Shape.h>
2022
#include <mlir/IR/Builders.h>
2123

24+
namespace DDPT {
25+
2226
// convert id of our binop to id of imex::ptensor binop
2327
static ::imex::ptensor::EWBinOpId ddpt2mlir(const EWBinOpId bop) {
2428
switch (bop) {
@@ -91,7 +95,7 @@ struct DeferredEWBinOp : public Deferred {
9195
: Deferred(a.dtype(), broadcast(a.shape(), b.shape()), a.team(), true),
9296
_a(a.guid()), _b(b.guid()), _op(op) {}
9397

94-
bool generate_mlir(::mlir::OpBuilder &builder, ::mlir::Location loc,
98+
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
9599
jit::DepManager &dm) override {
96100
// FIXME the type of the result is based on a only
97101
auto av = dm.getDependent(builder, _a);
@@ -135,13 +139,20 @@ struct DeferredEWBinOp : public Deferred {
135139

136140
ddptensor *EWBinOp::op(EWBinOpId op, const py::object &a, const py::object &b) {
137141
uint64_t teama = 0, teamb = 0;
138-
if (py::isinstance<ddptensor>(a))
139-
teama = a.cast<ddptensor *>()->get().team();
140-
else if (py::isinstance<ddptensor>(b))
141-
teamb = b.cast<ddptensor *>()->get().team();
142-
auto team = teama ? teama : teamb;
143-
auto bb = Creator::mk_future(b, team);
144-
auto aa = Creator::mk_future(a, team);
142+
DTypeId dtypea = DTYPE_LAST, dtypeb = DTYPE_LAST;
143+
144+
if (py::isinstance<ddptensor>(a)) {
145+
auto tmp = a.cast<ddptensor *>()->get();
146+
teama = tmp.team();
147+
dtypea = tmp.dtype();
148+
}
149+
if (py::isinstance<ddptensor>(b)) {
150+
auto tmp = b.cast<ddptensor *>()->get();
151+
teamb = tmp.team();
152+
dtypeb = tmp.dtype();
153+
}
154+
auto aa = Creator::mk_future(a, teamb, dtypeb);
155+
auto bb = Creator::mk_future(b, teama, dtypea);
145156
if (bb.first->get().team() != aa.first->get().team()) {
146157
throw std::runtime_error(
147158
"teams of operands do not match in binary operation");
@@ -159,3 +170,4 @@ ddptensor *EWBinOp::op(EWBinOpId op, const py::object &a, const py::object &b) {
159170
}
160171

161172
FACTORY_INIT(DeferredEWBinOp, F_EWBINOP);
173+
} // namespace DDPT

0 commit comments

Comments
 (0)