7
7
import argparse
8
8
import contextlib
9
9
import importlib .util
10
+ import random
10
11
import re
12
+ import time
11
13
12
14
import pytest
13
15
import torch
@@ -434,7 +436,7 @@ def test_chat_env(slef, tokenizer):
434
436
)
435
437
)
436
438
# Check history after reset
437
- torchrl_logger .info (' td_reset["history"].content' , td_reset [ "history" ]. content )
439
+ torchrl_logger .info (f' { td_reset ["history" ].content = } ' )
438
440
assert len (td_reset ["history" ][0 ].content ) == 2
439
441
assert td_reset ["history" ][0 , 0 ].content == "I'm system, do what I want."
440
442
assert td_reset ["history" ][0 , 1 ].content .startswith ("I'm the user." )
@@ -593,9 +595,10 @@ def test_ifeval(self):
593
595
env = IFEvalEnv (apply_template = True , tokenizer = tokenizer )
594
596
torchrl_logger .info (env .reset ())
595
597
r = env .reset ()
596
- r [0 ][
597
- "text_response"
598
- ] = """<think>
598
+ r .set (
599
+ "text_response" ,
600
+ [
601
+ """<think>
599
602
The task requires crafting a riddle about a 'house' that's not traditionally considered one. The answer must be included, and the response should be at least 400 words with a title wrapped in double angular brackets. Let's start by brainstorming what could be considered a 'house' in a non-traditional sense. Ideas include natural shelters, abstract concepts, or objects that serve a similar purpose to a house.
600
603
One potential concept is a "womb," as it provides shelter and housing for a developing being. However, we need to ensure our riddle is engaging, meets the word count requirement, and includes the necessary elements like a title.
601
604
Let's construct a narrative around the chosen concept, ensuring it's detailed and follows the required structure.
@@ -637,6 +640,8 @@ def test_ifeval(self):
637
640
By embracing such metaphors, we're encouraged to look beyond the obvious and appreciate the myriad ways 'shelter' manifests in our lives. And so, the riddle serves not just as a puzzle to be solved but as a reflection on the profound connections that bind us to the very essence of existence.
638
641
</answer><|im_end|>
639
642
"""
643
+ ],
644
+ )
640
645
td = env .step (r )
641
646
assert td ["next" , "ifeval_score" ].all ()
642
647
assert td .get (("next" , "reward" )) is not None
@@ -881,7 +886,7 @@ def test_python_interpreter_persistent_reset(self):
881
886
r ["text_response" ] = [
882
887
"""Here is a python code to execute:
883
888
```python
884
- # check if a is still defined
889
+ # check if a is still defined
885
890
if "a" in globals():
886
891
raise RuntimeError("a is still defined")
887
892
else:
@@ -899,7 +904,7 @@ def test_python_interpreter_persistent_reset(self):
899
904
"<|im_start|>assistant\n "
900
905
"Here is a python code to execute:\n "
901
906
"```python\n "
902
- "#\xa0 check if a is still defined\n "
907
+ "# check if a is still defined\n "
903
908
'if "a" in globals():\n '
904
909
' raise RuntimeError("a is still defined")\n '
905
910
"else:\n "
@@ -914,6 +919,216 @@ def test_python_interpreter_persistent_reset(self):
914
919
"<|im_start|>assistant\n " ,
915
920
)
916
921
922
+ @pytest .mark .skipif (not _has_transformers , reason = "requires transformers" )
923
+ def test_mcp_tool_transform (self ):
924
+ """Test the MCPToolTransform with a simple calculator tool."""
925
+ from torchrl .envs .llm import ChatEnv
926
+ from torchrl .envs .llm .transforms .tools import MCPToolTransform
927
+ from transformers import AutoTokenizer
928
+
929
+ # Define a simple calculator tool
930
+ def calculator (operation : str , a : float , b : float ) -> dict :
931
+ if operation == "add" :
932
+ return {"result" : a + b }
933
+ elif operation == "multiply" :
934
+ return {"result" : a * b }
935
+ else :
936
+ raise ValueError (f"Unknown operation: { operation } " )
937
+
938
+ # Define the tool schema
939
+ calculator_schema = {
940
+ "name" : "calculator" ,
941
+ "description" : "A simple calculator that can add or multiply two numbers" ,
942
+ "parameters" : {
943
+ "type" : "object" ,
944
+ "properties" : {
945
+ "operation" : {"type" : "string" , "enum" : ["add" , "multiply" ]},
946
+ "a" : {"type" : "number" },
947
+ "b" : {"type" : "number" },
948
+ },
949
+ "required" : ["operation" , "a" , "b" ],
950
+ },
951
+ }
952
+
953
+ # Create tools dictionary
954
+ tools = {"calculator" : calculator }
955
+ schemas = {"calculator" : calculator_schema }
956
+
957
+ # Create environment and transform
958
+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-3B" )
959
+ env = ChatEnv (
960
+ batch_size = (1 ,),
961
+ system_prompt = "You are a helpful assistant that uses a calculator." ,
962
+ apply_template = True ,
963
+ tokenizer = tokenizer ,
964
+ )
965
+ transform = MCPToolTransform (tools , schemas )
966
+ env = env .append_transform (transform )
967
+
968
+ # Test single tool call
969
+ td = TensorDict ({"text" : ["Let me calculate 2 + 3" ]}, batch_size = (1 ,))
970
+ td = env .reset (td )
971
+ td ["text_response" ] = [
972
+ 'I will help you calculate 2 + 3:\n <tool>calculator\n {"operation": "add", "a": 2, "b": 3}</tool><|im_end|>'
973
+ ]
974
+ result = env .step (td )
975
+
976
+ # Check that the tool was executed and returned correct result
977
+ history = result ["next" , "history" ]
978
+ assert len (history [0 ]) == 4 # system, user, assistant, tool response
979
+ assert history [0 , - 1 ].role == "tool"
980
+ assert "result': 5" in history [0 , - 1 ].content
981
+
982
+ # Test multiple tool calls in one response
983
+ td = TensorDict ({"text" : ["Calculate 2 + 3 and 4 * 5" ]}, batch_size = (1 ,))
984
+ td = env .reset (td )
985
+ td ["text_response" ] = [
986
+ "I will help you calculate both:\n "
987
+ '<tool>calculator\n {"operation": "add", "a": 2, "b": 3}</tool>\n '
988
+ '<tool>calculator\n {"operation": "multiply", "a": 4, "b": 5}</tool><|im_end|>'
989
+ ]
990
+ result = env .step (td )
991
+
992
+ # Check that both tools were executed and returned correct results
993
+ history = result ["next" , "history" ]
994
+ assert (
995
+ len (history [0 ]) == 5
996
+ ) # system, user, assistant, tool response 1, tool response 2
997
+ assert history [0 , - 2 ].role == "tool"
998
+ assert history [0 , - 1 ].role == "tool"
999
+ assert "result': 5" in history [0 , - 2 ].content # 2 + 3 = 5
1000
+ assert "result': 20" in history [0 , - 1 ].content # 4 * 5 = 20
1001
+
1002
+ # Test error handling
1003
+ td = TensorDict ({"text" : ["Calculate 2 ? 3" ]}, batch_size = (1 ,))
1004
+ td = env .reset (td )
1005
+ td ["text_response" ] = [
1006
+ 'I will try to calculate:\n <tool>calculator\n {"operation": "invalid", "a": 2, "b": 3}</tool><|im_end|>'
1007
+ ]
1008
+ result = env .step (td )
1009
+
1010
+ # Check that error was handled gracefully
1011
+ history = result ["next" , "history" ]
1012
+ assert len (history [0 ]) == 4
1013
+ assert history [0 , - 1 ].role == "tool"
1014
+ assert "failed" in history [0 , - 1 ].content
1015
+ assert "Unknown operation: invalid" in history [0 , - 1 ].content
1016
+
1017
+ # Test invalid JSON
1018
+ td = TensorDict ({"text" : ["Calculate something" ]}, batch_size = (1 ,))
1019
+ td = env .reset (td )
1020
+ td ["text_response" ] = [
1021
+ "Let me calculate:\n <tool>calculator\n invalid json</tool><|im_end|>"
1022
+ ]
1023
+ result = env .step (td )
1024
+
1025
+ # Check that JSON error was handled gracefully
1026
+ history = result ["next" , "history" ]
1027
+ assert len (history [0 ]) == 4
1028
+ assert history [0 , - 1 ].role == "tool"
1029
+ assert "failed" in history [0 , - 1 ].content
1030
+ assert "Failed to parse tool arguments" in history [0 , - 1 ].content
1031
+
1032
+ # Define a tool that waits for a random amount of time
1033
+ @classmethod
1034
+ def delayed_calculator (cls , operation : str , a : float , b : float ) -> dict :
1035
+ # Random delay between 100ms and 300ms
1036
+ delay = random .uniform (0.1 , 0.3 )
1037
+ time .sleep (delay )
1038
+ if operation == "add" :
1039
+ return {"result" : a + b , "delay" : delay }
1040
+ elif operation == "multiply" :
1041
+ return {"result" : a * b , "delay" : delay }
1042
+ else :
1043
+ raise ValueError (f"Unknown operation: { operation } " )
1044
+
1045
+ # Define the tool schema
1046
+ calculator_schema = {
1047
+ "name" : "delayed_calculator" ,
1048
+ "description" : "A calculator that introduces random delays" ,
1049
+ "parameters" : {
1050
+ "type" : "object" ,
1051
+ "properties" : {
1052
+ "operation" : {"type" : "string" , "enum" : ["add" , "multiply" ]},
1053
+ "a" : {"type" : "number" },
1054
+ "b" : {"type" : "number" },
1055
+ },
1056
+ "required" : ["operation" , "a" , "b" ],
1057
+ },
1058
+ }
1059
+
1060
+ # Create environment factory
1061
+ @classmethod
1062
+ def make_env (cls ):
1063
+ from torchrl .envs .llm .transforms .tools import MCPToolTransform
1064
+
1065
+ tokenizer = AutoTokenizer .from_pretrained ("Qwen/Qwen2.5-3B" )
1066
+ env = ChatEnv (
1067
+ batch_size = (1 ,),
1068
+ system_prompt = "I'm a calculator assistant" ,
1069
+ apply_template = True ,
1070
+ tokenizer = tokenizer ,
1071
+ )
1072
+ tools = {"calculator" : cls .delayed_calculator }
1073
+ schemas = {"calculator" : cls .calculator_schema }
1074
+ return env .append_transform (MCPToolTransform (tools , schemas ))
1075
+
1076
+ @pytest .mark .skipif (not _has_transformers , reason = "requires transformers" )
1077
+ def test_async_mcp_tools (self ):
1078
+ """Test async execution of MCP tools in an AsyncEnvPool."""
1079
+ from tensordict import TensorDict
1080
+ from torchrl .envs import AsyncEnvPool
1081
+
1082
+ # Create async env pool with 2 environments
1083
+ env_pool = AsyncEnvPool (
1084
+ [self .make_env , self .make_env ], backend = "multiprocessing"
1085
+ )
1086
+ try :
1087
+ # Reset both environments
1088
+ tdreset = TensorDict (
1089
+ text = [["Let me calculate 2 + 3" ], ["Let me calculate 4 * 5" ]],
1090
+ batch_size = (2 , 1 ),
1091
+ )
1092
+ td = env_pool .reset (tdreset )
1093
+
1094
+ # Send async steps to both environments
1095
+ td ["text_response" ] = [
1096
+ [
1097
+ 'Let me calculate 2 + 3:\n <tool>calculator\n {"operation": "add", "a": 2, "b": 3}</tool><|im_end|>'
1098
+ ],
1099
+ [
1100
+ 'Let me calculate 4 * 5:\n <tool>calculator\n {"operation": "multiply", "a": 4, "b": 5}</tool><|im_end|>'
1101
+ ],
1102
+ ]
1103
+ env_pool .async_step_send (td )
1104
+
1105
+ # Get results as they complete
1106
+ results = env_pool .async_step_recv (min_get = 1 ) # Get at least one result
1107
+ assert len (results ) >= 1 # We should get at least one result
1108
+
1109
+ # Get remaining results
1110
+ if len (results ) < 2 :
1111
+ remaining = env_pool .async_step_recv ()
1112
+ else :
1113
+ remaining = []
1114
+
1115
+ # Combine results
1116
+ all_results = torch .stack (list (results ) + list (remaining ))
1117
+
1118
+ # Verify results
1119
+ history = all_results ["next" , "history" ]
1120
+ assert len (history [0 , 0 ]) == 4 # system, user, assistant, tool response
1121
+ assert history [0 , 0 , - 1 ].role == "tool"
1122
+ assert any (
1123
+ "result': 5" in c for c in history [:, 0 , - 1 ].content
1124
+ ) # 2 + 3 = 5
1125
+ assert any (
1126
+ "result': 20" in c for c in history [:, 0 , - 1 ].content
1127
+ ) # 4 * 5 = 20
1128
+
1129
+ finally :
1130
+ env_pool .close ()
1131
+
917
1132
918
1133
if __name__ == "__main__" :
919
1134
args , unknown = argparse .ArgumentParser ().parse_known_args ()
0 commit comments