Skip to content

Commit

Permalink
Merge pull request #1 from nalbam/main
Browse files Browse the repository at this point in the history
add allowed_channel_ids
  • Loading branch information
nalbam authored Jun 5, 2024
2 parents 5d4fc89 + d1a42a7 commit 847ad9f
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 14 deletions.
6 changes: 4 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ DYNAMODB_TABLE_NAME="gurumi-ai-bot-context"
TEXT_MODEL_ID="anthropic.claude-3-sonnet-20240229-v1:0"
IMAGE_MODEL_ID="stability.stable-diffusion-xl-v1"

SYSTEM_MESSAGE="너는 최대한 정확하고 신뢰할 수 있는 정보를 알려줘. 너는 항상 사용자를 존중해."
ALLOWED_CHANNEL_IDS="C000000,C000001"

TEMPERATURE="0"
SYSTEM_MESSAGE="너는 AWSKRUG(AWS Korea User Group)에서 친절하게 도움을 주는 구름이(Gurumi)야."

MESSAGE_MAX="4000"
2 changes: 2 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ env:
TEXT_MODEL_ID: ${{ secrets.TEXT_MODEL_ID }}
IMAGE_MODEL_ID: ${{ secrets.IMAGE_MODEL_ID }}
SYSTEM_MESSAGE: ${{ vars.SYSTEM_MESSAGE }}
ALLOWED_CHANNEL_IDS: ${{ secrets.ALLOWED_CHANNEL_IDS }}

jobs:
deploy:
Expand Down Expand Up @@ -50,6 +51,7 @@ jobs:
echo "TEXT_MODEL_ID=${TEXT_MODEL_ID}" >> .env
echo "IMAGE_MODEL_ID=${IMAGE_MODEL_ID}" >> .env
echo "SYSTEM_MESSAGE=${SYSTEM_MESSAGE}" >> .env
echo "ALLOWED_CHANNEL_IDS=${ALLOWED_CHANNEL_IDS}" >> .env
- name: Deploy to AWS Lambda 🚀
env:
Expand Down
120 changes: 108 additions & 12 deletions handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import re
import sys
import time
import base64
import requests

from slack_bolt import App, Say
from slack_bolt.adapter.aws_lambda import SlackRequestHandler
Expand All @@ -25,6 +27,9 @@
ANTHROPIC_VERSION = os.environ.get("ANTHROPIC_VERSION", "bedrock-2023-05-31")
ANTHROPIC_TOKENS = int(os.environ.get("ANTHROPIC_TOKENS", 1024))

# Set up the allowed channel ID
ALLOWED_CHANNEL_IDS = os.environ.get("ALLOWED_CHANNEL_IDS", "AAA,CCC,EEE,GGG")

# Set up System messages
SYSTEM_MESSAGE = os.environ.get("SYSTEM_MESSAGE", "None")

Expand Down Expand Up @@ -115,11 +120,11 @@ def invoke_claude_3(messages):
# Process and print the response
result = json.loads(response.get("body").read())

output_list = result.get("content", [])
print("response: {}".format(result))

print(f"- The model returned {len(output_list)} response(s):")
content = result.get("content", [])

for output in output_list:
for output in content:
text = output["text"]

return text
Expand Down Expand Up @@ -191,13 +196,19 @@ def conversations_replies(channel, ts, client_msg_id):


# Handle the chatgpt conversation
def conversation(say: Say, thread_ts, prompt, channel, user, client_msg_id):
print("conversation: {}".format(json.dumps(prompt)))
def conversation(say: Say, thread_ts, content, channel, user, client_msg_id):
print("conversation: {}".format(json.dumps(content)))

# Keep track of the latest message timestamp
result = say(text=BOT_CURSOR, thread_ts=thread_ts)
latest_ts = result["ts"]

prompt = content[0]["text"]

type = "text"
# if "그려줘" in prompt:
# type = "image"

prompts = []

# Get the thread messages
Expand All @@ -208,6 +219,25 @@ def conversation(say: Say, thread_ts, prompt, channel, user, client_msg_id):

prompts = [reply["content"] for reply in replies if reply["content"].strip()]

# Get the image from the message
if type == "image" and len(content) > 1:
chat_update(channel, latest_ts, "이미지 감상 중... " + BOT_CURSOR)

content[0]["text"] = "Describe the image in great detail as if viewing a photo."

messages = []
messages.append(
{
"role": "user",
"content": content,
},
)

# Send the prompt to Bedrock
message = invoke_claude_3(messages)

prompts.append(message)

# Send the prompt to Bedrock
if prompt:
prompts.append(prompt)
Expand Down Expand Up @@ -244,6 +274,64 @@ def conversation(say: Say, thread_ts, prompt, channel, user, client_msg_id):
chat_update(channel, latest_ts, message)


# Get image from URL
def get_image_from_url(image_url, token=None):
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"

response = requests.get(image_url, headers=headers)

if response.status_code == 200:
return response.content
else:
print("Failed to fetch image: {}".format(image_url))

return None


# Get image from Slack
def get_image_from_slack(image_url):
return get_image_from_url(image_url, SLACK_BOT_TOKEN)


# Get encoded image from Slack
def get_encoded_image_from_slack(image_url):
image = get_image_from_slack(image_url)

if image:
return base64.b64encode(image).decode("utf-8")

return None


# Extract content from the message
def content_from_message(prompt, event):
content = []
content.append({"type": "text", "text": prompt})

if "files" in event:
files = event.get("files", [])
for file in files:
mimetype = file["mimetype"]
if mimetype.startswith("image"):
image_url = file.get("url_private")
base64_image = get_encoded_image_from_slack(image_url)
if base64_image:
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mimetype,
"data": base64_image,
},
}
)

return content


# Handle the app_mention event
@app.event("app_mention")
def handle_mention(body: dict, say: Say):
Expand All @@ -254,15 +342,22 @@ def handle_mention(body: dict, say: Say):
if "bot_id" in event: # Ignore messages from the bot itself
return

thread_ts = event["thread_ts"] if "thread_ts" in event else event["ts"]
prompt = re.sub(f"<@{bot_id}>", "", event["text"]).strip()
channel = event["channel"]

allowed_channel_ids = ALLOWED_CHANNEL_IDS.split(",")
if channel not in allowed_channel_ids:
# say("Sorry, I'm not allowed to respond in this channel.")
return

thread_ts = event["thread_ts"] if "thread_ts" in event else event["ts"]
user = event["user"]
client_msg_id = event["client_msg_id"]

# content = content_from_message(prompt, event)
prompt = re.sub(f"<@{bot_id}>", "", event["text"]).strip()

content = content_from_message(prompt, event)

conversation(say, thread_ts, prompt, channel, user, client_msg_id)
conversation(say, thread_ts, content, channel, user, client_msg_id)


# Handle the DM (direct message) event
Expand All @@ -275,15 +370,16 @@ def handle_message(body: dict, say: Say):
if "bot_id" in event: # Ignore messages from the bot itself
return

prompt = event["text"].strip()
channel = event["channel"]
user = event["user"]
client_msg_id = event["client_msg_id"]

# content = content_from_message(prompt, event)
prompt = event["text"].strip()

content = content_from_message(prompt, event)

# Use thread_ts=None for regular messages, and user ID for DMs
conversation(say, None, prompt, channel, user, client_msg_id)
conversation(say, None, content, channel, user, client_msg_id)


# Handle the Lambda function
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ boto3
slack-bolt
slack-sdk
pillow
requests

0 comments on commit 847ad9f

Please sign in to comment.