diff --git a/ark/api/context.cpp b/ark/api/context.cpp index 76baedc8..702247dd 100644 --- a/ark/api/context.cpp +++ b/ark/api/context.cpp @@ -29,4 +29,8 @@ void Context::set(const std::string& key, const std::string& value, this->impl_->set(key, value_json, type); } +std::string Context::dump() const { + return this->impl_->dump().dump(); +} + } // namespace ark diff --git a/ark/api/planner.cpp b/ark/api/planner.cpp index 1c117af9..506dcaff 100644 --- a/ark/api/planner.cpp +++ b/ark/api/planner.cpp @@ -17,7 +17,7 @@ namespace ark { PlannerContext::PlannerContext(Model &model) : Context(model) { - this->impl_->set("Id", this->id(), ContextType::Immutable); + this->impl_->set("Id", this->id(), ContextType::Overwrite); } void PlannerContext::check_range(const std::string &key, diff --git a/ark/context_impl.cpp b/ark/context_impl.cpp index 9a2692ea..c4f95f2c 100644 --- a/ark/context_impl.cpp +++ b/ark/context_impl.cpp @@ -52,4 +52,8 @@ bool Context::Impl::has(const std::string& key) const { return context_manager_->has(key); } +Json Context::Impl::dump() const { + return context_manager_->dump(); +} + } // namespace ark diff --git a/ark/context_impl.hpp b/ark/context_impl.hpp index 1a77891b..73fcae92 100644 --- a/ark/context_impl.hpp +++ b/ark/context_impl.hpp @@ -21,6 +21,8 @@ class Context::Impl { bool has(const std::string& key) const; + Json dump() const; + protected: friend class Context; diff --git a/ark/include/ark/context.hpp b/ark/include/ark/context.hpp index f3eef283..aaa22bd3 100644 --- a/ark/include/ark/context.hpp +++ b/ark/include/ark/context.hpp @@ -17,9 +17,9 @@ enum class ContextType { class Context { public: /// - /// Construct an empty context for the given model. + /// Context handler of the given model. /// - /// @param model The model to create the context for. + /// @param model The model to manipulate the context for. /// Context(Model& model); @@ -78,6 +78,9 @@ class Context { void set(const std::string& key, const std::string& value, ContextType type = ContextType::Overwrite); + /// Return the entire context stacks as a JSON format string. + std::string dump() const; + protected: friend class PlannerContext; diff --git a/ark/model/model_context_manager.cpp b/ark/model/model_context_manager.cpp index f1bb62e9..799cce78 100644 --- a/ark/model/model_context_manager.cpp +++ b/ark/model/model_context_manager.cpp @@ -27,4 +27,8 @@ Json ModelContextManager::get(const std::string& key) const { return context_stack_->get(key); } +Json ModelContextManager::dump() const { + return context_stack_->dump(); +} + } // namespace ark diff --git a/ark/model/model_context_manager.hpp b/ark/model/model_context_manager.hpp index 6aa91692..4dc246fe 100644 --- a/ark/model/model_context_manager.hpp +++ b/ark/model/model_context_manager.hpp @@ -24,6 +24,8 @@ class ModelContextManager { Json get(const std::string& key) const; + Json dump() const; + private: std::shared_ptr context_stack_; std::vector keys_; diff --git a/ark/model/model_graph_impl.cpp b/ark/model/model_graph_impl.cpp index 7c1ea3fb..b7717ecd 100644 --- a/ark/model/model_graph_impl.cpp +++ b/ark/model/model_graph_impl.cpp @@ -62,6 +62,17 @@ std::map ModelGraphContextStack::get_all() const { return cur; } +Json ModelGraphContextStack::dump() const { + Json j; + for (const auto &pair : this->storage_) { + j[pair.first] = Json::array(); + for (const auto &value : pair.second) { + j[pair.first].emplace_back(*value); + } + } + return j; +} + ModelGraph::Impl::Impl(const ModelGraph::Impl &other) { *this = other; } ModelGraph::Impl &ModelGraph::Impl::operator=(const ModelGraph::Impl &other) { diff --git a/ark/model/model_graph_impl.hpp b/ark/model/model_graph_impl.hpp index 62944f99..5cd60d03 100644 --- a/ark/model/model_graph_impl.hpp +++ b/ark/model/model_graph_impl.hpp @@ -39,6 +39,8 @@ class ModelGraphContextStack { Json get(const std::string &key) const; std::map get_all() const; + + Json dump() const; }; class ModelGraph::Impl { diff --git a/ark/model/model_json.cpp b/ark/model/model_json.cpp index dad62cb4..31fb24d5 100644 --- a/ark/model/model_json.cpp +++ b/ark/model/model_json.cpp @@ -302,9 +302,16 @@ static void verify_format_plan(const Json &json) { "NumWarpsPerProcessor", "TaskInfos", "ProcessorGroups"}; + if (!json.is_object()) { + std::string dumped = json.dump(); + if (dumped.size() > 100) { + dumped = dumped.substr(0, 100) + "..."; + } + ERR(PlanError, "Plan should be a JSON object. Given: ", dumped); + } for (const auto &field : required_fields) { if (!json.contains(field)) { - ERR(PlanError, field + " not found"); + ERR(PlanError, field, " not found"); } } if (!json.at("TaskInfos").is_array()) { diff --git a/examples/llama/model.py b/examples/llama/model.py index 3d18190b..f80d68e5 100644 --- a/examples/llama/model.py +++ b/examples/llama/model.py @@ -104,15 +104,9 @@ def forward(self, x): x2 = ark.mul(x, x) with Context(config={"Tile": [1], "ImplType": "WarpWise"}): mean = ark.reduce_mean(x2, axis=-1) - with Context( - config={ - "NumWarps": 1, - "SramBytes": 0, - "Tile": [64, 1], - } - ): - mean = ark.add(mean, self.eps) - rrms = ark.rsqrt(mean) + mean = ark.add(mean, self.eps) + rrms = ark.rsqrt(mean) + with Context( warp_range=[0, 8], sync=False, @@ -307,22 +301,6 @@ def forward(self, x): return ark.matmul(x, self.weight, transpose_other=True) -# def tester(ref_func): -# def decorator(func): -# def wrapper(*args, **kwargs): -# data = [] -# kdata = {} -# for arg in args: -# if isinstance(arg, ark.Tensor): -# rand_data = -# ref_outputs = ref_func(*args, **kwargs) -# outputs = func(*args, **kwargs) -# return outputs - -# return wrapper -# return decorator - - class Silu(ark.Module): """ Silu activation function, silu(x) = x * sigmoid(x) diff --git a/python/ark/planner.py b/python/ark/planner.py index 3c82719b..59de7a61 100644 --- a/python/ark/planner.py +++ b/python/ark/planner.py @@ -195,6 +195,15 @@ def __init__(self, **kwargs): if config is not None: self.config(json.dumps(config)) + def dump(self) -> str: + """ + Dump the context stack. + + Returns: + str: The context stack in JSON format. + """ + return super().dump() + def __enter__(self) -> "PlannerContext": """ Enter the plan manager. diff --git a/python/planner_py.cpp b/python/planner_py.cpp index f0af0fa3..b43a8fdd 100644 --- a/python/planner_py.cpp +++ b/python/planner_py.cpp @@ -13,6 +13,7 @@ namespace py = pybind11; void register_planner(py::module &m) { py::class_(m, "CorePlannerContext") .def(py::init()) + .def("id", &ark::PlannerContext::id) .def("processor_range", &ark::PlannerContext::processor_range, py::arg("start"), py::arg("end"), py::arg("step") = 1) .def("warp_range", &ark::PlannerContext::warp_range, py::arg("start"), @@ -20,7 +21,8 @@ void register_planner(py::module &m) { .def("sram_range", &ark::PlannerContext::sram_range, py::arg("start"), py::arg("end"), py::arg("step") = 1) .def("sync", &ark::PlannerContext::sync, py::arg("sync")) - .def("config", &ark::PlannerContext::config, py::arg("config")); + .def("config", &ark::PlannerContext::config, py::arg("config")) + .def("dump", &ark::PlannerContext::dump); py::class_(m, "CorePlanner") .def(py::init())