Skip to content

Commit 3da8c99

Browse files
committed
restructuring
1 parent f70af12 commit 3da8c99

File tree

5 files changed

+141
-113
lines changed

5 files changed

+141
-113
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ redisai.egg-info
88
build/
99
dist/
1010
docs/_build/
11+
.DS_Store

redisai/client.py

+7-112
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import warnings
44

55
from redis import StrictRedis
6-
from redis.client import Pipeline as RedisPipeline
76
import numpy as np
87

9-
from . import command_builder as builder
10-
from .postprocessor import Processor
8+
from redisai import command_builder as builder
9+
from redisai.dag import Dag
10+
from redisai.pipeline import Pipeline
11+
from redisai.postprocessor import Processor
1112

1213

1314
processor = Processor()
@@ -37,6 +38,9 @@ class Client(StrictRedis):
3738
>>> from redisai import Client
3839
>>> con = Client(host='localhost', port=6379)
3940
"""
41+
REDISAI_COMMANDS_RESPONSE_CALLBACKS = {
42+
}
43+
4044
def __init__(self, debug=False, enable_postprocess=True, *args, **kwargs):
4145
super().__init__(*args, **kwargs)
4246
if debug:
@@ -575,115 +579,6 @@ def inforeset(self, key: AnyStr) -> str:
575579
return res if not self.enable_postprocess else processor.inforeset(res)
576580

577581

578-
class Pipeline(RedisPipeline, Client):
579-
def __init__(self, enable_postprocess, *args, **kwargs):
580-
warnings.warn("Pipeling AI commands through this client is experimental.",
581-
UserWarning)
582-
self.enable_postprocess = False
583-
if enable_postprocess:
584-
warnings.warn("Postprocessing is enabled but not allowed in pipelines."
585-
"Disable postprocessing to remove this warning.", UserWarning)
586-
self.tensorget_processors = []
587-
super().__init__(*args, **kwargs)
588-
589-
def dag(self, *args, **kwargs):
590-
raise RuntimeError("Pipeline object doesn't allow DAG creation currently")
591-
592-
def tensorget(self, key, as_numpy=True, as_numpy_mutable=False, meta_only=False):
593-
self.tensorget_processors.append(partial(processor.tensorget,
594-
as_numpy=as_numpy,
595-
as_numpy_mutable=as_numpy_mutable,
596-
meta_only=meta_only))
597-
return super().tensorget(key, as_numpy, as_numpy_mutable, meta_only)
598-
599-
def _execute_transaction(self, *args, **kwargs):
600-
# TODO: Blocking commands like MODELRUN, SCRIPTRUN and DAGRUN won't work
601-
res = super()._execute_transaction(*args, **kwargs)
602-
for i in range(len(res)):
603-
# tensorget will have minimum 4 values if meta_only = True
604-
if isinstance(res[i], list) and len(res[i]) >= 4:
605-
res[i] = self.tensorget_processors.pop(0)(res[i])
606-
return res
607-
608-
def _execute_pipeline(self, *args, **kwargs):
609-
res = super()._execute_pipeline(*args, **kwargs)
610-
for i in range(len(res)):
611-
# tensorget will have minimum 4 values if meta_only = True
612-
if isinstance(res[i], list) and len(res[i]) >= 4:
613-
res[i] = self.tensorget_processors.pop(0)(res[i])
614-
return res
615-
616-
617-
class Dag:
618-
def __init__(self, load, persist, executor, readonly=False, postprocess=True):
619-
self.result_processors = []
620-
self.enable_postprocess = True
621-
if readonly:
622-
if persist:
623-
raise RuntimeError("READONLY requests cannot write (duh!) and should not "
624-
"have PERSISTing values")
625-
self.commands = ['AI.DAGRUN_RO']
626-
else:
627-
self.commands = ['AI.DAGRUN']
628-
if load:
629-
if not isinstance(load, (list, tuple)):
630-
self.commands += ["LOAD", 1, load]
631-
else:
632-
self.commands += ["LOAD", len(load), *load]
633-
if persist:
634-
if not isinstance(persist, (list, tuple)):
635-
self.commands += ["PERSIST", 1, persist, '|>']
636-
else:
637-
self.commands += ["PERSIST", len(persist), *persist, '|>']
638-
else:
639-
self.commands.append('|>')
640-
self.executor = executor
641-
642-
def tensorset(self,
643-
key: AnyStr,
644-
tensor: Union[np.ndarray, list, tuple],
645-
shape: Sequence[int] = None,
646-
dtype: str = None) -> Any:
647-
args = builder.tensorset(key, tensor, shape, dtype)
648-
self.commands.extend(args)
649-
self.commands.append("|>")
650-
self.result_processors.append(bytes.decode)
651-
return self
652-
653-
def tensorget(self,
654-
key: AnyStr, as_numpy: bool = True, as_numpy_mutable: bool = False,
655-
meta_only: bool = False) -> Any:
656-
args = builder.tensorget(key, as_numpy, as_numpy_mutable)
657-
self.commands.extend(args)
658-
self.commands.append("|>")
659-
self.result_processors.append(partial(processor.tensorget,
660-
as_numpy=as_numpy,
661-
as_numpy_mutable=as_numpy_mutable,
662-
meta_only=meta_only))
663-
return self
664-
665-
def modelrun(self,
666-
key: AnyStr,
667-
inputs: Union[AnyStr, List[AnyStr]],
668-
outputs: Union[AnyStr, List[AnyStr]]) -> Any:
669-
args = builder.modelrun(key, inputs, outputs)
670-
self.commands.extend(args)
671-
self.commands.append("|>")
672-
self.result_processors.append(bytes.decode)
673-
return self
674-
675-
def run(self):
676-
commands = self.commands[:-1] # removing the last "|>
677-
results = self.executor(*commands)
678-
if self.enable_postprocess:
679-
out = []
680-
for res, fn in zip(results, self.result_processors):
681-
out.append(fn(res))
682-
else:
683-
out = results
684-
return out
685-
686-
687582
def enable_debug(f):
688583
@wraps(f)
689584
def wrapper(*args):

redisai/dag.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from functools import partial
2+
from typing import AnyStr, Union, Sequence, Any, List
3+
4+
import numpy as np
5+
6+
from redisai.postprocessor import Processor
7+
from redisai import command_builder as builder
8+
9+
10+
processor = Processor()
11+
12+
13+
class Dag:
14+
def __init__(self, load, persist, executor, readonly=False, postprocess=True):
15+
self.result_processors = []
16+
self.enable_postprocess = True
17+
if readonly:
18+
if persist:
19+
raise RuntimeError("READONLY requests cannot write (duh!) and should not "
20+
"have PERSISTing values")
21+
self.commands = ['AI.DAGRUN_RO']
22+
else:
23+
self.commands = ['AI.DAGRUN']
24+
if load:
25+
if not isinstance(load, (list, tuple)):
26+
self.commands += ["LOAD", 1, load]
27+
else:
28+
self.commands += ["LOAD", len(load), *load]
29+
if persist:
30+
if not isinstance(persist, (list, tuple)):
31+
self.commands += ["PERSIST", 1, persist, '|>']
32+
else:
33+
self.commands += ["PERSIST", len(persist), *persist, '|>']
34+
else:
35+
self.commands.append('|>')
36+
self.executor = executor
37+
38+
def tensorset(self,
39+
key: AnyStr,
40+
tensor: Union[np.ndarray, list, tuple],
41+
shape: Sequence[int] = None,
42+
dtype: str = None) -> Any:
43+
args = builder.tensorset(key, tensor, shape, dtype)
44+
self.commands.extend(args)
45+
self.commands.append("|>")
46+
self.result_processors.append(bytes.decode)
47+
return self
48+
49+
def tensorget(self,
50+
key: AnyStr, as_numpy: bool = True, as_numpy_mutable: bool = False,
51+
meta_only: bool = False) -> Any:
52+
args = builder.tensorget(key, as_numpy, as_numpy_mutable)
53+
self.commands.extend(args)
54+
self.commands.append("|>")
55+
self.result_processors.append(partial(processor.tensorget,
56+
as_numpy=as_numpy,
57+
as_numpy_mutable=as_numpy_mutable,
58+
meta_only=meta_only))
59+
return self
60+
61+
def modelrun(self,
62+
key: AnyStr,
63+
inputs: Union[AnyStr, List[AnyStr]],
64+
outputs: Union[AnyStr, List[AnyStr]]) -> Any:
65+
args = builder.modelrun(key, inputs, outputs)
66+
self.commands.extend(args)
67+
self.commands.append("|>")
68+
self.result_processors.append(bytes.decode)
69+
return self
70+
71+
def run(self):
72+
commands = self.commands[:-1] # removing the last "|>
73+
results = self.executor(*commands)
74+
if self.enable_postprocess:
75+
out = []
76+
for res, fn in zip(results, self.result_processors):
77+
out.append(fn(res))
78+
else:
79+
out = results
80+
return out

redisai/pipeline.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import warnings
2+
from functools import partial
3+
from typing import AnyStr, Union, Sequence
4+
5+
import numpy as np
6+
7+
from redisai import command_builder as builder
8+
import redis
9+
from redisai.postprocessor import Processor
10+
11+
12+
processor = Processor()
13+
14+
15+
class Pipeline(redis.client.Pipeline):
16+
def __init__(self, enable_postprocess, *args, **kwargs):
17+
self.enable_postprocess = enable_postprocess
18+
self.tensorget_processors = []
19+
self.tensorset_processors = []
20+
super().__init__(*args, **kwargs)
21+
22+
def tensorget(self, key, as_numpy=True, as_numpy_mutable=False, meta_only=False):
23+
self.tensorget_processors.append(partial(processor.tensorget,
24+
as_numpy=as_numpy,
25+
as_numpy_mutable=as_numpy_mutable,
26+
meta_only=meta_only))
27+
args = builder.tensorget(key, as_numpy, meta_only)
28+
return self.execute_command(*args)
29+
30+
def tensorset(self, key: AnyStr,
31+
tensor: Union[np.ndarray, list, tuple],
32+
shape: Sequence[int] = None,
33+
dtype: str = None) -> str:
34+
args = builder.tensorset(key, tensor, shape, dtype)
35+
return self.execute_command(*args)
36+
37+
def _execute_transaction(self, *args, **kwargs):
38+
res = super()._execute_transaction(*args, **kwargs)
39+
for i in range(len(res)):
40+
# tensorget will have minimum 4 values if meta_only = True
41+
if isinstance(res[i], list) and len(res[i]) >= 4:
42+
res[i] = self.tensorget_processors.pop(0)(res[i])
43+
return res
44+
45+
def _execute_pipeline(self, *args, **kwargs):
46+
res = super()._execute_pipeline(*args, **kwargs)
47+
for i in range(len(res)):
48+
# tensorget will have minimum 4 values if meta_only = True
49+
if isinstance(res[i], list) and len(res[i]) >= 4:
50+
res[i] = self.tensorget_processors.pop(0)(res[i])
51+
return res

test/test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,8 @@ def test_dagrun_with_load(self):
338338
self.assertEqual(expected, result)
339339
self.assertRaises(ResponseError, con.tensorget, 'b')
340340

341-
def test_dagrun_with_persist(self):
341+
def dagrun_with_persist(self):
342+
# TODO: disabling for now
342343
con = self.get_client()
343344

344345
with self.assertRaises(ResponseError):

0 commit comments

Comments
 (0)