diff --git a/examples/smart_minion/gsm8k/evalute_gsm8k_re2.py b/examples/smart_minion/gsm8k/evalute_gsm8k_re2.py index c9e808e1..8f3db3f2 100644 --- a/examples/smart_minion/gsm8k/evalute_gsm8k_re2.py +++ b/examples/smart_minion/gsm8k/evalute_gsm8k_re2.py @@ -119,7 +119,7 @@ async def save_run_info(filename, last_processed_id): last_processed_id = last_processed_item["item_id"] tasks = [] # Reset tasks after processing pbar.set_postfix({"Correct": correct, "count": count}) - pbar.update(6) + pbar.update(concurrency_count) # Save running information after each batch await save_run_info(filename=run_filename, last_processed_id=last_processed_id) diff --git a/examples/smart_minion/human_eval/evalute_human_eval.py b/examples/smart_minion/human_eval/evalute_human_eval.py index af083c0e..cc372f29 100644 --- a/examples/smart_minion/human_eval/evalute_human_eval.py +++ b/examples/smart_minion/human_eval/evalute_human_eval.py @@ -271,8 +271,10 @@ async def solve_question(question, route=None): # print(obs) return obs +#model = "gpt-4o-mini" +model = "default" -llm = create_llm_provider(config.models.get("default")) +llm = create_llm_provider(config.models.get(model)) cost_manager = CostManager() llm.cost_manager = cost_manager async def main(): @@ -282,7 +284,7 @@ async def main(): # data = await load_data_sample(file_name, samples=1055) correct, count, matched_ids, mismatched_ids = await evaluate_dataset( - data, run_filename="run_human_eval_deepseek.json", continue_process=True, concurrency_count=1 + data, run_filename=f"run_human_eval_{model}_check.json", continue_process=True, concurrency_count=60 ) print(f"Accuracy: {correct/count:.2%}") diff --git a/examples/smart_minion/human_eval/human_eval_config.json b/examples/smart_minion/human_eval/human_eval_config.json index 1b230444..d149b9f3 100644 --- a/examples/smart_minion/human_eval/human_eval_config.json +++ b/examples/smart_minion/human_eval/human_eval_config.json @@ -5,7 +5,7 @@ { "name": "cot", "count": 1, - "check": false, + "check": true, "post_processing": "extract_python" } ], diff --git a/minion/actions/action_node.py b/minion/actions/action_node.py index 8fa0b383..45ee2f5d 100644 --- a/minion/actions/action_node.py +++ b/minion/actions/action_node.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional, List, Dict +from typing import Any, Optional, List, Dict, Union import json from tenacity import retry, stop_after_attempt, retry_if_exception_type @@ -39,35 +39,22 @@ async def execute(self, messages: List[Message], **kwargs) -> Any: return response - def normalize_response(self, response: Dict[Any, Any] | str, is_answer_format = False) -> Dict[str, str]: - - # 初始化response_is_str标志 - response_is_str = isinstance(response, str) - - # 如果响应是字符串,尝试解析为JSON - if response_is_str: + def normalize_response(self, response: Union[str, dict], is_answer_format: bool = False) -> Union[str, dict]: + """规范化响应格式""" + response_is_str = False + if isinstance(response, str): + response_is_str = True + # 使用更新后的 extract_json 函数处理响应 response_str = extract_json(response) try: - response = json.loads(response_str) + response = json.loads(response_str) except json.JSONDecodeError: - # 如果解析失败,将字符串作为原样返回 - return response - - # 如果响应已经是简单格式 - if is_answer_format: - if "answer" in response: - if response_is_str: - return response_str - return response - - # 如果响应是schema格式 - if "properties" in response and "answer" in response["properties"]: - answer_value = response["properties"]["answer"].get("default", "") - if response_is_str: - return json.dumps({"answer": answer_value}) - return {"answer": answer_value} - - # 如果是其他格式,返回空答案 + return response_str + + # 处理 schema 格式 + if is_answer_format and isinstance(response, dict) and "properties" in response: + if "answer" in response["properties"]: + return {"answer": response["properties"]["answer"].get("default", "")} if response_is_str: return json.dumps(response) return response diff --git a/minion/utils/utils.py b/minion/utils/utils.py index c35013ec..32794fe9 100644 --- a/minion/utils/utils.py +++ b/minion/utils/utils.py @@ -5,7 +5,7 @@ import re from difflib import SequenceMatcher from pathlib import Path -from typing import Optional +from typing import Optional, Union import aiofiles from nltk.corpus import wordnet @@ -49,16 +49,49 @@ def recursive_replace(obj): return recursive_replace(config) -def extract_json(text): - # Regular expression pattern to match all content between ```json and ``` - pattern = r"```json\s*([\s\S]*?)\s*```" - - # Find all matches in the input text - matches = re.findall(pattern, text) - - if matches: - return matches[0] - else: +def extract_json(text: str) -> Union[str, dict]: + """ + 从文本中提取 JSON 内容,支持处理嵌套的代码块 + + Args: + text: 输入文本,可能包含 JSON 字符串或被 ``` 包裹的 JSON + + Returns: + 解析后的 JSON 对象或原始字符串 + """ + text = text.strip() + + # 处理被多层 ``` 包裹的情况 + while text.startswith('```'): + # 移除开头的 ``` + text = text[3:] + # 检查是否有语言标识符(如 json) + first_line_end = text.find('\n') + first_line = text[:first_line_end].strip() if first_line_end != -1 else text + if first_line.lower() == 'json' or first_line.startswith('{'): + text = text[first_line_end + 1:] if first_line_end != -1 else text + else: + text = text.strip() + + # 移除结尾的 ``` + if text.endswith('```'): + text = text[:-3].strip() + + try: + # 尝试解析 JSON + return json.dumps(json.loads(text)) + except json.JSONDecodeError: + # 如果解析失败,尝试在文本中查找 JSON 对象 + start_brace = text.find('{') + end_brace = text.rfind('}') + + if start_brace != -1 and end_brace != -1: + try: + json_str = text[start_brace:end_brace + 1] + return json.dumps(json.loads(json_str)) + except json.JSONDecodeError: + pass + return text def extract_last_number(text: str): diff --git a/tests/actions/test_action_node.py b/tests/actions/test_action_node.py index 488e85ba..c508453b 100644 --- a/tests/actions/test_action_node.py +++ b/tests/actions/test_action_node.py @@ -54,6 +54,30 @@ def test_normalize_response_json_string(llm_action_node): assert parsed_result["correct"] is True assert parsed_result["score"] == 1 + +def test_normalize_response_json_string2(llm_action_node): + # 测试JSON字符串输入 + json_input = r''' + ```json +{ + "feedback": "The provided function implementation generally follows the instructions and addresses the problem context effectively. However, there are a few areas that could be improved for clarity and correctness:\n\n1. **Variable Naming**: The variable `total_needed` is not used in the logic. It might be clearer to directly use `number + need` in the comparison.\n2. **Edge Cases**: The function does not explicitly handle the edge cases where `need` is 0 or `remaining` is 0. While the logic implicitly covers these cases, it would be better to explicitly mention them in the comments or handle them separately.\n3. **Clarity in Logic**: The logic could be slightly simplified by directly comparing `remaining` with `need` without calculating `total_needed` separately.\n\nSuggested Improvement:\n```python\ndef eat(number, need, remaining):\n # Calculate the total number of carrots eaten\n if remaining >= need:\n total_eaten = number + need\n carrots_left = remaining - need\n else:\n total_eaten = number + remaining\n carrots_left = 0\n \n # Return the result\n return [total_eaten, carrots_left]\n```\n\nThis version simplifies the logic and makes it clearer by directly comparing `remaining` with `need`.", + "correct": true, + "score": 0.9 +} +``` + ''' + + result = llm_action_node.normalize_response(json_input) + # 验证返回的是提取并格式化后的JSON字符串 + assert isinstance(result, str) + # 确保可以被解析回JSON对象 + parsed_result = json.loads(result) + assert "feedback" in parsed_result + assert "correct" in parsed_result + assert "score" in parsed_result + assert parsed_result["correct"] is True + assert parsed_result["score"] == 0.9 + def test_normalize_response_dict_with_answer(llm_action_node): # 测试包含answer字段的字典 input_dict = {"answer": "test answer"} @@ -83,4 +107,46 @@ def test_normalize_response_plain_string(llm_action_node): # 测试普通字符串输入 plain_string = "This is a test string" result = llm_action_node.normalize_response(plain_string) - assert result == plain_string \ No newline at end of file + assert result == plain_string + +def test_normalize_response_complex_json(llm_action_node): + # 测试包含详细反馈、正确性和分数的复杂 JSON 响应 + complex_json = { + "feedback": "The provided function implementation is generally correct...", + "correct": True, + "score": 0.9 + } + + result = llm_action_node.normalize_response(complex_json) + assert isinstance(result, dict) + assert "feedback" in result + assert "correct" in result + assert "score" in result + assert isinstance(result["feedback"], str) + assert isinstance(result["correct"], bool) + assert isinstance(result["score"], (int, float)) + assert result["correct"] is True + assert result["score"] == 0.9 + +def test_normalize_response_nested_json_string(llm_action_node): + # 测试嵌套引号的 JSON 字符串输入 + nested_json = r'''```json +{ + "feedback": "The provided function implementation generally follows the instructions and addresses the problem context effectively. However, there are a few areas that could be improved for clarity and correctness:\n\n1. **Variable Naming**: The variable `total_needed` is not used in the logic. It might be clearer to directly use `number + need` in the comparison.\n2. **Edge Cases**: The function does not explicitly handle the edge cases where `need` is 0 or `remaining` is 0. While the logic implicitly covers these cases, it would be better to explicitly mention them in the comments or handle them separately.\n3. **Clarity in Logic**: The logic could be slightly simplified by directly comparing `remaining` with `need` without calculating `total_needed` separately.\n\nSuggested Improvement:\n```python\ndef eat(number, need, remaining):\n # Calculate the total number of carrots eaten\n if remaining >= need:\n total_eaten = number + need\n carrots_left = remaining - need\n else:\n total_eaten = number + remaining\n carrots_left = 0\n \n # Return the result\n return [total_eaten, carrots_left]\n```\n\nThis version simplifies the logic and makes it clearer by directly comparing `remaining` with `need`.", + "correct": true, + "score": 0.9 +} +```''' + + result = llm_action_node.normalize_response(nested_json) + # 验证返回的是提取并格式化后的JSON字符串 + print(result) + result = json.loads(result) + assert isinstance(result, dict) + assert "feedback" in result + assert "correct" in result + assert "score" in result + # 确保内部的代码块被正确保留 + assert "```python" in result["feedback"] + assert result["correct"] is True + assert result["score"] == 0.9 \ No newline at end of file