Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hist multi llm #277

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
156 changes: 130 additions & 26 deletions functions/chain.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,142 @@
# def create_health_ai_chain(llm, vector_store):
# retriever = SelfQueryRetriever.from_llm(
# llm=llm,
# vectorstore=vector_store,
# document_content_description=document_content_description,
# metadata_field_info=metadata_field_info,
# document_contents='',
# )
# health_ai_template = """
# You are a health AI agent equipped with access to diverse sources of health data,
# including research articles, nutritional information, medical archives, and more.
# Your task is to provide informed answers to user queries based on the available data.
# If you cannot find relevant information, simply state that you do not have enough data
# to answer accurately. write your response in markdown form and also add reference url
# so user can know from which source you are answering the questions.
# CONTEXT:
# {context}
# QUESTION: {question}
# YOUR ANSWER:
# """
# health_ai_prompt = ChatPromptTemplate.from_template(health_ai_template)
# chain = (
# {'context': retriever, 'question': RunnablePassthrough()}
# | health_ai_prompt
# | llm
# | StrOutputParser()
# )
# return chain
import logging
from os import environ

from get_google_docs import get_inital_prompt
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_anthropic import ChatAnthropic

# separated files
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import OpenAI
from meta import document_content_description, metadata_field_info
from store import get_vector_store


def custom_history(entire_history: list, llm_name: str):
if not entire_history or not isinstance(entire_history, list):
logging.error("Invalid 'entire_history': Must be a non-empty list.")
return []

chat_history = []
for msg in entire_history:
if not isinstance(msg, dict):
logging.warning('Skipping invalid message format: Not a dictionary.')
continue

msg_type = msg.get('type')
message_content = msg.get('message')

if msg_type == 'USER':
if message_content is None:
continue
chat_history.append(HumanMessage(content=message_content))
elif msg_type == 'AI':
if not isinstance(message_content, dict) or llm_name not in message_content:
continue
chat_history.append(AIMessage(content=message_content[llm_name]))
else:
logging.warning(f'Skipping message with unrecognized type: {msg_type}.')

def create_health_ai_chain(llm, vector_store):
return chat_history


def hist_aware_answers(llm_name, input_string, message_history):
vector_store = get_vector_store()
get_init_answer = get_inital_prompt()
init_prompt = '' if get_init_answer is None else get_init_answer
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
# add in custom user info: -----------------------------
# custom_istructions = get_custom_instructions_callable()
# user_info = " "
# if custom_istructions:
# user_info = f"""Here is some information about the user, including the user's name,
# their profile description and style instructions on how they want you to answer stylewise:
# User Name: {custom_istructions['name']}
# Style Instrctions: {custom_istructions['styleInstructions']}
# Personal Info: {custom_istructions['personalInstructions']}
# """

agent_str = """
You are a health AI agent equipped with access to diverse sources of health data,
including research articles, nutritional information, medical archives, and more.
Your task is to provide informed answers to user queries based on the available data.
If you cannot find relevant information, simply state that you do not have enough data
to answer accurately. write your response in markdown form and also add reference url
so user can know from which source you are answering the questions.
"""

context_str = """
CONTEXT:
{context}

"""
# health_ai_template = f'{init_p,rompt}{agent_str}{user_info}{context_str}'
health_ai_template = f'{init_prompt}{agent_str}{context_str}'
chat_history = custom_history(message_history, llm_name)
if llm_name == 'gpt-4':
llm = OpenAI(temperature=0.2, api_key=environ.get('OPENAI_API_KEY'))
elif llm_name == 'gemini':
llm = ChatGoogleGenerativeAI(
model='gemini-1.5-pro-latest', google_api_key=environ.get('GOOGLE_API_KEY')
)
elif llm_name == 'claude':
llm = ChatAnthropic(
model='claude-3-5-sonnet-20240620', api_key=environ.get('ANTHROPIC_API_KEY')
)
retriever = SelfQueryRetriever.from_llm(
llm=llm,
vectorstore=vector_store,
document_content_description=document_content_description,
metadata_field_info=metadata_field_info,
document_contents='',
)
health_ai_template = """
You are a health AI agent equipped with access to diverse sources of health data,
including research articles, nutritional information, medical archives, and more.
Your task is to provide informed answers to user queries based on the available data.
If you cannot find relevant information, simply state that you do not have enough data
to answer accurately. write your response in markdown form and also add reference url
so user can know from which source you are answering the questions.

CONTEXT:
{context}

QUESTION: {question}

YOUR ANSWER:
"""
health_ai_prompt = ChatPromptTemplate.from_template(health_ai_template)
chain = (
{'context': retriever, 'question': RunnablePassthrough()}
| health_ai_prompt
| llm
| StrOutputParser()
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
('system', contextualize_q_system_prompt),
MessagesPlaceholder('chat_history'),
('human', '{input}'),
]
)
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
qa_prompt = ChatPromptTemplate.from_messages(
[('system', health_ai_template), MessagesPlaceholder('chat_history'), ('human', '{input}')]
)
return chain
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
msg = rag_chain.invoke({'input': input_string, 'chat_history': chat_history})
return msg['answer']
81 changes: 81 additions & 0 deletions functions/get_google_docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import io
import os
import pickle
import re

from google.auth.transport.requests import Request
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload


def extract_document_id_from_url(url):
pattern = r'/d/([a-zA-Z0-9-_]+)'
matches = re.findall(pattern, url)
document_id = max(matches, key=len)
return document_id


def authenticate(credentials, scopes):
"""Obtaining auth with needed apis"""
creds = None
# The file token.pickle stores the user's access
# and refresh tokens, and is created automatically
# when the authorization flow completes for the first time.
if os.path.exists('token.pickle'):
with open('token.pickle', 'rb') as token:
creds = pickle.load(token)
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(credentials, scopes)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open('token.pickle', 'wb') as token:
pickle.dump(creds, token)

return creds


def download_file(file_id, credentials_path):
scopes = ['https://www.googleapis.com/auth/drive.readonly']
credentials = authenticate(credentials_path, scopes)
drive_service = build('drive', 'v3', credentials=credentials)

# Export the Google Docs file as plain text
export_mime_type = 'text/plain'
request = drive_service.files().export_media(fileId=file_id, mimeType=export_mime_type)

# Use a BytesIO buffer to handle the file content in memory
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while not done:
status, done = downloader.next_chunk()
print(f'Download {int(status.progress() * 100)}%.')

# Reset the buffer's position to the beginning
fh.seek(0)

# Read the content of the buffer
content = fh.read().decode('utf-8')

return content


def get_inital_prompt():
# Example usage
document_id = extract_document_id_from_url(
'https://docs.google.com/document/d/1GtLyBqhk-cu8CSo4A15WTgGDbMbL4B9LLjdvBoU3234/edit'
)
# print("Document id: ", document_id)
credentials_json = 'credentials.json'

try:
content = download_file(document_id, credentials_json)
return content
except Exception as e:
print(f'An error occurred: {e}')
return None
19 changes: 10 additions & 9 deletions functions/main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from json import dumps

# from handlers import get_response_from_llm
from chain import hist_aware_answers
from firebase_functions import https_fn, options
from handlers import get_response_from_llm


@https_fn.on_request(cors=options.CorsOptions(cors_origins=['*']))
@https_fn.on_request(memory=options.MemoryOption.GB_32, cpu=8, timeout_sec=540)
def get_response_url(req: https_fn.Request) -> https_fn.Response:
query = req.get_json().get('query', '')
llms = req.get_json().get('llms', ['gpt-4'])
llms = req.get_json().get('llms', ['gpt-4', 'gemini', 'claude'])
chat = req.get_json().get('history', [])
responses = {}
for llm in llms:
response = get_response_from_llm(query, llm)
responses[llm] = response
responses[llm] = hist_aware_answers(llm, query, chat)
return https_fn.Response(dumps(responses), mimetype='application/json')


@https_fn.on_call()
@https_fn.on_call(memory=options.MemoryOption.GB_32, cpu=8, timeout_sec=540)
def get_response(req: https_fn.CallableRequest):
query = req.data.get('query', '')
llms = req.data.get('llms', ['gpt-4'])
llms = req.get_json().get('llms', ['gpt-4', 'gemini', 'claude'])
chat = req.get_json().get('history', [])
responses = {}
for llm in llms:
response = get_response_from_llm(query, llm)
responses[llm] = response
responses[llm] = hist_aware_answers(llm, query, chat)
return responses
7 changes: 7 additions & 0 deletions functions/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ langchain-community
langchain-openai
langchain-astradb
lark
langchain_core
langchain_google_genai
langchain_anthropic
google-auth
google-auth-oauthlib
google-api-python-client
python-dotenv
Loading
Loading