From 1ce0bba6c7510820ad3ee5066d83a24df6e8f06f Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 13 Jan 2025 15:34:25 -0500 Subject: [PATCH] Simplify component creation --- .../autogen_agentchat/base/_termination.py | 57 +++++++++++++++++-- .../framework/component-config.ipynb | 10 ++-- .../autogen-core/src/autogen_core/__init__.py | 2 + .../src/autogen_core/_component_config.py | 55 +++++++++++------- .../src/autogen_core/models/_model_client.py | 5 +- .../tests/test_component_config.py | 16 +++--- .../src/autogen_ext/auth/azure/__init__.py | 4 +- .../src/autogen_test_utils/__init__.py | 7 ++- 8 files changed, 111 insertions(+), 45 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py index 8975c75aad12..bc436687fb84 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py @@ -1,6 +1,11 @@ import asyncio from abc import ABC, abstractmethod -from typing import List, Sequence +from typing import Any, List, Sequence + +from autogen_core import Component, ComponentModel +from autogen_core._component_config import ComponentBase +from pydantic import BaseModel +from typing_extensions import Self from ..messages import AgentEvent, ChatMessage, StopMessage @@ -8,7 +13,7 @@ class TerminatedException(BaseException): ... -class TerminationCondition(ABC): +class TerminationCondition(ComponentBase[BaseModel], ABC): """A stateful condition that determines when a conversation should be terminated. A termination condition is a callable that takes a sequence of ChatMessage objects @@ -43,6 +48,9 @@ async def main() -> None: asyncio.run(main()) """ + component_type = "termination" + # component_config_schema = BaseModel # type: ignore + @property @abstractmethod def terminated(self) -> bool: @@ -72,14 +80,22 @@ async def reset(self) -> None: def __and__(self, other: "TerminationCondition") -> "TerminationCondition": """Combine two termination conditions with an AND operation.""" - return _AndTerminationCondition(self, other) + return AndTerminationCondition(self, other) def __or__(self, other: "TerminationCondition") -> "TerminationCondition": """Combine two termination conditions with an OR operation.""" - return _OrTerminationCondition(self, other) + return OrTerminationCondition(self, other) + +class AndTerminationConditionConfig(BaseModel): + conditions: List[ComponentModel] + + +class AndTerminationCondition(TerminationCondition, Component[AndTerminationConditionConfig]): + component_config_schema = AndTerminationConditionConfig + component_type = "termination" + component_provider_override = "autogen_agentchat.base.AndTerminationCondition" -class _AndTerminationCondition(TerminationCondition): def __init__(self, *conditions: TerminationCondition) -> None: self._conditions = conditions self._stop_messages: List[StopMessage] = [] @@ -111,8 +127,27 @@ async def reset(self) -> None: await condition.reset() self._stop_messages.clear() + def _to_config(self) -> AndTerminationConditionConfig: + """Convert the AND termination condition to a config.""" + return AndTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions]) + + @classmethod + def _from_config(cls, config: AndTerminationConditionConfig) -> Self: + """Create an AND termination condition from a config.""" + conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions] + return cls(*conditions) + + +class OrTerminationConditionConfig(BaseModel): + conditions: List[ComponentModel] + """List of termination conditions where any one being satisfied is sufficient.""" + + +class OrTerminationCondition(TerminationCondition, Component[OrTerminationConditionConfig]): + component_config_schema = OrTerminationConditionConfig + component_type = "termination" + component_provider_override = "autogen_agentchat.base.OrTerminationCondition" -class _OrTerminationCondition(TerminationCondition): def __init__(self, *conditions: TerminationCondition) -> None: self._conditions = conditions @@ -133,3 +168,13 @@ async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMe async def reset(self) -> None: for condition in self._conditions: await condition.reset() + + def _to_config(self) -> OrTerminationConditionConfig: + """Convert the OR termination condition to a config.""" + return OrTerminationConditionConfig(conditions=[condition.dump_component() for condition in self._conditions]) + + @classmethod + def _from_config(cls, config: OrTerminationConditionConfig) -> Self: + """Create an OR termination condition from a config.""" + conditions = [TerminationCondition.load_component(condition_model) for condition_model in config.conditions] + return cls(*conditions) diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/component-config.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/component-config.ipynb index 6d21c33d2a87..1f335d0b59e8 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/component-config.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/component-config.ipynb @@ -20,7 +20,7 @@ "\n", "## Usage\n", "\n", - "If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n", + "If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentToConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n", "\n", "### Loading a component from a config\n", "\n", @@ -52,7 +52,7 @@ "To add component functionality to a given class:\n", "\n", "1. Add a call to {py:meth}`~autogen_core.Component` in the class inheritance list.\n", - "2. Implment the {py:meth}`~autogen_core.ComponentConfigImpl._to_config` and {py:meth}`~autogen_core.ComponentConfigImpl._from_config` methods\n", + "2. Implment the {py:meth}`~autogen_core.ComponentToConfig._to_config` and {py:meth}`~autogen_core.ComponentFromConfig._from_config` methods\n", "\n", "For example:" ] @@ -63,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "from autogen_core import Component\n", + "from autogen_core import Component, ComponentBase\n", "from pydantic import BaseModel\n", "\n", "\n", @@ -71,7 +71,7 @@ " value: str\n", "\n", "\n", - "class MyComponent(Component[Config]):\n", + "class MyComponent(ComponentBase[Config], Component[Config]):\n", " component_type = \"custom\"\n", " component_config_schema = Config\n", "\n", @@ -129,7 +129,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/python/packages/autogen-core/src/autogen_core/__init__.py b/python/packages/autogen-core/src/autogen_core/__init__.py index df79231eaa7f..478ecc422e03 100644 --- a/python/packages/autogen-core/src/autogen_core/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/__init__.py @@ -14,6 +14,7 @@ from ._closure_agent import ClosureAgent, ClosureContext from ._component_config import ( Component, + ComponentBase, ComponentFromConfig, ComponentLoader, ComponentModel, @@ -116,6 +117,7 @@ "EVENT_LOGGER_NAME", "TRACE_LOGGER_NAME", "Component", + "ComponentBase", "ComponentFromConfig", "ComponentLoader", "ComponentModel", diff --git a/python/packages/autogen-core/src/autogen_core/_component_config.py b/python/packages/autogen-core/src/autogen_core/_component_config.py index 629065b4fe2d..f64165703832 100644 --- a/python/packages/autogen-core/src/autogen_core/_component_config.py +++ b/python/packages/autogen-core/src/autogen_core/_component_config.py @@ -2,7 +2,6 @@ import importlib import warnings -from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload from pydantic import BaseModel @@ -50,9 +49,8 @@ def _type_to_provider_str(t: type) -> str: } -class ComponentFromConfig(ABC, Generic[FromConfigT]): +class ComponentFromConfig(Generic[FromConfigT]): @classmethod - @abstractmethod def _from_config(cls, config: FromConfigT) -> Self: """Create a new instance of the component from a configuration object. @@ -64,7 +62,7 @@ def _from_config(cls, config: FromConfigT) -> Self: :meta public: """ - ... + raise NotImplementedError("This component does not support dumping to config") @classmethod def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self: @@ -84,7 +82,7 @@ def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self raise NotImplementedError("This component does not support loading from past versions") -class ComponentToConfig(ABC, Generic[ToConfigT]): +class ComponentToConfig(Generic[ToConfigT]): """The two methods a class must implement to be a component. Args: @@ -98,7 +96,6 @@ class ComponentToConfig(ABC, Generic[ToConfigT]): component_provider_override: ClassVar[str | None] = None """Override the provider string for the component. This should be used to prevent internal module names being a part of the module name.""" - @abstractmethod def _to_config(self) -> ToConfigT: """Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance. @@ -107,7 +104,7 @@ def _to_config(self) -> ToConfigT: :meta public: """ - ... + raise NotImplementedError("This component does not support dumping to config") def dump_component(self) -> ComponentModel: """Dump the component to a model that can be loaded back in. @@ -258,22 +255,25 @@ class ComponentSchemaType(Generic[ConfigT]): def __init_subclass__(cls, **kwargs: Any): super().__init_subclass__(**kwargs) - # TODO: validate provider is loadable - for var in cls.required_class_vars: - if not hasattr(cls, var): - warnings.warn( - f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", stacklevel=2 - ) + if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent": + # TODO: validate provider is loadable + for var in cls.required_class_vars: + if not hasattr(cls, var): + warnings.warn( + f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", + stacklevel=2, + ) + + +class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ... class Component( ComponentFromConfig[ConfigT], - ComponentToConfig[ConfigT], ComponentSchemaType[ConfigT], - ComponentLoader, Generic[ConfigT], ): - """To create a component class, inherit from this class. Then implement two class variables: + """To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables: - :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component. - :py:attr:`component_type` - What is the logical type of the component. @@ -307,10 +307,27 @@ def _from_config(cls, config: Config) -> MyComponent: return cls(value=config.value) """ - ... + def __init_subclass__(cls, **kwargs: Any): + super().__init_subclass__(**kwargs) + + if not is_component_class(cls): + warnings.warn( + f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.", + stacklevel=2, + ) + + +# Should never be used directly, only for type checking +class _ConcreteComponent( + ComponentFromConfig[ConfigT], + ComponentSchemaType[ConfigT], + ComponentToConfig[ConfigT], + ComponentLoader, + Generic[ConfigT], +): ... -def is_component_instance(cls: Any) -> TypeGuard[Component[BaseModel]]: +def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]: return ( isinstance(cls, ComponentFromConfig) and isinstance(cls, ComponentToConfig) @@ -319,7 +336,7 @@ def is_component_instance(cls: Any) -> TypeGuard[Component[BaseModel]]: ) -def is_component_class(cls: type) -> TypeGuard[Type[Component[BaseModel]]]: +def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]: return ( issubclass(cls, ComponentFromConfig) and issubclass(cls, ComponentToConfig) diff --git a/python/packages/autogen-core/src/autogen_core/models/_model_client.py b/python/packages/autogen-core/src/autogen_core/models/_model_client.py index a952ad43458c..356fad5487ca 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_model_client.py +++ b/python/packages/autogen-core/src/autogen_core/models/_model_client.py @@ -4,10 +4,11 @@ from abc import ABC, abstractmethod from typing import Literal, Mapping, Optional, Sequence, TypeAlias +from pydantic import BaseModel from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated from .. import CancellationToken -from .._component_config import ComponentLoader +from .._component_config import ComponentBase from ..tools import Tool, ToolSchema from ._types import CreateResult, LLMMessage, RequestUsage @@ -47,7 +48,7 @@ class ModelInfo(TypedDict, total=False): """Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family.""" -class ChatCompletionClient(ABC, ComponentLoader): +class ChatCompletionClient(ComponentBase[BaseModel], ABC): # Caching has to be handled internally as they can depend on the create args that were stored in the constructor @abstractmethod async def create( diff --git a/python/packages/autogen-core/tests/test_component_config.py b/python/packages/autogen-core/tests/test_component_config.py index fe726227acc1..d59fde59c1b6 100644 --- a/python/packages/autogen-core/tests/test_component_config.py +++ b/python/packages/autogen-core/tests/test_component_config.py @@ -4,7 +4,7 @@ from typing import Any, Dict import pytest -from autogen_core import Component, ComponentLoader, ComponentModel +from autogen_core import Component, ComponentBase, ComponentLoader, ComponentModel from autogen_core._component_config import _type_to_provider_str # type: ignore from autogen_core.models import ChatCompletionClient from autogen_test_utils import MyInnerComponent, MyOuterComponent @@ -16,7 +16,7 @@ class MyConfig(BaseModel): info: str -class MyComponent(Component[MyConfig]): +class MyComponent(ComponentBase[MyConfig], Component[MyConfig]): component_config_schema = MyConfig component_type = "custom" @@ -95,7 +95,7 @@ def test_cannot_import_locals() -> None: class InvalidModelClientConfig(BaseModel): info: str - class MyInvalidModelClient(Component[InvalidModelClientConfig]): + class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]): component_config_schema = InvalidModelClientConfig component_type = "model" @@ -119,7 +119,7 @@ class InvalidModelClientConfig(BaseModel): info: str -class MyInvalidModelClient(Component[InvalidModelClientConfig]): +class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]): component_config_schema = InvalidModelClientConfig component_type = "model" @@ -143,7 +143,7 @@ def test_type_error_on_creation() -> None: with pytest.warns(UserWarning): - class MyInvalidMissingAttrs(Component[InvalidModelClientConfig]): + class MyInvalidMissingAttrs(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]): def __init__(self, info: str): self.info = info @@ -189,7 +189,7 @@ def test_config_optional_values() -> None: assert component.__class__ == MyComponent -class ConfigProviderOverrided(Component[MyConfig]): +class ConfigProviderOverrided(ComponentBase[MyConfig], Component[MyConfig]): component_provider_override = "InvalidButStillOverridden" component_config_schema = MyConfig component_type = "custom" @@ -215,7 +215,7 @@ class MyConfig2(BaseModel): info2: str -class ComponentNonOneVersion(Component[MyConfig2]): +class ComponentNonOneVersion(ComponentBase[MyConfig2], Component[MyConfig2]): component_config_schema = MyConfig2 component_version = 2 component_type = "custom" @@ -231,7 +231,7 @@ def _from_config(cls, config: MyConfig2) -> Self: return cls(info=config.info2) -class ComponentNonOneVersionWithUpgrade(Component[MyConfig2]): +class ComponentNonOneVersionWithUpgrade(ComponentBase[MyConfig2], Component[MyConfig2]): component_config_schema = MyConfig2 component_version = 2 component_type = "custom" diff --git a/python/packages/autogen-ext/src/autogen_ext/auth/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/auth/azure/__init__.py index 607a3e01c3a4..6ddf5c83bef8 100644 --- a/python/packages/autogen-ext/src/autogen_ext/auth/azure/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/auth/azure/__init__.py @@ -1,6 +1,6 @@ from typing import List -from autogen_core import Component +from autogen_core import Component, ComponentBase from pydantic import BaseModel from typing_extensions import Self @@ -13,7 +13,7 @@ class TokenProviderConfig(BaseModel): scopes: List[str] -class AzureTokenProvider(Component[TokenProviderConfig]): +class AzureTokenProvider(ComponentBase[TokenProviderConfig], Component[TokenProviderConfig]): component_type = "token_provider" component_config_schema = TokenProviderConfig component_provider_override = "autogen_ext.models.openai.AzureTokenProvider" diff --git a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py index e3539ac9ede9..b917194b1d82 100644 --- a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py +++ b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py @@ -6,13 +6,14 @@ from autogen_core import ( BaseAgent, Component, + ComponentBase, + ComponentModel, DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler, ) -from autogen_core._component_config import ComponentModel from pydantic import BaseModel @@ -76,7 +77,7 @@ class MyInnerConfig(BaseModel): inner_message: str -class MyInnerComponent(Component[MyInnerConfig]): +class MyInnerComponent(ComponentBase[MyInnerConfig], Component[MyInnerConfig]): component_config_schema = MyInnerConfig component_type = "custom" @@ -96,7 +97,7 @@ class MyOuterConfig(BaseModel): inner_class: ComponentModel -class MyOuterComponent(Component[MyOuterConfig]): +class MyOuterComponent(ComponentBase[MyOuterConfig], Component[MyOuterConfig]): component_config_schema = MyOuterConfig component_type = "custom"