Skip to content

Commit 0e25229

Browse files
author
Sherin Thomas
authored
Merge pull request #37 from RedisAI/pypifix
Quick Cleanup
2 parents 8cfd097 + 39d5492 commit 0e25229

File tree

8 files changed

+42
-17
lines changed

8 files changed

+42
-17
lines changed

.bumpversion.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 1.0.0
2+
current_version = 1.0.1
33
commit = True
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<build>\d+))?

docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
project = 'redisai-py'
22
copyright = '2020, RedisLabs'
33
author = 'RedisLabs'
4-
release = '1.0.0'
4+
release = '1.0.1'
55
extensions = ['sphinx.ext.autodoc',
66
'sphinx.ext.autosummary',
77
'sphinx.ext.extlinks',

redisai/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .client import Client
22

3-
__version__ = '1.0.0'
3+
__version__ = '1.0.1'

redisai/client.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def tensorget(self,
6262
self.commands.extend(args)
6363
self.commands.append("|>")
6464
self.result_processors.append(partial(utils.tensorget_postprocessor,
65-
as_numpy,
66-
meta_only))
65+
as_numpy=as_numpy,
66+
meta_only=meta_only))
6767
return self
6868

6969
def modelrun(self,
@@ -407,7 +407,7 @@ def tensorget(self,
407407
"""
408408
args = builder.tensorget(key, as_numpy, meta_only)
409409
res = self.execute_command(*args)
410-
return utils.tensorget_postprocessor(as_numpy, meta_only, res)
410+
return utils.tensorget_postprocessor(res, as_numpy, meta_only)
411411

412412
def scriptset(self, key: AnyStr, device: str, script: str, tag: AnyStr = None) -> str:
413413
"""

redisai/command_builder.py

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ def modelset(self, name: AnyStr, backend: str, device: str, data: ByteString,
1515
batch: int, minbatch: int, tag: AnyStr,
1616
inputs: Union[AnyStr, List[AnyStr]],
1717
outputs: Union[AnyStr, List[AnyStr]]) -> Sequence:
18+
if device.upper() not in utils.allowed_devices:
19+
raise ValueError(f"Device not allowed. Use any from {utils.allowed_devices}")
20+
if backend.upper() not in utils.allowed_backends:
21+
raise ValueError(f"Backend not allowed. Use any from {utils.allowed_backends}")
1822
args = ['AI.MODELSET', name, backend, device]
1923

2024
if batch is not None:
@@ -87,6 +91,8 @@ def tensorget(self,
8791
return args
8892

8993
def scriptset(self, name: AnyStr, device: str, script: str, tag: AnyStr = None) -> Sequence:
94+
if device.upper() not in utils.allowed_devices:
95+
raise ValueError(f"Device not allowed. Use any from {utils.allowed_devices}")
9096
args = ['AI.SCRIPTSET', name, device]
9197
if tag:
9298
args += ['TAG', tag]

redisai/utils.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
'uint32': 'UINT32',
1717
'uint64': 'UINT64'}
1818

19+
allowed_devices = {'CPU', 'GPU'}
20+
allowed_backends = {'TF', 'TFLITE', 'TORCH', 'ONNX'}
21+
1922

2023
def numpy2blob(tensor: np.ndarray) -> tuple:
21-
""" Convert the numpy input from user to `Tensor` """
24+
"""Convert the numpy input from user to `Tensor`."""
2225
try:
2326
dtype = dtype_dict[str(tensor.dtype)]
2427
except KeyError:
@@ -29,7 +32,7 @@ def numpy2blob(tensor: np.ndarray) -> tuple:
2932

3033

3134
def blob2numpy(value: ByteString, shape: Union[list, tuple], dtype: str) -> np.ndarray:
32-
""" Convert `BLOB` result from RedisAI to `np.ndarray` """
35+
"""Convert `BLOB` result from RedisAI to `np.ndarray`."""
3336
mm = {
3437
'FLOAT': 'float32',
3538
'DOUBLE': 'float64'
@@ -40,6 +43,7 @@ def blob2numpy(value: ByteString, shape: Union[list, tuple], dtype: str) -> np.n
4043

4144

4245
def list2dict(lst):
46+
"""Convert the list from RedisAI to a dict."""
4347
if len(lst) % 2 != 0:
4448
raise RuntimeError("Can't unpack the list: {}".format(lst))
4549
out = {}
@@ -55,10 +59,8 @@ def list2dict(lst):
5559
def recursive_bytetransform(arr: List[AnyStr], target: Callable) -> list:
5660
"""
5761
Recurse value, replacing each element of b'' with the appropriate element.
58-
Function returns the same array after inplace operation which updates `arr`
5962
60-
:param target: Type of tensor | array
61-
:param arr: The array with b'' numbers or recursive array of b''
63+
Function returns the same array after inplace operation which updates `arr`
6264
"""
6365
for ix in range(len(arr)):
6466
obj = arr[ix]
@@ -70,10 +72,16 @@ def recursive_bytetransform(arr: List[AnyStr], target: Callable) -> list:
7072

7173

7274
def listify(inp: Union[str, Sequence[str]]) -> Sequence[str]:
75+
"""Wrap the ``inp`` with a list if it's not a list already."""
7376
return (inp,) if not isinstance(inp, (list, tuple)) else inp
7477

7578

76-
def tensorget_postprocessor(as_numpy, meta_only, rai_result):
79+
def tensorget_postprocessor(rai_result, as_numpy, meta_only):
80+
"""Process the tensorget output.
81+
82+
If ``as_numpy`` is True, it'll be converted to a numpy array. The required
83+
information such as datatype and shape must be in ``rai_result`` itself.
84+
"""
7785
rai_result = list2dict(rai_result)
7886
if meta_only:
7987
return rai_result

setup.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88
setup(
99
name='redisai',
10-
version='1.0.0',
10+
version='1.0.1',
1111
description='RedisAI Python Client',
1212
long_description=long_description,
13-
long_description_content_type='text/markdown',
13+
long_description_content_type='text/x-rst',
1414
url='http://github.com/RedisAI/redisai-py',
1515
author='RedisLabs',
1616
author_email='[email protected]',
1717
packages=find_packages(),
1818
install_requires=['redis', 'hiredis', 'numpy'],
19-
python_requires='>=3.2',
19+
python_requires='>=3.6',
2020
classifiers=[
2121
'Development Status :: 4 - Beta',
2222
'Intended Audience :: Developers',
@@ -26,7 +26,6 @@
2626
'Programming Language :: Python :: 3.6',
2727
'Programming Language :: Python :: 3.7',
2828
'Programming Language :: Python :: 3.8',
29-
'Topic :: Database',
30-
'Topic :: Software Development :: Testing'
29+
'Topic :: Database'
3130
]
3231
)

test/test.py

+12
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ def test_numpy_tensor(self):
106106
with self.assertRaises(TypeError):
107107
con.tensorset('trying', stringarr)
108108

109+
def test_modelset_errors(self):
110+
model_path = os.path.join(MODEL_DIR, 'graph.pb')
111+
model_pb = load_model(model_path)
112+
con = self.get_client()
113+
with self.assertRaises(ValueError):
114+
con.modelset('m', 'tf', 'wrongdevice', model_pb,
115+
inputs=['a', 'b'], outputs=['mul'], tag='v1.0')
116+
with self.assertRaises(ValueError):
117+
con.modelset('m', 'wrongbackend', 'cpu', model_pb,
118+
inputs=['a', 'b'], outputs=['mul'], tag='v1.0')
119+
120+
109121
def test_modelget_meta(self):
110122
model_path = os.path.join(MODEL_DIR, 'graph.pb')
111123
model_pb = load_model(model_path)

0 commit comments

Comments
 (0)