Skip to content

Commit 9fff523

Browse files
committed
Add way more RAG demos
1 parent 2584b63 commit 9fff523

11 files changed

+196245
-24
lines changed

data/California_carpenter_bee.pdf

316 KB
Binary file not shown.

data/Centris_pallida.pdf

970 KB
Binary file not shown.

data/Western_honey_bee.pdf

1010 KB
Binary file not shown.

rag.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import csv
2+
import os
3+
4+
import azure.identity
5+
import openai
6+
from dotenv import load_dotenv
7+
from lunr import lunr
8+
9+
# Setup the OpenAI client to use either Azure, OpenAI.com, or Ollama API
10+
load_dotenv(override=True)
11+
API_HOST = os.getenv("API_HOST")
12+
13+
if API_HOST == "azure":
14+
token_provider = azure.identity.get_bearer_token_provider(
15+
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
16+
)
17+
client = openai.AzureOpenAI(
18+
api_version=os.environ["AZURE_OPENAI_VERSION"],
19+
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
20+
azure_ad_token_provider=token_provider,
21+
)
22+
MODEL_NAME = os.environ["AZURE_OPENAI_DEPLOYMENT"]
23+
24+
elif API_HOST == "ollama":
25+
client = openai.OpenAI(base_url=os.environ["OLLAMA_ENDPOINT"], api_key="nokeyneeded")
26+
MODEL_NAME = os.environ["OLLAMA_MODEL"]
27+
28+
elif API_HOST == "github":
29+
client = openai.OpenAI(base_url="https://models.inference.ai.azure.com", api_key=os.environ["GITHUB_TOKEN"])
30+
MODEL_NAME = os.environ["GITHUB_MODEL"]
31+
32+
else:
33+
client = openai.OpenAI(api_key=os.environ["OPENAI_KEY"])
34+
MODEL_NAME = os.environ["OPENAI_MODEL"]
35+
36+
# Index the data from the CSV
37+
with open("hybrid.csv") as file:
38+
reader = csv.reader(file)
39+
rows = list(reader)
40+
documents = [{"id": (i + 1), "body": " ".join(row)} for i, row in enumerate(rows[1:])]
41+
index = lunr(ref="id", fields=["body"], documents=documents)
42+
43+
# Get the user question
44+
user_question = "how fast is the prius v?"
45+
46+
# Search the index for the user question
47+
results = index.search(user_question)
48+
matching_rows = [rows[int(result["ref"])] for result in results]
49+
50+
# Format as a markdown table, since language models understand markdown
51+
matches_table = " | ".join(rows[0]) + "\n" + " | ".join(" --- " for _ in range(len(rows[0]))) + "\n"
52+
matches_table += "\n".join(" | ".join(row) for row in matching_rows)
53+
54+
print("Found matches:")
55+
print(matches_table)
56+
57+
# Now we can use the matches to generate a response
58+
SYSTEM_MESSAGE = """
59+
You are a helpful assistant that answers questions about cars based off a hybrid car data set.
60+
You must use the data set to answer the questions, you should not provide any info that is not in the provided sources.
61+
"""
62+
63+
response = client.chat.completions.create(
64+
model=MODEL_NAME,
65+
temperature=0.3,
66+
messages=[
67+
{"role": "system", "content": SYSTEM_MESSAGE},
68+
{"role": "user", "content": f"{user_question}\nSources: {matches_table}"},
69+
],
70+
)
71+
72+
print(f"\nResponse from {API_HOST}: \n")
73+
print(response.choices[0].message.content)

rag_hybrid.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import json
2+
import os
3+
4+
import azure.identity
5+
import openai
6+
from dotenv import load_dotenv
7+
from lunr import lunr
8+
9+
# Setup the OpenAI client to use either Azure, OpenAI.com, or Ollama API
10+
load_dotenv(override=True)
11+
API_HOST = os.getenv("API_HOST")
12+
13+
if API_HOST == "azure":
14+
token_provider = azure.identity.get_bearer_token_provider(
15+
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
16+
)
17+
client = openai.AzureOpenAI(
18+
api_version=os.environ["AZURE_OPENAI_VERSION"],
19+
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
20+
azure_ad_token_provider=token_provider,
21+
)
22+
MODEL_NAME = os.environ["AZURE_OPENAI_DEPLOYMENT"]
23+
24+
elif API_HOST == "ollama":
25+
client = openai.OpenAI(base_url=os.environ["OLLAMA_ENDPOINT"], api_key="nokeyneeded")
26+
MODEL_NAME = os.environ["OLLAMA_MODEL"]
27+
28+
elif API_HOST == "github":
29+
client = openai.OpenAI(base_url="https://models.inference.ai.azure.com", api_key=os.environ["GITHUB_TOKEN"])
30+
MODEL_NAME = os.environ["GITHUB_MODEL"]
31+
32+
else:
33+
client = openai.OpenAI(api_key=os.environ["OPENAI_KEY"])
34+
MODEL_NAME = os.environ["OPENAI_MODEL"]
35+
36+
# Index the data from the JSON - each object has id, text, and embedding
37+
with open("rag_ingested_chunks.json") as file:
38+
documents = json.load(file)
39+
documents_by_id = {doc["id"]: doc for doc in documents}
40+
index = lunr(ref="id", fields=["text"], documents=documents)
41+
42+
# Get the user question
43+
user_question = "where do digger bees live?"
44+
45+
# Search the index for the user question
46+
results = index.search(user_question)
47+
retrieved_documents = [documents_by_id[result["ref"]] for result in results]
48+
print(f"Retrieved {len(retrieved_documents)} matching documents, only sending the first 5.")
49+
context = "\n".join([f"{doc['id']}: {doc['text']}" for doc in retrieved_documents[0:5]])
50+
51+
# Now we can use the matches to generate a response
52+
SYSTEM_MESSAGE = """
53+
You are a helpful assistant that answers questions about Maya civilization.
54+
You must use the data set to answer the questions,
55+
you should not provide any info that is not in the provided sources.
56+
Cite the sources you used to answer the question inside square brackets.
57+
The sources are in the format: <id>: <text>.
58+
"""
59+
60+
response = client.chat.completions.create(
61+
model=MODEL_NAME,
62+
temperature=0.3,
63+
messages=[
64+
{"role": "system", "content": SYSTEM_MESSAGE},
65+
{"role": "user", "content": f"{user_question}\nSources: {context}"},
66+
],
67+
)
68+
69+
print(f"\nResponse from {MODEL_NAME} on {API_HOST}: \n")
70+
print(response.choices[0].message.content)

0 commit comments

Comments
 (0)