Skip to content

Commit

Permalink
feat: add extra_create_args to AssistantAgent for model client custom…
Browse files Browse the repository at this point in the history
…ization
  • Loading branch information
gagb committed Jan 27, 2025
1 parent e582072 commit 1e94783
Showing 1 changed file with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class AssistantAgentConfig(BaseModel):
system_message: str | None = None
reflect_on_tool_use: bool
tool_call_summary_format: str
extra_create_args: Mapping[str, Any] | None = None


class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
Expand Down Expand Up @@ -147,6 +148,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
Available variables: `{tool_name}`, `{arguments}`, `{result}`.
For example, `"{tool_name}: {result}"` will create a summary like `"tool_name: result"`.
memory (Sequence[Memory] | None, optional): The memory store to use for the agent. Defaults to `None`.
extra_create_args (Mapping[str, Any] | None, optional): Additional arguments to pass to the model client during the create method call. Defaults to `None`.
Raises:
ValueError: If tool names are not unique.
Expand Down Expand Up @@ -271,6 +273,7 @@ def __init__(
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
memory: Sequence[Memory] | None = None,
extra_create_args: Mapping[str, Any] | None = None,
):
super().__init__(name=name, description=description)
self._model_client = model_client
Expand Down Expand Up @@ -337,6 +340,7 @@ def __init__(
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
self._extra_create_args = extra_create_args or {}

@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
Expand Down Expand Up @@ -384,7 +388,7 @@ async def on_messages_stream(
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
llm_messages, tools=self._tools + self._handoff_tools, extra_create_args=self._extra_create_args, cancellation_token=cancellation_token
)

# Add the response to the model context.
Expand Down Expand Up @@ -465,7 +469,9 @@ async def on_messages_stream(
if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
model_result = await self._model_client.create(
llm_messages, extra_create_args=self._extra_create_args, cancellation_token=cancellation_token
)
assert isinstance(model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
Expand Down Expand Up @@ -540,6 +546,7 @@ def _to_config(self) -> AssistantAgentConfig:
else None,
reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format,
extra_create_args=self._extra_create_args,
)

@classmethod
Expand All @@ -555,4 +562,5 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self:
system_message=config.system_message,
reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format,
extra_create_args=config.extra_create_args,
)

0 comments on commit 1e94783

Please sign in to comment.