Skip to content

Commit 184f521

Browse files
authored
Enable semantic router reference updates (#322)
This pr add methods so that you can easily get/add/delete route references.
1 parent dc657d2 commit 184f521

File tree

11 files changed

+555
-117
lines changed

11 files changed

+555
-117
lines changed

docs/user_guide/08_semantic_router.ipynb

+214-92
Large diffs are not rendered by default.

docs/user_guide/router.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ routes:
88
metadata:
99
category: tech
1010
priority: 1
11-
distance_threshold: 1.0
11+
distance_threshold: 0.71
1212
- name: sports
1313
references:
1414
- who won the game last night?
@@ -19,7 +19,7 @@ routes:
1919
metadata:
2020
category: sports
2121
priority: 2
22-
distance_threshold: 0.5
22+
distance_threshold: 0.72
2323
- name: entertainment
2424
references:
2525
- what are the top movies right now?

redisvl/extensions/router/schema.py

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def from_params(cls, name: str, vector_dims: int, dtype: str):
100100
return cls(
101101
index={"name": name, "prefix": name}, # type: ignore
102102
fields=[ # type: ignore
103+
{"name": "reference_id", "type": "tag"},
103104
{"name": "route_name", "type": "tag"},
104105
{"name": "reference", "type": "text"},
105106
{

redisvl/extensions/router/semantic.py

+201-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Any, Dict, List, Optional, Type
2+
from typing import Any, Dict, List, Optional, Type, Union
33

44
import redis.commands.search.reducers as reducers
55
import yaml
@@ -8,6 +8,7 @@
88
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
99
from redis.exceptions import ResponseError
1010

11+
from redisvl.exceptions import RedisModuleVersionError
1112
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
1213
from redisvl.extensions.router.schema import (
1314
DistanceAggregationMethod,
@@ -17,10 +18,12 @@
1718
SemanticRouterIndexSchema,
1819
)
1920
from redisvl.index import SearchIndex
20-
from redisvl.query import VectorRangeQuery
21+
from redisvl.query import FilterQuery, VectorRangeQuery
22+
from redisvl.query.filter import Tag
23+
from redisvl.redis.connection import RedisConnectionFactory
2124
from redisvl.redis.utils import convert_bytes, hashify, make_dict
2225
from redisvl.utils.log import get_logger
23-
from redisvl.utils.utils import deprecated_argument, model_to_dict
26+
from redisvl.utils.utils import deprecated_argument, model_to_dict, scan_by_pattern
2427
from redisvl.utils.vectorize.base import BaseVectorizer
2528
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
2629

@@ -98,9 +101,41 @@ def __init__(
98101
routes=routes,
99102
vectorizer=vectorizer,
100103
routing_config=routing_config,
104+
redis_url=redis_url,
105+
redis_client=redis_client,
101106
)
107+
102108
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)
103109

110+
self._index.client.json().set(f"{self.name}:route_config", f".", self.to_dict()) # type: ignore
111+
112+
@classmethod
113+
def from_existing(
114+
cls,
115+
name: str,
116+
redis_client: Optional[Redis] = None,
117+
redis_url: str = "redis://localhost:6379",
118+
**kwargs,
119+
) -> "SemanticRouter":
120+
"""Return SemanticRouter instance from existing index."""
121+
try:
122+
if redis_url:
123+
redis_client = RedisConnectionFactory.get_redis_connection(
124+
redis_url=redis_url,
125+
**kwargs,
126+
)
127+
elif redis_client:
128+
RedisConnectionFactory.validate_sync_redis(redis_client)
129+
except RedisModuleVersionError as e:
130+
raise RedisModuleVersionError(
131+
f"Loading from existing index failed. {str(e)}"
132+
)
133+
134+
router_dict = redis_client.json().get(f"{name}:route_config") # type: ignore
135+
return cls.from_dict(
136+
router_dict, redis_url=redis_url, redis_client=redis_client
137+
)
138+
104139
@deprecated_argument("dtype")
105140
def _initialize_index(
106141
self,
@@ -111,9 +146,11 @@ def _initialize_index(
111146
**connection_kwargs,
112147
):
113148
"""Initialize the search index and handle Redis connection."""
149+
114150
schema = SemanticRouterIndexSchema.from_params(
115151
self.name, self.vectorizer.dims, self.vectorizer.dtype # type: ignore
116152
)
153+
117154
self._index = SearchIndex(
118155
schema=schema,
119156
redis_client=redis_client,
@@ -174,10 +211,10 @@ def update_route_thresholds(self, route_thresholds: Dict[str, Optional[float]]):
174211
if route.name in route_thresholds:
175212
route.distance_threshold = route_thresholds[route.name] # type: ignore
176213

177-
def _route_ref_key(self, route_name: str, reference: str) -> str:
214+
@staticmethod
215+
def _route_ref_key(index: SearchIndex, route_name: str, reference_hash: str) -> str:
178216
"""Generate the route reference key."""
179-
reference_hash = hashify(reference)
180-
return f"{self._index.prefix}:{route_name}:{reference_hash}"
217+
return f"{index.prefix}:{route_name}:{reference_hash}"
181218

182219
def _add_routes(self, routes: List[Route]):
183220
"""Add routes to the router and index.
@@ -195,14 +232,18 @@ def _add_routes(self, routes: List[Route]):
195232
)
196233
# set route references
197234
for i, reference in enumerate(route.references):
235+
reference_hash = hashify(reference)
198236
route_references.append(
199237
{
238+
"reference_id": reference_hash,
200239
"route_name": route.name,
201240
"reference": reference,
202241
"vector": reference_vectors[i],
203242
}
204243
)
205-
keys.append(self._route_ref_key(route.name, reference))
244+
keys.append(
245+
self._route_ref_key(self._index, route.name, reference_hash)
246+
)
206247

207248
# set route if does not yet exist client side
208249
if not self.get(route.name):
@@ -438,7 +479,7 @@ def remove_route(self, route_name: str) -> None:
438479
else:
439480
self._index.drop_keys(
440481
[
441-
self._route_ref_key(route.name, reference)
482+
self._route_ref_key(self._index, route.name, hashify(reference))
442483
for reference in route.references
443484
]
444485
)
@@ -596,3 +637,155 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None:
596637
with open(fp, "w") as f:
597638
yaml_data = self.to_dict()
598639
yaml.dump(yaml_data, f, sort_keys=False)
640+
641+
# reference methods
642+
def add_route_references(
643+
self,
644+
route_name: str,
645+
references: Union[str, List[str]],
646+
) -> List[str]:
647+
"""Add a reference(s) to an existing route.
648+
649+
Args:
650+
router_name (str): The name of the router.
651+
references (Union[str, List[str]]): The reference or list of references to add.
652+
653+
Returns:
654+
List[str]: The list of added references keys.
655+
"""
656+
657+
if isinstance(references, str):
658+
references = [references]
659+
660+
route_references: List[Dict[str, Any]] = []
661+
keys: List[str] = []
662+
663+
# embed route references as a single batch
664+
reference_vectors = self.vectorizer.embed_many(references, as_buffer=True)
665+
666+
# set route references
667+
for i, reference in enumerate(references):
668+
reference_hash = hashify(reference)
669+
670+
route_references.append(
671+
{
672+
"reference_id": reference_hash,
673+
"route_name": route_name,
674+
"reference": reference,
675+
"vector": reference_vectors[i],
676+
}
677+
)
678+
keys.append(self._route_ref_key(self._index, route_name, reference_hash))
679+
680+
keys = self._index.load(route_references, keys=keys)
681+
682+
route = self.get(route_name)
683+
if not route:
684+
raise ValueError(f"Route {route_name} not found in the SemanticRouter")
685+
route.references.extend(references)
686+
self._update_router_state()
687+
return keys
688+
689+
@staticmethod
690+
def _make_filter_queries(ids: List[str]) -> List[FilterQuery]:
691+
"""Create a filter query for the given ids."""
692+
693+
queries = []
694+
695+
for id in ids:
696+
fe = Tag("reference_id") == id
697+
fq = FilterQuery(
698+
return_fields=["reference_id", "route_name", "reference"],
699+
filter_expression=fe,
700+
)
701+
queries.append(fq)
702+
703+
return queries
704+
705+
def get_route_references(
706+
self,
707+
route_name: str = "",
708+
reference_ids: List[str] = [],
709+
keys: List[str] = [],
710+
) -> List[Dict[str, Any]]:
711+
"""Get references for an existing route route.
712+
713+
Args:
714+
router_name (str): The name of the router.
715+
references (Union[str, List[str]]): The reference or list of references to add.
716+
717+
Returns:
718+
List[Dict[str, Any]]]: Reference objects stored
719+
"""
720+
721+
if reference_ids:
722+
queries = self._make_filter_queries(reference_ids)
723+
elif route_name:
724+
if not keys:
725+
keys = scan_by_pattern(
726+
self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore
727+
)
728+
729+
queries = self._make_filter_queries(
730+
[key.split(":")[-1] for key in convert_bytes(keys)]
731+
)
732+
else:
733+
raise ValueError(
734+
"Must provide a route name, reference ids, or keys to get references"
735+
)
736+
737+
res = self._index.batch_query(queries)
738+
739+
return [r[0] for r in res if len(r) > 0]
740+
741+
def delete_route_references(
742+
self,
743+
route_name: str = "",
744+
reference_ids: List[str] = [],
745+
keys: List[str] = [],
746+
) -> int:
747+
"""Get references for an existing semantic router route.
748+
749+
Args:
750+
router_name Optional(str): The name of the router.
751+
reference_ids Optional(List[str]]): The reference or list of references to delete.
752+
keys Optional(List[str]]): List of fully qualified keys (prefix:router:reference_id) to delete.
753+
754+
Returns:
755+
int: Number of objects deleted
756+
"""
757+
758+
if reference_ids and not keys:
759+
queries = self._make_filter_queries(reference_ids)
760+
res = self._index.batch_query(queries)
761+
keys = [r[0]["id"] for r in res if len(r) > 0]
762+
elif not keys:
763+
keys = scan_by_pattern(
764+
self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore
765+
)
766+
767+
if not keys:
768+
raise ValueError(f"No references found for route {route_name}")
769+
770+
to_be_deleted = []
771+
for key in keys:
772+
route_name = key.split(":")[-2]
773+
to_be_deleted.append(
774+
(route_name, convert_bytes(self._index.client.hgetall(key))) # type: ignore
775+
)
776+
777+
deleted = self._index.drop_keys(keys)
778+
779+
for route_name, delete in to_be_deleted:
780+
route = self.get(route_name)
781+
if not route:
782+
raise ValueError(f"Route {route_name} not found in the SemanticRouter")
783+
route.references.remove(delete["reference"])
784+
785+
self._update_router_state()
786+
787+
return deleted
788+
789+
def _update_router_state(self) -> None:
790+
"""Update the router configuration in Redis."""
791+
self._index.client.json().set(f"{self.name}:route_config", f".", self.to_dict()) # type: ignore

redisvl/index/index.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Iterable,
1515
List,
1616
Optional,
17+
Sequence,
1718
Tuple,
1819
Union,
1920
)
@@ -833,7 +834,7 @@ def search(self, *args, **kwargs) -> "Result":
833834
raise RedisSearchError(f"Error while searching: {str(e)}") from e
834835

835836
def batch_query(
836-
self, queries: List[BaseQuery], batch_size: int = 10
837+
self, queries: Sequence[BaseQuery], batch_size: int = 10
837838
) -> List[List[Dict[str, Any]]]:
838839
"""Execute a batch of queries and process results."""
839840
results = self.batch_search(

redisvl/utils/optimize/router.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -
1818
run_dict[td.id] = {}
1919
route_match = router(td.query)
2020
if route_match and route_match.name == td.query_match:
21-
run_dict[td.id][td.query_match] = 1
21+
run_dict[td.id][td.query_match] = np.int64(1)
2222
else:
23-
run_dict[td.id][NULL_RESPONSE_KEY] = 1
23+
run_dict[td.id][NULL_RESPONSE_KEY] = np.int64(1)
2424

2525
return Run(run_dict)
2626

redisvl/utils/optimize/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List
22

3+
import numpy as np
34
from ranx import Qrels
45

56
from redisvl.utils.optimize.schema import LabeledData
@@ -13,10 +14,10 @@ def _format_qrels(test_data: List[LabeledData]) -> Qrels:
1314

1415
for td in test_data:
1516
if td.query_match:
16-
qrels_dict[td.id] = {td.query_match: 1}
17+
qrels_dict[td.id] = {td.query_match: np.int64(1)}
1718
else:
1819
# This is for capturing true negatives from test set
19-
qrels_dict[td.id] = {NULL_RESPONSE_KEY: 1}
20+
qrels_dict[td.id] = {NULL_RESPONSE_KEY: np.int64(1)}
2021

2122
return Qrels(qrels_dict)
2223

redisvl/utils/utils.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from enum import Enum
88
from functools import wraps
99
from time import time
10-
from typing import Any, Callable, Coroutine, Dict, Optional
10+
from typing import Any, Callable, Coroutine, Dict, Optional, Sequence
1111
from warnings import warn
1212

1313
from pydantic import BaseModel
14+
from redis import Redis
1415
from ulid import ULID
1516

1617

@@ -213,3 +214,22 @@ def norm_l2_distance(value: float) -> float:
213214
Normalize the L2 distance.
214215
"""
215216
return 1 / (1 + value)
217+
218+
219+
def scan_by_pattern(
220+
redis_client: Redis,
221+
pattern: str,
222+
) -> Sequence[str]:
223+
"""
224+
Scan the Redis database for keys matching a specific pattern.
225+
226+
Args:
227+
redis (Redis): The Redis client instance.
228+
pattern (str): The pattern to match keys against.
229+
230+
Returns:
231+
List[str]: A dictionary containing the keys and their values.
232+
"""
233+
from redisvl.redis.utils import convert_bytes
234+
235+
return convert_bytes(list(redis_client.scan_iter(match=pattern)))

schemas/semantic_router.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: test-router
1+
name: test-router-01JSHK4MJ79HH51PS6WEK6M9MF
22
routes:
33
- name: greeting
44
references:

0 commit comments

Comments
 (0)