Skip to content

Commit 7479bbf

Browse files
author
Sherin Thomas
authored
Merge pull request #17 from RedisAI/dtype_issue
Dtype issue fix
2 parents fd69104 + 3f13408 commit 7479bbf

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

redisai/client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _convert_to_num(dt, arr):
5454
if isinstance(obj, list):
5555
_convert_to_num(obj)
5656
else:
57-
if dt in (DType.float.value, DType.double.value):
57+
if dt in (DType.float, DType.double):
5858
arr[ix] = float(obj)
5959
else:
6060
arr[ix] = int(obj)
@@ -159,6 +159,8 @@ def to_numpy(self):
159159

160160
@staticmethod
161161
def _to_numpy_type(t):
162+
if isinstance(t, DType):
163+
t = t.value
162164
mm = {
163165
'FLOAT': 'float32',
164166
'DOUBLE': 'float64'
@@ -237,10 +239,11 @@ def tensorget(self, key, as_type=Tensor, meta_only=False):
237239
argname = 'META' if meta_only else as_type.ARGNAME
238240
res = self.execute_command('AI.TENSORGET', key, argname)
239241
dtype, shape = to_string(res[0]), res[1]
242+
dt = DType.__members__[dtype.lower()]
240243
if meta_only:
241-
return as_type(dtype, shape, [])
244+
return as_type(dt, shape, [])
242245
else:
243-
return as_type.from_resp(dtype, shape, res[2])
246+
return as_type.from_resp(dt, shape, res[2])
244247

245248
def scriptset(self, name, device, script):
246249
return self.execute_command('AI.SCRIPTSET', name, device.value, script)

setup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
#!/usr/bin/env python
33
from setuptools import setup, find_packages
44

5-
exec(open('redisai/version.py').read())
5+
exec(open('redisai/version.py', encoding='utf-8').read())
6+
with open('README.md') as f:
7+
long_description = f.read()
8+
69

710
setup(
811
name='redisai',
912
version=__version__, # comes from redisai/version.py
10-
1113
description='RedisAI Python Client',
14+
long_description=long_description,
15+
long_description_content_type='text/markdown',
1216
url='http://github.com/RedisAI/redisai-py',
1317
author='RedisLabs',
1418
author_email='[email protected]',

test/test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@ def test_set_tensor(self):
4040

4141
def test_numpy_tensor(self):
4242
con = self.get_client()
43-
con.tensorset('x', np.array([2, 3]))
43+
input_array = np.array([2, 3])
44+
con.tensorset('x', input_array)
4445
values = con.tensorget('x').value
4546
self.assertEqual([2, 3], values)
47+
values = con.tensorget('x', as_type=BlobTensor)
48+
self.assertTrue(np.allclose(input_array, values.to_numpy()))
49+
ret = con.tensorset('x', values)
50+
self.assertEqual(ret, b'OK')
4651

4752
def test_run_tf_model(self):
4853
model_path = os.path.join(MODEL_DIR, 'graph.pb')

0 commit comments

Comments
 (0)