Skip to content

Commit 55f5c51

Browse files
PyTorch tracer (#243)
Co-authored-by: Noli Gerawork <[email protected]>
1 parent 66b78a7 commit 55f5c51

31 files changed

+1298
-1851
lines changed

ark/api/executor.cpp

+573-452
Large diffs are not rendered by default.

ark/api/tensor.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,27 @@ Dims Tensor::torch_strides() const {
6868
return Dims();
6969
}
7070

71+
void *Tensor::data() const {
72+
if (ref_) {
73+
return ref_->data();
74+
}
75+
return nullptr;
76+
}
77+
78+
void *Tensor::data(void *data) {
79+
if (ref_) {
80+
return ref_->data(data);
81+
}
82+
return nullptr;
83+
}
84+
85+
bool Tensor::is_external() const {
86+
if (ref_) {
87+
return ref_->is_external();
88+
}
89+
return false;
90+
}
91+
7192
std::ostream &operator<<(std::ostream &os, const Tensor &tensor) {
7293
if (tensor.is_null()) {
7394
os << "null";

ark/buffer_registry.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#include "buffer_registry.hpp"
5+
6+
#include "gpu/gpu_logging.hpp"
7+
8+
namespace ark {
9+
10+
BufferRegistry &BufferRegistry::get_instance() {
11+
static BufferRegistry instance;
12+
return instance;
13+
}
14+
15+
void BufferRegistry::set(size_t id, void *data, int device_id,
16+
bool is_external) {
17+
if (data != nullptr && device_id < 0) {
18+
gpuPointerAttributes attr;
19+
GLOG(gpuPointerGetAttributes(&attr, data));
20+
device_id = attr.device;
21+
}
22+
buffers_[id] =
23+
std::make_shared<BufferRegistry::Info>(data, device_id, is_external);
24+
}
25+
26+
std::shared_ptr<BufferRegistry::Info> BufferRegistry::get(size_t id) const {
27+
auto it = buffers_.find(id);
28+
if (it != buffers_.end()) {
29+
return it->second;
30+
}
31+
return nullptr;
32+
}
33+
34+
} // namespace ark

ark/buffer_registry.hpp

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#ifndef ARK_BUFFER_REGISTRY_HPP_
5+
#define ARK_BUFFER_REGISTRY_HPP_
6+
7+
#include <memory>
8+
#include <unordered_map>
9+
10+
namespace ark {
11+
12+
/// Manages addresses of all allocated buffers including externally managed
13+
/// buffers.
14+
class BufferRegistry {
15+
public:
16+
struct Info {
17+
Info(void *data, int device_id, bool is_external)
18+
: data(data), device_id(device_id), is_external(is_external) {}
19+
void *data;
20+
int device_id;
21+
bool is_external;
22+
};
23+
24+
~BufferRegistry() {}
25+
26+
static BufferRegistry &get_instance();
27+
28+
void set(size_t id, void *data, int device_id, bool is_external);
29+
30+
std::shared_ptr<Info> get(size_t id) const;
31+
32+
private:
33+
std::unordered_map<size_t, std::shared_ptr<Info>> buffers_;
34+
BufferRegistry() {}
35+
BufferRegistry(const BufferRegistry &) = delete;
36+
BufferRegistry &operator=(const BufferRegistry &) = delete;
37+
};
38+
39+
} // namespace ark
40+
41+
#endif // ARK_BUFFER_REGISTRY_HPP_

ark/codegen.cpp

+32-29
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#include <utility>
88

99
#include "ark/data_type.hpp"
10+
#include "buffer_registry.hpp"
1011
#include "env.h"
11-
#include "external_buffer_registry.hpp"
1212
#include "file_io.h"
1313
#include "logging.hpp"
1414
#include "model/model_buffer.hpp"
@@ -56,9 +56,7 @@ class CodeGenerator::Impl {
5656
public:
5757
Impl(const PlanJson &plan,
5858
const std::map<size_t, size_t> &buffer_id_to_offset,
59-
const std::map<size_t, std::pair<std::string, void *>>
60-
&buffer_id_to_kernel_arg,
61-
const std::string &name);
59+
const std::set<size_t> &extra_buffer_ids, const std::string &name);
6260
~Impl() = default;
6361

6462
private:
@@ -83,7 +81,7 @@ class CodeGenerator::Impl {
8381
friend class CodeGenerator;
8482

8583
std::map<size_t, size_t> buffer_id_to_offset_;
86-
std::map<size_t, std::pair<std::string, void *>> buffer_id_to_kernel_arg_;
84+
std::set<size_t> extra_buffer_ids_;
8785
std::string name_;
8886
int rank_;
8987
int world_size_;
@@ -94,11 +92,10 @@ class CodeGenerator::Impl {
9492

9593
CodeGenerator::Impl::Impl(const PlanJson &plan,
9694
const std::map<size_t, size_t> &buffer_id_to_offset,
97-
const std::map<size_t, std::pair<std::string, void *>>
98-
&buffer_id_to_kernel_arg,
95+
const std::set<size_t> &extra_buffer_ids,
9996
const std::string &name)
10097
: buffer_id_to_offset_(buffer_id_to_offset),
101-
buffer_id_to_kernel_arg_(buffer_id_to_kernel_arg),
98+
extra_buffer_ids_(extra_buffer_ids),
10299
name_(name) {
103100
rank_ = plan.at("Rank");
104101
world_size_ = plan.at("WorldSize");
@@ -191,8 +188,8 @@ CodeGenerator::Impl::Impl(const PlanJson &plan,
191188

192189
// Generate the global arguments
193190
std::stringstream global_args_ss, function_args_ss, arg_types_ss;
194-
for (const auto &[buf_id, kernel_arg] : buffer_id_to_kernel_arg_) {
195-
const auto &arg_name = kernel_arg.first;
191+
for (auto buf_id : extra_buffer_ids_) {
192+
std::string arg_name = "_ext_buf_" + std::to_string(buf_id);
196193
global_args_ss << "void *" << arg_name << ", ";
197194
function_args_ss << arg_name << ", ";
198195
arg_types_ss << "void *, ";
@@ -263,6 +260,7 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) {
263260
}
264261
ss << "__device__ void t" << task_json["Id"]
265262
<< "(char *_buf, int _idx, int _spw, @GLOBAL_ARGS@) {\n";
263+
auto &buf_reg = BufferRegistry::get_instance();
266264
op_idx = 0;
267265
for (auto &op_json : task_json["Ops"]) {
268266
auto op = ModelOp::deserialize(op_json);
@@ -273,29 +271,36 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) {
273271
if (arg.type_name() == "TENSOR") {
274272
auto tns = arg.value<ModelTensorRef>();
275273
size_t buffer_id = tns->buffer()->id();
276-
auto it = buffer_id_to_kernel_arg_.find(buffer_id);
277-
if (it == buffer_id_to_kernel_arg_.end()) {
278-
size_t buffer_offset = buffer_id_to_offset_.at(buffer_id);
274+
auto it = buffer_id_to_offset_.find(buffer_id);
275+
auto buf_info = buf_reg.get(buffer_id);
276+
if ((buf_info && buf_info->is_external) ||
277+
(it == buffer_id_to_offset_.end())) {
278+
ss << "(" << tns->data_type()->type_str() << "*)_ext_buf_"
279+
<< buffer_id;
280+
} else {
281+
size_t buffer_offset;
282+
buffer_offset = it->second;
279283
size_t offset = buffer_offset + ModelOffset(tns).value();
280284
ss << "(" << tns->data_type()->type_str() << "*)&_buf["
281285
<< offset << "]";
282-
} else {
283-
const auto &name = it->second.first;
284-
ss << "(" << tns->data_type()->type_str() << "*)" << name;
285286
}
286287
} else if (arg.type_name() == "OFFSET") {
287288
auto moff = arg.value<ModelOffset>();
288289
size_t buffer_id = moff.buffer_id();
289-
auto it = buffer_id_to_kernel_arg_.find(buffer_id);
290-
if (it == buffer_id_to_kernel_arg_.end()) {
291-
size_t buffer_offset = buffer_id_to_offset_.at(buffer_id);
290+
auto buf_info = buf_reg.get(buffer_id);
291+
if (buf_info && buf_info->is_external) {
292+
size_t offset = moff.value();
293+
ss << "(uint64_t)((char*)_ext_buf_" << buffer_id << " + "
294+
<< offset << ")";
295+
} else {
296+
size_t buffer_offset;
297+
auto it = buffer_id_to_offset_.find(buffer_id);
298+
if (it == buffer_id_to_offset_.end()) {
299+
ERR(InternalError, "buffer ID not found: ", buffer_id);
300+
}
301+
buffer_offset = it->second;
292302
size_t offset = buffer_offset + moff.value();
293303
ss << offset;
294-
} else {
295-
const auto &name = it->second.first;
296-
size_t offset = moff.value();
297-
ss << "(uint64_t)((char*)" << name << " + " << offset
298-
<< ")";
299304
}
300305
} else {
301306
ss << arg.serialize().begin().value();
@@ -496,11 +501,9 @@ std::string CodeGenerator::Impl::sync_process_range(const Range<size_t> &range,
496501

497502
CodeGenerator::CodeGenerator(
498503
const PlanJson &plan, const std::map<size_t, size_t> &buffer_id_to_offset,
499-
const std::map<size_t, std::pair<std::string, void *>>
500-
&buffer_id_to_kernel_arg,
501-
const std::string &name)
502-
: impl_(std::make_shared<Impl>(plan, buffer_id_to_offset,
503-
buffer_id_to_kernel_arg, name)) {}
504+
const std::set<size_t> &extra_buffer_ids, const std::string &name)
505+
: impl_(std::make_shared<Impl>(plan, buffer_id_to_offset, extra_buffer_ids,
506+
name)) {}
504507

505508
std::string CodeGenerator::code() const { return impl_->code_; }
506509

ark/codegen.hpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
#include <map>
88
#include <memory>
9+
#include <set>
910
#include <string>
10-
#include <utility>
1111

1212
#include "model/model_json.hpp"
1313

@@ -17,8 +17,7 @@ class CodeGenerator {
1717
public:
1818
CodeGenerator(const PlanJson &plan,
1919
const std::map<size_t, size_t> &buffer_id_to_offset,
20-
const std::map<size_t, std::pair<std::string, void *>>
21-
&buffer_id_to_kernel_arg,
20+
const std::set<size_t> &extra_buffer_ids,
2221
const std::string &name = "ark_kernel");
2322

2423
~CodeGenerator() = default;

ark/external_buffer_registry.cpp

-29
This file was deleted.

ark/external_buffer_registry.hpp

-31
This file was deleted.

ark/gpu/gpu_event.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@
77
#include "gpu/gpu_manager.hpp"
88

99
namespace ark {
10+
1011
class GpuEvent::Impl {
1112
public:
12-
Impl(bool disable_timing);
13+
Impl(int device_id, bool disable_timing);
1314
~Impl();
1415
Impl(const Impl&) = delete;
1516
Impl& operator=(const Impl&) = delete;
1617

18+
int device_id() const { return device_id_; }
1719
void record(gpuStream stream);
1820
float elapsed_msec(const GpuEvent& other) const;
1921

2022
private:
23+
int device_id_;
2124
gpuEvent event_;
2225
};
2326

24-
GpuEvent::Impl::Impl(bool disable_timing) {
27+
GpuEvent::Impl::Impl(int device_id, bool disable_timing)
28+
: device_id_(device_id) {
2529
unsigned int flags = 0;
2630
if (disable_timing) {
2731
flags |= gpuEventDisableTiming;
@@ -41,8 +45,10 @@ float GpuEvent::Impl::elapsed_msec(const GpuEvent& other) const {
4145
return elapsed;
4246
}
4347

44-
GpuEvent::GpuEvent(bool disable_timing)
45-
: pimpl_(std::make_shared<Impl>(disable_timing)) {}
48+
GpuEvent::GpuEvent(int device_id, bool disable_timing)
49+
: pimpl_(std::make_shared<Impl>(device_id, disable_timing)) {}
50+
51+
int GpuEvent::device_id() const { return pimpl_->device_id(); }
4652

4753
void GpuEvent::record(gpuStream stream) { pimpl_->record(stream); }
4854

ark/gpu/gpu_event.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ class GpuEvent {
1919
GpuEvent(const GpuEvent &) = delete;
2020
GpuEvent &operator=(const GpuEvent &) = delete;
2121

22+
int device_id() const;
2223
void record(gpuStream stream);
2324
float elapsed_msec(const GpuEvent &other) const;
2425

2526
protected:
2627
friend class GpuManager;
2728

28-
GpuEvent(bool disable_timing = false);
29+
GpuEvent(int device_id, bool disable_timing = false);
2930

3031
private:
3132
class Impl;

ark/gpu/gpu_manager.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ std::shared_ptr<GpuHostMemory> GpuManager::malloc_host(size_t bytes,
118118
}
119119

120120
std::shared_ptr<GpuEvent> GpuManager::create_event(bool disable_timing) const {
121-
return std::shared_ptr<GpuEvent>(new GpuEvent(disable_timing));
121+
return std::shared_ptr<GpuEvent>(
122+
new GpuEvent(pimpl_->gpu_id_, disable_timing));
122123
}
123124

124125
std::shared_ptr<GpuStream> GpuManager::create_stream() const {

0 commit comments

Comments
 (0)