@@ -627,7 +627,7 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
627
627
}
628
628
// `task.visited` was `False`
629
629
int64_t task_index = static_cast <int64_t >(tasks.size ()) - 1 ;
630
- if (type_info-> type_index == kMLCList ) {
630
+ if (lhs-> IsInstance <UListObj>() ) {
631
631
UListObj *lhs_list = reinterpret_cast <UListObj *>(lhs);
632
632
UListObj *rhs_list = reinterpret_cast <UListObj *>(rhs);
633
633
int64_t lhs_size = lhs_list->size ();
@@ -639,7 +639,7 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
639
639
auto &err = tasks[task_index].err = std::make_unique<std::ostringstream>();
640
640
(*err) << " List length mismatch: " << lhs_size << " vs " << rhs_size;
641
641
}
642
- } else if (type_info-> type_index == kMLCDict ) {
642
+ } else if (lhs-> IsInstance <UDictObj>() ) {
643
643
UDictObj *lhs_dict = reinterpret_cast <UDictObj *>(lhs);
644
644
UDictObj *rhs_dict = reinterpret_cast <UDictObj *>(rhs);
645
645
std::vector<AnyView> not_found_lhs_keys;
@@ -892,13 +892,13 @@ inline uint64_t StructuralHashImpl(Object *obj) {
892
892
task.index_in_result_hashes = result_hashes.size ();
893
893
}
894
894
// `task.visited` was `False`
895
- if (type_info-> type_index == kMLCList ) {
895
+ if (obj-> IsInstance <UListObj>() ) {
896
896
UListObj *list = reinterpret_cast <UListObj *>(obj);
897
897
hash_value = HashCombine (hash_value, list->size ());
898
898
for (int64_t i = list->size () - 1 ; i >= 0 ; --i) {
899
899
Visitor::EnqueueAny (&tasks, bind_free_vars, &list->at (i));
900
900
}
901
- } else if (type_info-> type_index == kMLCDict ) {
901
+ } else if (obj-> IsInstance <UDictObj>() ) {
902
902
UDictObj *dict = reinterpret_cast <UDictObj *>(obj);
903
903
hash_value = HashCombine (hash_value, dict->size ());
904
904
struct KVPair {
@@ -1151,8 +1151,16 @@ inline Any CopyDeepImpl(AnyView source) {
1151
1151
} else if (object->IsInstance <StrObj>() || object->IsInstance <ErrorObj>() || object->IsInstance <FuncObj>() ||
1152
1152
object->IsInstance <TensorObj>()) {
1153
1153
ret = object;
1154
- } else if (object->IsInstance <OpaqueObj>()) {
1155
- MLC_THROW (TypeError) << " Cannot copy `mlc.Opaque` of type: " << object->DynCast <OpaqueObj>()->opaque_type_name ;
1154
+ } else if (OpaqueObj *opaque = object->as <OpaqueObj>()) {
1155
+ std::string func_name = " Opaque.deepcopy." ;
1156
+ func_name += opaque->opaque_type_name ;
1157
+ FuncObj *func = Func::GetGlobal (func_name.c_str (), true );
1158
+ if (func == nullptr ) {
1159
+ MLC_THROW (ValueError) << " Cannot deepcopy `mlc.Opaque` of type: " << opaque->opaque_type_name
1160
+ << " ; Use `mlc.Func.register(\" " << func_name
1161
+ << " \" )(deepcopy_func)` to register a deepcopy method" ;
1162
+ }
1163
+ ret = (*func)(object);
1156
1164
} else {
1157
1165
fields.clear ();
1158
1166
VisitFields (object, type_info, Copier{&orig2copy, &fields});
@@ -1627,8 +1635,8 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len, FuncObj *fn_o
1627
1635
MLC_THROW (ValueError) << " Invalid reference when parsing type `" << type_keys[json_type_index]
1628
1636
<< " `: referring #" << k << " at #" << i << " . v = " << value;
1629
1637
}
1630
- } else if (arg. type_index == kMLCList ) {
1631
- (*list)[j] = invoke_init (arg. operator UList ());
1638
+ } else if (UListObj *arg_list = arg. as <UListObj>() ) {
1639
+ (*list)[j] = invoke_init (UList (arg_list ));
1632
1640
} else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
1633
1641
arg.type_index == kMLCNone ) {
1634
1642
// Do nothing
0 commit comments