6
6
#include < cstring>
7
7
#include < iostream>
8
8
#include < memory>
9
- #include < mlc/ffi/ffi.hpp>
10
9
#include < string>
11
10
#include < unordered_map>
12
11
#include < vector>
13
12
14
13
namespace mlc {
15
- namespace ffi {
16
14
namespace registry {
17
15
18
16
struct TypeTable ;
@@ -31,6 +29,32 @@ struct TypeInfoWrapper {
31
29
~TypeInfoWrapper () { this ->Reset (); }
32
30
};
33
31
32
+ template <typename T> struct PODGetterSetter {
33
+ static int32_t Getter (void *addr, MLCAny *ret) {
34
+ MLC_SAFE_CALL_BEGIN ();
35
+ *static_cast <Any *>(ret) = *static_cast <T *>(addr);
36
+ MLC_SAFE_CALL_END (static_cast <Any *>(ret));
37
+ }
38
+ static int32_t Setter (void *addr, MLCAny *src) {
39
+ MLC_SAFE_CALL_BEGIN ();
40
+ *static_cast <T *>(addr) = (static_cast <Any *>(src))->operator T ();
41
+ MLC_SAFE_CALL_END (static_cast <Any *>(src));
42
+ }
43
+ };
44
+
45
+ template <> struct PODGetterSetter <std::nullptr_t > {
46
+ static int32_t Getter (void *, MLCAny *ret) {
47
+ MLC_SAFE_CALL_BEGIN ();
48
+ *static_cast <Any *>(ret) = nullptr ;
49
+ MLC_SAFE_CALL_END (static_cast <Any *>(ret));
50
+ }
51
+ static int32_t Setter (void *addr, MLCAny *src) {
52
+ MLC_SAFE_CALL_BEGIN ();
53
+ *static_cast <void **>(addr) = nullptr ;
54
+ MLC_SAFE_CALL_END (static_cast <Any *>(src));
55
+ }
56
+ };
57
+
34
58
struct TypeTable {
35
59
using ObjPtr = std::unique_ptr<MLCObject, void (*)(MLCObject *)>;
36
60
@@ -39,15 +63,15 @@ struct TypeTable {
39
63
std::unordered_map<std::string, MLCTypeInfo *> type_key_to_info;
40
64
std::unordered_map<std::string, std::unordered_map<int32_t , FuncObj *>> vtable;
41
65
std::unordered_map<std::string, FuncObj *> global_funcs;
42
- std::unordered_map<const void *, details ::PODArray> pool_pod_array;
66
+ std::unordered_map<const void *, ::mlc::base ::PODArray> pool_pod_array;
43
67
std::unordered_map<const void *, ObjPtr> pool_obj_ptr;
44
68
std::unordered_map<std::string, std::unique_ptr<DSOLibrary>> dso_library;
45
69
46
70
template <typename PODType> inline PODType *NewArray (int64_t size) {
47
71
if (size == 0 ) {
48
72
return nullptr ;
49
73
}
50
- details ::PODArray owned (static_cast <void *>(std::malloc (size * sizeof (PODType))), std::free);
74
+ ::mlc::base ::PODArray owned (static_cast <void *>(std::malloc (size * sizeof (PODType))), std::free);
51
75
PODType *ptr = reinterpret_cast <PODType *>(owned.get ());
52
76
auto [it, success] = this ->pool_pod_array .emplace (ptr, std::move (owned));
53
77
if (!success) {
@@ -80,8 +104,8 @@ struct TypeTable {
80
104
std::abort ();
81
105
}
82
106
MLCObject *source_casted = reinterpret_cast <MLCObject *>(source);
83
- ::mlc::ffi::details ::IncRef (source_casted);
84
- it->second = ObjPtr (source_casted, ::mlc::ffi::details ::DecRef);
107
+ ::mlc::base ::IncRef (source_casted);
108
+ it->second = ObjPtr (source_casted, ::mlc::base ::DecRef);
85
109
}
86
110
}
87
111
@@ -131,7 +155,8 @@ struct TypeTable {
131
155
return (self == nullptr ) ? TypeTable::Global () : static_cast <TypeTable *>(self);
132
156
}
133
157
134
- MLCTypeInfo *TypeRegister (int32_t parent_type_index, int32_t type_index, const char *type_key) {
158
+ MLCTypeInfo *TypeRegister (int32_t parent_type_index, int32_t type_index, const char *type_key,
159
+ MLCAttrGetterSetter getter, MLCAttrGetterSetter setter) {
135
160
// Step 1.Check if the type is already registered
136
161
if (auto it = this ->type_key_to_info .find (type_key); it != this ->type_key_to_info .end ()) {
137
162
MLCTypeInfo *ret = it->second ;
@@ -156,6 +181,8 @@ struct TypeTable {
156
181
info->type_index = type_index;
157
182
info->type_key = this ->NewArray (type_key);
158
183
info->type_depth = (parent == nullptr ) ? 0 : (parent->type_depth + 1 );
184
+ info->getter = getter;
185
+ info->setter = setter;
159
186
info->fields = nullptr ;
160
187
info->methods = nullptr ;
161
188
info->type_ancestors = this ->NewArray <int32_t >(info->type_depth );
@@ -209,24 +236,24 @@ struct TypeTable {
209
236
};
210
237
211
238
struct _POD_REG {
212
- inline static const int32_t _none = details ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCNone ))
213
- .Method(" __str__" , &PODTraits<std::nullptr_t >::__str__);
214
- inline static const int32_t _int = details ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCInt ))
215
- .Method(" __str__" , &PODTraits<int64_t >::__str__);
216
- inline static const int32_t _float = details ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCFloat ))
217
- .Method(" __str__" , &PODTraits<double >::__str__);
218
- inline static const int32_t _ptr = details ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCPtr ))
219
- .Method(" __str__" , &PODTraits<void *>::__str__);
239
+ inline static const int32_t _none = base ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCNone ))
240
+ .Method(" __str__" , &base:: PODTraits<std::nullptr_t >::__str__);
241
+ inline static const int32_t _int = base ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCInt ))
242
+ .Method(" __str__" , &base:: PODTraits<int64_t >::__str__);
243
+ inline static const int32_t _float = base ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCFloat ))
244
+ .Method(" __str__" , &base:: PODTraits<double >::__str__);
245
+ inline static const int32_t _ptr = base ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCPtr ))
246
+ .Method(" __str__" , &base:: PODTraits<void *>::__str__);
220
247
inline static const int32_t _device =
221
- details ::ReflectionHelper (static_cast <int32_t >(MLCTypeIndex::kMLCDevice ))
222
- .Method(" __str__" , &PODTraits<DLDevice>::__str__)
248
+ base ::ReflectionHelper (static_cast <int32_t >(MLCTypeIndex::kMLCDevice ))
249
+ .Method(" __str__" , &base:: PODTraits<DLDevice>::__str__)
223
250
.Method(" __init__" , [](AnyView device) { return device.operator DLDevice (); });
224
251
inline static const int32_t _dtype =
225
- details ::ReflectionHelper (static_cast <int32_t >(MLCTypeIndex::kMLCDataType ))
226
- .Method(" __str__" , &PODTraits<DLDataType>::__str__)
252
+ base ::ReflectionHelper (static_cast <int32_t >(MLCTypeIndex::kMLCDataType ))
253
+ .Method(" __str__" , &base:: PODTraits<DLDataType>::__str__)
227
254
.Method(" __init__" , [](AnyView dtype) { return dtype.operator DLDataType (); });
228
- inline static const int32_t _str = details ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCRawStr ))
229
- .Method(" __str__" , &PODTraits<const char *>::__str__);
255
+ inline static const int32_t _str = base ::ReflectionHelper(static_cast <int32_t >(MLCTypeIndex::kMLCRawStr ))
256
+ .Method(" __str__" , &base:: PODTraits<const char *>::__str__);
230
257
};
231
258
232
259
inline TypeTable *TypeTable::New () {
@@ -235,7 +262,8 @@ inline TypeTable *TypeTable::New() {
235
262
self->type_key_to_info .reserve (1024 );
236
263
self->num_types = static_cast <int32_t >(MLCTypeIndex::kMLCDynObjectBegin );
237
264
#define MLC_TYPE_TABLE_INIT_TYPE (TypeIndex, UnderlyingType, Self ) \
238
- Self->TypeRegister (-1 , static_cast <int32_t >(TypeIndex), PODTraits<UnderlyingType>::Type2Str ());
265
+ Self->TypeRegister (-1 , static_cast <int32_t >(TypeIndex), ::mlc::base::PODTraits<UnderlyingType>::Type2Str (), \
266
+ PODGetterSetter<UnderlyingType>::Getter, PODGetterSetter<UnderlyingType>::Setter);
239
267
240
268
MLC_TYPE_TABLE_INIT_TYPE (MLCTypeIndex::kMLCNone , std::nullptr_t , self);
241
269
MLC_TYPE_TABLE_INIT_TYPE (MLCTypeIndex::kMLCInt , int64_t , self);
@@ -315,7 +343,6 @@ inline void TypeInfoWrapper::SetMethods(int64_t new_num_methods, MLCTypeMethod *
315
343
}
316
344
317
345
} // namespace registry
318
- } // namespace ffi
319
346
} // namespace mlc
320
347
321
348
#endif // MLC_REGISTRY_H_
0 commit comments