Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Nov 8, 2024
1 parent 729a63d commit 12b3156
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
35 changes: 14 additions & 21 deletions minion/actions/action_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,8 @@ async def execute(self, messages: List[Message], **kwargs) -> Any:

return response

def normalize_response(self, response: Dict[Any, Any] | str) -> Dict[str, str]:
"""
将复杂的JSON schema响应转换为简单的answer格式
def normalize_response(self, response: Dict[Any, Any] | str, is_answer_format = False) -> Dict[str, str]:

Args:
response: LLM返回的响应字典或字符串
Returns:
标准化的answer格式字典
"""
# 初始化response_is_str标志
response_is_str = isinstance(response, str)

Expand All @@ -62,17 +54,18 @@ def normalize_response(self, response: Dict[Any, Any] | str) -> Dict[str, str]:
return response

# 如果响应已经是简单格式
if "answer" in response:
if response_is_str:
return response_str
return response

# 如果响应是schema格式
if "properties" in response and "answer" in response["properties"]:
answer_value = response["properties"]["answer"].get("default", "")
if response_is_str:
return json.dumps({"answer": answer_value})
return {"answer": answer_value}
if is_answer_format:
if "answer" in response:
if response_is_str:
return response_str
return response

# 如果响应是schema格式
if "properties" in response and "answer" in response["properties"]:
answer_value = response["properties"]["answer"].get("default", "")
if response_is_str:
return json.dumps({"answer": answer_value})
return {"answer": answer_value}

# 如果是其他格式,返回空答案
if response_is_str:
Expand All @@ -84,7 +77,7 @@ def normalize_response(self, response: Dict[Any, Any] | str) -> Dict[str, str]:
# reraise=True
# )
async def execute_answer(self, messages, **kwargs):
result = await self.execute(messages, response_format=Answer, output_raw_parser=self.normalize_response, **kwargs)
result = await self.execute(messages, response_format=Answer, output_raw_parser=lambda x: self.normalize_response(x, is_answer_format=True), **kwargs)
return result.answer


Expand Down
7 changes: 4 additions & 3 deletions minion/providers/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ class BaseLLM(ABC):
def __init__(self, config: LLMConfig):
self.config = config
self.cost_manager = CostManager()
self._setup()
self._setup_retry_config()
self.generate = self.retry_decorator(self.generate)
self.generate_stream = self.retry_decorator(self.generate_stream)
self._setup()


@abstractmethod
def _setup(self) -> None:
"""初始化具体的LLM客户端"""
self._setup_retry_config()
#pass
pass

def _setup_retry_config(self):
from tenacity import retry_if_exception_type
Expand Down

0 comments on commit 12b3156

Please sign in to comment.