Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make termination condition config declarative #4984

Merged
merged 10 commits into from
Jan 14, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
from ._handoff import Handoff
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition
from ._termination import AndTerminationCondition, OrTerminationCondition, TerminatedException, TerminationCondition

__all__ = [
"ChatAgent",
"Response",
"Team",
"TerminatedException",
"TerminationCondition",
"AndTerminationCondition",
"OrTerminationCondition",
"TaskResult",
"TaskRunner",
"Handoff",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
from abc import ABC, abstractmethod
from typing import List, Sequence

from autogen_core import Component, ComponentBase, ComponentModel
from pydantic import BaseModel
from typing_extensions import Self

from ..messages import AgentEvent, ChatMessage, StopMessage


class TerminatedException(BaseException): ...


class TerminationCondition(ABC):
class TerminationCondition(ABC, ComponentBase[BaseModel]):
"""A stateful condition that determines when a conversation should be terminated.

A termination condition is a callable that takes a sequence of ChatMessage objects
Expand Down Expand Up @@ -43,6 +47,9 @@ async def main() -> None:
asyncio.run(main())
"""

component_type = "termination"
# component_config_schema = BaseModel # type: ignore

@property
@abstractmethod
def terminated(self) -> bool:
Expand Down Expand Up @@ -72,14 +79,22 @@ async def reset(self) -> None:

def __and__(self, other: "TerminationCondition") -> "TerminationCondition":
"""Combine two termination conditions with an AND operation."""
return _AndTerminationCondition(self, other)
return AndTerminationCondition(self, other)

def __or__(self, other: "TerminationCondition") -> "TerminationCondition":
"""Combine two termination conditions with an OR operation."""
return _OrTerminationCondition(self, other)
return OrTerminationCondition(self, other)


class AndTerminationConditionConfig(BaseModel):
conditions: List[ComponentModel]


class AndTerminationCondition(TerminationCondition, Component[AndTerminationConditionConfig]):
component_config_schema = AndTerminationConditionConfig
component_type = "termination"
component_provider_override = "autogen_agentchat.base.AndTerminationCondition"

victordibia marked this conversation as resolved.
Show resolved Hide resolved
class _AndTerminationCondition(TerminationCondition):
def __init__(self, *conditions: TerminationCondition) -> None:
self._conditions = conditions
self._stop_messages: List[StopMessage] = []
Expand Down Expand Up @@ -111,8 +126,27 @@ async def reset(self) -> None:
await condition.reset()
self._stop_messages.clear()

def _to_config(self) -> AndTerminationConditionConfig:
"""Convert the AND termination condition to a config."""
return AndTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])

@classmethod
def _from_config(cls, config: AndTerminationConditionConfig) -> Self:
"""Create an AND termination condition from a config."""
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
return cls(*conditions)


class OrTerminationConditionConfig(BaseModel):
conditions: List[ComponentModel]
"""List of termination conditions where any one being satisfied is sufficient."""


class OrTerminationCondition(TerminationCondition, Component[OrTerminationConditionConfig]):
component_config_schema = OrTerminationConditionConfig
component_type = "termination"
component_provider_override = "autogen_agentchat.base.OrTerminationCondition"

class _OrTerminationCondition(TerminationCondition):
def __init__(self, *conditions: TerminationCondition) -> None:
self._conditions = conditions

Expand All @@ -133,3 +167,13 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
for condition in self._conditions:
await condition.reset()

def _to_config(self) -> OrTerminationConditionConfig:
"""Convert the OR termination condition to a config."""
return OrTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])

@classmethod
def _from_config(cls, config: OrTerminationConditionConfig) -> Self:
"""Create an OR termination condition from a config."""
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
return cls(*conditions)
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
import time
from typing import List, Sequence

from autogen_core import Component
from pydantic import BaseModel
from typing_extensions import Self

from ..base import TerminatedException, TerminationCondition
from ..messages import AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage


class StopMessageTermination(TerminationCondition):
class StopMessageTerminationConfig(BaseModel):
pass


class StopMessageTermination(TerminationCondition, Component[StopMessageTerminationConfig]):
"""Terminate the conversation if a StopMessage is received."""

component_type = "termination"
component_config_schema = StopMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.StopMessageTermination"

def __init__(self) -> None:
self._terminated = False

Expand All @@ -27,14 +39,29 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
self._terminated = False

def _to_config(self) -> StopMessageTerminationConfig:
return StopMessageTerminationConfig()

@classmethod
def _from_config(cls, config: StopMessageTerminationConfig) -> Self:
return cls()


class MaxMessageTermination(TerminationCondition):
class MaxMessageTerminationConfig(BaseModel):
max_messages: int


class MaxMessageTermination(TerminationCondition, Component[MaxMessageTerminationConfig]):
"""Terminate the conversation after a maximum number of messages have been exchanged.

Args:
max_messages: The maximum number of messages allowed in the conversation.
"""

component_type = "termination"
component_config_schema = MaxMessageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.MaxMessageTermination"

def __init__(self, max_messages: int) -> None:
self._max_messages = max_messages
self._message_count = 0
Expand All @@ -57,14 +84,30 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
self._message_count = 0

def _to_config(self) -> MaxMessageTerminationConfig:
return MaxMessageTerminationConfig(max_messages=self._max_messages)

@classmethod
def _from_config(cls, config: MaxMessageTerminationConfig) -> Self:
return cls(max_messages=config.max_messages)


class TextMentionTermination(TerminationCondition):
class TextMentionTerminationConfig(BaseModel):
text: str


class TextMentionTermination(TerminationCondition, Component[TextMentionTerminationConfig]):
"""Terminate the conversation if a specific text is mentioned.


Args:
text: The text to look for in the messages.
"""

component_type = "termination"
component_config_schema = TextMentionTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TextMentionTermination"

def __init__(self, text: str) -> None:
self._text = text
self._terminated = False
Expand All @@ -90,8 +133,21 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
self._terminated = False

def _to_config(self) -> TextMentionTerminationConfig:
return TextMentionTerminationConfig(text=self._text)

@classmethod
def _from_config(cls, config: TextMentionTerminationConfig) -> Self:
return cls(text=config.text)


class TokenUsageTermination(TerminationCondition):
class TokenUsageTerminationConfig(BaseModel):
max_total_token: int | None
max_prompt_token: int | None
max_completion_token: int | None


class TokenUsageTermination(TerminationCondition, Component[TokenUsageTerminationConfig]):
"""Terminate the conversation if a token usage limit is reached.

Args:
Expand All @@ -103,6 +159,10 @@ class TokenUsageTermination(TerminationCondition):
ValueError: If none of max_total_token, max_prompt_token, or max_completion_token is provided.
"""

component_type = "termination"
component_config_schema = TokenUsageTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TokenUsageTermination"

def __init__(
self,
max_total_token: int | None = None,
Expand Down Expand Up @@ -146,15 +206,38 @@ async def reset(self) -> None:
self._prompt_token_count = 0
self._completion_token_count = 0

def _to_config(self) -> TokenUsageTerminationConfig:
return TokenUsageTerminationConfig(
max_total_token=self._max_total_token,
max_prompt_token=self._max_prompt_token,
max_completion_token=self._max_completion_token,
)

@classmethod
def _from_config(cls, config: TokenUsageTerminationConfig) -> Self:
return cls(
max_total_token=config.max_total_token,
max_prompt_token=config.max_prompt_token,
max_completion_token=config.max_completion_token,
)


class HandoffTerminationConfig(BaseModel):
target: str

class HandoffTermination(TerminationCondition):

class HandoffTermination(TerminationCondition, Component[HandoffTerminationConfig]):
"""Terminate the conversation if a :class:`~autogen_agentchat.messages.HandoffMessage`
with the given target is received.

Args:
target (str): The target of the handoff message.
"""

component_type = "termination"
component_config_schema = HandoffTerminationConfig
component_provider_override = "autogen_agentchat.conditions.HandoffTermination"

def __init__(self, target: str) -> None:
self._terminated = False
self._target = target
Expand All @@ -177,14 +260,29 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
self._terminated = False

def _to_config(self) -> HandoffTerminationConfig:
return HandoffTerminationConfig(target=self._target)

@classmethod
def _from_config(cls, config: HandoffTerminationConfig) -> Self:
return cls(target=config.target)

class TimeoutTermination(TerminationCondition):

class TimeoutTerminationConfig(BaseModel):
timeout_seconds: float


class TimeoutTermination(TerminationCondition, Component[TimeoutTerminationConfig]):
"""Terminate the conversation after a specified duration has passed.

Args:
timeout_seconds: The maximum duration in seconds before terminating the conversation.
"""

component_type = "termination"
component_config_schema = TimeoutTerminationConfig
component_provider_override = "autogen_agentchat.conditions.TimeoutTermination"

def __init__(self, timeout_seconds: float) -> None:
self._timeout_seconds = timeout_seconds
self._start_time = time.monotonic()
Expand All @@ -209,8 +307,19 @@ async def reset(self) -> None:
self._start_time = time.monotonic()
self._terminated = False

def _to_config(self) -> TimeoutTerminationConfig:
return TimeoutTerminationConfig(timeout_seconds=self._timeout_seconds)

@classmethod
def _from_config(cls, config: TimeoutTerminationConfig) -> Self:
return cls(timeout_seconds=config.timeout_seconds)

class ExternalTermination(TerminationCondition):

class ExternalTerminationConfig(BaseModel):
pass


class ExternalTermination(TerminationCondition, Component[ExternalTerminationConfig]):
"""A termination condition that is externally controlled
by calling the :meth:`set` method.

Expand All @@ -230,6 +339,10 @@ class ExternalTermination(TerminationCondition):

"""

component_type = "termination"
component_config_schema = ExternalTerminationConfig
component_provider_override = "autogen_agentchat.conditions.ExternalTermination"

def __init__(self) -> None:
self._terminated = False
self._setted = False
Expand All @@ -254,8 +367,19 @@ async def reset(self) -> None:
self._terminated = False
self._setted = False

def _to_config(self) -> ExternalTerminationConfig:
return ExternalTerminationConfig()

@classmethod
def _from_config(cls, config: ExternalTerminationConfig) -> Self:
return cls()


class SourceMatchTermination(TerminationCondition):
class SourceMatchTerminationConfig(BaseModel):
sources: List[str]


class SourceMatchTermination(TerminationCondition, Component[SourceMatchTerminationConfig]):
"""Terminate the conversation after a specific source responds.

Args:
Expand All @@ -265,6 +389,10 @@ class SourceMatchTermination(TerminationCondition):
TerminatedException: If the termination condition has already been reached.
"""

component_type = "termination"
component_config_schema = SourceMatchTerminationConfig
component_provider_override = "autogen_agentchat.conditions.SourceMatchTermination"

def __init__(self, sources: List[str]) -> None:
self._sources = sources
self._terminated = False
Expand All @@ -286,3 +414,10 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe

async def reset(self) -> None:
self._terminated = False

def _to_config(self) -> SourceMatchTerminationConfig:
return SourceMatchTerminationConfig(sources=self._sources)

@classmethod
def _from_config(cls, config: SourceMatchTerminationConfig) -> Self:
return cls(sources=config.sources)
Loading
Loading