1
1
from pathlib import Path
2
- from typing import Any , Dict , List , Optional , Type
2
+ from typing import Any , Dict , List , Optional , Type , Union
3
3
4
4
import redis .commands .search .reducers as reducers
5
5
import yaml
8
8
from redis .commands .search .aggregation import AggregateRequest , AggregateResult , Reducer
9
9
from redis .exceptions import ResponseError
10
10
11
+ from redisvl .exceptions import RedisModuleVersionError
11
12
from redisvl .extensions .constants import ROUTE_VECTOR_FIELD_NAME
12
13
from redisvl .extensions .router .schema import (
13
14
DistanceAggregationMethod ,
17
18
SemanticRouterIndexSchema ,
18
19
)
19
20
from redisvl .index import SearchIndex
20
- from redisvl .query import VectorRangeQuery
21
+ from redisvl .query import FilterQuery , VectorRangeQuery
22
+ from redisvl .query .filter import Tag
23
+ from redisvl .redis .connection import RedisConnectionFactory
21
24
from redisvl .redis .utils import convert_bytes , hashify , make_dict
22
25
from redisvl .utils .log import get_logger
23
- from redisvl .utils .utils import deprecated_argument , model_to_dict
26
+ from redisvl .utils .utils import deprecated_argument , model_to_dict , scan_by_pattern
24
27
from redisvl .utils .vectorize .base import BaseVectorizer
25
28
from redisvl .utils .vectorize .text .huggingface import HFTextVectorizer
26
29
@@ -98,9 +101,41 @@ def __init__(
98
101
routes = routes ,
99
102
vectorizer = vectorizer ,
100
103
routing_config = routing_config ,
104
+ redis_url = redis_url ,
105
+ redis_client = redis_client ,
101
106
)
107
+
102
108
self ._initialize_index (redis_client , redis_url , overwrite , ** connection_kwargs )
103
109
110
+ self ._index .client .json ().set (f"{ self .name } :route_config" , f"." , self .to_dict ()) # type: ignore
111
+
112
+ @classmethod
113
+ def from_existing (
114
+ cls ,
115
+ name : str ,
116
+ redis_client : Optional [Redis ] = None ,
117
+ redis_url : str = "redis://localhost:6379" ,
118
+ ** kwargs ,
119
+ ) -> "SemanticRouter" :
120
+ """Return SemanticRouter instance from existing index."""
121
+ try :
122
+ if redis_url :
123
+ redis_client = RedisConnectionFactory .get_redis_connection (
124
+ redis_url = redis_url ,
125
+ ** kwargs ,
126
+ )
127
+ elif redis_client :
128
+ RedisConnectionFactory .validate_sync_redis (redis_client )
129
+ except RedisModuleVersionError as e :
130
+ raise RedisModuleVersionError (
131
+ f"Loading from existing index failed. { str (e )} "
132
+ )
133
+
134
+ router_dict = redis_client .json ().get (f"{ name } :route_config" ) # type: ignore
135
+ return cls .from_dict (
136
+ router_dict , redis_url = redis_url , redis_client = redis_client
137
+ )
138
+
104
139
@deprecated_argument ("dtype" )
105
140
def _initialize_index (
106
141
self ,
@@ -111,9 +146,11 @@ def _initialize_index(
111
146
** connection_kwargs ,
112
147
):
113
148
"""Initialize the search index and handle Redis connection."""
149
+
114
150
schema = SemanticRouterIndexSchema .from_params (
115
151
self .name , self .vectorizer .dims , self .vectorizer .dtype # type: ignore
116
152
)
153
+
117
154
self ._index = SearchIndex (
118
155
schema = schema ,
119
156
redis_client = redis_client ,
@@ -174,10 +211,10 @@ def update_route_thresholds(self, route_thresholds: Dict[str, Optional[float]]):
174
211
if route .name in route_thresholds :
175
212
route .distance_threshold = route_thresholds [route .name ] # type: ignore
176
213
177
- def _route_ref_key (self , route_name : str , reference : str ) -> str :
214
+ @staticmethod
215
+ def _route_ref_key (index : SearchIndex , route_name : str , reference_hash : str ) -> str :
178
216
"""Generate the route reference key."""
179
- reference_hash = hashify (reference )
180
- return f"{ self ._index .prefix } :{ route_name } :{ reference_hash } "
217
+ return f"{ index .prefix } :{ route_name } :{ reference_hash } "
181
218
182
219
def _add_routes (self , routes : List [Route ]):
183
220
"""Add routes to the router and index.
@@ -195,14 +232,18 @@ def _add_routes(self, routes: List[Route]):
195
232
)
196
233
# set route references
197
234
for i , reference in enumerate (route .references ):
235
+ reference_hash = hashify (reference )
198
236
route_references .append (
199
237
{
238
+ "reference_id" : reference_hash ,
200
239
"route_name" : route .name ,
201
240
"reference" : reference ,
202
241
"vector" : reference_vectors [i ],
203
242
}
204
243
)
205
- keys .append (self ._route_ref_key (route .name , reference ))
244
+ keys .append (
245
+ self ._route_ref_key (self ._index , route .name , reference_hash )
246
+ )
206
247
207
248
# set route if does not yet exist client side
208
249
if not self .get (route .name ):
@@ -438,7 +479,7 @@ def remove_route(self, route_name: str) -> None:
438
479
else :
439
480
self ._index .drop_keys (
440
481
[
441
- self ._route_ref_key (route .name , reference )
482
+ self ._route_ref_key (self . _index , route .name , hashify ( reference ) )
442
483
for reference in route .references
443
484
]
444
485
)
@@ -596,3 +637,155 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None:
596
637
with open (fp , "w" ) as f :
597
638
yaml_data = self .to_dict ()
598
639
yaml .dump (yaml_data , f , sort_keys = False )
640
+
641
+ # reference methods
642
+ def add_route_references (
643
+ self ,
644
+ route_name : str ,
645
+ references : Union [str , List [str ]],
646
+ ) -> List [str ]:
647
+ """Add a reference(s) to an existing route.
648
+
649
+ Args:
650
+ router_name (str): The name of the router.
651
+ references (Union[str, List[str]]): The reference or list of references to add.
652
+
653
+ Returns:
654
+ List[str]: The list of added references keys.
655
+ """
656
+
657
+ if isinstance (references , str ):
658
+ references = [references ]
659
+
660
+ route_references : List [Dict [str , Any ]] = []
661
+ keys : List [str ] = []
662
+
663
+ # embed route references as a single batch
664
+ reference_vectors = self .vectorizer .embed_many (references , as_buffer = True )
665
+
666
+ # set route references
667
+ for i , reference in enumerate (references ):
668
+ reference_hash = hashify (reference )
669
+
670
+ route_references .append (
671
+ {
672
+ "reference_id" : reference_hash ,
673
+ "route_name" : route_name ,
674
+ "reference" : reference ,
675
+ "vector" : reference_vectors [i ],
676
+ }
677
+ )
678
+ keys .append (self ._route_ref_key (self ._index , route_name , reference_hash ))
679
+
680
+ keys = self ._index .load (route_references , keys = keys )
681
+
682
+ route = self .get (route_name )
683
+ if not route :
684
+ raise ValueError (f"Route { route_name } not found in the SemanticRouter" )
685
+ route .references .extend (references )
686
+ self ._update_router_state ()
687
+ return keys
688
+
689
+ @staticmethod
690
+ def _make_filter_queries (ids : List [str ]) -> List [FilterQuery ]:
691
+ """Create a filter query for the given ids."""
692
+
693
+ queries = []
694
+
695
+ for id in ids :
696
+ fe = Tag ("reference_id" ) == id
697
+ fq = FilterQuery (
698
+ return_fields = ["reference_id" , "route_name" , "reference" ],
699
+ filter_expression = fe ,
700
+ )
701
+ queries .append (fq )
702
+
703
+ return queries
704
+
705
+ def get_route_references (
706
+ self ,
707
+ route_name : str = "" ,
708
+ reference_ids : List [str ] = [],
709
+ keys : List [str ] = [],
710
+ ) -> List [Dict [str , Any ]]:
711
+ """Get references for an existing route route.
712
+
713
+ Args:
714
+ router_name (str): The name of the router.
715
+ references (Union[str, List[str]]): The reference or list of references to add.
716
+
717
+ Returns:
718
+ List[Dict[str, Any]]]: Reference objects stored
719
+ """
720
+
721
+ if reference_ids :
722
+ queries = self ._make_filter_queries (reference_ids )
723
+ elif route_name :
724
+ if not keys :
725
+ keys = scan_by_pattern (
726
+ self ._index .client , f"{ self ._index .prefix } :{ route_name } :*" # type: ignore
727
+ )
728
+
729
+ queries = self ._make_filter_queries (
730
+ [key .split (":" )[- 1 ] for key in convert_bytes (keys )]
731
+ )
732
+ else :
733
+ raise ValueError (
734
+ "Must provide a route name, reference ids, or keys to get references"
735
+ )
736
+
737
+ res = self ._index .batch_query (queries )
738
+
739
+ return [r [0 ] for r in res if len (r ) > 0 ]
740
+
741
+ def delete_route_references (
742
+ self ,
743
+ route_name : str = "" ,
744
+ reference_ids : List [str ] = [],
745
+ keys : List [str ] = [],
746
+ ) -> int :
747
+ """Get references for an existing semantic router route.
748
+
749
+ Args:
750
+ router_name Optional(str): The name of the router.
751
+ reference_ids Optional(List[str]]): The reference or list of references to delete.
752
+ keys Optional(List[str]]): List of fully qualified keys (prefix:router:reference_id) to delete.
753
+
754
+ Returns:
755
+ int: Number of objects deleted
756
+ """
757
+
758
+ if reference_ids and not keys :
759
+ queries = self ._make_filter_queries (reference_ids )
760
+ res = self ._index .batch_query (queries )
761
+ keys = [r [0 ]["id" ] for r in res if len (r ) > 0 ]
762
+ elif not keys :
763
+ keys = scan_by_pattern (
764
+ self ._index .client , f"{ self ._index .prefix } :{ route_name } :*" # type: ignore
765
+ )
766
+
767
+ if not keys :
768
+ raise ValueError (f"No references found for route { route_name } " )
769
+
770
+ to_be_deleted = []
771
+ for key in keys :
772
+ route_name = key .split (":" )[- 2 ]
773
+ to_be_deleted .append (
774
+ (route_name , convert_bytes (self ._index .client .hgetall (key ))) # type: ignore
775
+ )
776
+
777
+ deleted = self ._index .drop_keys (keys )
778
+
779
+ for route_name , delete in to_be_deleted :
780
+ route = self .get (route_name )
781
+ if not route :
782
+ raise ValueError (f"Route { route_name } not found in the SemanticRouter" )
783
+ route .references .remove (delete ["reference" ])
784
+
785
+ self ._update_router_state ()
786
+
787
+ return deleted
788
+
789
+ def _update_router_state (self ) -> None :
790
+ """Update the router configuration in Redis."""
791
+ self ._index .client .json ().set (f"{ self .name } :route_config" , f"." , self .to_dict ()) # type: ignore
0 commit comments