Skip to content

Commit

Permalink
Use function instead of FunctionCall in FunctionResultMessage (#110)
Browse files Browse the repository at this point in the history
* Remove from_function_call methods from FunctionResultMessage

* Add tests for message_to_openai_message

* Fix: Use FunctionCallFunctionSchema to name function in result message

* Use function instead of function_call in FunctionResultMessage

* Add test for chatprompt with function call
  • Loading branch information
jackmpcollins authored Feb 19, 2024
1 parent 1028c8b commit f66feaf
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 40 deletions.
15 changes: 11 additions & 4 deletions src/magentic/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,28 @@ def exec_function_call(self: Self) -> Self:
msg = "Last message is not a function call."
raise TypeError(msg)

function_call = last_message.content
output = function_call()
return self.add_message(
FunctionResultMessage.from_function_call(last_message.content)
FunctionResultMessage(content=output, function=function_call.function)
)

async def aexec_function_call(self: Self) -> Self:
"""Async version of `exec_function_call`."""
"""Async version of `exec_function_call`.
Additionally, if the result of the function is awaitable, await it
before adding the message.
"""
last_message = self._messages[-1]
if not isinstance(last_message.content, FunctionCall):
msg = "Last message is not a function call."
raise TypeError(msg)

output = last_message.content()
function_call = last_message.content
output = function_call()
if inspect.isawaitable(output):
output = await output

return self.add_message(
FunctionResultMessage(content=output, function_call=last_message.content)
FunctionResultMessage(content=output, function=function_call.function)
)
40 changes: 9 additions & 31 deletions src/magentic/chat_model/message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from abc import ABC, abstractmethod
from typing import Awaitable, Generic, TypeVar, overload

from magentic.function_call import FunctionCall
from typing import Awaitable, Callable, Generic, TypeVar, overload

T = TypeVar("T")

Expand Down Expand Up @@ -54,45 +52,25 @@ class FunctionResultMessage(Message[T], Generic[T]):
"""A message containing the result of a function call."""

@overload
def __init__(self, content: T, function_call: FunctionCall[T]):
def __init__(self, content: T, function: Callable[..., Awaitable[T]]):
...

@overload
def __init__(self, content: T, function_call: FunctionCall[Awaitable[T]]):
def __init__(self, content: T, function: Callable[..., T]):
...

def __init__(
self, content: T, function_call: FunctionCall[T] | FunctionCall[Awaitable[T]]
self, content: T, function: Callable[..., Awaitable[T]] | Callable[..., T]
):
super().__init__(content)
self._function_call = function_call
self._function = function

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.content!r}, {self._function_call!r})"
return f"{self.__class__.__name__}({self.content!r}, {self._function!r})"

@property
def function_call(self) -> FunctionCall[T] | FunctionCall[Awaitable[T]]:
return self._function_call
def function(self) -> Callable[..., Awaitable[T]] | Callable[..., T]:
return self._function

def with_content(self, content: T) -> "FunctionResultMessage[T]":
return FunctionResultMessage(content, self._function_call)

@classmethod
def from_function_call(
cls, function_call: FunctionCall[T]
) -> "FunctionResultMessage[T]":
"""Create a message containing the result of a function call."""
return cls(
content=function_call(),
function_call=function_call,
)

@classmethod
async def afrom_function_call(
cls, function_call: FunctionCall[Awaitable[T]]
) -> "FunctionResultMessage[T]":
"""Async version of `from_function_call`."""
return cls(
content=await function_call(),
function_call=function_call,
)
return FunctionResultMessage(content, self._function)
2 changes: 1 addition & 1 deletion src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def message_to_openai_message(message: Message[Any]) -> ChatCompletionMessagePar
function_schema = function_schema_for_type(type(message.content))
return {
"role": OpenaiMessageRole.FUNCTION.value,
"name": function_schema.name,
"name": FunctionCallFunctionSchema(message.function).name,
"content": function_schema.serialize_args(message.content),
}

Expand Down
53 changes: 51 additions & 2 deletions tests/chat_model/test_openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,57 @@
import openai
import pytest

from magentic.chat_model.message import UserMessage
from magentic.chat_model.openai_chat_model import OpenaiChatModel
from magentic.chat_model.message import (
AssistantMessage,
FunctionResultMessage,
SystemMessage,
UserMessage,
)
from magentic.chat_model.openai_chat_model import (
OpenaiChatModel,
message_to_openai_message,
)
from magentic.function_call import FunctionCall


def plus(a: int, b: int) -> int:
return a + b


@pytest.mark.parametrize(
("message", "expected_openai_message"),
[
(SystemMessage("Hello"), {"role": "system", "content": "Hello"}),
(UserMessage("Hello"), {"role": "user", "content": "Hello"}),
(AssistantMessage("Hello"), {"role": "assistant", "content": "Hello"}),
(
AssistantMessage(42),
{
"role": "assistant",
"content": None,
"function_call": {"name": "return_int", "arguments": '{"value":42}'},
},
),
(
AssistantMessage(FunctionCall(plus, 1, 2)),
{
"role": "assistant",
"content": None,
"function_call": {"name": "plus", "arguments": '{"a":1,"b":2}'},
},
),
(
FunctionResultMessage(3, plus),
{
"role": "function",
"name": "plus",
"content": '{"value":3}',
},
),
],
)
def test_message_to_openai_message(message, expected_openai_message):
assert message_to_openai_message(message) == expected_openai_message


@pytest.mark.openai
Expand Down
22 changes: 20 additions & 2 deletions tests/test_chatprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
chatprompt,
escape_braces,
)
from magentic.function_call import FunctionCall


@pytest.mark.parametrize(
Expand Down Expand Up @@ -51,12 +52,12 @@ def test_escape_braces(text):
(
[
FunctionResultMessage(
"Function result message with {param}", function_call=Mock()
"Function result message with {param}", function=Mock()
)
],
[
FunctionResultMessage(
"Function result message with arg", function_call=Mock()
"Function result message with arg", function=Mock()
)
],
),
Expand Down Expand Up @@ -156,3 +157,20 @@ def get_movie_quote(movie: str) -> Quote:

movie_quote = get_movie_quote("Iron Man")
assert isinstance(movie_quote, Quote)


@pytest.mark.openai
def test_chatprompt_with_function_call_and_result():
def plus(a: int, b: int) -> int:
return a + b

@chatprompt(
UserMessage("Use the plus function to add 1 and 2."),
AssistantMessage(FunctionCall(plus, 1, 2)),
FunctionResultMessage(3, plus),
)
def do_math() -> str:
...

output = do_math()
assert isinstance(output, str)

0 comments on commit f66feaf

Please sign in to comment.