diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py b/python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py index 8807fd28e50a..40c8f90ca21d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py @@ -63,9 +63,8 @@ class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]): .. code-block:: python import asyncio - from autogen_ext.models.openai import AzureOpenAIChatCompletionClient + from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.tools.graphrag import GlobalSearchTool - from azure.identity import DefaultAzureCredential, get_bearer_token_provider from autogen_agentchat.agents import AssistantAgent diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py b/python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py index c977b3522336..6380f7cfbeee 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py @@ -63,9 +63,8 @@ class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]): .. code-block:: python import asyncio - from autogen_ext.models.openai import AzureOpenAIChatCompletionClient + from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.tools.graphrag import LocalSearchTool - from azure.identity import DefaultAzureCredential, get_bearer_token_provider from autogen_agentchat.agents import AssistantAgent diff --git a/python/samples/agentchat_graphrag/app.py b/python/samples/agentchat_graphrag/app.py index 736aed159d47..6a6d1fd16fe4 100644 --- a/python/samples/agentchat_graphrag/app.py +++ b/python/samples/agentchat_graphrag/app.py @@ -1,32 +1,28 @@ +import argparse import asyncio +import json +import logging +from typing import Any, Dict from autogen_agentchat.messages import TextMessage -from autogen_ext.models.openai import AzureOpenAIChatCompletionClient from autogen_ext.tools.graphrag import ( GlobalSearchTool, LocalSearchTool, ) -from azure.identity import DefaultAzureCredential, get_bearer_token_provider from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination from autogen_agentchat.teams import RoundRobinGroupChat +from autogen_core.models import ChatCompletionClient -async def main(): - # Initialize the OpenAI client - openai_client = AzureOpenAIChatCompletionClient( - model="gpt-4o-mini", - azure_endpoint="https://.openai.azure.com", - azure_deployment="gpt-4o-mini", - api_version="2024-08-01-preview", - azure_ad_token_provider=get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default") - ) +async def main(model_config: Dict[str, Any]) -> None: + # Initialize the model client from config + model_client = ChatCompletionClient.load_component(model_config) # Set up global search tool global_tool = GlobalSearchTool.from_settings( settings_path="./settings.yaml" ) - local_tool = LocalSearchTool.from_settings( settings_path="./settings.yaml" ) @@ -35,7 +31,7 @@ async def main(): assistant_agent = AssistantAgent( name="search_assistant", tools=[global_tool, local_tool], - model_client=openai_client, + model_client=model_client, system_message=( "You are a tool selector AI assistant using the GraphRAG framework. " "Your primary task is to determine the appropriate search tool to call based on the user's query. " @@ -63,5 +59,18 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) + parser = argparse.ArgumentParser(description="Run a GraphRAG search with an agent.") + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.") + parser.add_argument( + "--model-config", type=str, help="Path to the model configuration file.", default="model_config.json" + ) + args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.WARNING) + logging.getLogger("autogen_core").setLevel(logging.DEBUG) + handler = logging.FileHandler("graphrag_search.log") + logging.getLogger("autogen_core").addHandler(handler) + with open(args.model_config, "r") as f: + model_config = json.load(f) + asyncio.run(main(model_config))