Skip to content

Commit 3f713bc

Browse files
author
Sherin Thomas
authored
Merge pull request #33 from RedisAI/dag
Dag implementation
2 parents 00c3c2a + 61d8c6e commit 3f713bc

File tree

5 files changed

+399
-101
lines changed

5 files changed

+399
-101
lines changed

redisai/client.py

Lines changed: 113 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from functools import wraps
2-
from typing import Union, AnyStr, ByteString, List, Sequence
1+
from functools import wraps, partial
2+
from typing import Union, AnyStr, ByteString, List, Sequence, Any
33
import warnings
44

55
from redis import StrictRedis
66
import numpy as np
77

88
from . import utils
9+
from .command_builder import Builder
10+
11+
12+
builder = Builder()
913

1014

1115
def enable_debug(f):
@@ -16,7 +20,69 @@ def wrapper(*args):
1620
return wrapper
1721

1822

19-
# TODO: typing to use AnyStr
23+
class Dag:
24+
def __init__(self, load, persist, executor, readonly=False):
25+
self.result_processors = []
26+
if readonly:
27+
if persist:
28+
raise RuntimeError("READONLY requests cannot write (duh!) and should not "
29+
"have PERSISTing values")
30+
self.commands = ['AI.DAGRUN_RO']
31+
else:
32+
self.commands = ['AI.DAGRUN']
33+
if load:
34+
if not isinstance(load, (list, tuple)):
35+
self.commands += ["LOAD", 1, load]
36+
else:
37+
self.commands += ["LOAD", len(load), *load]
38+
if persist:
39+
if not isinstance(persist, (list, tuple)):
40+
self.commands += ["PERSIST", 1, persist, '|>']
41+
else:
42+
self.commands += ["PERSIST", len(persist), *persist, '|>']
43+
elif load:
44+
self.commands.append('|>')
45+
self.executor = executor
46+
47+
def tensorset(self,
48+
key: AnyStr,
49+
tensor: Union[np.ndarray, list, tuple],
50+
shape: Sequence[int] = None,
51+
dtype: str = None) -> Any:
52+
args = builder.tensorset(key, tensor, shape, dtype)
53+
self.commands.extend(args)
54+
self.commands.append("|>")
55+
self.result_processors.append(bytes.decode)
56+
return self
57+
58+
def tensorget(self,
59+
key: AnyStr, as_numpy: bool = True,
60+
meta_only: bool = False) -> Any:
61+
args = builder.tensorget(key, as_numpy, meta_only)
62+
self.commands.extend(args)
63+
self.commands.append("|>")
64+
self.result_processors.append(partial(utils.tensorget_postprocessor,
65+
as_numpy,
66+
meta_only))
67+
return self
68+
69+
def modelrun(self,
70+
name: AnyStr,
71+
inputs: Union[AnyStr, List[AnyStr]],
72+
outputs: Union[AnyStr, List[AnyStr]]) -> Any:
73+
args = builder.modelrun(name, inputs, outputs)
74+
self.commands.extend(args)
75+
self.commands.append("|>")
76+
self.result_processors.append(bytes.decode)
77+
return self
78+
79+
def run(self):
80+
results = self.executor(*self.commands)
81+
out = []
82+
for res, fn in zip(results, self.result_processors):
83+
out.append(fn(res))
84+
return out
85+
2086

2187
class Client(StrictRedis):
2288
"""
@@ -27,6 +93,11 @@ def __init__(self, debug=False, *args, **kwargs):
2793
if debug:
2894
self.execute_command = enable_debug(super().execute_command)
2995

96+
def dag(self, load: Sequence = None, persist: Sequence = None,
97+
readonly: bool = False) -> Dag:
98+
""" Special function to return a dag object """
99+
return Dag(load, persist, self.execute_command, readonly)
100+
30101
def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
31102
"""
32103
RedisAI by default won't load any backends. User can either explicitly
@@ -37,7 +108,8 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
37108
:param path: Path to the shared object of the backend
38109
:return: byte string represents success or failure
39110
"""
40-
return self.execute_command('AI.CONFIG LOADBACKEND', identifier, path).decode()
111+
args = builder.loadbackend(identifier, path)
112+
return self.execute_command(*args).decode()
41113

42114
def modelset(self,
43115
name: AnyStr,
@@ -46,9 +118,9 @@ def modelset(self,
46118
data: ByteString,
47119
batch: int = None,
48120
minbatch: int = None,
49-
tag: str = None,
50-
inputs: List[AnyStr] = None,
51-
outputs: List[AnyStr] = None) -> str:
121+
tag: AnyStr = None,
122+
inputs: Union[AnyStr, List[AnyStr]] = None,
123+
outputs: Union[AnyStr, List[AnyStr]] = None) -> str:
52124
"""
53125
Set the model on provided key.
54126
:param name: str, Key name
@@ -66,50 +138,32 @@ def modelset(self,
66138
67139
:return:
68140
"""
69-
args = ['AI.MODELSET', name, backend, device]
70-
71-
if batch is not None:
72-
args += ['BATCHSIZE', batch]
73-
if minbatch is not None:
74-
args += ['MINBATCHSIZE', minbatch]
75-
if tag is not None:
76-
args += ['TAG', tag]
77-
78-
if backend.upper() == 'TF':
79-
if not(all((inputs, outputs))):
80-
raise ValueError(
81-
'Require keyword arguments input and output for TF models')
82-
args += ['INPUTS'] + utils.listify(inputs)
83-
args += ['OUTPUTS'] + utils.listify(outputs)
84-
args.append(data)
141+
args = builder.modelset(name, backend, device, data,
142+
batch, minbatch, tag, inputs, outputs)
85143
return self.execute_command(*args).decode()
86144

87145
def modelget(self, name: AnyStr, meta_only=False) -> dict:
88-
args = ['AI.MODELGET', name, 'META']
89-
if not meta_only:
90-
args.append('BLOB')
146+
args = builder.modelget(name, meta_only)
91147
rv = self.execute_command(*args)
92148
return utils.list2dict(rv)
93149

94150
def modeldel(self, name: AnyStr) -> str:
95-
return self.execute_command('AI.MODELDEL', name).decode()
151+
args = builder.modeldel(name)
152+
return self.execute_command(*args).decode()
96153

97154
def modelrun(self,
98155
name: AnyStr,
99-
inputs: List[AnyStr],
100-
outputs: List[AnyStr]
101-
) -> str:
102-
out = self.execute_command(
103-
'AI.MODELRUN', name,
104-
'INPUTS', *utils.listify(inputs),
105-
'OUTPUTS', *utils.listify(outputs)
106-
)
107-
return out.decode()
156+
inputs: Union[AnyStr, List[AnyStr]],
157+
outputs: Union[AnyStr, List[AnyStr]]) -> str:
158+
args = builder.modelrun(name, inputs, outputs)
159+
return self.execute_command(*args).decode()
108160

109161
def modelscan(self) -> list:
110162
warnings.warn("Experimental: Model List API is experimental and might change "
111163
"in the future without any notice", UserWarning)
112-
return utils.un_bytize(self.execute_command("AI._MODELSCAN"), lambda x: x.decode())
164+
args = builder.modelscan()
165+
result = self.execute_command(*args)
166+
return utils.recursive_bytetransform(result, lambda x: x.decode())
113167

114168
def tensorset(self,
115169
key: AnyStr,
@@ -123,20 +177,11 @@ def tensorset(self,
123177
:param shape: Shape of the tensor. Required if `tensor` is list or tuple
124178
:param dtype: data type of the tensor. Required if `tensor` is list or tuple
125179
"""
126-
if np and isinstance(tensor, np.ndarray):
127-
dtype, shape, blob = utils.numpy2blob(tensor)
128-
args = ['AI.TENSORSET', key, dtype, *shape, 'BLOB', blob]
129-
elif isinstance(tensor, (list, tuple)):
130-
if shape is None:
131-
shape = (len(tensor),)
132-
args = ['AI.TENSORSET', key, dtype, *shape, 'VALUES', *tensor]
133-
else:
134-
raise TypeError(f"``tensor`` argument must be a numpy array or a list or a "
135-
f"tuple, but got {type(tensor)}")
180+
args = builder.tensorset(key, tensor, shape, dtype)
136181
return self.execute_command(*args).decode()
137182

138183
def tensorget(self,
139-
key: str, as_numpy: bool = True,
184+
key: AnyStr, as_numpy: bool = True,
140185
meta_only: bool = False) -> Union[dict, np.ndarray]:
141186
"""
142187
Retrieve the value of a tensor from the server. By default it returns the numpy array
@@ -149,63 +194,45 @@ def tensorget(self,
149194
only the shape and the type
150195
:return: an instance of as_type
151196
"""
152-
args = ['AI.TENSORGET', key, 'META']
153-
if not meta_only:
154-
if as_numpy is True:
155-
args.append('BLOB')
156-
else:
157-
args.append('VALUES')
158-
197+
args = builder.tensorget(key, as_numpy, meta_only)
159198
res = self.execute_command(*args)
160-
res = utils.list2dict(res)
161-
if meta_only:
162-
return res
163-
elif as_numpy is True:
164-
return utils.blob2numpy(res['blob'], res['shape'], res['dtype'])
165-
else:
166-
target = float if res['dtype'] in ('FLOAT', 'DOUBLE') else int
167-
utils.un_bytize(res['values'], target)
168-
return res
169-
170-
def scriptset(self, name: str, device: str, script: str, tag: str = None) -> str:
171-
args = ['AI.SCRIPTSET', name, device]
172-
if tag:
173-
args += ['TAG', tag]
174-
args.append(script)
199+
return utils.tensorget_postprocessor(as_numpy, meta_only, res)
200+
201+
def scriptset(self, name: AnyStr, device: str, script: str, tag: AnyStr = None) -> str:
202+
args = builder.scriptset(name, device, script, tag)
175203
return self.execute_command(*args).decode()
176204

177205
def scriptget(self, name: AnyStr, meta_only=False) -> dict:
178206
# TODO scripget test
179-
args = ['AI.SCRIPTGET', name, 'META']
180-
if not meta_only:
181-
args.append('SOURCE')
207+
args = builder.scriptget(name, meta_only)
182208
ret = self.execute_command(*args)
183209
return utils.list2dict(ret)
184210

185-
def scriptdel(self, name: str) -> str:
186-
return self.execute_command('AI.SCRIPTDEL', name).decode()
211+
def scriptdel(self, name: AnyStr) -> str:
212+
args = builder.scriptdel(name)
213+
return self.execute_command(*args).decode()
187214

188215
def scriptrun(self,
189216
name: AnyStr,
190217
function: AnyStr,
191218
inputs: Union[AnyStr, Sequence[AnyStr]],
192219
outputs: Union[AnyStr, Sequence[AnyStr]]
193-
) -> AnyStr:
194-
out = self.execute_command(
195-
'AI.SCRIPTRUN', name, function,
196-
'INPUTS', *utils.listify(inputs),
197-
'OUTPUTS', *utils.listify(outputs)
198-
)
220+
) -> str:
221+
args = builder.scriptrun(name, function, inputs, outputs)
222+
out = self.execute_command(*args)
199223
return out.decode()
200224

201225
def scriptscan(self) -> list:
202226
warnings.warn("Experimental: Script List API is experimental and might change "
203227
"in the future without any notice", UserWarning)
204-
return utils.un_bytize(self.execute_command("AI._SCRIPTSCAN"), lambda x: x.decode())
228+
args = builder.scriptscan()
229+
return utils.recursive_bytetransform(self.execute_command(*args), lambda x: x.decode())
205230

206-
def infoget(self, key: str) -> dict:
207-
ret = self.execute_command('AI.INFO', key)
231+
def infoget(self, key: AnyStr) -> dict:
232+
args = builder.infoget(key)
233+
ret = self.execute_command(*args)
208234
return utils.list2dict(ret)
209235

210-
def inforeset(self, key: str) -> str:
211-
return self.execute_command('AI.INFO', key, 'RESETSTAT').decode()
236+
def inforeset(self, key: AnyStr) -> str:
237+
args = builder.inforeset(key)
238+
return self.execute_command(*args).decode()

0 commit comments

Comments
 (0)