Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support function call #3

Merged
merged 2 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions app/services/chat/message_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
parts = []

if isinstance(msg["content"], str):
parts.append({"text": msg["content"]})
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空
if msg["content"]:
parts.append({"text": msg["content"]})
elif isinstance(msg["content"], list):
for content in msg["content"]:
if isinstance(content, str):
if isinstance(content, str) and content:
parts.append({"text": content})
elif isinstance(content, dict):
if content["type"] == "text":
if content["type"] == "text" and content["text"]:
parts.append({"text": content["text"]})
elif content["type"] == "image_url":
parts.append(_convert_image(content["image_url"]["url"]))
Expand Down
102 changes: 73 additions & 29 deletions app/services/chat/response_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# app/services/chat/response_handler.py

import json
import random
import string
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from typing import Dict, Any, List, Optional
import time
import uuid
from app.core.config import settings
Expand Down Expand Up @@ -29,40 +32,38 @@ def handle_response(self, response: Dict[str, Any], model: str, stream: bool = F


def _handle_openai_stream_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = _extract_text(response, model, stream=True)
text, tool_calls = _extract_result(response, model, stream=True, gemini_format=False)
if not text and not tool_calls:
delta = {}
else:
delta = {"content": text, "role": "assistant"}
if tool_calls:
delta["tool_calls"] = tool_calls

return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"delta": {"content": text} if text else {},
"finish_reason": finish_reason
}]
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
}


def _handle_openai_normal_response(response: Dict[str, Any], model: str, finish_reason: str) -> Dict[str, Any]:
text = _extract_text(response, model, stream=False)
text, tool_calls = _extract_result(response, model, stream=False, gemini_format=False)
return {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason
}],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": text, "tool_calls": tool_calls},
"finish_reason": finish_reason,
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}


Expand Down Expand Up @@ -127,8 +128,8 @@ def _handle_openai_normal_image_response(image_str: str,model: str,finish_reason
}


def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) -> str:
text = ""
def _extract_result(response: Dict[str, Any], model: str, stream: bool = False, gemini_format: bool = False) -> tuple[str, List[Dict[str, Any]]]:
text, tool_calls = "", []
if stream:
if response.get("candidates"):
candidate = response["candidates"][0]
Expand Down Expand Up @@ -212,6 +213,7 @@ def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) ->
else:
text = ""
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(parts, gemini_format)
else:
if response.get("candidates"):
candidate = response["candidates"][0]
Expand All @@ -234,23 +236,65 @@ def _extract_text(response: Dict[str, Any], model: str, stream: bool = False) ->
else:
text = ""
for part in candidate["content"]["parts"]:
text += part["text"]
text += part.get("text", "")
text = _add_search_link_text(model, candidate, text)
tool_calls = _extract_tool_calls(candidate["content"]["parts"], gemini_format)
else:
text = "暂无返回"
return text
return text, tool_calls

def _extract_tool_calls(parts: List[Dict[str, Any]], gemini_format: bool) -> List[Dict[str, Any]]:
"""提取工具调用信息"""
if not parts or not isinstance(parts, list):
return []

letters = string.ascii_lowercase + string.digits

tool_calls = list()
for i in range(len(parts)):
part = parts[i]
if not part or not isinstance(part, dict):
continue

item = part.get("functionCall", {})
if not item or not isinstance(item, dict):
continue

if gemini_format:
tool_calls.append(part)
else:
id = f"call_{''.join(random.sample(letters, 32))}"
name = item.get("name", "")
arguments = json.dumps(item.get("args", None) or {})

tool_calls.append(
{
"index": i,
"id": id,
"type": "function",
"function": {"name": name, "arguments": arguments},
}
)

return tool_calls


def _handle_gemini_stream_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = _extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}], "role": "model"}
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
else:
content = {"parts": [{"text": text}], "role": "model"}
response["candidates"][0]["content"] = content
return response


def _handle_gemini_normal_response(response: Dict[str, Any], model: str, stream: bool) -> Dict[str, Any]:
text = _extract_text(response, model, stream=stream)
content = {"parts": [{"text": text}], "role": "model"}
text, tool_calls = _extract_result(response, model, stream=stream, gemini_format=True)
if tool_calls:
content = {"parts": tool_calls, "role": "model"}
else:
content = {"parts": [{"text": text}], "role": "model"}
response["candidates"][0]["content"] = content
return response

Expand Down
6 changes: 6 additions & 0 deletions app/services/gemini_chat_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ def _build_tools(model: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})

if payload and isinstance(payload, dict) and "tools" in payload:
items = payload.get("tools", [])
if items and isinstance(items, list):
tools.extend(items)

return tools


Expand Down
20 changes: 20 additions & 0 deletions app/services/openai_chat_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# app/services/chat_service.py

from copy import deepcopy
import json
from typing import Dict, Any, AsyncGenerator, List, Union
from app.core.logger import get_openai_logger
Expand Down Expand Up @@ -39,6 +40,25 @@ def _build_tools(
tools.append({"code_execution": {}})
if model.endswith("-search"):
tools.append({"googleSearch": {}})

# 将 request 中的 tools 合并到 tools 中
if request.tools:
function_declarations = []
for tool in request.tools:
if not tool or not isinstance(tool, dict):
continue

if tool.get("type", "") == "function" and tool.get("function"):
function = deepcopy(tool.get("function"))
parameters = function.get("parameters", {})
if parameters.get("type") == "object" and not parameters.get("properties", {}):
function.pop("parameters", None)

function_declarations.append(function)

if function_declarations:
tools.append({"functionDeclarations": function_declarations})

return tools


Expand Down