Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split apart component infra to allow for abstract class integration #5017

Merged
merged 10 commits into from
Jan 13, 2025
10 changes: 8 additions & 2 deletions python/packages/autogen-core/src/autogen_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
from ._closure_agent import ClosureAgent, ClosureContext
from ._component_config import (
Component,
ComponentConfigImpl,
ComponentFromConfig,
ComponentLoader,
ComponentModel,
ComponentSchemaType,
ComponentToConfig,
ComponentType,
is_component_class,
)
from ._constants import (
EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS,
Expand Down Expand Up @@ -112,10 +115,13 @@
"EVENT_LOGGER_NAME",
"TRACE_LOGGER_NAME",
"Component",
"ComponentFromConfig",
"ComponentLoader",
"ComponentConfigImpl",
"ComponentModel",
"ComponentSchemaType",
"ComponentToConfig",
"ComponentType",
"is_component_class",
"DropMessage",
"InterventionHandler",
"DefaultInterventionHandler",
Expand Down
199 changes: 113 additions & 86 deletions python/packages/autogen-core/src/autogen_core/_component_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import importlib
import warnings
from typing import Any, ClassVar, Dict, Generic, List, Literal, Protocol, Type, cast, overload, runtime_checkable
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Dict, Generic, Literal, Type, cast, overload

from pydantic import BaseModel
from typing_extensions import Self, TypeVar

ComponentType = Literal["model", "agent", "tool", "termination", "token_provider"] | str
ConfigT = TypeVar("ConfigT", bound=BaseModel)
FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True)
ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True)

T = TypeVar("T", bound=BaseModel, covariant=True)

Expand Down Expand Up @@ -47,36 +50,10 @@
}


@runtime_checkable
class ComponentConfigImpl(Protocol[ConfigT]):
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
component_config_schema: Type[ConfigT]
"""The Pydantic model class which represents the configuration of the component."""
component_type: ClassVar[ComponentType]
"""The logical type of the component."""
component_version: ClassVar[int] = 1
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
component_provider_override: ClassVar[str | None] = None
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""

"""The two methods a class must implement to be a component.

Args:
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
"""

def _to_config(self) -> ConfigT:
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.

Returns:
T: The configuration of the component.

:meta public:
"""
...

class ComponentFromConfig(ABC, Generic[FromConfigT]):
@classmethod
def _from_config(cls, config: ConfigT) -> Self:
@abstractmethod
def _from_config(cls, config: FromConfigT) -> Self:
"""Create a new instance of the component from a configuration object.

Args:
Expand Down Expand Up @@ -104,7 +81,70 @@

:meta public:
"""
raise NotImplementedError()
raise NotImplementedError("This component does not support loading from past versions")


class ComponentToConfig(ABC, Generic[ToConfigT]):
"""The two methods a class must implement to be a component.

Args:
Protocol (ConfigT): Type which derives from :py:class:`pydantic.BaseModel`.
"""

component_type: ClassVar[ComponentType]
"""The logical type of the component."""
component_version: ClassVar[int] = 1
"""The version of the component, if schema incompatibilities are introduced this should be updated."""
component_provider_override: ClassVar[str | None] = None
"""Override the provider string for the component. This should be used to prevent internal module names being a part of the module name."""

@abstractmethod
def _to_config(self) -> ToConfigT:
"""Dump the configuration that would be requite to create a new instance of a component matching the configuration of this instance.

Returns:
T: The configuration of the component.

:meta public:
"""
...

Check warning on line 110 in python/packages/autogen-core/src/autogen_core/_component_config.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_component_config.py#L110

Added line #L110 was not covered by tests

def dump_component(self) -> ComponentModel:
"""Dump the component to a model that can be loaded back in.

Raises:
TypeError: If the component is a local class.

Returns:
ComponentModel: The model representing the component.
"""
if self.component_provider_override is not None:
provider = self.component_provider_override
else:
provider = _type_to_provider_str(self.__class__)
# Warn if internal module name is used,
if "._" in provider:
warnings.warn(

Check warning on line 127 in python/packages/autogen-core/src/autogen_core/_component_config.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_component_config.py#L127

Added line #L127 was not covered by tests
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
stacklevel=2,
)

if "<locals>" in provider:
raise TypeError("Cannot dump component with local class")

if not hasattr(self, "component_type"):
raise AttributeError("component_type not defined")

obj_config = self._to_config().model_dump(exclude_none=True)
model = ComponentModel(
provider=provider,
component_type=self.component_type,
version=self.component_version,
component_version=self.component_version,
description=None,
config=obj_config,
)
return model


ExpectedType = TypeVar("ExpectedType")
Expand Down Expand Up @@ -171,9 +211,9 @@

module_path, class_name = output
module = importlib.import_module(module_path)
component_class = cast(ComponentConfigImpl[BaseModel], module.__getattribute__(class_name))
component_class = cast(Component[BaseModel], module.__getattribute__(class_name))

if not isinstance(component_class, ComponentConfigImpl):
if not is_component_class(component_class):
raise TypeError("Invalid component class")

# We need to check the schema is valid
Expand Down Expand Up @@ -208,7 +248,31 @@
return cast(ExpectedType, instance)


class Component(ComponentConfigImpl[ConfigT], ComponentLoader, Generic[ConfigT]):
class ComponentSchemaType(Generic[ConfigT]):
# Ideally would be ClassVar[Type[ConfigT]], but this is disallowed https://github.com/python/typing/discussions/1424 (despite being valid in this context)
component_config_schema: Type[ConfigT]
"""The Pydantic model class which represents the configuration of the component."""

required_class_vars = ["component_config_schema", "component_type"]

def __init_subclass__(cls, **kwargs: Any):
super().__init_subclass__(**kwargs)

# TODO: validate provider is loadable
for var in cls.required_class_vars:
if not hasattr(cls, var):
warnings.warn(
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", stacklevel=2
)


class Component(
ComponentFromConfig[ConfigT],
ComponentToConfig[ConfigT],
ComponentSchemaType[ConfigT],
ComponentLoader,
Generic[ConfigT],
):
"""To create a component class, inherit from this class. Then implement two class variables:

- :py:attr:`component_config_schema` - A Pydantic model class which represents the configuration of the component. This is also the type parameter of Component.
Expand Down Expand Up @@ -243,55 +307,18 @@
return cls(value=config.value)
"""

required_class_vars: ClassVar[List[str]] = ["component_config_schema", "component_type"]

def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)

# TODO: validate provider is loadable
for var in cls.required_class_vars:
if not hasattr(cls, var):
warnings.warn(
f"Class variable '{var}' must be defined in {cls.__name__} to be a valid component", stacklevel=2
)

def dump_component(self) -> ComponentModel:
"""Dump the component to a model that can be loaded back in.

Raises:
TypeError: If the component is a local class.

Returns:
ComponentModel: The model representing the component.
"""
if self.component_provider_override is not None:
provider = self.component_provider_override
else:
provider = _type_to_provider_str(self.__class__)
# Warn if internal module name is used,
if "._" in provider:
warnings.warn(
"Internal module name used in provider string. This is not recommended and may cause issues in the future. Silence this warning by setting component_provider_override to this value.",
stacklevel=2,
)

if "<locals>" in provider:
raise TypeError("Cannot dump component with local class")

if not hasattr(self, "component_type"):
raise AttributeError("component_type not defined")

obj_config = self._to_config().model_dump(exclude_none=True)
model = ComponentModel(
provider=provider,
component_type=self.component_type,
version=self.component_version,
component_version=self.component_version,
description=None,
config=obj_config,
)
return model

@classmethod
def _from_config_past_version(cls, config: Dict[str, Any], version: int) -> Self:
raise NotImplementedError()
...


def is_component_class(cls: type | Any) -> bool:
return (
issubclass(cls, ComponentFromConfig)
and issubclass(cls, ComponentToConfig)
and issubclass(cls, ComponentSchemaType)
and issubclass(cls, ComponentLoader)
) or (
isinstance(cls, ComponentFromConfig)
and isinstance(cls, ComponentToConfig)
and isinstance(cls, ComponentSchemaType)
and isinstance(cls, ComponentLoader)
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from autogen_core import ComponentModel
from autogen_core._component_config import (
WELL_KNOWN_PROVIDERS,
ComponentConfigImpl,
ComponentSchemaType,
ComponentToConfig,
_type_to_provider_str, # type: ignore
)
from autogen_ext.auth.azure import AzureTokenProvider
Expand All @@ -17,10 +18,13 @@
T = TypeVar("T", bound=BaseModel)


def build_specific_component_schema(component: type[ComponentConfigImpl[T]], provider_str: str) -> Dict[str, Any]:
def build_specific_component_schema(component: type[ComponentSchemaType[T]], provider_str: str) -> Dict[str, Any]:
model = component.component_config_schema # type: ignore
model_schema = model.model_json_schema()

# We can't specify component to be the union of two types, so we assert it here
assert issubclass(component, ComponentToConfig)

component_model_schema = ComponentModel.model_json_schema()
if "$defs" not in component_model_schema:
component_model_schema["$defs"] = {}
Expand Down Expand Up @@ -70,7 +74,9 @@ def main() -> None:
for key, value in WELL_KNOWN_PROVIDERS.items():
reverse_provider_lookup_table[value].append(key)

def add_type(type: type[ComponentConfigImpl[T]]) -> None:
def add_type(type: type[ComponentSchemaType[T]]) -> None:
# We can't specify component to be the union of two types, so we assert it here
assert issubclass(type, ComponentToConfig)
canonical = type.component_provider_override or _type_to_provider_str(type)
reverse_provider_lookup_table[canonical].append(canonical)
for provider_str in reverse_provider_lookup_table[canonical]:
Expand Down
Loading