Skip to content

Enable semantic router reference updates #322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 214 additions & 92 deletions docs/user_guide/08_semantic_router.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/user_guide/router.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ routes:
metadata:
category: tech
priority: 1
distance_threshold: 1.0
distance_threshold: 0.71
- name: sports
references:
- who won the game last night?
Expand All @@ -19,7 +19,7 @@ routes:
metadata:
category: sports
priority: 2
distance_threshold: 0.5
distance_threshold: 0.72
- name: entertainment
references:
- what are the top movies right now?
Expand Down
1 change: 1 addition & 0 deletions redisvl/extensions/router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def from_params(cls, name: str, vector_dims: int, dtype: str):
return cls(
index={"name": name, "prefix": name}, # type: ignore
fields=[ # type: ignore
{"name": "reference_id", "type": "tag"},
{"name": "route_name", "type": "tag"},
{"name": "reference", "type": "text"},
{
Expand Down
209 changes: 201 additions & 8 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Type, Union

import redis.commands.search.reducers as reducers
import yaml
Expand All @@ -8,6 +8,7 @@
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
from redis.exceptions import ResponseError

from redisvl.exceptions import RedisModuleVersionError
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
from redisvl.extensions.router.schema import (
DistanceAggregationMethod,
Expand All @@ -17,10 +18,12 @@
SemanticRouterIndexSchema,
)
from redisvl.index import SearchIndex
from redisvl.query import VectorRangeQuery
from redisvl.query import FilterQuery, VectorRangeQuery
from redisvl.query.filter import Tag
from redisvl.redis.connection import RedisConnectionFactory
from redisvl.redis.utils import convert_bytes, hashify, make_dict
from redisvl.utils.log import get_logger
from redisvl.utils.utils import deprecated_argument, model_to_dict
from redisvl.utils.utils import deprecated_argument, model_to_dict, scan_by_pattern
from redisvl.utils.vectorize.base import BaseVectorizer
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer

Expand Down Expand Up @@ -98,9 +101,41 @@ def __init__(
routes=routes,
vectorizer=vectorizer,
routing_config=routing_config,
redis_url=redis_url,
redis_client=redis_client,
)

self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)

self._index.client.json().set(f"{self.name}:route_config", f".", self.to_dict()) # type: ignore

@classmethod
def from_existing(
cls,
name: str,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
**kwargs,
) -> "SemanticRouter":
"""Return SemanticRouter instance from existing index."""
try:
if redis_url:
redis_client = RedisConnectionFactory.get_redis_connection(
redis_url=redis_url,
**kwargs,
)
elif redis_client:
RedisConnectionFactory.validate_sync_redis(redis_client)
except RedisModuleVersionError as e:
raise RedisModuleVersionError(
f"Loading from existing index failed. {str(e)}"
)

router_dict = redis_client.json().get(f"{name}:route_config") # type: ignore
return cls.from_dict(
router_dict, redis_url=redis_url, redis_client=redis_client
)

@deprecated_argument("dtype")
def _initialize_index(
self,
Expand All @@ -111,9 +146,11 @@ def _initialize_index(
**connection_kwargs,
):
"""Initialize the search index and handle Redis connection."""

schema = SemanticRouterIndexSchema.from_params(
self.name, self.vectorizer.dims, self.vectorizer.dtype # type: ignore
)

self._index = SearchIndex(
schema=schema,
redis_client=redis_client,
Expand Down Expand Up @@ -174,10 +211,10 @@ def update_route_thresholds(self, route_thresholds: Dict[str, Optional[float]]):
if route.name in route_thresholds:
route.distance_threshold = route_thresholds[route.name] # type: ignore

def _route_ref_key(self, route_name: str, reference: str) -> str:
@staticmethod
def _route_ref_key(index: SearchIndex, route_name: str, reference_hash: str) -> str:
"""Generate the route reference key."""
reference_hash = hashify(reference)
return f"{self._index.prefix}:{route_name}:{reference_hash}"
return f"{index.prefix}:{route_name}:{reference_hash}"

def _add_routes(self, routes: List[Route]):
"""Add routes to the router and index.
Expand All @@ -195,14 +232,18 @@ def _add_routes(self, routes: List[Route]):
)
# set route references
for i, reference in enumerate(route.references):
reference_hash = hashify(reference)
route_references.append(
{
"reference_id": reference_hash,
"route_name": route.name,
"reference": reference,
"vector": reference_vectors[i],
}
)
keys.append(self._route_ref_key(route.name, reference))
keys.append(
self._route_ref_key(self._index, route.name, reference_hash)
)

# set route if does not yet exist client side
if not self.get(route.name):
Expand Down Expand Up @@ -438,7 +479,7 @@ def remove_route(self, route_name: str) -> None:
else:
self._index.drop_keys(
[
self._route_ref_key(route.name, reference)
self._route_ref_key(self._index, route.name, hashify(reference))
for reference in route.references
]
)
Expand Down Expand Up @@ -596,3 +637,155 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None:
with open(fp, "w") as f:
yaml_data = self.to_dict()
yaml.dump(yaml_data, f, sort_keys=False)

# reference methods
def add_route_references(
self,
route_name: str,
references: Union[str, List[str]],
) -> List[str]:
"""Add a reference(s) to an existing route.

Args:
router_name (str): The name of the router.
references (Union[str, List[str]]): The reference or list of references to add.

Returns:
List[str]: The list of added references keys.
"""

if isinstance(references, str):
references = [references]

route_references: List[Dict[str, Any]] = []
keys: List[str] = []

# embed route references as a single batch
reference_vectors = self.vectorizer.embed_many(references, as_buffer=True)

# set route references
for i, reference in enumerate(references):
reference_hash = hashify(reference)

route_references.append(
{
"reference_id": reference_hash,
"route_name": route_name,
"reference": reference,
"vector": reference_vectors[i],
}
)
keys.append(self._route_ref_key(self._index, route_name, reference_hash))

keys = self._index.load(route_references, keys=keys)

route = self.get(route_name)
if not route:
raise ValueError(f"Route {route_name} not found in the SemanticRouter")
route.references.extend(references)
self._update_router_state()
return keys

@staticmethod
def _make_filter_queries(ids: List[str]) -> List[FilterQuery]:
"""Create a filter query for the given ids."""

queries = []

for id in ids:
fe = Tag("reference_id") == id
fq = FilterQuery(
return_fields=["reference_id", "route_name", "reference"],
filter_expression=fe,
)
queries.append(fq)

return queries

def get_route_references(
self,
route_name: str = "",
reference_ids: List[str] = [],
keys: List[str] = [],
) -> List[Dict[str, Any]]:
"""Get references for an existing route route.

Args:
router_name (str): The name of the router.
references (Union[str, List[str]]): The reference or list of references to add.

Returns:
List[Dict[str, Any]]]: Reference objects stored
"""

if reference_ids:
queries = self._make_filter_queries(reference_ids)
elif route_name:
if not keys:
keys = scan_by_pattern(
self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore
)

queries = self._make_filter_queries(
[key.split(":")[-1] for key in convert_bytes(keys)]
)
else:
raise ValueError(
"Must provide a route name, reference ids, or keys to get references"
)

res = self._index.batch_query(queries)

return [r[0] for r in res if len(r) > 0]

def delete_route_references(
self,
route_name: str = "",
reference_ids: List[str] = [],
keys: List[str] = [],
) -> int:
"""Get references for an existing semantic router route.

Args:
router_name Optional(str): The name of the router.
reference_ids Optional(List[str]]): The reference or list of references to delete.
keys Optional(List[str]]): List of fully qualified keys (prefix:router:reference_id) to delete.

Returns:
int: Number of objects deleted
"""

if reference_ids and not keys:
queries = self._make_filter_queries(reference_ids)
res = self._index.batch_query(queries)
keys = [r[0]["id"] for r in res if len(r) > 0]
elif not keys:
keys = scan_by_pattern(
self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore
)

if not keys:
raise ValueError(f"No references found for route {route_name}")

to_be_deleted = []
for key in keys:
route_name = key.split(":")[-2]
to_be_deleted.append(
(route_name, convert_bytes(self._index.client.hgetall(key))) # type: ignore
)

deleted = self._index.drop_keys(keys)

for route_name, delete in to_be_deleted:
route = self.get(route_name)
if not route:
raise ValueError(f"Route {route_name} not found in the SemanticRouter")
route.references.remove(delete["reference"])

self._update_router_state()

return deleted

def _update_router_state(self) -> None:
"""Update the router configuration in Redis."""
self._index.client.json().set(f"{self.name}:route_config", f".", self.to_dict()) # type: ignore
3 changes: 2 additions & 1 deletion redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)
Expand Down Expand Up @@ -833,7 +834,7 @@ def search(self, *args, **kwargs) -> "Result":
raise RedisSearchError(f"Error while searching: {str(e)}") from e

def batch_query(
self, queries: List[BaseQuery], batch_size: int = 10
self, queries: Sequence[BaseQuery], batch_size: int = 10
) -> List[List[Dict[str, Any]]]:
"""Execute a batch of queries and process results."""
results = self.batch_search(
Expand Down
4 changes: 2 additions & 2 deletions redisvl/utils/optimize/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -
run_dict[td.id] = {}
route_match = router(td.query)
if route_match and route_match.name == td.query_match:
run_dict[td.id][td.query_match] = 1
run_dict[td.id][td.query_match] = np.int64(1)
else:
run_dict[td.id][NULL_RESPONSE_KEY] = 1
run_dict[td.id][NULL_RESPONSE_KEY] = np.int64(1)

return Run(run_dict)

Expand Down
5 changes: 3 additions & 2 deletions redisvl/utils/optimize/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List

import numpy as np
from ranx import Qrels

from redisvl.utils.optimize.schema import LabeledData
Expand All @@ -13,10 +14,10 @@ def _format_qrels(test_data: List[LabeledData]) -> Qrels:

for td in test_data:
if td.query_match:
qrels_dict[td.id] = {td.query_match: 1}
qrels_dict[td.id] = {td.query_match: np.int64(1)}
else:
# This is for capturing true negatives from test set
qrels_dict[td.id] = {NULL_RESPONSE_KEY: 1}
qrels_dict[td.id] = {NULL_RESPONSE_KEY: np.int64(1)}

return Qrels(qrels_dict)

Expand Down
22 changes: 21 additions & 1 deletion redisvl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from enum import Enum
from functools import wraps
from time import time
from typing import Any, Callable, Coroutine, Dict, Optional
from typing import Any, Callable, Coroutine, Dict, Optional, Sequence
from warnings import warn

from pydantic import BaseModel
from redis import Redis
from ulid import ULID


Expand Down Expand Up @@ -213,3 +214,22 @@ def norm_l2_distance(value: float) -> float:
Normalize the L2 distance.
"""
return 1 / (1 + value)


def scan_by_pattern(
redis_client: Redis,
pattern: str,
) -> Sequence[str]:
"""
Scan the Redis database for keys matching a specific pattern.

Args:
redis (Redis): The Redis client instance.
pattern (str): The pattern to match keys against.

Returns:
List[str]: A dictionary containing the keys and their values.
"""
from redisvl.redis.utils import convert_bytes

return convert_bytes(list(redis_client.scan_iter(match=pattern)))
2 changes: 1 addition & 1 deletion schemas/semantic_router.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: test-router
name: test-router-01JSHK4MJ79HH51PS6WEK6M9MF
routes:
- name: greeting
references:
Expand Down
Loading