Skip to content

Commit

Permalink
1
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Nov 9, 2024
1 parent 4216a85 commit 47637b5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
6 changes: 3 additions & 3 deletions minion/models/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ class Plan(BaseModel):
)

class CheckResult(BaseModel):
feedback: str
correct: bool
score: float
feedback: str = ""
correct: bool = False
score: float = 0.0
8 changes: 6 additions & 2 deletions minion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from nltk.corpus import wordnet
from PIL import Image

from minion.utils.custom_decoder import CustomDecoder
from minion.utils.sanitize import sanitize


Expand Down Expand Up @@ -78,7 +79,9 @@ def extract_json(text: str) -> Union[str, dict]:

try:
# 尝试解析 JSON
return json.dumps(json.loads(text))
dict = CustomDecoder(strict=False).decode(text)
return json.dumps(dict)
#return json.dumps(json.loads(text))
except json.JSONDecodeError:
# 如果解析失败,尝试在文本中查找 JSON 对象
start_brace = text.find('{')
Expand All @@ -87,7 +90,8 @@ def extract_json(text: str) -> Union[str, dict]:
if start_brace != -1 and end_brace != -1:
try:
json_str = text[start_brace:end_brace + 1]
return json.dumps(json.loads(json_str))
dict = CustomDecoder(strict=False).decode(json_str)
return json.dumps(dict)
except json.JSONDecodeError:
pass

Expand Down
23 changes: 23 additions & 0 deletions tests/actions/test_action_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,29 @@ def test_normalize_response_json_string2(llm_action_node):
assert "score" in parsed_result
assert parsed_result["correct"] is True
assert parsed_result["score"] == 0.9
def test_normalize_response_json_string3(llm_action_node):
# 测试JSON字符串输入
json_input = r'''
```json
{
"feedback": "a\(\)",
"correct": true,
"score": 0.9
}
```
'''

result = llm_action_node.normalize_response(json_input)
# 验证返回的是提取并格式化后的JSON字符串
assert isinstance(result, str)
# 确保可以被解析回JSON对象
print(result)
parsed_result = json.loads(result)
assert "feedback" in parsed_result
assert "correct" in parsed_result
assert "score" in parsed_result
assert parsed_result["correct"] is True
assert parsed_result["score"] == 0.9

def test_normalize_response_dict_with_answer(llm_action_node):
# 测试包含answer字段的字典
Expand Down

0 comments on commit 47637b5

Please sign in to comment.