Skip to content

Commit

Permalink
2
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Nov 14, 2024
1 parent 608da9d commit 116635f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
24 changes: 19 additions & 5 deletions minion/main/minion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from minion.utils.answer_extraction import math_equal

MINION_REGISTRY = {}
MINION_ROUTE_DOWNSTREAM = {}
WORKER_MINIONS = {}


# a dummy score that does nothing, always return 1 to shortcut the score process
Expand All @@ -27,10 +27,24 @@ def __init__(cls, name, bases, clsdict):
cls._subclassed_hook()


def register_route_downstream(cls):
# Register the class in the dictionary with its name as the key
MINION_ROUTE_DOWNSTREAM[camel_case_to_snake_case(cls.__name__)] = cls
return cls
def register_worker_minion(cls=None, *, name=None):
"""Decorator to register worker minions.
Can be used as @register_worker_minion or @register_worker_minion(name="custom_name")
Args:
cls: The class to register (when used as @register_worker_minion)
name: Optional custom name (when used as @register_worker_minion(name="custom_name"))
"""
def decorator(cls):
# Use custom name if provided, otherwise convert class name to snake_case
register_name = name if name is not None else camel_case_to_snake_case(cls.__name__)
WORKER_MINIONS[register_name] = cls
return cls

# Handle both @register_worker_minion and @register_worker_minion(name="custom_name")
if cls is None:
return decorator
return decorator(cls)


class Minion(metaclass=SubclassHookMeta):
Expand Down
20 changes: 10 additions & 10 deletions minion/main/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from minion.main.input import Input
from minion.main.minion import (
MINION_REGISTRY,
MINION_ROUTE_DOWNSTREAM,
WORKER_MINIONS,
Minion,
register_route_downstream,
register_worker_minion,
)
from minion.main.prompt import (
ASK_PROMPT_JINJA,
Expand Down Expand Up @@ -65,7 +65,7 @@
class WorkerMinion(Minion):
pass

@register_route_downstream
@register_worker_minion
class NativeMinion(WorkerMinion):
"""native minion, directly asks llm for answer"""

Expand All @@ -84,7 +84,7 @@ async def execute(self):
return self.answer


@register_route_downstream
@register_worker_minion
class CotMinion(WorkerMinion):
"""Chain of Thought (CoT) Strategy, Ask the LLM to think step-by-step, explaining each part of the problem to enhance the accuracy of the answer. Please noted you can't access web or user's local computer, so if you need information from the web or from user's local computer, DON'T USE THIS STRATEGY."""

Expand Down Expand Up @@ -167,13 +167,13 @@ async def execute(self):
return self.answer


@register_route_downstream
@register_worker_minion
class MultiPlanMinion(WorkerMinion):
"This Strategy will first generate multiple plan, and then compare each plan, see which one is more promising to produce good result, first try most promising plan, then to less promising plan."
pass


@register_route_downstream
@register_worker_minion
class PlanMinion(WorkerMinion):
"Divide and Conquer Strategy, Divide the problem into smaller subproblems, solve each subproblem independently, and then merge the results for the final solution."

Expand Down Expand Up @@ -319,7 +319,7 @@ async def resume(self):
await self.execute()


@register_route_downstream
@register_worker_minion
class MathPlanMinion(PlanMinion):
def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -386,7 +386,7 @@ async def execute(self):
return await self.choose_minion_and_run()


@register_route_downstream
@register_worker_minion
class PythonMinion(WorkerMinion):
"This problem requires writing code to solve it, write python code to solve it"

Expand Down Expand Up @@ -529,7 +529,7 @@ def save_files(self, file_structure):
f.write(content)


@register_route_downstream
@register_worker_minion
class MathMinion(PythonMinion):
"This is a problem involve math, you need to use math tool to solve it"

Expand Down Expand Up @@ -888,7 +888,7 @@ def deserialize_function(func_str: str) -> Callable:
return dill.loads(bytes.fromhex(func_str))


@register_route_downstream
@register_worker_minion
class OptillmMinion(WorkerMinion):
"""Minion that uses Optillm approaches"""

Expand Down

0 comments on commit 116635f

Please sign in to comment.