|
| 1 | +# pip install sentence-transformers |
| 2 | +import json |
| 3 | +import os |
| 4 | + |
| 5 | +import azure.identity |
| 6 | +import openai |
| 7 | +from dotenv import load_dotenv |
| 8 | +from lunr import lunr |
| 9 | +from sentence_transformers import CrossEncoder |
| 10 | + |
| 11 | +# Setup the OpenAI client to use either Azure, OpenAI.com, or Ollama API |
| 12 | +load_dotenv(override=True) |
| 13 | +API_HOST = os.getenv("API_HOST") |
| 14 | + |
| 15 | +if API_HOST == "azure": |
| 16 | + token_provider = azure.identity.get_bearer_token_provider( |
| 17 | + azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" |
| 18 | + ) |
| 19 | + client = openai.AzureOpenAI( |
| 20 | + api_version=os.environ["AZURE_OPENAI_VERSION"], |
| 21 | + azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"], |
| 22 | + azure_ad_token_provider=token_provider, |
| 23 | + ) |
| 24 | + MODEL_NAME = os.environ["AZURE_OPENAI_DEPLOYMENT"] |
| 25 | + |
| 26 | +elif API_HOST == "ollama": |
| 27 | + client = openai.OpenAI(base_url=os.environ["OLLAMA_ENDPOINT"], api_key="nokeyneeded") |
| 28 | + MODEL_NAME = os.environ["OLLAMA_MODEL"] |
| 29 | + |
| 30 | +elif API_HOST == "github": |
| 31 | + client = openai.OpenAI(base_url="https://models.inference.ai.azure.com", api_key=os.environ["GITHUB_TOKEN"]) |
| 32 | + MODEL_NAME = os.environ["GITHUB_MODEL"] |
| 33 | + |
| 34 | +else: |
| 35 | + client = openai.OpenAI(api_key=os.environ["OPENAI_KEY"]) |
| 36 | + MODEL_NAME = os.environ["OPENAI_MODEL"] |
| 37 | + |
| 38 | +# Index the data from the JSON - each object has id, text, and embedding |
| 39 | +with open("rag_ingested_chunks.json") as file: |
| 40 | + documents = json.load(file) |
| 41 | + documents_by_id = {doc["id"]: doc for doc in documents} |
| 42 | +index = lunr(ref="id", fields=["text"], documents=documents) |
| 43 | + |
| 44 | + |
| 45 | +def full_text_search(query, limit): |
| 46 | + """ |
| 47 | + Perform a full-text search on the indexed documents. |
| 48 | + """ |
| 49 | + results = index.search(query) |
| 50 | + retrieved_documents = [documents_by_id[result["ref"]] for result in results[:limit]] |
| 51 | + return retrieved_documents |
| 52 | + |
| 53 | + |
| 54 | +def vector_search(query, limit): |
| 55 | + """ |
| 56 | + Perform a vector search on the indexed documents |
| 57 | + using a simple cosine similarity function. |
| 58 | + """ |
| 59 | + |
| 60 | + def cosine_similarity(a, b): |
| 61 | + return sum(x * y for x, y in zip(a, b)) / ((sum(x * x for x in a) ** 0.5) * (sum(y * y for y in b) ** 0.5)) |
| 62 | + |
| 63 | + query_embedding = client.embeddings.create(model="text-embedding-3-small", input=query).data[0].embedding |
| 64 | + similarities = [] |
| 65 | + for doc in documents: |
| 66 | + doc_embedding = doc["embedding"] |
| 67 | + similarity = cosine_similarity(query_embedding, doc_embedding) |
| 68 | + similarities.append((doc, similarity)) |
| 69 | + similarities.sort(key=lambda x: x[1], reverse=True) |
| 70 | + |
| 71 | + retrieved_documents = [doc for doc, _ in similarities[:limit]] |
| 72 | + return retrieved_documents |
| 73 | + |
| 74 | + |
| 75 | +def reciprocal_rank_fusion(text_results, vector_results, alpha=0.5): |
| 76 | + """ |
| 77 | + Perform Reciprocal Rank Fusion on the results from text and vector searches. |
| 78 | + """ |
| 79 | + text_ids = {doc["id"] for doc in text_results} |
| 80 | + vector_ids = {doc["id"] for doc in vector_results} |
| 81 | + |
| 82 | + combined_results = [] |
| 83 | + for doc in text_results: |
| 84 | + if doc["id"] in vector_ids: |
| 85 | + combined_results.append((doc, alpha)) |
| 86 | + else: |
| 87 | + combined_results.append((doc, 1 - alpha)) |
| 88 | + for doc in vector_results: |
| 89 | + if doc["id"] not in text_ids: |
| 90 | + combined_results.append((doc, alpha)) |
| 91 | + combined_results.sort(key=lambda x: x[1], reverse=True) |
| 92 | + return [doc for doc, _ in combined_results] |
| 93 | + |
| 94 | + |
| 95 | +def rerank(query, retrieved_documents): |
| 96 | + """ |
| 97 | + Rerank the results using a cross-encoder model. |
| 98 | + """ |
| 99 | + encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") |
| 100 | + scores = encoder.predict([(query, doc["text"]) for doc in retrieved_documents]) |
| 101 | + scored_documents = [v for _, v in sorted(zip(scores, retrieved_documents), reverse=True)] |
| 102 | + return scored_documents |
| 103 | + |
| 104 | + |
| 105 | +def hybrid_search(query, limit): |
| 106 | + """ |
| 107 | + Perform a hybrid search using both full-text and vector search. |
| 108 | + """ |
| 109 | + text_results = full_text_search(query, limit * 2) |
| 110 | + vector_results = vector_search(query, limit * 2) |
| 111 | + combined_results = reciprocal_rank_fusion(text_results, vector_results) |
| 112 | + combined_results = rerank(query, combined_results) |
| 113 | + return combined_results[:limit] |
| 114 | + |
| 115 | + |
| 116 | +# Get the user question |
| 117 | +user_question = "cute gray fuzzsters" |
| 118 | + |
| 119 | +# Search the index for the user question |
| 120 | +retrieved_documents = hybrid_search(user_question, limit=5) |
| 121 | +print(f"Retrieved {len(retrieved_documents)} matching documents.") |
| 122 | +context = "\n".join([f"{doc['id']}: {doc['text']}" for doc in retrieved_documents[0:5]]) |
| 123 | + |
| 124 | +# Now we can use the matches to generate a response |
| 125 | +SYSTEM_MESSAGE = """ |
| 126 | +You are a helpful assistant that answers questions about Maya civilization. |
| 127 | +You must use the data set to answer the questions, |
| 128 | +you should not provide any info that is not in the provided sources. |
| 129 | +Cite the sources you used to answer the question inside square brackets. |
| 130 | +The sources are in the format: <id>: <text>. |
| 131 | +""" |
| 132 | + |
| 133 | +response = client.chat.completions.create( |
| 134 | + model=MODEL_NAME, |
| 135 | + temperature=0.3, |
| 136 | + messages=[ |
| 137 | + {"role": "system", "content": SYSTEM_MESSAGE}, |
| 138 | + {"role": "user", "content": f"{user_question}\nSources: {context}"}, |
| 139 | + ], |
| 140 | +) |
| 141 | + |
| 142 | +print(f"\nResponse from {MODEL_NAME} on {API_HOST}: \n") |
| 143 | +print(response.choices[0].message.content) |
0 commit comments