From 2a0e1a0fc3b09b60d40c3833b6c5551d1fcfe8aa Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Fri, 27 Sep 2024 09:21:41 +0000 Subject: [PATCH] fix --- ark/api/planner.cpp | 58 +++++++++++++++++++++------------ python/ark/planner.py | 4 ++- python/unittest/test_planner.py | 40 +++++++++++++++++++++++ 3 files changed, 81 insertions(+), 21 deletions(-) create mode 100644 python/unittest/test_planner.py diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 56e0b5b0..8bf8c2f9 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -18,6 +18,7 @@ namespace ark { PlannerContext::PlannerContext(Model &model) : Context(model) { this->impl_->set("Id", this->id(), ContextType::Overwrite); + this->impl_->set("Sync", true, ContextType::Overwrite); } void PlannerContext::check_range(const std::string &key, @@ -71,16 +72,9 @@ void PlannerContext::sram_range(int start, int end, int step) { } void PlannerContext::sync(bool sync) { - if (sync) { - // `true` should not overwrite `false`. - if (this->impl_->get("Sync") == Json(false)) { - LOG(WARN, "Ignoring sync(true) while sync(false) is already set"); - return; - } - this->impl_->set("Sync", true, ContextType::Immutable); - } else { - this->impl_->set("Sync", false, ContextType::Overwrite); - } + // Sync should be always pushed with Id together. + this->impl_->set("Id", this->id(), ContextType::Overwrite); + this->impl_->set("Sync", sync, ContextType::Overwrite); } void PlannerContext::config(const std::string &config) { @@ -133,7 +127,7 @@ std::string Planner::Impl::plan(bool pretty) const { size_t max_processor_id = 1; size_t max_warp_id = 1; size_t next_task_id = 0; - int prev_ctx_id = -1; + int merge_root = -1; bool first_op = true; auto get_context = [&](const ModelNodeRef &node, @@ -142,7 +136,7 @@ std::string Planner::Impl::plan(bool pretty) const { return node->context.at(key); } catch (const Json::out_of_range &e) { } - return Json(); + return Json::array(); }; auto get_latest_context = [&](const ModelNodeRef &node, @@ -231,11 +225,36 @@ std::string Planner::Impl::plan(bool pretty) const { } size_t granularity = config.value("Granularity", 1); - auto ctx_id = get_latest_context(node, "Id"); - auto ctx_sync = get_latest_context(node, "Sync"); - int id = ctx_id.empty() ? -1 : ctx_id.get(); - bool sync = ctx_sync.empty() ? true : ctx_sync.get(); - if (id == prev_ctx_id && !sync) { + auto ctx_id_list = get_context(node, "Id").get>(); + auto ctx_sync_list = get_context(node, "Sync").get>(); + if (merge_root != -1) { + bool not_found = true; + for (auto ctx_id : ctx_id_list) { + if (ctx_id == merge_root) { + not_found = false; + break; + } + } + if (not_found) { + merge_root = -1; + } + } + bool merge_this_node = (merge_root != -1); + if (merge_root == -1) { + size_t idx = 0; + for (; idx < ctx_sync_list.size(); idx++) { + if (!ctx_sync_list[idx]) { + if (ctx_id_list.size() <= idx) { + ERR(InternalError, + "ctx_id_list should have the same size as " + "ctx_sync_list"); + } + merge_root = ctx_id_list[idx]; + break; + } + } + } + if (merge_this_node) { auto &task_info = task_infos.back(); task_info["NumWarps"] = std::max(task_info["NumWarps"].get(), num_warps); @@ -266,8 +285,8 @@ std::string Planner::Impl::plan(bool pretty) const { resource_group["ProcessorRange"] = {0, num_processors}; max_processor_id = std::max(max_processor_id, num_processors); } else if (ctx_processor_range_list.size() == 1 || - (id != prev_ctx_id)) { - auto &ctx_processor_range = ctx_processor_range_list[0]; + !merge_this_node) { + auto &ctx_processor_range = ctx_processor_range_list.back(); processor_group["ProcessorRange"] = ctx_processor_range; resource_group["ProcessorRange"] = ctx_processor_range; max_processor_id = std::max( @@ -304,7 +323,6 @@ std::string Planner::Impl::plan(bool pretty) const { resource_group); } } - prev_ctx_id = id; first_op = false; } diff --git a/python/ark/planner.py b/python/ark/planner.py index 59de7a61..0fdbe6c5 100644 --- a/python/ark/planner.py +++ b/python/ark/planner.py @@ -184,6 +184,8 @@ def __init__(self, **kwargs): sync: bool = kwargs.get("sync", True) config: Dict[str, Any] = kwargs.get("config", None) + print(f"ctx id = {super().id()}") + if prange is not None: self.processor_range(*prange) if wrange is not None: @@ -236,4 +238,4 @@ def plan(self) -> Plan: """ Generate an execution plan. """ - return Plan.from_str(super().plan(pretty=False)) + return Plan.from_str(super().plan(pretty=True)) diff --git a/python/unittest/test_planner.py b/python/unittest/test_planner.py new file mode 100644 index 00000000..94ad3ca4 --- /dev/null +++ b/python/unittest/test_planner.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from common import ark, pytest_ark + + +@pytest_ark() +def test_planner_processor_range(): + input_tensor = ark.tensor([64, 64], ark.fp16) + other_tensor = ark.tensor([64, 64], ark.fp16) + + with ark.PlannerContext(processor_range=[0, 128]): + with ark.PlannerContext(processor_range=[0, 8], sync=False): + ark.add(input_tensor, other_tensor) + with ark.PlannerContext(processor_range=[8, 16], sync=False): + ark.add(input_tensor, other_tensor) + + plan = ark.Planner().plan() + + pg = plan.processor_groups + assert len(pg) == 2 + assert pg[0]["ProcessorRange"] == [0, 8] + assert pg[1]["ProcessorRange"] == [8, 16] + + +@pytest_ark() +def test_planner_sync(): + input_tensor = ark.tensor([64, 64], ark.fp16) + other_tensor = ark.tensor([64, 64], ark.fp16) + + with ark.PlannerContext(sync=False): + with ark.PlannerContext(): + ark.add(input_tensor, other_tensor) + with ark.PlannerContext(): + ark.add(input_tensor, other_tensor) + + plan = ark.Planner().plan() + + pg = plan.processor_groups + assert len(pg) == 1