Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add extra_create_args to AssistantAgent for model client custom… #5213

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SystemMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool, BaseTool
from autogen_core.tools import BaseTool, FunctionTool
from pydantic import BaseModel
from typing_extensions import Self

Expand Down Expand Up @@ -64,6 +64,7 @@ class AssistantAgentConfig(BaseModel):
system_message: str | None = None
reflect_on_tool_use: bool
tool_call_summary_format: str
extra_create_args: Mapping[str, Any] | None = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is correct, since Any could not be serialized.



class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
Expand Down Expand Up @@ -147,6 +148,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
Available variables: `{tool_name}`, `{arguments}`, `{result}`.
For example, `"{tool_name}: {result}"` will create a summary like `"tool_name: result"`.
memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`.
extra_create_args (Mapping[str, Any] | None, optional): Additional arguments to pass to the model client during the create method call. Defaults to `None`.

Raises:
ValueError: If tool names are not unique.
Expand Down Expand Up @@ -271,6 +273,7 @@ def __init__(
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
memory: Sequence[Memory] | None = None,
extra_create_args: Mapping[str, Any] | None = None,
):
super().__init__(name=name, description=description)
self._model_client = model_client
Expand Down Expand Up @@ -337,6 +340,7 @@ def __init__(
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
self._extra_create_args = extra_create_args or {}

@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
Expand Down Expand Up @@ -384,7 +388,10 @@ async def on_messages_stream(
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
llm_messages,
tools=self._tools + self._handoff_tools,
extra_create_args=self._extra_create_args,
cancellation_token=cancellation_token,
)

# Add the response to the model context.
Expand Down Expand Up @@ -465,7 +472,9 @@ async def on_messages_stream(
if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
model_result = await self._model_client.create(
llm_messages, extra_create_args=self._extra_create_args, cancellation_token=cancellation_token
)
assert isinstance(model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
Expand Down Expand Up @@ -540,6 +549,7 @@ def _to_config(self) -> AssistantAgentConfig:
else None,
reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format,
extra_create_args=self._extra_create_args,
)

@classmethod
Expand All @@ -555,4 +565,5 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self:
system_message=config.system_message,
reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format,
extra_create_args=config.extra_create_args,
)
Loading