Skip to content

Commit

Permalink
Simplify component creation
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Jan 13, 2025
1 parent e1d0fae commit 1ce0bba
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 45 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List, Sequence
from typing import Any, List, Sequence

from autogen_core import Component, ComponentModel
from autogen_core._component_config import ComponentBase
from pydantic import BaseModel
from typing_extensions import Self

from ..messages import AgentEvent, ChatMessage, StopMessage


class TerminatedException(BaseException): ...


class TerminationCondition(ABC):
class TerminationCondition(ComponentBase[BaseModel], ABC):
"""A stateful condition that determines when a conversation should be terminated.
A termination condition is a callable that takes a sequence of ChatMessage objects
Expand Down Expand Up @@ -43,6 +48,9 @@ async def main() -> None:
asyncio.run(main())
"""

component_type = "termination"
# component_config_schema = BaseModel # type: ignore

@property
@abstractmethod
def terminated(self) -> bool:
Expand Down Expand Up @@ -72,14 +80,22 @@ async def reset(self) -> None:

def __and__(self, other: "TerminationCondition") -> "TerminationCondition":
"""Combine two termination conditions with an AND operation."""
return _AndTerminationCondition(self, other)
return AndTerminationCondition(self, other)

def __or__(self, other: "TerminationCondition") -> "TerminationCondition":
"""Combine two termination conditions with an OR operation."""
return _OrTerminationCondition(self, other)
return OrTerminationCondition(self, other)


class AndTerminationConditionConfig(BaseModel):
conditions: List[ComponentModel]


class AndTerminationCondition(TerminationCondition, Component[AndTerminationConditionConfig]):
component_config_schema = AndTerminationConditionConfig
component_type = "termination"
component_provider_override = "autogen_agentchat.base.AndTerminationCondition"

class _AndTerminationCondition(TerminationCondition):
def __init__(self, *conditions: TerminationCondition) -> None:
self._conditions = conditions
self._stop_messages: List[StopMessage] = []
Expand Down Expand Up @@ -111,8 +127,27 @@ async def reset(self) -> None:
await condition.reset()
self._stop_messages.clear()

def _to_config(self) -> AndTerminationConditionConfig:
"""Convert the AND termination condition to a config."""
return AndTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])

@classmethod
def _from_config(cls, config: AndTerminationConditionConfig) -> Self:
"""Create an AND termination condition from a config."""
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
return cls(*conditions)


class OrTerminationConditionConfig(BaseModel):
conditions: List[ComponentModel]
"""List of termination conditions where any one being satisfied is sufficient."""


class OrTerminationCondition(TerminationCondition, Component[OrTerminationConditionConfig]):
component_config_schema = OrTerminationConditionConfig
component_type = "termination"
component_provider_override = "autogen_agentchat.base.OrTerminationCondition"

class _OrTerminationCondition(TerminationCondition):
def __init__(self, *conditions: TerminationCondition) -> None:
self._conditions = conditions

Expand All @@ -133,3 +168,13 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe
async def reset(self) -> None:
for condition in self._conditions:
await condition.reset()

def _to_config(self) -> OrTerminationConditionConfig:
"""Convert the OR termination condition to a config."""
return OrTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions])

@classmethod
def _from_config(cls, config: OrTerminationConditionConfig) -> Self:
"""Create an OR termination condition from a config."""
conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions]
return cls(*conditions)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"\n",
"## Usage\n",
"\n",
"If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n",
"If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentToConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n",
"\n",
"### Loading a component from a config\n",
"\n",
Expand Down Expand Up @@ -52,7 +52,7 @@
"To add component functionality to a given class:\n",
"\n",
"1. Add a call to {py:meth}`~autogen_core.Component` in the class inheritance list.\n",
"2. Implment the {py:meth}`~autogen_core.ComponentConfigImpl._to_config` and {py:meth}`~autogen_core.ComponentConfigImpl._from_config` methods\n",
"2. Implment the {py:meth}`~autogen_core.ComponentToConfig._to_config` and {py:meth}`~autogen_core.ComponentFromConfig._from_config` methods\n",
"\n",
"For example:"
]
Expand All @@ -63,15 +63,15 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core import Component\n",
"from autogen_core import Component, ComponentBase\n",
"from pydantic import BaseModel\n",
"\n",
"\n",
"class Config(BaseModel):\n",
" value: str\n",
"\n",
"\n",
"class MyComponent(Component[Config]):\n",
"class MyComponent(ComponentBase[Config], Component[Config]):\n",
" component_type = \"custom\"\n",
" component_config_schema = Config\n",
"\n",
Expand Down Expand Up @@ -129,7 +129,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions python/packages/autogen-core/src/autogen_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._closure_agent import ClosureAgent, ClosureContext
from ._component_config import (
Component,
ComponentBase,
ComponentFromConfig,
ComponentLoader,
ComponentModel,
Expand Down Expand Up @@ -116,6 +117,7 @@
"EVENT_LOGGER_NAME",
"TRACE_LOGGER_NAME",
"Component",
"ComponentBase",
"ComponentFromConfig",
"ComponentLoader",
"ComponentModel",
Expand Down
55 changes: 36 additions & 19 deletions python/packages/autogen-core/src/autogen_core/_component_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import importlib
import warnings
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload

from pydantic import BaseModel
Expand Down Expand Up @@ -50,9 +49,8 @@ def _type_to_provider_str(t: type) -> str:
}


class ComponentFromConfig(ABC, Generic[FromConfigT]):
class ComponentFromConfig(Generic[FromConfigT]):
@classmethod
@abstractmethod
def _from_config(cls, config: FromConfigT) -> Self:
"""Create a new instance of the component from a configuration object.
Expand All @@ -64,7 +62,7 @@ def _from_config(cls, config: FromConfigT) -> Self:
:meta public:
"""
...
raise NotImplementedError("This component does not support dumping to config")

@classmethod
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
Expand All @@ -84,7 +82,7 @@ def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self
raise NotImplementedError("This component does not support loading from past versions")


class ComponentToConfig(ABC, Generic[ToConfigT]):
class ComponentToConfig(Generic[ToConfigT]):
"""The two methods a class must implement to be a component.
Args:
Expand All @@ -98,7 +96,6 @@ class ComponentToConfig(ABC, Generic[ToConfigT]):
component_provider_override: ClassVar[str | None] = None
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""

@abstractmethod
def _to_config(self) -> ToConfigT:
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
Expand All @@ -107,7 +104,7 @@ def _to_config(self) -> ToConfigT:
:meta public:
"""
...
raise NotImplementedError("This component does not support dumping to config")

def dump_component(self) -> ComponentModel:
"""Dump the component to a model that can be loaded back in.
Expand Down Expand Up @@ -258,22 +255,25 @@ class ComponentSchemaType(Generic[ConfigT]):
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)

# TODO: validate provider is loadable
for var in cls.required_class_vars:
if not hasattr(cls, var):
warnings.warn(
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", stacklevel=2
)
if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent":
# TODO: validate provider is loadable
for var in cls.required_class_vars:
if not hasattr(cls, var):
warnings.warn(
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component",
stacklevel=2,
)


class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ...


class Component(
ComponentFromConfig[ConfigT],
ComponentToConfig[ConfigT],
ComponentSchemaType[ConfigT],
ComponentLoader,
Generic[ConfigT],
):
"""To create a component class, inherit from this class. Then implement two class variables:
"""To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables:
- :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component.
- :py:attr:`component_type` - What is the logical type of the component.
Expand Down Expand Up @@ -307,10 +307,27 @@ def _from_config(cls, config: Config) -> MyComponent:
return cls(value=config.value)
"""

...
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)

if not is_component_class(cls):
warnings.warn(
f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.",
stacklevel=2,
)


# Should never be used directly, only for type checking
class _ConcreteComponent(
ComponentFromConfig[ConfigT],
ComponentSchemaType[ConfigT],
ComponentToConfig[ConfigT],
ComponentLoader,
Generic[ConfigT],
): ...


def is_component_instance(cls: Any) -> TypeGuard[Component[BaseModel]]:
def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]:
return (
isinstance(cls, ComponentFromConfig)
and isinstance(cls, ComponentToConfig)
Expand All @@ -319,7 +336,7 @@ def is_component_instance(cls: Any) -> TypeGuard[Component[BaseModel]]:
)


def is_component_class(cls: type) -> TypeGuard[Type[Component[BaseModel]]]:
def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]:
return (
issubclass(cls, ComponentFromConfig)
and issubclass(cls, ComponentToConfig)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from abc import ABC, abstractmethod
from typing import Literal, Mapping, Optional, Sequence, TypeAlias

from pydantic import BaseModel
from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated

from .. import CancellationToken
from .._component_config import ComponentLoader
from .._component_config import ComponentBase
from ..tools import Tool, ToolSchema
from ._types import CreateResult, LLMMessage, RequestUsage

Expand Down Expand Up @@ -47,7 +48,7 @@ class ModelInfo(TypedDict, total=False):
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""


class ChatCompletionClient(ABC, ComponentLoader):
class ChatCompletionClient(ComponentBase[BaseModel], ABC):
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
@abstractmethod
async def create(
Expand Down
16 changes: 8 additions & 8 deletions python/packages/autogen-core/tests/test_component_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict

import pytest
from autogen_core import Component, ComponentLoader, ComponentModel
from autogen_core import Component, ComponentBase, ComponentLoader, ComponentModel
from autogen_core._component_config import _type_to_provider_str # type: ignore
from autogen_core.models import ChatCompletionClient
from autogen_test_utils import MyInnerComponent, MyOuterComponent
Expand All @@ -16,7 +16,7 @@ class MyConfig(BaseModel):
info: str


class MyComponent(Component[MyConfig]):
class MyComponent(ComponentBase[MyConfig], Component[MyConfig]):
component_config_schema = MyConfig
component_type = "custom"

Expand Down Expand Up @@ -95,7 +95,7 @@ def test_cannot_import_locals() -> None:
class InvalidModelClientConfig(BaseModel):
info: str

class MyInvalidModelClient(Component[InvalidModelClientConfig]):
class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
component_config_schema = InvalidModelClientConfig
component_type = "model"

Expand All @@ -119,7 +119,7 @@ class InvalidModelClientConfig(BaseModel):
info: str


class MyInvalidModelClient(Component[InvalidModelClientConfig]):
class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
component_config_schema = InvalidModelClientConfig
component_type = "model"

Expand All @@ -143,7 +143,7 @@ def test_type_error_on_creation() -> None:

with pytest.warns(UserWarning):

class MyInvalidMissingAttrs(Component[InvalidModelClientConfig]):
class MyInvalidMissingAttrs(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]):
def __init__(self, info: str):
self.info = info

Expand Down Expand Up @@ -189,7 +189,7 @@ def test_config_optional_values() -> None:
assert component.__class__ == MyComponent


class ConfigProviderOverrided(Component[MyConfig]):
class ConfigProviderOverrided(ComponentBase[MyConfig], Component[MyConfig]):
component_provider_override = "InvalidButStillOverridden"
component_config_schema = MyConfig
component_type = "custom"
Expand All @@ -215,7 +215,7 @@ class MyConfig2(BaseModel):
info2: str


class ComponentNonOneVersion(Component[MyConfig2]):
class ComponentNonOneVersion(ComponentBase[MyConfig2], Component[MyConfig2]):
component_config_schema = MyConfig2
component_version = 2
component_type = "custom"
Expand All @@ -231,7 +231,7 @@ def _from_config(cls, config: MyConfig2) -> Self:
return cls(info=config.info2)


class ComponentNonOneVersionWithUpgrade(Component[MyConfig2]):
class ComponentNonOneVersionWithUpgrade(ComponentBase[MyConfig2], Component[MyConfig2]):
component_config_schema = MyConfig2
component_version = 2
component_type = "custom"
Expand Down
Loading

0 comments on commit 1ce0bba

Please sign in to comment.