Skip to content

Add MCP server tool filtering support to agents-js #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changeset/hungry-suns-search.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
'@openai/agents-realtime': patch
'@openai/agents-core': patch
---

agents-core, agents-realtime: add MCP tool-filtering support (fixes #162)
21 changes: 21 additions & 0 deletions docs/src/content/docs/guides/mcp.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,27 @@ For **Streamable HTTP** and **Stdio** servers, each time an `Agent` runs it may

Only enable this if you're confident the tool list won't change. To invalidate the cache later, call `invalidateToolsCache()` on the server instance.

### Tool filtering

You can restrict which tools are exposed from each server. Pass either a static filter
using `createMCPToolStaticFilter` or a custom function:

```ts
const server = new MCPServerStdio({
fullCommand: 'my-server',
toolFilter: createMCPToolStaticFilter({
allowed: ['safe_tool'],
blocked: ['danger_tool'],
}),
});

const dynamicServer = new MCPServerStreamableHttp({
url: 'http://localhost:3000',
toolFilter: ({ runContext }, tool) =>
runContext.context.allowAll || tool.name !== 'admin',
});
```

## Further reading

- [Model Context Protocol](https://modelcontextprotocol.io/) – official specification.
Expand Down
6 changes: 6 additions & 0 deletions examples/mcp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ Run the example from the repository root:
```bash
pnpm -F mcp start:stdio
```

`tool-filter-example.ts` shows how to expose only a subset of server tools:

```bash
pnpm -F mcp start:tool-filter
```
3 changes: 2 additions & 1 deletion examples/mcp/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"start:streamable-http": "tsx streamable-http-example.ts",
"start:hosted-mcp-on-approval": "tsx hosted-mcp-on-approval.ts",
"start:hosted-mcp-human-in-the-loop": "tsx hosted-mcp-human-in-the-loop.ts",
"start:hosted-mcp-simple": "tsx hosted-mcp-simple.ts"
"start:hosted-mcp-simple": "tsx hosted-mcp-simple.ts",
"start:tool-filter": "tsx tool-filter-example.ts"
}
}
53 changes: 53 additions & 0 deletions examples/mcp/tool-filter-example.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import {
Agent,
run,
MCPServerStdio,
createMCPToolStaticFilter,
withTrace,
} from '@openai/agents';
import * as path from 'node:path';

async function main() {
const samplesDir = path.join(__dirname, 'sample_files');
const mcpServer = new MCPServerStdio({
name: 'Filesystem Server with filter',
fullCommand: `npx -y @modelcontextprotocol/server-filesystem ${samplesDir}`,
toolFilter: createMCPToolStaticFilter({
allowed: ['read_file', 'list_directory'],
blocked: ['write_file'],
}),
});

await mcpServer.connect();

try {
await withTrace('MCP Tool Filter Example', async () => {
const agent = new Agent({
name: 'MCP Assistant',
instructions: 'Use the filesystem tools to answer questions.',
mcpServers: [mcpServer],
});

console.log('Listing sample files:');
let result = await run(
agent,
'List the files in the sample_files directory.',
);
console.log(result.finalOutput);

console.log('\nAttempting to write a file (should be blocked):');
result = await run(
agent,
'Create a file named sample_files/test.txt with the text "hello"',
);
console.log(result.finalOutput);
});
} finally {
await mcpServer.close();
}
}

main().catch((err) => {
console.error(err);
process.exit(1);
});
12 changes: 8 additions & 4 deletions packages/agents-core/src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,11 @@ export class Agent<
* Fetches the available tools from the MCP servers.
* @returns the MCP powered tools
*/
async getMcpTools(): Promise<Tool<TContext>[]> {
async getMcpTools(
runContext: RunContext<TContext>,
): Promise<Tool<TContext>[]> {
if (this.mcpServers.length > 0) {
return getAllMcpTools(this.mcpServers);
return getAllMcpTools(this.mcpServers, runContext, this, false);
}

return [];
Expand All @@ -527,8 +529,10 @@ export class Agent<
*
* @returns all configured tools
*/
async getAllTools(): Promise<Tool<TContext>[]> {
return [...(await this.getMcpTools()), ...this.tools];
async getAllTools(
runContext: RunContext<TContext>,
): Promise<Tool<TContext>[]> {
return [...(await this.getMcpTools(runContext)), ...this.tools];
}

/**
Expand Down
6 changes: 6 additions & 0 deletions packages/agents-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ export {
MCPServerStdio,
MCPServerStreamableHttp,
} from './mcp';
export {
MCPToolFilterCallable,
MCPToolFilterContext,
MCPToolFilterStatic,
createMCPToolStaticFilter,
} from './mcpUtil';
export {
Model,
ModelProvider,
Expand Down
73 changes: 71 additions & 2 deletions packages/agents-core/src/mcp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import {
JsonObjectSchemaStrict,
UnknownContext,
} from './types';
import type { MCPToolFilterCallable, MCPToolFilterStatic } from './mcpUtil';
import type { RunContext } from './runContext';
import type { Agent } from './agent';

export const DEFAULT_STDIO_MCP_CLIENT_LOGGER_NAME =
'openai-agents:stdio-mcp-client';
Expand All @@ -27,6 +30,7 @@ export const DEFAULT_STREAMABLE_HTTP_MCP_CLIENT_LOGGER_NAME =
*/
export interface MCPServer {
cacheToolsList: boolean;
toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
connect(): Promise<void>;
readonly name: string;
close(): Promise<void>;
Expand All @@ -40,12 +44,14 @@ export interface MCPServer {
export abstract class BaseMCPServerStdio implements MCPServer {
public cacheToolsList: boolean;
protected _cachedTools: any[] | undefined = undefined;
public toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;

protected logger: Logger;
constructor(options: MCPServerStdioOptions) {
this.logger =
options.logger ?? getLogger(DEFAULT_STDIO_MCP_CLIENT_LOGGER_NAME);
this.cacheToolsList = options.cacheToolsList ?? false;
this.toolFilter = options.toolFilter;
}

abstract get name(): string;
Expand All @@ -72,13 +78,15 @@ export abstract class BaseMCPServerStdio implements MCPServer {
export abstract class BaseMCPServerStreamableHttp implements MCPServer {
public cacheToolsList: boolean;
protected _cachedTools: any[] | undefined = undefined;
public toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;

protected logger: Logger;
constructor(options: MCPServerStreamableHttpOptions) {
this.logger =
options.logger ??
getLogger(DEFAULT_STREAMABLE_HTTP_MCP_CLIENT_LOGGER_NAME);
this.cacheToolsList = options.cacheToolsList ?? false;
this.toolFilter = options.toolFilter;
}

abstract get name(): string;
Expand Down Expand Up @@ -195,13 +203,17 @@ export class MCPServerStreamableHttp extends BaseMCPServerStreamableHttp {
*/
export async function getAllMcpFunctionTools<TContext = UnknownContext>(
mcpServers: MCPServer[],
runContext: RunContext<TContext>,
agent: Agent<any, any>,
convertSchemasToStrict = false,
): Promise<Tool<TContext>[]> {
const allTools: Tool<TContext>[] = [];
const toolNames = new Set<string>();
for (const server of mcpServers) {
const serverTools = await getFunctionToolsFromServer(
server,
runContext,
agent,
convertSchemasToStrict,
);
const serverToolNames = new Set(serverTools.map((t) => t.name));
Expand Down Expand Up @@ -233,6 +245,8 @@ export function invalidateServerToolsCache(serverName: string) {
*/
async function getFunctionToolsFromServer<TContext = UnknownContext>(
server: MCPServer,
runContext: RunContext<TContext>,
agent: Agent<any, any>,
convertSchemasToStrict: boolean,
): Promise<FunctionTool<TContext, any, unknown>[]> {
if (server.cacheToolsList && _cachedTools[server.name]) {
Expand All @@ -242,7 +256,53 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
}
return withMCPListToolsSpan(
async (span) => {
const mcpTools = await server.listTools();
const fetchedMcpTools = await server.listTools();
const mcpTools: MCPTool[] = [];
const context = {
runContext,
agent,
serverName: server.name,
};
for (const tool of fetchedMcpTools) {
const filter = server.toolFilter;
if (filter) {
if (filter && typeof filter === 'function') {
const filtered = await filter(context, tool);
if (!filtered) {
globalLogger.debug(
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the callable filter.`,
);
continue; // skip this tool
}
} else {
const allowedToolNames = filter.allowedToolNames ?? [];
const blockedToolNames = filter.blockedToolNames ?? [];
if (allowedToolNames.length > 0 || blockedToolNames.length > 0) {
const allowed =
allowedToolNames.length > 0
? allowedToolNames.includes(tool.name)
: true;
const blocked =
blockedToolNames.length > 0
? blockedToolNames.includes(tool.name)
: false;
if (!allowed || blocked) {
if (blocked) {
globalLogger.debug(
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is blocked by the static filter.`,
);
} else if (!allowed) {
globalLogger.debug(
`MCP Tool (server: ${server.name}, tool: ${tool.name}) is not allowed by the static filter.`,
);
}
continue; // skip this tool
}
}
}
}
mcpTools.push(tool);
}
span.spanData.result = mcpTools.map((t) => t.name);
const tools: FunctionTool<TContext, any, string>[] = mcpTools.map((t) =>
mcpToFunctionTool(t, server, convertSchemasToStrict),
Expand All @@ -261,9 +321,16 @@ async function getFunctionToolsFromServer<TContext = UnknownContext>(
*/
export async function getAllMcpTools<TContext = UnknownContext>(
mcpServers: MCPServer[],
runContext: RunContext<TContext>,
agent: Agent<TContext, any>,
convertSchemasToStrict = false,
): Promise<Tool<TContext>[]> {
return getAllMcpFunctionTools(mcpServers, convertSchemasToStrict);
return getAllMcpFunctionTools(
mcpServers,
runContext,
agent,
convertSchemasToStrict,
);
}

/**
Expand Down Expand Up @@ -353,6 +420,7 @@ export interface BaseMCPServerStdioOptions {
encoding?: string;
encodingErrorHandler?: 'strict' | 'ignore' | 'replace';
logger?: Logger;
toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;
}
export interface DefaultMCPServerStdioOptions
extends BaseMCPServerStdioOptions {
Expand All @@ -373,6 +441,7 @@ export interface MCPServerStreamableHttpOptions {
clientSessionTimeoutSeconds?: number;
name?: string;
logger?: Logger;
toolFilter?: MCPToolFilterCallable | MCPToolFilterStatic;

// ----------------------------------------------------
// OAuth
Expand Down
46 changes: 46 additions & 0 deletions packages/agents-core/src/mcpUtil.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import type { Agent } from './agent';
import type { RunContext } from './runContext';
import type { MCPTool } from './mcp';
import type { UnknownContext } from './types';

/** Context information available to tool filter functions. */
export interface MCPToolFilterContext<TContext = UnknownContext> {
/** The current run context. */
runContext: RunContext<TContext>;
/** The agent requesting the tools. */
agent: Agent<TContext, any>;
/** Name of the MCP server providing the tools. */
serverName: string;
}

/** A function that determines whether a tool should be available. */
export type MCPToolFilterCallable<TContext = UnknownContext> = (
context: MCPToolFilterContext<TContext>,
tool: MCPTool,
) => Promise<boolean>;

/** Static tool filter configuration using allow and block lists. */
export interface MCPToolFilterStatic {
/** Optional list of tool names to allow. */
allowedToolNames?: string[];
/** Optional list of tool names to block. */
blockedToolNames?: string[];
}

/** Convenience helper to create a static tool filter. */
export function createMCPToolStaticFilter(options?: {
allowed?: string[];
blocked?: string[];
}): MCPToolFilterStatic | undefined {
if (!options?.allowed && !options?.blocked) {
return undefined;
}
const filter: MCPToolFilterStatic = {};
if (options?.allowed) {
filter.allowedToolNames = options.allowed;
}
if (options?.blocked) {
filter.blockedToolNames = options.blocked;
}
return filter;
}
4 changes: 2 additions & 2 deletions packages/agents-core/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
setCurrentSpan(state._currentAgentSpan);
}

const tools = await state._currentAgent.getAllTools();
const tools = await state._currentAgent.getAllTools(state._context);
const serializedTools = tools.map((t) => serializeTool(t));
const serializedHandoffs = handoffs.map((h) => serializeHandoff(h));
if (state._currentAgentSpan) {
Expand Down Expand Up @@ -615,7 +615,7 @@ export class Runner extends RunHooks<any, AgentOutputType<unknown>> {
while (true) {
const currentAgent = result.state._currentAgent;
const handoffs = currentAgent.handoffs.map(getHandoff);
const tools = await currentAgent.getAllTools();
const tools = await currentAgent.getAllTools(result.state._context);
const serializedTools = tools.map((t) => serializeTool(t));
const serializedHandoffs = handoffs.map((h) => serializeHandoff(h));

Expand Down
4 changes: 3 additions & 1 deletion packages/agents-core/src/runState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ export class RunState<TContext, TAgent extends Agent<any, any>> {
? await deserializeProcessedResponse(
agentMap,
state._currentAgent,
state._context,
stateJson.lastProcessedResponse,
)
: undefined;
Expand Down Expand Up @@ -707,11 +708,12 @@ export function deserializeItem(
async function deserializeProcessedResponse<TContext = UnknownContext>(
agentMap: Map<string, Agent<any, any>>,
currentAgent: Agent<TContext, any>,
context: RunContext<TContext>,
serializedProcessedResponse: z.infer<
typeof serializedProcessedResponseSchema
>,
): Promise<ProcessedResponse<TContext>> {
const allTools = await currentAgent.getAllTools();
const allTools = await currentAgent.getAllTools(context);
const tools = new Map(
allTools
.filter((tool) => tool.type === 'function')
Expand Down
Loading