Skip to content

Commit

Permalink
Make termination condition config declarative (#4984)
Browse files Browse the repository at this point in the history
* make termination condition declarative

* make all term conditions declarative

* make And/OrTermination top level objects in base

* add basic tests

* add tutorial notebook

* update tests and formatting

* update tests

* update declarative config with updated api.
  • Loading branch information
victordibia authored Jan 14, 2025
1 parent 9570e82 commit d883e3d
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
from ._handoff import Handoff
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition
from ._termination import AndTerminationCondition, OrTerminationCondition, TerminatedException, TerminationCondition

__all__ = [
"ChatAgent",
"Response",
"Team",
"TerminatedException",
"TerminationCondition",
"AndTerminationCondition",
"OrTerminationCondition",
"TaskResult",
"TaskRunner",
"Handoff",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
from abc import ABC, abstractmethod
from typing import List, Sequence

from autogen_core import Component, ComponentBase, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self

from ..messages import AgentEvent, ChatMessage, StopMessage


class TerminatedException(BaseException): ...


class TerminationCondition(ABC):
class TerminationCondition(ABC, ComponentBase[BaseModel]):
"""A stateful condition that determines when a conversation should be terminated.
A termination condition is a callable that takes a sequence of ChatMessage objects
Expand Down Expand Up @@ -43,6 +47,9 @@ async def main() -> None:
asyncio.run(main())
"""

component_type = "termination"
# component_config_schema = BaseModel # type: ignore

@property
@abstractmethod
def terminated(self) -> bool:
Expand Down Expand Up @@ -72,14 +79,22 @@ async def reset(self) -> None:

def __and__(self, other: "TerminationCondition") -> "TerminationCondition":
"""Combine two termination conditions with an AND operation."""
return _AndTerminationCondition(self, other)
return AndTerminationCondition(self, other)

def __or__(self, other: "TerminationCondition") -> "TerminationCondition":
"""Combine two termination conditions with an OR operation."""
return _OrTerminationCondition(self, other)
return OrTerminationCondition(self, other)


class AndTerminationConditionConfig(BaseModel):
conditions: List[ComponentModel]


class AndTerminationCondition(TerminationCondition, Component[AndTerminationConditionConfig]):
component_config_schema = AndTerminationConditionConfig
component_type = "termination"
component_provider_override = "autogen_agentchat.base.AndTerminationCondition"

class _AndTerminationCondition(TerminationCondition):
def __init__(self, *conditions: TerminationCondition) -> None:
self._conditions = conditions
self._stop_messages: List[StopMessage] = []
Expand Down Expand Up @@ -111,8 +126,27 @@ async def reset(self) -> None:
await condition.reset()
self._stop_messages.clear()

def _to_config(self) -> AndTerminationConditionConfig:
"""Convert the AND termination condition to a config."""
return AndTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])

@classmethod
def _from_config(cls, config: AndTerminationConditionConfig) -> Self:
"""Create an AND termination condition from a config."""
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
return cls(*conditions)


class OrTerminationConditionConfig(BaseModel):
conditions: List[ComponentModel]
"""List of termination conditions where any one being satisfied is sufficient."""


class OrTerminationCondition(TerminationCondition, Component[OrTerminationConditionConfig]):
component_config_schema = OrTerminationConditionConfig
component_type = "termination"
component_provider_override = "autogen_agentchat.base.OrTerminationCondition"

class _OrTerminationCondition(TerminationCondition):
def __init__(self, *conditions: TerminationCondition) -> None:
self._conditions = conditions

Expand All @@ -133,3 +167,13 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
for condition in self._conditions:
await condition.reset()

def _to_config(self) -> OrTerminationConditionConfig:
"""Convert the OR termination condition to a config."""
return OrTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])

@classmethod
def _from_config(cls, config: OrTerminationConditionConfig) -> Self:
"""Create an OR termination condition from a config."""
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
return cls(*conditions)
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
import time
from typing import List, Sequence

from autogen_core import Component
from pydantic import BaseModel
from typing_extensions import Self

from ..base import TerminatedException, TerminationCondition
from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage


class StopMessageTermination(TerminationCondition):
class StopMessageTerminationConfig(BaseModel):
pass


class StopMessageTermination(TerminationCondition, Component[StopMessageTerminationConfig]):
"""Terminate the conversation if a StopMessage is received."""

component_type = "termination"
component_config_schema = StopMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.StopMessageTermination"

def __init__(self) -> None:
self._terminated = False

Expand All @@ -27,14 +39,29 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
self._terminated = False

def _to_config(self) -> StopMessageTerminationConfig:
return StopMessageTerminationConfig()

@classmethod
def _from_config(cls, config: StopMessageTerminationConfig) -> Self:
return cls()


class MaxMessageTermination(TerminationCondition):
class MaxMessageTerminationConfig(BaseModel):
max_messages: int


class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminationConfig]):
"""Terminate the conversation after a maximum number of messages have been exchanged.
Args:
max_messages: The maximum number of messages allowed in the conversation.
"""

component_type = "termination"
component_config_schema = MaxMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.MaxMessageTermination"

def __init__(self, max_messages: int) -> None:
self._max_messages = max_messages
self._message_count = 0
Expand All @@ -57,14 +84,30 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
self._message_count = 0

def _to_config(self) -> MaxMessageTerminationConfig:
return MaxMessageTerminationConfig(max_messages=self._max_messages)

@classmethod
def _from_config(cls, config: MaxMessageTerminationConfig) -> Self:
return cls(max_messages=config.max_messages)


class TextMentionTermination(TerminationCondition):
class TextMentionTerminationConfig(BaseModel):
text: str


class TextMentionTermination(TerminationCondition, Component[TextMentionTerminationConfig]):
"""Terminate the conversation if a specific text is mentioned.
Args:
text: The text to look for in the messages.
"""

component_type = "termination"
component_config_schema = TextMentionTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TextMentionTermination"

def __init__(self, text: str) -> None:
self._text = text
self._terminated = False
Expand All @@ -90,8 +133,21 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
self._terminated = False

def _to_config(self) -> TextMentionTerminationConfig:
return TextMentionTerminationConfig(text=self._text)

@classmethod
def _from_config(cls, config: TextMentionTerminationConfig) -> Self:
return cls(text=config.text)


class TokenUsageTermination(TerminationCondition):
class TokenUsageTerminationConfig(BaseModel):
max_total_token: int | None
max_prompt_token: int | None
max_completion_token: int | None


class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminationConfig]):
"""Terminate the conversation if a token usage limit is reached.
Args:
Expand All @@ -103,6 +159,10 @@ class TokenUsageTermination(TerminationCondition):
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
"""

component_type = "termination"
component_config_schema = TokenUsageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TokenUsageTermination"

def __init__(
self,
max_total_token: int | None = None,
Expand Down Expand Up @@ -146,15 +206,38 @@ async def reset(self) -> None:
self._prompt_token_count = 0
self._completion_token_count = 0

def _to_config(self) -> TokenUsageTerminationConfig:
return TokenUsageTerminationConfig(
max_total_token=self._max_total_token,
max_prompt_token=self._max_prompt_token,
max_completion_token=self._max_completion_token,
)

@classmethod
def _from_config(cls, config: TokenUsageTerminationConfig) -> Self:
return cls(
max_total_token=config.max_total_token,
max_prompt_token=config.max_prompt_token,
max_completion_token=config.max_completion_token,
)


class HandoffTerminationConfig(BaseModel):
target: str

class HandoffTermination(TerminationCondition):

class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfig]):
"""Terminate the conversation if a :class:`~autogen_agentchat.messages.HandoffMessage`
with the given target is received.
Args:
target (str): The target of the handoff message.
"""

component_type = "termination"
component_config_schema = HandoffTerminationConfig
component_provider_override = "autogen_agentchat.conditions.HandoffTermination"

def __init__(self, target: str) -> None:
self._terminated = False
self._target = target
Expand All @@ -177,14 +260,29 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
self._terminated = False

def _to_config(self) -> HandoffTerminationConfig:
return HandoffTerminationConfig(target=self._target)

@classmethod
def _from_config(cls, config: HandoffTerminationConfig) -> Self:
return cls(target=config.target)

class TimeoutTermination(TerminationCondition):

class TimeoutTerminationConfig(BaseModel):
timeout_seconds: float


class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfig]):
"""Terminate the conversation after a specified duration has passed.
Args:
timeout_seconds: The maximum duration in seconds before terminating the conversation.
"""

component_type = "termination"
component_config_schema = TimeoutTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TimeoutTermination"

def __init__(self, timeout_seconds: float) -> None:
self._timeout_seconds = timeout_seconds
self._start_time = time.monotonic()
Expand All @@ -209,8 +307,19 @@ async def reset(self) -> None:
self._start_time = time.monotonic()
self._terminated = False

def _to_config(self) -> TimeoutTerminationConfig:
return TimeoutTerminationConfig(timeout_seconds=self._timeout_seconds)

@classmethod
def _from_config(cls, config: TimeoutTerminationConfig) -> Self:
return cls(timeout_seconds=config.timeout_seconds)

class ExternalTermination(TerminationCondition):

class ExternalTerminationConfig(BaseModel):
pass


class ExternalTermination(TerminationCondition, Component[ExternalTerminationConfig]):
"""A termination condition that is externally controlled
by calling the :meth:`set` method.
Expand All @@ -230,6 +339,10 @@ class ExternalTermination(TerminationCondition):
"""

component_type = "termination"
component_config_schema = ExternalTerminationConfig
component_provider_override = "autogen_agentchat.conditions.ExternalTermination"

def __init__(self) -> None:
self._terminated = False
self._setted = False
Expand All @@ -254,8 +367,19 @@ async def reset(self) -> None:
self._terminated = False
self._setted = False

def _to_config(self) -> ExternalTerminationConfig:
return ExternalTerminationConfig()

@classmethod
def _from_config(cls, config: ExternalTerminationConfig) -> Self:
return cls()


class SourceMatchTermination(TerminationCondition):
class SourceMatchTerminationConfig(BaseModel):
sources: List[str]


class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminationConfig]):
"""Terminate the conversation after a specific source responds.
Args:
Expand All @@ -265,6 +389,10 @@ class SourceMatchTermination(TerminationCondition):
TerminatedException: If the termination condition has already been reached.
"""

component_type = "termination"
component_config_schema = SourceMatchTerminationConfig
component_provider_override = "autogen_agentchat.conditions.SourceMatchTermination"

def __init__(self, sources: List[str]) -> None:
self._sources = sources
self._terminated = False
Expand All @@ -286,3 +414,10 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe

async def reset(self) -> None:
self._terminated = False

def _to_config(self) -> SourceMatchTerminationConfig:
return SourceMatchTerminationConfig(sources=self._sources)

@classmethod
def _from_config(cls, config: SourceMatchTerminationConfig) -> Self:
return cls(sources=config.sources)
Loading

0 comments on commit d883e3d

Please sign in to comment.