generated from pamelafox/python-project-template
-
Notifications
You must be signed in to change notification settings - Fork 154
/
Copy pathrag_documents_hybrid.py
143 lines (116 loc) · 5.05 KB
/
rag_documents_hybrid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import json
import os
import azure.identity
import openai
from dotenv import load_dotenv
from lunr import lunr
from sentence_transformers import CrossEncoder
# Setup the OpenAI client to use either Azure, OpenAI.com, or Ollama API
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
if API_HOST == "azure":
token_provider = azure.identity.get_bearer_token_provider(
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
client = openai.AzureOpenAI(
api_version=os.environ["AZURE_OPENAI_VERSION"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
azure_ad_token_provider=token_provider,
)
MODEL_NAME = os.environ["AZURE_OPENAI_DEPLOYMENT"]
elif API_HOST == "ollama":
client = openai.OpenAI(base_url=os.environ["OLLAMA_ENDPOINT"], api_key="nokeyneeded")
MODEL_NAME = os.environ["OLLAMA_MODEL"]
elif API_HOST == "github":
client = openai.OpenAI(base_url="https://models.inference.ai.azure.com", api_key=os.environ["GITHUB_TOKEN"])
MODEL_NAME = os.getenv("GITHUB_MODEL", "gpt-4o")
else:
client = openai.OpenAI(api_key=os.environ["OPENAI_KEY"])
MODEL_NAME = os.environ["OPENAI_MODEL"]
# Index the data from the JSON - each object has id, text, and embedding
with open("rag_ingested_chunks.json") as file:
documents = json.load(file)
documents_by_id = {doc["id"]: doc for doc in documents}
index = lunr(ref="id", fields=["text"], documents=documents)
def full_text_search(query, limit):
"""
Perform a full-text search on the indexed documents.
"""
results = index.search(query)
retrieved_documents = [documents_by_id[result["ref"]] for result in results[:limit]]
return retrieved_documents
def vector_search(query, limit):
"""
Perform a vector search on the indexed documents
using a simple cosine similarity function.
"""
def cosine_similarity(a, b):
return sum(x * y for x, y in zip(a, b)) / ((sum(x * x for x in a) ** 0.5) * (sum(y * y for y in b) ** 0.5))
query_embedding = client.embeddings.create(model="text-embedding-3-small", input=query).data[0].embedding
similarities = []
for doc in documents:
doc_embedding = doc["embedding"]
similarity = cosine_similarity(query_embedding, doc_embedding)
similarities.append((doc, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
retrieved_documents = [doc for doc, _ in similarities[:limit]]
return retrieved_documents
def reciprocal_rank_fusion(text_results, vector_results, k=60):
"""
Perform Reciprocal Rank Fusion (RRF) on the results from text and vector searches,
based on algorithm described here:
https://learn.microsoft.com/azure/search/hybrid-search-ranking#how-rrf-ranking-works
"""
scores = {}
for i, doc in enumerate(text_results):
if doc["id"] not in scores:
scores[doc["id"]] = 0
scores[doc["id"]] += 1 / (i + k)
for i, doc in enumerate(vector_results):
if doc["id"] not in scores:
scores[doc["id"]] = 0
scores[doc["id"]] += 1 / (i + k)
scored_documents = sorted(scores.items(), key=lambda x: x[1], reverse=True)
retrieved_documents = [documents_by_id[doc_id] for doc_id, _ in scored_documents]
return retrieved_documents
def rerank(query, retrieved_documents):
"""
Rerank the results using a cross-encoder model.
"""
encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
scores = encoder.predict([(query, doc["text"]) for doc in retrieved_documents])
scored_documents = [v for _, v in sorted(zip(scores, retrieved_documents), reverse=True)]
return scored_documents
def hybrid_search(query, limit):
"""
Perform a hybrid search using both full-text and vector search.
"""
text_results = full_text_search(query, limit * 2)
vector_results = vector_search(query, limit * 2)
fused_results = reciprocal_rank_fusion(text_results, vector_results)
reranked_results = rerank(query, fused_results)
return reranked_results[:limit]
# Get the user question
user_question = "cute gray fuzzy bee"
# Search the index for the user question
retrieved_documents = hybrid_search(user_question, limit=5)
print(f"Retrieved {len(retrieved_documents)} matching documents.")
context = "\n".join([f"{doc['id']}: {doc['text']}" for doc in retrieved_documents[0:5]])
# Now we can use the matches to generate a response
SYSTEM_MESSAGE = """
You are a helpful assistant that answers questions about insects.
You must use the data set to answer the questions,
you should not provide any info that is not in the provided sources.
Cite the sources you used to answer the question inside square brackets.
The sources are in the format: <id>: <text>.
"""
response = client.chat.completions.create(
model=MODEL_NAME,
temperature=0.3,
messages=[
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": f"{user_question}\nSources: {context}"},
],
)
print(f"\nResponse from {MODEL_NAME} on {API_HOST}: \n")
print(response.choices[0].message.content)