Skip to content

Commit

Permalink
Merge pull request #13 from nalbam/main
Browse files Browse the repository at this point in the history
feat: Add knowledge base retrieval functionality
  • Loading branch information
nalbam authored Aug 14, 2024
2 parents 0a936c2 + 46935fa commit f83af18
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 4 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ env:
BOT_CURSOR: ${{ vars.BOT_CURSOR }}
DYNAMODB_TABLE_NAME: ${{ vars.DYNAMODB_TABLE_NAME }}
ENABLE_IMAGE: ${{ vars.ENABLE_IMAGE }}
KB_ID: ${{ vars.KB_ID }}
MODEL_ID_IMAGE: ${{ vars.MODEL_ID_IMAGE }}
SYSTEM_MESSAGE: ${{ vars.SYSTEM_MESSAGE }}
MODEL_ID_TEXT: ${{ vars.MODEL_ID_TEXT }}
SYSTEM_MESSAGE: ${{ vars.SYSTEM_MESSAGE }}

AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
Expand Down Expand Up @@ -52,11 +53,12 @@ jobs:
echo "BOT_CURSOR=${BOT_CURSOR}" >> .env
echo "DYNAMODB_TABLE_NAME=${DYNAMODB_TABLE_NAME}" >> .env
echo "ENABLE_IMAGE=${ENABLE_IMAGE}" >> .env
echo "KB_ID=${KB_ID}" >> .env
echo "MODEL_ID_IMAGE=${MODEL_ID_IMAGE}" >> .env
echo "MODEL_ID_TEXT=${MODEL_ID_TEXT}" >> .env
echo "SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN}" >> .env
echo "SLACK_SIGNING_SECRET=${SLACK_SIGNING_SECRET}" >> .env
echo "SYSTEM_MESSAGE=${SYSTEM_MESSAGE}" >> .env
echo "MODEL_ID_TEXT=${MODEL_ID_TEXT}" >> .env
- name: Deploy to AWS Lambda 🚀
run: npx serverless deploy --region us-east-1
150 changes: 150 additions & 0 deletions bedrock/invoke_knowledge_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import json
import boto3

from botocore.client import Config


bedrock_config = Config(
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
)
bedrock_client = boto3.client("bedrock-runtime", region_name="us-east-1")
bedrock_agent_client = boto3.client(
"bedrock-agent-runtime", region_name="us-east-1", config=bedrock_config
)

model_id = "anthropic.claude-v2:1" # try with both claude instant as well as claude-v2. for claude v2 - "anthropic.claude-v2"
region_id = "us-east-1" # replace it with the region you're running sagemaker notebook

SYSTEM_MESSAGE = "답변은 한국어 해요체로 해요."


def parse_args():
p = argparse.ArgumentParser(description="invoke_claude_3")
p.add_argument("-p", "--prompt", default="안녕", help="prompt")
p.add_argument("-d", "--debug", default="False", help="debug")
return p.parse_args()


def retrieve(query, kbId, numberOfResults=5):
return bedrock_agent_client.retrieve(
retrievalQuery={"text": query},
knowledgeBaseId=kbId,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": numberOfResults,
# "overrideSearchType": "HYBRID", # optional
}
},
)


def retrieveAndGenerate(
input,
kbId,
sessionId=None,
model_id="anthropic.claude-v2:1",
region_id="us-east-1",
):
model_arn = f"arn:aws:bedrock:{region_id}::foundation-model/{model_id}"
if sessionId:
return bedrock_agent_client.retrieve_and_generate(
input={"text": input},
retrieveAndGenerateConfiguration={
"type": "KNOWLEDGE_BASE",
"knowledgeBaseConfiguration": {
"knowledgeBaseId": kbId,
"modelArn": model_arn,
},
},
sessionId=sessionId,
)
else:
return bedrock_agent_client.retrieve_and_generate(
input={"text": input},
retrieveAndGenerateConfiguration={
"type": "KNOWLEDGE_BASE",
"knowledgeBaseConfiguration": {
"knowledgeBaseId": kbId,
"modelArn": model_arn,
},
},
)


# fetch context from the response
def get_contexts(retrievalResults):
contexts = []
for retrievedResult in retrievalResults:
contexts.append(retrievedResult["content"]["text"])
return contexts


def main():
# args = parse_args()

kb_id = "DQXVNP05K5"

query = "kontrol의 기능 알려줘."

# response = retrieveAndGenerate(query, kb_id, model_id=model_id, region_id=region_id)
# generated_text = response["output"]["text"]
# print(generated_text)

response = retrieve(query, kb_id, 3)
retrievalResults = response["retrievalResults"]
# print(retrievalResults)

contexts = get_contexts(retrievalResults)
# print(contexts)

prompt = f"""
Human: You are a financial advisor AI system, and provides answers to questions by using fact based and statistical information when possible.
Use the following pieces of information to provide a concise answer to the question enclosed in <question> tags.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
<context>
{contexts}
</context>
<question>
{query}
</question>
The response should be specific and use statistics or numbers when possible.
Assistant:"""

# payload with model paramters
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
}
]
sonnet_payload = json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 512,
"messages": messages,
"temperature": 0.5,
"top_p": 1,
}
)

modelId = "anthropic.claude-3-sonnet-20240229-v1:0" # change this to use a different version from the model provider
accept = "application/json"
contentType = "application/json"
response = bedrock_client.invoke_model(
body=sonnet_payload, modelId=modelId, accept=accept, contentType=contentType
)
response_body = json.loads(response.get("body").read())
response_text = response_body.get("content")[0]["text"]

print(response_text)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion bedrock/invoke_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def parse_args():
p = argparse.ArgumentParser(description="invoke_claude_3")
p = argparse.ArgumentParser(description="invoke_stable_diffusion")
p.add_argument("-p", "--prompt", default="Hello", help="prompt", required=True)
p.add_argument("-d", "--debug", default="False", help="debug")
return p.parse_args()
Expand Down
60 changes: 59 additions & 1 deletion handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,26 @@
import requests
import io

from botocore.client import Config

from slack_bolt import App, Say
from slack_bolt.adapter.aws_lambda import SlackRequestHandler


BOT_CURSOR = os.environ.get("BOT_CURSOR", ":robot_face:")

AWS_REGION = os.environ.get("AWS_REGION", "us-east-1")

# Set up Slack API credentials
SLACK_BOT_TOKEN = os.environ["SLACK_BOT_TOKEN"]
SLACK_SIGNING_SECRET = os.environ["SLACK_SIGNING_SECRET"]

# Keep track of conversation history by thread and user
DYNAMODB_TABLE_NAME = os.environ.get("DYNAMODB_TABLE_NAME", "gurumi-ai-bot-context")

# Amazon Bedrock Knowledge Base ID
KB_ID = os.environ.get("KB_ID", "None")

# Amazon Bedrock Model ID
MODEL_ID_TEXT = os.environ.get("MODEL_ID_TEXT", "anthropic.claude-3")
MODEL_ID_IMAGE = os.environ.get("MODEL_ID_IMAGE", "stability.stable-diffusion-xl")
Expand Down Expand Up @@ -73,7 +81,14 @@
table = dynamodb.Table(DYNAMODB_TABLE_NAME)

# Initialize the Amazon Bedrock runtime client
bedrock = boto3.client(service_name="bedrock-runtime", region_name="us-east-1")
bedrock = boto3.client(service_name="bedrock-runtime", region_name=AWS_REGION)

bedrock_config = Config(
connect_timeout=120, read_timeout=120, retries={"max_attempts": 0}
)
bedrock_agent_client = boto3.client(
"bedrock-agent-runtime", region_name=AWS_REGION, config=bedrock_config
)


# Get the context from DynamoDB
Expand Down Expand Up @@ -161,6 +176,41 @@ def chat_update(say, channel, thread_ts, latest_ts, message="", continue_thread=
return message, latest_ts


def invoke_knowledge_base(content):
"""
Invokes the Amazon Bedrock Knowledge Base to retrieve information using the input
provided in the request body.
:param content: The content that you want to use for retrieval.
:return: The retrieved contexts from the knowledge base.
"""

try:
response = bedrock_agent_client.retrieve(
retrievalQuery={"text": content},
knowledgeBaseId=KB_ID,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": 3,
# "overrideSearchType": "HYBRID", # optional
}
},
)

retrievalResults = response["retrievalResults"]

contexts = []
for retrievedResult in retrievalResults:
contexts.append(retrievedResult["content"]["text"])

return contexts

except Exception as e:
print("invoke_knowledge_base: Error: {}".format(e))

raise e


def invoke_claude_3(content):
"""
Invokes Anthropic Claude 3 Sonnet to run an inference using the input
Expand Down Expand Up @@ -347,6 +397,14 @@ def conversation(say: Say, thread_ts, content, channel, user, client_msg_id):

prompts.append(message)

if KB_ID != "None":
chat_update(say, channel, thread_ts, latest_ts, MSG_RESPONSE)

# Get the knowledge base contexts
contexts = invoke_knowledge_base(prompt)

prompts.extend(contexts)

if prompt:
prompts.append(prompt)

Expand Down
5 changes: 5 additions & 0 deletions serverless.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ provider:
- dynamodb:*
Resource:
- "arn:aws:dynamodb:*:*:table/${self:provider.environment.DYNAMODB_TABLE_NAME}"
- Effect: Allow
Action:
- bedrock:Retrieve
Resource:
- "arn:aws:bedrock:*:*:knowledge-base/*"
- Effect: Allow
Action:
- bedrock:InvokeModel
Expand Down

0 comments on commit f83af18

Please sign in to comment.