From 4bb760b4edac9396641d867dd0ba81df39506598 Mon Sep 17 00:00:00 2001 From: femto Date: Thu, 14 Nov 2024 16:34:08 +0800 Subject: [PATCH] check route --- examples/smart_minion/brain.py | 6 ++++-- minion/main/check_route.py | 14 +++++++------- minion/main/input.py | 1 + minion/main/minion.py | 4 +++- minion/main/worker.py | 11 +++++++++-- 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/examples/smart_minion/brain.py b/examples/smart_minion/brain.py index b65a4c11..4fec0ecd 100644 --- a/examples/smart_minion/brain.py +++ b/examples/smart_minion/brain.py @@ -46,10 +46,12 @@ async def smart_brain(): obs, score, *_ = await brain.step( query=test_data["prompt"], - route="ldb", + route="python", post_processing="extract_python", entry_point=test_data["entry_point"], - check=False, + check=10, + check_route="ldb_check", + dataset="HumanEval", metadata={"test_cases": test_data["test"]} # 添加测试用例到 metadata ) print(obs) diff --git a/minion/main/check_route.py b/minion/main/check_route.py index a49fd3f7..31136c58 100644 --- a/minion/main/check_route.py +++ b/minion/main/check_route.py @@ -58,18 +58,18 @@ class CheckRouterMinion(Minion): async def choose_checker(self): """Choose appropriate checker based on input characteristics""" try: - # First check input.check_route + # First check worker_config + if self.worker_config and self.worker_config.get('check_route', None): + checker_name = most_similar_minion(self.worker_config['check_route'], CHECK_MINION_REGISTRY.keys()) + logger.info(f"Using checker from worker config: {checker_name}") + return CHECK_MINION_REGISTRY.get(checker_name, CHECK_MINION_REGISTRY.get("check")) + + # Then check input.check_route if hasattr(self.input, 'check_route') and self.input.check_route: checker_name = most_similar_minion(self.input.check_route, CHECK_MINION_REGISTRY.keys()) logger.info(f"Using checker from input.check_route: {checker_name}") return CHECK_MINION_REGISTRY.get(checker_name, CHECK_MINION_REGISTRY.get("check")) - # Then check worker_config - if hasattr(self, 'worker_config') and self.worker_config.get('check_route', None): - checker_name = most_similar_minion(self.worker_config['check_route'], CHECK_MINION_REGISTRY.keys()) - logger.info(f"Using checker from worker config: {checker_name}") - return CHECK_MINION_REGISTRY.get(checker_name, CHECK_MINION_REGISTRY.get("check")) - # Prepare template for LLM recommendation choose_template = Template(CHECK_ROUTE_PROMPT) filled_template = choose_template.render( diff --git a/minion/main/input.py b/minion/main/input.py index 60cd7273..be889757 100644 --- a/minion/main/input.py +++ b/minion/main/input.py @@ -100,6 +100,7 @@ class Input(BaseModel): answer_protocol: str = "" # Protocol for answer formatting, should we call it answer_format? execution_config: dict = {} # Configuration for execution, like ensemble stragety etc. check: Union[bool,int] = True # Whether to perform validation + check_route:str = "" # Whether to perform validation # Metadata dataset: str = "" # Source dataset identifier diff --git a/minion/main/minion.py b/minion/main/minion.py index e63ff4de..c178cc51 100644 --- a/minion/main/minion.py +++ b/minion/main/minion.py @@ -48,7 +48,7 @@ def decorator(cls): class Minion(metaclass=SubclassHookMeta): - def __init__(self, input=None, brain=None, id=None, score_func=None, task=None, **kwargs): + def __init__(self, input=None, brain=None, id=None, score_func=None, worker_config=None, task=None, **kwargs): if brain is None: raise ValueError("The 'brain' parameter cannot be None.") @@ -58,6 +58,8 @@ def __init__(self, input=None, brain=None, id=None, score_func=None, task=None, self.brain = brain self.followers = [] self.score_func = score_func + + self.worker_config = worker_config self.task = task def propagate_information(self, other): diff --git a/minion/main/worker.py b/minion/main/worker.py index 9d671bc4..0f29bb83 100644 --- a/minion/main/worker.py +++ b/minion/main/worker.py @@ -395,9 +395,16 @@ def __init__(self, **kwargs): self.python_env = self.brain.python_env async def execute(self): + # Check post_processing setting, giving precedence to worker_config + post_processing = None + if self.worker_config and 'post_processing' in self.worker_config: + post_processing = self.worker_config['post_processing'] + elif self.input.post_processing: + post_processing = self.input.post_processing + if self.input.query_type == "calculate": return await self.execute_calculation() - elif self.input.query_type == "code_solution": + elif post_processing == "extract_python" or self.input.query_type == "code_solution": return await self.execute_code_solution() elif self.input.query_type == "generate": return await self.execute_generation() @@ -840,7 +847,7 @@ async def invoke_minion_and_improve(self, klass, name, max_iterations=3): self.input.update_execution_state(current_iteration=iteration) self.save_execution_state() - check_router_minion = CheckRouterMinion(input=self.input, brain=self.brain) + check_router_minion = CheckRouterMinion(input=self.input, brain=self.brain, worker_config=self.worker_config) check_result = await check_router_minion.execute() self.input.update_execution_state(check_result=check_result)