Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 27, 2024
1 parent 2a0e1a0 commit dae1bb2
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 70 deletions.
114 changes: 71 additions & 43 deletions ark/api/planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
namespace ark {

PlannerContext::PlannerContext(Model &model) : Context(model) {
this->impl_->set("Id", this->id(), ContextType::Overwrite);
this->impl_->set("Sync", true, ContextType::Overwrite);
this->impl_->set("Id", id());
Json val;
val.push_back(id());
val.push_back(true);
this->impl_->set("Sync", val);
}

void PlannerContext::check_range(const std::string &key,
Expand All @@ -28,7 +31,7 @@ void PlannerContext::check_range(const std::string &key,
// ok
return;
}
auto prev_vec = prev.get<std::vector<int>>();
auto prev_vec = prev[1].get<std::vector<int>>();
if (prev_vec.size() < 2 || prev_vec.size() > 3) {
ERR(InternalError, "unexpected");
}
Expand All @@ -42,43 +45,56 @@ void PlannerContext::check_range(const std::string &key,

void PlannerContext::processor_range(int start, int end, int step) {
check_range("ProcessorRange", {start, end, step});
Json val;
val.push_back(id());
if (step == 1) {
this->impl_->set("ProcessorRange", {start, end},
ContextType::Overwrite);
val.push_back({start, end});
this->impl_->set("ProcessorRange", {id(), {start, end}});
} else {
this->impl_->set("ProcessorRange", {start, end, step},
ContextType::Overwrite);
val.push_back({start, end, step});
this->impl_->set("ProcessorRange", {id(), {start, end, step}});
}
}

void PlannerContext::warp_range(int start, int end, int step) {
check_range("WarpRange", {start, end, step});
Json val;
val.push_back(id());
if (step == 1) {
this->impl_->set("WarpRange", {start, end}, ContextType::Overwrite);
val.push_back({start, end});
this->impl_->set("WarpRange", {id(), {start, end}});
} else {
this->impl_->set("WarpRange", {start, end, step},
ContextType::Overwrite);
val.push_back({start, end, step});
this->impl_->set("WarpRange", {id(), {start, end, step}});
}
}

void PlannerContext::sram_range(int start, int end, int step) {
check_range("SramRange", {start, end, step});
Json val;
val.push_back(id());
if (step == 1) {
this->impl_->set("SramRange", {start, end}, ContextType::Overwrite);
val.push_back({start, end});
this->impl_->set("SramRange", {id(), {start, end}});
} else {
this->impl_->set("SramRange", {start, end, step},
ContextType::Overwrite);
val.push_back({start, end, step});
this->impl_->set("SramRange", {id(), {start, end, step}});
}
}

void PlannerContext::sync(bool sync) {
// Sync should be always pushed with Id together.
this->impl_->set("Id", this->id(), ContextType::Overwrite);
this->impl_->set("Sync", sync, ContextType::Overwrite);
Json val;
val.push_back(id());
val.push_back(sync);
this->impl_->set("Sync", val);
}

void PlannerContext::config(const std::string &config) {
this->impl_->set("Config", Json::parse(config), ContextType::Extend);
Json val;
val.push_back(id());
val.push_back(Json::parse(config));
this->impl_->set("Config", val);
}

class Planner::Impl {
Expand Down Expand Up @@ -128,6 +144,7 @@ std::string Planner::Impl::plan(bool pretty) const {
size_t max_warp_id = 1;
size_t next_task_id = 0;
int merge_root = -1;
int processor_group_root = -1;
bool first_op = true;

auto get_context = [&](const ModelNodeRef &node,
Expand All @@ -150,12 +167,15 @@ std::string Planner::Impl::plan(bool pretty) const {
const auto &op = node->op;
if (op->is_virtual()) continue;

auto ctx_config = get_latest_context(node, "Config");

Json config;
if (!ctx_config.empty()) {
config = ctx_config;
} else if (!config_rules_.empty()) {
Json config = Json::object();
for (auto &obj : get_context(node, "Config")) {
LOG(INFO, obj.dump());
auto &items = obj[1];
for (auto &item : items.items()) {
config[item.key()] = item.value();
}
}
if (config.empty() && !config_rules_.empty()) {
const std::string op_str = op->serialize().dump();
for (auto &rule : config_rules_) {
auto config_str = rule(op_str, gpu_info.arch->name());
Expand Down Expand Up @@ -225,8 +245,8 @@ std::string Planner::Impl::plan(bool pretty) const {
}

size_t granularity = config.value("Granularity", 1);
auto ctx_id_list = get_context(node, "Id").get<std::vector<int>>();
auto ctx_sync_list = get_context(node, "Sync").get<std::vector<bool>>();
auto ctx_id_list = get_context(node, "Id");
auto ctx_sync_list = get_context(node, "Sync");
if (merge_root != -1) {
bool not_found = true;
for (auto ctx_id : ctx_id_list) {
Expand All @@ -241,15 +261,11 @@ std::string Planner::Impl::plan(bool pretty) const {
}
bool merge_this_node = (merge_root != -1);
if (merge_root == -1) {
size_t idx = 0;
for (; idx < ctx_sync_list.size(); idx++) {
if (!ctx_sync_list[idx]) {
if (ctx_id_list.size() <= idx) {
ERR(InternalError,
"ctx_id_list should have the same size as "
"ctx_sync_list");
}
merge_root = ctx_id_list[idx];
for (auto &item : ctx_sync_list) {
auto &ctx_id = item[0];
auto &sync = item[1];
if (!sync) {
merge_root = ctx_id;
break;
}
}
Expand Down Expand Up @@ -279,34 +295,46 @@ std::string Planner::Impl::plan(bool pretty) const {
Json processor_group;
Json resource_group;
bool new_processor_group = true;
bool id_found = false;
for (auto &item : ctx_processor_range_list) {
if (item[0] == processor_group_root) {
id_found = true;
break;
}
}
if (!id_found) {
processor_group_root = -1;
}
if (ctx_processor_range_list.size() > 2) {
ERR(UnsupportedError, "ProcessorRange list size > 2");
}
if (ctx_processor_range_list.empty()) {
size_t num_processors = std::min(num_sm, num_tasks);
processor_group["ProcessorRange"] = {0, num_processors};
resource_group["ProcessorRange"] = {0, num_processors};
max_processor_id = std::max(max_processor_id, num_processors);
} else if (ctx_processor_range_list.size() == 1 ||
!merge_this_node) {
auto &ctx_processor_range = ctx_processor_range_list.back();
processor_group["ProcessorRange"] = ctx_processor_range;
resource_group["ProcessorRange"] = ctx_processor_range;
} else if (processor_group_root == -1) {
processor_group_root = ctx_processor_range_list.front()[0];
processor_group["ProcessorRange"] = ctx_processor_range_list.front()[1];
resource_group["ProcessorRange"] = ctx_processor_range_list.back()[1];
max_processor_id = std::max(
max_processor_id, ctx_processor_range[1].get<size_t>());
max_processor_id, ctx_processor_range_list.front()[1][1].get<size_t>());
} else {
new_processor_group = false;
resource_group["ProcessorRange"] =
ctx_processor_range_list.back();
ctx_processor_range_list.back()[1];
}

if (!ctx_warp_range.empty()) {
resource_group["WarpRange"] = ctx_warp_range;
resource_group["WarpRange"] = ctx_warp_range[1];
max_warp_id =
std::max(max_warp_id, ctx_warp_range[1].get<size_t>());
std::max(max_warp_id, ctx_warp_range[1][1].get<size_t>());
} else {
resource_group["WarpRange"] = {0, num_warps};
max_warp_id = std::max(max_warp_id, num_warps);
}
if (!ctx_sram_range.empty()) {
resource_group["SramRange"] = ctx_sram_range;
resource_group["SramRange"] = ctx_sram_range[1];
} else {
resource_group["SramRange"] = {0, sram_bytes};
}
Expand Down
2 changes: 1 addition & 1 deletion ark/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Context::Impl {

Json get(const std::string& key) const;

void set(const std::string& key, const Json& value_json, ContextType type);
void set(const std::string& key, const Json& value_json, ContextType type = ContextType::Overwrite);

bool has(const std::string& key) const;

Expand Down
31 changes: 11 additions & 20 deletions examples/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,10 @@ def forward(self, x):
mean = ark.reduce_mean(x2, axis=-1)
mean = ark.add(mean, self.eps)
rrms = ark.rsqrt(mean)

with Context(
warp_range=[0, 8],
sync=False,
config={
"NumWarps": 1,
"SramBytes": 0,
"Tile": [1, 4096],
"Granularity": 7,
},
):
x = ark.mul(x, rrms)
x = ark.mul(x, self.weight, x)
return ark.cast(x, self.dtype)
with Context(config={"Tile": [1, 4096]}):
x = ark.mul(x, rrms)
x = ark.mul(x, self.weight, x)
return ark.cast(x, self.dtype)


class ColumnParallelLinear(ark.Module):
Expand Down Expand Up @@ -668,10 +658,11 @@ def forward(
freqs_cis: ark.Tensor,
mask: Optional[ark.Tensor],
):
h = self.tok_embeddings(tokens)
with Context(warp_range=[0, 8]):
h = self.tok_embeddings(tokens)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h)
return output
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h)
return output
53 changes: 50 additions & 3 deletions python/ark/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import sys
import time
from typing import Optional, List

from .runtime import Runtime
from .planner import Plan
Expand All @@ -26,10 +27,12 @@ def run(
iter: int = 1000,
loop_mode: bool = True,
profile_processor_groups: bool = False,
target_processor_groups: Optional[List[int]] = None,
):
sys.stderr.write(
f"End-to-end: {timeit(self.plan, iter, loop_mode):.6f} seconds/iter\n"
)
if target_processor_groups is None:
sys.stderr.write(
f"End-to-end: {timeit(self.plan, iter, loop_mode):.6f} seconds/iter\n"
)

if not profile_processor_groups:
return
Expand All @@ -44,8 +47,52 @@ def run(
"ProcessorGroups": [None],
}
for i in range(num_processor_groups):
if target_processor_groups is not None and i not in target_processor_groups:
continue
new_plan["ProcessorGroups"][0] = self.plan.processor_groups[i]
lat_per_iter = timeit(Plan(new_plan), iter, loop_mode)
sys.stderr.write(
f"Processor group {i}: {lat_per_iter:.6f} seconds/iter\n"
)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="ARK Profiler")
parser.add_argument(
"--iter",
type=int,
default=1000,
help="Number of iterations to run for each measurement",
)
parser.add_argument(
"--loop_mode",
action="store_true",
help="Use loop mode to measure end-to-end latency",
)
parser.add_argument(
"--profile_processor_groups",
action="store_true",
help="Profile processor groups",
)
parser.add_argument(
"--target_processor_groups",
type=str,
help="Target processor groups to profile",
)
parser.add_argument("--plan", type=str, help="Path to the plan file", required=True)
args = parser.parse_args()

target_processor_groups = None
if args.target_processor_groups is not None:
target_processor_groups = list(map(int, args.target_processor_groups.split(",")))

plan = Plan.from_file(args.plan)
profiler = Profiler(plan)
profiler.run(
iter=args.iter,
loop_mode=args.loop_mode,
profile_processor_groups=args.profile_processor_groups,
target_processor_groups=target_processor_groups,
)
6 changes: 3 additions & 3 deletions python/unittest/test_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def test_planner_processor_range():
plan = ark.Planner().plan()

pg = plan.processor_groups
assert len(pg) == 2
assert pg[0]["ProcessorRange"] == [0, 8]
assert pg[1]["ProcessorRange"] == [8, 16]
assert len(pg) == 1
assert pg[0]["ResourceGroups"][0]["ProcessorRange"] == [0, 8]
assert pg[0]["ResourceGroups"][1]["ProcessorRange"] == [8, 16]


@pytest_ark()
Expand Down

0 comments on commit dae1bb2

Please sign in to comment.