Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Nov 9, 2024
1 parent ed79af5 commit 480c928
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 43 deletions.
2 changes: 1 addition & 1 deletion examples/smart_minion/gsm8k/evalute_gsm8k_re2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def save_run_info(filename, last_processed_id):
last_processed_id = last_processed_item["item_id"]
tasks = [] # Reset tasks after processing
pbar.set_postfix({"Correct": correct, "count": count})
pbar.update(6)
pbar.update(concurrency_count)

# Save running information after each batch
await save_run_info(filename=run_filename, last_processed_id=last_processed_id)
Expand Down
6 changes: 4 additions & 2 deletions examples/smart_minion/human_eval/evalute_human_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,10 @@ async def solve_question(question, route=None):
# print(obs)
return obs

#model = "gpt-4o-mini"
model = "default"

llm = create_llm_provider(config.models.get("default"))
llm = create_llm_provider(config.models.get(model))
cost_manager = CostManager()
llm.cost_manager = cost_manager
async def main():
Expand All @@ -282,7 +284,7 @@ async def main():
# data = await load_data_sample(file_name, samples=1055)

correct, count, matched_ids, mismatched_ids = await evaluate_dataset(
data, run_filename="run_human_eval_deepseek.json", continue_process=True, concurrency_count=1
data, run_filename=f"run_human_eval_{model}_check.json", continue_process=True, concurrency_count=60
)

print(f"Accuracy: {correct/count:.2%}")
Expand Down
2 changes: 1 addition & 1 deletion examples/smart_minion/human_eval/human_eval_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
{
"name": "cot",
"count": 1,
"check": false,
"check": true,
"post_processing": "extract_python"
}
],
Expand Down
41 changes: 14 additions & 27 deletions minion/actions/action_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, List, Dict
from typing import Any, Optional, List, Dict, Union
import json

from tenacity import retry, stop_after_attempt, retry_if_exception_type
Expand Down Expand Up @@ -39,35 +39,22 @@ async def execute(self, messages: List[Message], **kwargs) -> Any:

return response

def normalize_response(self, response: Dict[Any, Any] | str, is_answer_format = False) -> Dict[str, str]:

# 初始化response_is_str标志
response_is_str = isinstance(response, str)

# 如果响应是字符串,尝试解析为JSON
if response_is_str:
def normalize_response(self, response: Union[str, dict], is_answer_format: bool = False) -> Union[str, dict]:
"""规范化响应格式"""
response_is_str = False
if isinstance(response, str):
response_is_str = True
# 使用更新后的 extract_json 函数处理响应
response_str = extract_json(response)
try:
response = json.loads(response_str)
response = json.loads(response_str)
except json.JSONDecodeError:
# 如果解析失败,将字符串作为原样返回
return response

# 如果响应已经是简单格式
if is_answer_format:
if "answer" in response:
if response_is_str:
return response_str
return response

# 如果响应是schema格式
if "properties" in response and "answer" in response["properties"]:
answer_value = response["properties"]["answer"].get("default", "")
if response_is_str:
return json.dumps({"answer": answer_value})
return {"answer": answer_value}

# 如果是其他格式,返回空答案
return response_str

# 处理 schema 格式
if is_answer_format and isinstance(response, dict) and "properties" in response:
if "answer" in response["properties"]:
return {"answer": response["properties"]["answer"].get("default", "")}
if response_is_str:
return json.dumps(response)
return response
Expand Down
55 changes: 44 additions & 11 deletions minion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from difflib import SequenceMatcher
from pathlib import Path
from typing import Optional
from typing import Optional, Union

import aiofiles
from nltk.corpus import wordnet
Expand Down Expand Up @@ -49,16 +49,49 @@ def recursive_replace(obj):
return recursive_replace(config)


def extract_json(text):
# Regular expression pattern to match all content between ```json and ```
pattern = r"```json\s*([\s\S]*?)\s*```"

# Find all matches in the input text
matches = re.findall(pattern, text)

if matches:
return matches[0]
else:
def extract_json(text: str) -> Union[str, dict]:
"""
从文本中提取 JSON 内容,支持处理嵌套的代码块
Args:
text: 输入文本,可能包含 JSON 字符串或被 ``` 包裹的 JSON
Returns:
解析后的 JSON 对象或原始字符串
"""
text = text.strip()

# 处理被多层 ``` 包裹的情况
while text.startswith('```'):
# 移除开头的 ```
text = text[3:]
# 检查是否有语言标识符(如 json)
first_line_end = text.find('\n')
first_line = text[:first_line_end].strip() if first_line_end != -1 else text
if first_line.lower() == 'json' or first_line.startswith('{'):
text = text[first_line_end + 1:] if first_line_end != -1 else text
else:
text = text.strip()

# 移除结尾的 ```
if text.endswith('```'):
text = text[:-3].strip()

try:
# 尝试解析 JSON
return json.dumps(json.loads(text))
except json.JSONDecodeError:
# 如果解析失败,尝试在文本中查找 JSON 对象
start_brace = text.find('{')
end_brace = text.rfind('}')

if start_brace != -1 and end_brace != -1:
try:
json_str = text[start_brace:end_brace + 1]
return json.dumps(json.loads(json_str))
except json.JSONDecodeError:
pass

return text

def extract_last_number(text: str):
Expand Down
68 changes: 67 additions & 1 deletion tests/actions/test_action_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ def test_normalize_response_json_string(llm_action_node):
assert parsed_result["correct"] is True
assert parsed_result["score"] == 1


def test_normalize_response_json_string2(llm_action_node):
# 测试JSON字符串输入
json_input = r'''
```json
{
"feedback": "The provided function implementation generally follows the instructions and addresses the problem context effectively. However, there are a few areas that could be improved for clarity and correctness:\n\n1. **Variable Naming**: The variable `total_needed` is not used in the logic. It might be clearer to directly use `number + need` in the comparison.\n2. **Edge Cases**: The function does not explicitly handle the edge cases where `need` is 0 or `remaining` is 0. While the logic implicitly covers these cases, it would be better to explicitly mention them in the comments or handle them separately.\n3. **Clarity in Logic**: The logic could be slightly simplified by directly comparing `remaining` with `need` without calculating `total_needed` separately.\n\nSuggested Improvement:\n```python\ndef eat(number, need, remaining):\n # Calculate the total number of carrots eaten\n if remaining >= need:\n total_eaten = number + need\n carrots_left = remaining - need\n else:\n total_eaten = number + remaining\n carrots_left = 0\n \n # Return the result\n return [total_eaten, carrots_left]\n```\n\nThis version simplifies the logic and makes it clearer by directly comparing `remaining` with `need`.",
"correct": true,
"score": 0.9
}
```
'''

result = llm_action_node.normalize_response(json_input)
# 验证返回的是提取并格式化后的JSON字符串
assert isinstance(result, str)
# 确保可以被解析回JSON对象
parsed_result = json.loads(result)
assert "feedback" in parsed_result
assert "correct" in parsed_result
assert "score" in parsed_result
assert parsed_result["correct"] is True
assert parsed_result["score"] == 0.9

def test_normalize_response_dict_with_answer(llm_action_node):
# 测试包含answer字段的字典
input_dict = {"answer": "test answer"}
Expand Down Expand Up @@ -83,4 +107,46 @@ def test_normalize_response_plain_string(llm_action_node):
# 测试普通字符串输入
plain_string = "This is a test string"
result = llm_action_node.normalize_response(plain_string)
assert result == plain_string
assert result == plain_string

def test_normalize_response_complex_json(llm_action_node):
# 测试包含详细反馈、正确性和分数的复杂 JSON 响应
complex_json = {
"feedback": "The provided function implementation is generally correct...",
"correct": True,
"score": 0.9
}

result = llm_action_node.normalize_response(complex_json)
assert isinstance(result, dict)
assert "feedback" in result
assert "correct" in result
assert "score" in result
assert isinstance(result["feedback"], str)
assert isinstance(result["correct"], bool)
assert isinstance(result["score"], (int, float))
assert result["correct"] is True
assert result["score"] == 0.9

def test_normalize_response_nested_json_string(llm_action_node):
# 测试嵌套引号的 JSON 字符串输入
nested_json = r'''```json
{
"feedback": "The provided function implementation generally follows the instructions and addresses the problem context effectively. However, there are a few areas that could be improved for clarity and correctness:\n\n1. **Variable Naming**: The variable `total_needed` is not used in the logic. It might be clearer to directly use `number + need` in the comparison.\n2. **Edge Cases**: The function does not explicitly handle the edge cases where `need` is 0 or `remaining` is 0. While the logic implicitly covers these cases, it would be better to explicitly mention them in the comments or handle them separately.\n3. **Clarity in Logic**: The logic could be slightly simplified by directly comparing `remaining` with `need` without calculating `total_needed` separately.\n\nSuggested Improvement:\n```python\ndef eat(number, need, remaining):\n # Calculate the total number of carrots eaten\n if remaining >= need:\n total_eaten = number + need\n carrots_left = remaining - need\n else:\n total_eaten = number + remaining\n carrots_left = 0\n \n # Return the result\n return [total_eaten, carrots_left]\n```\n\nThis version simplifies the logic and makes it clearer by directly comparing `remaining` with `need`.",
"correct": true,
"score": 0.9
}
```'''

result = llm_action_node.normalize_response(nested_json)
# 验证返回的是提取并格式化后的JSON字符串
print(result)
result = json.loads(result)
assert isinstance(result, dict)
assert "feedback" in result
assert "correct" in result
assert "score" in result
# 确保内部的代码块被正确保留
assert "```python" in result["feedback"]
assert result["correct"] is True
assert result["score"] == 0.9

0 comments on commit 480c928

Please sign in to comment.