Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 27, 2024
1 parent 97ab2df commit 2a0e1a0
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 21 deletions.
58 changes: 38 additions & 20 deletions ark/api/planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace ark {

PlannerContext::PlannerContext(Model &model) : Context(model) {
this->impl_->set("Id", this->id(), ContextType::Overwrite);
this->impl_->set("Sync", true, ContextType::Overwrite);
}

void PlannerContext::check_range(const std::string &key,
Expand Down Expand Up @@ -71,16 +72,9 @@ void PlannerContext::sram_range(int start, int end, int step) {
}

void PlannerContext::sync(bool sync) {
if (sync) {
// `true` should not overwrite `false`.
if (this->impl_->get("Sync") == Json(false)) {
LOG(WARN, "Ignoring sync(true) while sync(false) is already set");
return;
}
this->impl_->set("Sync", true, ContextType::Immutable);
} else {
this->impl_->set("Sync", false, ContextType::Overwrite);
}
// Sync should be always pushed with Id together.
this->impl_->set("Id", this->id(), ContextType::Overwrite);
this->impl_->set("Sync", sync, ContextType::Overwrite);
}

void PlannerContext::config(const std::string &config) {
Expand Down Expand Up @@ -133,7 +127,7 @@ std::string Planner::Impl::plan(bool pretty) const {
size_t max_processor_id = 1;
size_t max_warp_id = 1;
size_t next_task_id = 0;
int prev_ctx_id = -1;
int merge_root = -1;
bool first_op = true;

auto get_context = [&](const ModelNodeRef &node,
Expand All @@ -142,7 +136,7 @@ std::string Planner::Impl::plan(bool pretty) const {
return node->context.at(key);
} catch (const Json::out_of_range &e) {
}
return Json();
return Json::array();
};

auto get_latest_context = [&](const ModelNodeRef &node,
Expand Down Expand Up @@ -231,11 +225,36 @@ std::string Planner::Impl::plan(bool pretty) const {
}

size_t granularity = config.value("Granularity", 1);
auto ctx_id = get_latest_context(node, "Id");
auto ctx_sync = get_latest_context(node, "Sync");
int id = ctx_id.empty() ? -1 : ctx_id.get<int>();
bool sync = ctx_sync.empty() ? true : ctx_sync.get<bool>();
if (id == prev_ctx_id && !sync) {
auto ctx_id_list = get_context(node, "Id").get<std::vector<int>>();
auto ctx_sync_list = get_context(node, "Sync").get<std::vector<bool>>();
if (merge_root != -1) {
bool not_found = true;
for (auto ctx_id : ctx_id_list) {
if (ctx_id == merge_root) {
not_found = false;
break;
}
}
if (not_found) {
merge_root = -1;
}
}
bool merge_this_node = (merge_root != -1);
if (merge_root == -1) {
size_t idx = 0;
for (; idx < ctx_sync_list.size(); idx++) {
if (!ctx_sync_list[idx]) {
if (ctx_id_list.size() <= idx) {
ERR(InternalError,
"ctx_id_list should have the same size as "
"ctx_sync_list");
}
merge_root = ctx_id_list[idx];
break;
}
}
}
if (merge_this_node) {
auto &task_info = task_infos.back();
task_info["NumWarps"] =
std::max(task_info["NumWarps"].get<size_t>(), num_warps);
Expand Down Expand Up @@ -266,8 +285,8 @@ std::string Planner::Impl::plan(bool pretty) const {
resource_group["ProcessorRange"] = {0, num_processors};
max_processor_id = std::max(max_processor_id, num_processors);
} else if (ctx_processor_range_list.size() == 1 ||
(id != prev_ctx_id)) {
auto &ctx_processor_range = ctx_processor_range_list[0];
!merge_this_node) {
auto &ctx_processor_range = ctx_processor_range_list.back();
processor_group["ProcessorRange"] = ctx_processor_range;
resource_group["ProcessorRange"] = ctx_processor_range;
max_processor_id = std::max(
Expand Down Expand Up @@ -304,7 +323,6 @@ std::string Planner::Impl::plan(bool pretty) const {
resource_group);
}
}
prev_ctx_id = id;
first_op = false;
}

Expand Down
4 changes: 3 additions & 1 deletion python/ark/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def __init__(self, **kwargs):
sync: bool = kwargs.get("sync", True)
config: Dict[str, Any] = kwargs.get("config", None)

print(f"ctx id = {super().id()}")

if prange is not None:
self.processor_range(*prange)
if wrange is not None:
Expand Down Expand Up @@ -236,4 +238,4 @@ def plan(self) -> Plan:
"""
Generate an execution plan.
"""
return Plan.from_str(super().plan(pretty=False))
return Plan.from_str(super().plan(pretty=True))
40 changes: 40 additions & 0 deletions python/unittest/test_planner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from common import ark, pytest_ark


@pytest_ark()
def test_planner_processor_range():
input_tensor = ark.tensor([64, 64], ark.fp16)
other_tensor = ark.tensor([64, 64], ark.fp16)

with ark.PlannerContext(processor_range=[0, 128]):
with ark.PlannerContext(processor_range=[0, 8], sync=False):
ark.add(input_tensor, other_tensor)
with ark.PlannerContext(processor_range=[8, 16], sync=False):
ark.add(input_tensor, other_tensor)

plan = ark.Planner().plan()

pg = plan.processor_groups
assert len(pg) == 2
assert pg[0]["ProcessorRange"] == [0, 8]
assert pg[1]["ProcessorRange"] == [8, 16]


@pytest_ark()
def test_planner_sync():
input_tensor = ark.tensor([64, 64], ark.fp16)
other_tensor = ark.tensor([64, 64], ark.fp16)

with ark.PlannerContext(sync=False):
with ark.PlannerContext():
ark.add(input_tensor, other_tensor)
with ark.PlannerContext():
ark.add(input_tensor, other_tensor)

plan = ark.Planner().plan()

pg = plan.processor_groups
assert len(pg) == 1

0 comments on commit 2a0e1a0

Please sign in to comment.