Skip to content

Commit

Permalink
Split apart component infra to allow for abstract class integration (#…
Browse files Browse the repository at this point in the history
…5017)

* Split apart component infra to allow for abstract class integration

* fix is_component_class check

* make is_ functions type guards

* Simplify component creation

* undo changes

* Format
  • Loading branch information
jackgerrits authored Jan 13, 2025
1 parent 70f7e99 commit 404522b
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"\n",
"## Usage\n",
"\n",
"If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n",
"If you have a component in Python and want to get the config for it, simply call {py:meth}`~autogen_core.ComponentToConfig.dump_component` on it. The resulting object can be passed back into {py:meth}`~autogen_core.ComponentLoader.load_component` to get the component back.\n",
"\n",
"### Loading a component from a config\n",
"\n",
Expand Down Expand Up @@ -52,7 +52,7 @@
"To add component functionality to a given class:\n",
"\n",
"1. Add a call to {py:meth}`~autogen_core.Component` in the class inheritance list.\n",
"2. Implment the {py:meth}`~autogen_core.ComponentConfigImpl._to_config` and {py:meth}`~autogen_core.ComponentConfigImpl._from_config` methods\n",
"2. Implment the {py:meth}`~autogen_core.ComponentToConfig._to_config` and {py:meth}`~autogen_core.ComponentFromConfig._from_config` methods\n",
"\n",
"For example:"
]
Expand All @@ -63,15 +63,15 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core import Component\n",
"from autogen_core import Component, ComponentBase\n",
"from pydantic import BaseModel\n",
"\n",
"\n",
"class Config(BaseModel):\n",
" value: str\n",
"\n",
"\n",
"class MyComponent(Component[Config]):\n",
"class MyComponent(ComponentBase[Config], Component[Config]):\n",
" component_type = \"custom\"\n",
" component_config_schema = Config\n",
"\n",
Expand Down Expand Up @@ -129,7 +129,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
14 changes: 12 additions & 2 deletions python/packages/autogen-core/src/autogen_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
from ._closure_agent import ClosureAgent, ClosureContext
from ._component_config import (
Component,
ComponentConfigImpl,
ComponentBase,
ComponentFromConfig,
ComponentLoader,
ComponentModel,
ComponentSchemaType,
ComponentToConfig,
ComponentType,
is_component_class,
is_component_instance,
)
from ._constants import (
EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS,
Expand Down Expand Up @@ -112,10 +117,15 @@
"EVENT_LOGGER_NAME",
"TRACE_LOGGER_NAME",
"Component",
"ComponentBase",
"ComponentFromConfig",
"ComponentLoader",
"ComponentConfigImpl",
"ComponentModel",
"ComponentSchemaType",
"ComponentToConfig",
"ComponentType",
"is_component_class",
"is_component_instance",
"DropMessage",
"InterventionHandler",
"DefaultInterventionHandler",
Expand Down
222 changes: 135 additions & 87 deletions python/packages/autogen-core/src/autogen_core/_component_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import importlib
import warnings
from typing import Any, ClassVar, Dict, Generic, List, Literal, Protocol, Type, cast, overload, runtime_checkable
from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload

from pydantic import BaseModel
from typing_extensions import Self, TypeVar

ComponentType = Literal["model", "agent", "tool", "termination", "token_provider"] | str
ConfigT = TypeVar("ConfigT", bound=BaseModel)
FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True)
ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True)

T = TypeVar("T", bound=BaseModel, covariant=True)

Expand Down Expand Up @@ -47,36 +49,9 @@ def _type_to_provider_str(t: type) -> str:
}


@runtime_checkable
class ComponentConfigImpl(Protocol[ConfigT]):
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
component_config_schema: Type[ConfigT]
"""The Pydantic model class which represents the configuration of the component."""
component_type: ClassVar[ComponentType]
"""The logical type of the component."""
component_version: ClassVar[int] = 1
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
component_provider_override: ClassVar[str | None] = None
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""

"""The two methods a class must implement to be a component.
Args:
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
"""

def _to_config(self) -> ConfigT:
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
Returns:
T: The configuration of the component.
:meta public:
"""
...

class ComponentFromConfig(Generic[FromConfigT]):
@classmethod
def _from_config(cls, config: ConfigT) -> Self:
def _from_config(cls, config: FromConfigT) -> Self:
"""Create a new instance of the component from a configuration object.
Args:
Expand All @@ -87,7 +62,7 @@ def _from_config(cls, config: ConfigT) -> Self:
:meta public:
"""
...
raise NotImplementedError("This component does not support dumping to config")

@classmethod
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
Expand All @@ -104,7 +79,69 @@ def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self
:meta public:
"""
raise NotImplementedError()
raise NotImplementedError("This component does not support loading from past versions")


class ComponentToConfig(Generic[ToConfigT]):
"""The two methods a class must implement to be a component.
Args:
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
"""

component_type: ClassVar[ComponentType]
"""The logical type of the component."""
component_version: ClassVar[int] = 1
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
component_provider_override: ClassVar[str | None] = None
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""

def _to_config(self) -> ToConfigT:
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.
Returns:
T: The configuration of the component.
:meta public:
"""
raise NotImplementedError("This component does not support dumping to config")

def dump_component(self) -> ComponentModel:
"""Dump the component to a model that can be loaded back in.
Raises:
TypeError: If the component is a local class.
Returns:
ComponentModel: The model representing the component.
"""
if self.component_provider_override is not None:
provider = self.component_provider_override
else:
provider = _type_to_provider_str(self.__class__)
# Warn if internal module name is used,
if "._" in provider:
warnings.warn(
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
stacklevel=2,
)

if "<locals>" in provider:
raise TypeError("Cannot dump component with local class")

if not hasattr(self, "component_type"):
raise AttributeError("component_type not defined")

obj_config = self._to_config().model_dump(exclude_none=True)
model = ComponentModel(
provider=provider,
component_type=self.component_type,
version=self.component_version,
component_version=self.component_version,
description=None,
config=obj_config,
)
return model


ExpectedType = TypeVar("ExpectedType")
Expand Down Expand Up @@ -171,9 +208,9 @@ def load_component(

module_path, class_name = output
module = importlib.import_module(module_path)
component_class = cast(ComponentConfigImpl[BaseModel], module.__getattribute__(class_name))
component_class = module.__getattribute__(class_name)

if not isinstance(component_class, ComponentConfigImpl):
if not is_component_class(component_class):
raise TypeError("Invalid component class")

# We need to check the schema is valid
Expand All @@ -192,7 +229,7 @@ def load_component(
f"Tried to load component {component_class} which is on version {component_class.component_version} with a config on version {loaded_config_version} but _from_config_past_version is not implemented"
) from e
else:
schema = component_class.component_config_schema
schema = component_class.component_config_schema # type: ignore
validated_config = schema.model_validate(loaded_model.config)

# We're allowed to use the private method here
Expand All @@ -208,8 +245,35 @@ def load_component(
return cast(ExpectedType, instance)


class Component(ComponentConfigImpl[ConfigT], ComponentLoader, Generic[ConfigT]):
"""To create a component class, inherit from this class. Then implement two class variables:
class ComponentSchemaType(Generic[ConfigT]):
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
component_config_schema: Type[ConfigT]
"""The Pydantic model class which represents the configuration of the component."""

required_class_vars = ["component_config_schema", "component_type"]

def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)

if cls.__name__ != "Component" and not cls.__name__ == "_ConcreteComponent":
# TODO: validate provider is loadable
for var in cls.required_class_vars:
if not hasattr(cls, var):
warnings.warn(
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component",
stacklevel=2,
)


class ComponentBase(ComponentToConfig[ConfigT], ComponentLoader, Generic[ConfigT]): ...


class Component(
ComponentFromConfig[ConfigT],
ComponentSchemaType[ConfigT],
Generic[ConfigT],
):
"""To create a component class, inherit from this class for the concrete class and ComponentBase on the interface. Then implement two class variables:
- :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component.
- :py:attr:`component_type` - What is the logical type of the component.
Expand Down Expand Up @@ -243,55 +307,39 @@ def _from_config(cls, config: Config) -> MyComponent:
return cls(value=config.value)
"""

required_class_vars: ClassVar[List[str]] = ["component_config_schema", "component_type"]

def __init_subclass__(cls, **kwargs: Any) -> None:
def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)

# TODO: validate provider is loadable
for var in cls.required_class_vars:
if not hasattr(cls, var):
warnings.warn(
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", stacklevel=2
)

def dump_component(self) -> ComponentModel:
"""Dump the component to a model that can be loaded back in.
Raises:
TypeError: If the component is a local class.
Returns:
ComponentModel: The model representing the component.
"""
if self.component_provider_override is not None:
provider = self.component_provider_override
else:
provider = _type_to_provider_str(self.__class__)
# Warn if internal module name is used,
if "._" in provider:
warnings.warn(
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
stacklevel=2,
)

if "<locals>" in provider:
raise TypeError("Cannot dump component with local class")

if not hasattr(self, "component_type"):
raise AttributeError("component_type not defined")

obj_config = self._to_config().model_dump(exclude_none=True)
model = ComponentModel(
provider=provider,
component_type=self.component_type,
version=self.component_version,
component_version=self.component_version,
description=None,
config=obj_config,
)
return model

@classmethod
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
raise NotImplementedError()
if not is_component_class(cls):
warnings.warn(
f"Component class '{cls.__name__}' must subclass the following: ComponentFromConfig, ComponentToConfig, ComponentSchemaType, ComponentLoader, individually or with ComponentBase and Component. Look at the component config documentation or how OpenAIChatCompletionClient does it.",
stacklevel=2,
)


# Should never be used directly, only for type checking
class _ConcreteComponent(
ComponentFromConfig[ConfigT],
ComponentSchemaType[ConfigT],
ComponentToConfig[ConfigT],
ComponentLoader,
Generic[ConfigT],
): ...


def is_component_instance(cls: Any) -> TypeGuard[_ConcreteComponent[BaseModel]]:
return (
isinstance(cls, ComponentFromConfig)
and isinstance(cls, ComponentToConfig)
and isinstance(cls, ComponentSchemaType)
and isinstance(cls, ComponentLoader)
)


def is_component_class(cls: type) -> TypeGuard[Type[_ConcreteComponent[BaseModel]]]:
return (
issubclass(cls, ComponentFromConfig)
and issubclass(cls, ComponentToConfig)
and issubclass(cls, ComponentSchemaType)
and issubclass(cls, ComponentLoader)
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from abc import ABC, abstractmethod
from typing import Literal, Mapping, Optional, Sequence, TypeAlias

from pydantic import BaseModel
from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated

from .. import CancellationToken
from .._component_config import ComponentLoader
from .._component_config import ComponentBase
from ..tools import Tool, ToolSchema
from ._types import CreateResult, LLMMessage, RequestUsage

Expand Down Expand Up @@ -47,7 +48,7 @@ class ModelInfo(TypedDict, total=False):
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""


class ChatCompletionClient(ABC, ComponentLoader):
class ChatCompletionClient(ComponentBase[BaseModel], ABC):
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
@abstractmethod
async def create(
Expand Down
Loading

0 comments on commit 404522b

Please sign in to comment.