Skip to content

Commit 95ffe75

Browse files
Enable embedding caching on all vectorizers (#320)
Adds support to the `BaseVectorizer` class to have an optional `EmbeddingsCache` attached. - Refactored the subclass vectorizers to implement private embed methods and then let the base class handle the cache wrapper logic. - Fixed some circular imports. - Fixed async client handling in the cache subclasses (caught during testing). - Handle some typing checks and pydantic stuff related to private attrs and custom attrs. TODO in a separate PR: - Add embeddings caching to our testing suite (CI/CD speed up??) - Add embeddings caching to our SemanticRouter
1 parent 2e8b167 commit 95ffe75

20 files changed

+1879
-1399
lines changed

docs/user_guide/10_embeddings_cache.ipynb

+47-71
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,23 @@
5151
},
5252
{
5353
"cell_type": "code",
54-
"execution_count": null,
54+
"execution_count": 2,
5555
"metadata": {},
56-
"outputs": [],
56+
"outputs": [
57+
{
58+
"name": "stderr",
59+
"output_type": "stream",
60+
"text": [
61+
"/Users/tyler.hutcherson/Library/Caches/pypoetry/virtualenvs/redisvl-VnTEShF2-py3.13/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
62+
" from .autonotebook import tqdm as notebook_tqdm\n",
63+
"Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. Falling back to non-compiled mode.\n"
64+
]
65+
}
66+
],
5767
"source": [
5868
"# Initialize the vectorizer\n",
5969
"vectorizer = HFTextVectorizer(\n",
60-
" model=\"sentence-transformers/all-mpnet-base-v2\",\n",
70+
" model=\"redis/langcache-embed-v1\",\n",
6171
" cache_folder=os.getenv(\"SENTENCE_TRANSFORMERS_HOME\")\n",
6272
")"
6373
]
@@ -103,21 +113,21 @@
103113
},
104114
{
105115
"cell_type": "code",
106-
"execution_count": 4,
116+
"execution_count": 5,
107117
"metadata": {},
108118
"outputs": [
109119
{
110120
"name": "stdout",
111121
"output_type": "stream",
112122
"text": [
113-
"Stored with key: embedcache:059d...\n"
123+
"Stored with key: embedcache:909f...\n"
114124
]
115125
}
116126
],
117127
"source": [
118128
"# Text to embed\n",
119129
"text = \"What is machine learning?\"\n",
120-
"model_name = \"sentence-transformers/all-mpnet-base-v2\"\n",
130+
"model_name = \"redis/langcache-embed-v1\"\n",
121131
"\n",
122132
"# Generate the embedding\n",
123133
"embedding = vectorizer.embed(text)\n",
@@ -147,15 +157,15 @@
147157
},
148158
{
149159
"cell_type": "code",
150-
"execution_count": 5,
160+
"execution_count": 6,
151161
"metadata": {},
152162
"outputs": [
153163
{
154164
"name": "stdout",
155165
"output_type": "stream",
156166
"text": [
157167
"Found in cache: What is machine learning?\n",
158-
"Model: sentence-transformers/all-mpnet-base-v2\n",
168+
"Model: redis/langcache-embed-v1\n",
159169
"Metadata: {'category': 'ai', 'source': 'user_query'}\n",
160170
"Embedding shape: (768,)\n"
161171
]
@@ -184,7 +194,7 @@
184194
},
185195
{
186196
"cell_type": "code",
187-
"execution_count": 6,
197+
"execution_count": 7,
188198
"metadata": {},
189199
"outputs": [
190200
{
@@ -218,7 +228,7 @@
218228
},
219229
{
220230
"cell_type": "code",
221-
"execution_count": 7,
231+
"execution_count": 8,
222232
"metadata": {},
223233
"outputs": [
224234
{
@@ -251,14 +261,14 @@
251261
},
252262
{
253263
"cell_type": "code",
254-
"execution_count": 8,
264+
"execution_count": 9,
255265
"metadata": {},
256266
"outputs": [
257267
{
258268
"name": "stdout",
259269
"output_type": "stream",
260270
"text": [
261-
"Stored with key: embedcache:059d...\n",
271+
"Stored with key: embedcache:909f...\n",
262272
"Exists by key: True\n",
263273
"Retrieved by key: What is machine learning?\n"
264274
]
@@ -297,7 +307,7 @@
297307
},
298308
{
299309
"cell_type": "code",
300-
"execution_count": 9,
310+
"execution_count": 10,
301311
"metadata": {},
302312
"outputs": [
303313
{
@@ -382,7 +392,7 @@
382392
},
383393
{
384394
"cell_type": "code",
385-
"execution_count": 10,
395+
"execution_count": 11,
386396
"metadata": {},
387397
"outputs": [
388398
{
@@ -430,7 +440,7 @@
430440
},
431441
{
432442
"cell_type": "code",
433-
"execution_count": 11,
443+
"execution_count": 12,
434444
"metadata": {},
435445
"outputs": [
436446
{
@@ -484,7 +494,7 @@
484494
},
485495
{
486496
"cell_type": "code",
487-
"execution_count": 12,
497+
"execution_count": 13,
488498
"metadata": {},
489499
"outputs": [
490500
{
@@ -533,18 +543,13 @@
533543
},
534544
{
535545
"cell_type": "code",
536-
"execution_count": 13,
546+
"execution_count": 14,
537547
"metadata": {},
538548
"outputs": [
539549
{
540550
"name": "stdout",
541551
"output_type": "stream",
542552
"text": [
543-
"Computing embedding for: What is artificial intelligence?\n",
544-
"Computing embedding for: How does machine learning work?\n",
545-
"Found in cache: What is artificial intelligence?\n",
546-
"Computing embedding for: What are neural networks?\n",
547-
"Found in cache: How does machine learning work?\n",
548553
"\n",
549554
"Statistics:\n",
550555
"Total queries: 5\n",
@@ -562,25 +567,11 @@
562567
" ttl=3600 # 1 hour TTL\n",
563568
")\n",
564569
"\n",
565-
"# Function to get embedding with caching\n",
566-
"def get_cached_embedding(text, model_name):\n",
567-
" # Check if it's in the cache first\n",
568-
" if cached_result := example_cache.get(text=text, model_name=model_name):\n",
569-
" print(f\"Found in cache: {text}\")\n",
570-
" return cached_result[\"embedding\"]\n",
571-
" \n",
572-
" # Not in cache, compute the embedding\n",
573-
" print(f\"Computing embedding for: {text}\")\n",
574-
" embedding = vectorizer.embed(text)\n",
575-
" \n",
576-
" # Store in cache\n",
577-
" example_cache.set(\n",
578-
" text=text,\n",
579-
" model_name=model_name,\n",
580-
" embedding=embedding,\n",
581-
" )\n",
582-
" \n",
583-
" return embedding\n",
570+
"vectorizer = HFTextVectorizer(\n",
571+
" model=model_name,\n",
572+
" cache=example_cache,\n",
573+
" cache_folder=os.getenv(\"SENTENCE_TRANSFORMERS_HOME\")\n",
574+
")\n",
584575
"\n",
585576
"# Simulate processing a stream of queries\n",
586577
"queries = [\n",
@@ -604,7 +595,7 @@
604595
" cache_hits += 1\n",
605596
" \n",
606597
" # Get embedding (will compute or use cache)\n",
607-
" embedding = get_cached_embedding(query, model_name)\n",
598+
" embedding = vectorizer.embed(query)\n",
608599
"\n",
609600
"# Report statistics\n",
610601
"cache_misses = total_queries - cache_hits\n",
@@ -632,72 +623,57 @@
632623
},
633624
{
634625
"cell_type": "code",
635-
"execution_count": 14,
626+
"execution_count": 15,
636627
"metadata": {},
637628
"outputs": [
638629
{
639630
"name": "stdout",
640631
"output_type": "stream",
641632
"text": [
642633
"Benchmarking without caching:\n",
643-
"Time taken without caching: 0.0940 seconds\n",
644-
"Average time per embedding: 0.0094 seconds\n",
634+
"Time taken without caching: 0.4735 seconds\n",
635+
"Average time per embedding: 0.0474 seconds\n",
645636
"\n",
646637
"Benchmarking with caching:\n",
647-
"Time taken with caching: 0.0237 seconds\n",
648-
"Average time per embedding: 0.0024 seconds\n",
638+
"Time taken with caching: 0.0663 seconds\n",
639+
"Average time per embedding: 0.0066 seconds\n",
649640
"\n",
650641
"Performance comparison:\n",
651-
"Speedup with caching: 3.96x faster\n",
652-
"Time saved: 0.0703 seconds (74.8%)\n",
653-
"Latency reduction: 0.0070 seconds per query\n"
642+
"Speedup with caching: 7.14x faster\n",
643+
"Time saved: 0.4073 seconds (86.0%)\n",
644+
"Latency reduction: 0.0407 seconds per query\n"
654645
]
655646
}
656647
],
657648
"source": [
658649
"# Text to use for benchmarking\n",
659650
"benchmark_text = \"This is a benchmark text to measure the performance of embedding caching.\"\n",
660-
"benchmark_model = \"sentence-transformers/all-mpnet-base-v2\"\n",
661651
"\n",
662652
"# Create a fresh cache for benchmarking\n",
663653
"benchmark_cache = EmbeddingsCache(\n",
664654
" name=\"benchmark_cache\",\n",
665655
" redis_url=\"redis://localhost:6379\",\n",
666656
" ttl=3600 # 1 hour TTL\n",
667657
")\n",
668-
"\n",
669-
"# Function to get embeddings without caching\n",
670-
"def get_embedding_without_cache(text, model_name):\n",
671-
" return vectorizer.embed(text)\n",
672-
"\n",
673-
"# Function to get embeddings with caching\n",
674-
"def get_embedding_with_cache(text, model_name):\n",
675-
" if cached_result := benchmark_cache.get(text=text, model_name=model_name):\n",
676-
" return cached_result[\"embedding\"]\n",
677-
" \n",
678-
" embedding = vectorizer.embed(text)\n",
679-
" benchmark_cache.set(\n",
680-
" text=text,\n",
681-
" model_name=model_name,\n",
682-
" embedding=embedding\n",
683-
" )\n",
684-
" return embedding\n",
658+
"vectorizer.cache = benchmark_cache\n",
685659
"\n",
686660
"# Number of iterations for the benchmark\n",
687661
"n_iterations = 10\n",
688662
"\n",
689663
"# Benchmark without caching\n",
690664
"print(\"Benchmarking without caching:\")\n",
691665
"start_time = time.time()\n",
692-
"get_embedding_without_cache(benchmark_text, benchmark_model)\n",
666+
"for _ in range(n_iterations):\n",
667+
" embedding = vectorizer.embed(text, skip_cache=True)\n",
693668
"no_cache_time = time.time() - start_time\n",
694669
"print(f\"Time taken without caching: {no_cache_time:.4f} seconds\")\n",
695670
"print(f\"Average time per embedding: {no_cache_time/n_iterations:.4f} seconds\")\n",
696671
"\n",
697672
"# Benchmark with caching\n",
698673
"print(\"\\nBenchmarking with caching:\")\n",
699674
"start_time = time.time()\n",
700-
"get_embedding_with_cache(benchmark_text, benchmark_model)\n",
675+
"for _ in range(n_iterations):\n",
676+
" embedding = vectorizer.embed(text)\n",
701677
"cache_time = time.time() - start_time\n",
702678
"print(f\"Time taken with caching: {cache_time:.4f} seconds\")\n",
703679
"print(f\"Average time per embedding: {cache_time/n_iterations:.4f} seconds\")\n",
@@ -785,7 +761,7 @@
785761
"name": "python",
786762
"nbconvert_exporter": "python",
787763
"pygments_lexer": "ipython3",
788-
"version": "3.10.12"
764+
"version": "3.13.2"
789765
}
790766
},
791767
"nbformat": 4,

poetry.lock

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

redisvl/extensions/cache/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,5 @@
66
"""
77

88
from redisvl.extensions.cache.base import BaseCache
9-
from redisvl.extensions.cache.embeddings import EmbeddingsCache
10-
from redisvl.extensions.cache.llm import SemanticCache
119

12-
__all__ = ["BaseCache", "EmbeddingsCache", "SemanticCache"]
10+
__all__ = ["BaseCache"]

redisvl/extensions/cache/base.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from redis import Redis
1010
from redis.asyncio import Redis as AsyncRedis
1111

12+
from redisvl.redis.connection import RedisConnectionFactory
13+
1214

1315
class BaseCache:
1416
"""Base abstract cache interface for all RedisVL caches.
@@ -121,10 +123,15 @@ async def _get_async_redis_client(self) -> AsyncRedis:
121123
AsyncRedis: An async Redis client instance.
122124
"""
123125
if not hasattr(self, "_async_redis_client") or self._async_redis_client is None:
124-
# Create new async Redis client
125-
url = self.redis_kwargs["redis_url"]
126-
kwargs = self.redis_kwargs["connection_kwargs"]
127-
self._async_redis_client = AsyncRedis.from_url(url, **kwargs) # type: ignore
126+
client = self.redis_kwargs.get("redis_client")
127+
if isinstance(client, Redis):
128+
self._async_redis_client = RedisConnectionFactory.sync_to_async_redis(
129+
client
130+
)
131+
else:
132+
url = self.redis_kwargs["redis_url"]
133+
kwargs = self.redis_kwargs["connection_kwargs"]
134+
self._async_redis_client = RedisConnectionFactory.get_async_redis_connection(url, **kwargs) # type: ignore
128135
return self._async_redis_client
129136

130137
def expire(self, key: str, ttl: Optional[int] = None) -> None:

0 commit comments

Comments
 (0)