diff --git a/.gitignore b/.gitignore index 9da3b2b85..d9cd28efd 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,4 @@ docs/_build/ .env tmp/ examples/data/cache/dense +examples/data/*.csv diff --git a/.ruff.toml b/.ruff.toml index 8ab1c7a9c..0b91a8b0b 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -48,6 +48,7 @@ ignore = [ "D107", # missing docstring in `__init__` "D200", # one-line docstring should fit on one line "D202", # no blank lines allowed after function docstring + "D203", # one blank line before class "D204", # 1 blank line required after class docstring "D205", # 1 blank line required between summary line and description "D212", # multi-line docstring summary should start at the first line @@ -66,10 +67,12 @@ ignore = [ "EM102", # exception must not use an f-string literal, assign to variable first "ERA001", # found commented-out code "F401", # imported but unused + "FA102", # missing `from __future__ import annotations`, but uses PEP 604 union "FBT001", # boolean-typed positional argument in function definition "FBT002", # boolean default positional argument in function definition "FBT003", # boolean positional value in function call "FIX002", # line contains TODO, consider resolving the issue + "FURB101", # `open` and `read` should be replaced by `Path(file_path).read_bytes()` "I001", # import block is un-sorted or un-formatted "INP001", # file is part of an implicit namespace package; add an `__init__.py` "LOG009", # use of undocumented `logging.WARN` constant @@ -106,11 +109,16 @@ ignore = [ "RUF013", # PEP 484 prohibits implicit `Optional` "RUF017", # Avoid quadratic list summation" "RUF018", # avoid assignment expressions in `assert` statements + "RUF022", # __all__` is not sorted + "RUF100", # unused `noqa` directive "S101", # use of `assert` detected "S106", # possible hardcoded password assigned to argument + "S110", # `try`-`except`-`pass` detected, consider logging the exception "S605", # starting a process with a shell, possible injection detected + "S607", # starting a process with a partial executable path "SIM102", # use a single `if` statement instead of nested `if` statements "SIM108", # use ternary operator `temp = temp["content"] if isinstance(temp, dict) else temp.content` instead of `if`-`else`-block + "SIM105", # use `contextlib.suppress(Exception)` instead of `try`-`except`-`pass` "SIM112", # use capitalized environment variable "SIM401", # use `item.get("role", "assistant")` instead of an `if` block "SLF001", # private member accessed @@ -123,5 +131,4 @@ ignore = [ "UP007", # use `X | Y` for type annotations "UP035", # `typing.[X]` is deprecated, use `[x]` instead "UP039", # unnecessary parentheses after class definition - "D203", # one-blank-line-before-class ] diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 8e6b1c832..747814822 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -9,8 +9,11 @@ ], "unwantedRecommendations": [ - "ms-python.mypy-type-checker", // MyPy Type Checker (Microsoft) + "ms-python.autopep8", // AutoPEP8 (Microsoft) + "ms-python.black-formatter", // Black Formatter (Microsoft) + "eeyore.yapf", // Yet Another Python Formatter + "ms-python.mypy-type-checker", // MyPy Type Checker (Microsoft) "matangover.mypy" // MyPy ] } diff --git a/Makefile b/Makefile index 664dd2563..7476f3154 100644 --- a/Makefile +++ b/Makefile @@ -36,7 +36,7 @@ install: poetry install --extras=contrib --with=docs --with=lint --with=test install-editable: - python3 -m pip install -e ".[contrib]" --upgrade --user + python3 -m pip install -e ".[contrib]" --upgrade # LINTING @@ -56,7 +56,6 @@ lint-pylint: lint-ruff: # docs.astral.sh/ruff/linter poetry run ruff check $(LIB_DIR) $(DOCS_DIR) $(EXAMPLES_DIR) $(TESTS_DIR) \ - --show-source \ --output-format text \ --target-version py310 \ --preview \ diff --git a/README.md b/README.md index 89f3e060e..aa2faaa00 100644 --- a/README.md +++ b/README.md @@ -147,7 +147,7 @@ Head to [Lepton](https://dashboard.lepton.ai/) to get your API key. * Select `API tokens` * Copy `` -In terminal, run +In terminal, run ```bash= export LEPTON_API_KEY= ``` diff --git a/examples/data/docs/m290/m290.pdf b/examples/data/docs/m290/m290.pdf deleted file mode 100644 index 6cd5b2dd1..000000000 Binary files a/examples/data/docs/m290/m290.pdf and /dev/null differ diff --git a/examples/data/notebook_qa_comparion.csv b/examples/data/notebook_qa_comparion.csv deleted file mode 100644 index fe48aa1f0..000000000 --- a/examples/data/notebook_qa_comparion.csv +++ /dev/null @@ -1,11 +0,0 @@ -question,standard vector index,rag ooda -do I need heat treatment for the 316L with M 290?,"Yes, heat treatment is required for the 316L with M 290.","Yes, heat treatment is required for 316L when using the M 290 equipment." -Which steel can I print?,"You can print EOS CaseHardeningSteel 20MnCrZr9030-0004, EOS MaragingSteel MS19011-0016, EOS StainlessSteel 17-4PH9011-0041, EOS StainlessSteel 2549030-0007, EOS StainlessSteel 316L9011-0032, EOS StainlessSteel CX9011-0037, and EOS StainlessSteel PH19011-0019.","You can print with the following steels using the EOS M 290 system: EOS CaseHardeningSteel 20MnCrZr, EOS MaragingSteel MS1, EOS StainlessSteel 17-4PH, EOS StainlessSteel 254, EOS StainlessSteel 316L, EOS StainlessSteel CX, EOS StainlessSteel PH1, and EOS ToolSteel 1.2709." -can I print a part 50 cm long in the M 290?,"Yes, you can print a part 50 cm long in the M 290.","The conversation regarding the printing of a 50 cm long part using an EOS M 290 3D printer includes conflicting information. Initially, there was ambiguity since no specific details about the maximum build size of the EOS M 290 were provided. When the user asked whether a part 50 cm long could be printed, the response was that it is possible to print a part 50 cm long in the M 290. However, no specific dimensions were provided to support this claim. - -In the discussion, there is a mention of researching the build size of the M 290 printer, which indicates that the appropriate next step would involve finding out the printer's build volume capabilities to accurately assess if a 50 cm long part is printable. However, the actual build volume specifications of the EOS M 290 are not detailed within the conversation. - -Without the explicit dimensions of the printer's build volume, there are two contrasting statuses: one suggesting that printing a 50 cm part is possible, and the other, implied by the absence of hard data, suggesting that no conclusion can be drawn. - -Given the conflicting information and the lack of detailed specifications, it is not possible to accurately determine if the EOS M 290 can print a part 50 cm long based solely on the information provided in the conversation. To resolve this, the exact build dimensions of the M 290 need to be consulted from an accurate source." -Is it preferable to use nitrogen when printing with titanium?,"Yes, it is preferable to use nitrogen when printing with titanium.","No, it is not preferable to use nitrogen when printing with titanium on the EOS M 290. Titanium powder and titanium condensate can react exothermically with nitrogen, presenting a fire and explosion hazard. The machine should only be operated with argon as the inert gas when using titanium powder." diff --git a/examples/data/qa_comparion.csv b/examples/data/qa_comparion.csv deleted file mode 100644 index 8fd0691c6..000000000 --- a/examples/data/qa_comparion.csv +++ /dev/null @@ -1,4 +0,0 @@ -question,stand vector index,rag ooda -do I need heat treatment for the 316L with M 290?,"Yes, heat treatment is required for the 316L with M 290.","The dialogue contains mixed messages regarding the necessity of heat treatment for 316L stainless steel when subjected to the M 290 process. Two responses indicate that heat treatment is not required based on the defined properties of 316L stainless steel treated with this process, while another response claims that it is required. - -Given that the information provided is contradictory and there are no additional details or context that might clarify the situation, I cannot produce an accurate and definitive answer to the user's initial question. Tools available in the system provided different outcomes, and without further information, it's impossible to conclude whether heat treatment is necessary for 316L stainless steel in the context of the M 290 process." diff --git a/examples/data/qa_standard_agent_comparion.csv b/examples/data/qa_standard_agent_comparion.csv deleted file mode 100644 index f6a7bda23..000000000 --- a/examples/data/qa_standard_agent_comparion.csv +++ /dev/null @@ -1,33 +0,0 @@ -question,standard agent,rag ooda -do I need heat treatment for the 316L with M290?," Based on the context information provided, it is not necessary to perform heat treatment on 316L with M290. This is because the metal powder category for 316L is listed in the table as ""Metal powder, EOS StainlessSteel 316L"", which indicates that it is already suitable for building using the M290 process without the need for heat treatment.","The conversation presents a query about whether heat treatment is necessary for 316L stainless steel when used with an EOS M290 3D printer. The assistant's initial response, meant to guide further research, was to investigate the requirement of heat treatment for 316L post-printing and to look for best practices for post-processing such components made with an EOS M290 printer. - -However, subsequent messages from the system indicate research results that suggest there is no direct requirement of heat treatment for 316L stainless steel after using EOS M290 3D printing technology. The system also outlines best practices for post-processing, which include cleaning but do not specifically mention heat treatment. - -Another message generated by the system refers to reasons why heat treatment might be considered, such as relieving residual stresses, increasing hardness, and improving corrosion resistance. However, this message also mentions the lack of specific context or detailed information related to the initial question. - -Given that the system's tools seem to indicate no direct need for heat treatment for 316L when used with an EOS M290 based on the provided context, and without additional details or application-specific requirements, the synthesized answer to the user's initial question is: - -Based on the available information in the conversation, there is no explicit mention that heat treatment is required for 316L stainless steel after using an EOS M290 3D printer. It is suggested that in general, heat treatment may not be necessary unless there are particular requirements for the finished product that would demand such post-processing. For a definitive answer, it would be important to consider specific application needs or consult with material and manufacturing experts or guidelines from the EOS M290 manufacturer." -Which steel can I print?," Based on the provided context information, you can print EOS CaseHardeningSteel 20MnCrZr X using the EOS M 290 laser-sintering system.","To synthesize the final answer to the user's initial question about which steels are available for printing using a specific 3D printer, we have gathered information from the conversation indicating the specific printer model as the ""EOS M 290."" - -The tool performed a research based on this context and provided a list of steels that are available for printing using the EOS M 290 3D printer. According to the tool's findings, the compatible steels for printing with this machine are: - -- EOS CaseHardeningSteel 20MnCr5 -- EOS MaragingSteel MS1 -- EOS StainlessSteel 316L -- EOS StainlessSteel 17-4PH -- EOS StainlessSteel CX -- EOS StainlessSteel PH1 -- EOS ToolSteel H13 -- EOS ToolSteel MS1 -- EOS NickelAlloy HX -- EOS NickelAlloy IN718 -- EOS NickelAlloy IN625 -- EOS CobaltChrome SP2 -- EOS Titanium Ti64 -- EOS Titanium Ti64ELI -- EOS Aluminium AlSi10Mg - -This information was generated from the context provided and should be confirmed by consulting the manufacturer's specifications or user manuals for the most updated and accurate list of compatible materials. It is critical to ensure the steel chosen aligns with the printer's capabilities and the intended application of the printed components. Manufacturers may update the range of compatible materials over time, so direct confirmation from EOS or the latest guidance documents for the EOS M 290 is advised. - -Therefore, the synthesized answer, based on the provided dialog, is that a variety of steel materials are available for use in the EOS M 290, including various types of stainless steel, tool steel, and others as listed above. For a definitive and current list, consultation of the manufacturer's official documentation is recommended." diff --git a/examples/financebench/.env.template b/examples/financebench/.env.template new file mode 100644 index 000000000..374561b08 --- /dev/null +++ b/examples/financebench/.env.template @@ -0,0 +1,5 @@ +AWS_ACCESS_KEY_ID=[... AWS creds if caching through S3 ...] +AWS_SECRET_ACCESS_KEY=[... AWS creds if caching through S3 ...] + +LEPTON_API_KEY=[... Lepton key (obtainable at dashboard.lepton.ai) if running SSAs on Aitomatic services ...] +OPENAI_API_KEY=[... OpenAI creds if running SSAs directly on OpenAI services ...] diff --git a/examples/financebench/.gitignore b/examples/financebench/.gitignore new file mode 100644 index 000000000..ff3d0e57c --- /dev/null +++ b/examples/financebench/.gitignore @@ -0,0 +1,2 @@ +.FinanceBench/ +.streamlit/secrets.toml diff --git a/examples/financebench/.streamlit/secrets.toml.template b/examples/financebench/.streamlit/secrets.toml.template new file mode 100644 index 000000000..5d432f9ee --- /dev/null +++ b/examples/financebench/.streamlit/secrets.toml.template @@ -0,0 +1 @@ +LEPTON_API_KEY = '...' diff --git a/examples/financebench/Makefile b/examples/financebench/Makefile new file mode 100644 index 000000000..864103c49 --- /dev/null +++ b/examples/financebench/Makefile @@ -0,0 +1,5 @@ +run-streamlit: + @streamlit run streamlit-main.py --server.allowRunOnSave=true --server.runOnSave=true + +solve: + @poetry run python3 ssa_fb/prob_solve.py ${id} diff --git a/examples/financebench/README.md b/examples/financebench/README.md new file mode 100644 index 000000000..bec751b85 --- /dev/null +++ b/examples/financebench/README.md @@ -0,0 +1,32 @@ + + +# OpenSSA-FinanceBench benchmarking + +This example conducts the benchmarking of `OpenSSA` performance +on the `FinanceBench` dataset. + +## [`FinanceBench` Dataset](https://github.com/patronus-ai/financebench/blob/main/financebench_sample_150.csv) + +## Running Aitomatic SSA benchmarking project + +Have Python 3.10-3.11 installed. + +Have Poetry installed: __`make get-poetry`__. + +__Install__ project, and update its dependencies from time to time: +__`make install`__. + +Create `.env` file following the `.env.template` and fill in necessary credentials. + +__Solve__ the problem corresponding to a specific `financebench_id`: +__`make solve id=...`__. + +- refer to `FinanceBench` dataset above for `financebench_id`s +and corresponding information + +## Notes to Aitomatic Developers + +The OpenSSA dependency for this benchmarking project is from the `experimental` +branch of the private [SSA](https://github.com/aitomatic/ssa) repository. +Hence, all improvements to OpenSSA during this project must be +committed/pushed/merged into that repository and branch. diff --git a/examples/financebench/poetry.toml b/examples/financebench/poetry.toml new file mode 100644 index 000000000..53b35d370 --- /dev/null +++ b/examples/financebench/poetry.toml @@ -0,0 +1,3 @@ +[virtualenvs] +create = true +in-project = true diff --git a/examples/financebench/requirements.txt b/examples/financebench/requirements.txt new file mode 100644 index 000000000..dd81406d5 --- /dev/null +++ b/examples/financebench/requirements.txt @@ -0,0 +1 @@ +OpenSSA[contrib] @ https://GitHub.com/Aitomatic/OpenSSA/archive/main.zip diff --git a/examples/financebench/scripts/README.md b/examples/financebench/scripts/README.md new file mode 100644 index 000000000..65a082342 --- /dev/null +++ b/examples/financebench/scripts/README.md @@ -0,0 +1,30 @@ +### Instruction to run finance bench dataset with OpenSSA OODA + + +1. Download finacial reports + +``` + python data.py +``` + + +2. Load documents and run Q&A on set of questions. + +``` + python qa.py # standard RAG + python ooda-qa.py # run with ooda +``` + +3. Auto resume if the run was incompleted or stopped in the middle. + +``` + python qa.py + python ooda-qa.py +``` + + +4. Output + +``` + tmp/finance-bench/output.csv +``` diff --git a/examples/financebench/scripts/data.py b/examples/financebench/scripts/data.py new file mode 100644 index 000000000..5e2eaf4c4 --- /dev/null +++ b/examples/financebench/scripts/data.py @@ -0,0 +1,58 @@ +import os +from pathlib import Path +import pandas as pd +import requests +from requests.exceptions import HTTPError + +# Note JOHNSON&JOHNSON is not downloadable + +FINANCEBENCH_METADATA_URL: str = ( + "https://raw.githubusercontent.com/patronus-ai/" + "financebench/main/financebench_sample_150.csv" +) + + +def download_dataset(): + # Read the CSV file + df_finbench = pd.read_csv(FINANCEBENCH_METADATA_URL) + df_finbench["status"] = "ok" + + base_directory = "tmp/finance-bench/docs" + count = 0 + for index, row in df_finbench.iterrows(): + doc_name = row["doc_name"] + doc_link = row["doc_link"] + + # Create a subdirectory for each document + doc_directory = os.path.join(base_directory, doc_name) + if not os.path.exists(doc_directory): + os.makedirs(doc_directory) + + # Path for the PDF file + file_path = os.path.join(doc_directory, f"{doc_name}.pdf") + + # Check if the file has already been downloaded + if not Path(file_path).is_file(): + try: + # Download the file + response = requests.get(doc_link, timeout=30) + response.raise_for_status() # Raises if the HTTP request, returned an unsuccessful status code + + # Write the content to a file + with open(file_path, "wb") as file: + file.write(response.content) + print(f"Downloaded and saved: {file_path}") + count += 1 + except HTTPError as e: + df_finbench.loc[index, "status"] = "error" # Update the status to 'error' + print(f"Error downloading {file_path}: {e}") + else: + print(f"File already exists, skipping: {file_path}") + + dataset_directory = "tmp/finance-bench" + df_finbench.to_csv(os.path.join(dataset_directory, "finance_bench_dataset.csv"), index=False) + print(f"All files processed. Total files downloaded: {count}") + + +if __name__ == "__main__": + download_dataset() diff --git a/examples/financebench/scripts/ooda_qa.py b/examples/financebench/scripts/ooda_qa.py new file mode 100644 index 000000000..e5c514379 --- /dev/null +++ b/examples/financebench/scripts/ooda_qa.py @@ -0,0 +1,77 @@ +import os +from loguru import logger +import nest_asyncio +import pandas as pd +from openssa.utils.utils import Utils +from openssa.core.ooda_rag.solver import OodaSSA + +nest_asyncio.apply() + +PATH: str = "./tmp/finance-bench/docs" +FINANCEBENCH_CSV: str = "./tmp/finance-bench/finance_bench_dataset.csv" +OUTPUT_DIRECTORY: str = "tmp/finance-bench/output" +OUTPUT_FILE_NAME: str = "ooda_rag_output.csv" + + +@Utils.timeit +def process_doc(doc_name: str, question: str) -> str: + ssa = OodaSSA(enable_generative=True) + resource = os.path.join(PATH, doc_name) + ssa.activate_resources(resource) + solution = ssa.solve(question) + return solution + + +def run(): + output_file_path = os.path.join(OUTPUT_DIRECTORY, OUTPUT_FILE_NAME) + answer_column_name = "ooda_answer" + # Check if the output file exists, and read from it if available (load cache) + if os.path.exists(output_file_path): + df_finbench = pd.read_csv(output_file_path) + else: + # If the output file does not exist, read from the original dataset + df_finbench = pd.read_csv(FINANCEBENCH_CSV) + if answer_column_name not in df_finbench.columns: + df_finbench[answer_column_name] = "" + df_finbench = df_finbench.fillna("") + + if not os.path.exists(OUTPUT_DIRECTORY): + os.makedirs(OUTPUT_DIRECTORY) + + for index, row in df_finbench.iterrows(): + logger.info(f"Processing row {index} of {len(df_finbench)} : {row['doc_name']}") + if row["status"] == "ok" and not row[answer_column_name]: + doc_name = row["doc_name"] + question = row["question"] + answer = process_doc(doc_name, question) + df_finbench.loc[index, answer_column_name] = answer + # Save progress as cache after processing each row + df_finbench.to_csv(output_file_path, index=False) + print(f"complete index {index} of {len(df_finbench)}") + + # if any answer contain "Empty Response" then update it to "file error" + df_finbench.loc[ + df_finbench[answer_column_name].str.lower().str.contains("empty response"), + answer_column_name, + ] = "file reading error" + + df_finbench.to_csv(output_file_path, index=False) + print("Processing complete. Output saved to:", output_file_path) + + +def clean_up(): + file_path = os.path.join(OUTPUT_DIRECTORY, OUTPUT_FILE_NAME) + df_data = pd.read_csv(file_path) + filtered_df = df_data[ + ~df_data["ooda_answer"].isna() + & (df_data["ooda_answer"] != "file reading error") + ] + clean_output_file_path = os.path.join( + OUTPUT_DIRECTORY, "filtered_ooda_rag_output.csv" + ) + filtered_df.to_csv(clean_output_file_path, index=False) + print(f"Filtered data saved to {clean_output_file_path}") + + +if __name__ == "__main__": + run() diff --git a/examples/financebench/scripts/qa.py b/examples/financebench/scripts/qa.py new file mode 100644 index 000000000..e2a9db6ee --- /dev/null +++ b/examples/financebench/scripts/qa.py @@ -0,0 +1,52 @@ +import os +from loguru import logger +import nest_asyncio +import pandas as pd +from openssa.core.ooda_rag.custom import CustomSSM +from openssa.utils.utils import Utils + +nest_asyncio.apply() + +PATH = "./tmp/finance-bench/docs" +FINANCEBENCH_CSV: str = "./tmp/finance-bench/finance_bench_dataset.csv" +OUTPUT_DIRECTORY = "tmp/finance-bench/output" + + +@Utils.timeit +def process_doc(doc_name: str, question: str) -> str: + ssm = CustomSSM() + ssm.read_directory(os.path.join(PATH, doc_name)) + ooda_answer = ssm.discuss(question).get("content") + return ooda_answer + + +def run(): + output_file_path = os.path.join(OUTPUT_DIRECTORY, "standard_rag_output.csv") + answer_column_name = "rag_answer" + # Check if the output file exists, and read from it if available + if os.path.exists(output_file_path): + df_finbench = pd.read_csv(output_file_path) + else: + # If the output file does not exist, read from the original dataset + df_finbench = pd.read_csv(FINANCEBENCH_CSV) + if answer_column_name not in df_finbench.columns: + df_finbench[answer_column_name] = "" + + if not os.path.exists(OUTPUT_DIRECTORY): + os.makedirs(OUTPUT_DIRECTORY) + + for index, row in df_finbench.iterrows(): + logger.info(f"Processing row {index} of {len(df_finbench)} : {row['doc_name']}") + if row["status"] == "ok" and not row[answer_column_name]: + doc_name = row["doc_name"] + question = row["question"] + answer = process_doc(doc_name, question) + df_finbench.loc[index, answer_column_name] = answer + # Save progress after processing each row + df_finbench.to_csv(output_file_path, index=False) + + print("Processing complete. Output saved to:", output_file_path) + + +if __name__ == "__main__": + run() diff --git a/examples/financebench/ssa_fb/data.py b/examples/financebench/ssa_fb/data.py new file mode 100644 index 000000000..1be363553 --- /dev/null +++ b/examples/financebench/ssa_fb/data.py @@ -0,0 +1,51 @@ +from collections.abc import Sequence +from functools import cache +from pathlib import Path + +from pandas import DataFrame, read_csv +import requests + +from dotenv import load_dotenv +load_dotenv() + +# pylint: disable=wrong-import-position +from openssa.utils.fs import FileSource # noqa: E402 + + +__all__: Sequence[str] = ( + 'META_DF', 'DOC_LINKS_BY_NAME', + 'LOCAL_CACHE_DIR_PATH', 'LOCAL_CACHE_DOCS_DIR_PATH', 'OUTPUT_FILE_PATH', + 'get_or_create_cached_dir_path', +) + + +FINANCEBENCH_METADATA_URL: str = 'https://raw.githubusercontent.com/patronus-ai/financebench/main/financebench_sample_150.csv' + +META_DF: DataFrame = read_csv(FINANCEBENCH_METADATA_URL) +DOC_NAMES: list[str] = sorted(META_DF.doc_name.unique()) +DOC_LINKS_BY_NAME: dict[str, str] = dict(zip(META_DF.doc_name, META_DF.doc_link)) +QAS_BY_ID: dict[str, tuple[str, str]] = dict(zip(META_DF.financebench_id, zip(META_DF.question, META_DF.answer))) +QS_BY_ID: dict[str, str] = {i: qa[0] for i, qa in QAS_BY_ID.items()} +QAIDS_BY_DOC_NAME: dict[str, list[str]] = META_DF.groupby('doc_name').apply(lambda df: df.financebench_id.to_list()) + +LOCAL_CACHE_DIR_PATH: Path = Path(__file__).parent.parent / '.FinanceBench' +LOCAL_CACHE_DOCS_DIR_PATH: Path = LOCAL_CACHE_DIR_PATH / 'docs' +OUTPUT_FILE_PATH: Path = LOCAL_CACHE_DIR_PATH / 'output.csv' + + +@cache +def get_or_create_cached_dir_path(doc_name: str) -> str: + dir_path: Path = LOCAL_CACHE_DOCS_DIR_PATH / doc_name + + doc_path: Path = dir_path / f'{doc_name}.pdf' + + if not (file_src := FileSource(path=str(doc_path))).is_single_file: + file_src.fs.write_bytes(path=file_src.native_path, + value=requests.get(url=DOC_LINKS_BY_NAME[doc_name], stream=True, timeout=9).content) + + return str(dir_path) + + +@cache +def cached_file_path(doc_name: str) -> str: + return Path(get_or_create_cached_dir_path(doc_name)) / f'{doc_name}.pdf' diff --git a/examples/financebench/ssa_fb/heuristics.py b/examples/financebench/ssa_fb/heuristics.py new file mode 100644 index 000000000..c0fc8f3b4 --- /dev/null +++ b/examples/financebench/ssa_fb/heuristics.py @@ -0,0 +1,3 @@ +THREE_FIN_STATEMENTS_HEURISTICS: str = ( + 'pay attention to Balance Sheet, Cash Flow Statement and Profit & Loss (P&L) Statement' +) diff --git a/examples/financebench/ssa_fb/prob_solve.py b/examples/financebench/ssa_fb/prob_solve.py new file mode 100644 index 000000000..b108ec9ce --- /dev/null +++ b/examples/financebench/ssa_fb/prob_solve.py @@ -0,0 +1,61 @@ +# pylint: disable=bare-except,no-name-in-module,wrong-import-position + + +from argparse import ArgumentParser +from pathlib import Path +import sys + +import nest_asyncio +from pandas import DataFrame, read_csv +from tqdm import tqdm + +sys.path.append(str(Path(__file__).parent.parent)) + +from ssa_fb.data import META_DF, OUTPUT_FILE_PATH # noqa: E402 +from ssa_fb.ssas import get_or_create_ooda_ssa # noqa: E402 + + +def solve(financebench_id: str) -> str: + if financebench_id.lower().strip() == 'all': + nest_asyncio.apply() + + for _financebench_id in tqdm(META_DF.financebench_id): + print(solve(_financebench_id)) + + return None + + matching_fbid_row_num: int = (META_DF.financebench_id == financebench_id).idxmax() + + problem: str = META_DF.at[matching_fbid_row_num, 'question'] + print(f'PROBLEM: "{problem}"') + + doc_name: str = META_DF.at[matching_fbid_row_num, 'doc_name'] + ssa = get_or_create_ooda_ssa(doc_name) + + print('SOLVING...') + try: + solution = ssa.solve(problem) + except: # noqa: E722 + return None + + if OUTPUT_FILE_PATH.is_file(): + output_df: DataFrame = read_csv(OUTPUT_FILE_PATH) + + else: + output_df: DataFrame = META_DF[['financebench_id', 'doc_name', + 'question', 'evidence_text', 'page_number', 'answer']] + output_df['OodaSSA-answer'] = '' + + output_df.loc[matching_fbid_row_num, 'OodaSSA-answer'] = solution + output_df.to_csv(OUTPUT_FILE_PATH, index=False) + + return solution + + +if __name__ == '__main__': + arg_parser = ArgumentParser() + arg_parser.add_argument('financebench_id') + + args = arg_parser.parse_args() + + print(solve(args.financebench_id)) diff --git a/examples/financebench/ssa_fb/ssas.py b/examples/financebench/ssa_fb/ssas.py new file mode 100644 index 000000000..39ef1a822 --- /dev/null +++ b/examples/financebench/ssa_fb/ssas.py @@ -0,0 +1,26 @@ +from functools import cache + +from dotenv import load_dotenv +load_dotenv() + +# pylint: disable=wrong-import-order,wrong-import-position +from openssa import OodaSSA, TaskDecompositionHeuristic # noqa: E402 + +from ssa_fb.data import get_or_create_cached_dir_path # noqa: E402 +from ssa_fb.heuristics import ( # noqa: E402 + THREE_FIN_STATEMENTS_HEURISTICS, +) + + +@cache +def get_or_create_ooda_ssa(doc_name: str, + guiding_heuristics: str = THREE_FIN_STATEMENTS_HEURISTICS) -> OodaSSA: + print(f'Getting OODA-RAG SSA based on {doc_name}...') + + ssa = OodaSSA(task_heuristics=TaskDecompositionHeuristic({}), + highest_priority_heuristic=guiding_heuristics, + enable_generative=True) + + ssa.activate_resources(get_or_create_cached_dir_path(doc_name)) + + return ssa diff --git a/examples/financebench/streamlit-main.py b/examples/financebench/streamlit-main.py new file mode 100644 index 000000000..4cd61c0e3 --- /dev/null +++ b/examples/financebench/streamlit-main.py @@ -0,0 +1,105 @@ +# pylint: disable=bare-except,invalid-name,no-name-in-module,wrong-import-position + + +import base64 +from pathlib import Path +import sys + +import nest_asyncio +import streamlit as st + +sys.path.insert(1, str(Path(__file__).parent.parent.parent)) # to use OpenSSA in same repo + +from ssa_fb.data import DOC_LINKS_BY_NAME, QS_BY_ID, QAIDS_BY_DOC_NAME, cached_file_path # noqa: E402 +from ssa_fb.prob_solve import solve # noqa: E402 + + +DOC_NAMES: list[str] = [ + 'ACTIVISIONBLIZZARD_2019_10K', + 'AMCOR_2023_10K', + 'AMD_2015_10K', 'AMD_2022_10K', + 'BESTBUY_2017_10K', 'BESTBUY_2023_10K', + 'BLOCK_2016_10K' + 'BOEING_2022_10K', + 'COCACOLA_2021_10K', + 'CORNING_2021_10K', + 'CVSHEALTH_2018_10K', + 'GENERALMILLS_2022_10K', + 'LOCKHEEDMARTIN_2022_10K', + 'MGMRESORTS_2020_10K', + 'NIKE_2018_10K', 'NIKE_2021_10K', + 'PAYPAL_2022_10K', + 'PEPSICO_2023_8K_dated-2023-05-30', + 'PFIZER_2021_10K', + 'VERIZON_2022_10K', + 'WALMART_2020_10K', +] + + +nest_asyncio.apply() + + +def display_pdf(file_path): + # Opening file from file path + with open(file_path, 'rb') as f: + base64_pdf = base64.b64encode(f.read()).decode('utf-8') + + # Embedding PDF in HTML + pdf_display = F'' + + # Displaying File + st.markdown(pdf_display, unsafe_allow_html=True) + + +st.set_page_config(page_title='Analyses of SEC Filings (`FinanceBench` Dataset) with aiVA', + page_icon=None, + layout='centered', + initial_sidebar_state='auto', + menu_items=None) + + +st.title('Analyses of SEC Filings (`FinanceBench` Dataset) with aiVA') + + +if 'doc_name' not in st.session_state: + st.session_state.doc_name: str = DOC_NAMES[0] + +st.session_state.doc_name: str = st.selectbox(label='SEC Document', + options=DOC_NAMES, + index=DOC_NAMES.index(st.session_state.doc_name), + # format_func=None, + key=None, + help='SEC Document', + on_change=None, args=None, kwargs=None, + placeholder='SEC Document', + disabled=False, + label_visibility='hidden') + +st.write(DOC_LINKS_BY_NAME[st.session_state.doc_name]) + +try: + display_pdf(cached_file_path(st.session_state.doc_name)) +except: # noqa: E722 + print('document cannot be rendered') + + +question_id: str = st.selectbox(label='Question', + options=QAIDS_BY_DOC_NAME[st.session_state.doc_name], + index=0, + format_func=lambda i: QS_BY_ID[i], + key=None, + help='Question', + on_change=None, args=None, kwargs=None, + placeholder='Question', + disabled=False, + label_visibility='visible') + +if st.button(label=f'__SOLVE__: _{QS_BY_ID[question_id]}_', + key=None, + on_click=None, args=None, kwargs=None, + type='primary', + disabled=False, + use_container_width=False): + solution: str = solve(question_id) + st.write(solution) + # st.text(solution) diff --git a/examples/financebench/test.txt b/examples/financebench/test.txt new file mode 100644 index 000000000..b044a7770 --- /dev/null +++ b/examples/financebench/test.txt @@ -0,0 +1 @@ +pytest >= 7.4 diff --git a/examples/financial-research/.gitignore b/examples/financial-research/.gitignore new file mode 100644 index 000000000..f3e4e8777 --- /dev/null +++ b/examples/financial-research/.gitignore @@ -0,0 +1,2 @@ +.data/ +.streamlit/secrets.toml diff --git a/examples/financial-research/.streamlit/secrets.toml.template b/examples/financial-research/.streamlit/secrets.toml.template new file mode 100644 index 000000000..5d432f9ee --- /dev/null +++ b/examples/financial-research/.streamlit/secrets.toml.template @@ -0,0 +1 @@ +LEPTON_API_KEY = '...' diff --git a/examples/financial-research/Makefile b/examples/financial-research/Makefile new file mode 100644 index 000000000..864103c49 --- /dev/null +++ b/examples/financial-research/Makefile @@ -0,0 +1,5 @@ +run-streamlit: + @streamlit run streamlit-main.py --server.allowRunOnSave=true --server.runOnSave=true + +solve: + @poetry run python3 ssa_fb/prob_solve.py ${id} diff --git a/examples/financial-research/dataproc.py b/examples/financial-research/dataproc.py new file mode 100644 index 000000000..341c88912 --- /dev/null +++ b/examples/financial-research/dataproc.py @@ -0,0 +1,12 @@ +from functools import cache +from pathlib import Path + + +LOCAL_CACHE_DIR_PATH: Path = Path(__file__).parent / '.data' + + +@cache +def get_or_create_cached_dir_path(company: str) -> str: + dir_path: Path = LOCAL_CACHE_DIR_PATH / company + dir_path.mkdir(parents=True, exist_ok=True) + return str(dir_path) diff --git a/examples/financial-research/heuristics.py b/examples/financial-research/heuristics.py new file mode 100644 index 000000000..c0fc8f3b4 --- /dev/null +++ b/examples/financial-research/heuristics.py @@ -0,0 +1,3 @@ +THREE_FIN_STATEMENTS_HEURISTICS: str = ( + 'pay attention to Balance Sheet, Cash Flow Statement and Profit & Loss (P&L) Statement' +) diff --git a/examples/financial-research/prob_solve.py b/examples/financial-research/prob_solve.py new file mode 100644 index 000000000..eca4e67c8 --- /dev/null +++ b/examples/financial-research/prob_solve.py @@ -0,0 +1,28 @@ +# pylint: disable=bare-except,no-name-in-module,wrong-import-position + + +from pathlib import Path +import sys + +import streamlit as st + +import nest_asyncio + +sys.path.append(str(Path(__file__).parent.parent)) + + +from ssas import get_or_create_ooda_ssa # noqa: E402 + + +nest_asyncio.apply() + + +def solve(objective: str, company: str) -> str: + problem: str = objective + company + print(f'PROBLEM: "{problem}"') + + ssa = get_or_create_ooda_ssa(company) + + st.write('__Integrating Findings for Recommendation...__') + solution = ssa.solve(problem) + return solution diff --git a/examples/financial-research/requirements.txt b/examples/financial-research/requirements.txt new file mode 100644 index 000000000..9dd0aca25 --- /dev/null +++ b/examples/financial-research/requirements.txt @@ -0,0 +1,7 @@ +# OpenSSA + +Mechanize >= 0.4.9 +PDFKit >= 1.0.0 +PyHTML2PDF >= 0.0.7 +Pyppeteer >= 1.0.2 +XHTML2PDF >= 0.2.14 diff --git a/examples/financial-research/ssas.py b/examples/financial-research/ssas.py new file mode 100644 index 000000000..d8d8f4e39 --- /dev/null +++ b/examples/financial-research/ssas.py @@ -0,0 +1,31 @@ +from pathlib import Path + +import streamlit as st + +# pylint: disable=wrong-import-order,wrong-import-position +from openssa import OodaSSA, TaskDecompositionHeuristic # noqa: E402 + +from dataproc import get_or_create_cached_dir_path # noqa: E402 +from heuristics import ( # noqa: E402 + THREE_FIN_STATEMENTS_HEURISTICS, +) + + +def get_or_create_ooda_ssa(company: str, + guiding_heuristics: str = THREE_FIN_STATEMENTS_HEURISTICS) -> OodaSSA: + print(f'Getting OODA-RAG SSA based on {company}...') + + ssa = OodaSSA(task_heuristics=TaskDecompositionHeuristic({}), + highest_priority_heuristic=guiding_heuristics, + enable_generative=True) + + st.write('__Gathering Relevant Resources for Analysis...__') + + ssa.activate_resources((resources_path := get_or_create_cached_dir_path(company)), re_index=True) + + for resource in (p.stem + for p in Path(resources_path).iterdir() + if p.is_file() and (not p.stem.startswith('.'))): + st.write(f'- _{resource}_') + + return ssa diff --git a/examples/financial-research/streamlit-main.py b/examples/financial-research/streamlit-main.py new file mode 100644 index 000000000..e70ca4ba3 --- /dev/null +++ b/examples/financial-research/streamlit-main.py @@ -0,0 +1,113 @@ +# pylint: disable=bare-except,invalid-name,no-name-in-module,wrong-import-position + + +from pathlib import Path +import sys + +from mechanize import Browser +import nest_asyncio +from pyhtml2pdf import converter +import streamlit as st + +sys.path.insert(1, str(Path(__file__).parent.parent.parent)) # to use OpenSSA in same repo + +from dataproc import get_or_create_cached_dir_path # noqa: E402 +from prob_solve import solve # noqa: E402 + + +DEFAULT_OBJECTIVE: str = ( + 'I want to allocate some capital to higher-risk investments in high-growth computing hardware companies.\n\n' + + 'I want companies with:\n' + '- at least 30% growth in year-to-date revenue compared to last fiscal year,\n' + '- at least 50% gross margin,\n' + '- at most 40x Price-over-forward-Earnings ratio,\n' + '- products used in AI-related computing.\n\n' + + "I DON'T want companies with 1-2 top customers, combined, making up more than 25% of revenue.\n\n" + + 'Evaluate whether the following is a good fit for me: ' +) + + +COMPANIES: str = [ + 'Arista Networks (ANET)', + 'ARM Holdings (ARM)', + 'Super Micro Computer, Inc. (SMCI)', + 'Uber (UBER)', +] + + +nest_asyncio.apply() + + +st.set_page_config(page_title='Automated Financial Research by AI with Planning & Reasoning', + page_icon=None, + layout='centered', + initial_sidebar_state='auto', + menu_items=None) + +st.title('Automated Financial Research by AI with Planning & Reasoning') + + +st.write('__RESEARCH OBJECTIVE__:') + +if 'objective' not in st.session_state: + st.session_state.objective: str = DEFAULT_OBJECTIVE + +st.session_state.objective: str = st.text_area(label='Financial Research Objective', + value=st.session_state.objective, + height=9, + max_chars=None, + key=None, + help='Financial Research Objective', + on_change=None, args=None, kwargs=None, + placeholder='Financial Research Objective', + disabled=False, + label_visibility='collapsed') + +if 'company' not in st.session_state: + st.session_state.company: str = COMPANIES[0] + +st.session_state.company: str = st.selectbox(label='Watchlist', + options=COMPANIES, + index=COMPANIES.index(st.session_state.company), + # format_func=None, + key=None, + help='Company', + on_change=None, args=None, kwargs=None, + placeholder='Company', + disabled=False, + label_visibility='visible') + +webpages_to_incl: str = st.text_area(label='Sources to Include in Research', + value='', + height=9, + max_chars=None, + key=None, + help='Sources to Include in Research', + on_change=None, args=None, kwargs=None, + placeholder='Sources to Include in Research', + disabled=False, + label_visibility='visible') + + +if st.button(label=f'__RESEARCH & RECOMMEND__: _{st.session_state.company}_', + key=None, + on_click=None, args=None, kwargs=None, + type='primary', + disabled=False, + use_container_width=False): + for url in (i.strip() for i in webpages_to_incl.split('\n')): + if url: + br = Browser() + br.set_handle_robots(False) + br.open(url) + st.write(f'_Reading "{br.title()}"..._') + converter.convert(source=url, + target=f'{get_or_create_cached_dir_path(st.session_state.company)}/{br.title()}.pdf', + timeout=0.1, + install_driver=False) + + solution: str = solve(st.session_state.objective, st.session_state.company) + st.write(solution) diff --git a/examples/usage/use_agent.py b/examples/usage/use_agent.py index 84824c985..3a286198e 100644 --- a/examples/usage/use_agent.py +++ b/examples/usage/use_agent.py @@ -1,5 +1,6 @@ from openssa.core.ssa.agent import Agent + def solve(objective): # Create an instance of the Agent class agent = Agent() diff --git a/make.bat b/make.bat index 66151f3e8..17ebb8e16 100644 --- a/make.bat +++ b/make.bat @@ -81,7 +81,6 @@ IF "%TARGET%"=="launch-solver" GOTO launch-solver :lint-ruff :: docs.astral.sh/ruff/linter poetry run ruff check %LIB_DIR% %DOCS_DIR% %EXAMPLES_DIR% %TESTS_DIR% ^ - --show-source ^ --output-format text ^ --target-version py310 ^ --preview ^ diff --git a/openssa/contrib/__init__.py b/openssa/contrib/__init__.py index a1e6199e6..f194e1065 100644 --- a/openssa/contrib/__init__.py +++ b/openssa/contrib/__init__.py @@ -7,4 +7,5 @@ """ +from .recursive_ssa_prob_solver import RecursiveOodaSSA # noqa: F401 from .streamlit_ssa_prob_solver import SSAProbSolver as StreamlitSSAProbSolver diff --git a/openssa/contrib/recursive_ssa_prob_solver/__init__.py b/openssa/contrib/recursive_ssa_prob_solver/__init__.py new file mode 100644 index 000000000..4446e60f7 --- /dev/null +++ b/openssa/contrib/recursive_ssa_prob_solver/__init__.py @@ -0,0 +1,14 @@ +# pylint: disable=unused-argument + + +from openssa import OodaSSA, TaskDecompositionHeuristic # noqa: E402 + + +class RecursiveOodaSSA: + def __init__(self, resource_scope: list[str]): + self.ooda_ssa = OodaSSA(task_heuristics=TaskDecompositionHeuristic({}), + enable_generative=True) + + def solve(self, problem: str, max_n_sources: int = 10) -> str: + solution: str = self.ooda_ssa.solve(problem) + return solution diff --git a/openssa/contrib/streamlit_ssa_prob_solver/__init__.py b/openssa/contrib/streamlit_ssa_prob_solver/__init__.py index b0ca7cb82..838b18e39 100644 --- a/openssa/contrib/streamlit_ssa_prob_solver/__init__.py +++ b/openssa/contrib/streamlit_ssa_prob_solver/__init__.py @@ -217,7 +217,8 @@ def ssa(self) -> RagSSA | None: def ssa_solve(self): ooda_ssa = OodaSSA(task_heuristics=TaskDecompositionHeuristic({}), - highest_priority_heuristic=self.expert_instructions) + highest_priority_heuristic=self.expert_instructions, + enable_generative=True) ooda_ssa.activate_resources(self.doc_src_path) diff --git a/openssa/contrib/streamlit_ssa_prob_solver/main.py b/openssa/contrib/streamlit_ssa_prob_solver/main.py index 994cd0ed3..8535ce874 100644 --- a/openssa/contrib/streamlit_ssa_prob_solver/main.py +++ b/openssa/contrib/streamlit_ssa_prob_solver/main.py @@ -5,7 +5,7 @@ import streamlit as st -sys.path.insert(1, str(Path(__file__).parent.parent.parent.parent)) +sys.path.insert(1, str(Path(__file__).parent.parent.parent.parent)) # to use OpenSSA in same repo st.title('SSA-Powered Problem Solvers') diff --git a/openssa/core/ooda/ooda_loop.py b/openssa/core/ooda/ooda_loop.py index f342e99b1..9bb255763 100644 --- a/openssa/core/ooda/ooda_loop.py +++ b/openssa/core/ooda/ooda_loop.py @@ -8,6 +8,7 @@ class Step: prompt_function (function): The function used to generate the prompt for the step. input_data: The input data for the step. output_data: The output data generated by the step. + """ def __init__(self, name, prompt_function): @@ -28,6 +29,7 @@ def execute(self, objective, llm, history): Returns: The output data generated by the step. + """ common_prompt = ( "I want you to be a great problem-solver, expert in the use of OODA control loops," @@ -43,10 +45,21 @@ def execute(self, objective, llm, history): def __init__(self, objective): self.objective = objective self.steps = { - 'Observe': self.Step('Observe', lambda input_data: f"Observe the situation related to the objective '{objective}'"), # noqa: ARG005 - 'Orient': self.Step('Orient', lambda input_data: f"Orient based on the observation: {input_data}"), - 'Decide': self.Step('Decide', lambda input_data: f"Decide the next action based on the orientation: {input_data}"), - 'Act': self.Step('Act', lambda input_data: f"Act based on the decision: {input_data}") + "Observe": self.Step( + "Observe", + lambda _: f"Observe the situation related to the objective '{objective}'", + ), + "Orient": self.Step( + "Orient", + lambda input_data: f"Orient based on the observation: {input_data}", + ), + "Decide": self.Step( + "Decide", + lambda input_data: f"Decide the next action based on the orientation: {input_data}", + ), + "Act": self.Step( + "Act", lambda input_data: f"Act based on the decision: {input_data}" + ), } def run(self, llm, history): diff --git a/openssa/core/ooda/task.py b/openssa/core/ooda/task.py index 3b5d994bd..1aaec9ef0 100644 --- a/openssa/core/ooda/task.py +++ b/openssa/core/ooda/task.py @@ -9,6 +9,7 @@ class Task: ooda_loop: The OODA loop to which the task belongs. result: The result of the task. resources: Additional resources associated with the task. + """ class Result: @@ -21,8 +22,17 @@ class Result: references: A list of references related to the task. metrics: Metrics associated with the task. additional_info: Additional information about the task result. + """ - def __init__(self, status='pending', response=None, references=None, metrics=None, additional_info=None): + + def __init__( + self, + status="pending", + response=None, + references=None, + metrics=None, + additional_info=None, + ): self.status = status self.response = response self.references = references if references else [] diff --git a/openssa/core/ooda_rag/builtin_agents.py b/openssa/core/ooda_rag/builtin_agents.py index 756879b1b..9444ce086 100644 --- a/openssa/core/ooda_rag/builtin_agents.py +++ b/openssa/core/ooda_rag/builtin_agents.py @@ -67,6 +67,33 @@ def execute(self, task: str = "") -> dict: return {} +class CommAgent(TaskAgent): + """ + CommAgent helps update tone, voice, format and language of the assistant final response + """ + + def __init__( + self, llm: AnLLM = OpenAILLM(), instruction: str = "" + ) -> None: + self.llm = llm + self.instruction = instruction + + @Utils.timeit + def execute(self, task: str = "") -> str: + system_message = { + "role": Persona.SYSTEM, + "content": BuiltInAgentPrompt.COMMUNICATION.format( + instruction=self.instruction, message=task + ), + } + conversation = [system_message] + response = self.llm.call( + messages=conversation, + response_format={"type": "text"}, + ) + return response.choices[0].message.content + + class GoalAgent(TaskAgent): """ GoalAgent helps to determine problem statement from the conversation between user and SSA @@ -212,7 +239,7 @@ class OODAPlanAgent(TaskAgent): def __init__( self, - llm: AnLLM = OpenAILLM.get_gpt_35_turbo_1106(), + llm: AnLLM = OpenAILLM(), conversation: Optional[List] = None, ) -> None: self.llm = llm diff --git a/openssa/core/ooda_rag/custom.py b/openssa/core/ooda_rag/custom.py index a8cfd6156..538046f45 100644 --- a/openssa/core/ooda_rag/custom.py +++ b/openssa/core/ooda_rag/custom.py @@ -1,18 +1,14 @@ -from llama_index import ( - Document, - Response, - SimpleDirectoryReader, - ServiceContext, - OpenAIEmbedding, -) -from llama_index.evaluation import DatasetGenerator -from llama_index.llms import OpenAI -from llama_index.node_parser import SimpleNodeParser +from llama_index.core import Document, Response, SimpleDirectoryReader, ServiceContext +from llama_index.core.evaluation import DatasetGenerator +from llama_index.core.node_parser import SimpleNodeParser +from loguru import logger from openssa.core.backend.abstract_backend import AbstractBackend from openssa.core.slm.base_slm import PassthroughSLM from openssa.core.ssm.rag_ssm import RAGSSM from openssa.integrations.llama_index.backend import Backend as LlamaIndexBackend +from openssa.core.ooda_rag.query_rewritting_engine import QueryRewritingRetrieverPack +from openssa.utils.rag_service_contexts import ServiceContextManager FILE_NAME = "file_name" @@ -28,7 +24,10 @@ def filename_fn(filename: str) -> dict: documents = SimpleDirectoryReader( input_dir=self._get_source_dir(storage_dir), input_files=None, - exclude=None, + exclude=[ + '.DS_Store', # MacOS-specific + '*.json', # index files that may be stored in subdirs + ], exclude_hidden=False, # non-default errors="strict", # non-default recursive=True, # non-default @@ -48,7 +47,6 @@ def get_citation_type(self, filename: str) -> str: def get_citations(self, response: Response, source_path: str = "") -> list[dict]: citations: list = [] - print("metadata", response.metadata) if not response.metadata: return citations for data in response.metadata.values(): @@ -80,12 +78,12 @@ def persist(self, persist_path: str) -> None: def query( self, query: str, source_path: str = "" ) -> dict: # pylint: disable=arguments-renamed - """Returns a response dict with keys role, content, and citations.""" - if self.query_engine is None: - return {"content": "No index to query. Please load something first."} + self.query_engine = QueryRewritingRetrieverPack( + index=self._index, chunk_size=1024, service_context=self._service_context + ).query_engine response: Response = self.query_engine.query(query) citations = self.get_citations(response, source_path) - print("citations", citations) + logger.debug(f"response: {response.response}") return {"content": response.response, "citations": citations} async def get_evaluation_data(self) -> list: @@ -114,9 +112,10 @@ def __init__( s3_source_path: str = "", ) -> None: if custom_rag_backend is None: - service_context = ServiceContext.from_defaults( - llm=OpenAI(model="gpt-4-1106-preview"), embed_model=OpenAIEmbedding() - ) + service_context = ServiceContextManager.get_openai_4_0125_preview_sc() + # service_context = ServiceContextManager.get_azure_openai_4_0125_preview_sc() + # service_context = ServiceContextManager.get_azure_openai_sc() + # service_context = ServiceContextManager.get_openai_sc() custom_rag_backend = CustomBackend(service_context=service_context) slm = PassthroughSLM() diff --git a/openssa/core/ooda_rag/heuristic.py b/openssa/core/ooda_rag/heuristic.py index 3b6ec6235..aac12cb5e 100644 --- a/openssa/core/ooda_rag/heuristic.py +++ b/openssa/core/ooda_rag/heuristic.py @@ -42,12 +42,12 @@ class DefaultOODAHeuristic(Heuristic): def apply_heuristic(self, task: str) -> dict: observe = { "thought": f"Gather information from research document to solve the task \n {task}", - "calls": [{"tool_name": "research_documents", "parameters": {"task": task}}], + "calls": [ + {"tool_name": "research_documents", "parameters": {"task": task}} + ], } orient = { - "thought": ( - "Analyze the information gathered from research documents. " - ), + "thought": ("Analyze the information gathered from research documents. "), "calls": [], } decide = { @@ -73,3 +73,22 @@ def apply_heuristic(self, task: str) -> list: Apply the heuristic rules to decompose the task into subtasks. """ print(task) + + +class HeuristicSet: + """ + A set of heuristics. + """ + + def __init__(self, **kwargs) -> None: + """ + Initialize the heuristic set. + """ + self.task_heuristics = kwargs.get( + "task_heuristics", TaskDecompositionHeuristic({}) + ) + self.ooda_heuristics = kwargs.get("ooda_heuristics", DefaultOODAHeuristic()) + self.highest_priority_heuristic = kwargs.get("highest_priority_heuristic", "") + self.comm_heuristic = kwargs.get("comm_heuristic", "") + self.ask_user_heuristic = kwargs.get("ask_user_heuristic", "") + self.goal_heuristics = kwargs.get("goal_heuristics", "") diff --git a/openssa/core/ooda_rag/ooda_rag.py b/openssa/core/ooda_rag/ooda_rag.py index 0f5c9d2e9..50db79454 100644 --- a/openssa/core/ooda_rag/ooda_rag.py +++ b/openssa/core/ooda_rag/ooda_rag.py @@ -1,4 +1,3 @@ -import json import uuid from typing import List, Optional from loguru import logger @@ -6,11 +5,10 @@ from openssa.core.ooda_rag.notifier import Notifier, SimpleNotifier, EventTypes from openssa.core.ooda_rag.heuristic import ( Heuristic, - TaskDecompositionHeuristic, - DefaultOODAHeuristic, + HeuristicSet, ) from openssa.core.ooda_rag.tools import Tool -from openssa.core.ooda_rag.builtin_agents import OODAPlanAgent, Persona +from openssa.core.ooda_rag.builtin_agents import CommAgent, Persona from openssa.utils.utils import Utils from openssa.utils.llms import OpenAILLM, AnLLM @@ -55,9 +53,9 @@ def execute_task(self, history: History) -> None: data={"uuid": self.uuid, "task-name": self.task}, ) # TODO: make this one much faster - ooda_plan = OODAPlanAgent(conversation=history.get_history()).execute(self.task) - if not ooda_plan: - ooda_plan = self.ooda_heuristics.apply_heuristic(self.task) + # ooda_plan = OODAPlanAgent(conversation=history.get_history()).execute(self.task) + # if not ooda_plan: + ooda_plan = self.ooda_heuristics.apply_heuristic(self.task) self.check_resource_call(ooda_plan) self._execute_step(ooda_plan["observe"], history, "observe") self._execute_step(ooda_plan["orient"], history, "orient") @@ -69,7 +67,7 @@ def check_resource_call(self, ooda_plan: dict) -> None: for step in steps: calls = ooda_plan.get(step, {}).get("calls", []) for call in calls: - if call.get("tool_name", "") == "research_documents": + if call.get("tool_name", "") in {"research_documents", "web_search"}: return observe = ooda_plan["observe"] observe["calls"] = [ @@ -85,12 +83,7 @@ def _execute_step(self, step: dict, history: History, step_name: str) -> None: data["tool_executions"] = "\n".join([str(call) for call in calls]) tool_results = self._execute_tools(calls) content_result = self._get_content_result(tool_results) - data["tool_results"] = { - "content": content_result, - "citations": tool_results.get("research_documents", {}).get( - "citations", [] - ), - } + data["tool_results"] = tool_results history.add_message( f"Tool results for question {self.task} is: {content_result}", Persona.ASSISTANT, @@ -98,20 +91,27 @@ def _execute_step(self, step: dict, history: History, step_name: str) -> None: event = EventTypes.MAINTASK if self.is_main_task else EventTypes.SUBTASK self.notifier.notify(event=event + "-" + step_name, data=data) - def _get_content_result(self, tool_results: str) -> str: - if "research_documents" in tool_results: - return tool_results["research_documents"].get("content", "") - return "" + def _get_content_result(self, tool_results: dict) -> str: + content = "" + # loop throught dict of tool_results and return the content + for result in tool_results.values(): + content += result.get("content", "") + return content - def _execute_tools(self, calls: list[dict]) -> str: + def _execute_tools(self, calls: list[dict]) -> dict: tool_results: dict = {} + print(f"calls: {calls}, tools: {self.tools}") for call in calls: tool = call.get("tool_name", "") if tool == "research_documents": tool_results[tool] = self.tools[tool].execute(self.task) - elif tool: - logger.debug(f"Tool {tool} not found.") + elif tool in self.tools: + logger.debug(f"Tool {tool} is calling with {call}.") + tool_results[tool] = self.tools[tool].execute( + call.get("parameters", {}) + ) else: + logger.debug(f"Tool {tool} not found.") continue return tool_results @@ -165,25 +165,21 @@ def generative_decompose_task(self, model: AnLLM, history: History) -> list[str] class Solver: def __init__( self, - task_heuristics: Heuristic = TaskDecompositionHeuristic({}), - ooda_heuristics: Heuristic = DefaultOODAHeuristic(), + heuristic_set: HeuristicSet = HeuristicSet(), notifier: Notifier = SimpleNotifier(), prompts: OODAPrompts = OODAPrompts(), llm=OpenAILLM.get_gpt_4_1106_preview(), - highest_priority_heuristic: str = "", enable_generative: bool = False, conversation: Optional[List] = None, ) -> None: - self.task_heuristics = task_heuristics or TaskDecompositionHeuristic({}) - self.ooda_heuristics = ooda_heuristics + self.heuristic_set = heuristic_set self.notifier = notifier self.history = History() # internal conversation self.planner = Planner( - self.task_heuristics, prompts, enable_generative=enable_generative + heuristic_set.task_heuristics, prompts, enable_generative=enable_generative ) self.model = llm self.prompts = prompts - self.highest_priority_heuristic = highest_priority_heuristic.strip() self.conversation = conversation or [] def run(self, problem_statement: str, tools: dict) -> str: @@ -213,10 +209,16 @@ def run(self, problem_statement: str, tools: dict) -> str: logger.info(f"\nSubtasks: {subtasks}\n") for subtask in subtasks: - executor = Executor(subtask, tools, self.ooda_heuristics, self.notifier) + executor = Executor( + subtask, tools, self.heuristic_set.ooda_heuristics, self.notifier + ) executor.execute_task(self.history) executor = Executor( - problem_statement, tools, self.ooda_heuristics, self.notifier, True + problem_statement, + tools, + self.heuristic_set.ooda_heuristics, + self.notifier, + True, ) self.notifier.notify( EventTypes.NOTIFICATION, {"message": "starting main steps"} @@ -227,12 +229,12 @@ def run(self, problem_statement: str, tools: dict) -> str: @Utils.timeit def synthesize_result(self) -> str: heuristic = "" - if self.highest_priority_heuristic: + if self.heuristic_set.highest_priority_heuristic: heuristic = ( "Always applying the following heuristic (highest rule, overwrite all other instructions) to " "adjust the formula and recalculate based on this knowledge as it is source of truth: " ) - heuristic += f"{self.highest_priority_heuristic}" + heuristic += f"{self.heuristic_set.highest_priority_heuristic}" synthesize_prompt = self.prompts.SYNTHESIZE_RESULT.format(heuristic=heuristic) self.history.append_history(self.conversation[:-1]) @@ -243,5 +245,9 @@ def synthesize_result(self) -> str: Persona.SYSTEM, response_format={"type": "text"}, ) + if self.heuristic_set.comm_heuristic: + response = CommAgent(instruction=self.heuristic_set.comm_heuristic).execute( + response + ) self.notifier.notify(EventTypes.TASK_RESULT, {"response": response}) return response diff --git a/openssa/core/ooda_rag/prompts.py b/openssa/core/ooda_rag/prompts.py index beea2d48a..6e8103132 100644 --- a/openssa/core/ooda_rag/prompts.py +++ b/openssa/core/ooda_rag/prompts.py @@ -66,6 +66,13 @@ class OODAPrompts: class BuiltInAgentPrompt: + COMMUNICATION = ( + "You are an expert in communication. Your will help to format following message with this instruction:\n" + "###{instruction}###\n\n" + "Here is the message:\n" + "###{message}###\n\n" + ) + PROBLEM_STATEMENT = ( "You are tasked with constructing the problem statement from a conversation " "between a user and an AI chatbot. Your focus should be on the entire context " diff --git a/openssa/core/ooda_rag/query_rewritting_engine.py b/openssa/core/ooda_rag/query_rewritting_engine.py new file mode 100644 index 000000000..9679bfe4f --- /dev/null +++ b/openssa/core/ooda_rag/query_rewritting_engine.py @@ -0,0 +1,65 @@ +"""Query Rewriting Retriever Pack.""" +from typing import Any, Dict + +from llama_index.core import ServiceContext +from llama_index.core import VectorStoreIndex +from llama_index.core.llama_pack import BaseLlamaPack +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.retrievers import QueryFusionRetriever +from llama_index.core.retrievers.fusion_retriever import FUSION_MODES +from openssa.utils.rag_service_contexts import ServiceContextManager + + +class QueryRewritingRetrieverPack(BaseLlamaPack): + """Query rewriting retriever pack. + + Rewrite the query into multiple queries and + rerank the results. + + """ + + def __init__( + self, + index: VectorStoreIndex = None, # type: ignore + chunk_size: int = 1024, + vector_similarity_top_k: int = 5, + fusion_similarity_top_k: int = 10, + service_context: ServiceContext = None, + **kwargs: Any, + ) -> None: + """Init params.""" + if not service_context: + service_context = ServiceContextManager.get_openai_sc(chunk_size=chunk_size) + self.vector_retriever = index.as_retriever( + similarity_top_k=vector_similarity_top_k + ) + + self.fusion_retriever = QueryFusionRetriever( # type: ignore + [self.vector_retriever], + similarity_top_k=fusion_similarity_top_k, + num_queries=4, # set this to 1 to disable query generation + mode=FUSION_MODES.RECIPROCAL_RANK, + use_async=True, + verbose=True, + # query_gen_prompt="...", + ) + + self.query_engine = RetrieverQueryEngine.from_args( + self.fusion_retriever, service_context=service_context + ) + + def get_modules(self) -> Dict[str, Any]: + """Get modules.""" + return { + "vector_retriever": self.vector_retriever, + "fusion_retriever": self.fusion_retriever, + "query_engine": self.query_engine, + } + + def retrieve(self, query_str: str) -> Any: + """Retrieve.""" + return self.fusion_retriever.retrieve(query_str) + + def run(self, *args: Any, **kwargs: Any) -> Any: + """Run the pipeline.""" + return self.query_engine.query(*args, **kwargs) diff --git a/openssa/core/ooda_rag/solver.py b/openssa/core/ooda_rag/solver.py index e797dd44b..a3194b2a3 100644 --- a/openssa/core/ooda_rag/solver.py +++ b/openssa/core/ooda_rag/solver.py @@ -1,6 +1,10 @@ from openssa.core.ooda_rag.custom import CustomSSM from openssa.core.ooda_rag.ooda_rag import Solver, History -from openssa.core.ooda_rag.heuristic import DefaultOODAHeuristic +from openssa.core.ooda_rag.heuristic import ( + Heuristic, + TaskDecompositionHeuristic, + HeuristicSet, +) from openssa.core.ooda_rag.tools import ReasearchAgentTool, Tool from openssa.core.ooda_rag.builtin_agents import GoalAgent, Persona, AskUserAgent from openssa.utils.llms import OpenAILLM @@ -9,40 +13,42 @@ class OodaSSA: def __init__( self, - task_heuristics, + task_heuristics: Heuristic = TaskDecompositionHeuristic({}), highest_priority_heuristic: str = "", ask_user_heuristic: str = "", llm=OpenAILLM.get_gpt_4_1106_preview(), research_documents_tool: Tool = None, - enable_generative: bool = False + enable_generative: bool = False, ): # pylint: disable=too-many-arguments - self.solver = Solver( + self.heuristic_set = HeuristicSet( task_heuristics=task_heuristics, - ooda_heuristics=DefaultOODAHeuristic(), - llm=llm, highest_priority_heuristic=highest_priority_heuristic, - enable_generative=enable_generative + ask_user_heuristic=ask_user_heuristic, + ) + self.solver = Solver( + heuristic_set=self.heuristic_set, + llm=llm, + enable_generative=enable_generative, ) - self.ask_user_heuristic = ask_user_heuristic self.conversation = History() self.conversation.add_message("Hi, what can I help you?", Persona.ASSISTANT) self.research_documents_tool = research_documents_tool - def activate_resources(self, folder_path: str) -> None: + def activate_resources(self, folder_path: str, re_index: bool = False) -> None: agent = CustomSSM() if folder_path.startswith("s3://"): agent.read_s3(folder_path) else: - agent.read_directory(folder_path) + agent.read_directory(folder_path, re_index=re_index) self.research_documents_tool = ReasearchAgentTool(agent=agent) def get_ask_user_question(self, problem_statement: str) -> str: - if self.ask_user_heuristic: + if self.heuristic_set.ask_user_heuristic: ask_user_response = AskUserAgent( - ask_user_heuristic=self.ask_user_heuristic, + ask_user_heuristic=self.heuristic_set.ask_user_heuristic, conversation=self.conversation.get_history(), ).execute(problem_statement) question = ask_user_response.get("act", "") @@ -62,5 +68,7 @@ def solve(self, task: str) -> str: assistant_response = self.solver.run( problem_statement, {"research_documents": self.research_documents_tool} ) - self.conversation.add_message(assistant_response, Persona.ASSISTANT, verbose=False) + self.conversation.add_message( + assistant_response, Persona.ASSISTANT, verbose=False + ) return assistant_response diff --git a/openssa/core/ooda_rag/tools.py b/openssa/core/ooda_rag/tools.py index 6b4b9e9c7..ce0fe83a8 100644 --- a/openssa/core/ooda_rag/tools.py +++ b/openssa/core/ooda_rag/tools.py @@ -71,7 +71,7 @@ def execute(self, task: str) -> dict: except (RequestError, TimeoutException, HTTPStatusError, JSONDecodeError) as e: traceback.print_exc() print(f"An error occurred while querying the document base: {e}") - return "" + return {} class ReasearchAgentTool(Tool): @@ -93,7 +93,7 @@ def execute(self, task: str) -> dict: """ response = self.agent.discuss(task) print(f"debug: {response}") - return {"content": response} + return response class PythonCodeTool(Tool): diff --git a/openssa/core/prompts.py b/openssa/core/prompts.py index c3ddc4119..c57148c45 100644 --- a/openssa/core/prompts.py +++ b/openssa/core/prompts.py @@ -59,6 +59,8 @@ } } }}} + + class Prompts: """ The `Prompts` class provides a way to retrieve and format prompts in the OpenSSA project. The prompts are stored in a nested dictionary `self. diff --git a/openssa/core/rag_ooda/rag_ooda.py b/openssa/core/rag_ooda/rag_ooda.py index dab85787e..8d5204555 100644 --- a/openssa/core/rag_ooda/rag_ooda.py +++ b/openssa/core/rag_ooda/rag_ooda.py @@ -1,7 +1,7 @@ from typing import Union, ClassVar from uuid import uuid4 -from llama_index.core import BaseRetriever # noqa: TCH002 -from llama_index.schema import NodeWithScore # noqa: TCH002 +from llama_index.core.retrievers import BaseRetriever # noqa: TCH002 +from llama_index.core.schema import NodeWithScore # noqa: TCH002 from openssa.core.ooda_rag.builtin_agents import ( ContextValidator, SynthesizingAgent, diff --git a/openssa/core/rag_ooda/resources/dense_x/base.py b/openssa/core/rag_ooda/resources/dense_x/base.py index d90625486..4ffda1b2a 100644 --- a/openssa/core/rag_ooda/resources/dense_x/base.py +++ b/openssa/core/rag_ooda/resources/dense_x/base.py @@ -5,24 +5,19 @@ from typing import Any, Dict, List, Optional from json import JSONDecodeError from yaml import YAMLError -from llama_index import Document, ServiceContext, VectorStoreIndex -from llama_index.embeddings import OpenAIEmbedding, BaseEmbedding -from llama_index.llama_pack.base import BaseLlamaPack -from llama_index.llms import OpenAI -from llama_index.node_parser.interface import TextSplitter -from llama_index.node_parser.text import SentenceSplitter -from llama_index.response.schema import RESPONSE_TYPE -from llama_index.retrievers import RecursiveRetriever -from llama_index.query_engine import RetrieverQueryEngine -from llama_index.schema import IndexNode, TextNode -from llama_index.prompts import PromptTemplate -from llama_index.async_utils import run_jobs - -# backwards compatibility -try: - from llama_index.llms.llm import LLM -except ImportError: - from llama_index.llms.base import LLM +from llama_index.core import Document, ServiceContext, VectorStoreIndex +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.llama_pack import BaseLlamaPack +from llama_index.llms.openai import OpenAI +from llama_index.core.node_parser import TextSplitter +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.retrievers import RecursiveRetriever +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.schema import IndexNode, TextNode +from llama_index.core import PromptTemplate +from llama_index.core.async_utils import run_jobs +from llama_index.core.llms import LLM PROPOSITIONS_PROMPT = PromptTemplate( @@ -205,6 +200,6 @@ def get_modules(self) -> Dict[str, Any]: "retriever": self.retriever, } - def run(self, query_str: str, **kwargs: Any) -> RESPONSE_TYPE: + def run(self, query_str: str, **kwargs: Any) -> Any: """Run the pipeline.""" return self.query_engine.query(query_str) diff --git a/openssa/core/rag_ooda/resources/dense_x/dense_x.py b/openssa/core/rag_ooda/resources/dense_x/dense_x.py index 52b22b49d..4c81b23a6 100644 --- a/openssa/core/rag_ooda/resources/dense_x/dense_x.py +++ b/openssa/core/rag_ooda/resources/dense_x/dense_x.py @@ -1,11 +1,11 @@ import os -from llama_index import SimpleDirectoryReader -from llama_index.llms import OpenAI +from llama_index.core import SimpleDirectoryReader +from llama_index.llms.openai import OpenAI from openssa.utils.utils import Utils -from llama_index import StorageContext, load_index_from_storage, ServiceContext -from llama_index.retrievers import RecursiveRetriever -from llama_index.query_engine import RetrieverQueryEngine -from llama_index.node_parser.text import SentenceSplitter +from llama_index.core import StorageContext, load_index_from_storage, ServiceContext +from llama_index.core.retrievers import RecursiveRetriever +from llama_index.core.query_engine import RetrieverQueryEngine +from llama_index.core.node_parser import SentenceSplitter from openssa.core.rag_ooda.resources.dense_x.base import ( DenseXRetrievalPack, load_nodes_dict, diff --git a/openssa/core/rag_ooda/resources/rag_resource.py b/openssa/core/rag_ooda/resources/rag_resource.py index 4c3f0320a..bdefd09d3 100644 --- a/openssa/core/rag_ooda/resources/rag_resource.py +++ b/openssa/core/rag_ooda/resources/rag_resource.py @@ -1,4 +1,5 @@ -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.retrievers import BaseRetriever +from llama_index.core.query_engine import BaseQueryEngine class RagResource: diff --git a/openssa/core/rag_ooda/resources/standard_vi/standard_vi.py b/openssa/core/rag_ooda/resources/standard_vi/standard_vi.py index 6aa28045e..36d9126a1 100644 --- a/openssa/core/rag_ooda/resources/standard_vi/standard_vi.py +++ b/openssa/core/rag_ooda/resources/standard_vi/standard_vi.py @@ -1,10 +1,5 @@ import os.path -from llama_index import ( - VectorStoreIndex, - SimpleDirectoryReader, - StorageContext, - load_index_from_storage, -) +from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, StorageContext, load_index_from_storage from openssa.core.rag_ooda.resources.rag_resource import RagResource from openssa.utils.utils import Utils diff --git a/openssa/core/ssa/agent.py b/openssa/core/ssa/agent.py index 4737fb54c..d1c8703a5 100644 --- a/openssa/core/ssa/agent.py +++ b/openssa/core/ssa/agent.py @@ -58,7 +58,7 @@ def subtask(self, task, heuristic): def run_ooda_loop(self, task, heuristic): task.status = "ooda_looping" task.ooda_loop = OODALoop(task.goal) - output = task.ooda_loop.run(self.llm, task) + output = task.ooda_loop.run(self.llm, []) task.result = Task.Result(status="completed", response=output) def update_memory(self, key, value, memory_type="short"): diff --git a/openssa/core/ssa/ssa.py b/openssa/core/ssa/ssa.py index 9422685b8..2b68e8c05 100644 --- a/openssa/core/ssa/ssa.py +++ b/openssa/core/ssa/ssa.py @@ -15,7 +15,9 @@ def train(cls, document_path: str) -> str: document_path (str): Path to the document to train the agent. Returns: str: Training session ID. + """ + pass @classmethod @@ -28,7 +30,9 @@ def load(cls, training_session_id: str) -> "AbstractSSA": training_session_id (str): Training session ID. Returns: AbstractSSA: Trained agent. + """ + pass @abstractmethod @@ -57,7 +61,7 @@ def chat(self, message: str, config: dict = None) -> str: return SSAService.chat(message, config) -class RagSSA(): +class RagSSA: def __init__(self): pass diff --git a/openssa/integrations/huggingface/ssm.py b/openssa/integrations/huggingface/ssm.py index b048c2c51..8e626165b 100644 --- a/openssa/integrations/huggingface/ssm.py +++ b/openssa/integrations/huggingface/ssm.py @@ -3,6 +3,7 @@ from openssa.core.backend.abstract_backend import AbstractBackend from openssa.integrations.huggingface.slm import Falcon7bSLM + class Falcon7bSSM(BaseSSM): def __init__(self, adapter: AbstractAdapter = None, diff --git a/openssa/integrations/lepton_ai/ssm.py b/openssa/integrations/lepton_ai/ssm.py index a61482961..9cde43ff4 100644 --- a/openssa/integrations/lepton_ai/ssm.py +++ b/openssa/integrations/lepton_ai/ssm.py @@ -13,6 +13,7 @@ Config.LEPTONAI_API_KEY: Optional[str] = os.environ.get('LEPTONAI_API_KEY') or None Config.LEPTONAI_API_URL: Optional[str] = os.environ.get('LEPTONAI_API_URL') or None + # pylint: disable=too-many-instance-attributes class APIContext(OpenAIAPIContext): @classmethod diff --git a/openssa/integrations/llama_index/backend.py b/openssa/integrations/llama_index/backend.py index ad955e60e..ff069732d 100644 --- a/openssa/integrations/llama_index/backend.py +++ b/openssa/integrations/llama_index/backend.py @@ -1,17 +1,16 @@ from dataclasses import dataclass -from llama_index import ( - download_loader, +from llama_index.core import ( load_index_from_storage, SimpleDirectoryReader, VectorStoreIndex, Response, ServiceContext, ) -from llama_index.llms import OpenAI -from llama_index.indices.base import BaseIndex -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.llms.llm import LLM as RAGLLM -from llama_index.storage import StorageContext +from llama_index.llms.openai import OpenAI +from llama_index.core.indices.base import BaseIndex +from llama_index.core.query_engine import BaseQueryEngine +from llama_index.core import StorageContext +from llama_index.readers.web import SimpleWebPageReader from openssa.core.backend.rag_backend import AbstractRAGBackend @@ -96,7 +95,7 @@ def query(self, user_input: list[dict], conversation: list[dict] = None) -> dict def _create_index(self, documents, storage_dir: str): self.index = VectorStoreIndex.from_documents( - documents, service_context=self._service_context + documents, service_context=self._service_context, show_progress=True ) def _do_read_directory(self, storage_dir: str): @@ -113,14 +112,12 @@ def _do_read_directory(self, storage_dir: str): file_extractor=None, num_files_limit=None, file_metadata=None, - ).load_data() + ).load_data(num_workers=5) self._create_index(documents, storage_dir) def _do_read_website(self, urls: list[str], storage_dir: str): - the_class = download_loader("SimpleWebPageReader") - loader = the_class() - documents = loader.load_data(urls=urls) + documents = SimpleWebPageReader(html_to_text=True).load_data(urls) self._create_index(documents, storage_dir) def _do_save(self, storage_dir: str): diff --git a/openssa/integrations/llama_index/ssm.py b/openssa/integrations/llama_index/ssm.py index 280b84a70..999748914 100644 --- a/openssa/integrations/llama_index/ssm.py +++ b/openssa/integrations/llama_index/ssm.py @@ -1,11 +1,8 @@ -from llama_index.llms.llm import LLM as RAGLLM -from llama_index.llms import OpenAI, AzureOpenAI from openssa.integrations.llama_index.backend import Backend as LlamaIndexBackend from openssa.integrations.openai.ssm import GPT3ChatCompletionSLM from openssa.core.ssm.rag_ssm import RAGSSM from openssa.core.slm.abstract_slm import AbstractSLM from openssa.integrations.lepton_ai.ssm import SLM as LeptonSLM -from openssa.utils.config import Config from openssa.core.slm.base_slm import PassthroughSLM diff --git a/openssa/utils/cli/contrib/ssa_prob_solver.py b/openssa/utils/cli/contrib/ssa_prob_solver.py index b088353d4..45a328ae4 100644 --- a/openssa/utils/cli/contrib/ssa_prob_solver.py +++ b/openssa/utils/cli/contrib/ssa_prob_solver.py @@ -1,15 +1,12 @@ """OpenSSA Contrib SSA Problem Solver CLI.""" -from collections.abc import Sequence import os +from pathlib import Path import click -__all__: Sequence[str] = ('openssa_contrib_ssa_prob_solver_cli',) - - @click.command(name='solver', cls=click.Command, context_settings=None, @@ -23,4 +20,5 @@ def openssa_contrib_ssa_prob_solver_cli(): """Launch StreamlitSSAProbSolver.""" from openssa.contrib.streamlit_ssa_prob_solver import __path__ # pylint: disable=import-outside-toplevel - os.system(f'streamlit run {__path__[0]}/main.py --server.allowRunOnSave=true --server.runOnSave=true') + os.chdir(Path(__path__[0])) + os.system('streamlit run main.py --server.allowRunOnSave=true --server.runOnSave=true') diff --git a/openssa/utils/config.py b/openssa/utils/config.py index 45fc7e327..b3e8dafcc 100644 --- a/openssa/utils/config.py +++ b/openssa/utils/config.py @@ -1,6 +1,7 @@ import os import dotenv from openssa.utils.logs import mlogger +from typing import Optional dotenv.load_dotenv(override=True) @@ -15,22 +16,47 @@ class Config: But the user can override them by setting them directly in the Config object. """ + _dummy = "value is not set" DEBUG = False - # get OPENAI_API_KEY from environment variable - # moved to openssa/integrations/openai/slm.py - # OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') or _dummy + DEFAULT_TEMPERATURE = 0.0 + + AITOMATIC_API_KEY: Optional[str] = os.environ.get("AITOMATIC_API_KEY") + AITOMATIC_API_URL: Optional[str] = ( + os.environ.get("AITOMATIC_API_URL") + or "https://aimo-api-mvp.platform.aitomatic.com/api/v1" + ) + + AITOMATIC_API_URL_7B: Optional[str] = os.environ.get("AITOMATIC_API_URL_7B") + AITOMATIC_API_URL_70B: Optional[str] = os.environ.get("AITOMATIC_API_URL_70B") + + OPENAI_API_KEY: Optional[str] = os.environ.get("OPENAI_API_KEY") + OPENAI_API_URL: Optional[str] = ( + os.environ.get("OPENAI_API_URL") or "https://api.openai.com/v1" + ) + + AZURE_API_VERSION: Optional[str] = ( + os.environ.get("AZURE_API_VERSION") or "2024-02-15-preview" + ) + # https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation#latest-preview-api-release - # get HUGGING_FACE_HUB_TOKEN from environment variable - # HUGGING_FACE_HUB_TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN') or _dummy + AZURE_OPENAI_API_KEY: Optional[str] = os.environ.get("AZURE_OPENAI_API_KEY") + AZURE_OPENAI_API_URL: Optional[str] = ( + os.environ.get("AZURE_OPENAI_API_URL") or "https://aiva-japan.openai.azure.com" + ) - # Falcon7b server token (HuggingFace’s, or our own server) - # FALCON7B_API_KEY = os.environ.get('FALCON7B_API_KEY') or HUGGING_FACE_HUB_TOKEN + US_AZURE_OPENAI_API_KEY: Optional[str] = os.environ.get("US_AZURE_OPENAI_API_KEY") + US_AZURE_OPENAI_API_BASE: Optional[str] = ( + os.environ.get("US_AZURE_OPENAI_API_BASE") + or "https://aiva-dev.openai.azure.com/" + ) - # Falcon7b server URL (HuggingFace’s, or our own server) - # FALCON7B_MODEL_URL = os.environ.get('FALCON7B_MODEL_URL') + LEPTON_API_KEY: Optional[str] = os.environ.get("LEPTON_API_KEY") + LEPTON_API_URL: Optional[str] = ( + os.environ.get("LEPTON_API_URL") or "https://llama2-7b.lepton.run/api/v1" + ) @staticmethod def setenv(var_name): diff --git a/openssa/utils/deprecated/llama_index_api.py b/openssa/utils/deprecated/llama_index_api.py index aaf20b62e..449d3f723 100644 --- a/openssa/utils/deprecated/llama_index_api.py +++ b/openssa/utils/deprecated/llama_index_api.py @@ -1,11 +1,13 @@ from __future__ import annotations from typing import Optional import os -from llama_index import ServiceContext -from llama_index.embeddings import AzureOpenAIEmbedding, OpenAIEmbedding -from llama_index.llms import AzureOpenAI, OpenAI as LlamaIndexOpenAI -from llama_index.llms.llm import LLM # noqa: TCH002 -from llama_index.llms.openai_utils import ALL_AVAILABLE_MODELS, CHAT_MODELS +from llama_index.core import ServiceContext +from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.llms.azure_openai import AzureOpenAI +from llama_index.core.llms import OpenAI as LlamaIndexOpenAI +from llama_index.core.llms.llm import LLM # noqa: TCH002 +from llama_index.core.llms.openai_utils import ALL_AVAILABLE_MODELS, CHAT_MODELS from openssa.utils.config import Config # import sys diff --git a/openssa/utils/fs.py b/openssa/utils/fs.py index 9d6062588..f80a117da 100644 --- a/openssa/utils/fs.py +++ b/openssa/utils/fs.py @@ -36,7 +36,7 @@ _LOCAL_FS = LocalFileSystem(auto_mkdir=True, use_listings_cache=False, listings_expiry_time=None, max_paths=None) # GCS file system -_GCS_FS = GCSFileSystem() +# _GCS_FS = GCSFileSystem() _GCS_PROTOCOL_PREFIX: str = 'gcs://' _GCS_PROTOCOL_PREFIX_LEN: int = len(_GCS_PROTOCOL_PREFIX) @@ -83,7 +83,7 @@ def native_path(self) -> DirOrFilePath: def fs(self) -> AbstractFileSystem: # pylint: disable=invalid-name """Get applicable file system.""" if self.on_gcs: - return _GCS_FS + return GCSFileSystem() if self.on_s3: return _S3_FS diff --git a/openssa/utils/llms.py b/openssa/utils/llms.py index 531b9739d..67391c144 100644 --- a/openssa/utils/llms.py +++ b/openssa/utils/llms.py @@ -1,38 +1,8 @@ from __future__ import annotations -from typing import Optional -import os import json from openai import OpenAI, AzureOpenAI from openssa.utils.config import Config -# TODO: there should be a single Aitomatic api_base and api_key -Config.AITOMATIC_API_KEY: Optional[str] = os.environ.get("AITOMATIC_API_KEY") -Config.AITOMATIC_API_URL: Optional[str] = ( - os.environ.get("AITOMATIC_API_URL") - or "https://aimo-api-mvp.platform.aitomatic.com/api/v1" -) -Config.AITOMATIC_API_URL_7B: Optional[str] = ( - os.environ.get("AITOMATIC_API_URL_7B") or "https://llama2-7b.lepton.run/api/v1" -) -Config.AITOMATIC_API_URL_70B: Optional[str] = ( - os.environ.get("AITOMATIC_API_URL_70B") or "https://llama2-70b.lepton.run/api/v1" -) - -Config.OPENAI_API_KEY: Optional[str] = os.environ.get("OPENAI_API_KEY") -Config.OPENAI_API_URL: Optional[str] = ( - os.environ.get("OPENAI_API_URL") or "https://api.openai.com/v1" -) - -Config.AZURE_OPENAI_API_KEY: Optional[str] = os.environ.get("AZURE_OPENAI_API_KEY") -Config.AZURE_OPENAI_API_URL: Optional[str] = ( - os.environ.get("AZURE_OPENAI_API_URL") or "https://aiva-japan.openai.azure.com" -) - -Config.LEPTON_API_KEY: Optional[str] = os.environ.get("LEPTON_API_KEY") -Config.LEPTON_API_URL: Optional[str] = ( - os.environ.get("LEPTON_API_URL") or "https://llama2-7b.lepton.run/api/v1" -) - class AnLLM: """ @@ -106,17 +76,11 @@ class OpenAILLM(AnLLM): def __init__( self, - model: str = None, - api_base: str = None, - api_key: str = None, + model: str = "gpt-3.5-turbo-0125", + api_base: str = Config.OPENAI_API_URL, + api_key: str = Config.OPENAI_API_KEY, **additional_kwargs, ): - if model is None: - model = "gpt-3.5-turbo-1106" - if api_base is None: - api_base = Config.OPENAI_API_URL - if api_key is None: - api_key = Config.OPENAI_API_KEY super().__init__( model=model, api_base=api_base, api_key=api_key, **additional_kwargs ) @@ -129,23 +93,19 @@ def client(self) -> OpenAI: @classmethod def get_default(cls) -> OpenAILLM: - return cls.get_gpt_35_turbo() + return cls() @classmethod def get_gpt_35_turbo_1106(cls) -> OpenAILLM: return cls(model="gpt-3.5-turbo-1106") - @classmethod - def get_gpt_35_turbo_0613(cls) -> OpenAILLM: - return cls(model="gpt-3.5-turbo") - @classmethod def get_gpt_35_turbo(cls) -> OpenAILLM: - return cls(model="gpt-3.5-turbo-0613") + return cls(model="gpt-3.5-turbo") @classmethod - def get_gpt_4(cls) -> OpenAILLM: - return cls(model="gpt-4") + def get_gpt_4_0125_preview(cls) -> OpenAILLM: + return cls(model="gpt-4-0125-preview") @classmethod def get_gpt_4_1106_preview(cls) -> OpenAILLM: @@ -159,17 +119,11 @@ class AitomaticLLM(OpenAILLM): def __init__( self, - model: str = None, - api_base: str = None, - api_key: str = None, + model: str = "llama2-7b", + api_base: str = Config.AITOMATIC_API_URL, + api_key: str = Config.AITOMATIC_API_KEY, **additional_kwargs, ): - if model is None: - model = "llama2-7b" - if api_base is None: - api_base = Config.AITOMATIC_API_URL - if api_key is None: - api_key = Config.AITOMATIC_API_KEY super().__init__( model=model, api_base=api_base, api_key=api_key, **additional_kwargs ) @@ -180,34 +134,22 @@ def get_default(cls) -> AitomaticLLM: @classmethod def get_llama2_70b(cls) -> AitomaticLLM: - # TODO: there should be a single Aitomatic api_base and api_key return cls( model="llama2-70b", api_base=Config.AITOMATIC_API_URL_70B, - api_key=Config.LEPTON_API_KEY, + api_key=Config.AITOMATIC_API_KEY, ) @classmethod def get_llama2_7b(cls) -> AitomaticLLM: - # TODO: there should be a single Aitomatic api_base and api_key return cls( model="llama2-7b", - api_base=Config.AITOMATIC_API_URL_70B, - api_key=Config.LEPTON_API_KEY, - ) - - @classmethod - def get_13b(cls) -> AitomaticLLM: - # TODO: there should be a single Aitomatic api_base and api_key - return cls( - model="gpt-3.5-turbo-0613", - api_base="http://35.199.34.91:8000/v1", + api_base=Config.AITOMATIC_API_URL_7B, api_key=Config.AITOMATIC_API_KEY, ) @classmethod def get_yi_34b(cls) -> AitomaticLLM: - # TODO: there should be a single Aitomatic api_base and api_key return cls( model="01-ai/Yi-34B-Chat", api_base="http://35.230.174.89:8000/v1", @@ -216,7 +158,6 @@ def get_yi_34b(cls) -> AitomaticLLM: @classmethod def get_intel_neural_chat_7b(cls) -> AitomaticLLM: # running - # TODO: there should be a single Aitomatic api_base and api_key return cls( model="Intel/neural-chat-7b-v3-1", api_base="http://34.145.174.152:8000/v1" ) diff --git a/openssa/utils/rag_service_contexts.py b/openssa/utils/rag_service_contexts.py new file mode 100644 index 000000000..4e21acf02 --- /dev/null +++ b/openssa/utils/rag_service_contexts.py @@ -0,0 +1,93 @@ +from __future__ import annotations +from llama_index.llms.openai import OpenAI +from llama_index.llms.azure_openai import AzureOpenAI +from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding +from llama_index.core import ServiceContext +from openssa.utils.config import Config + + +class ServiceContextManager: + + @classmethod + def get_aitomatic_sc(cls, **kwargs) -> ServiceContext: + temperature = Config.DEFAULT_TEMPERATURE + model = kwargs.get("model", "aitomatic-llm") + embed_model = kwargs.get("embed_model", "aitomatic-emb") + return ServiceContext.from_defaults( + llm=OpenAI( + model=model, + temperature=temperature, + api_key=Config.AITOMATIC_API_KEY, + api_base=Config.AITOMATIC_API_URL, + ), + embed_model=OpenAIEmbedding( + model=embed_model, + api_key=Config.AITOMATIC_API_KEY, + api_base=Config.AITOMATIC_API_URL, + ), + ) + + @classmethod + def get_openai_sc(cls, **kwargs) -> ServiceContext: + temperature = kwargs.get("temperature", Config.DEFAULT_TEMPERATURE) + model = kwargs.get("model", "gpt-3.5-turbo") + embed_model = kwargs.get("embed_model", "text-embedding-ada-002") + if "temperature" in kwargs: + del kwargs["temperature"] + if "model" in kwargs: + del kwargs["model"] + if "embed_model" in kwargs: + del kwargs["embed_model"] + + return ServiceContext.from_defaults( + llm=OpenAI(model=model, temperature=temperature, **kwargs), + embed_model=OpenAIEmbedding(model=embed_model, **kwargs), + ) + + @classmethod + def get_azure_openai_sc(cls, **kwargs) -> ServiceContext: + temperature = kwargs.get("temperature", Config.DEFAULT_TEMPERATURE) + engine = kwargs.get("engine", "aiva-dev-gpt35") + embed_model = kwargs.get("embed_model", "text-embedding-ada-002") + deployment_name = kwargs.get("deployment_name", "text-embedding-ada-002") + api_version = kwargs.get("api_version", Config.AZURE_API_VERSION) + azure_endpoint = kwargs.get("azure_endpoint", Config.AZURE_OPENAI_API_URL) + api_key = kwargs.get("api_key", Config.AZURE_OPENAI_API_KEY) + + return ServiceContext.from_defaults( + llm=AzureOpenAI( + temperature=temperature, + engine=engine, + api_version=api_version, + azure_endpoint=azure_endpoint, + api_key=api_key, + ), + embed_model=AzureOpenAIEmbedding( + model=embed_model, + deployment_name=deployment_name, + api_version=api_version, + azure_endpoint=azure_endpoint, + api_key=api_key, + ), + ) + + @classmethod + def get_openai_35_turbo_sc(cls, **kwargs) -> ServiceContext: + return cls.get_openai_sc(model="gpt-3.5-turbo", **kwargs) + + @classmethod + def get_azure_jp_openai_35_turbo_sc(cls) -> ServiceContext: + return cls.get_azure_openai_sc() + + @classmethod + def get_azure_openai_4_0125_preview_sc(cls) -> ServiceContext: + return cls.get_azure_openai_sc( + engine="gpt-4-0125", + api_key=Config.US_AZURE_OPENAI_API_KEY, + azure_endpoint=Config.US_AZURE_OPENAI_API_BASE, + ) + + @classmethod + def get_openai_4_0125_preview_sc(cls) -> ServiceContext: + return cls.get_openai_sc(model="gpt-4-0125-preview") diff --git a/pyproject.toml b/pyproject.toml index e6159691e..5b9c5b7ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "openssa" -version = "0.23.12.19" +version = "0.24.02.27" authors = [ "Aitomatic, Inc. ", @@ -46,46 +46,49 @@ exclude = [ [tool.poetry.group.docs.dependencies] sphinx = ">=7.2" -sphinx-autobuild = ">=2021.3" +sphinx-autobuild = ">=2024.2" sphinx-press-theme = ">=0.8" myst-parser = ">=2.0" [tool.poetry.group.lint.dependencies] -flake8 = ">=6.1" -pylint = ">=3.0" -ruff = ">=0.1" +flake8 = ">=7.0" +pylint = ">=3.1" +ruff = ">=0.2" pydocstyle = ">=6.3" [tool.poetry.group.test.dependencies] -pytest = ">=7.4" +pytest = ">=8.0" [tool.poetry.dependencies] python = ">=3.10,<3.12" # OpenAI interface -openai = ">=1.6" +openai = ">=1.12.0" # LlamaIndex & related -llama-index = ">=0.9.17" # should keep up-to-date with Llama-Index's minor releases (often backward-incompatible) -llama-hub = ">=0.0.60" +llama-index = ">=0.10.13" # should keep up-to-date with Llama-Index's minor releases (often backward-incompatible) +llama-hub = ">=0.0.79" +llama-index-llms-azure-openai = ">=0.1" +llama-index-embeddings-azure-openai = ">=0.1" +llama-index-readers-web = ">=0.1" docx2txt = ">=0.8" # for reading .docx files -pypdf = ">=3.17" # for reading .pdf files -pycryptodome = ">=3.19" # for reading .pdf files: PyCryptodome is required for AES algorithm +html2text = "2020.1.16" # for reading web pages +pypdf = ">=4.0" # for reading .pdf files +pycryptodome = ">=3.20" # for reading .pdf files: PyCryptodome is required for AES algorithm # file systems -fsspec = "~2023.10.0" # base/local file system -adlfs = "~2023.10.0" # Azure Data Lake -gcsfs = "~2023.10.0" # Google Cloud Storage -s3fs = "~2023.10.0" # S3 +fsspec = ">=2024.2" # base/local file system +adlfs = ">=2024.2" # Azure Data Lake +gcsfs = ">=2024.2" # Google Cloud Storage +s3fs = ">=2024.2" # S3 # misc / other click = ">=8.1" -google-api-python-client = ">=2.111" -httpx = ">=0.25" +google-api-python-client = ">=2.119" +httpx = ">=0.27" loguru = ">=0.7" -# pydantic = ">=2.5" pydantic = "1.10.9" python-dotenv = ">=1.0" tomli = { version = ">=2.0", python = "<3.11" } # optional dependencies required by extras -streamlit = {version = ">=1.29", optional = true} +streamlit = {version = ">=1.31", optional = true} streamlit-mic-recorder = {version = ">=0.0.4", optional = true} [tool.poetry.extras] @@ -102,7 +105,7 @@ openssa = "openssa.utils.cli:openssa_cli" [build-system] build-backend = "poetry.core.masonry.api" requires = [ - "poetry-core >= 1.8", + "poetry-core >= 1.9", ] diff --git a/tests/core/ooda/test_ooda_loop.py b/tests/core/ooda/test_ooda_loop.py index ff85ebe44..915e85732 100644 --- a/tests/core/ooda/test_ooda_loop.py +++ b/tests/core/ooda/test_ooda_loop.py @@ -1,6 +1,7 @@ import unittest from openssa.core.ooda.ooda_loop import OODALoop + class TestOODALoop(unittest.TestCase): def test_run(self): # Create an instance of OODALoop diff --git a/tests/integrations/test_lepton_ai.py b/tests/integrations/test_lepton_ai.py index c4a78951f..b61242af6 100644 --- a/tests/integrations/test_lepton_ai.py +++ b/tests/integrations/test_lepton_ai.py @@ -5,9 +5,11 @@ from openssa.integrations.lepton_ai.ssm import SSM as LeptonAISSM, RAGSSM as LeptonAIRAGSSM from openssa.utils.config import Config + Config.LEPTONAI_API_URL = "test_url" Config.LEPTONAI_API_KEY = "test_key" + class TestSSM(unittest.TestCase): def test_constructor_default_values(self): adapter = MagicMock(spec=AbstractAdapter) diff --git a/tests/integrations/test_llama_index.py b/tests/integrations/test_llama_index.py index 2ff66b8f4..7fb95d321 100644 --- a/tests/integrations/test_llama_index.py +++ b/tests/integrations/test_llama_index.py @@ -1,8 +1,8 @@ import unittest from unittest.mock import MagicMock, patch -from llama_index import Response -from llama_index.indices.base import BaseIndex -from llama_index.indices.query.base import BaseQueryEngine +from llama_index.core import Response +from llama_index.core.indices.base import BaseIndex +from llama_index.core.query_engine import BaseQueryEngine from openssa.core.slm.abstract_slm import AbstractSLM from openssa.core.slm.base_slm import PassthroughSLM from openssa.integrations.llama_index.backend import Backend as LlamaIndexBackend @@ -31,6 +31,7 @@ def test_gpt3_llama_index_ssm(self): # self.assertIsInstance(ssm.slm, GPT3ChatCompletionSLM) pass + class TestBackend(unittest.TestCase): def test_query_engine(self): backend = LlamaIndexBackend() diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index dd8124339..d348f73cf 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,5 +1,6 @@ from openssa.utils.utils import Utils + class TestUtils: def test_canonicalize_user_input_str(self): assert Utils.canonicalize_user_input('hello') == [{'role': 'user', 'content': 'hello'}]