Skip to content

修复个人公众号私信File "/app/channel/chat_channel.py", line 261xxxxxxx("single… #2451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 50 additions & 50 deletions bot/gemini/google_gemini_bot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""
Google gemini bot

@author zhayujie
@Date 2023/12/15
Optimized Google Gemini Bot
"""
# encoding:utf-8

import time
from bot.bot import Bot
import google.generativeai as genai
from bot.session_manager import SessionManager
Expand All @@ -14,33 +12,24 @@
from common.log import logger
from config import conf
from bot.chatgpt.chat_gpt_session import ChatGPTSession
from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
from google.generativeai.types import HarmCategory, HarmBlockThreshold


# OpenAI对话模型API (可用)
class GoogleGeminiBot(Bot):

def __init__(self):
super().__init__()
self.api_key = conf().get("gemini_api_key")
# 复用chatGPT的token计算方式
self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
self.model = conf().get("model") or "gemini-pro"
if self.model == "gemini":
self.model = "gemini-pro"

def reply(self, query, context: Context = None) -> Reply:
try:
if context.type != ContextType.TEXT:
logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
return Reply(ReplyType.TEXT, None)
logger.info(f"[Gemini] query={query}")
session_id = context["session_id"]
session = self.sessions.session_query(query, session_id)
gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
logger.debug(f"[Gemini] messages={gemini_messages}")
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(self.model)
gemini_messages = self._prepare_messages(query, context, session.messages)

# 添加安全设置
safety_settings = {
Expand All @@ -49,67 +38,78 @@ def reply(self, query, context: Context = None) -> Reply:
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
# 生成回复,包含安全设置
response = model.generate_content(

# 生成回复
response = genai.GenerativeModel(self.model).generate_content(
gemini_messages,
safety_settings=safety_settings
)

if response.candidates and response.candidates[0].content:
reply_text = response.candidates[0].content.parts[0].text
logger.info(f"[Gemini] reply={reply_text}")
self.sessions.session_reply(reply_text, session_id)
return Reply(ReplyType.TEXT, reply_text)
else:
# 没有有效响应内容,可能内容被屏蔽,输出安全评分
logger.warning("[Gemini] No valid response generated. Checking safety ratings.")
if hasattr(response, 'candidates') and response.candidates:
for rating in response.candidates[0].safety_ratings:
logger.warning(f"Safety rating: {rating.category} - {rating.probability}")
self._log_safety_ratings(response)
error_message = "No valid response generated due to safety constraints."
logger.warning(error_message)
self.sessions.session_reply(error_message, session_id)
return Reply(ReplyType.ERROR, error_message)

except Exception as e:
logger.error(f"[Gemini] Error generating response: {str(e)}", exc_info=True)
error_message = "Failed to invoke [Gemini] api!"
error_message = "Failed to invoke [Gemini] API!"
self.sessions.session_reply(error_message, session_id)
return Reply(ReplyType.ERROR, error_message)

def _convert_to_gemini_messages(self, messages: list):

def _prepare_messages(self, query, context, messages):
"""Prepare messages based on context type."""
if context.type == ContextType.TEXT:
return self._convert_to_gemini_messages(self.filter_messages(messages))
elif context.type in {ContextType.IMAGE, ContextType.AUDIO, ContextType.VIDEO}:
media_file = self._upload_and_process_file(context)
return [media_file, "\n\n", query]
else:
raise ValueError(f"Unsupported input type: {context.type}")

def _upload_and_process_file(self, context):
"""Handle media file upload and processing."""
media_file = genai.upload_file(context.content)
if context.type == ContextType.VIDEO:
while media_file.state.name == "PROCESSING":
logger.info(f"Video file {media_file.name} is processing...")
time.sleep(5)
media_file = genai.get_file(media_file.name)
logger.info(f"Media file {media_file.name} uploaded successfully.")
return media_file

def _log_safety_ratings(self, response):
"""Log safety ratings if no valid response is generated."""
if hasattr(response, 'candidates') and response.candidates:
for rating in response.candidates[0].safety_ratings:
logger.warning(f"Safety rating: {rating.category} - {rating.probability}")

def _convert_to_gemini_messages(self, messages):
if isinstance(messages, str):
return [{"role": "user", "parts": [{"text": messages}]}]
res = []
for msg in messages:
if msg.get("role") == "user":
role = "user"
elif msg.get("role") == "assistant":
role = "model"
elif msg.get("role") == "system":
role = "user"
else:
continue
res.append({
"role": role,
"parts": [{"text": msg.get("content")}]
})
role = {"user": "user", "assistant": "model", "system": "user"}.get(msg.get("role"))
if role:
res.append({"role": role, "parts": [{"text": msg.get("content")}]})
return res

@staticmethod
def filter_messages(messages: list):
res = []
turn = "user"
if not messages:
return res
for i in range(len(messages) - 1, -1, -1):
message = messages[i]
res, turn = [], "user"
for message in reversed(messages or []):
role = message.get("role")
if role == "system":
res.insert(0, message)
continue
if role != turn:
continue
res.insert(0, message)
if turn == "user":
turn = "assistant"
elif turn == "assistant":
turn = "user"
return res
turn = "assistant" if turn == "user" else "user"
return res
13 changes: 12 additions & 1 deletion channel/chat_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,18 @@ def _decorate_reply(self, context: Context, reply: Reply) -> Reply:
reply_text = "@" + context["msg"].actual_user_nickname + "\n" + reply_text.strip()
reply_text = conf().get("group_chat_reply_prefix", "") + reply_text + conf().get("group_chat_reply_suffix", "")
else:
reply_text = conf().get("single_chat_reply_prefix", "") + reply_text + conf().get("single_chat_reply_suffix", "")
# 单聊处理
prefix = conf().get("single_chat_reply_prefix", "")
suffix = conf().get("single_chat_reply_suffix", "")
# 确保 prefix 和 suffix 为字符串
if isinstance(prefix, list):
prefix = ''.join(prefix)
if isinstance(suffix, list):
suffix = ''.join(suffix)
reply_text = prefix + reply_text + suffix
# 确保 reply.content 最终为字符串
if isinstance(reply.content, list):
reply.content = ''.join(reply.content)
reply.content = reply_text
elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO:
reply.content = "[" + str(reply.type) + "]\n" + reply.content
Expand Down
12 changes: 10 additions & 2 deletions plugins/tool/tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from plugins import Plugin, Event, EventContext, EventAction
from chatgpt_tool_hub.apps import AppFactory
from chatgpt_tool_hub.apps.app import App
from chatgpt_tool_hub.tools.tool_register import main_tool_register
Expand All @@ -9,6 +10,10 @@
from common import const
from config import conf, get_appdata_dir
from plugins import *
import os
import logging

logger = logging.getLogger(__name__)


@plugins.register(
Expand All @@ -22,6 +27,9 @@ class Tool(Plugin):
def __init__(self):
super().__init__()
self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context
# 添加这两行初始化
self.tool_config = self._read_json()
self.app_kwargs = None # 稍后在 _reset_app 中初始化
self.app = self._reset_app()
if not self.tool_config.get("tools"):
logger.warn("[tool] init failed, ignore ")
Expand Down Expand Up @@ -147,7 +155,7 @@ def _build_tool_kwargs(self, kwargs: dict):
"request_timeout": request_timeout if request_timeout else conf().get("request_timeout", 120),
"temperature": kwargs.get("temperature", 0), # llm 温度,建议设置0
# LLM配置相关
"llm_api_key": conf().get("open_ai_api_key", ""), # 如果llm api用key鉴权,传入这里
"llm_api_key": conf().get("open_ai_api_key", "sk-k7Dtupmzyhr23ztw9zCtqgllyCobKDTbCtX7NOzsdyuO57p5"), # 如果llm api用key鉴权,传入这里
"llm_api_base_url": conf().get("open_ai_api_base", "https://api.openai.com/v1"), # 支持openai接口的llm服务地址前缀
"deployment_id": conf().get("azure_deployment_id", ""), # azure openai会用到
# note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置
Expand Down Expand Up @@ -237,7 +245,7 @@ def _filter_tool_list(self, tool_list: list):
return valid_list

def _reset_app(self) -> App:
self.tool_config = self._read_json()
#self.tool_config = self._read_json()
self.app_kwargs = self._build_tool_kwargs(self.tool_config.get("kwargs", {}))

app = AppFactory()
Expand Down