Skip to content

Commit b496bec

Browse files
committed
generalise pickle and msgpack to handle arrays with any dtype
1 parent 08e24b8 commit b496bec

File tree

5 files changed

+25
-19
lines changed

5 files changed

+25
-19
lines changed

numcodecs/msgpacks.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111

1212
class MsgPack(Codec):
13-
"""Codec to encode data as msgpacked bytes. Useful for encoding an array of Python strings.
13+
"""Codec to encode data as msgpacked bytes. Useful for encoding an array of Python string
14+
objects.
1415
1516
Examples
1617
--------
@@ -33,20 +34,27 @@ class MsgPack(Codec):
3334

3435
codec_id = 'msgpack'
3536

37+
def __init__(self, encoding='utf-8'):
38+
self.encoding = encoding
39+
3640
def encode(self, buf):
37-
buf = np.asarray(buf, dtype='object')
38-
return msgpack.packb(buf.tolist(), encoding='utf-8')
41+
buf = np.asarray(buf)
42+
l = buf.tolist()
43+
l.append(buf.dtype.str)
44+
return msgpack.packb(l, encoding=self.encoding)
3945

4046
def decode(self, buf, out=None):
41-
dec = np.array(msgpack.unpackb(buf, encoding='utf-8'), dtype='object')
47+
l = msgpack.unpackb(buf, encoding=self.encoding)
48+
dec = np.array(l[:-1], dtype=l[-1])
4249
if out is not None:
4350
np.copyto(out, dec)
4451
return out
4552
else:
4653
return dec
4754

4855
def get_config(self):
49-
return dict(id=self.codec_id)
56+
return dict(id=self.codec_id,
57+
encoding=self.encoding)
5058

5159
def __repr__(self):
52-
return 'MsgPack()'
60+
return 'MsgPack(encoding=%r)' % self.encoding

numcodecs/pickles.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515

1616
class Pickle(Codec):
17-
"""Codec to encode data as as pickled bytes. Useful for encoding an array of Python strings.
17+
"""Codec to encode data as as pickled bytes. Useful for encoding an array of Python string
18+
objects.
1819
1920
Parameters
2021
----------
@@ -42,7 +43,6 @@ def __init__(self, protocol=pickle.HIGHEST_PROTOCOL):
4243
self.protocol = protocol
4344

4445
def encode(self, buf):
45-
buf = np.asarray(buf, dtype='object')
4646
return pickle.dumps(buf, protocol=self.protocol)
4747

4848
def decode(self, buf, out=None):

numcodecs/tests/common.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,13 @@ def compare(res):
9191
compare(out)
9292

9393

94-
def check_encode_decode_objects(arr, codec):
95-
96-
# this is a more specific test that check_encode_decode
97-
# as these require actual objects (and not bytes only)
94+
def check_encode_decode_array(arr, codec):
9895

9996
def compare(res, arr=arr):
10097

10198
assert_true(isinstance(res, np.ndarray))
10299
assert_true(res.shape == arr.shape)
103-
assert_true(res.dtype == 'object')
100+
assert_true(res.dtype == arr.dtype)
104101

105102
# numpy asserts don't compare object arrays
106103
# properly; assert that we have the same nans
@@ -117,7 +114,7 @@ def compare(res, arr=arr):
117114
dec = codec.decode(enc)
118115
compare(dec)
119116

120-
out = np.empty_like(arr, dtype='object')
117+
out = np.empty_like(arr)
121118
codec.decode(enc, out=out)
122119
compare(out)
123120

numcodecs/tests/test_msgpacks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
except ImportError: # pragma: no cover
1111
raise nose.SkipTest("msgpack-python not available")
1212

13-
from numcodecs.tests.common import check_config, check_repr, check_encode_decode_objects
13+
from numcodecs.tests.common import check_config, check_repr, check_encode_decode_array
1414

1515

1616
# object array with strings
@@ -28,7 +28,7 @@
2828
def test_encode_decode():
2929
for arr in arrays:
3030
codec = MsgPack()
31-
check_encode_decode_objects(arr, codec)
31+
check_encode_decode_array(arr, codec)
3232

3333

3434
def test_config():
@@ -37,4 +37,5 @@ def test_config():
3737

3838

3939
def test_repr():
40-
check_repr("MsgPack()")
40+
check_repr("MsgPack(encoding='utf-8')")
41+
check_repr("MsgPack(encoding='ascii')")

numcodecs/tests/test_pickles.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
from numcodecs.pickles import Pickle
9-
from numcodecs.tests.common import check_config, check_repr, check_encode_decode_objects
9+
from numcodecs.tests.common import check_config, check_repr, check_encode_decode_array
1010

1111

1212
# object array with strings
@@ -24,7 +24,7 @@
2424
def test_encode_decode():
2525
codec = Pickle()
2626
for arr in arrays:
27-
check_encode_decode_objects(arr, codec)
27+
check_encode_decode_array(arr, codec)
2828

2929

3030
def test_config():

0 commit comments

Comments
 (0)