Skip to content

Commit 00c3c2a

Browse files
author
Sherin Thomas
authored
Merge pull request #32 from RedisAI/cleanup
Consistency and clean up
2 parents e3331f9 + 7536b62 commit 00c3c2a

File tree

12 files changed

+312
-287
lines changed

12 files changed

+312
-287
lines changed

.circleci/config.yml

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ version: 2
77
jobs:
88
build:
99
docker:
10-
- image: circleci/python:3.7.1
10+
- image: circleci/python:3.6.9
1111
- image: redisai/redisai:edge
1212

1313
working_directory: ~/repo
@@ -17,7 +17,7 @@ jobs:
1717

1818
- restore_cache: # Download and cache dependencies
1919
keys:
20-
- v1-dependencies-{{ checksum "requirements.txt" }}
20+
- v1-dependencies-{{ checksum "test-requirements.txt" }}
2121
# fallback to using the latest cache if no exact match is found
2222
- v1-dependencies-
2323

@@ -26,19 +26,17 @@ jobs:
2626
command: |
2727
virtualenv venv
2828
. venv/bin/activate
29-
pip install -r requirements.txt
29+
pip install -r test-requirements.txt
3030
3131
- save_cache:
3232
paths:
3333
- ./venv
34-
key: v1-dependencies-{{ checksum "requirements.txt" }}
34+
key: v1-dependencies-{{ checksum "test-requirements.txt" }}
3535

3636
- run:
3737
name: run tests
3838
command: |
3939
. venv/bin/activate
40-
pip install -r test-requirements.txt
41-
pip install nose codecov
4240
nosetests --with-coverage -vsx test
4341
codecov
4442
@@ -48,7 +46,7 @@ jobs:
4846

4947
build_nightly:
5048
docker:
51-
- image: circleci/python:3.7.1
49+
- image: circleci/python:3.6.9
5250
- image: redisai/redisai:edge
5351

5452
working_directory: ~/repo
@@ -58,7 +56,7 @@ jobs:
5856

5957
- restore_cache: # Download and cache dependencies
6058
keys:
61-
- v1-dependencies-{{ checksum "requirements.txt" }}
59+
- v1-dependencies-{{ checksum "test-requirements.txt" }}
6260
# fallback to using the latest cache if no exact match is found
6361
- v1-dependencies-
6462

@@ -67,19 +65,17 @@ jobs:
6765
command: |
6866
virtualenv venv
6967
. venv/bin/activate
70-
pip install -r requirements.txt
68+
pip install -r test-requirements.txt
7169
7270
- save_cache:
7371
paths:
7472
- ./venv
75-
key: v1-dependencies-{{ checksum "requirements.txt" }}
73+
key: v1-dependencies-{{ checksum "test-requirements.txt" }}
7674

7775
- run:
7876
name: run tests
7977
command: |
8078
. venv/bin/activate
81-
pip install -r test-requirements.txt
82-
pip install nose
8379
nosetests -vsx test
8480
8581
# no need for store_artifacts on nightly builds

redisai/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from .version import __version__
22
from .client import Client
3-
from .constants import DType, Device, Backend

redisai/client.py

Lines changed: 120 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
1-
from redis import StrictRedis
2-
from typing import Union, Any, AnyStr, ByteString, Sequence
3-
from .containers import Script, Model, Tensor
1+
from functools import wraps
2+
from typing import Union, AnyStr, ByteString, List, Sequence
43
import warnings
54

6-
try:
7-
import numpy as np
8-
except ImportError:
9-
np = None
5+
from redis import StrictRedis
6+
import numpy as np
7+
8+
from . import utils
9+
1010

11-
from .constants import Backend, Device, DType
12-
from .utils import str_or_strsequence, to_string, list_to_dict
13-
from . import convert
11+
def enable_debug(f):
12+
@wraps(f)
13+
def wrapper(*args):
14+
print(*args)
15+
return f(*args)
16+
return wrapper
1417

1518

19+
# TODO: typing to use AnyStr
20+
1621
class Client(StrictRedis):
1722
"""
1823
RedisAI client that can call Redis with RedisAI specific commands
1924
"""
20-
def loadbackend(self, identifier: AnyStr, path: AnyStr) -> AnyStr:
25+
def __init__(self, debug=False, *args, **kwargs):
26+
super().__init__(*args, **kwargs)
27+
if debug:
28+
self.execute_command = enable_debug(super().execute_command)
29+
30+
def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
2131
"""
2232
RedisAI by default won't load any backends. User can either explicitly
2333
load the backend by using this function or let RedisAI load the required
@@ -27,20 +37,36 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> AnyStr:
2737
:param path: Path to the shared object of the backend
2838
:return: byte string represents success or failure
2939
"""
30-
return self.execute_command('AI.CONFIG LOADBACKEND', identifier, path)
40+
return self.execute_command('AI.CONFIG LOADBACKEND', identifier, path).decode()
3141

3242
def modelset(self,
3343
name: AnyStr,
34-
backend: Backend,
35-
device: Device,
44+
backend: str,
45+
device: str,
3646
data: ByteString,
3747
batch: int = None,
3848
minbatch: int = None,
3949
tag: str = None,
40-
inputs: Union[AnyStr, Sequence[AnyStr]] = None,
41-
outputs: Union[AnyStr, Sequence[AnyStr]] = None
42-
) -> AnyStr:
43-
args = ['AI.MODELSET', name, backend.value, device.value]
50+
inputs: List[AnyStr] = None,
51+
outputs: List[AnyStr] = None) -> str:
52+
"""
53+
Set the model on provided key.
54+
:param name: str, Key name
55+
:param backend: str, Backend name. Allowed backends are TF, TORCH, TFLITE, ONNX
56+
:param device: str, Device name. Allowed devices are CPU and GPU
57+
:param data: bytes, Model graph read as bytestring
58+
:param batch: int, Number of batches for doing autobatching
59+
:param minbatch: int, Minimum number of samples required in a batch for model
60+
execution
61+
:param tag: str, Any string that will be saved in RedisAI as tags for the model
62+
:param inputs: list, List of strings that represents the input nodes in the graph.
63+
Required only Tensorflow graphs
64+
:param outputs: list, List of strings that represents the output nodes in the graph
65+
Required only for Tensorflow graphs
66+
67+
:return:
68+
"""
69+
args = ['AI.MODELSET', name, backend, device]
4470

4571
if batch is not None:
4672
args += ['BATCHSIZE', batch]
@@ -49,48 +75,47 @@ def modelset(self,
4975
if tag is not None:
5076
args += ['TAG', tag]
5177

52-
if backend == Backend.tf:
78+
if backend.upper() == 'TF':
5379
if not(all((inputs, outputs))):
5480
raise ValueError(
5581
'Require keyword arguments input and output for TF models')
56-
args += ['INPUTS'] + str_or_strsequence(inputs)
57-
args += ['OUTPUTS'] + str_or_strsequence(outputs)
58-
args += [data]
59-
return self.execute_command(*args)
60-
61-
def modelget(self, name: AnyStr, meta_only=False) -> Model:
62-
argname = 'META' if meta_only else 'BLOB'
63-
rv = self.execute_command('AI.MODELGET', name, argname)
64-
rv = list_to_dict(rv)
65-
return Model(
66-
rv.get('blob'),
67-
Device(rv['device']),
68-
Backend(rv['backend']),
69-
rv['tag'])
70-
71-
def modeldel(self, name: AnyStr) -> AnyStr:
72-
return self.execute_command('AI.MODELDEL', name)
82+
args += ['INPUTS'] + utils.listify(inputs)
83+
args += ['OUTPUTS'] + utils.listify(outputs)
84+
args.append(data)
85+
return self.execute_command(*args).decode()
86+
87+
def modelget(self, name: AnyStr, meta_only=False) -> dict:
88+
args = ['AI.MODELGET', name, 'META']
89+
if not meta_only:
90+
args.append('BLOB')
91+
rv = self.execute_command(*args)
92+
return utils.list2dict(rv)
93+
94+
def modeldel(self, name: AnyStr) -> str:
95+
return self.execute_command('AI.MODELDEL', name).decode()
7396

7497
def modelrun(self,
7598
name: AnyStr,
76-
inputs: Union[AnyStr, Sequence[AnyStr]],
77-
outputs: Union[AnyStr, Sequence[AnyStr]]
78-
) -> AnyStr:
79-
args = ['AI.MODELRUN', name]
80-
args += ['INPUTS'] + str_or_strsequence(inputs)
81-
args += ['OUTPUTS'] + str_or_strsequence(outputs)
82-
return self.execute_command(*args)
83-
84-
def modelist(self):
99+
inputs: List[AnyStr],
100+
outputs: List[AnyStr]
101+
) -> str:
102+
out = self.execute_command(
103+
'AI.MODELRUN', name,
104+
'INPUTS', *utils.listify(inputs),
105+
'OUTPUTS', *utils.listify(outputs)
106+
)
107+
return out.decode()
108+
109+
def modelscan(self) -> list:
85110
warnings.warn("Experimental: Model List API is experimental and might change "
86111
"in the future without any notice", UserWarning)
87-
return self.execute_command("AI._MODELLIST")
112+
return utils.un_bytize(self.execute_command("AI._MODELSCAN"), lambda x: x.decode())
88113

89114
def tensorset(self,
90115
key: AnyStr,
91116
tensor: Union[np.ndarray, list, tuple],
92117
shape: Sequence[int] = None,
93-
dtype: Union[DType, type] = None) -> Any:
118+
dtype: str = None) -> str:
94119
"""
95120
Set the values of the tensor on the server using the provided Tensor object
96121
:param key: The name of the tensor
@@ -99,20 +124,20 @@ def tensorset(self,
99124
:param dtype: data type of the tensor. Required if `tensor` is list or tuple
100125
"""
101126
if np and isinstance(tensor, np.ndarray):
102-
tensor = convert.from_numpy(tensor)
103-
args = ['AI.TENSORSET', key, tensor.dtype.value, *tensor.shape, tensor.argname, tensor.value]
127+
dtype, shape, blob = utils.numpy2blob(tensor)
128+
args = ['AI.TENSORSET', key, dtype, *shape, 'BLOB', blob]
104129
elif isinstance(tensor, (list, tuple)):
105130
if shape is None:
106131
shape = (len(tensor),)
107-
if not isinstance(dtype, DType):
108-
dtype = DType.__members__[np.dtype(dtype).name]
109-
tensor = convert.from_sequence(tensor, shape, dtype)
110-
args = ['AI.TENSORSET', key, tensor.dtype.value, *tensor.shape, tensor.argname, *tensor.value]
111-
return self.execute_command(*args)
132+
args = ['AI.TENSORSET', key, dtype, *shape, 'VALUES', *tensor]
133+
else:
134+
raise TypeError(f"``tensor`` argument must be a numpy array or a list or a "
135+
f"tuple, but got {type(tensor)}")
136+
return self.execute_command(*args).decode()
112137

113138
def tensorget(self,
114-
key: AnyStr, as_numpy: bool = True,
115-
meta_only: bool = False) -> Union[Tensor, np.ndarray]:
139+
key: str, as_numpy: bool = True,
140+
meta_only: bool = False) -> Union[dict, np.ndarray]:
116141
"""
117142
Retrieve the value of a tensor from the server. By default it returns the numpy array
118143
but it can be controlled using `as_type` argument and `meta_only` argument.
@@ -124,57 +149,63 @@ def tensorget(self,
124149
only the shape and the type
125150
:return: an instance of as_type
126151
"""
152+
args = ['AI.TENSORGET', key, 'META']
153+
if not meta_only:
154+
if as_numpy is True:
155+
args.append('BLOB')
156+
else:
157+
args.append('VALUES')
158+
159+
res = self.execute_command(*args)
160+
res = utils.list2dict(res)
127161
if meta_only:
128-
argname = 'META'
162+
return res
129163
elif as_numpy is True:
130-
argname = 'BLOB'
164+
return utils.blob2numpy(res['blob'], res['shape'], res['dtype'])
131165
else:
132-
argname = 'VALUES'
166+
target = float if res['dtype'] in ('FLOAT', 'DOUBLE') else int
167+
utils.un_bytize(res['values'], target)
168+
return res
133169

134-
res = self.execute_command('AI.TENSORGET', key, argname)
135-
dtype, shape = to_string(res[0]), res[1]
136-
if meta_only:
137-
return convert.to_sequence([], shape, dtype)
138-
if as_numpy is True:
139-
return convert.to_numpy(res[2], shape, dtype)
140-
else:
141-
return convert.to_sequence(res[2], shape, dtype)
142-
143-
def scriptset(self, name: AnyStr, device: Device, script: AnyStr, tag: str = None) -> AnyStr:
144-
args = ['AI.SCRIPTSET', name, device.value]
170+
def scriptset(self, name: str, device: str, script: str, tag: str = None) -> str:
171+
args = ['AI.SCRIPTSET', name, device]
145172
if tag:
146173
args += ['TAG', tag]
147-
args += [script]
148-
return self.execute_command(*args)
174+
args.append(script)
175+
return self.execute_command(*args).decode()
149176

150-
def scriptget(self, name: AnyStr) -> Script:
151-
ret = self.execute_command('AI.SCRIPTGET', name)
152-
ret = list_to_dict(ret)
153-
return Script(ret['source'], Device(ret['device']), ret['tag'])
177+
def scriptget(self, name: AnyStr, meta_only=False) -> dict:
178+
# TODO scripget test
179+
args = ['AI.SCRIPTGET', name, 'META']
180+
if not meta_only:
181+
args.append('SOURCE')
182+
ret = self.execute_command(*args)
183+
return utils.list2dict(ret)
154184

155-
def scriptdel(self, name):
156-
return self.execute_command('AI.SCRIPTDEL', name)
185+
def scriptdel(self, name: str) -> str:
186+
return self.execute_command('AI.SCRIPTDEL', name).decode()
157187

158188
def scriptrun(self,
159189
name: AnyStr,
160190
function: AnyStr,
161191
inputs: Union[AnyStr, Sequence[AnyStr]],
162192
outputs: Union[AnyStr, Sequence[AnyStr]]
163193
) -> AnyStr:
164-
args = ['AI.SCRIPTRUN', name, function, 'INPUTS']
165-
args += str_or_strsequence(inputs)
166-
args += ['OUTPUTS']
167-
args += str_or_strsequence(outputs)
168-
return self.execute_command(*args)
169-
170-
def scriptlist(self):
194+
out = self.execute_command(
195+
'AI.SCRIPTRUN', name, function,
196+
'INPUTS', *utils.listify(inputs),
197+
'OUTPUTS', *utils.listify(outputs)
198+
)
199+
return out.decode()
200+
201+
def scriptscan(self) -> list:
171202
warnings.warn("Experimental: Script List API is experimental and might change "
172203
"in the future without any notice", UserWarning)
173-
return self.execute_command("AI._SCRIPTLIST")
204+
return utils.un_bytize(self.execute_command("AI._SCRIPTSCAN"), lambda x: x.decode())
174205

175206
def infoget(self, key: str) -> dict:
176207
ret = self.execute_command('AI.INFO', key)
177-
return list_to_dict(ret)
208+
return utils.list2dict(ret)
178209

179-
def inforeset(self, key: str) -> dict:
180-
return self.execute_command('AI.INFO', key, 'RESETSTAT')
210+
def inforeset(self, key: str) -> str:
211+
return self.execute_command('AI.INFO', key, 'RESETSTAT').decode()

0 commit comments

Comments
 (0)