-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3a65fc5
commit eca8317
Showing
74 changed files
with
1,206 additions
and
298 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,3 +37,4 @@ docs/_build/ | |
.env | ||
tmp/ | ||
examples/data/cache/dense | ||
examples/data/*.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
AWS_ACCESS_KEY_ID=[... AWS creds if caching through S3 ...] | ||
AWS_SECRET_ACCESS_KEY=[... AWS creds if caching through S3 ...] | ||
|
||
LEPTON_API_KEY=[... Lepton key (obtainable at dashboard.lepton.ai) if running SSAs on Aitomatic services ...] | ||
OPENAI_API_KEY=[... OpenAI creds if running SSAs directly on OpenAI services ...] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.FinanceBench/ | ||
.streamlit/secrets.toml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
LEPTON_API_KEY = '...' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
run-streamlit: | ||
@streamlit run streamlit-main.py --server.allowRunOnSave=true --server.runOnSave=true | ||
|
||
solve: | ||
@poetry run python3 ssa_fb/prob_solve.py ${id} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
<!-- markdownlint-disable MD043 --> | ||
|
||
# OpenSSA-FinanceBench benchmarking | ||
|
||
This example conducts the benchmarking of `OpenSSA` performance | ||
on the `FinanceBench` dataset. | ||
|
||
## [`FinanceBench` Dataset](https://github.com/patronus-ai/financebench/blob/main/financebench_sample_150.csv) | ||
|
||
## Running Aitomatic SSA benchmarking project | ||
|
||
Have Python 3.10-3.11 installed. | ||
|
||
Have Poetry installed: __`make get-poetry`__. | ||
|
||
__Install__ project, and update its dependencies from time to time: | ||
__`make install`__. | ||
|
||
Create `.env` file following the `.env.template` and fill in necessary credentials. | ||
|
||
__Solve__ the problem corresponding to a specific `financebench_id`: | ||
__`make solve id=...`__. | ||
|
||
- refer to `FinanceBench` dataset above for `financebench_id`s | ||
and corresponding information | ||
|
||
## Notes to Aitomatic Developers | ||
|
||
The OpenSSA dependency for this benchmarking project is from the `experimental` | ||
branch of the private [SSA](https://github.com/aitomatic/ssa) repository. | ||
Hence, all improvements to OpenSSA during this project must be | ||
committed/pushed/merged into that repository and branch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[virtualenvs] | ||
create = true | ||
in-project = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
OpenSSA[contrib] @ https://GitHub.com/Aitomatic/OpenSSA/archive/main.zip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
### Instruction to run finance bench dataset with OpenSSA OODA | ||
Check notice on line 1 in examples/financebench/scripts/README.md
|
||
|
||
|
||
1. Download finacial reports | ||
|
||
``` | ||
python data.py | ||
``` | ||
|
||
|
||
2. Load documents and run Q&A on set of questions. | ||
|
||
``` | ||
python qa.py # standard RAG | ||
python ooda-qa.py # run with ooda | ||
``` | ||
|
||
3. Auto resume if the run was incompleted or stopped in the middle. | ||
|
||
``` | ||
python qa.py | ||
python ooda-qa.py | ||
``` | ||
|
||
|
||
4. Output | ||
|
||
``` | ||
tmp/finance-bench/output.csv | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
from pathlib import Path | ||
import pandas as pd | ||
import requests | ||
from requests.exceptions import HTTPError | ||
|
||
# Note JOHNSON&JOHNSON is not downloadable | ||
|
||
FINANCEBENCH_METADATA_URL: str = ( | ||
"https://raw.githubusercontent.com/patronus-ai/" | ||
"financebench/main/financebench_sample_150.csv" | ||
) | ||
|
||
|
||
def download_dataset(): | ||
# Read the CSV file | ||
df_finbench = pd.read_csv(FINANCEBENCH_METADATA_URL) | ||
df_finbench["status"] = "ok" | ||
|
||
base_directory = "tmp/finance-bench/docs" | ||
count = 0 | ||
for index, row in df_finbench.iterrows(): | ||
doc_name = row["doc_name"] | ||
doc_link = row["doc_link"] | ||
|
||
# Create a subdirectory for each document | ||
doc_directory = os.path.join(base_directory, doc_name) | ||
if not os.path.exists(doc_directory): | ||
os.makedirs(doc_directory) | ||
|
||
# Path for the PDF file | ||
file_path = os.path.join(doc_directory, f"{doc_name}.pdf") | ||
|
||
# Check if the file has already been downloaded | ||
if not Path(file_path).is_file(): | ||
try: | ||
# Download the file | ||
response = requests.get(doc_link, timeout=30) | ||
response.raise_for_status() # Raises if the HTTP request, returned an unsuccessful status code | ||
|
||
# Write the content to a file | ||
with open(file_path, "wb") as file: | ||
file.write(response.content) | ||
print(f"Downloaded and saved: {file_path}") | ||
count += 1 | ||
except HTTPError as e: | ||
df_finbench.loc[index, "status"] = "error" # Update the status to 'error' | ||
print(f"Error downloading {file_path}: {e}") | ||
else: | ||
print(f"File already exists, skipping: {file_path}") | ||
|
||
dataset_directory = "tmp/finance-bench" | ||
df_finbench.to_csv(os.path.join(dataset_directory, "finance_bench_dataset.csv"), index=False) | ||
print(f"All files processed. Total files downloaded: {count}") | ||
|
||
|
||
if __name__ == "__main__": | ||
download_dataset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import os | ||
from loguru import logger | ||
import nest_asyncio | ||
import pandas as pd | ||
from openssa.utils.utils import Utils | ||
from openssa.core.ooda_rag.solver import OodaSSA | ||
|
||
nest_asyncio.apply() | ||
|
||
PATH: str = "./tmp/finance-bench/docs" | ||
FINANCEBENCH_CSV: str = "./tmp/finance-bench/finance_bench_dataset.csv" | ||
OUTPUT_DIRECTORY: str = "tmp/finance-bench/output" | ||
OUTPUT_FILE_NAME: str = "ooda_rag_output.csv" | ||
|
||
|
||
@Utils.timeit | ||
def process_doc(doc_name: str, question: str) -> str: | ||
ssa = OodaSSA(enable_generative=True) | ||
resource = os.path.join(PATH, doc_name) | ||
ssa.activate_resources(resource) | ||
solution = ssa.solve(question) | ||
return solution | ||
|
||
|
||
def run(): | ||
output_file_path = os.path.join(OUTPUT_DIRECTORY, OUTPUT_FILE_NAME) | ||
answer_column_name = "ooda_answer" | ||
# Check if the output file exists, and read from it if available (load cache) | ||
if os.path.exists(output_file_path): | ||
df_finbench = pd.read_csv(output_file_path) | ||
else: | ||
# If the output file does not exist, read from the original dataset | ||
df_finbench = pd.read_csv(FINANCEBENCH_CSV) | ||
if answer_column_name not in df_finbench.columns: | ||
df_finbench[answer_column_name] = "" | ||
df_finbench = df_finbench.fillna("") | ||
|
||
if not os.path.exists(OUTPUT_DIRECTORY): | ||
os.makedirs(OUTPUT_DIRECTORY) | ||
|
||
for index, row in df_finbench.iterrows(): | ||
logger.info(f"Processing row {index} of {len(df_finbench)} : {row['doc_name']}") | ||
if row["status"] == "ok" and not row[answer_column_name]: | ||
doc_name = row["doc_name"] | ||
question = row["question"] | ||
answer = process_doc(doc_name, question) | ||
df_finbench.loc[index, answer_column_name] = answer | ||
# Save progress as cache after processing each row | ||
df_finbench.to_csv(output_file_path, index=False) | ||
print(f"complete index {index} of {len(df_finbench)}") | ||
|
||
# if any answer contain "Empty Response" then update it to "file error" | ||
df_finbench.loc[ | ||
df_finbench[answer_column_name].str.lower().str.contains("empty response"), | ||
answer_column_name, | ||
] = "file reading error" | ||
|
||
df_finbench.to_csv(output_file_path, index=False) | ||
print("Processing complete. Output saved to:", output_file_path) | ||
|
||
|
||
def clean_up(): | ||
file_path = os.path.join(OUTPUT_DIRECTORY, OUTPUT_FILE_NAME) | ||
df_data = pd.read_csv(file_path) | ||
filtered_df = df_data[ | ||
~df_data["ooda_answer"].isna() | ||
& (df_data["ooda_answer"] != "file reading error") | ||
] | ||
clean_output_file_path = os.path.join( | ||
OUTPUT_DIRECTORY, "filtered_ooda_rag_output.csv" | ||
) | ||
filtered_df.to_csv(clean_output_file_path, index=False) | ||
print(f"Filtered data saved to {clean_output_file_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
Oops, something went wrong.