diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/component-config.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/component-config.ipynb index 6d21c33d2a87..1f335d0b59e8 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/component-config.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/component-config.ipynb @@ -20,7 +20,7 @@ "\n", "## Usage\n", "\n", - "If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n", + "If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentToConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n", "\n", "### Loading a component from a config\n", "\n", @@ -52,7 +52,7 @@ "To add component functionality to a given class:\n", "\n", "1. Add a call to {py:meth}`~autogen_core.Component` in the class inheritance list.\n", - "2. Implment the {py:meth}`~autogen_core.ComponentConfigImpl._to_config` and {py:meth}`~autogen_core.ComponentConfigImpl._from_config` methods\n", + "2. Implment the {py:meth}`~autogen_core.ComponentToConfig._to_config` and {py:meth}`~autogen_core.ComponentFromConfig._from_config` methods\n", "\n", "For example:" ] @@ -63,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "from autogen_core import Component\n", + "from autogen_core import Component, ComponentBase\n", "from pydantic import BaseModel\n", "\n", "\n", @@ -71,7 +71,7 @@ " value: str\n", "\n", "\n", - "class MyComponent(Component[Config]):\n", + "class MyComponent(ComponentBase[Config], Component[Config]):\n", " component_type = \"custom\"\n", " component_config_schema = Config\n", "\n", @@ -129,7 +129,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.5" } }, "nbformat": 4, diff --git a/python/packages/autogen-core/src/autogen_core/__init__.py b/python/packages/autogen-core/src/autogen_core/__init__.py index c9d12872dd9d..478ecc422e03 100644 --- a/python/packages/autogen-core/src/autogen_core/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/__init__.py @@ -14,10 +14,15 @@ from ._closure_agent import ClosureAgent, ClosureContext from ._component_config import ( Component, - ComponentConfigImpl, + ComponentBase, + ComponentFromConfig, ComponentLoader, ComponentModel, + ComponentSchemaType, + ComponentToConfig, ComponentType, + is_component_class, + is_component_instance, ) from ._constants import ( EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS, @@ -112,10 +117,15 @@ "EVENT_LOGGER_NAME", "TRACE_LOGGER_NAME", "Component", + "ComponentBase", + "ComponentFromConfig", "ComponentLoader", - "ComponentConfigImpl", "ComponentModel", + "ComponentSchemaType", + "ComponentToConfig", "ComponentType", + "is_component_class", + "is_component_instance", "DropMessage", "InterventionHandler", "DefaultInterventionHandler", diff --git a/python/packages/autogen-core/src/autogen_core/_component_config.py b/python/packages/autogen-core/src/autogen_core/_component_config.py index 1045282921f2..f64165703832 100644 --- a/python/packages/autogen-core/src/autogen_core/_component_config.py +++ b/python/packages/autogen-core/src/autogen_core/_component_config.py @@ -2,13 +2,15 @@ import importlib import warnings -from typing import Any, ClassVar, Dict, Generic, List, Literal, Protocol, Type, cast, overload, runtime_checkable +from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload from pydantic import BaseModel from typing_extensions import Self, TypeVar ComponentType = Literal["model", "agent", "tool", "termination", "token_provider"] | str ConfigT = TypeVar("ConfigT", bound=BaseModel) +FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True) +ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True) T = TypeVar("T", bound=BaseModel, covariant=True) @@ -47,36 +49,9 @@ def _type_to_provider_str(t: type) -> str: } -@runtime_checkable -class ComponentConfigImpl(Protocol[ConfigT]): - # Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context) - component_config_schema: Type[ConfigT] - """The Pydantic model class which represents the configuration of the component.""" - component_type: ClassVar[ComponentType] - """The logical type of the component.""" - component_version: ClassVar[int] = 1 - """The version of the component, if schema incompatibilities are introduced this should be updated.""" - component_provider_override: ClassVar[str | None] = None - """Override the provider string for the component. This should be used to prevent internal module names being a part of the module name.""" - - """The two methods a class must implement to be a component. - - Args: - Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`. - """ - - def _to_config(self) -> ConfigT: - """Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance. - - Returns: - T: The configuration of the component. - - :meta public: - """ - ... - +class ComponentFromConfig(Generic[FromConfigT]): @classmethod - def _from_config(cls, config: ConfigT) -> Self: + def _from_config(cls, config: FromConfigT) -> Self: """Create a new instance of the component from a configuration object. Args: @@ -87,7 +62,7 @@ def _from_config(cls, config: ConfigT) -> Self: :meta public: """ - ... + raise NotImplementedError("This component does not support dumping to config") @classmethod def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self: @@ -104,7 +79,69 @@ def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self :meta public: """ - raise NotImplementedError() + raise NotImplementedError("This component does not support loading from past versions") + + +class ComponentToConfig(Generic[ToConfigT]): + """The two methods a class must implement to be a component. + + Args: + Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`. + """ + + component_type: ClassVar[ComponentType] + """The logical type of the component.""" + component_version: ClassVar[int] = 1 + """The version of the component, if schema incompatibilities are introduced this should be updated.""" + component_provider_override: ClassVar[str | None] = None + """Override the provider string for the component. This should be used to prevent internal module names being a part of the module name.""" + + def _to_config(self) -> ToConfigT: + """Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance. + + Returns: + T: The configuration of the component. + + :meta public: + """ + raise NotImplementedError("This component does not support dumping to config") + + def dump_component(self) -> ComponentModel: + """Dump the component to a model that can be loaded back in. + + Raises: + TypeError: If the component is a local class. + + Returns: + ComponentModel: The model representing the component. + """ + if self.component_provider_override is not None: + provider = self.component_provider_override + else: + provider = _type_to_provider_str(self.__class__) + # Warn if internal module name is used, + if "._" in provider: + warnings.warn( + "Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.", + stacklevel=2, + ) + + if "" in provider: + raise TypeError("Cannot dump component with local class") + + if not hasattr(self, "component_type"): + raise AttributeError("component_type not defined") + + obj_config = self._to_config().model_dump(exclude_none=True) + model = ComponentModel( + provider=provider, + component_type=self.component_type, + version=self.component_version, + component_version=self.component_version, + description=None, + config=obj_config, + ) + return model ExpectedType = TypeVar("ExpectedType") @@ -171,9 +208,9 @@ def load_component( module_path, class_name = output module = importlib.import_module(module_path) - component_class = cast(ComponentConfigImpl[BaseModel], module.__getattribute__(class_name)) + component_class = module.__getattribute__(class_name) - if not isinstance(component_class, ComponentConfigImpl): + if not is_component_class(component_class): raise TypeError("Invalid component class") # We need to check the schema is valid @@ -192,7 +229,7 @@ def load_component( f"Tried to load component {component_class} which is on version {component_class.component_version} with a config on version {loaded_config_version} but _from_config_past_version is not implemented" ) from e else: - schema = component_class.component_config_schema + schema = component_class.component_config_schema # type: ignore validated_config = schema.model_validate(loaded_model.config) # We're allowed to use the private method here @@ -208,8 +245,35 @@ def load_component( return cast(ExpectedType, instance) -class Component(ComponentConfigImpl[ConfigT], ComponentLoader, Generic[ConfigT]): - """To create a component class, inherit from this class. Then implement two class variables: +class ComponentSchemaType(Generic[ConfigT]): + # Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context) + component_config_schema: Type[ConfigT] + """The Pydantic model class which represents the configuration of the component.""" + + required_class_vars = ["component_config_schema", "component_type"] + + def __init_subclass__(cls, **kwargs: Any): + super().__init_subclass__(**kwargs) + + if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent": + # TODO: validate provider is loadable + for var in cls.required_class_vars: + if not hasattr(cls, var): + warnings.warn( + f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", + stacklevel=2, + ) + + +class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ... + + +class Component( + ComponentFromConfig[ConfigT], + ComponentSchemaType[ConfigT], + Generic[ConfigT], +): + """To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables: - :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component. - :py:attr:`component_type` - What is the logical type of the component. @@ -243,55 +307,39 @@ def _from_config(cls, config: Config) -> MyComponent: return cls(value=config.value) """ - required_class_vars: ClassVar[List[str]] = ["component_config_schema", "component_type"] - - def __init_subclass__(cls, **kwargs: Any) -> None: + def __init_subclass__(cls, **kwargs: Any): super().__init_subclass__(**kwargs) - # TODO: validate provider is loadable - for var in cls.required_class_vars: - if not hasattr(cls, var): - warnings.warn( - f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", stacklevel=2 - ) - - def dump_component(self) -> ComponentModel: - """Dump the component to a model that can be loaded back in. - - Raises: - TypeError: If the component is a local class. - - Returns: - ComponentModel: The model representing the component. - """ - if self.component_provider_override is not None: - provider = self.component_provider_override - else: - provider = _type_to_provider_str(self.__class__) - # Warn if internal module name is used, - if "._" in provider: - warnings.warn( - "Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.", - stacklevel=2, - ) - - if "" in provider: - raise TypeError("Cannot dump component with local class") - - if not hasattr(self, "component_type"): - raise AttributeError("component_type not defined") - - obj_config = self._to_config().model_dump(exclude_none=True) - model = ComponentModel( - provider=provider, - component_type=self.component_type, - version=self.component_version, - component_version=self.component_version, - description=None, - config=obj_config, - ) - return model - - @classmethod - def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self: - raise NotImplementedError() + if not is_component_class(cls): + warnings.warn( + f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.", + stacklevel=2, + ) + + +# Should never be used directly, only for type checking +class _ConcreteComponent( + ComponentFromConfig[ConfigT], + ComponentSchemaType[ConfigT], + ComponentToConfig[ConfigT], + ComponentLoader, + Generic[ConfigT], +): ... + + +def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]: + return ( + isinstance(cls, ComponentFromConfig) + and isinstance(cls, ComponentToConfig) + and isinstance(cls, ComponentSchemaType) + and isinstance(cls, ComponentLoader) + ) + + +def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]: + return ( + issubclass(cls, ComponentFromConfig) + and issubclass(cls, ComponentToConfig) + and issubclass(cls, ComponentSchemaType) + and issubclass(cls, ComponentLoader) + ) diff --git a/python/packages/autogen-core/src/autogen_core/models/_model_client.py b/python/packages/autogen-core/src/autogen_core/models/_model_client.py index a952ad43458c..356fad5487ca 100644 --- a/python/packages/autogen-core/src/autogen_core/models/_model_client.py +++ b/python/packages/autogen-core/src/autogen_core/models/_model_client.py @@ -4,10 +4,11 @@ from abc import ABC, abstractmethod from typing import Literal, Mapping, Optional, Sequence, TypeAlias +from pydantic import BaseModel from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated from .. import CancellationToken -from .._component_config import ComponentLoader +from .._component_config import ComponentBase from ..tools import Tool, ToolSchema from ._types import CreateResult, LLMMessage, RequestUsage @@ -47,7 +48,7 @@ class ModelInfo(TypedDict, total=False): """Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family.""" -class ChatCompletionClient(ABC, ComponentLoader): +class ChatCompletionClient(ComponentBase[BaseModel], ABC): # Caching has to be handled internally as they can depend on the create args that were stored in the constructor @abstractmethod async def create( diff --git a/python/packages/autogen-core/tests/test_component_config.py b/python/packages/autogen-core/tests/test_component_config.py index fe726227acc1..d59fde59c1b6 100644 --- a/python/packages/autogen-core/tests/test_component_config.py +++ b/python/packages/autogen-core/tests/test_component_config.py @@ -4,7 +4,7 @@ from typing import Any, Dict import pytest -from autogen_core import Component, ComponentLoader, ComponentModel +from autogen_core import Component, ComponentBase, ComponentLoader, ComponentModel from autogen_core._component_config import _type_to_provider_str # type: ignore from autogen_core.models import ChatCompletionClient from autogen_test_utils import MyInnerComponent, MyOuterComponent @@ -16,7 +16,7 @@ class MyConfig(BaseModel): info: str -class MyComponent(Component[MyConfig]): +class MyComponent(ComponentBase[MyConfig], Component[MyConfig]): component_config_schema = MyConfig component_type = "custom" @@ -95,7 +95,7 @@ def test_cannot_import_locals() -> None: class InvalidModelClientConfig(BaseModel): info: str - class MyInvalidModelClient(Component[InvalidModelClientConfig]): + class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]): component_config_schema = InvalidModelClientConfig component_type = "model" @@ -119,7 +119,7 @@ class InvalidModelClientConfig(BaseModel): info: str -class MyInvalidModelClient(Component[InvalidModelClientConfig]): +class MyInvalidModelClient(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]): component_config_schema = InvalidModelClientConfig component_type = "model" @@ -143,7 +143,7 @@ def test_type_error_on_creation() -> None: with pytest.warns(UserWarning): - class MyInvalidMissingAttrs(Component[InvalidModelClientConfig]): + class MyInvalidMissingAttrs(ComponentBase[InvalidModelClientConfig], Component[InvalidModelClientConfig]): def __init__(self, info: str): self.info = info @@ -189,7 +189,7 @@ def test_config_optional_values() -> None: assert component.__class__ == MyComponent -class ConfigProviderOverrided(Component[MyConfig]): +class ConfigProviderOverrided(ComponentBase[MyConfig], Component[MyConfig]): component_provider_override = "InvalidButStillOverridden" component_config_schema = MyConfig component_type = "custom" @@ -215,7 +215,7 @@ class MyConfig2(BaseModel): info2: str -class ComponentNonOneVersion(Component[MyConfig2]): +class ComponentNonOneVersion(ComponentBase[MyConfig2], Component[MyConfig2]): component_config_schema = MyConfig2 component_version = 2 component_type = "custom" @@ -231,7 +231,7 @@ def _from_config(cls, config: MyConfig2) -> Self: return cls(info=config.info2) -class ComponentNonOneVersionWithUpgrade(Component[MyConfig2]): +class ComponentNonOneVersionWithUpgrade(ComponentBase[MyConfig2], Component[MyConfig2]): component_config_schema = MyConfig2 component_version = 2 component_type = "custom" diff --git a/python/packages/autogen-ext/src/autogen_ext/auth/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/auth/azure/__init__.py index 96732a4194d8..08de1e723cd5 100644 --- a/python/packages/autogen-ext/src/autogen_ext/auth/azure/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/auth/azure/__init__.py @@ -1,6 +1,6 @@ from typing import List -from autogen_core import Component +from autogen_core import Component, ComponentBase from pydantic import BaseModel from typing_extensions import Self @@ -13,7 +13,7 @@ class TokenProviderConfig(BaseModel): scopes: List[str] -class AzureTokenProvider(Component[TokenProviderConfig]): +class AzureTokenProvider(ComponentBase[TokenProviderConfig], Component[TokenProviderConfig]): component_type = "token_provider" component_config_schema = TokenProviderConfig component_provider_override = "autogen_ext.auth.azure.AzureTokenProvider" diff --git a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py index e3539ac9ede9..b917194b1d82 100644 --- a/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py +++ b/python/packages/autogen-test-utils/src/autogen_test_utils/__init__.py @@ -6,13 +6,14 @@ from autogen_core import ( BaseAgent, Component, + ComponentBase, + ComponentModel, DefaultTopicId, MessageContext, RoutedAgent, default_subscription, message_handler, ) -from autogen_core._component_config import ComponentModel from pydantic import BaseModel @@ -76,7 +77,7 @@ class MyInnerConfig(BaseModel): inner_message: str -class MyInnerComponent(Component[MyInnerConfig]): +class MyInnerComponent(ComponentBase[MyInnerConfig], Component[MyInnerConfig]): component_config_schema = MyInnerConfig component_type = "custom" @@ -96,7 +97,7 @@ class MyOuterConfig(BaseModel): inner_class: ComponentModel -class MyOuterComponent(Component[MyOuterConfig]): +class MyOuterComponent(ComponentBase[MyOuterConfig], Component[MyOuterConfig]): component_config_schema = MyOuterConfig component_type = "custom" diff --git a/python/packages/component-schema-gen/src/component_schema_gen/__main__.py b/python/packages/component-schema-gen/src/component_schema_gen/__main__.py index bf0a21f1f141..810d5ec84455 100644 --- a/python/packages/component-schema-gen/src/component_schema_gen/__main__.py +++ b/python/packages/component-schema-gen/src/component_schema_gen/__main__.py @@ -5,7 +5,8 @@ from autogen_core import ComponentModel from autogen_core._component_config import ( WELL_KNOWN_PROVIDERS, - ComponentConfigImpl, + ComponentSchemaType, + ComponentToConfig, _type_to_provider_str, # type: ignore ) from autogen_ext.auth.azure import AzureTokenProvider @@ -17,10 +18,13 @@ T = TypeVar("T", bound=BaseModel) -def build_specific_component_schema(component: type[ComponentConfigImpl[T]], provider_str: str) -> Dict[str, Any]: +def build_specific_component_schema(component: type[ComponentSchemaType[T]], provider_str: str) -> Dict[str, Any]: model = component.component_config_schema # type: ignore model_schema = model.model_json_schema() + # We can't specify component to be the union of two types, so we assert it here + assert issubclass(component, ComponentToConfig) + component_model_schema = ComponentModel.model_json_schema() if "$defs" not in component_model_schema: component_model_schema["$defs"] = {} @@ -70,7 +74,9 @@ def main() -> None: for key, value in WELL_KNOWN_PROVIDERS.items(): reverse_provider_lookup_table[value].append(key) - def add_type(type: type[ComponentConfigImpl[T]]) -> None: + def add_type(type: type[ComponentSchemaType[T]]) -> None: + # We can't specify component to be the union of two types, so we assert it here + assert issubclass(type, ComponentToConfig) canonical = type.component_provider_override or _type_to_provider_str(type) reverse_provider_lookup_table[canonical].append(canonical) for provider_str in reverse_provider_lookup_table[canonical]: