Skip to content

Commit 71650a5

Browse files
authored
first redis test (#122)
* first redis test * fixed test * pr comment * added docs
1 parent 15b1169 commit 71650a5

File tree

4 files changed

+100
-31
lines changed

4 files changed

+100
-31
lines changed

datastore/providers/redis_datastore.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,6 @@
4545
{"name": "ReJSON", "ver": 20404}
4646
]
4747
REDIS_DEFAULT_ESCAPED_CHARS = re.compile(r"[,.<>{}\[\]\\\"\':;!@#$%^&*()\-+=~\/ ]")
48-
REDIS_SEARCH_SCHEMA = {
49-
"document_id": TagField("$.document_id", as_name="document_id"),
50-
"metadata": {
51-
# "source_id": TagField("$.metadata.source_id", as_name="source_id"),
52-
"source": TagField("$.metadata.source", as_name="source"),
53-
# "author": TextField("$.metadata.author", as_name="author"),
54-
# "created_at": NumericField("$.metadata.created_at", as_name="created_at"),
55-
},
56-
"embedding": VectorField(
57-
"$.embedding",
58-
REDIS_INDEX_TYPE,
59-
{
60-
"TYPE": "FLOAT64",
61-
"DIM": VECTOR_DIMENSION,
62-
"DISTANCE_METRIC": REDIS_DISTANCE_METRIC,
63-
},
64-
as_name="embedding",
65-
),
66-
}
6748

6849
# Helper functions
6950
def unpack_schema(d: dict):
@@ -82,22 +63,23 @@ async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]):
8263
error_message = "You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. " \
8364
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
8465
logging.error(error_message)
85-
raise ValueError(error_message)
66+
raise AttributeError(error_message)
8667

8768

8869

8970
class RedisDataStore(DataStore):
90-
def __init__(self, client: redis.Redis):
71+
def __init__(self, client: redis.Redis, redisearch_schema):
9172
self.client = client
73+
self._schema = redisearch_schema
9274
# Init default metadata with sentinel values in case the document written has no metadata
9375
self._default_metadata = {
94-
field: "_null_" for field in REDIS_SEARCH_SCHEMA["metadata"]
76+
field: "_null_" for field in redisearch_schema["metadata"]
9577
}
9678

9779
### Redis Helper Methods ###
9880

9981
@classmethod
100-
async def init(cls):
82+
async def init(cls, **kwargs):
10183
"""
10284
Setup the index if it does not exist.
10385
"""
@@ -112,7 +94,27 @@ async def init(cls):
11294
raise e
11395

11496
await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES)
115-
97+
98+
dim = kwargs.get("dim", VECTOR_DIMENSION)
99+
redisearch_schema = {
100+
"document_id": TagField("$.document_id", as_name="document_id"),
101+
"metadata": {
102+
"source_id": TagField("$.metadata.source_id", as_name="source_id"),
103+
"source": TagField("$.metadata.source", as_name="source"),
104+
"author": TextField("$.metadata.author", as_name="author"),
105+
"created_at": NumericField("$.metadata.created_at", as_name="created_at"),
106+
},
107+
"embedding": VectorField(
108+
"$.embedding",
109+
REDIS_INDEX_TYPE,
110+
{
111+
"TYPE": "FLOAT64",
112+
"DIM": dim,
113+
"DISTANCE_METRIC": REDIS_DISTANCE_METRIC,
114+
},
115+
as_name="embedding",
116+
),
117+
}
116118
try:
117119
# Check for existence of RediSearch Index
118120
await client.ft(REDIS_INDEX_NAME).info()
@@ -123,11 +125,12 @@ async def init(cls):
123125
definition = IndexDefinition(
124126
prefix=[REDIS_DOC_PREFIX], index_type=IndexType.JSON
125127
)
126-
fields = list(unpack_schema(REDIS_SEARCH_SCHEMA))
128+
fields = list(unpack_schema(redisearch_schema))
129+
logging.info(f"Creating index with fields: {fields}")
127130
await client.ft(REDIS_INDEX_NAME).create_index(
128131
fields=fields, definition=definition
129132
)
130-
return cls(client)
133+
return cls(client, redisearch_schema)
131134

132135
@staticmethod
133136
def _redis_key(document_id: str, chunk_id: str) -> str:
@@ -217,20 +220,21 @@ def _typ_to_str(typ, field, value) -> str: # type: ignore
217220

218221
# Build filter
219222
if query.filter:
223+
redisearch_schema = self._schema
220224
for field, value in query.filter.__dict__.items():
221225
if not value:
222226
continue
223-
if field in REDIS_SEARCH_SCHEMA:
224-
filter_str += _typ_to_str(REDIS_SEARCH_SCHEMA[field], field, value)
225-
elif field in REDIS_SEARCH_SCHEMA["metadata"]:
227+
if field in redisearch_schema:
228+
filter_str += _typ_to_str(redisearch_schema[field], field, value)
229+
elif field in redisearch_schema["metadata"]:
226230
if field == "source": # handle the enum
227231
value = value.value
228232
filter_str += _typ_to_str(
229-
REDIS_SEARCH_SCHEMA["metadata"][field], field, value
233+
redisearch_schema["metadata"][field], field, value
230234
)
231235
elif field in ["start_date", "end_date"]:
232236
filter_str += _typ_to_str(
233-
REDIS_SEARCH_SCHEMA["metadata"]["created_at"], field, value
237+
redisearch_schema["metadata"]["created_at"], field, value
234238
)
235239

236240
# Postprocess filter string

docs/providers/redis/setup.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,17 @@
2121
| `REDIS_DOC_PREFIX` | Optional | Redis key prefix for the index | `doc` |
2222
| `REDIS_DISTANCE_METRIC` | Optional | Vector similarity distance metric | `COSINE` |
2323
| `REDIS_INDEX_TYPE` | Optional | [Vector index algorithm type](https://redis.io/docs/stack/search/reference/vectors/#creation-attributes-per-algorithm) | `FLAT` |
24+
25+
26+
## Redis Datastore development & testing
27+
In order to test your changes to the Redis Datastore, you can run the following commands:
28+
29+
```bash
30+
# Run the Redis stack docker image
31+
docker run -it --rm -p 6379:6379 redis/redis-stack-server:latest
32+
```
33+
34+
```bash
35+
# Run the Redis datastore tests
36+
poetry run pytest -s ./tests/datastore/providers/redis/test_redis_datastore.py
37+
```

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,9 @@ pytest-asyncio = "^0.20.3"
4141
[build-system]
4242
requires = ["poetry-core"]
4343
build-backend = "poetry.core.masonry.api"
44+
45+
[tool.pytest.ini_options]
46+
pythonpath = [
47+
"."
48+
]
49+
asyncio_mode="auto"
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from datastore.providers.redis_datastore import RedisDataStore
2+
import datastore.providers.redis_datastore as static_redis
3+
from models.models import DocumentChunk, DocumentChunkMetadata, QueryWithEmbedding, Source
4+
import pytest
5+
import redis.asyncio as redis
6+
import numpy as np
7+
8+
@pytest.fixture
9+
async def redis_datastore():
10+
return await RedisDataStore.init(dim=5)
11+
12+
13+
def create_embedding(i, dim):
14+
vec = np.array([0.1] * dim).astype(np.float64).tolist()
15+
vec[dim-1] = i+1/10
16+
return vec
17+
18+
def create_document_chunk(i, dim):
19+
return DocumentChunk(
20+
id=f"first-doc_{i}",
21+
text=f"Lorem ipsum {i}",
22+
embedding=create_embedding(i, dim),
23+
metadata=DocumentChunkMetadata(
24+
source=Source.file, created_at="1970-01-01", document_id=f"doc-{i}"
25+
),
26+
)
27+
28+
def create_document_chunks(n, dim):
29+
docs = [create_document_chunk(i, dim) for i in range(n)]
30+
return {"docs": docs}
31+
32+
@pytest.mark.asyncio
33+
async def test_redis_upsert_query(redis_datastore):
34+
docs = create_document_chunks(10, 5)
35+
await redis_datastore._upsert(docs)
36+
query = QueryWithEmbedding(
37+
query="Lorem ipsum 0",
38+
top_k=5,
39+
embedding= create_embedding(0, 5),
40+
)
41+
query_results = await redis_datastore._query(queries=[query])
42+
assert 1 == len(query_results)
43+
for i in range(5):
44+
assert f"Lorem ipsum {i}" == query_results[0].results[i].text
45+
assert f"doc-{i}" == query_results[0].results[i].id

0 commit comments

Comments
 (0)