Skip to content

Commit

Permalink
A few fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang committed Sep 27, 2024
1 parent b9f35d9 commit dfae17b
Show file tree
Hide file tree
Showing 13 changed files with 58 additions and 30 deletions.
4 changes: 4 additions & 0 deletions ark/api/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,8 @@ void Context::set(const std::string& key, const std::string& value,
this->impl_->set(key, value_json, type);
}

std::string Context::dump() const {
return this->impl_->dump().dump();
}

} // namespace ark
2 changes: 1 addition & 1 deletion ark/api/planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
namespace ark {

PlannerContext::PlannerContext(Model &model) : Context(model) {
this->impl_->set("Id", this->id(), ContextType::Immutable);
this->impl_->set("Id", this->id(), ContextType::Overwrite);
}

void PlannerContext::check_range(const std::string &key,
Expand Down
4 changes: 4 additions & 0 deletions ark/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,8 @@ bool Context::Impl::has(const std::string& key) const {
return context_manager_->has(key);
}

Json Context::Impl::dump() const {
return context_manager_->dump();
}

} // namespace ark
2 changes: 2 additions & 0 deletions ark/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class Context::Impl {

bool has(const std::string& key) const;

Json dump() const;

protected:
friend class Context;

Expand Down
7 changes: 5 additions & 2 deletions ark/include/ark/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ enum class ContextType {
class Context {
public:
///
/// Construct an empty context for the given model.
/// Context handler of the given model.
///
/// @param model The model to create the context for.
/// @param model The model to manipulate the context for.
///
Context(Model& model);

Expand Down Expand Up @@ -78,6 +78,9 @@ class Context {
void set(const std::string& key, const std::string& value,
ContextType type = ContextType::Overwrite);

/// Return the entire context stacks as a JSON format string.
std::string dump() const;

protected:
friend class PlannerContext;

Expand Down
4 changes: 4 additions & 0 deletions ark/model/model_context_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,8 @@ Json ModelContextManager::get(const std::string& key) const {
return context_stack_->get(key);
}

Json ModelContextManager::dump() const {
return context_stack_->dump();
}

} // namespace ark
2 changes: 2 additions & 0 deletions ark/model/model_context_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ModelContextManager {

Json get(const std::string& key) const;

Json dump() const;

private:
std::shared_ptr<ModelGraphContextStack> context_stack_;
std::vector<std::string> keys_;
Expand Down
11 changes: 11 additions & 0 deletions ark/model/model_graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ std::map<std::string, Json> ModelGraphContextStack::get_all() const {
return cur;
}

Json ModelGraphContextStack::dump() const {
Json j;
for (const auto &pair : this->storage_) {
j[pair.first] = Json::array();
for (const auto &value : pair.second) {
j[pair.first].emplace_back(*value);
}
}
return j;
}

ModelGraph::Impl::Impl(const ModelGraph::Impl &other) { *this = other; }

ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) {
Expand Down
2 changes: 2 additions & 0 deletions ark/model/model_graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class ModelGraphContextStack {
Json get(const std::string &key) const;

std::map<std::string, Json> get_all() const;

Json dump() const;
};

class ModelGraph::Impl {
Expand Down
9 changes: 8 additions & 1 deletion ark/model/model_json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,16 @@ static void verify_format_plan(const Json &json) {
"NumWarpsPerProcessor",
"TaskInfos",
"ProcessorGroups"};
if (!json.is_object()) {
std::string dumped = json.dump();
if (dumped.size() > 100) {
dumped = dumped.substr(0, 100) + "...";
}
ERR(PlanError, "Plan should be a JSON object. Given: ", dumped);
}
for (const auto &field : required_fields) {
if (!json.contains(field)) {
ERR(PlanError, field + " not found");
ERR(PlanError, field, " not found");
}
}
if (!json.at("TaskInfos").is_array()) {
Expand Down
28 changes: 3 additions & 25 deletions examples/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,9 @@ def forward(self, x):
x2 = ark.mul(x, x)
with Context(config={"Tile": [1], "ImplType": "WarpWise"}):
mean = ark.reduce_mean(x2, axis=-1)
with Context(
config={
"NumWarps": 1,
"SramBytes": 0,
"Tile": [64, 1],
}
):
mean = ark.add(mean, self.eps)
rrms = ark.rsqrt(mean)
mean = ark.add(mean, self.eps)
rrms = ark.rsqrt(mean)

with Context(
warp_range=[0, 8],
sync=False,
Expand Down Expand Up @@ -307,22 +301,6 @@ def forward(self, x):
return ark.matmul(x, self.weight, transpose_other=True)


# def tester(ref_func):
# def decorator(func):
# def wrapper(*args, **kwargs):
# data = []
# kdata = {}
# for arg in args:
# if isinstance(arg, ark.Tensor):
# rand_data =
# ref_outputs = ref_func(*args, **kwargs)
# outputs = func(*args, **kwargs)
# return outputs

# return wrapper
# return decorator


class Silu(ark.Module):
"""
Silu activation function, silu(x) = x * sigmoid(x)
Expand Down
9 changes: 9 additions & 0 deletions python/ark/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,15 @@ def __init__(self, **kwargs):
if config is not None:
self.config(json.dumps(config))

def dump(self) -> str:
"""
Dump the context stack.
Returns:
str: The context stack in JSON format.
"""
return super().dump()

def __enter__(self) -> "PlannerContext":
"""
Enter the plan manager.
Expand Down
4 changes: 3 additions & 1 deletion python/planner_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ namespace py = pybind11;
void register_planner(py::module &m) {
py::class_<ark::PlannerContext>(m, "CorePlannerContext")
.def(py::init<ark::Model &>())
.def("id", &ark::PlannerContext::id)
.def("processor_range", &ark::PlannerContext::processor_range,
py::arg("start"), py::arg("end"), py::arg("step") = 1)
.def("warp_range", &ark::PlannerContext::warp_range, py::arg("start"),
py::arg("end"), py::arg("step") = 1)
.def("sram_range", &ark::PlannerContext::sram_range, py::arg("start"),
py::arg("end"), py::arg("step") = 1)
.def("sync", &ark::PlannerContext::sync, py::arg("sync"))
.def("config", &ark::PlannerContext::config, py::arg("config"));
.def("config", &ark::PlannerContext::config, py::arg("config"))
.def("dump", &ark::PlannerContext::dump);

py::class_<ark::Planner>(m, "CorePlanner")
.def(py::init<const ark::Model &, int>())
Expand Down

0 comments on commit dfae17b

Please sign in to comment.