Skip to content

Commit f77cd39

Browse files
committed
initial sk model adapter implementation
1 parent 501d8bb commit f77cd39

File tree

6 files changed

+672
-374
lines changed

6 files changed

+672
-374
lines changed

python/packages/autogen-ext/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ video-surfer = [
5252
grpc = [
5353
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
5454
]
55+
semantic-kernel = [
56+
"semantic-kernel>=1.17.1",
57+
]
5558

5659
[tool.hatch.build.targets.wheel]
5760
packages = ["src/autogen_ext"]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ._sk_chat_completion_adapter import SKChatCompletionAdapter
2+
3+
__all__ = [
4+
"SKChatCompletionAdapter"
5+
]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from semantic_kernel.functions.kernel_function import KernelFunction
2+
from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata
3+
from semantic_kernel.functions.kernel_parameter_metadata import KernelParameterMetadata
4+
from semantic_kernel.functions.function_result import FunctionResult
5+
from semantic_kernel.filters.functions.function_invocation_context import FunctionInvocationContext
6+
from semantic_kernel.exceptions import FunctionExecutionException
7+
from autogen_core.tools import BaseTool
8+
9+
class KernelFunctionFromTool(KernelFunction):
10+
def __init__(self, tool: BaseTool, plugin_name: str | None = None):
11+
# Build up KernelFunctionMetadata. You can also parse the tool’s schema for parameters.
12+
parameters = [
13+
KernelParameterMetadata(
14+
name="args",
15+
description="JSON arguments for the tool",
16+
default_value=None,
17+
type_="dict",
18+
type_object=dict,
19+
is_required=True,
20+
)
21+
]
22+
return_param = KernelParameterMetadata(
23+
name="return",
24+
description="Result from the tool",
25+
default_value=None,
26+
type_="str",
27+
type_object=str,
28+
is_required=False,
29+
)
30+
31+
metadata = KernelFunctionMetadata(
32+
name=tool.name,
33+
description=tool.description,
34+
parameters=parameters,
35+
return_parameter=return_param,
36+
is_prompt=False,
37+
is_asynchronous=True,
38+
plugin_name=plugin_name or "",
39+
)
40+
super().__init__(metadata=metadata)
41+
self._tool = tool
42+
43+
async def _invoke_internal(self, context: FunctionInvocationContext) -> None:
44+
# Extract the "args" parameter from the context
45+
if "args" not in context.arguments:
46+
raise FunctionExecutionException("Missing 'args' in FunctionInvocationContext.arguments")
47+
tool_args = context.arguments["args"]
48+
49+
# Call your tool’s run_json
50+
result = await self._tool.run_json(tool_args, cancellation_token=None)
51+
52+
# Wrap in a FunctionResult
53+
context.result = FunctionResult(
54+
function=self.metadata,
55+
value=result,
56+
metadata={"used_arguments": tool_args},
57+
)
58+
59+
async def _invoke_internal_stream(self, context: FunctionInvocationContext) -> None:
60+
# If you don’t have a streaming mechanism in your tool, you can simply reuse _invoke_internal
61+
# or raise NotImplementedError. For example:
62+
await self._invoke_internal(context)
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from typing import Any, Mapping, Optional, Sequence
2+
from autogen_core._cancellation_token import CancellationToken
3+
from autogen_core.models import RequestUsage, FunctionExecutionResultMessage, ModelCapabilities, AssistantMessage, SystemMessage, UserMessage, FunctionExecutionResult
4+
from autogen_core.models import ChatCompletionClient, CreateResult, LLMMessage
5+
from autogen_core.tools import Tool, ToolSchema
6+
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
7+
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
8+
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
9+
from semantic_kernel.contents.chat_history import ChatHistory
10+
from semantic_kernel.contents.chat_message_content import ChatMessageContent
11+
from semantic_kernel.contents.utils.author_role import AuthorRole
12+
from semantic_kernel.kernel import Kernel
13+
from semantic_kernel.functions.kernel_plugin import KernelPlugin
14+
from typing_extensions import AsyncGenerator, Union
15+
from ._kernel_function_from_tool import KernelFunctionFromTool
16+
17+
18+
class SKChatCompletionAdapter(ChatCompletionClient):
19+
def __init__(self, sk_client: ChatCompletionClientBase):
20+
self._sk_client = sk_client
21+
self._total_prompt_tokens = 0
22+
self._total_completion_tokens = 0
23+
self._tools_plugin: Optional[KernelPlugin] = None
24+
25+
def _convert_to_chat_history(self, messages: Sequence[LLMMessage]) -> ChatHistory:
26+
"""Convert Autogen LLMMessages to SK ChatHistory"""
27+
chat_history = ChatHistory()
28+
29+
for msg in messages:
30+
if msg.type == "SystemMessage":
31+
chat_history.add_system_message(msg.content)
32+
33+
elif msg.type == "UserMessage":
34+
if isinstance(msg.content, str):
35+
chat_history.add_user_message(msg.content)
36+
else:
37+
# Handle list of str/Image - would need to convert to SK content types
38+
chat_history.add_user_message(str(msg.content))
39+
40+
elif msg.type == "AssistantMessage":
41+
if isinstance(msg.content, str):
42+
chat_history.add_assistant_message(msg.content)
43+
else:
44+
# Handle function calls - would need to convert to SK function call format
45+
chat_history.add_assistant_message(str(msg.content))
46+
47+
elif msg.type == "FunctionExecutionResultMessage":
48+
for result in msg.content:
49+
chat_history.add_tool_message(result.content)
50+
51+
return chat_history
52+
53+
def _convert_from_chat_message(self, message: ChatMessageContent, source: str = "assistant") -> LLMMessage:
54+
"""Convert SK ChatMessageContent to Autogen LLMMessage"""
55+
if message.role == AuthorRole.SYSTEM:
56+
return SystemMessage(content=message.content)
57+
58+
elif message.role == AuthorRole.USER:
59+
return UserMessage(content=message.content, source=source)
60+
61+
elif message.role == AuthorRole.ASSISTANT:
62+
return AssistantMessage(content=message.content, source=source)
63+
64+
elif message.role == AuthorRole.TOOL:
65+
return FunctionExecutionResultMessage(
66+
content=[FunctionExecutionResult(content=message.content, call_id="")]
67+
)
68+
69+
raise ValueError(f"Unknown role: {message.role}")
70+
71+
def _build_execution_settings(self, extra_create_args: Mapping[str, Any], tools: Sequence[Tool | ToolSchema]) -> PromptExecutionSettings:
72+
"""Build PromptExecutionSettings from extra_create_args"""
73+
# Extract service_id if provided, otherwise use None
74+
service_id = extra_create_args.get("service_id")
75+
76+
# If tools are available, configure function choice behavior with auto_invoke disabled
77+
function_choice_behavior = None
78+
if tools:
79+
function_choice_behavior = FunctionChoiceBehavior.NoneInvoke()
80+
81+
# Create settings with remaining args as extension_data
82+
settings = PromptExecutionSettings(
83+
service_id=service_id,
84+
extension_data=dict(extra_create_args),
85+
function_choice_behavior=function_choice_behavior
86+
)
87+
88+
return settings
89+
90+
def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSchema]) -> None:
91+
"""Sync tools with kernel by updating the plugin"""
92+
# Create new plugin if none exists
93+
if not self._tools_plugin:
94+
self._tools_plugin = KernelPlugin(name="autogen_tools")
95+
kernel.add_plugin(self._tools_plugin)
96+
97+
# Get current tool names in plugin
98+
current_tool_names = set(self._tools_plugin.functions.keys())
99+
100+
# Get new tool names
101+
new_tool_names = {tool.schema["name"] if isinstance(tool, Tool) else tool.name for tool in tools}
102+
103+
# Remove tools that are no longer needed
104+
for tool_name in current_tool_names - new_tool_names:
105+
del self._tools_plugin.functions[tool_name]
106+
107+
# Add or update tools
108+
for tool in tools:
109+
if isinstance(tool, Tool):
110+
# Convert Tool to KernelFunction using KernelFunctionFromTool
111+
kernel_function = KernelFunctionFromTool(tool, plugin_name="autogen_tools")
112+
self._tools_plugin.functions[tool.name] = kernel_function
113+
114+
async def create(
115+
self,
116+
messages: Sequence[LLMMessage],
117+
tools: Sequence[Tool | ToolSchema] = [],
118+
json_output: Optional[bool] = None,
119+
extra_create_args: Mapping[str, Any] = {},
120+
cancellation_token: Optional[CancellationToken] = None,
121+
) -> CreateResult:
122+
if "kernel" not in extra_create_args:
123+
raise ValueError("kernel is required in extra_create_args")
124+
125+
kernel = extra_create_args["kernel"]
126+
if not isinstance(kernel, Kernel):
127+
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")
128+
129+
chat_history = self._convert_to_chat_history(messages)
130+
131+
# Build execution settings from extra args and tools
132+
settings = self._build_execution_settings(extra_create_args, tools)
133+
134+
# Sync tools with kernel
135+
self._sync_tools_with_kernel(kernel, tools)
136+
137+
result = await self._sk_client.get_chat_message_contents(
138+
chat_history,
139+
settings=settings,
140+
kernel=kernel
141+
)
142+
# Track token usage from result metadata
143+
prompt_tokens = 0
144+
completion_tokens = 0
145+
146+
if result[0].metadata and 'usage' in result[0].metadata:
147+
usage = result[0].metadata['usage']
148+
prompt_tokens = getattr(usage, 'prompt_tokens', 0)
149+
completion_tokens = getattr(usage, 'completion_tokens', 0)
150+
151+
self._total_prompt_tokens += prompt_tokens
152+
self._total_completion_tokens += completion_tokens
153+
154+
return CreateResult(
155+
content=result[0].content,
156+
finish_reason="stop",
157+
usage=RequestUsage(
158+
prompt_tokens=prompt_tokens,
159+
completion_tokens=completion_tokens
160+
),
161+
cached=False
162+
)
163+
164+
async def create_stream(
165+
self,
166+
messages: Sequence[LLMMessage],
167+
tools: Sequence[Tool | ToolSchema] = [],
168+
json_output: Optional[bool] = None,
169+
extra_create_args: Mapping[str, Any] = {},
170+
cancellation_token: Optional[CancellationToken] = None,
171+
) -> AsyncGenerator[Union[str, CreateResult], None]:
172+
# Very similar to create(), but orchestrates streaming.
173+
# 1. Convert messages -> ChatHistory
174+
# 2. Possibly set function-calling if needed
175+
# 3. Build generator that yields str segments or a final CreateResult
176+
# from SK's get_streaming_chat_message_contents(...)
177+
raise NotImplementedError("create_stream is not implemented")
178+
179+
def actual_usage(self) -> RequestUsage:
180+
return RequestUsage(
181+
prompt_tokens=self._total_prompt_tokens,
182+
completion_tokens=self._total_completion_tokens
183+
)
184+
185+
def total_usage(self) -> RequestUsage:
186+
return RequestUsage(
187+
prompt_tokens=self._total_prompt_tokens,
188+
completion_tokens=self._total_completion_tokens
189+
)
190+
191+
def count_tokens(self, messages: Sequence[LLMMessage]) -> int:
192+
chat_history = self._convert_to_chat_history(messages)
193+
total_tokens = 0
194+
for message in chat_history.messages:
195+
if message.metadata and 'usage' in message.metadata:
196+
usage = message.metadata['usage']
197+
total_tokens += getattr(usage, 'total_tokens', 0)
198+
return total_tokens
199+
200+
def remaining_tokens(self, messages: Sequence[LLMMessage]) -> int:
201+
# Get total token count
202+
used_tokens = self.count_tokens(messages)
203+
# Assume max tokens from SK client if available, otherwise use default
204+
max_tokens = getattr(self._sk_client, 'max_tokens', 4096)
205+
return max_tokens - used_tokens
206+
207+
@property
208+
def capabilities(self) -> ModelCapabilities:
209+
# Return something consistent with the underlying SK client
210+
return {
211+
"vision": False,
212+
"function_calling": self._sk_client.SUPPORTS_FUNCTION_CALLING,
213+
"json_output": False,
214+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import os
2+
import pytest
3+
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
4+
from semantic_kernel.kernel import Kernel
5+
from semantic_kernel.memory.null_memory import NullMemory
6+
from autogen_core.models import SystemMessage, UserMessage
7+
from autogen_core.tools import BaseTool
8+
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
9+
from pydantic import BaseModel
10+
from autogen_core import CancellationToken
11+
12+
class CalculatorArgs(BaseModel):
13+
a: float
14+
b: float
15+
16+
class CalculatorResult(BaseModel):
17+
result: float
18+
19+
class CalculatorTool(BaseTool[CalculatorArgs, CalculatorResult]):
20+
def __init__(self):
21+
super().__init__(
22+
args_type=CalculatorArgs,
23+
return_type=CalculatorResult,
24+
name="calculator",
25+
description="Add two numbers together"
26+
)
27+
28+
async def run(self, args: CalculatorArgs, cancellation_token: CancellationToken) -> CalculatorResult:
29+
return CalculatorResult(result=args.a + args.b)
30+
31+
@pytest.mark.asyncio
32+
async def test_sk_chat_completion_with_tools():
33+
# Set up Azure OpenAI client with token auth
34+
deployment_name = "gpt-4o-mini"
35+
endpoint = "https://<your-endpoint>.openai.azure.com/"
36+
api_version = "2024-07-18"
37+
38+
# Create SK client
39+
sk_client = AzureChatCompletion(
40+
deployment_name=deployment_name,
41+
endpoint=endpoint,
42+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
43+
)
44+
45+
# Create adapter
46+
adapter = SKChatCompletionAdapter(sk_client)
47+
48+
# Create kernel
49+
kernel = Kernel(memory=NullMemory())
50+
51+
# Create calculator tool instance
52+
tool = CalculatorTool()
53+
54+
# Test messages
55+
messages = [
56+
SystemMessage(content="You are a helpful assistant."),
57+
UserMessage(content="What is 2 + 2?", source="user"),
58+
]
59+
60+
# Call create with tool
61+
result = await adapter.create(
62+
messages=messages,
63+
tools=[tool],
64+
extra_create_args={"kernel": kernel}
65+
)
66+
67+
68+
# Verify response
69+
assert isinstance(result.content, str)
70+
assert result.finish_reason == "stop"
71+
assert result.usage.prompt_tokens >= 0
72+
assert result.usage.completion_tokens >= 0
73+
assert not result.cached

0 commit comments

Comments
 (0)