Skip to content

Commit 78b6026

Browse files
authored
[Feature] MCPToolTransform (#2993)
1 parent c6440df commit 78b6026

File tree

13 files changed

+1162
-62
lines changed

13 files changed

+1162
-62
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ Intermediate
105105
tutorials/dqn_with_rnn
106106
tutorials/rb_tutorial
107107
tutorials/export
108+
tutorials/llm_browser
108109

109110
Advanced
110111
--------

docs/source/reference/llms.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Therefore, the fundamental structure of an LLM post-training pipeline is:
5252
- An environment that handles the world around the LLM:
5353
- Loading data (through :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer`)
5454
- Formatting data (through :class:`~torchrl.envs.llm.transforms.TemplateTransform`)
55-
- Executing tools (through :class:`~torchrl.envs.llm.transforms.PythonInterpreter`)
55+
- Executing tools (through :class:`~torchrl.envs.llm.transforms.PythonInterpreter` or :class:`~torchrl.envs.llm.transforms.MCPToolTransform`)
5656
- Computing rewards online, if needed (through :class:`~torchrl.envs.llm.transforms.KLRewardTransform`)
5757
- A data collector that takes the policy (the LLM) and the environment, and handles the inference part of the pipeline:
5858
- Running reset, step and gathering actions;
@@ -179,6 +179,8 @@ transforms).
179179

180180
DataLoadingPrimer
181181
KLRewardTransform
182+
MCPToolTransform
183+
BrowserTransform
182184
PythonInterpreter
183185
TemplateTransform
184186
Tokenizer

setup.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,20 @@ def _main(argv):
237237
],
238238
"marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot"],
239239
"open_spiel": ["open_spiel>=1.5"],
240+
"llm": [
241+
"transformers", # For tokenizer and model support
242+
"vllm", # For efficient inference
243+
"playwright", # For browser automation
244+
"datasets", # For data loading
245+
"langdetect", # For language detection in IFEval
246+
"nltk", # For text processing in IFEval
247+
"immutabledict", # For IFEval
248+
"accelerate", # For model loading and inference
249+
"sentencepiece", # For tokenization
250+
"protobuf", # Required by some models
251+
"einops", # For tensor operations
252+
"safetensors", # For model loading
253+
],
240254
}
241255
extra_requires["all"] = set()
242256
for key in list(extra_requires.keys()):

test/llm/test_envs.py

Lines changed: 221 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import argparse
88
import contextlib
99
import importlib.util
10+
import random
1011
import re
12+
import time
1113

1214
import pytest
1315
import torch
@@ -434,7 +436,7 @@ def test_chat_env(slef, tokenizer):
434436
)
435437
)
436438
# Check history after reset
437-
torchrl_logger.info('td_reset["history"].content', td_reset["history"].content)
439+
torchrl_logger.info(f'{td_reset["history"].content=}')
438440
assert len(td_reset["history"][0].content) == 2
439441
assert td_reset["history"][0, 0].content == "I'm system, do what I want."
440442
assert td_reset["history"][0, 1].content.startswith("I'm the user.")
@@ -593,9 +595,10 @@ def test_ifeval(self):
593595
env = IFEvalEnv(apply_template=True, tokenizer=tokenizer)
594596
torchrl_logger.info(env.reset())
595597
r = env.reset()
596-
r[0][
597-
"text_response"
598-
] = """<think>
598+
r.set(
599+
"text_response",
600+
[
601+
"""<think>
599602
The task requires crafting a riddle about a 'house' that's not traditionally considered one. The answer must be included, and the response should be at least 400 words with a title wrapped in double angular brackets. Let's start by brainstorming what could be considered a 'house' in a non-traditional sense. Ideas include natural shelters, abstract concepts, or objects that serve a similar purpose to a house.
600603
One potential concept is a "womb," as it provides shelter and housing for a developing being. However, we need to ensure our riddle is engaging, meets the word count requirement, and includes the necessary elements like a title.
601604
Let's construct a narrative around the chosen concept, ensuring it's detailed and follows the required structure.
@@ -637,6 +640,8 @@ def test_ifeval(self):
637640
By embracing such metaphors, we're encouraged to look beyond the obvious and appreciate the myriad ways 'shelter' manifests in our lives. And so, the riddle serves not just as a puzzle to be solved but as a reflection on the profound connections that bind us to the very essence of existence.
638641
</answer><|im_end|>
639642
"""
643+
],
644+
)
640645
td = env.step(r)
641646
assert td["next", "ifeval_score"].all()
642647
assert td.get(("next", "reward")) is not None
@@ -881,7 +886,7 @@ def test_python_interpreter_persistent_reset(self):
881886
r["text_response"] = [
882887
"""Here is a python code to execute:
883888
```python
884-
# check if a is still defined
889+
# check if a is still defined
885890
if "a" in globals():
886891
raise RuntimeError("a is still defined")
887892
else:
@@ -899,7 +904,7 @@ def test_python_interpreter_persistent_reset(self):
899904
"<|im_start|>assistant\n"
900905
"Here is a python code to execute:\n"
901906
"```python\n"
902-
"#\xa0check if a is still defined\n"
907+
"# check if a is still defined\n"
903908
'if "a" in globals():\n'
904909
' raise RuntimeError("a is still defined")\n'
905910
"else:\n"
@@ -914,6 +919,216 @@ def test_python_interpreter_persistent_reset(self):
914919
"<|im_start|>assistant\n",
915920
)
916921

922+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers")
923+
def test_mcp_tool_transform(self):
924+
"""Test the MCPToolTransform with a simple calculator tool."""
925+
from torchrl.envs.llm import ChatEnv
926+
from torchrl.envs.llm.transforms.tools import MCPToolTransform
927+
from transformers import AutoTokenizer
928+
929+
# Define a simple calculator tool
930+
def calculator(operation: str, a: float, b: float) -> dict:
931+
if operation == "add":
932+
return {"result": a + b}
933+
elif operation == "multiply":
934+
return {"result": a * b}
935+
else:
936+
raise ValueError(f"Unknown operation: {operation}")
937+
938+
# Define the tool schema
939+
calculator_schema = {
940+
"name": "calculator",
941+
"description": "A simple calculator that can add or multiply two numbers",
942+
"parameters": {
943+
"type": "object",
944+
"properties": {
945+
"operation": {"type": "string", "enum": ["add", "multiply"]},
946+
"a": {"type": "number"},
947+
"b": {"type": "number"},
948+
},
949+
"required": ["operation", "a", "b"],
950+
},
951+
}
952+
953+
# Create tools dictionary
954+
tools = {"calculator": calculator}
955+
schemas = {"calculator": calculator_schema}
956+
957+
# Create environment and transform
958+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
959+
env = ChatEnv(
960+
batch_size=(1,),
961+
system_prompt="You are a helpful assistant that uses a calculator.",
962+
apply_template=True,
963+
tokenizer=tokenizer,
964+
)
965+
transform = MCPToolTransform(tools, schemas)
966+
env = env.append_transform(transform)
967+
968+
# Test single tool call
969+
td = TensorDict({"text": ["Let me calculate 2 + 3"]}, batch_size=(1,))
970+
td = env.reset(td)
971+
td["text_response"] = [
972+
'I will help you calculate 2 + 3:\n<tool>calculator\n{"operation": "add", "a": 2, "b": 3}</tool><|im_end|>'
973+
]
974+
result = env.step(td)
975+
976+
# Check that the tool was executed and returned correct result
977+
history = result["next", "history"]
978+
assert len(history[0]) == 4 # system, user, assistant, tool response
979+
assert history[0, -1].role == "tool"
980+
assert "result': 5" in history[0, -1].content
981+
982+
# Test multiple tool calls in one response
983+
td = TensorDict({"text": ["Calculate 2 + 3 and 4 * 5"]}, batch_size=(1,))
984+
td = env.reset(td)
985+
td["text_response"] = [
986+
"I will help you calculate both:\n"
987+
'<tool>calculator\n{"operation": "add", "a": 2, "b": 3}</tool>\n'
988+
'<tool>calculator\n{"operation": "multiply", "a": 4, "b": 5}</tool><|im_end|>'
989+
]
990+
result = env.step(td)
991+
992+
# Check that both tools were executed and returned correct results
993+
history = result["next", "history"]
994+
assert (
995+
len(history[0]) == 5
996+
) # system, user, assistant, tool response 1, tool response 2
997+
assert history[0, -2].role == "tool"
998+
assert history[0, -1].role == "tool"
999+
assert "result': 5" in history[0, -2].content # 2 + 3 = 5
1000+
assert "result': 20" in history[0, -1].content # 4 * 5 = 20
1001+
1002+
# Test error handling
1003+
td = TensorDict({"text": ["Calculate 2 ? 3"]}, batch_size=(1,))
1004+
td = env.reset(td)
1005+
td["text_response"] = [
1006+
'I will try to calculate:\n<tool>calculator\n{"operation": "invalid", "a": 2, "b": 3}</tool><|im_end|>'
1007+
]
1008+
result = env.step(td)
1009+
1010+
# Check that error was handled gracefully
1011+
history = result["next", "history"]
1012+
assert len(history[0]) == 4
1013+
assert history[0, -1].role == "tool"
1014+
assert "failed" in history[0, -1].content
1015+
assert "Unknown operation: invalid" in history[0, -1].content
1016+
1017+
# Test invalid JSON
1018+
td = TensorDict({"text": ["Calculate something"]}, batch_size=(1,))
1019+
td = env.reset(td)
1020+
td["text_response"] = [
1021+
"Let me calculate:\n<tool>calculator\ninvalid json</tool><|im_end|>"
1022+
]
1023+
result = env.step(td)
1024+
1025+
# Check that JSON error was handled gracefully
1026+
history = result["next", "history"]
1027+
assert len(history[0]) == 4
1028+
assert history[0, -1].role == "tool"
1029+
assert "failed" in history[0, -1].content
1030+
assert "Failed to parse tool arguments" in history[0, -1].content
1031+
1032+
# Define a tool that waits for a random amount of time
1033+
@classmethod
1034+
def delayed_calculator(cls, operation: str, a: float, b: float) -> dict:
1035+
# Random delay between 100ms and 300ms
1036+
delay = random.uniform(0.1, 0.3)
1037+
time.sleep(delay)
1038+
if operation == "add":
1039+
return {"result": a + b, "delay": delay}
1040+
elif operation == "multiply":
1041+
return {"result": a * b, "delay": delay}
1042+
else:
1043+
raise ValueError(f"Unknown operation: {operation}")
1044+
1045+
# Define the tool schema
1046+
calculator_schema = {
1047+
"name": "delayed_calculator",
1048+
"description": "A calculator that introduces random delays",
1049+
"parameters": {
1050+
"type": "object",
1051+
"properties": {
1052+
"operation": {"type": "string", "enum": ["add", "multiply"]},
1053+
"a": {"type": "number"},
1054+
"b": {"type": "number"},
1055+
},
1056+
"required": ["operation", "a", "b"],
1057+
},
1058+
}
1059+
1060+
# Create environment factory
1061+
@classmethod
1062+
def make_env(cls):
1063+
from torchrl.envs.llm.transforms.tools import MCPToolTransform
1064+
1065+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
1066+
env = ChatEnv(
1067+
batch_size=(1,),
1068+
system_prompt="I'm a calculator assistant",
1069+
apply_template=True,
1070+
tokenizer=tokenizer,
1071+
)
1072+
tools = {"calculator": cls.delayed_calculator}
1073+
schemas = {"calculator": cls.calculator_schema}
1074+
return env.append_transform(MCPToolTransform(tools, schemas))
1075+
1076+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers")
1077+
def test_async_mcp_tools(self):
1078+
"""Test async execution of MCP tools in an AsyncEnvPool."""
1079+
from tensordict import TensorDict
1080+
from torchrl.envs import AsyncEnvPool
1081+
1082+
# Create async env pool with 2 environments
1083+
env_pool = AsyncEnvPool(
1084+
[self.make_env, self.make_env], backend="multiprocessing"
1085+
)
1086+
try:
1087+
# Reset both environments
1088+
tdreset = TensorDict(
1089+
text=[["Let me calculate 2 + 3"], ["Let me calculate 4 * 5"]],
1090+
batch_size=(2, 1),
1091+
)
1092+
td = env_pool.reset(tdreset)
1093+
1094+
# Send async steps to both environments
1095+
td["text_response"] = [
1096+
[
1097+
'Let me calculate 2 + 3:\n<tool>calculator\n{"operation": "add", "a": 2, "b": 3}</tool><|im_end|>'
1098+
],
1099+
[
1100+
'Let me calculate 4 * 5:\n<tool>calculator\n{"operation": "multiply", "a": 4, "b": 5}</tool><|im_end|>'
1101+
],
1102+
]
1103+
env_pool.async_step_send(td)
1104+
1105+
# Get results as they complete
1106+
results = env_pool.async_step_recv(min_get=1) # Get at least one result
1107+
assert len(results) >= 1 # We should get at least one result
1108+
1109+
# Get remaining results
1110+
if len(results) < 2:
1111+
remaining = env_pool.async_step_recv()
1112+
else:
1113+
remaining = []
1114+
1115+
# Combine results
1116+
all_results = torch.stack(list(results) + list(remaining))
1117+
1118+
# Verify results
1119+
history = all_results["next", "history"]
1120+
assert len(history[0, 0]) == 4 # system, user, assistant, tool response
1121+
assert history[0, 0, -1].role == "tool"
1122+
assert any(
1123+
"result': 5" in c for c in history[:, 0, -1].content
1124+
) # 2 + 3 = 5
1125+
assert any(
1126+
"result': 20" in c for c in history[:, 0, -1].content
1127+
) # 4 * 5 = 20
1128+
1129+
finally:
1130+
env_pool.close()
1131+
9171132

9181133
if __name__ == "__main__":
9191134
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/async_envs.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525

2626
from tensordict.tensorclass import NonTensorData, NonTensorStack
27-
from tensordict.utils import _zip_strict
27+
from tensordict.utils import _zip_strict, expand_as_right
2828

2929
from torchrl.data.tensor_specs import NonTensor
3030
from torchrl.envs.common import _EnvPostInit, EnvBase
@@ -250,7 +250,10 @@ def _reset(
250250
tensordict = TensorDict(
251251
batch_size=(self.num_envs,) + self.env_batch_sizes[0]
252252
)
253-
tensordict.set(self._env_idx_key, torch.arange(tensordict.shape[0]))
253+
env_idx_nt = NonTensorStack(*range(tensordict.shape[0]))
254+
while env_idx_nt.batch_dims < tensordict.batch_dims:
255+
env_idx_nt = expand_as_right(env_idx_nt, tensordict)
256+
tensordict[self._env_idx_key] = env_idx_nt
254257
self._async_private_reset_send(tensordict)
255258
tensordict = self._async_private_reset_recv(min_get=self.num_envs)
256259
return tensordict
@@ -304,7 +307,10 @@ def reset(
304307
tensordict = TensorDict(
305308
batch_size=(self.num_envs,) + self.env_batch_sizes[0]
306309
)
307-
tensordict.set(self._env_idx_key, torch.arange(tensordict.shape[0]))
310+
indices = NonTensorStack(*range(tensordict.shape[0]))
311+
if indices.shape != tensordict.shape:
312+
indices = expand_as_right(indices, tensordict)
313+
tensordict[self._env_idx_key] = indices
308314
self.async_reset_send(tensordict)
309315
tensordict = self.async_reset_recv(min_get=self.num_envs)
310316
return tensordict
@@ -329,7 +335,7 @@ def _setup(self) -> None:
329335

330336
def _maybe_make_tensordict(self, tensordict, env_index, make_if_none):
331337
if env_index is None:
332-
env_idx = tensordict[self._env_idx_key]
338+
env_idx = tensordict.view(-1)[self._env_idx_key]
333339
if isinstance(env_idx, torch.Tensor):
334340
env_idx = env_idx.tolist()
335341
if isinstance(env_idx, int):
@@ -784,7 +790,6 @@ def async_step_send(
784790
self, tensordict: TensorDictBase, env_index: int | list[int] | None = None
785791
) -> None:
786792
tensordict, env_idx = self._maybe_make_tensordict(tensordict, env_index, False)
787-
788793
if self._busy.intersection(env_idx):
789794
raise RuntimeError(
790795
f"Some envs are still processing a step: envs that are busy: {self._busy}, queried: {env_idx}."

torchrl/envs/common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch.nn as nn
1717
from tensordict import (
1818
is_tensor_collection,
19+
lazy_stack,
1920
LazyStackedTensorDict,
2021
TensorDictBase,
2122
unravel_key,
@@ -3325,9 +3326,7 @@ def rollout(
33253326
)
33263327
raise
33273328
else:
3328-
out_td = LazyStackedTensorDict.maybe_dense_stack(
3329-
tensordicts, len(batch_size), out=out
3330-
)
3329+
out_td = lazy_stack(tensordicts, len(batch_size), out=out)
33313330
if set_truncated:
33323331
found_truncated = False
33333332
for key in self.done_keys:

0 commit comments

Comments
 (0)