-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor MultiStepAgent and improve some tests #598
Conversation
- Split agent initialization into multiple methods for better readability - Extracted tool and managed agent setup into separate methods - Simplified the `_run` method by breaking down complex logic - Improved code organization and reduced method complexity - Maintained existing functionality while improving code structure
- Added a new pytest fixture for planning prompt templates - Updated test_planning_step_first_step to use more robust testing approach - Improved model call verification with input message tracking - Added explicit task setting and more detailed assertion checks
- Replaced uv-specific import test with a more standard Python subprocess method - Use sys.executable to create a clean Python process for import testing - Simplified import test implementation while maintaining import verification
…ovements - Added specific model_id for HfApiModel test - Improved response assertion in HfApiModel test - Changed TransformersModel device_map to 'cpu' for more consistent testing
- Extend tool validation to support Tool subclasses - Update assertion to allow more flexible tool type checking - Maintain existing tool setup functionality while increasing type flexibility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for your contribution.
I left just a first comment below. I will review it in depth later.
This PR brings some great improvements, especially in terms of code readability and maintainability. I really appreciate the effort you put into refactoring, splitting larger methods into smaller, more focused ones is a solid application of Single Responsibility Principle (SRP) and makes the code easier to understand and test.
That said, reviewing such a large PR can be quite challenging 😅. In the future, it would be great to break changes into smaller, more manageable PRs. This would not only make the review process smoother but also help catch potential issues earlier.
tests/test_import.py
Outdated
# Run the import statement in an isolated virtual environment | ||
result = subprocess.run( | ||
["uv", "run", "--isolated", "--no-editable", "-"], input="import smolagents", text=True, capture_output=True | ||
) | ||
# Create a new Python process to test the import | ||
cmd = [sys.executable, "-c", "import smolagents"] | ||
result = subprocess.run(cmd, capture_output=True, text=True) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this change breaks the purpose of the test:
- test import of smolagents where the lib has been installed without extras.
The new code tests the import of smolagents in current Python environment, where the lib was installed with all the extras.
@albertvillanova totally understand - apologies for the large PR! I'll break them down into smaller ones next time 👍 |
- Replace exec() with ipython_shell.run_cell() for more accurate IPython session simulation - Update test cases to access class from ipython_shell.user_ns instead of locals() - Maintain existing test logic while improving IPython interaction accuracy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, please note that there are currently conflicts with the main branch: these need to be resolved so that CI tests can run automatically.
src/smolagents/agents.py
Outdated
def _initialize_agent( | ||
self, | ||
tools, | ||
model, | ||
prompt_templates, | ||
max_steps, | ||
tool_parser, | ||
add_base_tools, | ||
verbosity_level, | ||
grammar, | ||
managed_agents, | ||
step_callbacks, | ||
planning_interval, | ||
name, | ||
description, | ||
provide_run_summary, | ||
final_answer_checks, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say this function is not necessary: assign parameters to attributes is indeed the primary purpose of the __init__
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment has not been address.
src/smolagents/agents.py
Outdated
def _setup_managed_agents(self, managed_agents): | ||
self.managed_agents = {} | ||
if managed_agents: | ||
assert all(agent.name and agent.description for agent in managed_agents), ( | ||
"All managed agents need both a name and a description!" | ||
) | ||
self.managed_agents = {agent.name: agent for agent in managed_agents} | ||
|
||
def _setup_tools(self, tools, add_base_tools): | ||
assert all(isinstance(tool, Tool) or issubclass(type(tool), Tool) for tool in tools), ( | ||
"All elements must be Tool or a subclass of Tool" | ||
) | ||
self.tools = {tool.name: tool for tool in tools} | ||
if add_base_tools: | ||
self.tools.update( | ||
{ | ||
name: cls() | ||
for name, cls in TOOL_MAPPING.items() | ||
if name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent" | ||
} | ||
) | ||
self.tools["final_answer"] = FinalAnswerTool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea to define these 2 methods!
src/smolagents/agents.py
Outdated
def _run(self, task: str, images: List[str] | None = None) -> Generator[ActionStep | AgentType, None, None]: | ||
final_answer, self.step_number = None, 1 | ||
while final_answer is None and self.step_number <= self.max_steps: | ||
step_start_time = time.time() | ||
memory_step = self._create_memory_step(step_start_time, images) | ||
try: | ||
final_answer = self._execute_step(task, memory_step) | ||
except AgentError as e: | ||
memory_step.error = e | ||
finally: | ||
self._finalize_step(memory_step, step_start_time) | ||
yield memory_step | ||
self.step_number += 1 | ||
|
||
if final_answer is None and self.step_number == self.max_steps + 1: | ||
final_answer = self._handle_max_steps_reached(task, images, step_start_time) | ||
yield memory_step | ||
yield handle_agent_output_types(final_answer) | ||
|
||
def _create_memory_step(self, step_start_time: float, images: List[str] | None) -> ActionStep: | ||
return ActionStep(step_number=self.step_number, start_time=step_start_time, observations_images=images) | ||
|
||
def _execute_step(self, task: str, memory_step: ActionStep) -> Union[None, Any]: | ||
if self.planning_interval is not None and self.step_number % self.planning_interval == 1: | ||
self.planning_step(task, is_first_step=(self.step_number == 1), step=self.step_number) | ||
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO) | ||
final_answer = self.step(memory_step) | ||
if final_answer is not None and self.final_answer_checks: | ||
self._validate_final_answer(final_answer) | ||
return final_answer | ||
|
||
def _validate_final_answer(self, final_answer: Any): | ||
for check_function in self.final_answer_checks: | ||
try: | ||
assert check_function(final_answer, self.memory) | ||
except Exception as e: | ||
raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger) | ||
|
||
def _finalize_step(self, memory_step: ActionStep, step_start_time: float): | ||
memory_step.end_time = time.time() | ||
memory_step.duration = memory_step.end_time - step_start_time | ||
self.memory.steps.append(memory_step) | ||
for callback in self.step_callbacks: | ||
callback(memory_step) if len(inspect.signature(callback).parameters) == 1 else callback( | ||
memory_step, agent=self | ||
) | ||
|
||
def _handle_max_steps_reached(self, task: str, images: List[str], step_start_time: float) -> Any: | ||
final_answer = self.provide_final_answer(task, images) | ||
final_memory_step = ActionStep( | ||
step_number=self.step_number, error=AgentMaxStepsError("Reached max steps.", self.logger) | ||
) | ||
final_memory_step.action_output = final_answer | ||
final_memory_step.end_time = time.time() | ||
final_memory_step.duration = final_memory_step.end_time - step_start_time | ||
self.memory.steps.append(final_memory_step) | ||
for callback in self.step_callbacks: | ||
callback(final_memory_step) if len(inspect.signature(callback).parameters) == 1 else callback( | ||
final_memory_step, agent=self | ||
) | ||
return final_answer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also nice refactoring here!
src/smolagents/agents.py
Outdated
facts_message, plan_message = ( | ||
self._generate_initial_plan(task) if is_first_step else self._generate_updated_plan(task, step) | ||
) | ||
self._record_planning_step(facts_message, plan_message) | ||
|
||
def _generate_initial_plan(self, task: str) -> Tuple[ChatMessage, ChatMessage]: | ||
message_prompt_facts = { | ||
"role": MessageRole.SYSTEM, | ||
"content": [{"type": "text", "text": self.prompt_templates["planning"]["initial_facts"]}], | ||
} | ||
message_prompt_task = { | ||
"role": MessageRole.USER, | ||
"content": [{"type": "text", "text": textwrap.dedent(f"Here is the task:\n```\n{task}\n```\nNow begin!")}], | ||
} | ||
input_messages = [message_prompt_facts, message_prompt_task] | ||
facts_message = self.model(input_messages) | ||
|
||
message_prompt_plan = { | ||
"role": MessageRole.USER, | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": populate_template( | ||
self.prompt_templates["planning"]["initial_plan"], | ||
variables={ | ||
"task": task, | ||
"tools": self.tools, | ||
"managed_agents": self.managed_agents, | ||
"answer_facts": facts_message.content, | ||
}, | ||
), | ||
} | ||
], | ||
} | ||
plan_message = self.model([message_prompt_plan], stop_sequences=["<end_plan>"]) | ||
return facts_message, plan_message | ||
|
||
def _generate_updated_plan(self, task: str, step: int) -> Tuple[ChatMessage, ChatMessage]: | ||
memory_messages = self.write_memory_to_messages()[1:] | ||
facts_update_pre = { | ||
"role": MessageRole.SYSTEM, | ||
"content": [{"type": "text", "text": self.prompt_templates["planning"]["update_facts_pre_messages"]}], | ||
} | ||
facts_update_post = { | ||
"role": MessageRole.USER, | ||
"content": [{"type": "text", "text": self.prompt_templates["planning"]["update_facts_post_messages"]}], | ||
} | ||
facts_message = self.model([facts_update_pre] + memory_messages + [facts_update_post]) | ||
|
||
update_plan_pre = { | ||
"role": MessageRole.SYSTEM, | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": populate_template( | ||
self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task} | ||
), | ||
} | ||
], | ||
} | ||
update_plan_post = { | ||
"role": MessageRole.USER, | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": populate_template( | ||
self.prompt_templates["planning"]["update_plan_post_messages"], | ||
variables={ | ||
"task": task, | ||
"tools": self.tools, | ||
"managed_agents": self.managed_agents, | ||
"facts_update": facts_message.content, | ||
"remaining_steps": (self.max_steps - step), | ||
}, | ||
), | ||
} | ||
], | ||
} | ||
plan_message = self.model( | ||
[update_plan_pre] + memory_messages + [update_plan_post], stop_sequences=["<end_plan>"] | ||
) | ||
return facts_message, plan_message | ||
|
||
def _record_planning_step(self, facts_message: ChatMessage, plan_message: ChatMessage): | ||
final_plan = textwrap.dedent( | ||
f"""I still need to solve the task I was given:\n```\n{self.task}\n```\n\nHere is my new/updated plan of action to solve the task:\n```\n{plan_message.content}\n```""" | ||
) | ||
final_facts = textwrap.dedent( | ||
f"""Here is the updated list of the facts that I know:\n```\n{facts_message.content}\n```""" | ||
) | ||
self.memory.steps.append( | ||
PlanningStep( | ||
model_input_messages=self.input_messages, | ||
plan=final_plan, | ||
facts=final_facts, | ||
model_output_message_plan=plan_message, | ||
model_output_message_facts=facts_message, | ||
) | ||
) | ||
self.logger.log(Rule("[bold]Updated plan", style="orange"), Text(final_plan), level=LogLevel.INFO) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And also this is a nice refactoring. Please, take into account that some modifications have been incorporated in the main branch. You need to incorporate them as well here when merging the main branch and resolving the conflicts.
For example, now in _generate_initial_plan
, initial facts have been merged into a single user role (no system role anymore)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albertvillanova I've made a lot of changes to this PR, and with new commits and test updates, it’s getting even larger. I want to ensure it's manageable for review and doesn’t create too much overhead.
Would you prefer I continue refining this PR, or would it be better to break it into smaller PRs to make reviewing easier? I’m happy to go either way—just want to do what’s most helpful for the project. Let me know your thoughts. Thanks for your time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worry, it is OK we continue with this PR once we have already started...
Oh, it seems you forgot to run the quality checks: https://github.com/huggingface/smolagents/blob/main/CONTRIBUTING.md |
The CI tests are red as well. Could you please address them? |
tests/test_agents.py
Outdated
@@ -968,6 +968,185 @@ def __call__( | |||
# Test that visualization works | |||
manager_code_agent.visualize() | |||
|
|||
<<<<<<< HEAD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a remnant of a conflict that needs to be fixed.
src/smolagents/agents.py
Outdated
self.memory.steps.append(memory_step) | ||
for callback in self.step_callbacks: | ||
# For compatibility with old callbacks that don't take the agent as an argument | ||
if len(inspect.signature(callback).parameters) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to do
if "agent" in inspect.signature(callback).parameters:
callback(memory_step, agent=self)
else:
callback(memoery_step)
Because there could have been callbacks with default args, for example:
def my_callback(memory_step, some_optional_param=None)
pass
This check will treat it as a "new style" callback and fail. All we care here is if it has agent
kwarg, so let's check secifically for it.
src/smolagents/agents.py
Outdated
) | ||
final_memory_step.action_output = final_answer | ||
final_memory_step.end_time = time.time() | ||
final_memory_step.duration = memory_step.end_time - step_start_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
final_memory_step.duration = memory_step.end_time - step_start_time | |
final_memory_step.duration = final_memory_step.end_time - step_start_time |
A typo?
d2a734d
to
bda33be
Compare
There were issues with the resolution of the conflicts when merging the main branch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
This PR introduces several significant improvements to the test suite and agent implementation:
MultiStepAgent Refactoring
_run
method by breaking down complex logicTest Suite Enhancements
Planning Step Testing (
test_agents.py
)Import Testing (
test_import.py
)Model Testing (
test_models.py
)These changes improve test reliability, maintainability, and coverage while making the codebase more robust and easier to understand.