generated from pamelafox/python-project-template
-
Notifications
You must be signed in to change notification settings - Fork 154
/
Copy pathrag_documents_flow.py
70 lines (58 loc) · 2.52 KB
/
rag_documents_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
import os
import azure.identity
import openai
from dotenv import load_dotenv
from lunr import lunr
# Setup the OpenAI client to use either Azure, OpenAI.com, or Ollama API
load_dotenv(override=True)
API_HOST = os.getenv("API_HOST", "github")
if API_HOST == "azure":
token_provider = azure.identity.get_bearer_token_provider(
azure.identity.DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
)
client = openai.AzureOpenAI(
api_version=os.environ["AZURE_OPENAI_VERSION"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
azure_ad_token_provider=token_provider,
)
MODEL_NAME = os.environ["AZURE_OPENAI_DEPLOYMENT"]
elif API_HOST == "ollama":
client = openai.OpenAI(base_url=os.environ["OLLAMA_ENDPOINT"], api_key="nokeyneeded")
MODEL_NAME = os.environ["OLLAMA_MODEL"]
elif API_HOST == "github":
client = openai.OpenAI(base_url="https://models.inference.ai.azure.com", api_key=os.environ["GITHUB_TOKEN"])
MODEL_NAME = os.getenv("GITHUB_MODEL", "gpt-4o")
else:
client = openai.OpenAI(api_key=os.environ["OPENAI_KEY"])
MODEL_NAME = os.environ["OPENAI_MODEL"]
# Index the data from the JSON - each object has id, text, and embedding
with open("rag_ingested_chunks.json") as file:
documents = json.load(file)
documents_by_id = {doc["id"]: doc for doc in documents}
index = lunr(ref="id", fields=["text"], documents=documents)
# Get the user question
user_question = "where do digger bees live?"
# Search the index for the user question
results = index.search(user_question)
retrieved_documents = [documents_by_id[result["ref"]] for result in results]
print(f"Retrieved {len(retrieved_documents)} matching documents, only sending the first 5.")
context = "\n".join([f"{doc['id']}: {doc['text']}" for doc in retrieved_documents[0:5]])
# Now we can use the matches to generate a response
SYSTEM_MESSAGE = """
You are a helpful assistant that answers questions about insects.
You must use the data set to answer the questions,
you should not provide any info that is not in the provided sources.
Cite the sources you used to answer the question inside square brackets.
The sources are in the format: <id>: <text>.
"""
response = client.chat.completions.create(
model=MODEL_NAME,
temperature=0.3,
messages=[
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": f"{user_question}\nSources: {context}"},
],
)
print(f"\nResponse from {MODEL_NAME} on {API_HOST}: \n")
print(response.choices[0].message.content)