Skip to content

Commit

Permalink
add docstring and format
Browse files Browse the repository at this point in the history
  • Loading branch information
lspinheiro committed Jan 12, 2025
1 parent 107d743 commit e5a894e
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Mapping, Optional, Sequence
from typing import Any, Literal, Mapping, Optional, Sequence

from autogen_core import FunctionCall
from autogen_core._cancellation_token import CancellationToken
Expand Down Expand Up @@ -33,6 +33,122 @@


class SKChatCompletionAdapter(ChatCompletionClient):
"""
SKChatCompletionAdapter is an adapter that allows using Semantic Kernel model clients
as Autogen ChatCompletion clients. This makes it possible to seamlessly integrate
Semantic Kernel connectors (e.g., Azure OpenAI, Google Gemini, Ollama, etc.) into
Autogen agents that rely on a ChatCompletionClient interface.
By leveraging this adapter, you can:
- Pass in a `Kernel` and any supported Semantic Kernel `ChatCompletionClientBase` connector.
- Provide tools (via Autogen `Tool` or `ToolSchema`) for function calls during chat completion.
- Stream responses or retrieve them in a single request.
Args:
sk_client (ChatCompletionClientBase):
The Semantic Kernel client to wrap (e.g., AzureChatCompletion, GoogleAIChatCompletion, OllamaChatCompletion).
Example usage:
.. code-block:: python
import asyncio
from semantic_kernel import Kernel
from semantic_kernel.memory.null_memory import NullMemory
from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion
from semantic_kernel.connectors.ai.google.google_ai import GoogleAIChatCompletion
from semantic_kernel.connectors.ai.ollama import OllamaChatCompletion
from semantic_kernel.connectors.ai.ollama.ollama_prompt_execution_settings import OllamaChatPromptExecutionSettings
from autogen_core.models import SystemMessage, UserMessage, LLMMessage
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
from autogen_core import CancellationToken
from autogen_core.tools import BaseTool
from pydantic import BaseModel
# 1) Basic tool definition (for demonstration)
class CalculatorArgs(BaseModel):
a: float
b: float
class CalculatorResult(BaseModel):
result: float
class CalculatorTool(BaseTool[CalculatorArgs, CalculatorResult]):
def __init__(self) -> None:
super().__init__(
args_type=CalculatorArgs,
return_type=CalculatorResult,
name="calculator",
description="Add two numbers together",
)
async def run(self, args: CalculatorArgs, cancellation_token: CancellationToken) -> CalculatorResult:
return CalculatorResult(result=args.a + args.b)
async def main():
# 2) Create a Semantic Kernel instance (with null memory for simplicity)
kernel = Kernel(memory=NullMemory())
# ----------------------------------------------------------------
# Example A: Azure OpenAI
# ----------------------------------------------------------------
deployment_name = "<AZURE_OPENAI_DEPLOYMENT_NAME>"
endpoint = "<AZURE_OPENAI_ENDPOINT>"
api_key = "<AZURE_OPENAI_API_KEY>"
azure_client = AzureChatCompletion(deployment_name=deployment_name, endpoint=endpoint, api_key=api_key)
azure_adapter = SKChatCompletionAdapter(sk_client=azure_client)
# ----------------------------------------------------------------
# Example B: Google Gemini
# ----------------------------------------------------------------
google_api_key = "<GCP_API_KEY>"
google_model = "gemini-1.5-flash"
google_client = GoogleAIChatCompletion(model=google_model, api_key=google_api_key)
google_adapter = SKChatCompletionAdapter(sk_client=google_client)
# ----------------------------------------------------------------
# Example C: Ollama (local Llama-based model)
# ----------------------------------------------------------------
ollama_client = OllamaChatCompletion(
service_id="ollama", # custom ID
host="http://localhost:11434",
ai_model_id="llama3.1",
)
ollama_adapter = SKChatCompletionAdapter(sk_client=ollama_client)
# 3) Create a tool and register it with the kernel
calc_tool = CalculatorTool()
# 4) Prepare messages for a chat completion
messages: list[LLMMessage] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What is 2 + 2?"),
]
# 5) Invoke chat completion with the Azure adapter (as an example)
# Provide the kernel in extra_create_args, and pass the tool.
# The same pattern applies to Google or Ollama adapters.
result = await azure_adapter.create(messages=messages, tools=[calc_tool], extra_create_args={"kernel": kernel})
# Print or use the result
print("Result content:", result.content)
print("Finish reason:", result.finish_reason)
# Note: Tools are invoked if the model calls them (function calls).
# If the model simply returns text, you get text.
# You can also stream with `create_stream(...)` using the same approach.
if __name__ == "__main__":
asyncio.run(main())
"""

def __init__(self, sk_client: ChatCompletionClientBase):
self._sk_client = sk_client
self._total_prompt_tokens = 0
Expand Down Expand Up @@ -197,7 +313,7 @@ async def create(
content: Union[str, list[FunctionCall]]
if any(isinstance(item, FunctionCallContent) for item in result[0].items):
content = self._process_tool_calls(result[0])
finish_reason = "function_calls"
finish_reason: Literal["function_calls", "stop"] = "function_calls"
else:
content = result[0].content
finish_reason = "stop"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
from autogen_core import CancellationToken
from autogen_core.models import CreateResult, SystemMessage, UserMessage
from autogen_core.models import CreateResult, LLMMessage, SystemMessage, UserMessage
from autogen_core.tools import BaseTool
from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter
from pydantic import BaseModel
Expand All @@ -21,7 +21,7 @@ class CalculatorResult(BaseModel):


class CalculatorTool(BaseTool[CalculatorArgs, CalculatorResult]):
def __init__(self):
def __init__(self) -> None:
super().__init__(
args_type=CalculatorArgs,
return_type=CalculatorResult,
Expand Down Expand Up @@ -58,7 +58,7 @@ async def test_sk_chat_completion_with_tools(sk_client: AzureChatCompletion) ->
tool = CalculatorTool()

# Test messages
messages = [
messages: list[LLMMessage] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What is 2 + 2?", source="user"),
]
Expand All @@ -81,7 +81,7 @@ async def test_sk_chat_completion_without_tools(sk_client: AzureChatCompletion)
kernel = Kernel(memory=NullMemory())

# Test messages
messages = [
messages: list[LLMMessage] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Say hello!", source="user"),
]
Expand All @@ -107,7 +107,7 @@ async def test_sk_chat_completion_stream_with_tools(sk_client: AzureChatCompleti
tool = CalculatorTool()

# Test messages
messages = [
messages: list[LLMMessage] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="What is 2 + 2?", source="user"),
]
Expand Down Expand Up @@ -135,7 +135,7 @@ async def test_sk_chat_completion_stream_without_tools(sk_client: AzureChatCompl
kernel = Kernel(memory=NullMemory())

# Test messages
messages = [
messages: list[LLMMessage] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Say hello!", source="user"),
]
Expand Down

0 comments on commit e5a894e

Please sign in to comment.