Skip to content

Commit f34eb8f

Browse files
committed
fix docstring client imports
1 parent 4431968 commit f34eb8f

File tree

3 files changed

+25
-18
lines changed

3 files changed

+25
-18
lines changed

python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
6363
.. code-block:: python
6464
6565
import asyncio
66-
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
66+
from autogen_ext.models.openai import OpenAIChatCompletionClient
6767
from autogen_ext.tools.graphrag import GlobalSearchTool
68-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
6968
from autogen_agentchat.agents import AssistantAgent
7069
7170

python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]):
6363
.. code-block:: python
6464
6565
import asyncio
66-
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
66+
from autogen_ext.models.openai import OpenAIChatCompletionClient
6767
from autogen_ext.tools.graphrag import LocalSearchTool
68-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
6968
from autogen_agentchat.agents import AssistantAgent
7069
7170

python/samples/agentchat_graphrag/app.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,28 @@
1+
import argparse
12
import asyncio
3+
import json
4+
import logging
5+
from typing import Any, Dict
26
from autogen_agentchat.messages import TextMessage
3-
from autogen_ext.models.openai import AzureOpenAIChatCompletionClient
47
from autogen_ext.tools.graphrag import (
58
GlobalSearchTool,
69
LocalSearchTool,
710
)
8-
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
911
from autogen_agentchat.agents import AssistantAgent
1012
from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination
1113
from autogen_agentchat.teams import RoundRobinGroupChat
14+
from autogen_core.models import ChatCompletionClient
1215

1316

14-
async def main():
15-
# Initialize the OpenAI client
16-
openai_client = AzureOpenAIChatCompletionClient(
17-
model="gpt-4o-mini",
18-
azure_endpoint="https://<resource-name>.openai.azure.com",
19-
azure_deployment="gpt-4o-mini",
20-
api_version="2024-08-01-preview",
21-
azure_ad_token_provider=get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default")
22-
)
17+
async def main(model_config: Dict[str, Any]) -> None:
18+
# Initialize the model client from config
19+
model_client = ChatCompletionClient.load_component(model_config)
2320

2421
# Set up global search tool
2522
global_tool = GlobalSearchTool.from_settings(
2623
settings_path="./settings.yaml"
2724
)
2825

29-
3026
local_tool = LocalSearchTool.from_settings(
3127
settings_path="./settings.yaml"
3228
)
@@ -35,7 +31,7 @@ async def main():
3531
assistant_agent = AssistantAgent(
3632
name="search_assistant",
3733
tools=[global_tool, local_tool],
38-
model_client=openai_client,
34+
model_client=model_client,
3935
system_message=(
4036
"You are a tool selector AI assistant using the GraphRAG framework. "
4137
"Your primary task is to determine the appropriate search tool to call based on the user's query. "
@@ -63,5 +59,18 @@ async def main():
6359

6460

6561
if __name__ == "__main__":
66-
asyncio.run(main())
62+
parser = argparse.ArgumentParser(description="Run a GraphRAG search with an agent.")
63+
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.")
64+
parser.add_argument(
65+
"--model-config", type=str, help="Path to the model configuration file.", default="model_config.json"
66+
)
67+
args = parser.parse_args()
68+
if args.verbose:
69+
logging.basicConfig(level=logging.WARNING)
70+
logging.getLogger("autogen_core").setLevel(logging.DEBUG)
71+
handler = logging.FileHandler("graphrag_search.log")
72+
logging.getLogger("autogen_core").addHandler(handler)
6773

74+
with open(args.model_config, "r") as f:
75+
model_config = json.load(f)
76+
asyncio.run(main(model_config))

0 commit comments

Comments
 (0)