From 7896ac20da83bde6e624bcbec69421458961544f Mon Sep 17 00:00:00 2001 From: berkecanrizai <63911408+berkecanrizai@users.noreply.github.com> Date: Fri, 15 Mar 2024 18:17:10 +0300 Subject: [PATCH] refactor demo rag app, add config GitOrigin-RevId: 79fd4efdd8641ea1762882e29b232a1df4d3af06 --- .../pipelines/demo-question-answering/app.py | 398 ++++++++++-------- .../demo-question-answering/config.yaml | 39 ++ 2 files changed, 253 insertions(+), 184 deletions(-) create mode 100644 examples/pipelines/demo-question-answering/config.yaml diff --git a/examples/pipelines/demo-question-answering/app.py b/examples/pipelines/demo-question-answering/app.py index 50432fc..2c19578 100644 --- a/examples/pipelines/demo-question-answering/app.py +++ b/examples/pipelines/demo-question-answering/app.py @@ -1,100 +1,27 @@ import json +import sys from enum import Enum +import click import pathway as pw +import pathway.io.fs as io_fs +import pathway.io.gdrive as io_gdrive +import yaml from dotenv import load_dotenv -from pathway.internals.asynchronous import DiskCache, ExponentialBackoffRetryStrategy +from pathway.internals.udfs import DiskCache, ExponentialBackoffRetryStrategy from pathway.xpacks.llm import embedders, llms, prompts from pathway.xpacks.llm.parsers import ParseUnstructured from pathway.xpacks.llm.splitters import TokenCountSplitter from pathway.xpacks.llm.vector_store import VectorStoreServer +load_dotenv() + class AIResponseType(Enum): SHORT = "short" LONG = "long" -load_dotenv() - -embedder = embedders.OpenAIEmbedder( - model="text-embedding-ada-002", - cache_strategy=DiskCache(), -) - -host = "0.0.0.0" -port = 8000 - -data_sources = [] - -folder = pw.io.fs.read( - "data", - format="binary", - mode="streaming", - with_metadata=True, -) - -data_sources.append(folder) - -# drive_folder = pw.io.gdrive.read( -# object_id="YOUR FOLDER ID", -# with_metadata=True, -# service_user_credentials_file="secret.json", -# refresh_interval=30, -# ) - -# data_sources.append(drive_folder) - - -text_splitter = TokenCountSplitter(max_tokens=400) - - -vector_server = VectorStoreServer( - *data_sources, - embedder=embedder, - splitter=text_splitter, - parser=ParseUnstructured(), -) - - -chat = llms.OpenAIChat( - model="gpt-3.5-turbo", - retry_strategy=ExponentialBackoffRetryStrategy(max_retries=6), - cache_strategy=DiskCache(), - temperature=0.05, -) - - -class PWAIQuery(pw.Schema): - prompt: str - filters: str | None = pw.column_definition(default_value=None) - model: str | None = pw.column_definition(default_value="gpt-3.5-turbo") - response_type: str = pw.column_definition(default_value="short") # short | long - openai_api_key: str - - -pw_ai_endpoint = "/v1/pw_ai_answer" - - -class SummarizeQuery(pw.Schema): - text_list: list[str] - model: str | None = pw.column_definition(default_value="gpt-3.5-turbo") - openai_api_key: str - - -summarize_endpoint = "/v1/pw_ai_summary" - - -class AggregateQuery(pw.Schema): - question: str - answers: list[str] - model: str | None = pw.column_definition(default_value="gpt-3.5-turbo") - openai_api_key: str - - -aggregate_endpoint = "/v1/pw_ai_aggregate_responses" - - def _unwrap_udf(func): if isinstance(func, pw.UDF): return func.__wrapped__ @@ -102,11 +29,13 @@ def _unwrap_udf(func): @pw.udf -def gpt_respond(prompt, docs, filter, response_type) -> str: +def prep_rag_prompt( + prompt: str, docs: list[pw.Json], filter: str | None, response_type: str +) -> str: if filter is None: return prompt - docs = docs.value + docs = docs.value # type: ignore try: docs = [{"text": doc["text"], "path": doc["metadata"]["path"]} for doc in docs] @@ -122,7 +51,7 @@ def gpt_respond(prompt, docs, filter, response_type) -> str: @pw.udf -def prompt_aggregate(question, answers): +def prompt_aggregate(question: str, answers: list[str]) -> str: summary_data = "\n".join(answers) summaries_str = json.dumps(summary_data, indent=2) @@ -139,120 +68,221 @@ def prompt_aggregate(question, answers): return prompt -def run( - with_cache: bool = True, - cache_backend: pw.persistence.Backend | None = pw.persistence.Backend.filesystem( - "./Cache" - ), -): - webserver = pw.io.http.PathwayWebserver(host=host, port=port) - # Vectorserver - - def serve(route, schema, handler): - queries, writer = pw.io.http.rest_connector( - webserver=webserver, - route=route, - schema=schema, - autocommit_duration_ms=50, - delete_completed_queries=True, +def data_sources(source_configs) -> list[pw.Table]: + sources = [] + for source_config in source_configs: + if source_config["kind"] == "local": + source = io_fs.read( + **source_config["config"], + format="binary", + with_metadata=True, + ) + sources.append(source) + elif source_config["kind"] == "gdrive": + source = io_gdrive.read( + **source_config["config"], + with_metadata=True, + ) + sources.append(source) + elif source_config["kind"] == "sharepoint": + try: + import pathway.xpacks.connectors.sharepoint as io_sp + + source = io_sp.read(**source_config["config"], with_metadata=True) + sources.append(source) + except ImportError: + print( + "The Pathway Sharepoint connector is part of the commercial offering, " + "please contact us for a commercial license." + ) + sys.exit(1) + + return sources + + +class PathwayRAG: + class PWAIQuerySchema(pw.Schema): + prompt: str + filters: str | None = pw.column_definition(default_value=None) + model: str | None = pw.column_definition(default_value="gpt-3.5-turbo") + response_type: str = pw.column_definition(default_value="short") # short | long + + class SummarizeQuerySchema(pw.Schema): + text_list: list[str] + model: str | None = pw.column_definition(default_value="gpt-3.5-turbo") + + class AggregateQuerySchema(pw.Schema): + question: str + answers: list[str] + model: str | None = pw.column_definition(default_value="gpt-3.5-turbo") + + def __init__( + self, + *docs: pw.Table, + llm: pw.UDF, + embedder: pw.UDF, + splitter: pw.UDF, + parser: pw.UDF = ParseUnstructured(), + doc_post_processors=None, + ) -> None: + self.llm = llm + + self.embedder = embedder + + self.vector_server = VectorStoreServer( + *docs, + embedder=embedder, + splitter=splitter, + parser=parser, + doc_post_processors=doc_post_processors, ) - writer(handler(queries)) - - serve( - "/v1/retrieve", vector_server.RetrieveQuerySchema, vector_server.retrieve_query - ) - serve( - "/v1/statistics", - vector_server.StatisticsQuerySchema, - vector_server.statistics_query, - ) - serve( - "/v1/pw_list_documents", - vector_server.InputsQuerySchema, - vector_server.inputs_query, - ) - - gpt_queries, gpt_response_writer = pw.io.http.rest_connector( - webserver=webserver, - route=pw_ai_endpoint, - schema=PWAIQuery, - autocommit_duration_ms=50, - delete_completed_queries=True, - ) - gpt_results = gpt_queries + vector_server.retrieve_query( - gpt_queries.select( - metadata_filter=pw.this.filters, - filepath_globpattern=pw.cast(str | None, None), - query=pw.this.prompt, - k=6, + @pw.table_transformer + def pw_ai_query(self, pw_ai_queries: pw.Table[PWAIQuerySchema]) -> pw.Table: + """Main function for RAG applications that answer questions + based on available information.""" + + pw_ai_results = pw_ai_queries + self.vector_server.retrieve_query( + pw_ai_queries.select( + metadata_filter=pw.this.filters, + filepath_globpattern=pw.cast(str | None, None), + query=pw.this.prompt, + k=6, + ) + ).select( + docs=pw.this.result, ) - ).select( - docs=pw.this.result, - ) - gpt_results += gpt_results.select( - rag_prompt=gpt_respond( - pw.this.prompt, pw.this.docs, pw.this.filters, pw.this.response_type + pw_ai_results += pw_ai_results.select( + rag_prompt=prep_rag_prompt( + pw.this.prompt, pw.this.docs, pw.this.filters, pw.this.response_type + ) ) - ) - gpt_results += gpt_results.select( - result=chat( - llms.prompt_chat_single_qa(pw.this.rag_prompt), - model=pw.this.model, - api_key=pw.this.openai_api_key, + pw_ai_results += pw_ai_results.select( + result=self.llm( + llms.prompt_chat_single_qa(pw.this.rag_prompt), + model=pw.this.model, + ) + ) + return pw_ai_results + + @pw.table_transformer + def summarize_query( + self, summarize_queries: pw.Table[SummarizeQuerySchema] + ) -> pw.Table: + summarize_results = summarize_queries.select( + pw.this.model, + prompt=prompts.prompt_summarize(pw.this.text_list), + ) + summarize_results += summarize_results.select( + result=self.llm( + llms.prompt_chat_single_qa(pw.this.prompt), + model=pw.this.model, + ) + ) + return summarize_results + + @pw.table_transformer + def aggregate_query( + self, aggregate_queries: pw.Table[AggregateQuerySchema] + ) -> pw.Table: + aggregate_results = aggregate_queries.select( + pw.this.model, + prompt=prompt_aggregate(pw.this.question, pw.this.answers), + ) + aggregate_results += aggregate_results.select( + result=self.llm( + llms.prompt_chat_single_qa(pw.this.prompt), + model=pw.this.model, + ) + ) + return aggregate_results + + def build_server(self, host: str, port: int) -> None: + """Adds HTTP connectors to input tables""" + + webserver = pw.io.http.PathwayWebserver(host=host, port=port) + + # connect http endpoint to output writer + def serve(route, schema, handler): + queries, writer = pw.io.http.rest_connector( + webserver=webserver, + route=route, + schema=schema, + autocommit_duration_ms=50, + delete_completed_queries=True, + ) + writer(handler(queries)) + + serve( + "/v1/retrieve", + self.vector_server.RetrieveQuerySchema, + self.vector_server.retrieve_query, + ) + serve( + "/v1/statistics", + self.vector_server.StatisticsQuerySchema, + self.vector_server.statistics_query, + ) + serve( + "/v1/pw_list_documents", + self.vector_server.InputsQuerySchema, + self.vector_server.inputs_query, + ) + serve("/v1/pw_ai_answer", self.PWAIQuerySchema, self.pw_ai_query) + serve( + "/v1/pw_ai_summary", + self.SummarizeQuerySchema, + self.summarize_query, + ) + serve( + "/v1/pw_ai_aggregate_responses", + self.AggregateQuerySchema, + self.aggregate_query, ) - ) - summarize_queries, summarize_response_writer = pw.io.http.rest_connector( - webserver=webserver, - route=summarize_endpoint, - schema=SummarizeQuery, - autocommit_duration_ms=50, - delete_completed_queries=True, - ) - summarize_results = summarize_queries.select( - pw.this.model, - pw.this.openai_api_key, - prompt=prompts.prompt_summarize(pw.this.text_list), - ) - summarize_results += summarize_results.select( - result=chat( - llms.prompt_chat_single_qa(pw.this.prompt), - model=pw.this.model, - api_key=pw.this.openai_api_key, - ) - ) +@click.command() +@click.option("--config_file", default="config.yaml", help="Config file to be used.") +def run( + config_file: str = "config.yaml", +): + with open(config_file) as config_f: + configuration = yaml.safe_load(config_f) + + GPT_MODEL = configuration["llm_config"]["model"] - aggregate_queries, aggregate_response_writer = pw.io.http.rest_connector( - webserver=webserver, - route=aggregate_endpoint, - schema=AggregateQuery, - autocommit_duration_ms=50, - delete_completed_queries=True, + embedder = embedders.OpenAIEmbedder( + model="text-embedding-ada-002", + cache_strategy=DiskCache(), ) - aggregate_results = aggregate_queries.select( - pw.this.model, - pw.this.openai_api_key, - prompt=prompt_aggregate(pw.this.question, pw.this.answers), + text_splitter = TokenCountSplitter(max_tokens=400) + + chat = llms.OpenAIChat( + model=GPT_MODEL, + retry_strategy=ExponentialBackoffRetryStrategy(max_retries=6), + cache_strategy=DiskCache(), + temperature=0.05, ) - aggregate_results += aggregate_results.select( - result=chat( - llms.prompt_chat_single_qa(pw.this.prompt), - model=pw.this.model, - api_key=pw.this.openai_api_key, - ) + + host_config = configuration["host_config"] + host, port = host_config["host"], host_config["port"] + + rag_app = PathwayRAG( + *data_sources(configuration["sources"]), + embedder=embedder, + llm=chat, + splitter=text_splitter, ) - gpt_response_writer(gpt_results) - summarize_response_writer(summarize_results) - aggregate_response_writer(aggregate_results) + rag_app.build_server(host=host, port=port) - if with_cache: - if cache_backend is None: - raise ValueError("Cache usage was requested but the backend is unspecified") + if configuration["cache_options"].get("with_cache", True): + print("Running with cache enabled.") + cache_backend = pw.persistence.Backend.filesystem( + configuration["cache_options"].get("cache_folder", "./Cache") + ) persistence_config = pw.persistence.Config.simple_config( cache_backend, persistence_mode=pw.PersistenceMode.UDF_CACHING, @@ -267,4 +297,4 @@ def serve(route, schema, handler): if __name__ == "__main__": - run(with_cache=True) + run() diff --git a/examples/pipelines/demo-question-answering/config.yaml b/examples/pipelines/demo-question-answering/config.yaml new file mode 100644 index 0000000..facbe8f --- /dev/null +++ b/examples/pipelines/demo-question-answering/config.yaml @@ -0,0 +1,39 @@ +llm_config: + model: "gpt-3.5-turbo" +host_config: + host: "0.0.0.0" + port: 8000 +cache_options: + with_cache: True + cache_folder: "./Cache" +sources: + - local_files: + kind: local + config: + # Please refer to + # https://pathway.com/developers/api-docs/pathway-io/fs#pathway.io.fs.read + # for options definition + path: "data/" + # - google_drive_folder: + # kind: gdrive + # config: + # # Please refer to + # # https://pathway.com/developers/api-docs/pathway-io/gdrive#pathway.io.gdrive.read + # # for options definition + # # Please follow https://pathway.com/developers/user-guide/connectors/gdrive-connector/#setting-up-google-drive + # # for instructions on getting credentials + # object_id: "1cULDv2OaViJBmOfG5WB0oWcgayNrGtVs" # folder used in the managed demo + # service_user_credentials_file: SERVICE_CREDENTIALS + # refresh_interval: 5 + # - sharepoint_folder: + # kind: sharepoint + # config: + # # The sharepoint is part of our commercial offering, please contact us to use it + # # Please contact here: `contact@pathway.com` + # root_path: ROOT_PATH + # url: SHAREPOINT_URL + # tenant: SHAREPOINT_TENANT + # client_id: SHAREPOINT_CLIENT_ID + # cert_path: SHAREPOINT.pem + # thumbprint: SHAREPOINT_THUMBPRINT + # refresh_interval: 5