7
7
#include < utility>
8
8
9
9
#include " ark/data_type.hpp"
10
+ #include " buffer_registry.hpp"
10
11
#include " env.h"
11
- #include " external_buffer_registry.hpp"
12
12
#include " file_io.h"
13
13
#include " logging.hpp"
14
14
#include " model/model_buffer.hpp"
@@ -56,9 +56,7 @@ class CodeGenerator::Impl {
56
56
public:
57
57
Impl (const PlanJson &plan,
58
58
const std::map<size_t , size_t > &buffer_id_to_offset,
59
- const std::map<size_t , std::pair<std::string, void *>>
60
- &buffer_id_to_kernel_arg,
61
- const std::string &name);
59
+ const std::set<size_t > &extra_buffer_ids, const std::string &name);
62
60
~Impl () = default ;
63
61
64
62
private:
@@ -83,7 +81,7 @@ class CodeGenerator::Impl {
83
81
friend class CodeGenerator ;
84
82
85
83
std::map<size_t , size_t > buffer_id_to_offset_;
86
- std::map <size_t , std::pair<std::string, void *>> buffer_id_to_kernel_arg_ ;
84
+ std::set <size_t > extra_buffer_ids_ ;
87
85
std::string name_;
88
86
int rank_;
89
87
int world_size_;
@@ -94,11 +92,10 @@ class CodeGenerator::Impl {
94
92
95
93
CodeGenerator::Impl::Impl (const PlanJson &plan,
96
94
const std::map<size_t , size_t > &buffer_id_to_offset,
97
- const std::map<size_t , std::pair<std::string, void *>>
98
- &buffer_id_to_kernel_arg,
95
+ const std::set<size_t > &extra_buffer_ids,
99
96
const std::string &name)
100
97
: buffer_id_to_offset_(buffer_id_to_offset),
101
- buffer_id_to_kernel_arg_ (buffer_id_to_kernel_arg ),
98
+ extra_buffer_ids_ (extra_buffer_ids ),
102
99
name_(name) {
103
100
rank_ = plan.at (" Rank" );
104
101
world_size_ = plan.at (" WorldSize" );
@@ -191,8 +188,8 @@ CodeGenerator::Impl::Impl(const PlanJson &plan,
191
188
192
189
// Generate the global arguments
193
190
std::stringstream global_args_ss, function_args_ss, arg_types_ss;
194
- for (const auto &[ buf_id, kernel_arg] : buffer_id_to_kernel_arg_ ) {
195
- const auto & arg_name = kernel_arg. first ;
191
+ for (auto buf_id : extra_buffer_ids_ ) {
192
+ std::string arg_name = " _ext_buf_ " + std::to_string (buf_id) ;
196
193
global_args_ss << " void *" << arg_name << " , " ;
197
194
function_args_ss << arg_name << " , " ;
198
195
arg_types_ss << " void *, " ;
@@ -263,6 +260,7 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) {
263
260
}
264
261
ss << " __device__ void t" << task_json[" Id" ]
265
262
<< " (char *_buf, int _idx, int _spw, @GLOBAL_ARGS@) {\n " ;
263
+ auto &buf_reg = BufferRegistry::get_instance ();
266
264
op_idx = 0 ;
267
265
for (auto &op_json : task_json[" Ops" ]) {
268
266
auto op = ModelOp::deserialize (op_json);
@@ -273,29 +271,36 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) {
273
271
if (arg.type_name () == " TENSOR" ) {
274
272
auto tns = arg.value <ModelTensorRef>();
275
273
size_t buffer_id = tns->buffer ()->id ();
276
- auto it = buffer_id_to_kernel_arg_.find (buffer_id);
277
- if (it == buffer_id_to_kernel_arg_.end ()) {
278
- size_t buffer_offset = buffer_id_to_offset_.at (buffer_id);
274
+ auto it = buffer_id_to_offset_.find (buffer_id);
275
+ auto buf_info = buf_reg.get (buffer_id);
276
+ if ((buf_info && buf_info->is_external ) ||
277
+ (it == buffer_id_to_offset_.end ())) {
278
+ ss << " (" << tns->data_type ()->type_str () << " *)_ext_buf_"
279
+ << buffer_id;
280
+ } else {
281
+ size_t buffer_offset;
282
+ buffer_offset = it->second ;
279
283
size_t offset = buffer_offset + ModelOffset (tns).value ();
280
284
ss << " (" << tns->data_type ()->type_str () << " *)&_buf["
281
285
<< offset << " ]" ;
282
- } else {
283
- const auto &name = it->second .first ;
284
- ss << " (" << tns->data_type ()->type_str () << " *)" << name;
285
286
}
286
287
} else if (arg.type_name () == " OFFSET" ) {
287
288
auto moff = arg.value <ModelOffset>();
288
289
size_t buffer_id = moff.buffer_id ();
289
- auto it = buffer_id_to_kernel_arg_.find (buffer_id);
290
- if (it == buffer_id_to_kernel_arg_.end ()) {
291
- size_t buffer_offset = buffer_id_to_offset_.at (buffer_id);
290
+ auto buf_info = buf_reg.get (buffer_id);
291
+ if (buf_info && buf_info->is_external ) {
292
+ size_t offset = moff.value ();
293
+ ss << " (uint64_t)((char*)_ext_buf_" << buffer_id << " + "
294
+ << offset << " )" ;
295
+ } else {
296
+ size_t buffer_offset;
297
+ auto it = buffer_id_to_offset_.find (buffer_id);
298
+ if (it == buffer_id_to_offset_.end ()) {
299
+ ERR (InternalError, " buffer ID not found: " , buffer_id);
300
+ }
301
+ buffer_offset = it->second ;
292
302
size_t offset = buffer_offset + moff.value ();
293
303
ss << offset;
294
- } else {
295
- const auto &name = it->second .first ;
296
- size_t offset = moff.value ();
297
- ss << " (uint64_t)((char*)" << name << " + " << offset
298
- << " )" ;
299
304
}
300
305
} else {
301
306
ss << arg.serialize ().begin ().value ();
@@ -496,11 +501,9 @@ std::string CodeGenerator::Impl::sync_process_range(const Range<size_t> &range,
496
501
497
502
CodeGenerator::CodeGenerator (
498
503
const PlanJson &plan, const std::map<size_t , size_t > &buffer_id_to_offset,
499
- const std::map<size_t , std::pair<std::string, void *>>
500
- &buffer_id_to_kernel_arg,
501
- const std::string &name)
502
- : impl_(std::make_shared<Impl>(plan, buffer_id_to_offset,
503
- buffer_id_to_kernel_arg, name)) {}
504
+ const std::set<size_t > &extra_buffer_ids, const std::string &name)
505
+ : impl_(std::make_shared<Impl>(plan, buffer_id_to_offset, extra_buffer_ids,
506
+ name)) {}
504
507
505
508
std::string CodeGenerator::code () const { return impl_->code_ ; }
506
509
0 commit comments