diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index d023492..7edf2fc 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -14,6 +14,7 @@ env: IMAGE_MODEL_ID: ${{ secrets.IMAGE_MODEL_ID }} SYSTEM_MESSAGE: ${{ vars.SYSTEM_MESSAGE }} ALLOWED_CHANNEL_IDS: ${{ secrets.ALLOWED_CHANNEL_IDS }} + ENABLE_IMAGE: ${{ secrets.ENABLE_IMAGE }} jobs: deploy: @@ -52,6 +53,7 @@ jobs: echo "IMAGE_MODEL_ID=${IMAGE_MODEL_ID}" >> .env echo "SYSTEM_MESSAGE=${SYSTEM_MESSAGE}" >> .env echo "ALLOWED_CHANNEL_IDS=${ALLOWED_CHANNEL_IDS}" >> .env + echo "ENABLE_IMAGE=${ENABLE_IMAGE}" >> .env - name: Deploy to AWS Lambda 🚀 env: diff --git a/bedrock/invoke_claude_3.py b/bedrock/invoke_claude_3.py index 5737106..95f50f6 100644 --- a/bedrock/invoke_claude_3.py +++ b/bedrock/invoke_claude_3.py @@ -58,11 +58,11 @@ def invoke_claude_3(prompt): ) # Process and print the response - result = json.loads(response.get("body").read()) + body = json.loads(response.get("body").read()) - # print("response: {}".format(result)) + # print("response: {}".format(body)) - content = result.get("content", []) + content = body.get("content", []) for output in content: print(output["text"]) diff --git a/bedrock/invoke_claude_3_image.py b/bedrock/invoke_claude_3_image.py index f202050..d728f66 100644 --- a/bedrock/invoke_claude_3_image.py +++ b/bedrock/invoke_claude_3_image.py @@ -73,11 +73,11 @@ def invoke_claude_3(prompt): ) # Process and print the response - result = json.loads(response.get("body").read()) + body = json.loads(response.get("body").read()) - # print("response: {}".format(result)) + # print("response: {}".format(body)) - content = result.get("content", []) + content = body.get("content", []) for output in content: print(output["text"]) diff --git a/bedrock/invoke_stable_diffusion.py b/bedrock/invoke_stable_diffusion.py index 1f57a53..5de37c3 100644 --- a/bedrock/invoke_stable_diffusion.py +++ b/bedrock/invoke_stable_diffusion.py @@ -56,8 +56,12 @@ def invoke_stable_diffusion(prompt, seed=0, style_preset="photographic"): body=json.dumps(body), ) - response_body = json.loads(response["body"].read()) - base64_image = response_body.get("artifacts")[0].get("base64") + body = json.loads(response["body"].read()) + + # body["artifacts"][0]["base64"] = None + # print("response: {}".format(body)) + + base64_image = body.get("artifacts")[0].get("base64") base64_bytes = base64_image.encode("ascii") image_bytes = base64.b64decode(base64_bytes) diff --git a/handler.py b/handler.py index cc0ee03..5f06562 100644 --- a/handler.py +++ b/handler.py @@ -7,6 +7,9 @@ import time import base64 import requests +import io + +from PIL import Image from slack_bolt import App, Say from slack_bolt.adapter.aws_lambda import SlackRequestHandler @@ -30,6 +33,8 @@ # Set up the allowed channel ID ALLOWED_CHANNEL_IDS = os.environ.get("ALLOWED_CHANNEL_IDS", "AAA,CCC,EEE,GGG") +ENABLE_IMAGE = os.environ.get("ENABLE_IMAGE", "False") + # Set up System messages SYSTEM_MESSAGE = os.environ.get("SYSTEM_MESSAGE", "None") @@ -91,7 +96,7 @@ def chat_update(channel, ts, message, blocks=None): app.client.chat_update(channel=channel, ts=ts, text=message, blocks=blocks) -def invoke_claude_3(messages): +def invoke_claude_3(content): """ Invokes Anthropic Claude 3 Sonnet to run an inference using the input provided in the request body. @@ -100,13 +105,16 @@ def invoke_claude_3(messages): :return: Inference response from the model. """ - text = "" - try: body = { "anthropic_version": ANTHROPIC_VERSION, "max_tokens": ANTHROPIC_TOKENS, - "messages": messages, + "messages": [ + { + "role": "user", + "content": content, + }, + ], } if SYSTEM_MESSAGE != "None": @@ -118,13 +126,13 @@ def invoke_claude_3(messages): ) # Process and print the response - result = json.loads(response.get("body").read()) + body = json.loads(response.get("body").read()) - print("response: {}".format(result)) + print("response: {}".format(body)) - content = result.get("content", []) + result = body.get("content", []) - for output in content: + for output in result: text = output["text"] return text @@ -132,16 +140,57 @@ def invoke_claude_3(messages): except Exception as e: print("Error: {}".format(e)) + return None + + +def invoke_stable_diffusion(prompt, seed=0, style_preset="photographic"): + """ + Invokes the Stability.ai Stable Diffusion XL model to create an image using + the input provided in the request body. + + :param prompt: The prompt that you want Stable Diffusion to use for image generation. + :param seed: Random noise seed (omit this option or use 0 for a random seed) + :param style_preset: Pass in a style preset to guide the image model towards + a particular style. + :return: Base64-encoded inference response from the model. + """ + + try: + # The different model providers have individual request and response formats. + # For the format, ranges, and available style_presets of Stable Diffusion models refer to: + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-stability-diffusion.html + + body = { + "text_prompts": [{"text": prompt}], + "seed": seed, + "cfg_scale": 10, + "steps": 30, + "samples": 1, + } + + if style_preset: + body["style_preset"] = style_preset + + response = bedrock.invoke_model( + modelId=IMAGE_MODEL_ID, + body=json.dumps(body), + ) + + body = json.loads(response["body"].read()) + + base64_image = body.get("artifacts")[0].get("base64") + base64_bytes = base64_image.encode("ascii") + image_bytes = base64.b64decode(base64_bytes) -# Reply to the message -def reply_text(messages, channel, ts, user): - message = invoke_claude_3(messages) + image = Image.open(io.BytesIO(image_bytes)) + image.show() - message = message.replace("**", "*") + return image_bytes - chat_update(channel, ts, message) + except Exception as e: + print("Error: {}".format(e)) - return message + return None # Get thread messages using conversations.replies API method @@ -206,8 +255,10 @@ def conversation(say: Say, thread_ts, content, channel, user, client_msg_id): prompt = content[0]["text"] type = "text" - # if "그려줘" in prompt: - # type = "image" + if ENABLE_IMAGE == "True" and "그려줘" in prompt: + type = "image" + + print("conversation: {}".format(type)) prompts = [] @@ -225,52 +276,60 @@ def conversation(say: Say, thread_ts, content, channel, user, client_msg_id): content[0]["text"] = "Describe the image in great detail as if viewing a photo." - messages = [] - messages.append( - { - "role": "user", - "content": content, - }, - ) - # Send the prompt to Bedrock - message = invoke_claude_3(messages) + message = invoke_claude_3(content) prompts.append(message) - # Send the prompt to Bedrock if prompt: prompts.append(prompt) - # Send the prompt to Bedrock - try: - chat_update(channel, latest_ts, "응답 기다리는 중... " + BOT_CURSOR) + if type == "image": + chat_update(channel, latest_ts, "이미지 생성 준비 중... " + BOT_CURSOR) - messages = [] - messages.append( - { - "role": "user", - "content": [ - { - "type": "text", - "text": "\n\n\n".join(prompts), - } - ], - }, + prompts.append( + "Convert the above sentence into a command for stable-diffusion to generate an image within 1000 characters. Just give me a prompt." ) - print("conversation: {}".format(messages)) + prompt = "\n\n\n".join(prompts) + + content = [] + content.append({"type": "text", "text": prompt}) # Send the prompt to Bedrock - message = reply_text(messages, channel, latest_ts, user) + message = invoke_claude_3(content) - print("conversation: {}".format(message)) + chat_update(channel, latest_ts, "이미지 그리는 중... " + BOT_CURSOR) - except Exception as e: - print("conversation: Error handling message: {}".format(e)) + image = invoke_stable_diffusion(message) + + # Update the message in Slack + chat_update(channel, latest_ts, message) + + if image: + # Send the image to Slack + app.client.files_upload_v2( + channels=channel, + file=io.BytesIO(image), + title="Generated Image", + filename="image.jpg", + initial_comment="Here is the generated image.", + thread_ts=latest_ts, + ) + + else: + chat_update(channel, latest_ts, "응답 기다리는 중... " + BOT_CURSOR) + + prompt = "\n\n\n".join(prompts) + + content[0]["text"] = prompt + + # Send the prompt to Bedrock + message = invoke_claude_3(content) - message = f"```{e}```" + message = message.replace("**", "*") + # Update the message in Slack chat_update(channel, latest_ts, message)