diff --git a/python/packages/autogen-core/src/autogen_core/__init__.py b/python/packages/autogen-core/src/autogen_core/__init__.py index 64c8d620153..df79231eaa7 100644 --- a/python/packages/autogen-core/src/autogen_core/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/__init__.py @@ -21,6 +21,7 @@ ComponentToConfig, ComponentType, is_component_class, + is_component_instance, ) from ._constants import ( EVENT_LOGGER_NAME as EVENT_LOGGER_NAME_ALIAS, @@ -122,6 +123,7 @@ "ComponentToConfig", "ComponentType", "is_component_class", + "is_component_instance", "DropMessage", "InterventionHandler", "DefaultInterventionHandler", diff --git a/python/packages/autogen-core/src/autogen_core/_component_config.py b/python/packages/autogen-core/src/autogen_core/_component_config.py index 5697c922db0..629065b4fe2 100644 --- a/python/packages/autogen-core/src/autogen_core/_component_config.py +++ b/python/packages/autogen-core/src/autogen_core/_component_config.py @@ -3,7 +3,7 @@ import importlib import warnings from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, Generic, Literal, Type, cast, overload +from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, overload from pydantic import BaseModel from typing_extensions import Self, TypeVar @@ -211,7 +211,7 @@ def load_component( module_path, class_name = output module = importlib.import_module(module_path) - component_class = cast(Component[BaseModel], module.__getattribute__(class_name)) + component_class = module.__getattribute__(class_name) if not is_component_class(component_class): raise TypeError("Invalid component class") @@ -232,7 +232,7 @@ def load_component( f"Tried to load component {component_class} which is on version {component_class.component_version} with a config on version {loaded_config_version} but _from_config_past_version is not implemented" ) from e else: - schema = component_class.component_config_schema + schema = component_class.component_config_schema # type: ignore validated_config = schema.model_validate(loaded_model.config) # We're allowed to use the private method here @@ -310,15 +310,19 @@ def _from_config(cls, config: Config) -> MyComponent: ... -def is_component_class(cls: type | Any) -> bool: +def is_component_instance(cls: Any) -> TypeGuard[Component[BaseModel]]: return ( - issubclass(cls, ComponentFromConfig) - and issubclass(cls, ComponentToConfig) - and issubclass(cls, ComponentSchemaType) - and issubclass(cls, ComponentLoader) - ) or ( isinstance(cls, ComponentFromConfig) and isinstance(cls, ComponentToConfig) and isinstance(cls, ComponentSchemaType) and isinstance(cls, ComponentLoader) ) + + +def is_component_class(cls: type) -> TypeGuard[Type[Component[BaseModel]]]: + return ( + issubclass(cls, ComponentFromConfig) + and issubclass(cls, ComponentToConfig) + and issubclass(cls, ComponentSchemaType) + and issubclass(cls, ComponentLoader) + )