|
3 | 3 | import warnings |
4 | 4 |
|
5 | 5 | from redis import StrictRedis |
6 | | -from redis.client import Pipeline as RedisPipeline |
7 | 6 | import numpy as np |
8 | 7 |
|
9 | | -from . import command_builder as builder |
10 | | -from .postprocessor import Processor |
| 8 | +from redisai import command_builder as builder |
| 9 | +from redisai.dag import Dag |
| 10 | +from redisai.pipeline import Pipeline |
| 11 | +from redisai.postprocessor import Processor |
11 | 12 |
|
12 | 13 |
|
13 | 14 | processor = Processor() |
@@ -37,6 +38,9 @@ class Client(StrictRedis): |
37 | 38 | >>> from redisai import Client |
38 | 39 | >>> con = Client(host='localhost', port=6379) |
39 | 40 | """ |
| 41 | + REDISAI_COMMANDS_RESPONSE_CALLBACKS = { |
| 42 | + } |
| 43 | + |
40 | 44 | def __init__(self, debug=False, enable_postprocess=True, *args, **kwargs): |
41 | 45 | super().__init__(*args, **kwargs) |
42 | 46 | if debug: |
@@ -575,115 +579,6 @@ def inforeset(self, key: AnyStr) -> str: |
575 | 579 | return res if not self.enable_postprocess else processor.inforeset(res) |
576 | 580 |
|
577 | 581 |
|
578 | | -class Pipeline(RedisPipeline, Client): |
579 | | - def __init__(self, enable_postprocess, *args, **kwargs): |
580 | | - warnings.warn("Pipeling AI commands through this client is experimental.", |
581 | | - UserWarning) |
582 | | - self.enable_postprocess = False |
583 | | - if enable_postprocess: |
584 | | - warnings.warn("Postprocessing is enabled but not allowed in pipelines." |
585 | | - "Disable postprocessing to remove this warning.", UserWarning) |
586 | | - self.tensorget_processors = [] |
587 | | - super().__init__(*args, **kwargs) |
588 | | - |
589 | | - def dag(self, *args, **kwargs): |
590 | | - raise RuntimeError("Pipeline object doesn't allow DAG creation currently") |
591 | | - |
592 | | - def tensorget(self, key, as_numpy=True, as_numpy_mutable=False, meta_only=False): |
593 | | - self.tensorget_processors.append(partial(processor.tensorget, |
594 | | - as_numpy=as_numpy, |
595 | | - as_numpy_mutable=as_numpy_mutable, |
596 | | - meta_only=meta_only)) |
597 | | - return super().tensorget(key, as_numpy, as_numpy_mutable, meta_only) |
598 | | - |
599 | | - def _execute_transaction(self, *args, **kwargs): |
600 | | - # TODO: Blocking commands like MODELRUN, SCRIPTRUN and DAGRUN won't work |
601 | | - res = super()._execute_transaction(*args, **kwargs) |
602 | | - for i in range(len(res)): |
603 | | - # tensorget will have minimum 4 values if meta_only = True |
604 | | - if isinstance(res[i], list) and len(res[i]) >= 4: |
605 | | - res[i] = self.tensorget_processors.pop(0)(res[i]) |
606 | | - return res |
607 | | - |
608 | | - def _execute_pipeline(self, *args, **kwargs): |
609 | | - res = super()._execute_pipeline(*args, **kwargs) |
610 | | - for i in range(len(res)): |
611 | | - # tensorget will have minimum 4 values if meta_only = True |
612 | | - if isinstance(res[i], list) and len(res[i]) >= 4: |
613 | | - res[i] = self.tensorget_processors.pop(0)(res[i]) |
614 | | - return res |
615 | | - |
616 | | - |
617 | | -class Dag: |
618 | | - def __init__(self, load, persist, executor, readonly=False, postprocess=True): |
619 | | - self.result_processors = [] |
620 | | - self.enable_postprocess = True |
621 | | - if readonly: |
622 | | - if persist: |
623 | | - raise RuntimeError("READONLY requests cannot write (duh!) and should not " |
624 | | - "have PERSISTing values") |
625 | | - self.commands = ['AI.DAGRUN_RO'] |
626 | | - else: |
627 | | - self.commands = ['AI.DAGRUN'] |
628 | | - if load: |
629 | | - if not isinstance(load, (list, tuple)): |
630 | | - self.commands += ["LOAD", 1, load] |
631 | | - else: |
632 | | - self.commands += ["LOAD", len(load), *load] |
633 | | - if persist: |
634 | | - if not isinstance(persist, (list, tuple)): |
635 | | - self.commands += ["PERSIST", 1, persist, '|>'] |
636 | | - else: |
637 | | - self.commands += ["PERSIST", len(persist), *persist, '|>'] |
638 | | - else: |
639 | | - self.commands.append('|>') |
640 | | - self.executor = executor |
641 | | - |
642 | | - def tensorset(self, |
643 | | - key: AnyStr, |
644 | | - tensor: Union[np.ndarray, list, tuple], |
645 | | - shape: Sequence[int] = None, |
646 | | - dtype: str = None) -> Any: |
647 | | - args = builder.tensorset(key, tensor, shape, dtype) |
648 | | - self.commands.extend(args) |
649 | | - self.commands.append("|>") |
650 | | - self.result_processors.append(bytes.decode) |
651 | | - return self |
652 | | - |
653 | | - def tensorget(self, |
654 | | - key: AnyStr, as_numpy: bool = True, as_numpy_mutable: bool = False, |
655 | | - meta_only: bool = False) -> Any: |
656 | | - args = builder.tensorget(key, as_numpy, as_numpy_mutable) |
657 | | - self.commands.extend(args) |
658 | | - self.commands.append("|>") |
659 | | - self.result_processors.append(partial(processor.tensorget, |
660 | | - as_numpy=as_numpy, |
661 | | - as_numpy_mutable=as_numpy_mutable, |
662 | | - meta_only=meta_only)) |
663 | | - return self |
664 | | - |
665 | | - def modelrun(self, |
666 | | - key: AnyStr, |
667 | | - inputs: Union[AnyStr, List[AnyStr]], |
668 | | - outputs: Union[AnyStr, List[AnyStr]]) -> Any: |
669 | | - args = builder.modelrun(key, inputs, outputs) |
670 | | - self.commands.extend(args) |
671 | | - self.commands.append("|>") |
672 | | - self.result_processors.append(bytes.decode) |
673 | | - return self |
674 | | - |
675 | | - def run(self): |
676 | | - commands = self.commands[:-1] # removing the last "|> |
677 | | - results = self.executor(*commands) |
678 | | - if self.enable_postprocess: |
679 | | - out = [] |
680 | | - for res, fn in zip(results, self.result_processors): |
681 | | - out.append(fn(res)) |
682 | | - else: |
683 | | - out = results |
684 | | - return out |
685 | | - |
686 | | - |
687 | 582 | def enable_debug(f): |
688 | 583 | @wraps(f) |
689 | 584 | def wrapper(*args): |
|
0 commit comments