Skip to content

Commit

Permalink
apply_post_processing
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Oct 17, 2024
1 parent b1ce282 commit daf02a7
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 50 deletions.
38 changes: 18 additions & 20 deletions examples/smart_minion/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def smart_brain():
# print(obs)

# Get the directory of the current file
current_file_dir = os.path.dirname(os.path.abspath(__file__))
os.path.dirname(os.path.abspath(__file__))

# llm1 = LLM()
# LLM()
Expand All @@ -107,12 +107,12 @@ async def smart_brain():
# )
# print(obs)

obs, score, *_ = await brain.step(
query="Real numbers $x$ and $y$ with $x,y>1$ satisfy $\log_x(y^x)=\log_y(x^{4y})=10.$ What is the value of $xy$?",
route="cot",
dataset="aime 2024",
)
print(obs)
# obs, score, *_ = await brain.step(
# query="Real numbers $x$ and $y$ with $x,y>1$ satisfy $\log_x(y^x)=\log_y(x^{4y})=10.$ What is the value of $xy$?",
# route="cot",
# dataset="aime 2024",
# )
# print(obs)

# obs, score, *_ = await brain.step(
# query="Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them",
Expand All @@ -121,22 +121,21 @@ async def smart_brain():
# )
# print(obs)
# llm.model = "z3-" + llm.model
cache_plan = os.path.join(current_file_dir, "aime", "plan_gpt4o.12.json")

obs, score, *_ = await brain.step(
query="Define $f(x)=|| x|-\tfrac{1}{2}|$ and $g(x)=|| x|-\tfrac{1}{4}|$. Find the number of intersections of the graphs of\[y=4 g(f(\sin (2 \pi x))) \quad\text{ and }\quad x=4 g(f(\cos (3 \pi y))).\]",
route="cot",
dataset="aime 2024",
cache_plan=cache_plan,
)
print(obs)
# cache_plan = os.path.join(current_file_dir, "aime", "plan_gpt4o.12.json")
#
# obs, score, *_ = await brain.step(
# query="Define $f(x)=|| x|-\tfrac{1}{2}|$ and $g(x)=|| x|-\tfrac{1}{4}|$. Find the number of intersections of the graphs of\[y=4 g(f(\sin (2 \pi x))) \quad\text{ and }\quad x=4 g(f(\cos (3 \pi y))).\]",
# route="cot",
# dataset="aime 2024",
# cache_plan=cache_plan,
# )
# print(obs)

obs, score, *_ = await brain.step(
query='''
from typing import List def has_close_elements(numbers: List[float], threshold: float) -> bool: """ Check if in given list of numbers, are any two numbers closer to each other than given threshold. >>> has_close_elements([1.0, 2.0, 3.0], 0.5) False >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3) True """''',
route="cot",
dataset="aime 2024",
cache_plan=cache_plan,
query_type="code_solution",
)
print(obs)

Expand All @@ -152,8 +151,7 @@ async def smart_brain():
strange_sort_list([]) == []
'''""",
route="cot",
dataset="aime 2024",
cache_plan=cache_plan,
query_type="code_solution",
)
print(obs)

Expand Down
20 changes: 20 additions & 0 deletions metagpt/minion/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
from pydantic import BaseModel, Field

from metagpt.minion.symbol_table import SymbolTable
from metagpt.utils.math_utils import extract_math_answer, extract_number_from_string


class PostProcessingType(Enum):
NONE = "none"
EXTRACT_NUMBER = "extract_number_from_string"
EXTRACT_MATH_ANSWER = "extract_math_answer"


class EnsembleStrategyType(Enum):
Expand Down Expand Up @@ -108,6 +115,10 @@ class Input(BaseModel):
# 新增字段
execution_state: ExecutionState = Field(default_factory=ExecutionState)

post_processing: PostProcessingType = Field(
default=PostProcessingType.NONE, description="The type of post-processing to apply to the answer"
)

def save_state(self, file_path: str):
"""将当前状态保存到文件"""
import os
Expand Down Expand Up @@ -144,5 +155,14 @@ def context(self):
def context(self, context):
self.long_context = context

def apply_post_processing(self, raw_answer: str) -> Any:
"""Apply the specified post-processing to the raw answer."""
if self.post_processing == PostProcessingType.EXTRACT_NUMBER:
return extract_number_from_string(raw_answer)
elif self.post_processing == PostProcessingType.EXTRACT_MATH_ANSWER:
return extract_math_answer(raw_answer)
else:
return raw_answer


Task.update_forward_refs()
51 changes: 21 additions & 30 deletions metagpt/minion/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@
)
from metagpt.minion.symbol_table import Symbol
from metagpt.minion.task_graph import convert_tasks_to_graph
from metagpt.minion.utils import (
extract_math_answer,
extract_number_from_string,
most_similar_minion,
)
from metagpt.minion.utils import most_similar_minion
from metagpt.utils.custom_decoder import CustomDecoder


Expand Down Expand Up @@ -482,7 +478,7 @@ async def execute(self):
return await self.execute_calculation()
elif self.input.query_type == "code_solution":
return await self.execute_code_solution()
elif self.input.query_type == "file_creation":
elif self.input.query_type == "file_creation" or self.input.query_type == "generate":
return await self.execute_file_creation()
else:
return await self.execute_calculation() # 默认行为
Expand Down Expand Up @@ -756,32 +752,26 @@ async def execute_ensemble(self):
self.save_execution_state()

raw_answer = await self.invoke_minion(minion_name)

if minion.get("post_processing", None) == "extract_number_from_string":
result = extract_number_from_string(raw_answer)
elif minion.get("post_processing", None) == "extract_math_answer":
result = extract_math_answer(raw_answer)
else:
result = raw_answer
processed_answer = raw_answer # already handled in route minion?
# processed_answer = self.input.apply_post_processing(raw_answer)

weight = minion.get("weight", 1)

await self.update_stats(minion_name, result, raw_answer)
await self.update_stats(minion_name, processed_answer, raw_answer)

if True: # result: todo: consider how to handle result is None case
# Update the results dictionary
if result in results:
results[result] += weight
if True: # 考虑如何处理 processed_answer 为 None 的情况
if processed_answer in results:
results[processed_answer] += weight
else:
results[result] = weight
results[processed_answer] = weight

# short circuit logic, check if this result has reached the majority count
# 短路逻辑
if (
self.input.ensemble_logic["ensemble_strategy"].get("short_circuit", True)
and results[result] >= majority_count
and results[processed_answer] >= majority_count
):
self.answer = self.input.answer = result
return result # Majority found, return it
self.answer = self.input.answer = processed_answer
return processed_answer

# No result reached majority; find the result with the highest weight
most_weight = max(results.values())
Expand All @@ -804,11 +794,10 @@ async def execute(self):

if self.input.execution_state.current_minion:
# 从上次状态恢复
minion_name = self.input.execution_state.current_minion
if self.input.ensemble_logic:
await self.execute_ensemble()
else:
await self.invoke_minion(minion_name)
await self.execute_single()
else:
# 开始新的执行
await self.choose_minion_and_run()
Expand Down Expand Up @@ -916,12 +905,14 @@ async def invoke_minion_and_improve(self, klass, name, max_iterations=3):
self.input.update_execution_state(current_iteration=iteration)
self.save_execution_state()

result = await self.invoke_minion(klass)
self.answer = self.input.answer = result
await self.update_stats(name, result, result)
raw_answer = await self.invoke_minion(klass)
processed_answer = self.input.apply_post_processing(raw_answer)

self.answer = self.input.answer = processed_answer
await self.update_stats(name, processed_answer, raw_answer)

if not self.input.check:
break # Exit the loop if checking is not required
break

check_minion = CheckMinion(input=self.input, brain=self.brain)
check_result = await check_minion.execute()
Expand All @@ -930,7 +921,7 @@ async def invoke_minion_and_improve(self, klass, name, max_iterations=3):
self.save_execution_state()

if check_result and check_result["correct"]:
break # Exit the loop if the result is correct
break

return self.answer

Expand Down

0 comments on commit daf02a7

Please sign in to comment.