forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from philippe2803/integration-sqlite-vec
Preparing sqlite-vec as a vectore store partner
- Loading branch information
Showing
5 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
227 changes: 227 additions & 0 deletions
227
libs/community/langchain_community/vectorstores/sqlitevec.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
from __future__ import annotations | ||
|
||
import json | ||
import logging | ||
import warnings | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Iterable, | ||
List, | ||
Optional, | ||
Tuple, | ||
Type, | ||
) | ||
|
||
from langchain_core.documents import Document | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.vectorstores import VectorStore | ||
|
||
if TYPE_CHECKING: | ||
import sqlite3 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SQLiteVec(VectorStore): | ||
"""SQLite with vector extension as a vector database. | ||
To use, you should have the ``sqlite-vec`` python package installed. | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.vectorstores import SQLiteVec | ||
from langchain_community.embeddings.openai import OpenAIEmbeddings | ||
... | ||
""" | ||
|
||
def __init__( | ||
self, | ||
table: str, | ||
connection: Optional[sqlite3.Connection], | ||
embedding: Embeddings, | ||
db_file: str = "vec.db", | ||
): | ||
"""Initialize with sqlite client with sqlite-vec extension.""" | ||
try: | ||
import sqlite_vec # noqa # pylint: disable=unused-import | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import sqlite-vec python package. " | ||
"Please install it with `pip install sqlite-vec`." | ||
) | ||
|
||
if not connection: | ||
connection = self.create_connection(db_file) | ||
|
||
if not isinstance(embedding, Embeddings): | ||
warnings.warn("embeddings input must be Embeddings object.") | ||
|
||
self._connection = connection | ||
self._table = table | ||
self._embedding = embedding | ||
|
||
self.create_table_if_not_exists() | ||
|
||
def create_table_if_not_exists(self) -> None: | ||
self._connection.execute( | ||
f""" | ||
CREATE TABLE IF NOT EXISTS {self._table} | ||
( | ||
rowid INTEGER PRIMARY KEY AUTOINCREMENT, | ||
text TEXT, | ||
metadata BLOB, | ||
text_embedding BLOB | ||
) | ||
; | ||
""" | ||
) | ||
self._connection.execute( | ||
f""" | ||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_{self._table} USING vec0( | ||
text_embedding FLOAT[{self.get_dimensionality()}] | ||
); | ||
""" | ||
) | ||
self._connection.execute( | ||
f""" | ||
CREATE TRIGGER IF NOT EXISTS embed_text | ||
AFTER INSERT ON {self._table} | ||
BEGIN | ||
INSERT INTO vec_{self._table}(rowid, text_embedding) | ||
VALUES (new.rowid, new.text_embedding) | ||
; | ||
END; | ||
""" | ||
) | ||
self._connection.commit() | ||
|
||
def add_texts( | ||
self, | ||
texts: Iterable[str], | ||
metadatas: Optional[List[dict]] = None, | ||
**kwargs: Any, | ||
) -> List[str]: | ||
"""Add more texts to the vectorstore index. | ||
Args: | ||
texts: Iterable of strings to add to the vectorstore. | ||
metadatas: Optional list of metadatas associated with the texts. | ||
kwargs: vectorstore specific parameters | ||
""" | ||
max_id = self._connection.execute( | ||
f"SELECT max(rowid) as rowid FROM {self._table}" | ||
).fetchone()["rowid"] | ||
if max_id is None: # no text added yet | ||
max_id = 0 | ||
|
||
embeds = self._embedding.embed_documents(list(texts)) | ||
if not metadatas: | ||
metadatas = [{} for _ in texts] | ||
data_input = [ | ||
(text, json.dumps(metadata), json.dumps(embed)) | ||
for text, metadata, embed in zip(texts, metadatas, embeds) | ||
] | ||
self._connection.executemany( | ||
f"INSERT INTO {self._table}(text, metadata, text_embedding) " | ||
f"VALUES (?,?,?)", | ||
data_input, | ||
) | ||
self._connection.commit() | ||
# pulling every id we just inserted | ||
results = self._connection.execute( | ||
f"SELECT rowid FROM {self._table} WHERE rowid > {max_id}" | ||
) | ||
return [row["rowid"] for row in results] | ||
|
||
def similarity_search_with_score_by_vector( | ||
self, embedding: List[float], k: int = 4, **kwargs: Any | ||
) -> List[Tuple[Document, float]]: | ||
sql_query = f""" | ||
SELECT | ||
text, | ||
metadata, | ||
distance | ||
FROM {self._table} e | ||
INNER JOIN vec_{self._table} v on v.rowid = e.rowid | ||
WHERE v.text_embedding MATCH '{json.dumps(embedding)}' | ||
AND k = {k} | ||
ORDER BY distance | ||
""" | ||
cursor = self._connection.cursor() | ||
cursor.execute(sql_query) | ||
results = cursor.fetchall() | ||
|
||
documents = [] | ||
for row in results: | ||
metadata = json.loads(row["metadata"]) or {} | ||
doc = Document(page_content=row["text"], metadata=metadata) | ||
documents.append((doc, row["distance"])) | ||
|
||
return documents | ||
|
||
def similarity_search( | ||
self, query: str, k: int = 4, **kwargs: Any | ||
) -> List[Document]: | ||
"""Return docs most similar to query.""" | ||
embedding = self._embedding.embed_query(query) | ||
documents = self.similarity_search_with_score_by_vector( | ||
embedding=embedding, k=k | ||
) | ||
return [doc for doc, _ in documents] | ||
|
||
def similarity_search_with_score( | ||
self, query: str, k: int = 4, **kwargs: Any | ||
) -> List[Tuple[Document, float]]: | ||
"""Return docs most similar to query.""" | ||
embedding = self._embedding.embed_query(query) | ||
documents = self.similarity_search_with_score_by_vector( | ||
embedding=embedding, k=k | ||
) | ||
return documents | ||
|
||
def similarity_search_by_vector( | ||
self, embedding: List[float], k: int = 4, **kwargs: Any | ||
) -> List[Document]: | ||
documents = self.similarity_search_with_score_by_vector( | ||
embedding=embedding, k=k | ||
) | ||
return [doc for doc, _ in documents] | ||
|
||
@classmethod | ||
def from_texts( | ||
cls: Type[SQLiteVec], | ||
texts: List[str], | ||
embedding: Embeddings, | ||
metadatas: Optional[List[dict]] = None, | ||
table: str = "langchain", | ||
db_file: str = "vec.db", | ||
**kwargs: Any, | ||
) -> SQLiteVec: | ||
"""Return VectorStore initialized from texts and embeddings.""" | ||
connection = cls.create_connection(db_file) | ||
vec = cls( | ||
table=table, connection=connection, db_file=db_file, embedding=embedding | ||
) | ||
vec.add_texts(texts=texts, metadatas=metadatas) | ||
return vec | ||
|
||
@staticmethod | ||
def create_connection(db_file: str) -> sqlite3.Connection: | ||
import sqlite3 | ||
|
||
import sqlite_vec | ||
|
||
connection = sqlite3.connect(db_file) | ||
connection.row_factory = sqlite3.Row | ||
connection.enable_load_extension(True) | ||
sqlite_vec.load(connection) | ||
connection.enable_load_extension(False) | ||
return connection | ||
|
||
def get_dimensionality(self) -> int: | ||
""" | ||
Function that does a dummy embedding to figure out how many dimensions | ||
this embedding function returns. Needed for the virtual table DDL. | ||
""" | ||
dummy_text = "This is a dummy text" | ||
dummy_embedding = self._embedding.embed_query(dummy_text) | ||
return len(dummy_embedding) |
58 changes: 58 additions & 0 deletions
58
libs/community/tests/integration_tests/vectorstores/test_sqlitevec.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import List, Optional | ||
|
||
import pytest | ||
from langchain_core.documents import Document | ||
|
||
from langchain_community.vectorstores import SQLiteVec | ||
from tests.integration_tests.vectorstores.fake_embeddings import ( | ||
FakeEmbeddings, | ||
fake_texts, | ||
) | ||
|
||
|
||
def _sqlite_vec_from_texts( | ||
metadatas: Optional[List[dict]] = None, drop: bool = True | ||
) -> SQLiteVec: | ||
return SQLiteVec.from_texts( | ||
fake_texts, | ||
FakeEmbeddings(), | ||
metadatas=metadatas, | ||
table="test", | ||
db_file=":memory:", | ||
) | ||
|
||
|
||
@pytest.mark.requires("sqlite-vec") | ||
def test_sqlitevec() -> None: | ||
"""Test end to end construction and search.""" | ||
docsearch = _sqlite_vec_from_texts() | ||
output = docsearch.similarity_search("foo", k=1) | ||
assert output == [Document(page_content="foo", metadata={})] | ||
|
||
|
||
@pytest.mark.requires("sqlite-vec") | ||
def test_sqlitevec_with_score() -> None: | ||
"""Test end to end construction and search with scores and IDs.""" | ||
texts = ["foo", "bar", "baz"] | ||
metadatas = [{"page": i} for i in range(len(texts))] | ||
docsearch = _sqlite_vec_from_texts(metadatas=metadatas) | ||
output = docsearch.similarity_search_with_score("foo", k=3) | ||
docs = [o[0] for o in output] | ||
distances = [o[1] for o in output] | ||
assert docs == [ | ||
Document(page_content="foo", metadata={"page": 0}), | ||
Document(page_content="bar", metadata={"page": 1}), | ||
Document(page_content="baz", metadata={"page": 2}), | ||
] | ||
assert distances[0] < distances[1] < distances[2] | ||
|
||
|
||
@pytest.mark.requires("sqlite-vec") | ||
def test_sqlitevec_add_extra() -> None: | ||
"""Test end to end construction and MRR search.""" | ||
texts = ["foo", "bar", "baz"] | ||
metadatas = [{"page": i} for i in range(len(texts))] | ||
docsearch = _sqlite_vec_from_texts(metadatas=metadatas) | ||
docsearch.add_texts(texts, metadatas) | ||
output = docsearch.similarity_search("foo", k=10) | ||
assert len(output) == 6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,6 +76,7 @@ | |
"Relyt", | ||
"Rockset", | ||
"SKLearnVectorStore", | ||
"SQLiteVec", | ||
"SQLiteVSS", | ||
"ScaNN", | ||
"SemaDB", | ||
|