|
| 1 | +from typing import Any, Mapping, Optional, Sequence |
| 2 | +from autogen_core._cancellation_token import CancellationToken |
| 3 | +from autogen_core.models import RequestUsage, FunctionExecutionResultMessage, ModelCapabilities, AssistantMessage, SystemMessage, UserMessage, FunctionExecutionResult |
| 4 | +from autogen_core.models import ChatCompletionClient, CreateResult, LLMMessage |
| 5 | +from autogen_core.tools import Tool, ToolSchema |
| 6 | +from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase |
| 7 | +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings |
| 8 | +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior |
| 9 | +from semantic_kernel.contents.chat_history import ChatHistory |
| 10 | +from semantic_kernel.contents.chat_message_content import ChatMessageContent |
| 11 | +from semantic_kernel.contents.utils.author_role import AuthorRole |
| 12 | +from semantic_kernel.kernel import Kernel |
| 13 | +from semantic_kernel.functions.kernel_plugin import KernelPlugin |
| 14 | +from typing_extensions import AsyncGenerator, Union |
| 15 | +from ._kernel_function_from_tool import KernelFunctionFromTool |
| 16 | + |
| 17 | + |
| 18 | +class SKChatCompletionAdapter(ChatCompletionClient): |
| 19 | + def __init__(self, sk_client: ChatCompletionClientBase): |
| 20 | + self._sk_client = sk_client |
| 21 | + self._total_prompt_tokens = 0 |
| 22 | + self._total_completion_tokens = 0 |
| 23 | + self._tools_plugin: Optional[KernelPlugin] = None |
| 24 | + |
| 25 | + def _convert_to_chat_history(self, messages: Sequence[LLMMessage]) -> ChatHistory: |
| 26 | + """Convert Autogen LLMMessages to SK ChatHistory""" |
| 27 | + chat_history = ChatHistory() |
| 28 | + |
| 29 | + for msg in messages: |
| 30 | + if msg.type == "SystemMessage": |
| 31 | + chat_history.add_system_message(msg.content) |
| 32 | + |
| 33 | + elif msg.type == "UserMessage": |
| 34 | + if isinstance(msg.content, str): |
| 35 | + chat_history.add_user_message(msg.content) |
| 36 | + else: |
| 37 | + # Handle list of str/Image - would need to convert to SK content types |
| 38 | + chat_history.add_user_message(str(msg.content)) |
| 39 | + |
| 40 | + elif msg.type == "AssistantMessage": |
| 41 | + if isinstance(msg.content, str): |
| 42 | + chat_history.add_assistant_message(msg.content) |
| 43 | + else: |
| 44 | + # Handle function calls - would need to convert to SK function call format |
| 45 | + chat_history.add_assistant_message(str(msg.content)) |
| 46 | + |
| 47 | + elif msg.type == "FunctionExecutionResultMessage": |
| 48 | + for result in msg.content: |
| 49 | + chat_history.add_tool_message(result.content) |
| 50 | + |
| 51 | + return chat_history |
| 52 | + |
| 53 | + def _convert_from_chat_message(self, message: ChatMessageContent, source: str = "assistant") -> LLMMessage: |
| 54 | + """Convert SK ChatMessageContent to Autogen LLMMessage""" |
| 55 | + if message.role == AuthorRole.SYSTEM: |
| 56 | + return SystemMessage(content=message.content) |
| 57 | + |
| 58 | + elif message.role == AuthorRole.USER: |
| 59 | + return UserMessage(content=message.content, source=source) |
| 60 | + |
| 61 | + elif message.role == AuthorRole.ASSISTANT: |
| 62 | + return AssistantMessage(content=message.content, source=source) |
| 63 | + |
| 64 | + elif message.role == AuthorRole.TOOL: |
| 65 | + return FunctionExecutionResultMessage( |
| 66 | + content=[FunctionExecutionResult(content=message.content, call_id="")] |
| 67 | + ) |
| 68 | + |
| 69 | + raise ValueError(f"Unknown role: {message.role}") |
| 70 | + |
| 71 | + def _build_execution_settings(self, extra_create_args: Mapping[str, Any], tools: Sequence[Tool | ToolSchema]) -> PromptExecutionSettings: |
| 72 | + """Build PromptExecutionSettings from extra_create_args""" |
| 73 | + # Extract service_id if provided, otherwise use None |
| 74 | + service_id = extra_create_args.get("service_id") |
| 75 | + |
| 76 | + # If tools are available, configure function choice behavior with auto_invoke disabled |
| 77 | + function_choice_behavior = None |
| 78 | + if tools: |
| 79 | + function_choice_behavior = FunctionChoiceBehavior.NoneInvoke() |
| 80 | + |
| 81 | + # Create settings with remaining args as extension_data |
| 82 | + settings = PromptExecutionSettings( |
| 83 | + service_id=service_id, |
| 84 | + extension_data=dict(extra_create_args), |
| 85 | + function_choice_behavior=function_choice_behavior |
| 86 | + ) |
| 87 | + |
| 88 | + return settings |
| 89 | + |
| 90 | + def _sync_tools_with_kernel(self, kernel: Kernel, tools: Sequence[Tool | ToolSchema]) -> None: |
| 91 | + """Sync tools with kernel by updating the plugin""" |
| 92 | + # Create new plugin if none exists |
| 93 | + if not self._tools_plugin: |
| 94 | + self._tools_plugin = KernelPlugin(name="autogen_tools") |
| 95 | + kernel.add_plugin(self._tools_plugin) |
| 96 | + |
| 97 | + # Get current tool names in plugin |
| 98 | + current_tool_names = set(self._tools_plugin.functions.keys()) |
| 99 | + |
| 100 | + # Get new tool names |
| 101 | + new_tool_names = {tool.schema["name"] if isinstance(tool, Tool) else tool.name for tool in tools} |
| 102 | + |
| 103 | + # Remove tools that are no longer needed |
| 104 | + for tool_name in current_tool_names - new_tool_names: |
| 105 | + del self._tools_plugin.functions[tool_name] |
| 106 | + |
| 107 | + # Add or update tools |
| 108 | + for tool in tools: |
| 109 | + if isinstance(tool, Tool): |
| 110 | + # Convert Tool to KernelFunction using KernelFunctionFromTool |
| 111 | + kernel_function = KernelFunctionFromTool(tool, plugin_name="autogen_tools") |
| 112 | + self._tools_plugin.functions[tool.name] = kernel_function |
| 113 | + |
| 114 | + async def create( |
| 115 | + self, |
| 116 | + messages: Sequence[LLMMessage], |
| 117 | + tools: Sequence[Tool | ToolSchema] = [], |
| 118 | + json_output: Optional[bool] = None, |
| 119 | + extra_create_args: Mapping[str, Any] = {}, |
| 120 | + cancellation_token: Optional[CancellationToken] = None, |
| 121 | + ) -> CreateResult: |
| 122 | + if "kernel" not in extra_create_args: |
| 123 | + raise ValueError("kernel is required in extra_create_args") |
| 124 | + |
| 125 | + kernel = extra_create_args["kernel"] |
| 126 | + if not isinstance(kernel, Kernel): |
| 127 | + raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel") |
| 128 | + |
| 129 | + chat_history = self._convert_to_chat_history(messages) |
| 130 | + |
| 131 | + # Build execution settings from extra args and tools |
| 132 | + settings = self._build_execution_settings(extra_create_args, tools) |
| 133 | + |
| 134 | + # Sync tools with kernel |
| 135 | + self._sync_tools_with_kernel(kernel, tools) |
| 136 | + |
| 137 | + result = await self._sk_client.get_chat_message_contents( |
| 138 | + chat_history, |
| 139 | + settings=settings, |
| 140 | + kernel=kernel |
| 141 | + ) |
| 142 | + # Track token usage from result metadata |
| 143 | + prompt_tokens = 0 |
| 144 | + completion_tokens = 0 |
| 145 | + |
| 146 | + if result[0].metadata and 'usage' in result[0].metadata: |
| 147 | + usage = result[0].metadata['usage'] |
| 148 | + prompt_tokens = getattr(usage, 'prompt_tokens', 0) |
| 149 | + completion_tokens = getattr(usage, 'completion_tokens', 0) |
| 150 | + |
| 151 | + self._total_prompt_tokens += prompt_tokens |
| 152 | + self._total_completion_tokens += completion_tokens |
| 153 | + |
| 154 | + return CreateResult( |
| 155 | + content=result[0].content, |
| 156 | + finish_reason="stop", |
| 157 | + usage=RequestUsage( |
| 158 | + prompt_tokens=prompt_tokens, |
| 159 | + completion_tokens=completion_tokens |
| 160 | + ), |
| 161 | + cached=False |
| 162 | + ) |
| 163 | + |
| 164 | + async def create_stream( |
| 165 | + self, |
| 166 | + messages: Sequence[LLMMessage], |
| 167 | + tools: Sequence[Tool | ToolSchema] = [], |
| 168 | + json_output: Optional[bool] = None, |
| 169 | + extra_create_args: Mapping[str, Any] = {}, |
| 170 | + cancellation_token: Optional[CancellationToken] = None, |
| 171 | + ) -> AsyncGenerator[Union[str, CreateResult], None]: |
| 172 | + # Very similar to create(), but orchestrates streaming. |
| 173 | + # 1. Convert messages -> ChatHistory |
| 174 | + # 2. Possibly set function-calling if needed |
| 175 | + # 3. Build generator that yields str segments or a final CreateResult |
| 176 | + # from SK's get_streaming_chat_message_contents(...) |
| 177 | + raise NotImplementedError("create_stream is not implemented") |
| 178 | + |
| 179 | + def actual_usage(self) -> RequestUsage: |
| 180 | + return RequestUsage( |
| 181 | + prompt_tokens=self._total_prompt_tokens, |
| 182 | + completion_tokens=self._total_completion_tokens |
| 183 | + ) |
| 184 | + |
| 185 | + def total_usage(self) -> RequestUsage: |
| 186 | + return RequestUsage( |
| 187 | + prompt_tokens=self._total_prompt_tokens, |
| 188 | + completion_tokens=self._total_completion_tokens |
| 189 | + ) |
| 190 | + |
| 191 | + def count_tokens(self, messages: Sequence[LLMMessage]) -> int: |
| 192 | + chat_history = self._convert_to_chat_history(messages) |
| 193 | + total_tokens = 0 |
| 194 | + for message in chat_history.messages: |
| 195 | + if message.metadata and 'usage' in message.metadata: |
| 196 | + usage = message.metadata['usage'] |
| 197 | + total_tokens += getattr(usage, 'total_tokens', 0) |
| 198 | + return total_tokens |
| 199 | + |
| 200 | + def remaining_tokens(self, messages: Sequence[LLMMessage]) -> int: |
| 201 | + # Get total token count |
| 202 | + used_tokens = self.count_tokens(messages) |
| 203 | + # Assume max tokens from SK client if available, otherwise use default |
| 204 | + max_tokens = getattr(self._sk_client, 'max_tokens', 4096) |
| 205 | + return max_tokens - used_tokens |
| 206 | + |
| 207 | + @property |
| 208 | + def capabilities(self) -> ModelCapabilities: |
| 209 | + # Return something consistent with the underlying SK client |
| 210 | + return { |
| 211 | + "vision": False, |
| 212 | + "function_calling": self._sk_client.SUPPORTS_FUNCTION_CALLING, |
| 213 | + "json_output": False, |
| 214 | + } |
0 commit comments