File tree 2 files changed +25
-1
lines changed
2 files changed +25
-1
lines changed Original file line number Diff line number Diff line change @@ -50,6 +50,7 @@ class PythonAstREPLTool(BaseTool):
50
50
)
51
51
globals : Optional [Dict ] = Field (default_factory = dict )
52
52
locals : Optional [Dict ] = Field (default_factory = dict )
53
+ sanitize_input : bool = True
53
54
54
55
@root_validator (pre = True )
55
56
def validate_python_version (cls , values : Dict ) -> Dict :
@@ -65,6 +66,9 @@ def validate_python_version(cls, values: Dict) -> Dict:
65
66
def _run (self , query : str ) -> str :
66
67
"""Use the tool."""
67
68
try :
69
+ if self .sanitize_input :
70
+ # Remove the triple backticks from the query.
71
+ query = query .strip ().strip ("```" )
68
72
tree = ast .parse (query )
69
73
module = ast .Module (tree .body [:- 1 ], type_ignores = [])
70
74
exec (ast .unparse (module ), self .globals , self .locals ) # type: ignore
Original file line number Diff line number Diff line change 1
1
"""Test functionality of Python REPL."""
2
+ import sys
3
+
4
+ import pytest
2
5
3
6
from langchain .python import PythonREPL
4
- from langchain .tools .python .tool import PythonREPLTool
7
+ from langchain .tools .python .tool import PythonAstREPLTool , PythonREPLTool
5
8
6
9
_SAMPLE_CODE = """
7
10
```
@@ -11,6 +14,14 @@ def multiply():
11
14
```
12
15
"""
13
16
17
+ _AST_SAMPLE_CODE = """
18
+ ```
19
+ def multiply():
20
+ return(5*6)
21
+ multiply()
22
+ ```
23
+ """
24
+
14
25
15
26
def test_python_repl () -> None :
16
27
"""Test functionality when globals/locals are not provided."""
@@ -60,6 +71,15 @@ def test_functionality_multiline() -> None:
60
71
assert output == "30\n "
61
72
62
73
74
+ def test_python_ast_repl_multiline () -> None :
75
+ """Test correct functionality for ChatGPT multiline commands."""
76
+ if sys .version_info < (3 , 9 ):
77
+ pytest .skip ("Python 3.9+ is required for this test" )
78
+ tool = PythonAstREPLTool ()
79
+ output = tool .run (_AST_SAMPLE_CODE )
80
+ assert output == 30
81
+
82
+
63
83
def test_function () -> None :
64
84
"""Test correct functionality."""
65
85
chain = PythonREPL ()
You can’t perform that action at this time.
0 commit comments