@@ -41,13 +41,14 @@ def __init__(self, model_name=None):
41
41
self .anthropic = None
42
42
self .openai = None
43
43
self .gemini = None
44
-
44
+
45
45
if self .model_provider == "anthropic" :
46
46
self .anthropic = AsyncAnthropic ()
47
47
elif self .model_provider == "openai" :
48
48
self .openai = AsyncOpenAI ()
49
49
elif self .model_provider == "gemini" :
50
50
from google import genai
51
+
51
52
self .gemini = genai .Client (api_key = os .getenv ("GEMINI_API_KEY" , "" ))
52
53
else :
53
54
raise ValueError (f"Unsupported model provider: { self .model_provider } " )
@@ -265,20 +266,22 @@ async def _process_gemini_query(self, messages: list, available_tools: list):
265
266
266
267
# Convert tools format for Gemini
267
268
from google .genai import types
269
+
268
270
gemini_tools = []
269
271
for tool in available_tools :
270
272
# Make a deep copy of the input schema to avoid modifying the original
271
273
import copy
274
+
272
275
input_schema = copy .deepcopy (tool ["input_schema" ])
273
-
276
+
274
277
# Change all "type" values to uppercase as required by Gemini
275
278
if "type" in input_schema :
276
279
input_schema ["type" ] = input_schema ["type" ].upper ()
277
280
if "properties" in input_schema :
278
281
for prop in input_schema ["properties" ].values ():
279
282
if "type" in prop :
280
283
prop ["type" ] = prop ["type" ].upper ()
281
-
284
+
282
285
func_spec = {
283
286
"name" : tool ["name" ],
284
287
"description" : tool ["description" ],
@@ -295,18 +298,18 @@ async def _process_gemini_query(self, messages: list, available_tools: list):
295
298
"max_output_tokens" : self .max_tokens ,
296
299
"tools" : gemini_tools ,
297
300
}
298
-
301
+
299
302
response = await self .gemini .aio .models .generate_content (
300
303
model = self .model_name ,
301
304
contents = messages ,
302
305
config = types .GenerateContentConfig (** request_params ),
303
306
)
304
-
307
+
305
308
try :
306
309
response_text = response .text
307
310
except Exception :
308
311
response_text = None
309
-
312
+
310
313
function_calls = getattr (response , "function_calls" , [])
311
314
except Exception as e :
312
315
error_msg = f"Error calling Gemini API: { str (e )} "
@@ -323,69 +326,69 @@ async def _process_gemini_query(self, messages: list, available_tools: list):
323
326
else :
324
327
tool_args = function_call .args
325
328
tool_id = function_call .name + "_" + str (len (self .tool_outputs ))
326
-
329
+
327
330
# Add tool call to message history
328
331
tool_call_content = response .candidates [0 ].content
329
332
self ._add_to_message_history (tool_call_content , messages )
330
-
333
+
331
334
# Handle the tool call
332
335
result , result_text = await self ._handle_tool_call (tool_name , tool_args )
333
-
336
+
334
337
# Add tool result to message history
335
338
tool_result_message = types .Content (
336
339
role = "function" ,
337
340
parts = [
338
341
types .Part .from_function_response (
339
- name = tool_name ,
340
- response = {"result" : result_text }
342
+ name = tool_name , response = {"result" : result_text }
341
343
)
342
- ]
344
+ ],
343
345
)
344
346
self ._add_to_message_history (tool_result_message , messages )
345
347
346
348
# Add tool result to tool outputs
347
- self .tool_outputs .append ({
348
- "tool_call_id" : tool_id ,
349
- "name" : tool_name ,
350
- "args" : tool_args ,
351
- "result" : result_text ,
352
- "text" : response_text
353
- })
354
-
349
+ self .tool_outputs .append (
350
+ {
351
+ "tool_call_id" : tool_id ,
352
+ "name" : tool_name ,
353
+ "args" : tool_args ,
354
+ "result" : result_text ,
355
+ "text" : response_text ,
356
+ }
357
+ )
358
+
355
359
# Get next response from Gemini
356
360
try :
357
361
response = await self .gemini .aio .models .generate_content (
358
362
model = self .model_name ,
359
363
contents = messages ,
360
- config = types .GenerateContentConfig (** request_params )
364
+ config = types .GenerateContentConfig (** request_params ),
361
365
)
362
366
363
367
try :
364
368
response_text = response .text
365
369
except Exception :
366
370
response_text = None
367
-
371
+
368
372
# Extract function calls
369
373
function_calls = getattr (response , "function_calls" , [])
370
374
except Exception as e :
371
375
error_msg = f"Error calling Gemini API: { str (e )} "
372
376
print (error_msg )
373
377
return error_msg
374
-
378
+
375
379
# If no more function calls, break
376
380
if not function_calls :
377
381
break
378
-
382
+
379
383
# Final response with no tool calls
380
384
final_text = response_text
381
-
385
+
382
386
# Add final assistant response to message history
383
387
final_message = types .Content (
384
- role = "model" ,
385
- parts = [types .Part .from_text (final_text )]
388
+ role = "model" , parts = [types .Part .from_text (final_text )]
386
389
)
387
390
self .message_history .append (final_message )
388
-
391
+
389
392
return final_text
390
393
391
394
async def _process_anthropic_query (self , messages : list , available_tools : list ):
@@ -667,9 +670,7 @@ async def _process_prompt_templates(self, query: str) -> str:
667
670
)
668
671
return query
669
672
except Exception as e :
670
- print (
671
- f"Error processing prompt template /{ command } : { str (e )} "
672
- )
673
+ print (f"Error processing prompt template /{ command } : { str (e )} " )
673
674
return query
674
675
675
676
elif f"/{ command } " in query :
@@ -695,9 +696,7 @@ async def _process_prompt_templates(self, query: str) -> str:
695
696
)
696
697
return query_text
697
698
except Exception as e :
698
- print (
699
- f"Error processing prompt template /{ command } : { str (e )} "
700
- )
699
+ print (f"Error processing prompt template /{ command } : { str (e )} " )
701
700
return query_text
702
701
except Exception as e :
703
702
print (f"Unexpected error processing prompt template: { str (e )} " )
@@ -725,13 +724,13 @@ async def process_query(self, query: str) -> tuple[str, list[str]]:
725
724
# Add user query to message history (format depends on provider)
726
725
if self .model_provider == "gemini" :
727
726
from google .genai import types
727
+
728
728
user_message = types .Content (
729
- role = "user" ,
730
- parts = [types .Part .from_text (query )]
729
+ role = "user" , parts = [types .Part .from_text (query )]
731
730
)
732
731
else :
733
732
user_message = {"role" : "user" , "content" : query }
734
-
733
+
735
734
self .message_history .append (user_message )
736
735
737
736
# Use full message history for context
@@ -814,9 +813,7 @@ async def _connect_to_mcp_sse_server(self, server_name: str, server_url: str):
814
813
self .all_tools .append (tool )
815
814
self .tool_to_server [tool .name ] = server_name
816
815
except Exception as e :
817
- print (
818
- f"Failed to list tools from server '{ server_name } ': { str (e )} "
819
- )
816
+ print (f"Failed to list tools from server '{ server_name } ': { str (e )} " )
820
817
raise
821
818
822
819
# List and register available prompts
@@ -898,9 +895,7 @@ async def _connect_to_mcp_stdio_server(
898
895
self .all_tools .append (tool )
899
896
self .tool_to_server [tool .name ] = server_name
900
897
except Exception as e :
901
- print (
902
- f"Failed to list tools from server '{ server_name } ': { str (e )} "
903
- )
898
+ print (f"Failed to list tools from server '{ server_name } ': { str (e )} " )
904
899
raise
905
900
906
901
# List and register available prompts
@@ -925,9 +920,7 @@ async def _connect_to_mcp_stdio_server(
925
920
f"Timeout connecting to server '{ server_name } ': { str (e )} "
926
921
)
927
922
else :
928
- print (
929
- f"Failed to connect to server '{ server_name } ': { str (e )} "
930
- )
923
+ print (f"Failed to connect to server '{ server_name } ': { str (e )} " )
931
924
raise
932
925
933
926
async def connect_to_server_from_config (self , config : dict ):
0 commit comments