Skip to content

Commit

Permalink
xml_simple
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Nov 9, 2024
1 parent af361b2 commit 88cfde6
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 76 deletions.
64 changes: 60 additions & 4 deletions examples/smart_minion/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

async def smart_brain():
# 使用从 minion/__init__.py 导入的 config 对象
llm_config = config.models.get("default")
model = "default"
model = "llama3.2"
llm_config = config.models.get(model)

llm = create_llm_provider(llm_config)

Expand Down Expand Up @@ -92,16 +94,70 @@ async def smart_brain():
# )
# print(obs)

# obs, score, *_ = await brain.step(
# query="\ndef circular_shift(x, shift):\n \"\"\"Circular shift the digits of the integer x, shift the digits right by shift\n and return the result as a string.\n If shift > number of digits, return digits reversed.\n >>> circular_shift(12, 1)\n \"21\"\n >>> circular_shift(12, 2)\n \"12\"\n \"\"\"\n",
# route="cot",
# post_processing="extract_python",
#
# #check_route="doctest"
# )
# print(obs)

obs, score, *_ = await brain.step(
query="\ndef circular_shift(x, shift):\n \"\"\"Circular shift the digits of the integer x, shift the digits right by shift\n and return the result as a string.\n If shift > number of digits, return digits reversed.\n >>> circular_shift(12, 1)\n \"21\"\n >>> circular_shift(12, 2)\n \"12\"\n \"\"\"\n",
query="""extract the feedback from the following:
<root>
<feedback>
The provided solution for the `sort_array` function generally follows the instructions and addresses the problem requirements effectively. However, there are a few areas that could be improved for clarity and robustness.
1. **Edge Case Handling**: The function correctly handles the edge cases where the array is empty or contains only one element. This is a good practice.
2. **Sum Calculation and Sorting Logic**: The logic to determine the sum of the first and last elements and then sorting based on whether the sum is odd or even is correctly implemented. This aligns with the problem's requirements.
3. **Return Value**: The function returns a sorted copy of the array without modifying the original array, which is consistent with the problem's note.
4. **Clarity and Readability**: The code is clear and readable, with appropriate comments explaining the logic. However, the comments could be more detailed to explain the reasoning behind each step.
5. **Potential Improvement**: While the current implementation is correct, it could be slightly optimized by avoiding the need to calculate `first_value` and `last_value` separately. Instead, the sum could be calculated directly within the condition. This would make the code slightly more concise.
Suggested Improvement:
```python
def sort_array(array):
if len(array) <= 1:
return array[:] # Return a copy of the array if it's empty or has one element
sum_first_last = array[0] + array[-1]
if sum_first_last % 2 == 0:
# Sum is even, sort in descending order
return sorted(array, reverse=True)
else:
# Sum is odd, sort in ascending order
return sorted(array)
```
This version maintains the same functionality but is slightly more concise.
</feedback>
<correct>true</correct>
<score>1.0</score>
</root>
""",
route="native",
check=False

# check_route="doctest"
)
print(obs)

obs, score, *_ = await brain.step(
query="\ndef sort_array(array):\n \"\"\"\n Given an array of non-negative integers, return a copy of the given array after sorting,\n you will sort the given array in ascending order if the sum( first index value, last index value) is odd,\n or sort it in descending order if the sum( first index value, last index value) is even.\n\n Note:\n * don't change the given array.\n\n Examples:\n * sort_array([]) => []\n * sort_array([5]) => [5]\n * sort_array([2, 4, 3, 0, 1, 5]) => [0, 1, 2, 3, 4, 5]\n * sort_array([2, 4, 3, 0, 1, 5, 6]) => [6, 5, 4, 3, 2, 1, 0]\n \"\"\"\n",
route="cot",
post_processing="extract_python",

#check_route="doctest"
# check_route="doctest"
)
print(obs)


# obs, score, *_ = await brain.step(
# query="solve \log_{\sqrt{5}}{125\sqrt{5}}",
# route="cot",
Expand Down
179 changes: 154 additions & 25 deletions minion/actions/lmp_action_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Union, List, Optional, Type
from pydantic import BaseModel
import json
import xml.etree.ElementTree as ET
import re

import ell
from tenacity import retry, stop_after_attempt, retry_if_exception_type
Expand All @@ -26,7 +28,7 @@ def ell_call(self, ret):
"""You are a helpful assistant."""
return ret

async def execute(self, messages: Union[str, Message, List[Message]], response_format: Optional[Union[Type[BaseModel], dict]] = None, output_raw_parser=None, **kwargs) -> Any:
async def execute(self, messages: Union[str, Message, List[Message]], response_format: Optional[Union[Type[BaseModel], dict]] = None, output_raw_parser=None, format="json", **kwargs) -> Any:
# 添加 input_parser 处理
if self.input_parser:
messages = self.input_parser(messages)
Expand All @@ -48,44 +50,171 @@ async def execute(self, messages: Union[str, Message, List[Message]], response_f

# 创建示例数据
example = response_format.model_construct()
example_json = example.model_dump_json(indent=4)

if format == "json":
example_str = example.model_dump_json(indent=4)
prompt = (
f"Please provide the response in JSON format as per the following schema:\n"
f"{schema_with_indent}\n\n"
f"Here's an example of the expected format:\n"
f"{example_str}\n\n"
f"Please ensure your response follows this exact schema format."
)
api_params['response_format'] = { "type": "json_object" }
else: # format == "xml" or format == "xml_simple"
example_dict = example.model_dump()
example_xml = self._dict_to_xml_example(example_dict)
prompt = (
f"""Construct an XML response that adheres to the specified schema below.
Schema Structure Example:
{example_xml}
Required JSON Schema Compliance:
{schema_with_indent}
Your response should be:
Well-formed XML: Ensure it follows XML syntax rules.
Schema-compliant: Each element, attribute, and data type must match the JSON schema requirements.
Error-free for Parsing: Escape all special characters and ensure compatibility for JSON conversion.
Provide a final XML structure that aligns seamlessly with both the XML and JSON schema constraints."""
)
api_params['response_format'] = { "type": "text" }

if isinstance(messages, str):
messages = [user(messages)]
elif isinstance(messages, Message):
messages = [messages]

# 添加带有 schema 和示例的提示
prompt = (
f"Please provide the response in JSON format as per the following schema:\n"
f"{schema_with_indent}\n\n"
f"Here's an example of the expected format:\n"
f"{example_json}\n\n"
f"Please ensure your response follows this exact schema format."
)

messages.append(user(content=prompt))

api_params['response_format'] = { "type": "json_object" }
elif isinstance(response_format, dict):
# If response_format is a dictionary, pass it as is
api_params['response_format'] = response_format

# response = self.ell_call(messages, client=self.llm.client_ell, api_params=api_params)
# response = response.text
response = await super().execute(messages, **api_params)

if isinstance(response_format, type) and issubclass(response_format, BaseModel):
if format == "xml" or format == "xml_simple":
# 清理响应中的代码块标记
response = response.strip()
if response.startswith('```xml'):
response = response[6:] # 移除开头的 ```xml
if response.endswith('```'):
response = response[:-3] # 移除结尾的 ```
response = response.strip()

# 确保响应是有效的 XML
if not response.strip().startswith('<?xml'):
response = f'<?xml version="1.0" encoding="UTF-8"?>\n{response}'

# 根据format选择解析方式
if format == "xml_simple":
response = self._simple_xml_to_json(response_format, response)
else:
response = self._xml_to_json(response)
response = self.normalize_response(response)

if original_response_format and isinstance(original_response_format, type) and issubclass(original_response_format, BaseModel):
response = original_response_format.model_validate_json(response)
# 判断 response pydantic model 是否只有一个 field
# if len(original_response_format.model_fields) == 1:
# field_name = list(original_response_format.model_fields.keys())[0]
# response = getattr(response, field_name)
#判断 response pydantic model 是否只有一个 field
# if original_response_format == Answer:
# response = response.answer

if self.output_parser:
response = self.output_parser(response)
return response

def _dict_to_xml_example(self, data, root_name="root"):
"""Helper method to convert a dictionary to XML example string."""
if isinstance(data, dict):
elements = []
for key, value in data.items():
elements.append(f"<{key}>{self._dict_to_xml_example(value, key)}</{key}>")
if root_name == "root":
return f'<?xml version="1.0" encoding="UTF-8"?>\n<{root_name}>\n {" ".join(elements)}\n</{root_name}>'
return "\n".join(elements)
elif isinstance(data, list):
elements = []
item_name = root_name[:-1] if root_name.endswith('s') else 'item'
for item in data:
elements.append(f"<{item_name}>{self._dict_to_xml_example(item, item_name)}</{item_name}>")
return "\n".join(elements)
else:
return str(data) if data is not None else ""

def _xml_to_json(self, xml_str: str) -> str:
"""Convert XML string to JSON string compatible with Pydantic model."""
# 移除 XML 声明
xml_str = re.sub(r'<\?xml[^>]+\?>', '', xml_str).strip()

# 解析 XML
root = ET.fromstring(xml_str)

# 如果根元素是 'root',我们需要提取其子元素
if root.tag == 'root':
result = {}
for child in root:
result[child.tag] = self._process_xml_element(child)
else:
result = self._process_xml_element(root)

return json.dumps(result)

def _process_xml_element(self, element: ET.Element) -> Any:
"""递归处理 XML 元素"""
# 如果元素没有子元素且有文本
if len(element) == 0:
text = element.text.strip() if element.text else ""
# 尝试转换布尔值
if text.lower() == 'true':
return True
elif text.lower() == 'false':
return False
# 尝试转换数字
try:
if '.' in text:
return float(text)
return int(text)
except ValueError:
return text

# 如果元素有子元素
result = {}
for child in element:
# 处理列表情况(相同标签的多个元素)
if child.tag in result:
if not isinstance(result[child.tag], list):
result[child.tag] = [result[child.tag]]
result[child.tag].append(self._process_xml_element(child))
else:
result[child.tag] = self._process_xml_element(child)

return result

def _simple_xml_to_json(self, response_format, xml_str: str) -> str:
"""使用简单的正则表达式解析XML"""
# 移除XML声明
xml_str = re.sub(r'<\?xml[^>]+\?>', '', xml_str).strip()

# 获取schema中定义的所有字段
schema = response_format.model_json_schema()
fields = schema.get('properties', {}).keys()

result = {}
for field in fields:
# 使用非贪婪匹配来提取标签内容
pattern = f"<{field}>(.*?)</{field}>"
match = re.search(pattern, xml_str, re.DOTALL)
if match:
value = match.group(1).strip()
# 尝试转换布尔值
if value.lower() == 'true':
result[field] = True
elif value.lower() == 'false':
result[field] = False
else:
# 尝试转换数字
try:
if '.' in value:
result[field] = float(value)
else:
result[field] = int(value)
except ValueError:
result[field] = value

return json.dumps(result)
15 changes: 9 additions & 6 deletions minion/main/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,17 @@ async def choose_mind(self, input):
# Create the filled template
filled_template = mind_template.render(minds=self.minds.values(), input=input)

result = await self.lmp_action_node.execute_answer(filled_template)
try:
result = await self.lmp_action_node.execute_answer(filled_template)

# Ensure the result is a valid mind ID
if result not in self.minds:
result = "left_mind"
#raise ValueError(f"Invalid mind ID returned: {result}")
# Ensure the result is a valid mind ID
if result not in self.minds:
result = "left_mind"
#raise ValueError(f"Invalid mind ID returned: {result}")

return result
return result
except Exception as e:
return "left_mind" #tmp for llama3.2 which can't return valid json


Mind.model_rebuild()
31 changes: 15 additions & 16 deletions minion/main/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,21 @@ def __init__(self, **kwargs):
self.input.instruction = "let's think step by step to verify this answer"

async def execute(self):
for _ in range(3):
prompt = Template(CHECK_PROMPT)
prompt = prompt.render(input=self.input)

node = LmpActionNode(self.brain.llm)
result = await node.execute(prompt, response_format=CheckResult)

self.answer_node = result
self.answer = self.input.feedback = {
"feedback": result.feedback,
"correct": result.correct,
"score": result.score
}

if result:
return self.answer
prompt = Template(CHECK_PROMPT)
prompt = prompt.render(input=self.input)

node = LmpActionNode(self.brain.llm)
result = await node.execute(prompt, response_format=CheckResult, format="xml_simple")

self.answer_node = result
self.answer = self.input.feedback = {
"feedback": result.feedback,
"correct": result.correct,
"score": result.score
}

if result:
return self.answer

@register_check_minion
class TestMinion(CheckMinion):
Expand Down
Loading

0 comments on commit 88cfde6

Please sign in to comment.