Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sema4ai/bin/create_env/conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ dependencies:
- sema4ai-http-helper==2.1.0
- sema4ai-common==0.1.0
- ruff==0.11.11
- mcp==1.12.0
1 change: 1 addition & 0 deletions sema4ai/bin/create_env/conda_vscode_darwin_amd64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ dependencies:
- sema4ai-http-helper==2.1.0
- sema4ai-common==0.1.0
- ruff==0.11.11
- mcp==1.12.0
1 change: 1 addition & 0 deletions sema4ai/bin/create_env/conda_vscode_darwin_arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ dependencies:
- sema4ai-http-helper==2.1.0
- sema4ai-common==0.1.0
- ruff==0.11.11
- mcp==1.12.0
1 change: 1 addition & 0 deletions sema4ai/bin/create_env/conda_vscode_linux_amd64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dependencies:
- sema4ai-http-helper==2.1.0
- sema4ai-common==0.1.0
- ruff==0.11.11
- mcp==1.12.0
1 change: 1 addition & 0 deletions sema4ai/bin/create_env/conda_vscode_windows_amd64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ dependencies:
- sema4ai-http-helper==2.1.0
- sema4ai-common==0.1.0
- ruff==0.11.11
- mcp==1.12.0
762 changes: 644 additions & 118 deletions sema4ai/poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion sema4ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ msgspec = "^0.19"
pyyaml = "^6"
playwright = "<2"
pillow = "^10.2.0"
pywin32 = { version = ">=300,<304", platform = "win32", python = "!=3.8.1" }
pywin32 = { version = ">=300,<307", platform = "win32", python = "!=3.8.1" }
psutil = "^5.9.0"
comtypes = "^1.2"
pyscreeze = "^0.1.30"
Expand All @@ -46,6 +46,7 @@ tree-sitter-yaml = "^0.7.0"
ruamel-yaml = "^0.18.10"
sema4ai-http-helper = "^2.1.0"
sema4ai-common = "^0.1.0"
mcp = "~1.10.1"

[tool.poetry.group.dev.dependencies]
sema4ai-python-ls-core = { path = "../sema4ai-python-ls-core/", develop = true }
Expand Down
231 changes: 225 additions & 6 deletions sema4ai/src/sema4ai_code/robocorp_language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,223 @@ def _fix_wrong_agent_import(self, agent_dir, monitor: IMonitor) -> ActionResultD
def m_fix_wrong_agent_import(self, agent_dir) -> ActionResultDict:
return require_monitor(partial(self._fix_wrong_agent_import, agent_dir))

def _validate_mcp_server_config(
self, mcp_server_config: MCP_SERVER_CONFIG, agent_dir: str
) -> ActionResultDict:
"""
Validates the MCP server configuration based on transport type.
"""
transport = mcp_server_config.get("transport", "")
if not transport:
return ActionResult.make_failure("Transport type is required").as_dict()

if transport == "stdio":
command_line = mcp_server_config.get("commandLine")
if not command_line:
return ActionResult.make_failure(
"Command line is required for STDIO transport"
).as_dict()

# Validate cwd if provided
cwd = mcp_server_config.get("cwd")
if not cwd:
return ActionResult.make_failure(
"Working directory is required for STDIO transport"
).as_dict()

cwd_path = Path(cwd)
if not cwd_path.is_absolute():
cwd_path = Path(agent_dir) / cwd

if not cwd_path.exists():
return ActionResult.make_failure(
f"Working directory does not exist: {cwd}"
).as_dict()

elif transport in ["streamable-http", "sse", "auto"]:
url = mcp_server_config.get("url")
if not url:
return ActionResult.make_failure(
"Server URL is required for HTTP-based transports"
).as_dict()

# Basic URL validation
if not url.startswith(("http://", "https://")):
return ActionResult.make_failure(
"URL must start with http:// or https://"
).as_dict()
else:
return ActionResult.make_failure(
f"Unsupported transport type: {transport}"
).as_dict()

return ActionResult(success=True, message=None).as_dict()

def m_test_mcp_server(
self, mcp_server_config: MCP_SERVER_CONFIG, agent_dir: str
) -> ActionResultDict:
return require_monitor(
partial(self._test_mcp_server, mcp_server_config, agent_dir)
)

def _test_mcp_server(
self, mcp_server_config: MCP_SERVER_CONFIG, agent_dir: str, monitor: IMonitor
) -> ActionResultDict:
import asyncio
import shlex

from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client

# Validate configuration first
validation_result = self._validate_mcp_server_config(
mcp_server_config, agent_dir
)
if not validation_result["success"]:
return validation_result

def _process_dynamic_vars(vars: dict) -> dict:
processed_env = {}
for key, value in vars.items():
if isinstance(value, str):
# Plain text value
processed_env[key] = value
elif isinstance(value, dict):
if value.get("type") == "string" and value.get("default"):
processed_env[key] = value["default"]

return processed_env

def _extract_root_exception(exc: Exception) -> str:
"""Extract the root cause from ExceptionGroup or return the original exception message."""
if hasattr(exc, "exceptions"): # ExceptionGroup
# Get the first exception from the group and recurse
for sub_exc in exc.exceptions:
return _extract_root_exception(sub_exc)
return str(exc)

async def _handle_client_session(read, write) -> int:
async with ClientSession(read, write) as session:
await session.initialize()
tools_result = await session.list_tools()

return len(tools_result.tools) if tools_result.tools else 0

async def _list_server_tools(mcp_server_config: dict, transport: str) -> dict:
try:
if transport == "sse":
async with sse_client(**mcp_server_config) as (read, write):
tools_count = await _handle_client_session(read, write)

elif transport == "stdio":
async with stdio_client(
StdioServerParameters(**mcp_server_config)
) as (
read,
write,
):
tools_count = await _handle_client_session(read, write)

else:
async with streamablehttp_client(**mcp_server_config) as (
read,
write,
_,
):
tools_count = await _handle_client_session(read, write)
except Exception as e:
return {
"success": False,
"message": f"Failed to validate {transport} transport: {_extract_root_exception(e)}, Config: {mcp_server_config}",
}
return {
"success": True,
"message": f"Successfully connected to MCP server via {transport}. Found {tools_count} tools.",
}

async def validate_server():
transport = (
mcp_server_config["transport"]
if mcp_server_config["transport"] != "auto"
else "streamable-http"
)
try:
if transport == "stdio":
command_line_str = mcp_server_config.get("commandLine", "")
try:
command_line = shlex.split(command_line_str)
except Exception as e:
return {
"success": False,
"message": f"Failed to parse command line: {e}",
}

cwd = mcp_server_config["cwd"]
if cwd:
cwd_path = Path(cwd)
if not cwd_path.is_absolute():
cwd_path = Path(agent_dir) / cwd
cwd = str(cwd_path)

env_vars = mcp_server_config.get("env")
processed_env = (
_process_dynamic_vars(env_vars) if env_vars else None
)

server_config = {
"command": command_line[0],
"args": command_line[1:] if len(command_line) > 1 else [],
"cwd": cwd,
"env": processed_env,
}
return await _list_server_tools(server_config, transport)

elif transport in ["sse", "streamable-http"]:
headers = mcp_server_config.get("headers", {})
processed_headers = (
_process_dynamic_vars(headers) if headers else {}
)

server_config = {
"url": mcp_server_config.get("url"),
"headers": processed_headers,
}
return await _list_server_tools(server_config, transport)

else:
return {
"success": False,
"message": f"Unsupported transport type: {mcp_server_config['transport']}",
}
except Exception as e:
return {
"success": False,
"message": f"Failed to validate MCP server: {_extract_root_exception(e)}",
}

try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(validate_server())
finally:
loop.close()

if result["success"]:
return ActionResult(
success=True,
message=result["message"],
).as_dict()
else:
return ActionResult.make_failure(result["message"]).as_dict()
except Exception as e:
log.exception("Error validating MCP server")
return ActionResult.make_failure(
f"Error validating MCP server: {e}"
).as_dict()

def m_add_mcp_server(
self, agent_dir: str, mcp_server_config: MCP_SERVER_CONFIG
) -> ActionResultDict:
Expand All @@ -2423,6 +2640,13 @@ def _add_mcp_server(
from ruamel.yaml import YAML, CommentedSeq
from ruamel.yaml.scalarstring import DoubleQuotedScalarString

# Validate configuration first
validation_result = self._validate_mcp_server_config(
mcp_server_config, agent_dir
)
if not validation_result["success"]:
return validation_result

try:
agent_spec_path = Path(agent_dir) / "agent-spec.yaml"
if not agent_spec_path.exists():
Expand Down Expand Up @@ -2462,12 +2686,7 @@ def _add_mcp_server(
mcp_server_entry["transport"] = mcp_server_config["transport"]

if mcp_server_config["transport"] == "stdio":
command_line_str = mcp_server_config.get("commandLine", "")
if not command_line_str:
return ActionResult.make_failure(
"Command line is required for STDIO transport"
).as_dict()

command_line_str = mcp_server_config["commandLine"]
try:
command_line = shlex.split(command_line_str)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ send_chat_message_with_options:
messages:
default: []
items:
additionalProperties: true
type: object
title: Messages
type: array
Expand All @@ -73,6 +74,7 @@ send_chat_message_with_options:
items:
properties:
filter:
additionalProperties: true
description: A filter object to apply to the search results.
title: Filter
type: object
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from sema4ai_code_tests.protocols import IRobocorpLanguageServerClient
from sema4ai_ls_core.basic import implements
from sema4ai_ls_core.unittest_tools.language_server_client import LanguageServerClient

Expand All @@ -8,6 +7,7 @@
UploadNewRobotParamsDict,
UploadRobotParamsDict,
)
from sema4ai_code_tests.protocols import IRobocorpLanguageServerClient


class RobocorpLanguageServerClient(LanguageServerClient):
Expand Down
Loading
Loading