1
- from functools import wraps
2
- from typing import Union , AnyStr , ByteString , List , Sequence
1
+ from functools import wraps , partial
2
+ from typing import Union , AnyStr , ByteString , List , Sequence , Any
3
3
import warnings
4
4
5
5
from redis import StrictRedis
6
6
import numpy as np
7
7
8
8
from . import utils
9
+ from .command_builder import Builder
10
+
11
+
12
+ builder = Builder ()
9
13
10
14
11
15
def enable_debug (f ):
@@ -16,7 +20,69 @@ def wrapper(*args):
16
20
return wrapper
17
21
18
22
19
- # TODO: typing to use AnyStr
23
+ class Dag :
24
+ def __init__ (self , load , persist , executor , readonly = False ):
25
+ self .result_processors = []
26
+ if readonly :
27
+ if persist :
28
+ raise RuntimeError ("READONLY requests cannot write (duh!) and should not "
29
+ "have PERSISTing values" )
30
+ self .commands = ['AI.DAGRUN_RO' ]
31
+ else :
32
+ self .commands = ['AI.DAGRUN' ]
33
+ if load :
34
+ if not isinstance (load , (list , tuple )):
35
+ self .commands += ["LOAD" , 1 , load ]
36
+ else :
37
+ self .commands += ["LOAD" , len (load ), * load ]
38
+ if persist :
39
+ if not isinstance (persist , (list , tuple )):
40
+ self .commands += ["PERSIST" , 1 , persist , '|>' ]
41
+ else :
42
+ self .commands += ["PERSIST" , len (persist ), * persist , '|>' ]
43
+ elif load :
44
+ self .commands .append ('|>' )
45
+ self .executor = executor
46
+
47
+ def tensorset (self ,
48
+ key : AnyStr ,
49
+ tensor : Union [np .ndarray , list , tuple ],
50
+ shape : Sequence [int ] = None ,
51
+ dtype : str = None ) -> Any :
52
+ args = builder .tensorset (key , tensor , shape , dtype )
53
+ self .commands .extend (args )
54
+ self .commands .append ("|>" )
55
+ self .result_processors .append (bytes .decode )
56
+ return self
57
+
58
+ def tensorget (self ,
59
+ key : AnyStr , as_numpy : bool = True ,
60
+ meta_only : bool = False ) -> Any :
61
+ args = builder .tensorget (key , as_numpy , meta_only )
62
+ self .commands .extend (args )
63
+ self .commands .append ("|>" )
64
+ self .result_processors .append (partial (utils .tensorget_postprocessor ,
65
+ as_numpy ,
66
+ meta_only ))
67
+ return self
68
+
69
+ def modelrun (self ,
70
+ name : AnyStr ,
71
+ inputs : Union [AnyStr , List [AnyStr ]],
72
+ outputs : Union [AnyStr , List [AnyStr ]]) -> Any :
73
+ args = builder .modelrun (name , inputs , outputs )
74
+ self .commands .extend (args )
75
+ self .commands .append ("|>" )
76
+ self .result_processors .append (bytes .decode )
77
+ return self
78
+
79
+ def run (self ):
80
+ results = self .executor (* self .commands )
81
+ out = []
82
+ for res , fn in zip (results , self .result_processors ):
83
+ out .append (fn (res ))
84
+ return out
85
+
20
86
21
87
class Client (StrictRedis ):
22
88
"""
@@ -27,6 +93,11 @@ def __init__(self, debug=False, *args, **kwargs):
27
93
if debug :
28
94
self .execute_command = enable_debug (super ().execute_command )
29
95
96
+ def dag (self , load : Sequence = None , persist : Sequence = None ,
97
+ readonly : bool = False ) -> Dag :
98
+ """ Special function to return a dag object """
99
+ return Dag (load , persist , self .execute_command , readonly )
100
+
30
101
def loadbackend (self , identifier : AnyStr , path : AnyStr ) -> str :
31
102
"""
32
103
RedisAI by default won't load any backends. User can either explicitly
@@ -37,7 +108,8 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
37
108
:param path: Path to the shared object of the backend
38
109
:return: byte string represents success or failure
39
110
"""
40
- return self .execute_command ('AI.CONFIG LOADBACKEND' , identifier , path ).decode ()
111
+ args = builder .loadbackend (identifier , path )
112
+ return self .execute_command (* args ).decode ()
41
113
42
114
def modelset (self ,
43
115
name : AnyStr ,
@@ -46,9 +118,9 @@ def modelset(self,
46
118
data : ByteString ,
47
119
batch : int = None ,
48
120
minbatch : int = None ,
49
- tag : str = None ,
50
- inputs : List [AnyStr ] = None ,
51
- outputs : List [AnyStr ] = None ) -> str :
121
+ tag : AnyStr = None ,
122
+ inputs : Union [ AnyStr , List [AnyStr ] ] = None ,
123
+ outputs : Union [ AnyStr , List [AnyStr ] ] = None ) -> str :
52
124
"""
53
125
Set the model on provided key.
54
126
:param name: str, Key name
@@ -66,50 +138,32 @@ def modelset(self,
66
138
67
139
:return:
68
140
"""
69
- args = ['AI.MODELSET' , name , backend , device ]
70
-
71
- if batch is not None :
72
- args += ['BATCHSIZE' , batch ]
73
- if minbatch is not None :
74
- args += ['MINBATCHSIZE' , minbatch ]
75
- if tag is not None :
76
- args += ['TAG' , tag ]
77
-
78
- if backend .upper () == 'TF' :
79
- if not (all ((inputs , outputs ))):
80
- raise ValueError (
81
- 'Require keyword arguments input and output for TF models' )
82
- args += ['INPUTS' ] + utils .listify (inputs )
83
- args += ['OUTPUTS' ] + utils .listify (outputs )
84
- args .append (data )
141
+ args = builder .modelset (name , backend , device , data ,
142
+ batch , minbatch , tag , inputs , outputs )
85
143
return self .execute_command (* args ).decode ()
86
144
87
145
def modelget (self , name : AnyStr , meta_only = False ) -> dict :
88
- args = ['AI.MODELGET' , name , 'META' ]
89
- if not meta_only :
90
- args .append ('BLOB' )
146
+ args = builder .modelget (name , meta_only )
91
147
rv = self .execute_command (* args )
92
148
return utils .list2dict (rv )
93
149
94
150
def modeldel (self , name : AnyStr ) -> str :
95
- return self .execute_command ('AI.MODELDEL' , name ).decode ()
151
+ args = builder .modeldel (name )
152
+ return self .execute_command (* args ).decode ()
96
153
97
154
def modelrun (self ,
98
155
name : AnyStr ,
99
- inputs : List [AnyStr ],
100
- outputs : List [AnyStr ]
101
- ) -> str :
102
- out = self .execute_command (
103
- 'AI.MODELRUN' , name ,
104
- 'INPUTS' , * utils .listify (inputs ),
105
- 'OUTPUTS' , * utils .listify (outputs )
106
- )
107
- return out .decode ()
156
+ inputs : Union [AnyStr , List [AnyStr ]],
157
+ outputs : Union [AnyStr , List [AnyStr ]]) -> str :
158
+ args = builder .modelrun (name , inputs , outputs )
159
+ return self .execute_command (* args ).decode ()
108
160
109
161
def modelscan (self ) -> list :
110
162
warnings .warn ("Experimental: Model List API is experimental and might change "
111
163
"in the future without any notice" , UserWarning )
112
- return utils .un_bytize (self .execute_command ("AI._MODELSCAN" ), lambda x : x .decode ())
164
+ args = builder .modelscan ()
165
+ result = self .execute_command (* args )
166
+ return utils .recursive_bytetransform (result , lambda x : x .decode ())
113
167
114
168
def tensorset (self ,
115
169
key : AnyStr ,
@@ -123,20 +177,11 @@ def tensorset(self,
123
177
:param shape: Shape of the tensor. Required if `tensor` is list or tuple
124
178
:param dtype: data type of the tensor. Required if `tensor` is list or tuple
125
179
"""
126
- if np and isinstance (tensor , np .ndarray ):
127
- dtype , shape , blob = utils .numpy2blob (tensor )
128
- args = ['AI.TENSORSET' , key , dtype , * shape , 'BLOB' , blob ]
129
- elif isinstance (tensor , (list , tuple )):
130
- if shape is None :
131
- shape = (len (tensor ),)
132
- args = ['AI.TENSORSET' , key , dtype , * shape , 'VALUES' , * tensor ]
133
- else :
134
- raise TypeError (f"``tensor`` argument must be a numpy array or a list or a "
135
- f"tuple, but got { type (tensor )} " )
180
+ args = builder .tensorset (key , tensor , shape , dtype )
136
181
return self .execute_command (* args ).decode ()
137
182
138
183
def tensorget (self ,
139
- key : str , as_numpy : bool = True ,
184
+ key : AnyStr , as_numpy : bool = True ,
140
185
meta_only : bool = False ) -> Union [dict , np .ndarray ]:
141
186
"""
142
187
Retrieve the value of a tensor from the server. By default it returns the numpy array
@@ -149,63 +194,45 @@ def tensorget(self,
149
194
only the shape and the type
150
195
:return: an instance of as_type
151
196
"""
152
- args = ['AI.TENSORGET' , key , 'META' ]
153
- if not meta_only :
154
- if as_numpy is True :
155
- args .append ('BLOB' )
156
- else :
157
- args .append ('VALUES' )
158
-
197
+ args = builder .tensorget (key , as_numpy , meta_only )
159
198
res = self .execute_command (* args )
160
- res = utils .list2dict (res )
161
- if meta_only :
162
- return res
163
- elif as_numpy is True :
164
- return utils .blob2numpy (res ['blob' ], res ['shape' ], res ['dtype' ])
165
- else :
166
- target = float if res ['dtype' ] in ('FLOAT' , 'DOUBLE' ) else int
167
- utils .un_bytize (res ['values' ], target )
168
- return res
169
-
170
- def scriptset (self , name : str , device : str , script : str , tag : str = None ) -> str :
171
- args = ['AI.SCRIPTSET' , name , device ]
172
- if tag :
173
- args += ['TAG' , tag ]
174
- args .append (script )
199
+ return utils .tensorget_postprocessor (as_numpy , meta_only , res )
200
+
201
+ def scriptset (self , name : AnyStr , device : str , script : str , tag : AnyStr = None ) -> str :
202
+ args = builder .scriptset (name , device , script , tag )
175
203
return self .execute_command (* args ).decode ()
176
204
177
205
def scriptget (self , name : AnyStr , meta_only = False ) -> dict :
178
206
# TODO scripget test
179
- args = ['AI.SCRIPTGET' , name , 'META' ]
180
- if not meta_only :
181
- args .append ('SOURCE' )
207
+ args = builder .scriptget (name , meta_only )
182
208
ret = self .execute_command (* args )
183
209
return utils .list2dict (ret )
184
210
185
- def scriptdel (self , name : str ) -> str :
186
- return self .execute_command ('AI.SCRIPTDEL' , name ).decode ()
211
+ def scriptdel (self , name : AnyStr ) -> str :
212
+ args = builder .scriptdel (name )
213
+ return self .execute_command (* args ).decode ()
187
214
188
215
def scriptrun (self ,
189
216
name : AnyStr ,
190
217
function : AnyStr ,
191
218
inputs : Union [AnyStr , Sequence [AnyStr ]],
192
219
outputs : Union [AnyStr , Sequence [AnyStr ]]
193
- ) -> AnyStr :
194
- out = self .execute_command (
195
- 'AI.SCRIPTRUN' , name , function ,
196
- 'INPUTS' , * utils .listify (inputs ),
197
- 'OUTPUTS' , * utils .listify (outputs )
198
- )
220
+ ) -> str :
221
+ args = builder .scriptrun (name , function , inputs , outputs )
222
+ out = self .execute_command (* args )
199
223
return out .decode ()
200
224
201
225
def scriptscan (self ) -> list :
202
226
warnings .warn ("Experimental: Script List API is experimental and might change "
203
227
"in the future without any notice" , UserWarning )
204
- return utils .un_bytize (self .execute_command ("AI._SCRIPTSCAN" ), lambda x : x .decode ())
228
+ args = builder .scriptscan ()
229
+ return utils .recursive_bytetransform (self .execute_command (* args ), lambda x : x .decode ())
205
230
206
- def infoget (self , key : str ) -> dict :
207
- ret = self .execute_command ('AI.INFO' , key )
231
+ def infoget (self , key : AnyStr ) -> dict :
232
+ args = builder .infoget (key )
233
+ ret = self .execute_command (* args )
208
234
return utils .list2dict (ret )
209
235
210
- def inforeset (self , key : str ) -> str :
211
- return self .execute_command ('AI.INFO' , key , 'RESETSTAT' ).decode ()
236
+ def inforeset (self , key : AnyStr ) -> str :
237
+ args = builder .inforeset (key )
238
+ return self .execute_command (* args ).decode ()
0 commit comments