Skip to content

Commit c86b153

Browse files
committed
Merge branch 'main' into ducanhdt:add_docker
2 parents 752af61 + e10cdf4 commit c86b153

File tree

16 files changed

+438
-118
lines changed

16 files changed

+438
-118
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ repos:
2424
rev: 6.1.0
2525
hooks:
2626
- id: flake8
27-
args: [--max-line-length=88]
27+
args: [--max-line-length=120]

codeinterpreterapi/agents/functions_agent.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
1-
"""
2-
Module implements an agent that uses OpenAI's APIs function enabled API.
3-
4-
This file is a modified version of the original file
5-
from langchain/agents/openai_functions_agent/base.py.
6-
Credits go to the original authors :)
7-
"""
8-
1+
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
92
import json
103
from dataclasses import dataclass
114
from json import JSONDecodeError
125
from typing import Any, List, Optional, Sequence, Tuple, Union
136

147
from langchain.agents import BaseSingleActionAgent
15-
from langchain.base_language import BaseLanguageModel
168
from langchain.callbacks.base import BaseCallbackManager
179
from langchain.callbacks.manager import Callbacks
1810
from langchain.chat_models.openai import ChatOpenAI
19-
from langchain.prompts.chat import (BaseMessagePromptTemplate,
20-
ChatPromptTemplate,
21-
HumanMessagePromptTemplate,
22-
MessagesPlaceholder)
23-
from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage,
24-
BasePromptTemplate, FunctionMessage,
25-
OutputParserException, SystemMessage)
11+
from langchain.prompts.chat import (
12+
BaseMessagePromptTemplate,
13+
ChatPromptTemplate,
14+
HumanMessagePromptTemplate,
15+
MessagesPlaceholder,
16+
)
17+
from langchain.schema import (
18+
AgentAction,
19+
AgentFinish,
20+
BasePromptTemplate,
21+
OutputParserException,
22+
)
23+
from langchain.schema.language_model import BaseLanguageModel
24+
from langchain.schema.messages import (
25+
AIMessage,
26+
BaseMessage,
27+
FunctionMessage,
28+
SystemMessage,
29+
)
2630
from langchain.tools import BaseTool
2731
from langchain.tools.convert_to_openai import format_tool_to_openai_function
2832
from pydantic import root_validator
@@ -95,17 +99,14 @@ def _format_intermediate_steps(
9599
return messages
96100

97101

98-
async def _parse_ai_message(
99-
message: BaseMessage, llm: BaseLanguageModel
100-
) -> Union[AgentAction, AgentFinish]:
102+
def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
101103
"""Parse an AI message."""
102104
if not isinstance(message, AIMessage):
103105
raise TypeError(f"Expected an AI message got {type(message)}")
104106

105107
function_call = message.additional_kwargs.get("function_call", {})
106108

107109
if function_call:
108-
function_call = message.additional_kwargs["function_call"]
109110
function_name = function_call["name"]
110111
try:
111112
_tool_input = json.loads(function_call["arguments"])
@@ -189,8 +190,42 @@ def input_keys(self) -> List[str]:
189190
def functions(self) -> List[dict]:
190191
return [dict(format_tool_to_openai_function(t)) for t in self.tools]
191192

192-
def plan(self):
193-
raise NotImplementedError
193+
def plan(
194+
self,
195+
intermediate_steps: List[Tuple[AgentAction, str]],
196+
callbacks: Callbacks = None,
197+
with_functions: bool = True,
198+
**kwargs: Any,
199+
) -> Union[AgentAction, AgentFinish]:
200+
"""Given input, decided what to do.
201+
202+
Args:
203+
intermediate_steps: Steps the LLM has taken to date, along with observations
204+
**kwargs: User inputs.
205+
206+
Returns:
207+
Action specifying what tool to use.
208+
"""
209+
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
210+
selected_inputs = {
211+
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
212+
}
213+
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
214+
prompt = self.prompt.format_prompt(**full_inputs)
215+
messages = prompt.to_messages()
216+
if with_functions:
217+
predicted_message = self.llm.predict_messages(
218+
messages,
219+
functions=self.functions,
220+
callbacks=callbacks,
221+
)
222+
else:
223+
predicted_message = self.llm.predict_messages(
224+
messages,
225+
callbacks=callbacks,
226+
)
227+
agent_decision = _parse_ai_message(predicted_message)
228+
return agent_decision
194229

195230
async def aplan(
196231
self,
@@ -218,9 +253,38 @@ async def aplan(
218253
predicted_message = await self.llm.apredict_messages(
219254
messages, functions=self.functions, callbacks=callbacks
220255
)
221-
agent_decision = await _parse_ai_message(predicted_message, self.llm)
256+
agent_decision = _parse_ai_message(predicted_message)
222257
return agent_decision
223258

259+
def return_stopped_response(
260+
self,
261+
early_stopping_method: str,
262+
intermediate_steps: List[Tuple[AgentAction, str]],
263+
**kwargs: Any,
264+
) -> AgentFinish:
265+
"""Return response when agent has been stopped due to max iterations."""
266+
if early_stopping_method == "force":
267+
# `force` just returns a constant string
268+
return AgentFinish(
269+
{"output": "Agent stopped due to iteration limit or time limit."}, ""
270+
)
271+
elif early_stopping_method == "generate":
272+
# Generate does one final forward pass
273+
agent_decision = self.plan(
274+
intermediate_steps, with_functions=False, **kwargs
275+
)
276+
if type(agent_decision) == AgentFinish: # noqa: E721
277+
return agent_decision
278+
else:
279+
raise ValueError(
280+
f"got AgentAction with no functions provided: {agent_decision}"
281+
)
282+
else:
283+
raise ValueError(
284+
"early_stopping_method should be one of `force` or `generate`, "
285+
f"got {early_stopping_method}"
286+
)
287+
224288
@classmethod
225289
def create_prompt(
226290
cls,
@@ -275,7 +339,7 @@ def from_llm_and_tools(
275339
extra_prompt_messages=extra_prompt_messages,
276340
system_message=system_message,
277341
)
278-
return cls(
342+
return cls( # type: ignore
279343
llm=llm,
280344
prompt=prompt,
281345
tools=tools,

codeinterpreterapi/chains/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .extract_code import extract_python_code
2-
from .modifications_check import get_file_modifications
3-
from .rm_dl_link import remove_download_link
2+
from .modifications_check import aget_file_modifications, get_file_modifications
3+
from .rm_dl_link import aremove_download_link, remove_download_link
44

55
__all__ = [
66
"extract_python_code",
77
"get_file_modifications",
8+
"aget_file_modifications",
89
"remove_download_link",
10+
"aremove_download_link",
911
]

codeinterpreterapi/chains/modifications_check.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,28 @@
77
from codeinterpreterapi.prompts import determine_modifications_prompt
88

99

10-
async def get_file_modifications(
10+
def get_file_modifications(
11+
code: str,
12+
llm: BaseLanguageModel,
13+
retry: int = 2,
14+
) -> Optional[List[str]]:
15+
if retry < 1:
16+
return None
17+
18+
prompt = determine_modifications_prompt.format(code=code)
19+
20+
result = llm.predict(prompt, stop="```")
21+
22+
try:
23+
result = json.loads(result)
24+
except json.JSONDecodeError:
25+
result = ""
26+
if not result or not isinstance(result, dict) or "modifications" not in result:
27+
return get_file_modifications(code, llm, retry=retry - 1)
28+
return result["modifications"]
29+
30+
31+
async def aget_file_modifications(
1132
code: str,
1233
llm: BaseLanguageModel,
1334
retry: int = 2,
@@ -24,12 +45,12 @@ async def get_file_modifications(
2445
except json.JSONDecodeError:
2546
result = ""
2647
if not result or not isinstance(result, dict) or "modifications" not in result:
27-
return await get_file_modifications(code, llm, retry=retry - 1)
48+
return await aget_file_modifications(code, llm, retry=retry - 1)
2849
return result["modifications"]
2950

3051

3152
async def test():
32-
llm = ChatAnthropic(model="claude-1.3") # type: ignore
53+
llm = ChatAnthropic(model="claude-2") # type: ignore
3354

3455
code = """
3556
import matplotlib.pyplot as plt
@@ -45,7 +66,7 @@ async def test():
4566
plt.show()
4667
"""
4768

48-
print(await get_file_modifications(code, llm))
69+
print(get_file_modifications(code, llm))
4970

5071

5172
if __name__ == "__main__":

codeinterpreterapi/chains/rm_dl_link.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,22 @@
55
from codeinterpreterapi.prompts import remove_dl_link_prompt
66

77

8-
async def remove_download_link(
8+
def remove_download_link(
9+
input_response: str,
10+
llm: BaseLanguageModel,
11+
) -> str:
12+
messages = remove_dl_link_prompt.format_prompt(
13+
input_response=input_response
14+
).to_messages()
15+
message = llm.predict_messages(messages)
16+
17+
if not isinstance(message, AIMessage):
18+
raise OutputParserException("Expected an AIMessage")
19+
20+
return message.content
21+
22+
23+
async def aremove_download_link(
924
input_response: str,
1025
llm: BaseLanguageModel,
1126
) -> str:
@@ -20,21 +35,19 @@ async def remove_download_link(
2035
return message.content
2136

2237

23-
async def test():
38+
def test():
2439
llm = ChatOpenAI(model="gpt-3.5-turbo-0613") # type: ignore
2540

2641
example = (
2742
"I have created the plot to your dataset.\n\n"
2843
"Link to the file [here](sandbox:/plot.png)."
2944
)
30-
print(await remove_download_link(example, llm))
45+
print(remove_download_link(example, llm))
3146

3247

3348
if __name__ == "__main__":
34-
import asyncio
35-
3649
from dotenv import load_dotenv
3750

3851
load_dotenv()
3952

40-
asyncio.run(test())
53+
test()

codeinterpreterapi/utils/parser.py renamed to codeinterpreterapi/parser.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def _type(self) -> str:
3939

4040
class CodeChatAgentOutputParser(AgentOutputParser):
4141
def get_format_instructions(self) -> str:
42-
from langchain.agents.conversational_chat.prompt import \
43-
FORMAT_INSTRUCTIONS
42+
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
4443

4544
return FORMAT_INSTRUCTIONS
4645

codeinterpreterapi/prompts/remove_dl_link.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from langchain.prompts.chat import (ChatPromptTemplate,
2-
HumanMessagePromptTemplate)
1+
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
32
from langchain.schema import AIMessage, HumanMessage, SystemMessage
43

54
remove_dl_link_prompt = ChatPromptTemplate(

codeinterpreterapi/schema/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from .file import File
22
from .input import CodeInput, FileInput
33
from .response import CodeInterpreterResponse, UserRequest
4+
from .status import SessionStatus
45

56
__all__ = [
67
"CodeInterpreterResponse",
78
"CodeInput",
89
"File",
910
"FileInput",
1011
"UserRequest",
12+
"SessionStatus",
1113
]

codeinterpreterapi/schema/response.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@ def __repr__(self):
1414

1515

1616
class CodeInterpreterResponse(AIMessage):
17+
"""
18+
Response from the code interpreter agent.
19+
20+
files: list of files to be sent to the user (File )
21+
code_log: list[tuple[str, str]] = []
22+
"""
23+
1724
files: list[File] = []
18-
# final_code: str = "" TODO: implement
19-
# final_output: str = "" TODO: implement
25+
code_log: list[tuple[str, str]] = []
2026

2127
def show(self):
2228
print("AI: ", self.content)

codeinterpreterapi/schema/status.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from codeboxapi.schema import CodeBoxStatus # type: ignore
2+
3+
4+
class SessionStatus(CodeBoxStatus):
5+
@classmethod
6+
def from_codebox_status(cls, cbs: CodeBoxStatus) -> "SessionStatus":
7+
return cls(status=cbs.status)
8+
9+
def __repr__(self):
10+
return f"<SessionStatus status={self.status}>"

0 commit comments

Comments
 (0)