Skip to content

Commit 02affe5

Browse files
authored
feat(core): Add common methods to list/dict (#9)
1 parent c869192 commit 02affe5

File tree

18 files changed

+393
-71
lines changed

18 files changed

+393
-71
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
*.dSYM
33
build
44
build-cpp-tests
5+
python/mlc/_version.py

cpp/c_api_tests.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,8 @@ struct FieldSetter {
311311

312312
MLC_REGISTER_FUNC("mlc.testing.FieldGet").set_body([](ObjectRef root, const char *target_name) {
313313
Any ret;
314-
MLCTypeInfo *info = ::mlc::Lib::GetTypeInfo(root.GetTypeIndex());
315314
try {
316-
::mlc::core::VisitFields(root.get(), info, FieldGetter{target_name, &ret});
315+
::mlc::core::VisitFields(root.get(), nullptr, FieldGetter{target_name, &ret});
317316
} catch (FieldFoundException &) {
318317
return ret;
319318
}
@@ -322,9 +321,8 @@ MLC_REGISTER_FUNC("mlc.testing.FieldGet").set_body([](ObjectRef root, const char
322321
});
323322

324323
MLC_REGISTER_FUNC("mlc.testing.FieldSet").set_body([](ObjectRef root, const char *target_name, Any src) {
325-
MLCTypeInfo *info = ::mlc::Lib::GetTypeInfo(root.GetTypeIndex());
326324
try {
327-
::mlc::core::VisitFields(root.get(), info, FieldSetter{target_name, src});
325+
::mlc::core::VisitFields(root.get(), nullptr, FieldSetter{target_name, src});
328326
} catch (FieldFoundException &) {
329327
return;
330328
}

include/mlc/core/all.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#ifndef MLC_CORE_ALL_H_
22
#define MLC_CORE_ALL_H_
3-
#include "./dict.h" // IWYU pragma: export
4-
#include "./error.h" // IWYU pragma: export
5-
#include "./field_visitor.h" // IWYU pragma: export
6-
#include "./func.h" // IWYU pragma: export
7-
#include "./func_details.h" // IWYU pragma: export
8-
#include "./list.h" // IWYU pragma: export
9-
#include "./object.h" // IWYU pragma: export
10-
#include "./object_path.h" // IWYU pragma: export
11-
#include "./reflection.h" // IWYU pragma: export
12-
#include "./str.h" // IWYU pragma: export
13-
#include "./typing.h" // IWYU pragma: export
14-
#include "./utils.h" // IWYU pragma: export
3+
#include "./dict.h" // IWYU pragma: export
4+
#include "./error.h" // IWYU pragma: export
5+
#include "./func.h" // IWYU pragma: export
6+
#include "./func_details.h" // IWYU pragma: export
7+
#include "./list.h" // IWYU pragma: export
8+
#include "./object.h" // IWYU pragma: export
9+
#include "./object_path.h" // IWYU pragma: export
10+
#include "./reflection.h" // IWYU pragma: export
11+
#include "./str.h" // IWYU pragma: export
12+
#include "./typing.h" // IWYU pragma: export
13+
#include "./utils.h" // IWYU pragma: export
14+
#include "./visitor.h" // IWYU pragma: export
1515

1616
namespace mlc {
1717
namespace core {

include/mlc/core/dict.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ struct UDictObj : protected ::mlc::core::DictBase {
7373
MLC_INLINE iterator find(const Any &key) { return iterator(IterStateMut{this, Acc::Find(this, key)}); }
7474
MLC_INLINE const_iterator find(const Any &key) const { return const_iterator(IterStateConst{this, Acc::Find(this, key)}); }
7575
MLC_INLINE void erase(const Any &key) { Acc::Erase(this, key); }
76-
MLC_INLINE void erase(const iterator &it) { Acc::Erase(this, it.i.i); }
77-
MLC_INLINE void erase(const const_iterator &it) { Acc::Erase(this, it.i.i); }
76+
MLC_INLINE void erase(const iterator &it) { Acc::_Erase(this, it.i.i); }
77+
MLC_INLINE void erase(const const_iterator &it) { Acc::_Erase(this, it.i.i); }
7878
// clang-format on
7979
template <typename K, typename V> MLC_INLINE_NO_MSVC DictObj<K, V> *AsTyped() const;
8080

@@ -149,6 +149,8 @@ struct UDict : public ObjectRef {
149149
.StaticFn("__init__", FromAnyTuple)
150150
.MemFn("__str__", &UDictObj::__str__)
151151
.MemFn("__getitem__", ::mlc::core::DictBase::Accessor<UDictObj>::GetItem)
152+
.MemFn("__setitem__", ::mlc::core::DictBase::Accessor<UDictObj>::SetItem)
153+
.MemFn("__delitem__", ::mlc::core::DictBase::Accessor<UDictObj>::Erase)
152154
.MemFn("__iter_get_key__", ::mlc::core::DictBase::Accessor<UDictObj>::GetKey)
153155
.MemFn("__iter_get_value__", ::mlc::core::DictBase::Accessor<UDictObj>::GetValue)
154156
.MemFn("__iter_advance__", ::mlc::core::DictBase::Accessor<UDictObj>::Advance);

include/mlc/core/dict_base.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ struct DictBase : public MLCDict {
3939
inline static void WithCapacity(TDictObj *self, int64_t new_cap);
4040
inline static KVPair *InsertOrLookup(TDictObj *self, Any key);
4141
inline static KVPair *TryInsertOrLookup(TDictObj *self, MLCAny *key);
42-
inline static void Erase(TDictObj *self, const Any &key);
43-
inline static void Erase(TDictObj *self, int64_t index);
42+
inline static Any Erase(TDictObj *self, const Any &key);
43+
inline static void _Erase(TDictObj *self, int64_t index);
4444
inline static Any &At(TDictObj *self, const Any &key);
4545
inline static const Any &At(const TDictObj *self, const Any &key);
4646
inline static Any &Bracket(TDictObj *self, const Any &key) {
@@ -52,6 +52,8 @@ struct DictBase : public MLCDict {
5252
inline static BlockIter Prev(const TDictObj *self, BlockIter iter);
5353
inline static void New(int32_t num_args, const AnyView *args, Any *any_ret);
5454
inline static Any GetItem(TDictObj *self, Any key) { return self->at(key); }
55+
inline static void SetItem(TDictObj *self, Any key, Any value) { (*self)[key] = value; }
56+
inline static Any DelItem(TDictObj *self, Any key) { return Erase(self, key); }
5557
inline static Any GetKey(TDictObj *self, int64_t i) { return IterStateMut{self, i}.At().first; }
5658
inline static Any GetValue(TDictObj *self, int64_t i) { return IterStateMut{self, i}.At().second; }
5759
inline static int64_t Advance(TDictObj *self, int64_t i) { return IterStateMut{self, i}.Add().i; }
@@ -298,17 +300,19 @@ inline DictBase::KVPair *DictBase::Accessor<TDictObj>::TryInsertOrLookup(TDictOb
298300
}
299301

300302
template <typename TDictObj> //
301-
inline void DictBase::Accessor<TDictObj>::Erase(TDictObj *self, const Any &key) {
303+
inline Any DictBase::Accessor<TDictObj>::Erase(TDictObj *self, const Any &key) {
302304
BlockIter iter = TSelf::Lookup(self, key);
303305
if (!iter.IsNone()) {
304-
TSelf::Erase(self, iter.i);
305-
} else {
306-
MLC_THROW(KeyError) << key;
306+
Any ret = static_cast<Any &>(iter.Data().second);
307+
TSelf::_Erase(self, iter.i);
308+
return ret;
307309
}
310+
MLC_THROW(KeyError) << key;
311+
MLC_UNREACHABLE();
308312
}
309313

310314
template <typename TDictObj> //
311-
inline void DictBase::Accessor<TDictObj>::Erase(TDictObj *self, int64_t index) {
315+
inline void DictBase::Accessor<TDictObj>::_Erase(TDictObj *self, int64_t index) {
312316
DictBase *self_base = static_cast<DictBase *>(self);
313317
BlockIter iter = BlockIter::FromIndex(self_base, index);
314318
if (uint64_t offset = iter.Offset(); offset != 0) {

include/mlc/core/list.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,24 @@ struct UList : public ObjectRef {
157157
.FieldReadOnly("data", &MLCList::data)
158158
.StaticFn("__init__", FromAnyTuple)
159159
.MemFn("__str__", &UListObj::__str__)
160-
.MemFn("__iter_at__", &::mlc::core::ListBase::Accessor<UListObj>::At);
160+
.MemFn("__iter_at__", &::mlc::core::ListBase::Accessor<UListObj>::At)
161+
.MemFn("_append", &UListObj::push_back)
162+
.MemFn("_insert", [](UListObj *self, int64_t i, Any data) { self->insert(i, data); })
163+
.MemFn("_extend",
164+
[](int32_t num_args, const AnyView *args, Any *) {
165+
if (!args[0].IsInstance<UListObj>()) {
166+
MLC_THROW(TypeError) << "First argument must be a list";
167+
}
168+
UListObj *self = args[0];
169+
self->insert(self->size(), args + 1, args + num_args);
170+
})
171+
.MemFn("_pop",
172+
[](UListObj *self, int64_t i) {
173+
Any ret = self->operator[](i);
174+
self->erase(i);
175+
return ret;
176+
})
177+
.MemFn("_clear", &UListObj::clear);
161178
};
162179

163180
template <typename T> struct ListObj : protected UListObj {

include/mlc/core/field_visitor.h renamed to include/mlc/core/visitor.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ namespace core {
1515
void ReportTypeFieldError(const char *type_key, MLCTypeField *field);
1616

1717
template <typename Visitor> inline void VisitFields(Object *root, MLCTypeInfo *info, Visitor &&visitor) {
18+
if (root == nullptr) {
19+
MLC_THROW(ValueError) << "Root is nullptr";
20+
}
21+
if (info == nullptr) {
22+
info = Lib::GetTypeInfo(root->GetTypeIndex());
23+
}
1824
MLCTypeField *fields = info->fields;
1925
MLCTypeField *field = fields;
2026
for (; field->name != nullptr; ++field) {
@@ -90,6 +96,12 @@ template <typename Visitor> inline void VisitFields(Object *root, MLCTypeInfo *i
9096
}
9197

9298
template <typename Visitor> inline void VisitStructure(Object *root, MLCTypeInfo *info, Visitor &&visitor) {
99+
if (root == nullptr) {
100+
MLC_THROW(ValueError) << "Root is nullptr";
101+
}
102+
if (info == nullptr) {
103+
info = Lib::GetTypeInfo(root->GetTypeIndex());
104+
}
93105
if (info->structure_kind == 0) {
94106
MLC_THROW(TypeError) << "Structure is not defined for type: " << info->type_key;
95107
}

include/mlc/printer/ast.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,11 @@ struct LiteralObj : public ::mlc::Object {
152152
}; // struct LiteralObj
153153

154154
struct Literal : public ::mlc::printer::Expr {
155-
static Literal Bool(bool value) { return Literal(mlc::List<ObjectPath>(), Any(value)); }
156-
static Literal Int(int64_t value) { return Literal(mlc::List<ObjectPath>(), Any(value)); }
157-
static Literal Str(mlc::Str value) { return Literal(mlc::List<ObjectPath>(), Any(value)); }
158-
static Literal Float(double value) { return Literal(mlc::List<ObjectPath>(), Any(value)); }
159-
static Literal Null() { return Literal(mlc::List<ObjectPath>(), Any()); }
155+
static Literal Bool(bool value, List<ObjectPath> source_paths = {}) { return Literal(source_paths, Any(value)); }
156+
static Literal Int(int64_t value, List<ObjectPath> source_paths = {}) { return Literal(source_paths, Any(value)); }
157+
static Literal Str(mlc::Str value, List<ObjectPath> source_paths = {}) { return Literal(source_paths, Any(value)); }
158+
static Literal Float(double value, List<ObjectPath> source_paths = {}) { return Literal(source_paths, Any(value)); }
159+
static Literal Null(List<ObjectPath> source_paths = {}) { return Literal(source_paths, Any()); }
160160

161161
MLC_DEF_OBJ_REF(MLC_EXPORTS, Literal, LiteralObj, ::mlc::printer::Expr)
162162
.Field("source_paths", &LiteralObj::source_paths)

include/mlc/printer/ir_printer.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,26 @@ struct IRPrinterObj : public Object {
108108
return (*it).second->creator();
109109
}
110110

111-
Any operator()(const Optional<ObjectRef> &opt_obj, const ObjectPath &path) const {
112-
if (!opt_obj.has_value()) {
113-
return Literal::Null();
111+
Any operator()(Any source, ObjectPath path) const {
112+
if (source.type_index == kMLCNone) {
113+
return Literal::Null({path});
114114
}
115-
Node ret = ::mlc::Lib::IRPrint(opt_obj.value(), this, path);
115+
if (source.type_index == kMLCBool) {
116+
return Literal::Bool(source.operator bool(), {path});
117+
}
118+
if (source.type_index == kMLCInt) {
119+
return Literal::Int(source.operator int64_t(), {path});
120+
}
121+
if (source.type_index == kMLCStr || source.type_index == kMLCRawStr) {
122+
return Literal::Str(source.operator Str(), {path});
123+
}
124+
if (source.type_index == kMLCFloat) {
125+
return Literal::Float(source.operator double(), {path});
126+
}
127+
if (source.type_index < kMLCStaticObjectBegin) {
128+
MLC_THROW(ValueError) << "Unsupported type: " << source;
129+
}
130+
Node ret = ::mlc::Lib::IRPrint(source.operator Object *(), this, path);
116131
ret->source_paths->push_back(path);
117132
return ret;
118133
}

pyproject.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
[project]
22
name = "mlc-python"
3-
version = "0.1.1"
3+
dynamic = ["version"]
44
dependencies = [
55
'numpy >= 1.22',
66
'ml-dtypes >= 0.1',
77
'Pygments>=2.4.0',
88
'colorama',
99
'setuptools ; platform_system == "Windows"',
1010
]
11-
description = ""
11+
description = "Python-first Development for AI Compilers"
1212
requires-python = ">=3.9"
1313
classifiers = [
1414
"Programming Language :: Python :: 3",
@@ -38,10 +38,14 @@ dev = [
3838
]
3939

4040
[build-system]
41-
requires = ["scikit-build-core>=0.9.8", "cython"]
41+
requires = ["scikit-build-core>=0.9.8", "cython", "setuptools-scm"]
4242
build-backend = "scikit_build_core.build"
4343

44+
[tool.setuptools_scm]
45+
version_file = "python/mlc/_version.py"
46+
4447
[tool.scikit-build]
48+
metadata.version.provider = "scikit_build_core.metadata.setuptools_scm"
4549
build.targets = ["mlc_py", "mlc_registry"]
4650
build.verbose = true
4751
cmake.build-type = "RelWithDebInfo"

0 commit comments

Comments
 (0)