|
3 | 3 | import warnings
|
4 | 4 |
|
5 | 5 | from redis import StrictRedis
|
6 |
| -from redis.client import Pipeline as RedisPipeline |
7 | 6 | import numpy as np
|
8 | 7 |
|
9 |
| -from . import command_builder as builder |
10 |
| -from .postprocessor import Processor |
| 8 | +from redisai import command_builder as builder |
| 9 | +from redisai.dag import Dag |
| 10 | +from redisai.pipeline import Pipeline |
| 11 | +from redisai.postprocessor import Processor |
11 | 12 |
|
12 | 13 |
|
13 | 14 | processor = Processor()
|
@@ -37,6 +38,9 @@ class Client(StrictRedis):
|
37 | 38 | >>> from redisai import Client
|
38 | 39 | >>> con = Client(host='localhost', port=6379)
|
39 | 40 | """
|
| 41 | + REDISAI_COMMANDS_RESPONSE_CALLBACKS = { |
| 42 | + } |
| 43 | + |
40 | 44 | def __init__(self, debug=False, enable_postprocess=True, *args, **kwargs):
|
41 | 45 | super().__init__(*args, **kwargs)
|
42 | 46 | if debug:
|
@@ -575,115 +579,6 @@ def inforeset(self, key: AnyStr) -> str:
|
575 | 579 | return res if not self.enable_postprocess else processor.inforeset(res)
|
576 | 580 |
|
577 | 581 |
|
578 |
| -class Pipeline(RedisPipeline, Client): |
579 |
| - def __init__(self, enable_postprocess, *args, **kwargs): |
580 |
| - warnings.warn("Pipeling AI commands through this client is experimental.", |
581 |
| - UserWarning) |
582 |
| - self.enable_postprocess = False |
583 |
| - if enable_postprocess: |
584 |
| - warnings.warn("Postprocessing is enabled but not allowed in pipelines." |
585 |
| - "Disable postprocessing to remove this warning.", UserWarning) |
586 |
| - self.tensorget_processors = [] |
587 |
| - super().__init__(*args, **kwargs) |
588 |
| - |
589 |
| - def dag(self, *args, **kwargs): |
590 |
| - raise RuntimeError("Pipeline object doesn't allow DAG creation currently") |
591 |
| - |
592 |
| - def tensorget(self, key, as_numpy=True, as_numpy_mutable=False, meta_only=False): |
593 |
| - self.tensorget_processors.append(partial(processor.tensorget, |
594 |
| - as_numpy=as_numpy, |
595 |
| - as_numpy_mutable=as_numpy_mutable, |
596 |
| - meta_only=meta_only)) |
597 |
| - return super().tensorget(key, as_numpy, as_numpy_mutable, meta_only) |
598 |
| - |
599 |
| - def _execute_transaction(self, *args, **kwargs): |
600 |
| - # TODO: Blocking commands like MODELRUN, SCRIPTRUN and DAGRUN won't work |
601 |
| - res = super()._execute_transaction(*args, **kwargs) |
602 |
| - for i in range(len(res)): |
603 |
| - # tensorget will have minimum 4 values if meta_only = True |
604 |
| - if isinstance(res[i], list) and len(res[i]) >= 4: |
605 |
| - res[i] = self.tensorget_processors.pop(0)(res[i]) |
606 |
| - return res |
607 |
| - |
608 |
| - def _execute_pipeline(self, *args, **kwargs): |
609 |
| - res = super()._execute_pipeline(*args, **kwargs) |
610 |
| - for i in range(len(res)): |
611 |
| - # tensorget will have minimum 4 values if meta_only = True |
612 |
| - if isinstance(res[i], list) and len(res[i]) >= 4: |
613 |
| - res[i] = self.tensorget_processors.pop(0)(res[i]) |
614 |
| - return res |
615 |
| - |
616 |
| - |
617 |
| -class Dag: |
618 |
| - def __init__(self, load, persist, executor, readonly=False, postprocess=True): |
619 |
| - self.result_processors = [] |
620 |
| - self.enable_postprocess = True |
621 |
| - if readonly: |
622 |
| - if persist: |
623 |
| - raise RuntimeError("READONLY requests cannot write (duh!) and should not " |
624 |
| - "have PERSISTing values") |
625 |
| - self.commands = ['AI.DAGRUN_RO'] |
626 |
| - else: |
627 |
| - self.commands = ['AI.DAGRUN'] |
628 |
| - if load: |
629 |
| - if not isinstance(load, (list, tuple)): |
630 |
| - self.commands += ["LOAD", 1, load] |
631 |
| - else: |
632 |
| - self.commands += ["LOAD", len(load), *load] |
633 |
| - if persist: |
634 |
| - if not isinstance(persist, (list, tuple)): |
635 |
| - self.commands += ["PERSIST", 1, persist, '|>'] |
636 |
| - else: |
637 |
| - self.commands += ["PERSIST", len(persist), *persist, '|>'] |
638 |
| - else: |
639 |
| - self.commands.append('|>') |
640 |
| - self.executor = executor |
641 |
| - |
642 |
| - def tensorset(self, |
643 |
| - key: AnyStr, |
644 |
| - tensor: Union[np.ndarray, list, tuple], |
645 |
| - shape: Sequence[int] = None, |
646 |
| - dtype: str = None) -> Any: |
647 |
| - args = builder.tensorset(key, tensor, shape, dtype) |
648 |
| - self.commands.extend(args) |
649 |
| - self.commands.append("|>") |
650 |
| - self.result_processors.append(bytes.decode) |
651 |
| - return self |
652 |
| - |
653 |
| - def tensorget(self, |
654 |
| - key: AnyStr, as_numpy: bool = True, as_numpy_mutable: bool = False, |
655 |
| - meta_only: bool = False) -> Any: |
656 |
| - args = builder.tensorget(key, as_numpy, as_numpy_mutable) |
657 |
| - self.commands.extend(args) |
658 |
| - self.commands.append("|>") |
659 |
| - self.result_processors.append(partial(processor.tensorget, |
660 |
| - as_numpy=as_numpy, |
661 |
| - as_numpy_mutable=as_numpy_mutable, |
662 |
| - meta_only=meta_only)) |
663 |
| - return self |
664 |
| - |
665 |
| - def modelrun(self, |
666 |
| - key: AnyStr, |
667 |
| - inputs: Union[AnyStr, List[AnyStr]], |
668 |
| - outputs: Union[AnyStr, List[AnyStr]]) -> Any: |
669 |
| - args = builder.modelrun(key, inputs, outputs) |
670 |
| - self.commands.extend(args) |
671 |
| - self.commands.append("|>") |
672 |
| - self.result_processors.append(bytes.decode) |
673 |
| - return self |
674 |
| - |
675 |
| - def run(self): |
676 |
| - commands = self.commands[:-1] # removing the last "|> |
677 |
| - results = self.executor(*commands) |
678 |
| - if self.enable_postprocess: |
679 |
| - out = [] |
680 |
| - for res, fn in zip(results, self.result_processors): |
681 |
| - out.append(fn(res)) |
682 |
| - else: |
683 |
| - out = results |
684 |
| - return out |
685 |
| - |
686 |
| - |
687 | 582 | def enable_debug(f):
|
688 | 583 | @wraps(f)
|
689 | 584 | def wrapper(*args):
|
|
0 commit comments