Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MultiStepAgent and improve some tests #598

Merged
merged 12 commits into from
Feb 20, 2025

Conversation

colesmcintosh
Copy link
Contributor

This PR introduces several significant improvements to the test suite and agent implementation:

MultiStepAgent Refactoring

  • Split agent initialization logic into multiple methods for improved readability
  • Extracted tool and managed agent setup into dedicated methods
  • Simplified the _run method by breaking down complex logic
  • Enhanced code organization while maintaining existing functionality

Test Suite Enhancements

  1. Planning Step Testing (test_agents.py)

    • Added new pytest fixture for planning prompt templates
    • Implemented more robust testing approach for planning step verification
    • Enhanced model call verification with input message tracking
    • Added explicit task setting and detailed assertion checks
  2. Import Testing (test_import.py)

    • Replaced UV-specific import test with standard Python subprocess method
    • Improved import verification using sys.executable
    • Simplified implementation while maintaining thorough import validation
  3. Model Testing (test_models.py)

    • Added specific model_id for HfApiModel test
    • Improved response assertion in HfApiModel test
    • Changed TransformersModel device_map to 'cpu' for consistent testing

These changes improve test reliability, maintainability, and coverage while making the codebase more robust and easier to understand.

- Split agent initialization into multiple methods for better readability
- Extracted tool and managed agent setup into separate methods
- Simplified the `_run` method by breaking down complex logic
- Improved code organization and reduced method complexity
- Maintained existing functionality while improving code structure
- Added a new pytest fixture for planning prompt templates
- Updated test_planning_step_first_step to use more robust testing approach
- Improved model call verification with input message tracking
- Added explicit task setting and more detailed assertion checks
- Replaced uv-specific import test with a more standard Python subprocess method
- Use sys.executable to create a clean Python process for import testing
- Simplified import test implementation while maintaining import verification
…ovements

- Added specific model_id for HfApiModel test
- Improved response assertion in HfApiModel test
- Changed TransformersModel device_map to 'cpu' for more consistent testing
@colesmcintosh colesmcintosh marked this pull request as draft February 11, 2025 04:17
@colesmcintosh colesmcintosh marked this pull request as ready for review February 11, 2025 04:17
- Extend tool validation to support Tool subclasses
- Update assertion to allow more flexible tool type checking
- Maintain existing tool setup functionality while increasing type flexibility
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for your contribution.

I left just a first comment below. I will review it in depth later.

This PR brings some great improvements, especially in terms of code readability and maintainability. I really appreciate the effort you put into refactoring, splitting larger methods into smaller, more focused ones is a solid application of Single Responsibility Principle (SRP) and makes the code easier to understand and test.

That said, reviewing such a large PR can be quite challenging 😅. In the future, it would be great to break changes into smaller, more manageable PRs. This would not only make the review process smoother but also help catch potential issues earlier.

Comment on lines 5 to 9
# Run the import statement in an isolated virtual environment
result = subprocess.run(
["uv", "run", "--isolated", "--no-editable", "-"], input="import smolagents", text=True, capture_output=True
)
# Create a new Python process to test the import
cmd = [sys.executable, "-c", "import smolagents"]
result = subprocess.run(cmd, capture_output=True, text=True)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change breaks the purpose of the test:

  • test import of smolagents where the lib has been installed without extras.

The new code tests the import of smolagents in current Python environment, where the lib was installed with all the extras.

@colesmcintosh
Copy link
Contributor Author

@albertvillanova totally understand - apologies for the large PR! I'll break them down into smaller ones next time 👍

- Replace exec() with ipython_shell.run_cell() for more accurate IPython session simulation
- Update test cases to access class from ipython_shell.user_ns instead of locals()
- Maintain existing test logic while improving IPython interaction accuracy
Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, please note that there are currently conflicts with the main branch: these need to be resolved so that CI tests can run automatically.

Comment on lines 142 to 159
def _initialize_agent(
self,
tools,
model,
prompt_templates,
max_steps,
tool_parser,
add_base_tools,
verbosity_level,
grammar,
managed_agents,
step_callbacks,
planning_interval,
name,
description,
provide_run_summary,
final_answer_checks,
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say this function is not necessary: assign parameters to attributes is indeed the primary purpose of the __init__ method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment has not been address.

Comment on lines 178 to 199
def _setup_managed_agents(self, managed_agents):
self.managed_agents = {}
if managed_agents:
assert all(agent.name and agent.description for agent in managed_agents), (
"All managed agents need both a name and a description!"
)
self.managed_agents = {agent.name: agent for agent in managed_agents}

def _setup_tools(self, tools, add_base_tools):
assert all(isinstance(tool, Tool) or issubclass(type(tool), Tool) for tool in tools), (
"All elements must be Tool or a subclass of Tool"
)
self.tools = {tool.name: tool for tool in tools}
if add_base_tools:
self.tools.update(
{
name: cls()
for name, cls in TOOL_MAPPING.items()
if name != "python_interpreter" or self.__class__.__name__ == "ToolCallingAgent"
}
)
self.tools["final_answer"] = FinalAnswerTool()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to define these 2 methods!

Comment on lines 201 to 261
def _run(self, task: str, images: List[str] | None = None) -> Generator[ActionStep | AgentType, None, None]:
final_answer, self.step_number = None, 1
while final_answer is None and self.step_number <= self.max_steps:
step_start_time = time.time()
memory_step = self._create_memory_step(step_start_time, images)
try:
final_answer = self._execute_step(task, memory_step)
except AgentError as e:
memory_step.error = e
finally:
self._finalize_step(memory_step, step_start_time)
yield memory_step
self.step_number += 1

if final_answer is None and self.step_number == self.max_steps + 1:
final_answer = self._handle_max_steps_reached(task, images, step_start_time)
yield memory_step
yield handle_agent_output_types(final_answer)

def _create_memory_step(self, step_start_time: float, images: List[str] | None) -> ActionStep:
return ActionStep(step_number=self.step_number, start_time=step_start_time, observations_images=images)

def _execute_step(self, task: str, memory_step: ActionStep) -> Union[None, Any]:
if self.planning_interval is not None and self.step_number % self.planning_interval == 1:
self.planning_step(task, is_first_step=(self.step_number == 1), step=self.step_number)
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
final_answer = self.step(memory_step)
if final_answer is not None and self.final_answer_checks:
self._validate_final_answer(final_answer)
return final_answer

def _validate_final_answer(self, final_answer: Any):
for check_function in self.final_answer_checks:
try:
assert check_function(final_answer, self.memory)
except Exception as e:
raise AgentError(f"Check {check_function.__name__} failed with error: {e}", self.logger)

def _finalize_step(self, memory_step: ActionStep, step_start_time: float):
memory_step.end_time = time.time()
memory_step.duration = memory_step.end_time - step_start_time
self.memory.steps.append(memory_step)
for callback in self.step_callbacks:
callback(memory_step) if len(inspect.signature(callback).parameters) == 1 else callback(
memory_step, agent=self
)

def _handle_max_steps_reached(self, task: str, images: List[str], step_start_time: float) -> Any:
final_answer = self.provide_final_answer(task, images)
final_memory_step = ActionStep(
step_number=self.step_number, error=AgentMaxStepsError("Reached max steps.", self.logger)
)
final_memory_step.action_output = final_answer
final_memory_step.end_time = time.time()
final_memory_step.duration = final_memory_step.end_time - step_start_time
self.memory.steps.append(final_memory_step)
for callback in self.step_callbacks:
callback(final_memory_step) if len(inspect.signature(callback).parameters) == 1 else callback(
final_memory_step, agent=self
)
return final_answer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also nice refactoring here!

Comment on lines 263 to 363
facts_message, plan_message = (
self._generate_initial_plan(task) if is_first_step else self._generate_updated_plan(task, step)
)
self._record_planning_step(facts_message, plan_message)

def _generate_initial_plan(self, task: str) -> Tuple[ChatMessage, ChatMessage]:
message_prompt_facts = {
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": self.prompt_templates["planning"]["initial_facts"]}],
}
message_prompt_task = {
"role": MessageRole.USER,
"content": [{"type": "text", "text": textwrap.dedent(f"Here is the task:\n```\n{task}\n```\nNow begin!")}],
}
input_messages = [message_prompt_facts, message_prompt_task]
facts_message = self.model(input_messages)

message_prompt_plan = {
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["initial_plan"],
variables={
"task": task,
"tools": self.tools,
"managed_agents": self.managed_agents,
"answer_facts": facts_message.content,
},
),
}
],
}
plan_message = self.model([message_prompt_plan], stop_sequences=["<end_plan>"])
return facts_message, plan_message

def _generate_updated_plan(self, task: str, step: int) -> Tuple[ChatMessage, ChatMessage]:
memory_messages = self.write_memory_to_messages()[1:]
facts_update_pre = {
"role": MessageRole.SYSTEM,
"content": [{"type": "text", "text": self.prompt_templates["planning"]["update_facts_pre_messages"]}],
}
facts_update_post = {
"role": MessageRole.USER,
"content": [{"type": "text", "text": self.prompt_templates["planning"]["update_facts_post_messages"]}],
}
facts_message = self.model([facts_update_pre] + memory_messages + [facts_update_post])

update_plan_pre = {
"role": MessageRole.SYSTEM,
"content": [
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["update_plan_pre_messages"], variables={"task": task}
),
}
],
}
update_plan_post = {
"role": MessageRole.USER,
"content": [
{
"type": "text",
"text": populate_template(
self.prompt_templates["planning"]["update_plan_post_messages"],
variables={
"task": task,
"tools": self.tools,
"managed_agents": self.managed_agents,
"facts_update": facts_message.content,
"remaining_steps": (self.max_steps - step),
},
),
}
],
}
plan_message = self.model(
[update_plan_pre] + memory_messages + [update_plan_post], stop_sequences=["<end_plan>"]
)
return facts_message, plan_message

def _record_planning_step(self, facts_message: ChatMessage, plan_message: ChatMessage):
final_plan = textwrap.dedent(
f"""I still need to solve the task I was given:\n```\n{self.task}\n```\n\nHere is my new/updated plan of action to solve the task:\n```\n{plan_message.content}\n```"""
)
final_facts = textwrap.dedent(
f"""Here is the updated list of the facts that I know:\n```\n{facts_message.content}\n```"""
)
self.memory.steps.append(
PlanningStep(
model_input_messages=self.input_messages,
plan=final_plan,
facts=final_facts,
model_output_message_plan=plan_message,
model_output_message_facts=facts_message,
)
)
self.logger.log(Rule("[bold]Updated plan", style="orange"), Text(final_plan), level=LogLevel.INFO)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also this is a nice refactoring. Please, take into account that some modifications have been incorporated in the main branch. You need to incorporate them as well here when merging the main branch and resolving the conflicts.

For example, now in _generate_initial_plan, initial facts have been merged into a single user role (no system role anymore)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albertvillanova I've made a lot of changes to this PR, and with new commits and test updates, it’s getting even larger. I want to ensure it's manageable for review and doesn’t create too much overhead.

Would you prefer I continue refining this PR, or would it be better to break it into smaller PRs to make reviewing easier? I’m happy to go either way—just want to do what’s most helpful for the project. Let me know your thoughts. Thanks for your time!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worry, it is OK we continue with this PR once we have already started...

@albertvillanova
Copy link
Member

Oh, it seems you forgot to run the quality checks: https://github.com/huggingface/smolagents/blob/main/CONTRIBUTING.md

@albertvillanova
Copy link
Member

albertvillanova commented Feb 18, 2025

The CI tests are red as well. Could you please address them?

@@ -968,6 +968,185 @@ def __call__(
# Test that visualization works
manager_code_agent.visualize()

<<<<<<< HEAD
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a remnant of a conflict that needs to be fixed.

self.memory.steps.append(memory_step)
for callback in self.step_callbacks:
# For compatibility with old callbacks that don't take the agent as an argument
if len(inspect.signature(callback).parameters) == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to do

if "agent" in inspect.signature(callback).parameters:
    callback(memory_step, agent=self)
else:
   callback(memoery_step)

Because there could have been callbacks with default args, for example:

def my_callback(memory_step, some_optional_param=None)
   pass

This check will treat it as a "new style" callback and fail. All we care here is if it has agent kwarg, so let's check secifically for it.

)
final_memory_step.action_output = final_answer
final_memory_step.end_time = time.time()
final_memory_step.duration = memory_step.end_time - step_start_time
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
final_memory_step.duration = memory_step.end_time - step_start_time
final_memory_step.duration = final_memory_step.end_time - step_start_time

A typo?

@albertvillanova
Copy link
Member

albertvillanova commented Feb 20, 2025

There were issues with the resolution of the conflicts when merging the main branch.
I have fixed them and reverted some changes.

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@albertvillanova albertvillanova changed the title Test Suite Improvements and MultiStepAgent Refactoring Refactor MultiStepAgent and improve some tests Feb 20, 2025
@albertvillanova albertvillanova merged commit 0a9b7b4 into huggingface:main Feb 20, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants