Skip to content

Commit 955bd2e

Browse files
abhiksinglaAbhik Singla
and
Abhik Singla
authored
Fixed Ast Python Repl for Chatgpt multiline commands (#2406)
Resolves issue #2252 --------- Co-authored-by: Abhik Singla <[email protected]>
1 parent 1271c00 commit 955bd2e

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

langchain/tools/python/tool.py

+4
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class PythonAstREPLTool(BaseTool):
5050
)
5151
globals: Optional[Dict] = Field(default_factory=dict)
5252
locals: Optional[Dict] = Field(default_factory=dict)
53+
sanitize_input: bool = True
5354

5455
@root_validator(pre=True)
5556
def validate_python_version(cls, values: Dict) -> Dict:
@@ -65,6 +66,9 @@ def validate_python_version(cls, values: Dict) -> Dict:
6566
def _run(self, query: str) -> str:
6667
"""Use the tool."""
6768
try:
69+
if self.sanitize_input:
70+
# Remove the triple backticks from the query.
71+
query = query.strip().strip("```")
6872
tree = ast.parse(query)
6973
module = ast.Module(tree.body[:-1], type_ignores=[])
7074
exec(ast.unparse(module), self.globals, self.locals) # type: ignore

tests/unit_tests/test_python.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Test functionality of Python REPL."""
2+
import sys
3+
4+
import pytest
25

36
from langchain.python import PythonREPL
4-
from langchain.tools.python.tool import PythonREPLTool
7+
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
58

69
_SAMPLE_CODE = """
710
```
@@ -11,6 +14,14 @@ def multiply():
1114
```
1215
"""
1316

17+
_AST_SAMPLE_CODE = """
18+
```
19+
def multiply():
20+
return(5*6)
21+
multiply()
22+
```
23+
"""
24+
1425

1526
def test_python_repl() -> None:
1627
"""Test functionality when globals/locals are not provided."""
@@ -60,6 +71,15 @@ def test_functionality_multiline() -> None:
6071
assert output == "30\n"
6172

6273

74+
def test_python_ast_repl_multiline() -> None:
75+
"""Test correct functionality for ChatGPT multiline commands."""
76+
if sys.version_info < (3, 9):
77+
pytest.skip("Python 3.9+ is required for this test")
78+
tool = PythonAstREPLTool()
79+
output = tool.run(_AST_SAMPLE_CODE)
80+
assert output == 30
81+
82+
6383
def test_function() -> None:
6484
"""Test correct functionality."""
6585
chain = PythonREPL()

0 commit comments

Comments
 (0)