From dae1bb2afb09f3ccfaac93eb805984fe99325474 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 27 Sep 2024 11:50:49 +0000 Subject: [PATCH] fix --- ark/api/planner.cpp | 114 ++++++++++++++++++++------------ ark/context_impl.hpp | 2 +- examples/llama/model.py | 31 +++------ python/ark/profiler.py | 53 ++++++++++++++- python/unittest/test_planner.py | 6 +- 5 files changed, 136 insertions(+), 70 deletions(-) diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 8bf8c2f9..54bc1f89 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -17,8 +17,11 @@ namespace ark { PlannerContext::PlannerContext(Model &model) : Context(model) { - this->impl_->set("Id", this->id(), ContextType::Overwrite); - this->impl_->set("Sync", true, ContextType::Overwrite); + this->impl_->set("Id", id()); + Json val; + val.push_back(id()); + val.push_back(true); + this->impl_->set("Sync", val); } void PlannerContext::check_range(const std::string &key, @@ -28,7 +31,7 @@ void PlannerContext::check_range(const std::string &key, // ok return; } - auto prev_vec = prev.get>(); + auto prev_vec = prev[1].get>(); if (prev_vec.size() < 2 || prev_vec.size() > 3) { ERR(InternalError, "unexpected"); } @@ -42,43 +45,56 @@ void PlannerContext::check_range(const std::string &key, void PlannerContext::processor_range(int start, int end, int step) { check_range("ProcessorRange", {start, end, step}); + Json val; + val.push_back(id()); if (step == 1) { - this->impl_->set("ProcessorRange", {start, end}, - ContextType::Overwrite); + val.push_back({start, end}); + this->impl_->set("ProcessorRange", {id(), {start, end}}); } else { - this->impl_->set("ProcessorRange", {start, end, step}, - ContextType::Overwrite); + val.push_back({start, end, step}); + this->impl_->set("ProcessorRange", {id(), {start, end, step}}); } } void PlannerContext::warp_range(int start, int end, int step) { check_range("WarpRange", {start, end, step}); + Json val; + val.push_back(id()); if (step == 1) { - this->impl_->set("WarpRange", {start, end}, ContextType::Overwrite); + val.push_back({start, end}); + this->impl_->set("WarpRange", {id(), {start, end}}); } else { - this->impl_->set("WarpRange", {start, end, step}, - ContextType::Overwrite); + val.push_back({start, end, step}); + this->impl_->set("WarpRange", {id(), {start, end, step}}); } } void PlannerContext::sram_range(int start, int end, int step) { check_range("SramRange", {start, end, step}); + Json val; + val.push_back(id()); if (step == 1) { - this->impl_->set("SramRange", {start, end}, ContextType::Overwrite); + val.push_back({start, end}); + this->impl_->set("SramRange", {id(), {start, end}}); } else { - this->impl_->set("SramRange", {start, end, step}, - ContextType::Overwrite); + val.push_back({start, end, step}); + this->impl_->set("SramRange", {id(), {start, end, step}}); } } void PlannerContext::sync(bool sync) { // Sync should be always pushed with Id together. - this->impl_->set("Id", this->id(), ContextType::Overwrite); - this->impl_->set("Sync", sync, ContextType::Overwrite); + Json val; + val.push_back(id()); + val.push_back(sync); + this->impl_->set("Sync", val); } void PlannerContext::config(const std::string &config) { - this->impl_->set("Config", Json::parse(config), ContextType::Extend); + Json val; + val.push_back(id()); + val.push_back(Json::parse(config)); + this->impl_->set("Config", val); } class Planner::Impl { @@ -128,6 +144,7 @@ std::string Planner::Impl::plan(bool pretty) const { size_t max_warp_id = 1; size_t next_task_id = 0; int merge_root = -1; + int processor_group_root = -1; bool first_op = true; auto get_context = [&](const ModelNodeRef &node, @@ -150,12 +167,15 @@ std::string Planner::Impl::plan(bool pretty) const { const auto &op = node->op; if (op->is_virtual()) continue; - auto ctx_config = get_latest_context(node, "Config"); - - Json config; - if (!ctx_config.empty()) { - config = ctx_config; - } else if (!config_rules_.empty()) { + Json config = Json::object(); + for (auto &obj : get_context(node, "Config")) { + LOG(INFO, obj.dump()); + auto &items = obj[1]; + for (auto &item : items.items()) { + config[item.key()] = item.value(); + } + } + if (config.empty() && !config_rules_.empty()) { const std::string op_str = op->serialize().dump(); for (auto &rule : config_rules_) { auto config_str = rule(op_str, gpu_info.arch->name()); @@ -225,8 +245,8 @@ std::string Planner::Impl::plan(bool pretty) const { } size_t granularity = config.value("Granularity", 1); - auto ctx_id_list = get_context(node, "Id").get>(); - auto ctx_sync_list = get_context(node, "Sync").get>(); + auto ctx_id_list = get_context(node, "Id"); + auto ctx_sync_list = get_context(node, "Sync"); if (merge_root != -1) { bool not_found = true; for (auto ctx_id : ctx_id_list) { @@ -241,15 +261,11 @@ std::string Planner::Impl::plan(bool pretty) const { } bool merge_this_node = (merge_root != -1); if (merge_root == -1) { - size_t idx = 0; - for (; idx < ctx_sync_list.size(); idx++) { - if (!ctx_sync_list[idx]) { - if (ctx_id_list.size() <= idx) { - ERR(InternalError, - "ctx_id_list should have the same size as " - "ctx_sync_list"); - } - merge_root = ctx_id_list[idx]; + for (auto &item : ctx_sync_list) { + auto &ctx_id = item[0]; + auto &sync = item[1]; + if (!sync) { + merge_root = ctx_id; break; } } @@ -279,34 +295,46 @@ std::string Planner::Impl::plan(bool pretty) const { Json processor_group; Json resource_group; bool new_processor_group = true; + bool id_found = false; + for (auto &item : ctx_processor_range_list) { + if (item[0] == processor_group_root) { + id_found = true; + break; + } + } + if (!id_found) { + processor_group_root = -1; + } + if (ctx_processor_range_list.size() > 2) { + ERR(UnsupportedError, "ProcessorRange list size > 2"); + } if (ctx_processor_range_list.empty()) { size_t num_processors = std::min(num_sm, num_tasks); processor_group["ProcessorRange"] = {0, num_processors}; resource_group["ProcessorRange"] = {0, num_processors}; max_processor_id = std::max(max_processor_id, num_processors); - } else if (ctx_processor_range_list.size() == 1 || - !merge_this_node) { - auto &ctx_processor_range = ctx_processor_range_list.back(); - processor_group["ProcessorRange"] = ctx_processor_range; - resource_group["ProcessorRange"] = ctx_processor_range; + } else if (processor_group_root == -1) { + processor_group_root = ctx_processor_range_list.front()[0]; + processor_group["ProcessorRange"] = ctx_processor_range_list.front()[1]; + resource_group["ProcessorRange"] = ctx_processor_range_list.back()[1]; max_processor_id = std::max( - max_processor_id, ctx_processor_range[1].get()); + max_processor_id, ctx_processor_range_list.front()[1][1].get()); } else { new_processor_group = false; resource_group["ProcessorRange"] = - ctx_processor_range_list.back(); + ctx_processor_range_list.back()[1]; } if (!ctx_warp_range.empty()) { - resource_group["WarpRange"] = ctx_warp_range; + resource_group["WarpRange"] = ctx_warp_range[1]; max_warp_id = - std::max(max_warp_id, ctx_warp_range[1].get()); + std::max(max_warp_id, ctx_warp_range[1][1].get()); } else { resource_group["WarpRange"] = {0, num_warps}; max_warp_id = std::max(max_warp_id, num_warps); } if (!ctx_sram_range.empty()) { - resource_group["SramRange"] = ctx_sram_range; + resource_group["SramRange"] = ctx_sram_range[1]; } else { resource_group["SramRange"] = {0, sram_bytes}; } diff --git a/ark/context_impl.hpp b/ark/context_impl.hpp index 73fcae92..b7935329 100644 --- a/ark/context_impl.hpp +++ b/ark/context_impl.hpp @@ -17,7 +17,7 @@ class Context::Impl { Json get(const std::string& key) const; - void set(const std::string& key, const Json& value_json, ContextType type); + void set(const std::string& key, const Json& value_json, ContextType type = ContextType::Overwrite); bool has(const std::string& key) const; diff --git a/examples/llama/model.py b/examples/llama/model.py index 57ff7d9b..b69bcf2f 100644 --- a/examples/llama/model.py +++ b/examples/llama/model.py @@ -106,20 +106,10 @@ def forward(self, x): mean = ark.reduce_mean(x2, axis=-1) mean = ark.add(mean, self.eps) rrms = ark.rsqrt(mean) - - with Context( - warp_range=[0, 8], - sync=False, - config={ - "NumWarps": 1, - "SramBytes": 0, - "Tile": [1, 4096], - "Granularity": 7, - }, - ): - x = ark.mul(x, rrms) - x = ark.mul(x, self.weight, x) - return ark.cast(x, self.dtype) + with Context(config={"Tile": [1, 4096]}): + x = ark.mul(x, rrms) + x = ark.mul(x, self.weight, x) + return ark.cast(x, self.dtype) class ColumnParallelLinear(ark.Module): @@ -668,10 +658,11 @@ def forward( freqs_cis: ark.Tensor, mask: Optional[ark.Tensor], ): - h = self.tok_embeddings(tokens) + with Context(warp_range=[0, 8]): + h = self.tok_embeddings(tokens) - for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) - h = self.norm(h) - output = self.output(h) - return output + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h) + return output diff --git a/python/ark/profiler.py b/python/ark/profiler.py index e47f5b7a..f3ed5504 100644 --- a/python/ark/profiler.py +++ b/python/ark/profiler.py @@ -3,6 +3,7 @@ import sys import time +from typing import Optional, List from .runtime import Runtime from .planner import Plan @@ -26,10 +27,12 @@ def run( iter: int = 1000, loop_mode: bool = True, profile_processor_groups: bool = False, + target_processor_groups: Optional[List[int]] = None, ): - sys.stderr.write( - f"End-to-end: {timeit(self.plan, iter, loop_mode):.6f} seconds/iter\n" - ) + if target_processor_groups is None: + sys.stderr.write( + f"End-to-end: {timeit(self.plan, iter, loop_mode):.6f} seconds/iter\n" + ) if not profile_processor_groups: return @@ -44,8 +47,52 @@ def run( "ProcessorGroups": [None], } for i in range(num_processor_groups): + if target_processor_groups is not None and i not in target_processor_groups: + continue new_plan["ProcessorGroups"][0] = self.plan.processor_groups[i] lat_per_iter = timeit(Plan(new_plan), iter, loop_mode) sys.stderr.write( f"Processor group {i}: {lat_per_iter:.6f} seconds/iter\n" ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="ARK Profiler") + parser.add_argument( + "--iter", + type=int, + default=1000, + help="Number of iterations to run for each measurement", + ) + parser.add_argument( + "--loop_mode", + action="store_true", + help="Use loop mode to measure end-to-end latency", + ) + parser.add_argument( + "--profile_processor_groups", + action="store_true", + help="Profile processor groups", + ) + parser.add_argument( + "--target_processor_groups", + type=str, + help="Target processor groups to profile", + ) + parser.add_argument("--plan", type=str, help="Path to the plan file", required=True) + args = parser.parse_args() + + target_processor_groups = None + if args.target_processor_groups is not None: + target_processor_groups = list(map(int, args.target_processor_groups.split(","))) + + plan = Plan.from_file(args.plan) + profiler = Profiler(plan) + profiler.run( + iter=args.iter, + loop_mode=args.loop_mode, + profile_processor_groups=args.profile_processor_groups, + target_processor_groups=target_processor_groups, + ) diff --git a/python/unittest/test_planner.py b/python/unittest/test_planner.py index 94ad3ca4..0a739c71 100644 --- a/python/unittest/test_planner.py +++ b/python/unittest/test_planner.py @@ -18,9 +18,9 @@ def test_planner_processor_range(): plan = ark.Planner().plan() pg = plan.processor_groups - assert len(pg) == 2 - assert pg[0]["ProcessorRange"] == [0, 8] - assert pg[1]["ProcessorRange"] == [8, 16] + assert len(pg) == 1 + assert pg[0]["ResourceGroups"][0]["ProcessorRange"] == [0, 8] + assert pg[0]["ResourceGroups"][1]["ProcessorRange"] == [8, 16] @pytest_ark()