|
| 1 | +""" |
| 2 | +Chroma datastore support for the ChatGPT retrieval plugin. |
| 3 | +
|
| 4 | +Consult the Chroma docs and GitHub repo for more information: |
| 5 | +- https://docs.trychroma.com/usage-guide?lang=py |
| 6 | +- https://github.com/chroma-core/chroma |
| 7 | +- https://www.trychroma.com/ |
| 8 | +""" |
| 9 | + |
| 10 | +import os |
| 11 | +from datetime import datetime |
| 12 | +from typing import Dict, List, Optional |
| 13 | + |
| 14 | +import chromadb |
| 15 | + |
| 16 | +from datastore.datastore import DataStore |
| 17 | +from models.models import ( |
| 18 | + DocumentChunk, |
| 19 | + DocumentChunkMetadata, |
| 20 | + DocumentChunkWithScore, |
| 21 | + DocumentMetadataFilter, |
| 22 | + QueryResult, |
| 23 | + QueryWithEmbedding, |
| 24 | + Source, |
| 25 | +) |
| 26 | +from services.chunks import get_document_chunks |
| 27 | + |
| 28 | +CHROMA_IN_MEMORY = os.environ.get("CHROMA_IN_MEMORY", "True") |
| 29 | +CHROMA_PERSISTENCE_DIR = os.environ.get("CHROMA_PERSISTENCE_DIR", "openai") |
| 30 | +CHROMA_HOST = os.environ.get("CHROMA_HOST", "http://127.0.0.1") |
| 31 | +CHROMA_PORT = os.environ.get("CHROMA_PORT", "8000") |
| 32 | +CHROMA_COLLECTION = os.environ.get("CHROMA_COLLECTION", "openaiembeddings") |
| 33 | + |
| 34 | + |
| 35 | +class ChromaDataStore(DataStore): |
| 36 | + def __init__( |
| 37 | + self, |
| 38 | + in_memory: bool = CHROMA_IN_MEMORY, |
| 39 | + persistence_dir: Optional[str] = CHROMA_PERSISTENCE_DIR, |
| 40 | + collection_name: str = CHROMA_COLLECTION, |
| 41 | + host: str = CHROMA_HOST, |
| 42 | + port: str = CHROMA_PORT, |
| 43 | + client: Optional[chromadb.Client] = None, |
| 44 | + ): |
| 45 | + if client: |
| 46 | + self._client = client |
| 47 | + else: |
| 48 | + if in_memory: |
| 49 | + settings = ( |
| 50 | + chromadb.config.Settings( |
| 51 | + chroma_db_impl="duckdb+parquet", |
| 52 | + persist_directory=persistence_dir, |
| 53 | + ) |
| 54 | + if persistence_dir |
| 55 | + else chromadb.config.Settings() |
| 56 | + ) |
| 57 | + |
| 58 | + self._client = chromadb.Client(settings=settings) |
| 59 | + else: |
| 60 | + self._client = chromadb.Client( |
| 61 | + settings=chromadb.config.Settings( |
| 62 | + chroma_api_impl="rest", |
| 63 | + chroma_server_host=host, |
| 64 | + chroma_server_http_port=port, |
| 65 | + ) |
| 66 | + ) |
| 67 | + self._collection = self._client.get_or_create_collection( |
| 68 | + name=collection_name, |
| 69 | + embedding_function=None, |
| 70 | + ) |
| 71 | + |
| 72 | + async def upsert( |
| 73 | + self, documents: List[DocumentChunk], chunk_token_size: Optional[int] = None |
| 74 | + ) -> List[str]: |
| 75 | + """ |
| 76 | + Takes in a list of documents and inserts them into the database. If an id already exists, the document is updated. |
| 77 | + Return a list of document ids. |
| 78 | + """ |
| 79 | + |
| 80 | + chunks = get_document_chunks(documents, chunk_token_size) |
| 81 | + |
| 82 | + # Chroma has a true upsert, so we don't need to delete first |
| 83 | + return await self._upsert(chunks) |
| 84 | + |
| 85 | + async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]: |
| 86 | + """ |
| 87 | + Takes in a list of list of document chunks and inserts them into the database. |
| 88 | + Return a list of document ids. |
| 89 | + """ |
| 90 | + |
| 91 | + self._collection.upsert( |
| 92 | + ids=[chunk.id for chunk_list in chunks.values() for chunk in chunk_list], |
| 93 | + embeddings=[ |
| 94 | + chunk.embedding |
| 95 | + for chunk_list in chunks.values() |
| 96 | + for chunk in chunk_list |
| 97 | + ], |
| 98 | + documents=[ |
| 99 | + chunk.text for chunk_list in chunks.values() for chunk in chunk_list |
| 100 | + ], |
| 101 | + metadatas=[ |
| 102 | + self._process_metadata_for_storage(chunk.metadata) |
| 103 | + for chunk_list in chunks.values() |
| 104 | + for chunk in chunk_list |
| 105 | + ], |
| 106 | + ) |
| 107 | + return list(chunks.keys()) |
| 108 | + |
| 109 | + def _where_from_query_filter(self, query_filter: DocumentMetadataFilter) -> Dict: |
| 110 | + output = { |
| 111 | + k: v |
| 112 | + for (k, v) in query_filter.dict().items() |
| 113 | + if v is not None and k != "start_date" and k != "end_date" and k != "source" |
| 114 | + } |
| 115 | + if query_filter.source: |
| 116 | + output["source"] = query_filter.source.value |
| 117 | + if query_filter.start_date and query_filter.end_date: |
| 118 | + output["$and"] = [ |
| 119 | + { |
| 120 | + "created_at": { |
| 121 | + "$gte": int( |
| 122 | + datetime.fromisoformat(query_filter.start_date).timestamp() |
| 123 | + ) |
| 124 | + } |
| 125 | + }, |
| 126 | + { |
| 127 | + "created_at": { |
| 128 | + "$lte": int( |
| 129 | + datetime.fromisoformat(query_filter.end_date).timestamp() |
| 130 | + ) |
| 131 | + } |
| 132 | + }, |
| 133 | + ] |
| 134 | + elif query_filter.start_date: |
| 135 | + output["created_at"] = { |
| 136 | + "$gte": int(datetime.fromisoformat(query_filter.start_date).timestamp()) |
| 137 | + } |
| 138 | + elif query_filter.end_date: |
| 139 | + output["created_at"] = { |
| 140 | + "$lte": int(datetime.fromisoformat(query_filter.end_date).timestamp()) |
| 141 | + } |
| 142 | + |
| 143 | + return output |
| 144 | + |
| 145 | + def _process_metadata_for_storage(self, metadata: DocumentChunkMetadata) -> Dict: |
| 146 | + stored_metadata = {} |
| 147 | + if metadata.source: |
| 148 | + stored_metadata["source"] = metadata.source.value |
| 149 | + if metadata.source_id: |
| 150 | + stored_metadata["source_id"] = metadata.source_id |
| 151 | + if metadata.url: |
| 152 | + stored_metadata["url"] = metadata.url |
| 153 | + if metadata.created_at: |
| 154 | + stored_metadata["created_at"] = int( |
| 155 | + datetime.fromisoformat(metadata.created_at).timestamp() |
| 156 | + ) |
| 157 | + if metadata.author: |
| 158 | + stored_metadata["author"] = metadata.author |
| 159 | + if metadata.document_id: |
| 160 | + stored_metadata["document_id"] = metadata.document_id |
| 161 | + |
| 162 | + return stored_metadata |
| 163 | + |
| 164 | + def _process_metadata_from_storage(self, metadata: Dict) -> DocumentChunkMetadata: |
| 165 | + return DocumentChunkMetadata( |
| 166 | + source=Source(metadata["source"]) if "source" in metadata else None, |
| 167 | + source_id=metadata.get("source_id", None), |
| 168 | + url=metadata.get("url", None), |
| 169 | + created_at=datetime.fromtimestamp(metadata["created_at"]).isoformat() |
| 170 | + if "created_at" in metadata |
| 171 | + else None, |
| 172 | + author=metadata.get("author", None), |
| 173 | + document_id=metadata.get("document_id", None), |
| 174 | + ) |
| 175 | + |
| 176 | + async def _query(self, queries: List[QueryWithEmbedding]) -> List[QueryResult]: |
| 177 | + """ |
| 178 | + Takes in a list of queries with embeddings and filters and returns a list of query results with matching document chunks and scores. |
| 179 | + """ |
| 180 | + results = [ |
| 181 | + self._collection.query( |
| 182 | + query_embeddings=[query.embedding], |
| 183 | + include=["documents", "distances", "metadatas"], # embeddings |
| 184 | + n_results=min(query.top_k, self._collection.count()), |
| 185 | + where=( |
| 186 | + self._where_from_query_filter(query.filter) if query.filter else {} |
| 187 | + ), |
| 188 | + ) |
| 189 | + for query in queries |
| 190 | + ] |
| 191 | + |
| 192 | + output = [] |
| 193 | + for query, result in zip(queries, results): |
| 194 | + inner_results = [] |
| 195 | + (ids,) = result["ids"] |
| 196 | + # (embeddings,) = result["embeddings"] |
| 197 | + (documents,) = result["documents"] |
| 198 | + (metadatas,) = result["metadatas"] |
| 199 | + (distances,) = result["distances"] |
| 200 | + for id_, text, metadata, distance in zip( |
| 201 | + ids, |
| 202 | + documents, |
| 203 | + metadatas, |
| 204 | + distances, # embeddings (https://github.com/openai/chatgpt-retrieval-plugin/pull/59#discussion_r1154985153) |
| 205 | + ): |
| 206 | + inner_results.append( |
| 207 | + DocumentChunkWithScore( |
| 208 | + id=id_, |
| 209 | + text=text, |
| 210 | + metadata=self._process_metadata_from_storage(metadata), |
| 211 | + # embedding=embedding, |
| 212 | + score=distance, |
| 213 | + ) |
| 214 | + ) |
| 215 | + output.append(QueryResult(query=query.query, results=inner_results)) |
| 216 | + |
| 217 | + return output |
| 218 | + |
| 219 | + async def delete( |
| 220 | + self, |
| 221 | + ids: Optional[List[str]] = None, |
| 222 | + filter: Optional[DocumentMetadataFilter] = None, |
| 223 | + delete_all: Optional[bool] = None, |
| 224 | + ) -> bool: |
| 225 | + """ |
| 226 | + Removes vectors by ids, filter, or everything in the datastore. |
| 227 | + Multiple parameters can be used at once. |
| 228 | + Returns whether the operation was successful. |
| 229 | + """ |
| 230 | + if delete_all: |
| 231 | + self._collection.delete() |
| 232 | + return True |
| 233 | + |
| 234 | + if ids and len(ids) > 0: |
| 235 | + if len(ids) > 1: |
| 236 | + where_clause = {"$or": [{"document_id": id_} for id_ in ids]} |
| 237 | + else: |
| 238 | + (id_,) = ids |
| 239 | + where_clause = {"document_id": id_} |
| 240 | + |
| 241 | + if filter: |
| 242 | + where_clause = { |
| 243 | + "$and": [self._where_from_query_filter(filter), where_clause] |
| 244 | + } |
| 245 | + elif filter: |
| 246 | + where_clause = self._where_from_query_filter(filter) |
| 247 | + |
| 248 | + self._collection.delete(where=where_clause) |
| 249 | + return True |
0 commit comments