Skip to content

Commit 50c1fd4

Browse files
authored
feat(core): Support deepcopy in opaque (#76)
1 parent 45fae3e commit 50c1fd4

File tree

3 files changed

+54
-22
lines changed

3 files changed

+54
-22
lines changed

cpp/structure.cc

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
627627
}
628628
// `task.visited` was `False`
629629
int64_t task_index = static_cast<int64_t>(tasks.size()) - 1;
630-
if (type_info->type_index == kMLCList) {
630+
if (lhs->IsInstance<UListObj>()) {
631631
UListObj *lhs_list = reinterpret_cast<UListObj *>(lhs);
632632
UListObj *rhs_list = reinterpret_cast<UListObj *>(rhs);
633633
int64_t lhs_size = lhs_list->size();
@@ -639,7 +639,7 @@ inline void StructuralEqualImpl(Object *lhs, Object *rhs, bool bind_free_vars) {
639639
auto &err = tasks[task_index].err = std::make_unique<std::ostringstream>();
640640
(*err) << "List length mismatch: " << lhs_size << " vs " << rhs_size;
641641
}
642-
} else if (type_info->type_index == kMLCDict) {
642+
} else if (lhs->IsInstance<UDictObj>()) {
643643
UDictObj *lhs_dict = reinterpret_cast<UDictObj *>(lhs);
644644
UDictObj *rhs_dict = reinterpret_cast<UDictObj *>(rhs);
645645
std::vector<AnyView> not_found_lhs_keys;
@@ -892,13 +892,13 @@ inline uint64_t StructuralHashImpl(Object *obj) {
892892
task.index_in_result_hashes = result_hashes.size();
893893
}
894894
// `task.visited` was `False`
895-
if (type_info->type_index == kMLCList) {
895+
if (obj->IsInstance<UListObj>()) {
896896
UListObj *list = reinterpret_cast<UListObj *>(obj);
897897
hash_value = HashCombine(hash_value, list->size());
898898
for (int64_t i = list->size() - 1; i >= 0; --i) {
899899
Visitor::EnqueueAny(&tasks, bind_free_vars, &list->at(i));
900900
}
901-
} else if (type_info->type_index == kMLCDict) {
901+
} else if (obj->IsInstance<UDictObj>()) {
902902
UDictObj *dict = reinterpret_cast<UDictObj *>(obj);
903903
hash_value = HashCombine(hash_value, dict->size());
904904
struct KVPair {
@@ -1151,8 +1151,16 @@ inline Any CopyDeepImpl(AnyView source) {
11511151
} else if (object->IsInstance<StrObj>() || object->IsInstance<ErrorObj>() || object->IsInstance<FuncObj>() ||
11521152
object->IsInstance<TensorObj>()) {
11531153
ret = object;
1154-
} else if (object->IsInstance<OpaqueObj>()) {
1155-
MLC_THROW(TypeError) << "Cannot copy `mlc.Opaque` of type: " << object->DynCast<OpaqueObj>()->opaque_type_name;
1154+
} else if (OpaqueObj *opaque = object->as<OpaqueObj>()) {
1155+
std::string func_name = "Opaque.deepcopy.";
1156+
func_name += opaque->opaque_type_name;
1157+
FuncObj *func = Func::GetGlobal(func_name.c_str(), true);
1158+
if (func == nullptr) {
1159+
MLC_THROW(ValueError) << "Cannot deepcopy `mlc.Opaque` of type: " << opaque->opaque_type_name
1160+
<< "; Use `mlc.Func.register(\"" << func_name
1161+
<< "\")(deepcopy_func)` to register a deepcopy method";
1162+
}
1163+
ret = (*func)(object);
11561164
} else {
11571165
fields.clear();
11581166
VisitFields(object, type_info, Copier{&orig2copy, &fields});
@@ -1627,8 +1635,8 @@ inline Any Deserialize(const char *json_str, int64_t json_str_len, FuncObj *fn_o
16271635
MLC_THROW(ValueError) << "Invalid reference when parsing type `" << type_keys[json_type_index]
16281636
<< "`: referring #" << k << " at #" << i << ". v = " << value;
16291637
}
1630-
} else if (arg.type_index == kMLCList) {
1631-
(*list)[j] = invoke_init(arg.operator UList());
1638+
} else if (UListObj *arg_list = arg.as<UListObj>()) {
1639+
(*list)[j] = invoke_init(UList(arg_list));
16321640
} else if (arg.type_index == kMLCStr || arg.type_index == kMLCBool || arg.type_index == kMLCFloat ||
16331641
arg.type_index == kMLCNone) {
16341642
// Do nothing

python/mlc/core/opaque.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import copy
34
from collections.abc import Callable
4-
from typing import Any
5+
from typing import Any, Literal
56

67
from mlc._cython import Ptr, c_class_core, func_register, opaque_init, register_opauqe_type
78

@@ -18,17 +19,33 @@ def __init__(self, instance: Any) -> None:
1819
@staticmethod
1920
def register(
2021
ty: type,
21-
eq_s: Callable | None = None,
22-
hash_s: Callable | None = None,
22+
eq_s: Callable | Literal["default"] | None = "default",
23+
hash_s: Callable | Literal["default"] | None = "default",
24+
deepcopy: Callable | Literal["default"] | None = "default",
2325
) -> None:
2426
register_opauqe_type(ty)
2527
name = ty.__module__ + "." + ty.__name__
26-
if eq_s is not None:
27-
assert callable(eq_s)
28+
29+
if isinstance(eq_s, str) and eq_s == "default":
30+
func_register(f"Opaque.eq_s.{name}", False, lambda a, b: a == b)
31+
elif callable(eq_s):
2832
func_register(f"Opaque.eq_s.{name}", False, eq_s)
29-
if hash_s is not None:
30-
assert callable(hash_s)
33+
else:
34+
assert eq_s is None, "eq_s must be a callable, a literal 'default', or None"
35+
36+
if isinstance(hash_s, str) and hash_s == "default":
37+
func_register(f"Opaque.hash_s.{name}", False, lambda a: hash(a))
38+
elif callable(hash_s):
3139
func_register(f"Opaque.hash_s.{name}", False, hash_s)
40+
else:
41+
assert hash_s is None, "hash_s must be a callable, a literal 'default', or None"
42+
43+
if isinstance(deepcopy, str) and deepcopy == "default":
44+
func_register(f"Opaque.deepcopy.{name}", False, lambda a: copy.deepcopy(a))
45+
elif callable(deepcopy):
46+
func_register(f"Opaque.deepcopy.{name}", False, deepcopy)
47+
else:
48+
assert deepcopy is None, "deepcopy must be a callable, a literal 'default', or None"
3249

3350

3451
def _default_serialize(opaques: list[Any]) -> str:

tests/python/test_core_opaque.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import json
23
from typing import Any
34

@@ -17,19 +18,15 @@ def __init__(self, a: int) -> None:
1718
def __call__(self, x: int) -> int:
1819
return x + self.a
1920

20-
def eq_s(self, other: Any) -> bool:
21+
def __eq__(self, other: Any) -> bool:
2122
return isinstance(self, MyType) and isinstance(other, MyType) and self.a == other.a
2223

23-
def hash_s(self) -> int:
24+
def __hash__(self) -> int:
2425
assert isinstance(self, MyType)
2526
return hash((MyType, self.a))
2627

2728

28-
mlc.Opaque.register(
29-
MyType,
30-
eq_s=MyType.eq_s,
31-
hash_s=MyType.hash_s,
32-
)
29+
mlc.Opaque.register(MyType)
3330

3431

3532
@mlc.dataclasses.py_class(structure="bind")
@@ -124,3 +121,13 @@ def test_opaque_serialize_with_alias() -> None:
124121
assert obj_2.field[3].a == 30
125122
assert obj_2.field[4].a == 20
126123
assert obj_2.field[5].a == 10
124+
125+
126+
def test_opaque_deepcopy() -> None:
127+
a = MyType(a=10)
128+
obj_1 = Wrapper(field=a)
129+
obj_2 = copy.deepcopy(obj_1)
130+
assert isinstance(obj_2.field, MyType)
131+
assert obj_2.field.a == 10
132+
assert obj_1 is not obj_2
133+
assert obj_1.field is not obj_2.field

0 commit comments

Comments
 (0)