Skip to content

Commit 4337876

Browse files
[Andrew Polyakov] Add support for NDArray in python package
1 parent e2cff8d commit 4337876

File tree

5 files changed

+100
-24
lines changed

5 files changed

+100
-24
lines changed

utbot-python-executor/src/main/python/utbot_executor/utbot_executor/deep_serialization/deep_serialization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def serialize_objects_dump(objs: List[Any], clear_visited: bool = False) -> Tupl
5050
serializer.write_object_to_memory(obj)
5151
for obj in objs
5252
]
53+
5354
return ids, serializer.memory, serialize_memory_dump(serializer.memory)
5455

5556

utbot-python-executor/src/main/python/utbot_executor/utbot_executor/deep_serialization/json_converter.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
ListMemoryObject,
1212
DictMemoryObject,
1313
ReduceMemoryObject,
14-
MemoryDump, IteratorMemoryObject,
14+
MemoryDump, IteratorMemoryObject, NdarrayMemoryObject,
1515
)
1616
from utbot_executor.deep_serialization.utils import PythonId, TypeInfo
17+
import numpy as np
1718

1819

1920
class MemoryObjectEncoder(json.JSONEncoder):
@@ -27,6 +28,10 @@ def default(self, o):
2728
}
2829
if isinstance(o, ReprMemoryObject):
2930
base_json["value"] = o.value
31+
elif isinstance(o, NdarrayMemoryObject):
32+
base_json["items"] = o.items
33+
base_json["comparable"] = True
34+
# raise AttributeError(base_json)
3035
elif isinstance(o, (ListMemoryObject, DictMemoryObject)):
3136
base_json["items"] = o.items
3237
elif isinstance(o, IteratorMemoryObject):
@@ -53,6 +58,9 @@ def default(self, o):
5358
"kind": o.kind,
5459
"module": o.module,
5560
}
61+
if isinstance(o, np.ndarray):
62+
# raise TypeError(f'Object {o.tolist()}, type: {type(o)}')
63+
return all(o.tolist())
5664
return json.JSONEncoder.default(self, o)
5765

5866

@@ -75,6 +83,16 @@ def as_reduce_object(dct: Dict) -> Union[MemoryObject, Dict]:
7583
)
7684
obj.comparable = dct["comparable"]
7785
return obj
86+
87+
if dct["strategy"] == "ndarray":
88+
obj = NdarrayMemoryObject.__new__(NdarrayMemoryObject)
89+
obj.items = dct["items"]
90+
obj.typeinfo = TypeInfo(
91+
kind=dct["typeinfo"]["kind"], module=dct["typeinfo"]["module"]
92+
)
93+
obj.comparable = dct["comparable"]
94+
return obj
95+
7896
if dct["strategy"] == "dict":
7997
obj = DictMemoryObject.__new__(DictMemoryObject)
8098
obj.items = dct["items"]
@@ -138,6 +156,10 @@ def reload_id(self) -> MemoryDump:
138156
new_memory_object.items = [
139157
self.dump_id_to_real_id[id_] for id_ in new_memory_object.items
140158
]
159+
elif isinstance(new_memory_object, NdarrayMemoryObject):
160+
new_memory_object.items = [
161+
self.dump_id_to_real_id[id_] for id_ in new_memory_object.items
162+
]
141163
elif isinstance(new_memory_object, IteratorMemoryObject):
142164
new_memory_object.items = [
143165
self.dump_id_to_real_id[id_] for id_ in new_memory_object.items
@@ -198,6 +220,17 @@ def load_object(self, python_id: PythonId) -> object:
198220

199221
for item in dump_object.items:
200222
real_object.append(self.load_object(item))
223+
elif isinstance(dump_object, NdarrayMemoryObject):
224+
# print(f"Hi", file=sys.stderr)
225+
real_object = []
226+
227+
id_ = PythonId(str(id(real_object)))
228+
self.dump_id_to_real_id[python_id] = id_
229+
self.memory[id_] = real_object
230+
231+
for item in dump_object.items:
232+
real_object = np.append(real_object, self.load_object(item))
233+
# real_object.append()
201234
elif isinstance(dump_object, DictMemoryObject):
202235
real_object = {}
203236

@@ -250,7 +283,7 @@ def load_object(self, python_id: PythonId) -> object:
250283
for key, dictitem in dictitems.items():
251284
real_object[key] = dictitem
252285
else:
253-
raise TypeError(f"Invalid type {dump_object}")
286+
raise TypeError(f"Invalid type {dump_object}, type: {type(dump_object)}")
254287

255288
id_ = PythonId(str(id(real_object)))
256289
self.dump_id_to_real_id[python_id] = id_
@@ -279,6 +312,10 @@ def main():
279312
"builtins.tuple",
280313
"builtins.bytes",
281314
"builtins.type",
315+
"numpy.ndarray"
282316
]
283317
)
284318
print(loader.load_object(PythonId("140239390887040")))
319+
320+
if __name__ == '__main__':
321+
main()

utbot-python-executor/src/main/python/utbot_executor/utbot_executor/deep_serialization/memory_objects.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import typing
99
from itertools import zip_longest
1010
from typing import Any, Callable, Dict, List, Optional, Set, Type, Iterable
11+
import numpy as np
1112

1213
from utbot_executor.deep_serialization.config import PICKLE_PROTO
1314
from utbot_executor.deep_serialization.iterator_wrapper import IteratorWrapper
@@ -41,7 +42,7 @@ def __init__(self, obj: object) -> None:
4142
self.id_ = PythonId(str(id(self.obj)))
4243

4344
def _initialize(
44-
self, deserialized_obj: object = None, comparable: bool = True
45+
self, deserialized_obj: object = None, comparable: bool = True
4546
) -> None:
4647
self.deserialized_obj = deserialized_obj
4748
self.comparable = comparable
@@ -115,6 +116,30 @@ def initialize(self) -> None:
115116

116117
super()._initialize(deserialized_obj, comparable)
117118

119+
class NdarrayMemoryObject(MemoryObject):
120+
strategy: str = "ndarray"
121+
items: List[PythonId] = []
122+
123+
def __init__(self, ndarray_object: object) -> None:
124+
self.items: List[PythonId] = []
125+
super().__init__(ndarray_object)
126+
127+
def initialize(self) -> None:
128+
serializer = PythonSerializer()
129+
self.deserialized_obj = [] # for recursive collections
130+
self.comparable = False # for recursive collections
131+
132+
for elem in self.obj:
133+
elem_id = serializer.write_object_to_memory(elem)
134+
self.items.append(elem_id)
135+
self.deserialized_obj.append(serializer[elem_id])
136+
137+
deserialized_obj = self.deserialized_obj
138+
comparable = all(serializer.get_by_id(elem).comparable for elem in self.items)
139+
# comparable = True
140+
141+
super()._initialize(deserialized_obj, comparable)
142+
118143
def __repr__(self) -> str:
119144
if hasattr(self, "obj"):
120145
return str(self.obj)
@@ -264,10 +289,10 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:
264289

265290
is_reconstructor = constructor_kind.qualname == "copyreg._reconstructor"
266291
is_reduce_user_type = (
267-
len(self.reduce_value[1]) == 3
268-
and isinstance(self.reduce_value[1][0], type(self.obj))
269-
and self.reduce_value[1][1] is object
270-
and self.reduce_value[1][2] is None
292+
len(self.reduce_value[1]) == 3
293+
and isinstance(self.reduce_value[1][0], type(self.obj))
294+
and self.reduce_value[1][1] is object
295+
and self.reduce_value[1][2] is None
271296
)
272297
is_reduce_ex_user_type = len(self.reduce_value[1]) == 1 and isinstance(
273298
self.reduce_value[1][0], type(self.obj)
@@ -294,8 +319,8 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:
294319
len(inspect.signature(init_method).parameters),
295320
)
296321
if (
297-
not init_from_object
298-
and len(inspect.signature(init_method).parameters) == 1
322+
not init_from_object
323+
and len(inspect.signature(init_method).parameters) == 1
299324
) or init_from_object:
300325
logging.debug("init with one argument! %s", init_method)
301326
constructor_arguments = []
@@ -317,9 +342,9 @@ def constructor_builder(self) -> typing.Tuple[typing.Any, typing.Callable]:
317342
if is_reconstructor and is_user_type:
318343
constructor_arguments = self.reduce_value[1]
319344
if (
320-
len(constructor_arguments) == 3
321-
and constructor_arguments[-1] is None
322-
and constructor_arguments[-2] == object
345+
len(constructor_arguments) == 3
346+
and constructor_arguments[-1] is None
347+
and constructor_arguments[-2] == object
323348
):
324349
del constructor_arguments[1:]
325350
callable_constructor = object.__new__
@@ -392,6 +417,12 @@ def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
392417
return ListMemoryObject
393418
return None
394419

420+
class NdarrayMemoryObjectProvider(MemoryObjectProvider):
421+
@staticmethod
422+
def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
423+
if type(obj) == np.ndarray:
424+
return NdarrayMemoryObject
425+
return None
395426

396427
class DictMemoryObjectProvider(MemoryObjectProvider):
397428
@staticmethod
@@ -425,6 +456,7 @@ def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
425456
return None
426457

427458

459+
428460
class ReprMemoryObjectProvider(MemoryObjectProvider):
429461
@staticmethod
430462
def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
@@ -456,6 +488,7 @@ class PythonSerializer:
456488
ReduceMemoryObjectProvider,
457489
ReprMemoryObjectProvider,
458490
ReduceExMemoryObjectProvider,
491+
NdarrayMemoryObjectProvider
459492
]
460493

461494
def __new__(cls):

utbot-python-executor/src/main/python/utbot_executor/utbot_executor/executor.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545

4646
def _update_states(
47-
init_memory_dump: MemoryDump, before_memory_dump: MemoryDump
47+
init_memory_dump: MemoryDump, before_memory_dump: MemoryDump
4848
) -> MemoryDump:
4949
for id_, obj in before_memory_dump.objects.items():
5050
if id_ in init_memory_dump.objects:
@@ -246,9 +246,9 @@ def run_pickle_function(self, request: ExecutionRequest) -> ExecutionResponse:
246246

247247

248248
def _serialize_state(
249-
args: List[Any],
250-
kwargs: Dict[str, Any],
251-
result: Any = None,
249+
args: List[Any],
250+
kwargs: Dict[str, Any],
251+
result: Any = None,
252252
) -> Tuple[List[PythonId], Dict[str, PythonId], PythonId, MemoryDump, str]:
253253
"""Serialize objects from args, kwargs and result.
254254
@@ -267,13 +267,13 @@ def _serialize_state(
267267

268268

269269
def _run_calculate_function_value(
270-
function: types.FunctionType,
271-
args: List[Any],
272-
kwargs: Dict[str, Any],
273-
fullpath: str,
274-
state_init: str,
275-
tracer: UtTracer,
276-
state_assertions: bool,
270+
function: types.FunctionType,
271+
args: List[Any],
272+
kwargs: Dict[str, Any],
273+
fullpath: str,
274+
state_init: str,
275+
tracer: UtTracer,
276+
state_assertions: bool,
277277
) -> ExecutionResponse:
278278
"""Calculate function evaluation result.
279279

utbot-python-executor/src/main/python/utbot_executor/utbot_executor/memory_compressor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from utbot_executor.deep_serialization.memory_objects import MemoryDump
44
from utbot_executor.deep_serialization.utils import PythonId
55

6+
import numpy as np
7+
68

79
def compress_memory(
810
ids: typing.List[PythonId],
@@ -13,7 +15,10 @@ def compress_memory(
1315
for id_ in ids:
1416
if id_ in state_before.objects and id_ in state_after.objects:
1517
try:
16-
if state_before.objects[id_].obj != state_after.objects[id_].obj:
18+
if isinstance(state_before.objects[id_].obj, np.ndarray) or isinstance(state_after.objects[id_].obj, np.ndarray):
19+
if (state_before.objects[id_].obj != state_after.objects[id_].obj).all():
20+
diff_ids.append(id_)
21+
elif state_before.objects[id_].obj != state_after.objects[id_].obj:
1722
diff_ids.append(id_)
1823
except AttributeError as _:
1924
pass

0 commit comments

Comments
 (0)