From 6efa787b895d4d21d69cb00738b37684cee36d43 Mon Sep 17 00:00:00 2001 From: nalbam Date: Wed, 5 Jun 2024 11:49:47 +0900 Subject: [PATCH 1/4] add allowed_channel_ids --- .env.example | 6 +- .github/workflows/push.yml | 2 + handler.py | 120 +++++++++++++++++++++++++++++++++---- 3 files changed, 114 insertions(+), 14 deletions(-) diff --git a/.env.example b/.env.example index 95c2a0e..fa1420e 100644 --- a/.env.example +++ b/.env.example @@ -8,6 +8,8 @@ DYNAMODB_TABLE_NAME="gurumi-ai-bot-context" TEXT_MODEL_ID="anthropic.claude-3-sonnet-20240229-v1:0" IMAGE_MODEL_ID="stability.stable-diffusion-xl-v1" -SYSTEM_MESSAGE="너는 최대한 정확하고 신뢰할 수 있는 정보를 알려줘. 너는 항상 사용자를 존중해." +ALLOWED_CHANNEL_IDS="C000000,C000001" -TEMPERATURE="0" +SYSTEM_MESSAGE="너는 AWSKRUG(AWS Korea User Group)에서 친절하게 도움을 주는 구름이(Gurumi)야." + +MESSAGE_MAX="4000" diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 2d7f9b8..d023492 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -13,6 +13,7 @@ env: TEXT_MODEL_ID: ${{ secrets.TEXT_MODEL_ID }} IMAGE_MODEL_ID: ${{ secrets.IMAGE_MODEL_ID }} SYSTEM_MESSAGE: ${{ vars.SYSTEM_MESSAGE }} + ALLOWED_CHANNEL_IDS: ${{ secrets.ALLOWED_CHANNEL_IDS }} jobs: deploy: @@ -50,6 +51,7 @@ jobs: echo "TEXT_MODEL_ID=${TEXT_MODEL_ID}" >> .env echo "IMAGE_MODEL_ID=${IMAGE_MODEL_ID}" >> .env echo "SYSTEM_MESSAGE=${SYSTEM_MESSAGE}" >> .env + echo "ALLOWED_CHANNEL_IDS=${ALLOWED_CHANNEL_IDS}" >> .env - name: Deploy to AWS Lambda 🚀 env: diff --git a/handler.py b/handler.py index cfb95ac..cc0ee03 100644 --- a/handler.py +++ b/handler.py @@ -5,6 +5,8 @@ import re import sys import time +import base64 +import requests from slack_bolt import App, Say from slack_bolt.adapter.aws_lambda import SlackRequestHandler @@ -25,6 +27,9 @@ ANTHROPIC_VERSION = os.environ.get("ANTHROPIC_VERSION", "bedrock-2023-05-31") ANTHROPIC_TOKENS = int(os.environ.get("ANTHROPIC_TOKENS", 1024)) +# Set up the allowed channel ID +ALLOWED_CHANNEL_IDS = os.environ.get("ALLOWED_CHANNEL_IDS", "AAA,CCC,EEE,GGG") + # Set up System messages SYSTEM_MESSAGE = os.environ.get("SYSTEM_MESSAGE", "None") @@ -115,11 +120,11 @@ def invoke_claude_3(messages): # Process and print the response result = json.loads(response.get("body").read()) - output_list = result.get("content", []) + print("response: {}".format(result)) - print(f"- The model returned {len(output_list)} response(s):") + content = result.get("content", []) - for output in output_list: + for output in content: text = output["text"] return text @@ -191,13 +196,19 @@ def conversations_replies(channel, ts, client_msg_id): # Handle the chatgpt conversation -def conversation(say: Say, thread_ts, prompt, channel, user, client_msg_id): - print("conversation: {}".format(json.dumps(prompt))) +def conversation(say: Say, thread_ts, content, channel, user, client_msg_id): + print("conversation: {}".format(json.dumps(content))) # Keep track of the latest message timestamp result = say(text=BOT_CURSOR, thread_ts=thread_ts) latest_ts = result["ts"] + prompt = content[0]["text"] + + type = "text" + # if "그려줘" in prompt: + # type = "image" + prompts = [] # Get the thread messages @@ -208,6 +219,25 @@ def conversation(say: Say, thread_ts, prompt, channel, user, client_msg_id): prompts = [reply["content"] for reply in replies if reply["content"].strip()] + # Get the image from the message + if type == "image" and len(content) > 1: + chat_update(channel, latest_ts, "이미지 감상 중... " + BOT_CURSOR) + + content[0]["text"] = "Describe the image in great detail as if viewing a photo." + + messages = [] + messages.append( + { + "role": "user", + "content": content, + }, + ) + + # Send the prompt to Bedrock + message = invoke_claude_3(messages) + + prompts.append(message) + # Send the prompt to Bedrock if prompt: prompts.append(prompt) @@ -244,6 +274,64 @@ def conversation(say: Say, thread_ts, prompt, channel, user, client_msg_id): chat_update(channel, latest_ts, message) +# Get image from URL +def get_image_from_url(image_url, token=None): + headers = {} + if token: + headers["Authorization"] = f"Bearer {token}" + + response = requests.get(image_url, headers=headers) + + if response.status_code == 200: + return response.content + else: + print("Failed to fetch image: {}".format(image_url)) + + return None + + +# Get image from Slack +def get_image_from_slack(image_url): + return get_image_from_url(image_url, SLACK_BOT_TOKEN) + + +# Get encoded image from Slack +def get_encoded_image_from_slack(image_url): + image = get_image_from_slack(image_url) + + if image: + return base64.b64encode(image).decode("utf-8") + + return None + + +# Extract content from the message +def content_from_message(prompt, event): + content = [] + content.append({"type": "text", "text": prompt}) + + if "files" in event: + files = event.get("files", []) + for file in files: + mimetype = file["mimetype"] + if mimetype.startswith("image"): + image_url = file.get("url_private") + base64_image = get_encoded_image_from_slack(image_url) + if base64_image: + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": mimetype, + "data": base64_image, + }, + } + ) + + return content + + # Handle the app_mention event @app.event("app_mention") def handle_mention(body: dict, say: Say): @@ -254,15 +342,22 @@ def handle_mention(body: dict, say: Say): if "bot_id" in event: # Ignore messages from the bot itself return - thread_ts = event["thread_ts"] if "thread_ts" in event else event["ts"] - prompt = re.sub(f"<@{bot_id}>", "", event["text"]).strip() channel = event["channel"] + + allowed_channel_ids = ALLOWED_CHANNEL_IDS.split(",") + if channel not in allowed_channel_ids: + # say("Sorry, I'm not allowed to respond in this channel.") + return + + thread_ts = event["thread_ts"] if "thread_ts" in event else event["ts"] user = event["user"] client_msg_id = event["client_msg_id"] - # content = content_from_message(prompt, event) + prompt = re.sub(f"<@{bot_id}>", "", event["text"]).strip() + + content = content_from_message(prompt, event) - conversation(say, thread_ts, prompt, channel, user, client_msg_id) + conversation(say, thread_ts, content, channel, user, client_msg_id) # Handle the DM (direct message) event @@ -275,15 +370,16 @@ def handle_message(body: dict, say: Say): if "bot_id" in event: # Ignore messages from the bot itself return - prompt = event["text"].strip() channel = event["channel"] user = event["user"] client_msg_id = event["client_msg_id"] - # content = content_from_message(prompt, event) + prompt = event["text"].strip() + + content = content_from_message(prompt, event) # Use thread_ts=None for regular messages, and user ID for DMs - conversation(say, None, prompt, channel, user, client_msg_id) + conversation(say, None, content, channel, user, client_msg_id) # Handle the Lambda function From 9165ff3173df5f37d93e2970f028c1d46a1e6c60 Mon Sep 17 00:00:00 2001 From: nalbam Date: Wed, 5 Jun 2024 12:00:51 +0900 Subject: [PATCH 2/4] add allowed channel message --- handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handler.py b/handler.py index cc0ee03..84ea30a 100644 --- a/handler.py +++ b/handler.py @@ -346,7 +346,7 @@ def handle_mention(body: dict, say: Say): allowed_channel_ids = ALLOWED_CHANNEL_IDS.split(",") if channel not in allowed_channel_ids: - # say("Sorry, I'm not allowed to respond in this channel.") + say("Sorry, I'm not allowed to respond in this channel.") return thread_ts = event["thread_ts"] if "thread_ts" in event else event["ts"] From 2fa04bb35faa92a84211b4b3e558bb9a8615acf4 Mon Sep 17 00:00:00 2001 From: nalbam Date: Wed, 5 Jun 2024 12:04:24 +0900 Subject: [PATCH 3/4] add requests --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 0233d7a..6661a16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ boto3 slack-bolt slack-sdk pillow +requests From d1a42a799815b57376683aa6245a3b1d6b3b9132 Mon Sep 17 00:00:00 2001 From: nalbam Date: Wed, 5 Jun 2024 12:12:57 +0900 Subject: [PATCH 4/4] rm allowed message --- handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/handler.py b/handler.py index 84ea30a..cc0ee03 100644 --- a/handler.py +++ b/handler.py @@ -346,7 +346,7 @@ def handle_mention(body: dict, say: Say): allowed_channel_ids = ALLOWED_CHANNEL_IDS.split(",") if channel not in allowed_channel_ids: - say("Sorry, I'm not allowed to respond in this channel.") + # say("Sorry, I'm not allowed to respond in this channel.") return thread_ts = event["thread_ts"] if "thread_ts" in event else event["ts"]