7
7
import os
8
8
import pickle
9
9
import uuid
10
+ from collections .abc import Mapping
10
11
from decimal import Decimal
11
12
from pathlib import PosixPath , PurePosixPath
12
- from typing import Any
13
+ from typing import Any , Callable , Optional , Protocol
13
14
14
15
import msgpack as _msgpack
15
16
import temporenc
@@ -59,29 +60,45 @@ def _default(obj: object) -> _msgpack.ExtType:
59
60
raise TypeError (f"Unknown type: { obj !r} ({ type (obj )} )" )
60
61
61
62
62
- def _ext_hook (code : int , data : bytes ) -> Any :
63
- match code :
64
- case ExtTypes .UUID :
65
- return uuid .UUID (bytes = data )
66
- case ExtTypes .DATETIME :
67
- return temporenc .unpackb (data ).datetime ()
68
- case ExtTypes .DECIMAL :
69
- return pickle .loads (data )
70
- case ExtTypes .POSIX_PATH :
71
- return PosixPath (os .fsdecode (data ))
72
- case ExtTypes .PURE_POSIX_PATH :
73
- return PurePosixPath (os .fsdecode (data ))
74
- case ExtTypes .ENUM :
75
- return pickle .loads (data )
76
- case ExtTypes .RESOURCE_SLOT :
77
- return pickle .loads (data )
78
- case ExtTypes .BACKENDAI_BINARY_SIZE :
79
- return pickle .loads (data )
80
- case ExtTypes .IMAGE_REF :
81
- return pickle .loads (data )
82
- return _msgpack .ExtType (code , data )
63
+ class ExtFunc (Protocol ):
64
+ def __call__ (self , data : bytes ) -> Any :
65
+ pass
83
66
84
67
68
+ _DEFAULT_EXT_HOOK : Mapping [ExtTypes , ExtFunc ] = {
69
+ ExtTypes .UUID : lambda data : uuid .UUID (bytes = data ),
70
+ ExtTypes .DATETIME : lambda data : temporenc .unpackb (data ).datetime (),
71
+ ExtTypes .DECIMAL : lambda data : pickle .loads (data ),
72
+ ExtTypes .POSIX_PATH : lambda data : PosixPath (os .fsdecode (data )),
73
+ ExtTypes .PURE_POSIX_PATH : lambda data : PurePosixPath (os .fsdecode (data )),
74
+ ExtTypes .ENUM : lambda data : pickle .loads (data ),
75
+ ExtTypes .RESOURCE_SLOT : lambda data : pickle .loads (data ),
76
+ ExtTypes .BACKENDAI_BINARY_SIZE : lambda data : pickle .loads (data ),
77
+ ExtTypes .IMAGE_REF : lambda data : pickle .loads (data ),
78
+ }
79
+
80
+
81
+ class _Deserializer :
82
+ def __init__ (self , mapping : Optional [Mapping [int , ExtFunc ]] = None ):
83
+ self ._ext_hook : dict [int , ExtFunc ] = {}
84
+ mapping = mapping or {}
85
+ self ._ext_hook = {** mapping }
86
+ for ext_type , func in _DEFAULT_EXT_HOOK .items ():
87
+ if ext_type not in self ._ext_hook :
88
+ self ._ext_hook [ext_type ] = func
89
+
90
+ @property
91
+ def ext_hook (self ) -> Callable [[int , bytes ], Any ]:
92
+ def _hook_callable (code : int , data : bytes ) -> Any :
93
+ if code in self ._ext_hook :
94
+ return self ._ext_hook [code ](data )
95
+ return _msgpack .ExtType (code , data )
96
+
97
+ return _hook_callable
98
+
99
+
100
+ uuid_to_str : Mapping [int , ExtFunc ] = {ExtTypes .UUID : lambda data : str (uuid .UUID (bytes = data ))}
101
+
85
102
DEFAULT_PACK_OPTS = {
86
103
"use_bin_type" : True , # bytes -> bin type (default for Python 3)
87
104
"strict_types" : True , # do not serialize subclasses using superclasses
@@ -92,7 +109,7 @@ def _ext_hook(code: int, data: bytes) -> Any:
92
109
"raw" : False , # assume str as UTF-8 (default for Python 3)
93
110
"strict_map_key" : False , # allow using UUID as map keys
94
111
"use_list" : False , # array -> tuple
95
- "ext_hook" : _ext_hook ,
112
+ "ext_hook" : _Deserializer (). ext_hook ,
96
113
}
97
114
98
115
@@ -104,6 +121,10 @@ def packb(data: Any, **kwargs) -> bytes:
104
121
return ret
105
122
106
123
107
- def unpackb (packed : bytes , ** kwargs ) -> Any :
124
+ def unpackb (
125
+ packed : bytes , ext_hook_mapping : Optional [Mapping [int , ExtFunc ]] = None , ** kwargs
126
+ ) -> Any :
108
127
opts = {** DEFAULT_UNPACK_OPTS , ** kwargs }
128
+ if ext_hook_mapping is not None :
129
+ opts ["ext_hook" ] = _Deserializer (ext_hook_mapping ).ext_hook
109
130
return _msgpack .unpackb (packed , ** opts )
0 commit comments