Skip to content

Commit 61d0947

Browse files
committed
Add hybrid example
1 parent c1f2c1e commit 61d0947

File tree

2 files changed

+143
-0
lines changed

2 files changed

+143
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ These scripts for RAG:
2323
* [`rag_queryrewrite.py`](./rag_queryrewrite.py): Adds a query rewriting step to the RAG process, where the user's question is rewritten to improve the retrieval results.
2424
* [`rag_documents_ingestion.py`](./rag_ingestion.py): Ingests PDFs by using pymupdf to convert to markdown, then using Langchain to split into chunks, then using OpenAI to embed the chunks, and finally storing in a local JSON file.
2525
* [`rag_documents_flow.py`](./rag_pdfs.py): A RAG flow that retrieves matching results from the local JSON file created by `rag_documents_ingestion.py`.
26+
* [`rag_hybrid.py`](./rag_hybrid.py): A RAG flow that implements a hybrid retrieval with both vector and keyword search, merging with Reciprocal Rank Fusion (RRF), and semantic re-ranking with a cross-encoder model.
2627

2728
## Setting up the environment
2829

rag_hybrid.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# pip install sentence-transformers
2+
import json
3+
import os
4+
5+
import azure.identity
6+
import openai
7+
from dotenv import load_dotenv
8+
from lunr import lunr
9+
from sentence_transformers import CrossEncoder
10+
11+
# Setup the OpenAI client to use either Azure, OpenAI.com, or Ollama API
12+
load_dotenv(override=True)
13+
API_HOST = os.getenv("API_HOST")
14+
15+
if API_HOST == "azure":
16+
token_provider = azure.identity.get_bearer_token_provider(
17+
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
18+
)
19+
client = openai.AzureOpenAI(
20+
api_version=os.environ["AZURE_OPENAI_VERSION"],
21+
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
22+
azure_ad_token_provider=token_provider,
23+
)
24+
MODEL_NAME = os.environ["AZURE_OPENAI_DEPLOYMENT"]
25+
26+
elif API_HOST == "ollama":
27+
client = openai.OpenAI(base_url=os.environ["OLLAMA_ENDPOINT"], api_key="nokeyneeded")
28+
MODEL_NAME = os.environ["OLLAMA_MODEL"]
29+
30+
elif API_HOST == "github":
31+
client = openai.OpenAI(base_url="https://models.inference.ai.azure.com", api_key=os.environ["GITHUB_TOKEN"])
32+
MODEL_NAME = os.environ["GITHUB_MODEL"]
33+
34+
else:
35+
client = openai.OpenAI(api_key=os.environ["OPENAI_KEY"])
36+
MODEL_NAME = os.environ["OPENAI_MODEL"]
37+
38+
# Index the data from the JSON - each object has id, text, and embedding
39+
with open("rag_ingested_chunks.json") as file:
40+
documents = json.load(file)
41+
documents_by_id = {doc["id"]: doc for doc in documents}
42+
index = lunr(ref="id", fields=["text"], documents=documents)
43+
44+
45+
def full_text_search(query, limit):
46+
"""
47+
Perform a full-text search on the indexed documents.
48+
"""
49+
results = index.search(query)
50+
retrieved_documents = [documents_by_id[result["ref"]] for result in results[:limit]]
51+
return retrieved_documents
52+
53+
54+
def vector_search(query, limit):
55+
"""
56+
Perform a vector search on the indexed documents
57+
using a simple cosine similarity function.
58+
"""
59+
60+
def cosine_similarity(a, b):
61+
return sum(x * y for x, y in zip(a, b)) / ((sum(x * x for x in a) ** 0.5) * (sum(y * y for y in b) ** 0.5))
62+
63+
query_embedding = client.embeddings.create(model="text-embedding-3-small", input=query).data[0].embedding
64+
similarities = []
65+
for doc in documents:
66+
doc_embedding = doc["embedding"]
67+
similarity = cosine_similarity(query_embedding, doc_embedding)
68+
similarities.append((doc, similarity))
69+
similarities.sort(key=lambda x: x[1], reverse=True)
70+
71+
retrieved_documents = [doc for doc, _ in similarities[:limit]]
72+
return retrieved_documents
73+
74+
75+
def reciprocal_rank_fusion(text_results, vector_results, alpha=0.5):
76+
"""
77+
Perform Reciprocal Rank Fusion on the results from text and vector searches.
78+
"""
79+
text_ids = {doc["id"] for doc in text_results}
80+
vector_ids = {doc["id"] for doc in vector_results}
81+
82+
combined_results = []
83+
for doc in text_results:
84+
if doc["id"] in vector_ids:
85+
combined_results.append((doc, alpha))
86+
else:
87+
combined_results.append((doc, 1 - alpha))
88+
for doc in vector_results:
89+
if doc["id"] not in text_ids:
90+
combined_results.append((doc, alpha))
91+
combined_results.sort(key=lambda x: x[1], reverse=True)
92+
return [doc for doc, _ in combined_results]
93+
94+
95+
def rerank(query, retrieved_documents):
96+
"""
97+
Rerank the results using a cross-encoder model.
98+
"""
99+
encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
100+
scores = encoder.predict([(query, doc["text"]) for doc in retrieved_documents])
101+
return [v for _, v in sorted(zip(scores, retrieved_documents), reverse=True)]
102+
103+
104+
def hybrid_search(query, limit):
105+
"""
106+
Perform a hybrid search using both full-text and vector search.
107+
"""
108+
text_results = full_text_search(query, limit * 2)
109+
vector_results = vector_search(query, limit * 2)
110+
combined_results = reciprocal_rank_fusion(text_results, vector_results)
111+
combined_results = rerank(query, combined_results)
112+
return combined_results[:limit]
113+
114+
115+
# Get the user question
116+
user_question = "cute gray fuzzsters"
117+
118+
# Search the index for the user question
119+
retrieved_documents = hybrid_search(user_question, limit=5)
120+
print(f"Retrieved {len(retrieved_documents)} matching documents.")
121+
context = "\n".join([f"{doc['id']}: {doc['text']}" for doc in retrieved_documents[0:5]])
122+
123+
# Now we can use the matches to generate a response
124+
SYSTEM_MESSAGE = """
125+
You are a helpful assistant that answers questions about Maya civilization.
126+
You must use the data set to answer the questions,
127+
you should not provide any info that is not in the provided sources.
128+
Cite the sources you used to answer the question inside square brackets.
129+
The sources are in the format: <id>: <text>.
130+
"""
131+
132+
response = client.chat.completions.create(
133+
model=MODEL_NAME,
134+
temperature=0.3,
135+
messages=[
136+
{"role": "system", "content": SYSTEM_MESSAGE},
137+
{"role": "user", "content": f"{user_question}\nSources: {context}"},
138+
],
139+
)
140+
141+
print(f"\nResponse from {MODEL_NAME} on {API_HOST}: \n")
142+
print(response.choices[0].message.content)

0 commit comments

Comments
 (0)