Skip to content

Commit

Permalink
check route
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Nov 14, 2024
1 parent 116635f commit 4bb760b
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 12 deletions.
6 changes: 4 additions & 2 deletions examples/smart_minion/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ async def smart_brain():

obs, score, *_ = await brain.step(
query=test_data["prompt"],
route="ldb",
route="python",
post_processing="extract_python",
entry_point=test_data["entry_point"],
check=False,
check=10,
check_route="ldb_check",
dataset="HumanEval",
metadata={"test_cases": test_data["test"]} # 添加测试用例到 metadata
)
print(obs)
Expand Down
14 changes: 7 additions & 7 deletions minion/main/check_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ class CheckRouterMinion(Minion):
async def choose_checker(self):
"""Choose appropriate checker based on input characteristics"""
try:
# First check input.check_route
# First check worker_config
if self.worker_config and self.worker_config.get('check_route', None):
checker_name = most_similar_minion(self.worker_config['check_route'], CHECK_MINION_REGISTRY.keys())
logger.info(f"Using checker from worker config: {checker_name}")
return CHECK_MINION_REGISTRY.get(checker_name, CHECK_MINION_REGISTRY.get("check"))

# Then check input.check_route
if hasattr(self.input, 'check_route') and self.input.check_route:
checker_name = most_similar_minion(self.input.check_route, CHECK_MINION_REGISTRY.keys())
logger.info(f"Using checker from input.check_route: {checker_name}")
return CHECK_MINION_REGISTRY.get(checker_name, CHECK_MINION_REGISTRY.get("check"))

# Then check worker_config
if hasattr(self, 'worker_config') and self.worker_config.get('check_route', None):
checker_name = most_similar_minion(self.worker_config['check_route'], CHECK_MINION_REGISTRY.keys())
logger.info(f"Using checker from worker config: {checker_name}")
return CHECK_MINION_REGISTRY.get(checker_name, CHECK_MINION_REGISTRY.get("check"))

# Prepare template for LLM recommendation
choose_template = Template(CHECK_ROUTE_PROMPT)
filled_template = choose_template.render(
Expand Down
1 change: 1 addition & 0 deletions minion/main/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class Input(BaseModel):
answer_protocol: str = "" # Protocol for answer formatting, should we call it answer_format?
execution_config: dict = {} # Configuration for execution, like ensemble stragety etc.
check: Union[bool,int] = True # Whether to perform validation
check_route:str = "" # Whether to perform validation

# Metadata
dataset: str = "" # Source dataset identifier
Expand Down
4 changes: 3 additions & 1 deletion minion/main/minion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def decorator(cls):


class Minion(metaclass=SubclassHookMeta):
def __init__(self, input=None, brain=None, id=None, score_func=None, task=None, **kwargs):
def __init__(self, input=None, brain=None, id=None, score_func=None, worker_config=None, task=None, **kwargs):
if brain is None:
raise ValueError("The 'brain' parameter cannot be None.")

Expand All @@ -58,6 +58,8 @@ def __init__(self, input=None, brain=None, id=None, score_func=None, task=None,
self.brain = brain
self.followers = []
self.score_func = score_func

self.worker_config = worker_config
self.task = task

def propagate_information(self, other):
Expand Down
11 changes: 9 additions & 2 deletions minion/main/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,16 @@ def __init__(self, **kwargs):
self.python_env = self.brain.python_env

async def execute(self):
# Check post_processing setting, giving precedence to worker_config
post_processing = None
if self.worker_config and 'post_processing' in self.worker_config:
post_processing = self.worker_config['post_processing']
elif self.input.post_processing:
post_processing = self.input.post_processing

if self.input.query_type == "calculate":
return await self.execute_calculation()
elif self.input.query_type == "code_solution":
elif post_processing == "extract_python" or self.input.query_type == "code_solution":
return await self.execute_code_solution()
elif self.input.query_type == "generate":
return await self.execute_generation()
Expand Down Expand Up @@ -840,7 +847,7 @@ async def invoke_minion_and_improve(self, klass, name, max_iterations=3):
self.input.update_execution_state(current_iteration=iteration)
self.save_execution_state()

check_router_minion = CheckRouterMinion(input=self.input, brain=self.brain)
check_router_minion = CheckRouterMinion(input=self.input, brain=self.brain, worker_config=self.worker_config)
check_result = await check_router_minion.execute()

self.input.update_execution_state(check_result=check_result)
Expand Down

0 comments on commit 4bb760b

Please sign in to comment.