4
4
#include " codegen.hpp"
5
5
6
6
#include < list>
7
+ #include < utility>
7
8
8
9
#include " ark/data_type.hpp"
9
10
#include " env.h"
11
+ #include " external_buffer_registry.hpp"
10
12
#include " file_io.h"
11
13
#include " logging.hpp"
12
14
#include " model/model_buffer.hpp"
13
15
#include " model/model_data_type.hpp"
14
16
#include " model/model_op.hpp"
15
17
#include " model/model_tensor.hpp"
16
- #include " model_buffer_manager.hpp"
17
18
#include " range.hpp"
18
19
#include " utils/utils_math.hpp"
19
20
@@ -55,8 +56,8 @@ class CodeGenerator::Impl {
55
56
public:
56
57
Impl (const PlanJson &plan,
57
58
const std::map<size_t , size_t > &buffer_id_to_offset,
58
- const std::vector< std::string> &external_args,
59
- const std::map< size_t , std::string> &buffer_id_to_name ,
59
+ const std::map< size_t , std::pair<std:: string, void *>>
60
+ &buffer_id_to_kernel_arg ,
60
61
const std::string &name);
61
62
~Impl () = default ;
62
63
@@ -82,8 +83,7 @@ class CodeGenerator::Impl {
82
83
friend class CodeGenerator ;
83
84
84
85
std::map<size_t , size_t > buffer_id_to_offset_;
85
- std::vector<std::string> external_args_;
86
- std::map<size_t , std::string> buffer_id_to_name_;
86
+ std::map<size_t , std::pair<std::string, void *>> buffer_id_to_kernel_arg_;
87
87
std::string name_;
88
88
int rank_;
89
89
int world_size_;
@@ -92,14 +92,13 @@ class CodeGenerator::Impl {
92
92
std::string code_;
93
93
};
94
94
95
- CodeGenerator::Impl::Impl (
96
- const PlanJson &plan, const std::map<size_t , size_t > &buffer_id_to_offset,
97
- const std::vector< std::string> &external_args,
98
- const std::map< size_t , std::string> &buffer_id_to_name ,
99
- const std::string &name)
95
+ CodeGenerator::Impl::Impl (const PlanJson &plan,
96
+ const std::map<size_t , size_t > &buffer_id_to_offset,
97
+ const std::map< size_t , std::pair<std:: string, void *>>
98
+ &buffer_id_to_kernel_arg ,
99
+ const std::string &name)
100
100
: buffer_id_to_offset_(buffer_id_to_offset),
101
- external_args_ (external_args),
102
- buffer_id_to_name_(buffer_id_to_name),
101
+ buffer_id_to_kernel_arg_ (buffer_id_to_kernel_arg),
103
102
name_(name) {
104
103
rank_ = plan.at (" Rank" );
105
104
world_size_ = plan.at (" WorldSize" );
@@ -192,9 +191,10 @@ CodeGenerator::Impl::Impl(
192
191
193
192
// Generate the global arguments
194
193
std::stringstream global_args_ss, function_args_ss, arg_types_ss;
195
- for (const auto &arg : external_args_) {
196
- global_args_ss << " void *" << arg << " , " ;
197
- function_args_ss << arg << " , " ;
194
+ for (const auto &[buf_id, kernel_arg] : buffer_id_to_kernel_arg_) {
195
+ const auto &arg_name = kernel_arg.first ;
196
+ global_args_ss << " void *" << arg_name << " , " ;
197
+ function_args_ss << arg_name << " , " ;
198
198
arg_types_ss << " void *, " ;
199
199
}
200
200
std::string global_args = global_args_ss.str ();
@@ -219,7 +219,7 @@ CodeGenerator::Impl::Impl(
219
219
{" @NUM_WARPS_PER_BLOCK@" , std::to_string (num_warps_per_proc_)},
220
220
{" @DEFINITIONS@" , definitions_ss.str ()},
221
221
{" @BODY@" , body_ss.str ()},
222
- {" @NAME@" , (name_.empty () ? " " : " _ " + name_)},
222
+ {" @NAME@" , (! name_.empty () ? " " : name_)},
223
223
{" @GLOBAL_ARGS@" , global_args},
224
224
{" @FUNCTION_ARGS@" , function_args},
225
225
{" @ARG_TYPES@" , arg_types},
@@ -273,29 +273,28 @@ std::string CodeGenerator::Impl::def_task(const Json &task_json) {
273
273
if (arg.type_name () == " TENSOR" ) {
274
274
auto tns = arg.value <ModelTensorRef>();
275
275
size_t buffer_id = tns->buffer ()->id ();
276
- if (buffer_id_to_name_ .find (buffer_id) ==
277
- buffer_id_to_name_ .end ()) {
276
+ auto it = buffer_id_to_kernel_arg_ .find (buffer_id);
277
+ if (it == buffer_id_to_kernel_arg_ .end ()) {
278
278
size_t buffer_offset = buffer_id_to_offset_.at (buffer_id);
279
279
size_t offset = buffer_offset + ModelOffset (tns).value ();
280
280
ss << " (" << tns->data_type ()->type_str () << " *)&_buf["
281
281
<< offset << " ]" ;
282
282
} else {
283
- ss << " ( " << tns-> data_type ()-> type_str () << " *) "
284
- << buffer_id_to_name_. at (buffer_id) ;
283
+ const auto &name = it-> second . first ;
284
+ ss << " ( " << tns-> data_type ()-> type_str () << " *) " << name ;
285
285
}
286
286
} else if (arg.type_name () == " OFFSET" ) {
287
287
auto moff = arg.value <ModelOffset>();
288
288
size_t buffer_id = moff.buffer_id ();
289
- if (buffer_id_to_name_ .find (buffer_id) ==
290
- buffer_id_to_name_ .end ()) {
289
+ auto it = buffer_id_to_kernel_arg_ .find (buffer_id);
290
+ if (it == buffer_id_to_kernel_arg_ .end ()) {
291
291
size_t buffer_offset = buffer_id_to_offset_.at (buffer_id);
292
292
size_t offset = buffer_offset + moff.value ();
293
293
ss << offset;
294
294
} else {
295
- const std::string &buffer_name =
296
- buffer_id_to_name_.at (buffer_id);
295
+ const auto &name = it->second .first ;
297
296
size_t offset = moff.value ();
298
- ss << " (uint64_t)((char*)" << buffer_name << " + " << offset
297
+ ss << " (uint64_t)((char*)" << name << " + " << offset
299
298
<< " )" ;
300
299
}
301
300
} else {
@@ -372,8 +371,7 @@ std::string CodeGenerator::Impl::resource_group(
372
371
n_slots = total_warps / num_warps_per_task;
373
372
}
374
373
if (n_slots == 0 ) {
375
- ERR (PlanError, " not enough resources for task group: " ,
376
- tg.dump ());
374
+ ERR (PlanError, " not enough resources for task group: " , tg.dump ());
377
375
}
378
376
379
377
size_t task_b = *task_range.begin ();
@@ -498,11 +496,11 @@ std::string CodeGenerator::Impl::sync_process_range(const Range<size_t> &range,
498
496
499
497
CodeGenerator::CodeGenerator (
500
498
const PlanJson &plan, const std::map<size_t , size_t > &buffer_id_to_offset,
501
- const std::vector< std::string> &external_args,
502
- const std::map< size_t , std::string> &buffer_id_to_name ,
499
+ const std::map< size_t , std::pair<std:: string, void *>>
500
+ &buffer_id_to_kernel_arg ,
503
501
const std::string &name)
504
- : impl_(std::make_shared<Impl>(plan, buffer_id_to_offset, external_args,
505
- buffer_id_to_name , name)) {}
502
+ : impl_(std::make_shared<Impl>(plan, buffer_id_to_offset,
503
+ buffer_id_to_kernel_arg , name)) {}
506
504
507
505
std::string CodeGenerator::code () const { return impl_->code_ ; }
508
506
0 commit comments