Skip to content

Commit 3ff14c1

Browse files
committed
Add the ability to specify allowed tools for MCP servers
Signed-off-by: Donnie Adams <[email protected]>
1 parent 664d93c commit 3ff14c1

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

pkg/mcp/loader.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ type ServerConfig struct {
5656
BaseURL string `json:"baseURL,omitempty"`
5757
Headers []string `json:"headers"`
5858
Scope string `json:"scope"`
59+
AllowedTools []string `json:"allowedTools"`
5960
}
6061

6162
func (s *ServerConfig) GetBaseURL() string {
@@ -99,7 +100,7 @@ func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool,
99100
}
100101

101102
for server := range maps.Keys(servers.MCPServers) {
102-
tools, err := l.LoadSession(ctx, servers.MCPServers[server], tool.Name)
103+
tools, err := l.LoadTools(ctx, servers.MCPServers[server], tool.Name)
103104
if err != nil {
104105
return nil, fmt.Errorf("failed to load MCP session for server %s: %w", server, err)
105106
}
@@ -111,13 +112,17 @@ func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool,
111112
return nil, fmt.Errorf("no MCP server configuration found in tool instructions: %s", configData)
112113
}
113114

114-
func (l *Local) LoadSession(ctx context.Context, server ServerConfig, toolName string) ([]types.Tool, error) {
115+
func (l *Local) LoadTools(ctx context.Context, server ServerConfig, toolName string) ([]types.Tool, error) {
116+
allowedTools := server.AllowedTools
117+
// Reset so we don't start a new MCP server, no reason to if one is already running and the allowed tools change.
118+
server.AllowedTools = nil
119+
115120
session, err := l.loadSession(server)
116121
if err != nil {
117122
return nil, err
118123
}
119124

120-
return l.sessionToTools(ctx, session, toolName)
125+
return l.sessionToTools(ctx, session, toolName, allowedTools)
121126
}
122127

123128
func (l *Local) Close() error {
@@ -148,7 +153,9 @@ func (l *Local) Close() error {
148153
return errors.Join(errs...)
149154
}
150155

151-
func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName string) ([]types.Tool, error) {
156+
func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName string, allowedTools []string) ([]types.Tool, error) {
157+
allToolsAllowed := len(allowedTools) == 0 || slices.Contains(allowedTools, "*")
158+
152159
tools, err := session.Client.ListTools(ctx, mcp.ListToolsRequest{})
153160
if err != nil {
154161
return nil, fmt.Errorf("failed to list tools: %w", err)
@@ -158,6 +165,10 @@ func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName s
158165
var toolNames []string
159166

160167
for _, tool := range tools.Tools {
168+
if !allToolsAllowed && !slices.Contains(allowedTools, tool.Name) {
169+
continue
170+
}
171+
161172
var schema openapi3.Schema
162173

163174
schemaData, err := json.Marshal(tool.InputSchema)

0 commit comments

Comments
 (0)