Skip to content

Commit 9e2bb88

Browse files
authored
fix(BA-934): Resolve GQL Agent live_stat field (#3928)
1 parent 1f77864 commit 9e2bb88

File tree

4 files changed

+58
-26
lines changed

4 files changed

+58
-26
lines changed

changes/3928.fix.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix GQL Agent `live_stat` resolver to properly parse UUID keys in JSON data as strings

src/ai/backend/common/msgpack.py

+45-24
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import os
88
import pickle
99
import uuid
10+
from collections.abc import Mapping
1011
from decimal import Decimal
1112
from pathlib import PosixPath, PurePosixPath
12-
from typing import Any
13+
from typing import Any, Callable, Optional, Protocol
1314

1415
import msgpack as _msgpack
1516
import temporenc
@@ -59,29 +60,45 @@ def _default(obj: object) -> _msgpack.ExtType:
5960
raise TypeError(f"Unknown type: {obj!r} ({type(obj)})")
6061

6162

62-
def _ext_hook(code: int, data: bytes) -> Any:
63-
match code:
64-
case ExtTypes.UUID:
65-
return uuid.UUID(bytes=data)
66-
case ExtTypes.DATETIME:
67-
return temporenc.unpackb(data).datetime()
68-
case ExtTypes.DECIMAL:
69-
return pickle.loads(data)
70-
case ExtTypes.POSIX_PATH:
71-
return PosixPath(os.fsdecode(data))
72-
case ExtTypes.PURE_POSIX_PATH:
73-
return PurePosixPath(os.fsdecode(data))
74-
case ExtTypes.ENUM:
75-
return pickle.loads(data)
76-
case ExtTypes.RESOURCE_SLOT:
77-
return pickle.loads(data)
78-
case ExtTypes.BACKENDAI_BINARY_SIZE:
79-
return pickle.loads(data)
80-
case ExtTypes.IMAGE_REF:
81-
return pickle.loads(data)
82-
return _msgpack.ExtType(code, data)
63+
class ExtFunc(Protocol):
64+
def __call__(self, data: bytes) -> Any:
65+
pass
8366

8467

68+
_DEFAULT_EXT_HOOK: Mapping[ExtTypes, ExtFunc] = {
69+
ExtTypes.UUID: lambda data: uuid.UUID(bytes=data),
70+
ExtTypes.DATETIME: lambda data: temporenc.unpackb(data).datetime(),
71+
ExtTypes.DECIMAL: lambda data: pickle.loads(data),
72+
ExtTypes.POSIX_PATH: lambda data: PosixPath(os.fsdecode(data)),
73+
ExtTypes.PURE_POSIX_PATH: lambda data: PurePosixPath(os.fsdecode(data)),
74+
ExtTypes.ENUM: lambda data: pickle.loads(data),
75+
ExtTypes.RESOURCE_SLOT: lambda data: pickle.loads(data),
76+
ExtTypes.BACKENDAI_BINARY_SIZE: lambda data: pickle.loads(data),
77+
ExtTypes.IMAGE_REF: lambda data: pickle.loads(data),
78+
}
79+
80+
81+
class _Deserializer:
82+
def __init__(self, mapping: Optional[Mapping[int, ExtFunc]] = None):
83+
self._ext_hook: dict[int, ExtFunc] = {}
84+
mapping = mapping or {}
85+
self._ext_hook = {**mapping}
86+
for ext_type, func in _DEFAULT_EXT_HOOK.items():
87+
if ext_type not in self._ext_hook:
88+
self._ext_hook[ext_type] = func
89+
90+
@property
91+
def ext_hook(self) -> Callable[[int, bytes], Any]:
92+
def _hook_callable(code: int, data: bytes) -> Any:
93+
if code in self._ext_hook:
94+
return self._ext_hook[code](data)
95+
return _msgpack.ExtType(code, data)
96+
97+
return _hook_callable
98+
99+
100+
uuid_to_str: Mapping[int, ExtFunc] = {ExtTypes.UUID: lambda data: str(uuid.UUID(bytes=data))}
101+
85102
DEFAULT_PACK_OPTS = {
86103
"use_bin_type": True, # bytes -> bin type (default for Python 3)
87104
"strict_types": True, # do not serialize subclasses using superclasses
@@ -92,7 +109,7 @@ def _ext_hook(code: int, data: bytes) -> Any:
92109
"raw": False, # assume str as UTF-8 (default for Python 3)
93110
"strict_map_key": False, # allow using UUID as map keys
94111
"use_list": False, # array -> tuple
95-
"ext_hook": _ext_hook,
112+
"ext_hook": _Deserializer().ext_hook,
96113
}
97114

98115

@@ -104,6 +121,10 @@ def packb(data: Any, **kwargs) -> bytes:
104121
return ret
105122

106123

107-
def unpackb(packed: bytes, **kwargs) -> Any:
124+
def unpackb(
125+
packed: bytes, ext_hook_mapping: Optional[Mapping[int, ExtFunc]] = None, **kwargs
126+
) -> Any:
108127
opts = {**DEFAULT_UNPACK_OPTS, **kwargs}
128+
if ext_hook_mapping is not None:
129+
opts["ext_hook"] = _Deserializer(ext_hook_mapping).ext_hook
109130
return _msgpack.unpackb(packed, **opts)

src/ai/backend/manager/models/gql_models/agent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ async def _pipe_builder(r: Redis):
238238
ret = []
239239
for stat in await redis_helper.execute(ctx.redis_stat, _pipe_builder):
240240
if stat is not None:
241-
ret.append(msgpack.unpackb(stat))
241+
ret.append(msgpack.unpackb(stat, ext_hook_mapping=msgpack.uuid_to_str))
242242
else:
243243
ret.append(None)
244244

@@ -602,7 +602,7 @@ async def _pipe_builder(r: Redis):
602602
ret = []
603603
for stat in await redis_helper.execute(ctx.redis_stat, _pipe_builder):
604604
if stat is not None:
605-
ret.append(msgpack.unpackb(stat))
605+
ret.append(msgpack.unpackb(stat, ext_hook_mapping=msgpack.uuid_to_str))
606606
else:
607607
ret.append(None)
608608

tests/common/test_msgpack.py

+10
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ def test_msgpack_uuid_as_map_key():
5656
assert unpacked[device_id] == 1234
5757

5858

59+
def test_msgpack_uuid_to_str():
60+
device_id = uuid.uuid4()
61+
str_device_id = str(device_id)
62+
data = {device_id: 1234}
63+
packed = msgpack.packb(data)
64+
unpacked = msgpack.unpackb(packed, ext_hook_mapping=msgpack.uuid_to_str)
65+
assert isinstance(next(iter(unpacked.keys())), str)
66+
assert unpacked[str_device_id] == 1234
67+
68+
5969
def test_msgpack_datetime():
6070
now = datetime.now(tzutc())
6171
data = {"timestamp": now}

0 commit comments

Comments
 (0)