Skip to content

Commit

Permalink
Integrate with langchain for sqlite-vss implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
philippe2803 committed Feb 26, 2024
1 parent 90bd0a1 commit 7907d33
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
contentmap.db

/scratch

Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ RUN pip install poetry
RUN poetry config virtualenvs.create false
RUN poetry install

RUN python3 -c 'from sentence_transformers import SentenceTransformer; embedder = SentenceTransformer("all-MiniLM-L6-v2")'


ADD . /app

CMD ["pytest", "./tests"]
36 changes: 30 additions & 6 deletions contentmap/vss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ def __init__(self,
if not connection:
self.connection = SQLiteVSS.create_connection(db_file)

def load(self):
# content table must be there
assert self.table_exists(table_name="content")

embedding_function = SentenceTransformerEmbeddings(
model_name="all-MiniLM-L6-v2"
)
vss = SQLiteVSS(
self.vss = SQLiteVSS(
table="content_chunks",
embedding=embedding_function,
connection=self.connection
)
return vss

def load(self):
# content table must be there
assert self.table_exists(table_name="content")
texts, metadatas = self.prepare_texts_and_metadatas()
self.vss.add_texts(texts=texts, metadatas=metadatas)
return self.vss

def table_exists(self, table_name: str) -> bool:
res = self.connection.execute(f"""
Expand All @@ -50,3 +52,25 @@ def table_exists(self, table_name: str) -> bool:
if len(rows) == 1:
return True
return False

def prepare_texts_and_metadatas(self):
cursor = self.connection.cursor()
result = cursor.execute("SELECT content, url FROM content")
rows = result.fetchall()

# based on Anyscale analysis (https://t.ly/yjgxQ), it looks like the
# sweet spot is 700 chunk size and 50 chunk overlap
text_splitter = CharacterTextSplitter(chunk_size=700, chunk_overlap=50)

texts = []
metadatas = []
for row in rows:
chunks = text_splitter.split_text(row["content"])
chunk_metadatas = [{"url": row["url"]} for _ in chunks]
texts += chunks
metadatas += chunk_metadatas

return texts, metadatas

def similarity_search(self, *args, **kwargs):
return self.vss.similarity_search(*args, **kwargs)
9 changes: 7 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import os
import os.path as op
import logging


@pytest.fixture(autouse=True)
Expand All @@ -8,6 +10,9 @@ def remove_created_database_after_test():
# Setup logic
yield # this is where the testing happens
# Teardown logic
if os.path.exists("contentmap.db"):
os.remove("contentmap.db")

contentmap_db_path = op.join(op.dirname(__file__), "contentmap.db")
if op.exists(contentmap_db_path):
logging.info('Destroying mock sqlite content instance')
os.remove(contentmap_db_path)

31 changes: 28 additions & 3 deletions tests/test_vss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,34 @@ def test_assertion_content_not_exists(self):

class TestVssTablesCreation:

db = build_fixture_db()

def test_vss_instance(self):
cm_vss = ContentMapVSS(db_file=self.db)
db = build_fixture_db()
cm_vss = ContentMapVSS(db_file=db)
cm_vss.load()
assert cm_vss.table_exists("content_chunks")

def test_prepare_texts_and_metadatas(self):
db = build_fixture_db()
cm_vss = ContentMapVSS(db_file=db)
texts, metadatas = cm_vss.prepare_texts_and_metadatas()
assert len(texts) == len(metadatas) >= 1

def test_chunk_table(self):
db = build_fixture_db()
cm_vss = ContentMapVSS(db_file=db)
cm_vss.load()
assert cm_vss.table_exists("content_chunks")
cursor = cm_vss.connection.cursor()
res = cursor.execute("SELECT * FROM content_chunks")
rows = res.fetchall()
assert len(rows) >= 15

def test_similarity_search(self):
db = build_fixture_db()
cm_vss = ContentMapVSS(db_file=db)
cm_vss.load()
data = cm_vss.similarity_search(query="who is Mistral ai company?", k=2)
assert len(data) == 2
metadatas = [doc.metadata for doc in data]
for meta in metadatas:
assert meta.get("url") == "https://philippeoger.com/pages/ai-scene-in-europe-last-week/"

0 comments on commit 7907d33

Please sign in to comment.