generated from pamelafox/python-project-template
-
Notifications
You must be signed in to change notification settings - Fork 154
/
Copy pathrag_queryrewrite.py
96 lines (77 loc) · 3.63 KB
/
rag_queryrewrite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import csv
import os
import azure.identity
import openai
from dotenv import load_dotenv
from lunr import lunr
# Setup the OpenAI client to use either Azure, OpenAI.com, or Ollama API
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
if API_HOST == "azure":
token_provider = azure.identity.get_bearer_token_provider(
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
client = openai.AzureOpenAI(
api_version=os.environ["AZURE_OPENAI_VERSION"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
azure_ad_token_provider=token_provider,
)
MODEL_NAME = os.environ["AZURE_OPENAI_DEPLOYMENT"]
elif API_HOST == "ollama":
client = openai.OpenAI(base_url=os.environ["OLLAMA_ENDPOINT"], api_key="nokeyneeded")
MODEL_NAME = os.environ["OLLAMA_MODEL"]
elif API_HOST == "github":
client = openai.OpenAI(base_url="https://models.inference.ai.azure.com", api_key=os.environ["GITHUB_TOKEN"])
MODEL_NAME = os.getenv("GITHUB_MODEL", "gpt-4o")
else:
client = openai.OpenAI(api_key=os.environ["OPENAI_KEY"])
MODEL_NAME = os.environ["OPENAI_MODEL"]
# Index the data from the CSV
with open("hybrid.csv") as file:
reader = csv.reader(file)
rows = list(reader)
documents = [{"id": (i + 1), "body": " ".join(row)} for i, row in enumerate(rows[1:])]
index = lunr(ref="id", fields=["body"], documents=documents)
def search(query):
# Search the index for the user question
results = index.search(query)
matching_rows = [rows[int(result["ref"])] for result in results]
# Format as a markdown table, since language models understand markdown
matches_table = " | ".join(rows[0]) + "\n" + " | ".join(" --- " for _ in range(len(rows[0]))) + "\n"
matches_table += "\n".join(" | ".join(row) for row in matching_rows)
return matches_table
QUERY_REWRITE_SYSTEM_MESSAGE = """
You are a helpful assistant that rewrites user questions into good keyword queries
for an index of CSV rows with these columns: vehicle, year, msrp, acceleration, mpg, class.
Good keyword queries don't have any punctuation, and are all lowercase.
You will be given the user's new question and the conversation history.
Respond with ONLY the suggested keyword query, no other text.
"""
SYSTEM_MESSAGE = """
You are a helpful assistant that answers questions about cars based off a hybrid car data set.
You must use the data set to answer the questions, you should not provide any info that is not in the provided sources.
"""
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
while True:
question = input("\nYour question about electric cars: ")
# Rewrite the query to fix typos and incorporate past context
response = client.chat.completions.create(
model=MODEL_NAME,
temperature=0.05,
messages=[
{"role": "system", "content": QUERY_REWRITE_SYSTEM_MESSAGE},
{"role": "user", "content": f"New user question:{question}\n\nConversation history:{messages}"},
],
)
search_query = response.choices[0].message.content
print(f"Rewritten query: {search_query}")
# Search the CSV for the question
matches = search(search_query)
print("Found matches:\n", matches)
# Use the matches to generate a response
messages.append({"role": "user", "content": f"{question}\nSources: {matches}"})
response = client.chat.completions.create(model=MODEL_NAME, temperature=0.3, messages=messages)
bot_response = response.choices[0].message.content
messages.append({"role": "assistant", "content": bot_response})
print(f"\nResponse from {API_HOST} {MODEL_NAME}: \n")
print(bot_response)