From 97ab2dfc1ec59d29b13d283b325a79d14cb1ada5 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 27 Sep 2024 04:26:56 +0000 Subject: [PATCH] multiple resource groups --- ark/api/planner.cpp | 54 +++++++++++++++++-------- ark/model/model_graph_impl.cpp | 14 +------ ark/model/model_graph_impl.hpp | 2 - ark/model/model_node.hpp | 2 +- examples/llama/model.py | 72 ++++++++++++++++++---------------- 5 files changed, 78 insertions(+), 66 deletions(-) diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 506dcaff..56e0b5b0 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -138,17 +138,25 @@ std::string Planner::Impl::plan(bool pretty) const { auto get_context = [&](const ModelNodeRef &node, const std::string &key) -> Json { - if (node->context.find(key) != node->context.end()) { + try { return node->context.at(key); + } catch (const Json::out_of_range &e) { } return Json(); }; + auto get_latest_context = [&](const ModelNodeRef &node, + const std::string &key) -> Json { + auto ctx = get_context(node, key); + if (ctx.empty()) return Json(); + return ctx.back(); + }; + for (const auto &node : model_.nodes()) { const auto &op = node->op; if (op->is_virtual()) continue; - auto ctx_config = get_context(node, "Config"); + auto ctx_config = get_latest_context(node, "Config"); Json config; if (!ctx_config.empty()) { @@ -223,8 +231,8 @@ std::string Planner::Impl::plan(bool pretty) const { } size_t granularity = config.value("Granularity", 1); - auto ctx_id = get_context(node, "Id"); - auto ctx_sync = get_context(node, "Sync"); + auto ctx_id = get_latest_context(node, "Id"); + auto ctx_sync = get_latest_context(node, "Sync"); int id = ctx_id.empty() ? -1 : ctx_id.get(); bool sync = ctx_sync.empty() ? true : ctx_sync.get(); if (id == prev_ctx_id && !sync) { @@ -245,24 +253,31 @@ std::string Planner::Impl::plan(bool pretty) const { task_info["Ops"][0]["Config"] = config; task_infos.push_back(task_info); - auto ctx_processor_range = get_context(node, "ProcessorRange"); - auto ctx_warp_range = get_context(node, "WarpRange"); - auto ctx_sram_range = get_context(node, "SramRange"); + auto ctx_processor_range_list = get_context(node, "ProcessorRange"); + auto ctx_warp_range = get_latest_context(node, "WarpRange"); + auto ctx_sram_range = get_latest_context(node, "SramRange"); Json processor_group; - if (!ctx_processor_range.empty()) { + Json resource_group; + bool new_processor_group = true; + if (ctx_processor_range_list.empty()) { + size_t num_processors = std::min(num_sm, num_tasks); + processor_group["ProcessorRange"] = {0, num_processors}; + resource_group["ProcessorRange"] = {0, num_processors}; + max_processor_id = std::max(max_processor_id, num_processors); + } else if (ctx_processor_range_list.size() == 1 || + (id != prev_ctx_id)) { + auto &ctx_processor_range = ctx_processor_range_list[0]; processor_group["ProcessorRange"] = ctx_processor_range; + resource_group["ProcessorRange"] = ctx_processor_range; max_processor_id = std::max( max_processor_id, ctx_processor_range[1].get()); } else { - size_t num_processors = std::min(num_sm, num_tasks); - processor_group["ProcessorRange"] = {0, num_processors}; - max_processor_id = std::max(max_processor_id, num_processors); + new_processor_group = false; + resource_group["ProcessorRange"] = + ctx_processor_range_list.back(); } - Json resource_group; - resource_group["ProcessorRange"] = - processor_group["ProcessorRange"]; if (!ctx_warp_range.empty()) { resource_group["WarpRange"] = ctx_warp_range; max_warp_id = @@ -280,9 +295,14 @@ std::string Planner::Impl::plan(bool pretty) const { {"TaskRange", {0, num_tasks}}, {"Granularity", granularity}}}; - processor_group["ResourceGroups"] = Json::array(); - processor_group["ResourceGroups"].push_back(resource_group); - processor_groups.push_back(processor_group); + if (new_processor_group) { + processor_group["ResourceGroups"] = Json::array(); + processor_group["ResourceGroups"].push_back(resource_group); + processor_groups.push_back(processor_group); + } else { + processor_groups.back()["ResourceGroups"].push_back( + resource_group); + } } prev_ctx_id = id; first_op = false; diff --git a/ark/model/model_graph_impl.cpp b/ark/model/model_graph_impl.cpp index b7717ecd..7c72a7dd 100644 --- a/ark/model/model_graph_impl.cpp +++ b/ark/model/model_graph_impl.cpp @@ -52,18 +52,8 @@ Json ModelGraphContextStack::get(const std::string &key) const { return Json(); } -std::map ModelGraphContextStack::get_all() const { - std::map cur; - for (const auto &pair : this->storage_) { - if (!pair.second.empty()) { - cur[pair.first] = *pair.second.back(); - } - } - return cur; -} - Json ModelGraphContextStack::dump() const { - Json j; + Json j = Json::object(); for (const auto &pair : this->storage_) { j[pair.first] = Json::array(); for (const auto &value : pair.second) { @@ -227,7 +217,7 @@ ModelNodeRef ModelGraph::Impl::add_op(ModelOpRef op) { producer->consumers.push_back(node); } - node->context = context_stack_->get_all(); + node->context = context_stack_->dump(); nodes_.push_back(node); return node; diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index 5cd60d03..b9646d05 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -38,8 +38,6 @@ class ModelGraphContextStack { Json get(const std::string &key) const; - std::map get_all() const; - Json dump() const; }; diff --git a/ark/model/model_node.hpp b/ark/model/model_node.hpp index ca97f454..43787567 100644 --- a/ark/model/model_node.hpp +++ b/ark/model/model_node.hpp @@ -28,7 +28,7 @@ class ModelNode { UniqueList producers; /// Graph context of this node. - std::map context; + Json context; }; } // namespace ark diff --git a/examples/llama/model.py b/examples/llama/model.py index f80d68e5..57ff7d9b 100644 --- a/examples/llama/model.py +++ b/examples/llama/model.py @@ -495,27 +495,29 @@ def forward( scores = ark.tensor([bsz, self.n_local_heads, seqlen, seqlen], dtype=self.dtype) scores_shards = ark.sharding(scores, axis=1, dim_per_shard=1) results = [] - with Context( - warp_range=[0, 8], - sram_range=[0, 49344], - sync=False, - config={ - "NumWarps": 4, - "Granularity": 2, - "SramBytes": 24672, - "Tile": [256, 128], - }, - ): + with Context(processor_range=[0, 304]): for i in range(len(scores_shards)): - xq_shard_reshaped = ark.reshape(xq_shards[i], [bsz, 1, seqlen, self.head_dim]) - keys_shard_reshaped = ark.reshape(keys_shards[i], [bsz, 1, seqlen, self.head_dim]) - scores_shard_reshaped = ark.reshape(scores_shards[i], [bsz, 1, seqlen, seqlen]) - res = ark.matmul(xq_shard_reshaped, keys_shard_reshaped, scores_shard_reshaped, transpose_other=True) - res = ark.mul(res, 1.0 / math.sqrt(self.head_dim), res) - if mask is not None: - res = ark.add(res, mask, res) + with Context( + processor_range=[i*8, (i+1)*8], + warp_range=[0, 8], + sram_range=[0, 49344], + sync=False, + config={ + "NumWarps": 4, + "Granularity": 2, + "SramBytes": 24672, + "Tile": [256, 128], + }, + ): + xq_shard_reshaped = ark.reshape(xq_shards[i], [bsz, 1, seqlen, self.head_dim]) + keys_shard_reshaped = ark.reshape(keys_shards[i], [bsz, 1, seqlen, self.head_dim]) + scores_shard_reshaped = ark.reshape(scores_shards[i], [bsz, 1, seqlen, seqlen]) + res = ark.matmul(xq_shard_reshaped, keys_shard_reshaped, scores_shard_reshaped, transpose_other=True) + res = ark.mul(res, 1.0 / math.sqrt(self.head_dim), res) + if mask is not None: + res = ark.add(res, mask, res) results.append(res) - scores = ark.identity(scores, deps=results) + scores = ark.identity(scores, deps=results) def softmax(scores): with Context( @@ -546,22 +548,24 @@ def softmax(scores): output_shards = ark.sharding(output, axis=2, dim_per_shard=1) results = [] - with Context( - warp_range=[0, 4], - sram_range=[0, 24672], - sync=False, - config={ - "NumWarps": 4, - "SramBytes": 24672, - "Tile": [256, 128], - }, - ): + with Context(processor_range=[0, 304]): for i in range(len(output_shards)): - values_shard_reshaped = ark.reshape(values_shards[i], [bsz, 1, seqlen, self.head_dim]) - scores_shard_reshaped = ark.reshape(scores_shards[i], [bsz, 1, seqlen, seqlen]) - output_shard_reshaped = ark.reshape(output_shards[i], [bsz, 1, seqlen, self.head_dim]) - res = ark.matmul(scores_shard_reshaped, values_shard_reshaped, output_shard_reshaped) - results.append(res) + with Context( + processor_range=[i*8, (i+1)*8], + warp_range=[0, 4], + sram_range=[0, 24672], + sync=False, + config={ + "NumWarps": 4, + "SramBytes": 24672, + "Tile": [256, 128], + }, + ): + values_shard_reshaped = ark.reshape(values_shards[i], [bsz, 1, seqlen, self.head_dim]) + scores_shard_reshaped = ark.reshape(scores_shards[i], [bsz, 1, seqlen, seqlen]) + output_shard_reshaped = ark.reshape(output_shards[i], [bsz, 1, seqlen, self.head_dim]) + res = ark.matmul(scores_shard_reshaped, values_shard_reshaped, output_shard_reshaped) + results.append(res) output = ark.identity(output, deps=results) output = ark.reshape( output, [bsz, seqlen, self.head_dim * self.n_local_heads]