Skip to content

Commit 52edd19

Browse files
committed
linted
1 parent e54b0e6 commit 52edd19

File tree

5 files changed

+168
-152
lines changed

5 files changed

+168
-152
lines changed

Diff for: defog/llm/utils_mcp.py

+39-46
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@ def __init__(self, model_name=None):
4141
self.anthropic = None
4242
self.openai = None
4343
self.gemini = None
44-
44+
4545
if self.model_provider == "anthropic":
4646
self.anthropic = AsyncAnthropic()
4747
elif self.model_provider == "openai":
4848
self.openai = AsyncOpenAI()
4949
elif self.model_provider == "gemini":
5050
from google import genai
51+
5152
self.gemini = genai.Client(api_key=os.getenv("GEMINI_API_KEY", ""))
5253
else:
5354
raise ValueError(f"Unsupported model provider: {self.model_provider}")
@@ -265,20 +266,22 @@ async def _process_gemini_query(self, messages: list, available_tools: list):
265266

266267
# Convert tools format for Gemini
267268
from google.genai import types
269+
268270
gemini_tools = []
269271
for tool in available_tools:
270272
# Make a deep copy of the input schema to avoid modifying the original
271273
import copy
274+
272275
input_schema = copy.deepcopy(tool["input_schema"])
273-
276+
274277
# Change all "type" values to uppercase as required by Gemini
275278
if "type" in input_schema:
276279
input_schema["type"] = input_schema["type"].upper()
277280
if "properties" in input_schema:
278281
for prop in input_schema["properties"].values():
279282
if "type" in prop:
280283
prop["type"] = prop["type"].upper()
281-
284+
282285
func_spec = {
283286
"name": tool["name"],
284287
"description": tool["description"],
@@ -295,18 +298,18 @@ async def _process_gemini_query(self, messages: list, available_tools: list):
295298
"max_output_tokens": self.max_tokens,
296299
"tools": gemini_tools,
297300
}
298-
301+
299302
response = await self.gemini.aio.models.generate_content(
300303
model=self.model_name,
301304
contents=messages,
302305
config=types.GenerateContentConfig(**request_params),
303306
)
304-
307+
305308
try:
306309
response_text = response.text
307310
except Exception:
308311
response_text = None
309-
312+
310313
function_calls = getattr(response, "function_calls", [])
311314
except Exception as e:
312315
error_msg = f"Error calling Gemini API: {str(e)}"
@@ -323,69 +326,69 @@ async def _process_gemini_query(self, messages: list, available_tools: list):
323326
else:
324327
tool_args = function_call.args
325328
tool_id = function_call.name + "_" + str(len(self.tool_outputs))
326-
329+
327330
# Add tool call to message history
328331
tool_call_content = response.candidates[0].content
329332
self._add_to_message_history(tool_call_content, messages)
330-
333+
331334
# Handle the tool call
332335
result, result_text = await self._handle_tool_call(tool_name, tool_args)
333-
336+
334337
# Add tool result to message history
335338
tool_result_message = types.Content(
336339
role="function",
337340
parts=[
338341
types.Part.from_function_response(
339-
name=tool_name,
340-
response={"result": result_text}
342+
name=tool_name, response={"result": result_text}
341343
)
342-
]
344+
],
343345
)
344346
self._add_to_message_history(tool_result_message, messages)
345347

346348
# Add tool result to tool outputs
347-
self.tool_outputs.append({
348-
"tool_call_id": tool_id,
349-
"name": tool_name,
350-
"args": tool_args,
351-
"result": result_text,
352-
"text": response_text
353-
})
354-
349+
self.tool_outputs.append(
350+
{
351+
"tool_call_id": tool_id,
352+
"name": tool_name,
353+
"args": tool_args,
354+
"result": result_text,
355+
"text": response_text,
356+
}
357+
)
358+
355359
# Get next response from Gemini
356360
try:
357361
response = await self.gemini.aio.models.generate_content(
358362
model=self.model_name,
359363
contents=messages,
360-
config=types.GenerateContentConfig(**request_params)
364+
config=types.GenerateContentConfig(**request_params),
361365
)
362366

363367
try:
364368
response_text = response.text
365369
except Exception:
366370
response_text = None
367-
371+
368372
# Extract function calls
369373
function_calls = getattr(response, "function_calls", [])
370374
except Exception as e:
371375
error_msg = f"Error calling Gemini API: {str(e)}"
372376
print(error_msg)
373377
return error_msg
374-
378+
375379
# If no more function calls, break
376380
if not function_calls:
377381
break
378-
382+
379383
# Final response with no tool calls
380384
final_text = response_text
381-
385+
382386
# Add final assistant response to message history
383387
final_message = types.Content(
384-
role="model",
385-
parts=[types.Part.from_text(final_text)]
388+
role="model", parts=[types.Part.from_text(final_text)]
386389
)
387390
self.message_history.append(final_message)
388-
391+
389392
return final_text
390393

391394
async def _process_anthropic_query(self, messages: list, available_tools: list):
@@ -667,9 +670,7 @@ async def _process_prompt_templates(self, query: str) -> str:
667670
)
668671
return query
669672
except Exception as e:
670-
print(
671-
f"Error processing prompt template /{command}: {str(e)}"
672-
)
673+
print(f"Error processing prompt template /{command}: {str(e)}")
673674
return query
674675

675676
elif f"/{command}" in query:
@@ -695,9 +696,7 @@ async def _process_prompt_templates(self, query: str) -> str:
695696
)
696697
return query_text
697698
except Exception as e:
698-
print(
699-
f"Error processing prompt template /{command}: {str(e)}"
700-
)
699+
print(f"Error processing prompt template /{command}: {str(e)}")
701700
return query_text
702701
except Exception as e:
703702
print(f"Unexpected error processing prompt template: {str(e)}")
@@ -725,13 +724,13 @@ async def process_query(self, query: str) -> tuple[str, list[str]]:
725724
# Add user query to message history (format depends on provider)
726725
if self.model_provider == "gemini":
727726
from google.genai import types
727+
728728
user_message = types.Content(
729-
role="user",
730-
parts=[types.Part.from_text(query)]
729+
role="user", parts=[types.Part.from_text(query)]
731730
)
732731
else:
733732
user_message = {"role": "user", "content": query}
734-
733+
735734
self.message_history.append(user_message)
736735

737736
# Use full message history for context
@@ -814,9 +813,7 @@ async def _connect_to_mcp_sse_server(self, server_name: str, server_url: str):
814813
self.all_tools.append(tool)
815814
self.tool_to_server[tool.name] = server_name
816815
except Exception as e:
817-
print(
818-
f"Failed to list tools from server '{server_name}': {str(e)}"
819-
)
816+
print(f"Failed to list tools from server '{server_name}': {str(e)}")
820817
raise
821818

822819
# List and register available prompts
@@ -898,9 +895,7 @@ async def _connect_to_mcp_stdio_server(
898895
self.all_tools.append(tool)
899896
self.tool_to_server[tool.name] = server_name
900897
except Exception as e:
901-
print(
902-
f"Failed to list tools from server '{server_name}': {str(e)}"
903-
)
898+
print(f"Failed to list tools from server '{server_name}': {str(e)}")
904899
raise
905900

906901
# List and register available prompts
@@ -925,9 +920,7 @@ async def _connect_to_mcp_stdio_server(
925920
f"Timeout connecting to server '{server_name}': {str(e)}"
926921
)
927922
else:
928-
print(
929-
f"Failed to connect to server '{server_name}': {str(e)}"
930-
)
923+
print(f"Failed to connect to server '{server_name}': {str(e)}")
931924
raise
932925

933926
async def connect_to_server_from_config(self, config: dict):

Diff for: tests/mcp/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This file makes the tests/mcp directory a Python package

Diff for: tests/mcp/mcp_arithmetic_sse.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
host = "0.0.0.0"
44
port = 8001
55
mcp = FastMCP(
6-
name = "arithmetic_sse",
7-
host = host,
8-
port = port,
9-
)
6+
name="arithmetic_sse",
7+
host=host,
8+
port=port,
9+
)
10+
1011

1112
@mcp.tool()
1213
def add(a: int, b: int) -> int:
1314
"""Add two numbers together"""
1415
return a + b
1516

17+
1618
@mcp.tool()
1719
def multiply(a: int, b: int) -> int:
1820
"""Multiply two numbers together"""

Diff for: tests/mcp/mcp_arithmetic_stdio.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33

44
mcp = FastMCP(
5-
name = "arithmetic_sse",
6-
)
5+
name="arithmetic_sse",
6+
)
7+
78

89
@mcp.tool()
910
def add(a: int, b: int) -> int:
1011
"""Add two numbers together"""
1112
return a + b
1213

14+
1315
@mcp.tool()
1416
def multiply(a: int, b: int) -> int:
1517
"""Multiply two numbers together"""

0 commit comments

Comments
 (0)