@@ -112,13 +112,15 @@ def initialize(self) -> None:
112
112
elif self .typeinfo .fullname == "builtins.set" :
113
113
deserialized_obj = set (deserialized_obj )
114
114
115
+
115
116
comparable = all (serializer .get_by_id (elem ).comparable for elem in self .items )
116
117
117
118
super ()._initialize (deserialized_obj , comparable )
118
119
119
120
class NdarrayMemoryObject (MemoryObject ):
120
121
strategy : str = "ndarray"
121
122
items : List [PythonId ] = []
123
+ dimensions : List [int ] = []
122
124
123
125
def __init__ (self , ndarray_object : object ) -> None :
124
126
self .items : List [PythonId ] = []
@@ -129,21 +131,23 @@ def initialize(self) -> None:
129
131
self .deserialized_obj = [] # for recursive collections
130
132
self .comparable = False # for recursive collections
131
133
132
- for elem in self .obj :
133
- elem_id = serializer .write_object_to_memory (elem )
134
- self .items .append (elem_id )
135
- self .deserialized_obj .append (serializer [elem_id ])
134
+ temp_object = self .obj .copy ().flatten ()
136
135
137
- deserialized_obj = self .deserialized_obj
138
- comparable = all (serializer .get_by_id (elem ).comparable for elem in self .items )
139
- # comparable = True
136
+ self .dimensions = self .obj .shape
137
+ if temp_object .shape != (0 , ):
138
+ for elem in temp_object :
139
+ elem_id = serializer .write_object_to_memory (elem )
140
+ self .items .append (elem_id )
141
+ self .deserialized_obj .append (serializer [elem_id ])
140
142
143
+ deserialized_obj = self .deserialized_obj
144
+ comparable = all (serializer .get_by_id (elem ).comparable for elem in self .items ) if self .deserialized_obj != [] else True
141
145
super ()._initialize (deserialized_obj , comparable )
142
146
143
147
def __repr__ (self ) -> str :
144
148
if hasattr (self , "obj" ):
145
149
return str (self .obj )
146
- return f"{ self .typeinfo .kind } { self .items } "
150
+ return f"{ self .typeinfo .kind } { self .items } { self . dimensions } "
147
151
148
152
149
153
class DictMemoryObject (MemoryObject ):
@@ -413,7 +417,7 @@ def get_serializer(obj: object) -> Optional[Type[MemoryObject]]:
413
417
class ListMemoryObjectProvider (MemoryObjectProvider ):
414
418
@staticmethod
415
419
def get_serializer (obj : object ) -> Optional [Type [MemoryObject ]]:
416
- if any (type (obj ) == t for t in (list , set , tuple , frozenset )):
420
+ if any (type (obj ) == t for t in (list , set , tuple , frozenset )) and type ( obj ) != np . ndarray :
417
421
return ListMemoryObject
418
422
return None
419
423
@@ -482,13 +486,13 @@ class PythonSerializer:
482
486
visited : Set [PythonId ] = set ()
483
487
484
488
providers : List [MemoryObjectProvider ] = [
489
+ NdarrayMemoryObjectProvider ,
485
490
ListMemoryObjectProvider ,
486
491
DictMemoryObjectProvider ,
487
492
IteratorMemoryObjectProvider ,
488
493
ReduceMemoryObjectProvider ,
489
494
ReprMemoryObjectProvider ,
490
495
ReduceExMemoryObjectProvider ,
491
- NdarrayMemoryObjectProvider
492
496
]
493
497
494
498
def __new__ (cls ):
0 commit comments