|
| 1 | +import os |
| 2 | + |
| 3 | +from llama_index.core import ( |
| 4 | + SimpleDirectoryReader, |
| 5 | + VectorStoreIndex, |
| 6 | + StorageContext, |
| 7 | + Settings, |
| 8 | + get_response_synthesizer) |
| 9 | +from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine |
| 10 | +from llama_index.core.node_parser import SentenceSplitter |
| 11 | +from llama_index.core.schema import TextNode, MetadataMode |
| 12 | +from llama_index.vector_stores.qdrant import QdrantVectorStore |
| 13 | +from llama_index.embeddings.ollama import OllamaEmbedding |
| 14 | +# enable if you are using openai |
| 15 | +# from llama_index.embeddings.openai import OpenAIEmbedding |
| 16 | +from llama_index.llms.ollama import Ollama |
| 17 | +# enable if you are using openai |
| 18 | +# from llama_index.llms.openai import OpenAI |
| 19 | +from llama_index.core.retrievers import VectorIndexRetriever |
| 20 | +from llama_index.core.indices.query.query_transform import HyDEQueryTransform |
| 21 | +from llama_index.core.base.response.schema import Response, StreamingResponse, AsyncStreamingResponse, PydanticResponse |
| 22 | +from llama_parse import LlamaParse |
| 23 | +import qdrant_client |
| 24 | +import logging |
| 25 | +from dotenv import load_dotenv, find_dotenv |
| 26 | +from typing import Union |
| 27 | + |
| 28 | +_ = load_dotenv(find_dotenv()) |
| 29 | + |
| 30 | +logging.basicConfig(level=int(os.environ['INFO'])) |
| 31 | +logger = logging.getLogger(__name__) |
| 32 | + |
| 33 | + |
| 34 | +class RAGWithHyDeEngine: |
| 35 | + RESPONSE_TYPE = Union[ |
| 36 | + Response, StreamingResponse, AsyncStreamingResponse, PydanticResponse |
| 37 | + ] |
| 38 | + |
| 39 | + def __init__(self, data_path: str, chunk_size: int = 512, chunk_overlap: int = 200, |
| 40 | + similarity_top_k: int = 3): |
| 41 | + # load the local data directory and chunk the data for further processing |
| 42 | + self.docs = self._docs_with_llama_parse(data_path=data_path) |
| 43 | + self.text_parser = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) |
| 44 | + |
| 45 | + # Create a local Qdrant vector store |
| 46 | + logger.info("initializing the vector store related objects") |
| 47 | + self.client = qdrant_client.QdrantClient(url=os.environ['DB_URL'], api_key=os.environ['DB_API_KEY']) |
| 48 | + self.vector_store = QdrantVectorStore(client=self.client, collection_name=os.environ['COLLECTION_NAME']) |
| 49 | + |
| 50 | + # use your prefered vector embeddings model |
| 51 | + logger.info("initializing the OllamaEmbedding") |
| 52 | + embed_model = OllamaEmbedding(model_name=os.environ['OLLAMA_EMBED_MODEL'], |
| 53 | + base_url=os.environ['OLLAMA_BASE_URL']) |
| 54 | + # openai embeddings, embedding_model_name="text-embedding-3-large" |
| 55 | + # embed_model = OpenAIEmbedding(embed_batch_size=10, model=embedding_model_name) |
| 56 | + |
| 57 | + # use your prefered llm |
| 58 | + llm = Ollama(model=os.environ['OLLAMA_LLM_MODEL'], base_url=os.environ['OLLAMA_BASE_URL'], request_timeout=600) |
| 59 | + # llm = OpenAI(model="gpt-4o") |
| 60 | + |
| 61 | + logger.info("initializing the global settings") |
| 62 | + Settings.embed_model = embed_model |
| 63 | + Settings.llm = llm |
| 64 | + Settings.transformations = [self.text_parser] |
| 65 | + |
| 66 | + self.text_chunks = [] |
| 67 | + self.doc_ids = [] |
| 68 | + self.nodes = [] |
| 69 | + |
| 70 | + self.similarity_top_k = similarity_top_k |
| 71 | + self.hyde_query_engine: TransformQueryEngine = None |
| 72 | + |
| 73 | + # preprocess the data like chunking, nodes, metadata etc |
| 74 | + self._pre_process() |
| 75 | + |
| 76 | + def _docs_with_llama_parse(self, data_path: str, ): |
| 77 | + # set up parser |
| 78 | + parser = LlamaParse( |
| 79 | + result_type="markdown", # "markdown" and "text" are available |
| 80 | + api_key=os.environ.get('llama_cloud') |
| 81 | + ) |
| 82 | + |
| 83 | + # use SimpleDirectoryReader to parse our file |
| 84 | + file_extractor = {"pdf": parser} |
| 85 | + documents = SimpleDirectoryReader(input_dir=data_path, file_extractor=file_extractor).load_data( |
| 86 | + show_progress=True) |
| 87 | + return documents |
| 88 | + |
| 89 | + def _pre_process(self): |
| 90 | + logger.info("enumerating docs") |
| 91 | + for doc_idx, doc in enumerate(self.docs): |
| 92 | + curr_text_chunks = self.text_parser.split_text(doc.text) |
| 93 | + self.text_chunks.extend(curr_text_chunks) |
| 94 | + self.doc_ids.extend([doc_idx] * len(curr_text_chunks)) |
| 95 | + |
| 96 | + logger.info("enumerating text_chunks") |
| 97 | + for idx, text_chunk in enumerate(self.text_chunks): |
| 98 | + node = TextNode(text=text_chunk) |
| 99 | + src_doc = self.docs[self.doc_ids[idx]] |
| 100 | + node.metadata = src_doc.metadata |
| 101 | + self.nodes.append(node) |
| 102 | + |
| 103 | + logger.info("enumerating nodes") |
| 104 | + for node in self.nodes: |
| 105 | + node_embedding = Settings.embed_model.get_text_embedding( |
| 106 | + node.get_content(metadata_mode=MetadataMode.ALL) |
| 107 | + ) |
| 108 | + node.embedding = node_embedding |
| 109 | + |
| 110 | + # create vector store, index documents and creates retriever |
| 111 | + self._create_index_and_retriever() |
| 112 | + |
| 113 | + def _create_index_and_retriever(self): |
| 114 | + logger.info("initializing the storage context") |
| 115 | + storage_context = StorageContext.from_defaults(vector_store=self.vector_store) |
| 116 | + logger.info("indexing the nodes in VectorStoreIndex") |
| 117 | + if not self.client.collection_exists(collection_name=os.environ['COLLECTION_NAME']): |
| 118 | + index = VectorStoreIndex( |
| 119 | + nodes=self.nodes, |
| 120 | + storage_context=storage_context, |
| 121 | + transformations=Settings.transformations, |
| 122 | + ) |
| 123 | + else: |
| 124 | + index = VectorStoreIndex.from_vector_store(vector_store=self.vector_store) |
| 125 | + |
| 126 | + logger.info("initializing the VectorIndexRetriever with top_k as 5") |
| 127 | + vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=self.similarity_top_k) |
| 128 | + response_synthesizer = get_response_synthesizer() |
| 129 | + logger.info("creating the RetrieverQueryEngine instance") |
| 130 | + vector_query_engine = RetrieverQueryEngine( |
| 131 | + retriever=vector_retriever, |
| 132 | + response_synthesizer=response_synthesizer, |
| 133 | + ) |
| 134 | + logger.info("creating the HyDEQueryTransform instance") |
| 135 | + hyde = HyDEQueryTransform(include_original=True) |
| 136 | + hyde_query_engine = TransformQueryEngine(vector_query_engine, hyde) |
| 137 | + |
| 138 | + self.hyde_query_engine = hyde_query_engine |
| 139 | + |
| 140 | + def query(self, query_string: str) -> RESPONSE_TYPE: |
| 141 | + try: |
| 142 | + response = self.hyde_query_engine.query(str_or_query_bundle=query_string) |
| 143 | + return response |
| 144 | + except Exception as e: |
| 145 | + logger.error(f'Error while inference: {e}') |
0 commit comments