Skip to content

Commit

Permalink
Improve handling for kernel plugin from file.
Browse files Browse the repository at this point in the history
  • Loading branch information
moonbox3 committed Jan 24, 2025
1 parent 4d29656 commit cec8e1a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
2 changes: 2 additions & 0 deletions python/semantic_kernel/functions/kernel_function_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def add_plugin(
return self.plugins[plugin.name]
if not plugin_name:
raise ValueError("plugin_name must be provided if a plugin is not supplied.")
if not isinstance(plugin_name, str):
raise TypeError("plugin_name must be a string.")
if plugin:
self.plugins[plugin_name] = KernelPlugin.from_object(
plugin_name=plugin_name, plugin_instance=plugin, description=description
Expand Down
11 changes: 10 additions & 1 deletion python/semantic_kernel/functions/kernel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,16 @@ def from_python_file(
for name, cls_instance in inspect.getmembers(module, inspect.isclass):
if cls_instance.__module__ != module_name:
continue
instance = getattr(module, name)(**class_init_arguments.get(name, {}) if class_init_arguments else {})
# Check whether this class has at least one @kernel_function decorated method
has_kernel_function = False
for _, method in inspect.getmembers(cls_instance, inspect.isfunction):
if getattr(method, "__kernel_function__", False):
has_kernel_function = True
break
if not has_kernel_function:
continue
init_args = class_init_arguments.get(name, {}) if class_init_arguments else {}
instance = getattr(module, name)(**init_args)
return cls.from_object(plugin_name=plugin_name, description=description, plugin_instance=instance)
raise PluginInitializationError(f"No class found in file: {py_file}")

Expand Down
6 changes: 6 additions & 0 deletions python/tests/unit/kernel/test_kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import os
from pathlib import Path
from typing import Union
from unittest.mock import AsyncMock, MagicMock, patch

Expand Down Expand Up @@ -483,6 +484,11 @@ def test_plugin_name_error(kernel: Kernel):
kernel.add_plugin(" ", None)


def test_plugin_name_not_string_error(kernel: Kernel):
with pytest.raises(TypeError):
kernel.add_plugin(" ", plugin_name=Path(__file__).parent)


def test_plugins_add_plugins(kernel: Kernel):
plugin1 = KernelPlugin(name="TestPlugin")
plugin2 = KernelPlugin(name="TestPlugin2")
Expand Down

0 comments on commit cec8e1a

Please sign in to comment.