Skip to content

Commit c1f2c1e

Browse files
committed
Some renames
1 parent 9fff523 commit c1f2c1e

File tree

5 files changed

+32
-30
lines changed

5 files changed

+32
-30
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ Plus these scripts to demonstrate additional features:
1616
* [`chat_langchain.py`](./chat_langchain.py): Uses the langchain SDK to generate chat completions. [Learn more from Langchain docs](https://python.langchain.com/docs/get_started/quickstart)
1717
* [`chat_llamaindex.py`](./chat_llamaindex.py): Uses the LlamaIndex SDK to generate chat completions. [Learn more from LlamaIndex docs](https://docs.llamaindex.ai/en/stable/)
1818

19+
These scripts for RAG:
20+
21+
* [`rag_csv.py`](./rag.py): Retrieves matching results from a CSV file and uses them to answer user's question.
22+
* [`rag_multiturn.py`](./rag_multiturn.py): The same idea, but with a back-and-forth chat interface using `input()` which keeps track of past messages and sends them with each chat completion call.
23+
* [`rag_queryrewrite.py`](./rag_queryrewrite.py): Adds a query rewriting step to the RAG process, where the user's question is rewritten to improve the retrieval results.
24+
* [`rag_documents_ingestion.py`](./rag_ingestion.py): Ingests PDFs by using pymupdf to convert to markdown, then using Langchain to split into chunks, then using OpenAI to embed the chunks, and finally storing in a local JSON file.
25+
* [`rag_documents_flow.py`](./rag_pdfs.py): A RAG flow that retrieves matching results from the local JSON file created by `rag_documents_ingestion.py`.
26+
1927
## Setting up the environment
2028

2129
If you open this up in a Dev Container or GitHub Codespaces, everything will be setup for you.

rag.py renamed to rag_csv.py

File renamed without changes.
File renamed without changes.
File renamed without changes.

retrieval_augmented_generation.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,29 @@
3333
MODEL_NAME = os.environ["OPENAI_MODEL"]
3434

3535

36-
def search(query):
37-
# Open the CSV and store in a list
38-
with open("hybrid.csv") as file:
39-
reader = csv.reader(file)
40-
rows = list(reader)
41-
42-
# Normalize the user question to replace punctuation and make lowercase
43-
normalized_message = query.lower().replace("?", "").replace("(", " ").replace(")", " ")
44-
# Search the CSV for user question using very naive search
45-
words = normalized_message.split()
46-
matching_rows = []
47-
for row in rows[1:]:
48-
# if the word matches any word in row, add the row to the matches
49-
if any(word in row[0].lower().split() for word in words) or any(
50-
word in row[5].lower().split() for word in words
51-
):
52-
matching_rows.append(row)
53-
# Format as a markdown table, since language models understand markdown
54-
matches_table = " | ".join(rows[0]) + "\n" + " | ".join(" --- " for _ in range(len(rows[0]))) + "\n"
55-
matches_table += "\n".join(" | ".join(row) for row in matches)
56-
return matches_table
57-
58-
59-
user_question = "how fast is the prius v?"
60-
61-
matches = search(user_question)
62-
63-
print("Found matches:")
64-
print(matches)
36+
USER_MESSAGE = "how fast is the prius v?"
37+
38+
# Open the CSV and store in a list
39+
with open("hybrid.csv") as file:
40+
reader = csv.reader(file)
41+
rows = list(reader)
42+
43+
# Normalize the user question to replace punctuation and make lowercase
44+
normalized_message = USER_MESSAGE.lower().replace("?", "").replace("(", " ").replace(")", " ")
45+
46+
# Search the CSV for user question using very naive search
47+
words = normalized_message.split()
48+
matches = []
49+
for row in rows[1:]:
50+
# if the word matches any word in row, add the row to the matches
51+
if any(word in row[0].lower().split() for word in words) or any(word in row[5].lower().split() for word in words):
52+
matches.append(row)
53+
54+
# Format as a markdown table, since language models understand markdown
55+
matches_table = " | ".join(rows[0]) + "\n" + " | ".join(" --- " for _ in range(len(rows[0]))) + "\n"
56+
matches_table += "\n".join(" | ".join(row) for row in matches)
57+
print(f"Found {len(matches)} matches:")
58+
print(matches_table)
6559

6660
# Now we can use the matches to generate a response
6761
SYSTEM_MESSAGE = """
@@ -74,7 +68,7 @@ def search(query):
7468
temperature=0.3,
7569
messages=[
7670
{"role": "system", "content": SYSTEM_MESSAGE},
77-
{"role": "user", "content": user_question + "\nSources: " + matches},
71+
{"role": "user", "content": USER_MESSAGE + "\nSources: " + matches_table},
7872
],
7973
)
8074

0 commit comments

Comments
 (0)