Skip to content

Commit 5e79b29

Browse files
committed
add tests
1 parent d706ce8 commit 5e79b29

File tree

5 files changed

+96
-88
lines changed

5 files changed

+96
-88
lines changed

src/magentic/chat_model/anthropic_chat_model.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import json
32
from collections.abc import AsyncIterator, Callable, Iterable, Iterator, Sequence
43
from enum import Enum
@@ -12,7 +11,6 @@
1211
from magentic.chat_model.function_schema import (
1312
BaseFunctionSchema,
1413
FunctionCallFunctionSchema,
15-
FunctionSchema,
1614
function_schema_for_type,
1715
get_async_function_schemas,
1816
get_function_schemas,
@@ -183,25 +181,6 @@ def _(message: AssistantMessage[Any]) -> MessageParam:
183181
"content": content_blocks,
184182
}
185183

186-
if isinstance(message.content, AsyncStreamedResponse):
187-
from magentic.utilities import ASYNC_RUNNER
188-
189-
async def collect_content_blocks():
190-
content_blocks: list[TextBlockParam | ToolUseBlockParam] = []
191-
async for item in message.content:
192-
if isinstance(item, AsyncStreamedStr):
193-
content_blocks.append(
194-
{"type": "text", "text": await item.to_string()}
195-
)
196-
elif isinstance(item, FunctionCall):
197-
content_blocks.append(_function_call_to_tool_call_block(item))
198-
return content_blocks
199-
200-
return {
201-
"role": AnthropicMessageRole.ASSISTANT.value,
202-
"content": ASYNC_RUNNER.run_coroutine(collect_content_blocks()),
203-
}
204-
205184
function_schema = function_schema_for_type(type(message.content))
206185
return {
207186
"role": AnthropicMessageRole.ASSISTANT.value,
@@ -236,6 +215,24 @@ def _(message: ToolResultMessage[Any]) -> MessageParam:
236215
}
237216

238217

218+
async def async_message_to_anthropic_message(message: Message[Any]) -> MessageParam:
219+
"""Convert a Message to an Anthropic message (async version)."""
220+
if isinstance(message.content, AsyncStreamedResponse):
221+
content_blocks: list[TextBlockParam | ToolUseBlockParam] = []
222+
async for item in message.content:
223+
if isinstance(item, AsyncStreamedStr):
224+
content_blocks.append({"type": "text", "text": await item.to_string()})
225+
elif isinstance(item, FunctionCall):
226+
content_blocks.append(_function_call_to_tool_call_block(item))
227+
228+
return {
229+
"role": AnthropicMessageRole.ASSISTANT.value,
230+
"content": content_blocks,
231+
}
232+
else: # noqa: RET505
233+
return message_to_anthropic_message(message)
234+
235+
239236
# TODO: Move this to the magentic level by allowing `UserMessage` have a list of content
240237
def _combine_messages(messages: Iterable[MessageParam]) -> list[MessageParam]:
241238
"""Combine messages with the same role, to get alternating roles.
@@ -515,7 +512,7 @@ async def acomplete(
515512
] = await self._async_client.messages.stream(
516513
model=self.model,
517514
messages=_combine_messages(
518-
[message_to_anthropic_message(m) for m in messages]
515+
[await async_message_to_anthropic_message(m) for m in messages]
519516
),
520517
max_tokens=self.max_tokens,
521518
stop_sequences=_if_given(stop),

src/magentic/chat_model/openai_chat_model.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from magentic.chat_model.function_schema import (
2424
BaseFunctionSchema,
2525
FunctionCallFunctionSchema,
26-
FunctionSchema,
2726
function_schema_for_type,
2827
get_async_function_schemas,
2928
get_function_schemas,
@@ -174,31 +173,6 @@ def _(message: AssistantMessage[Any]) -> ChatCompletionMessageParam:
174173
],
175174
}
176175

177-
if isinstance(message.content, AsyncStreamedResponse):
178-
from magentic.utilities import ASYNC_RUNNER
179-
180-
async def collect_content_and_function_calls():
181-
content: list[str] = []
182-
function_calls: list[FunctionCall[Any]] = []
183-
async for item in message.content:
184-
if isinstance(item, AsyncStreamedStr):
185-
content.append(await item.to_string())
186-
elif isinstance(item, FunctionCall):
187-
function_calls.append(item)
188-
return content, function_calls
189-
190-
content, function_calls = ASYNC_RUNNER.run_coroutine(
191-
collect_content_and_function_calls()
192-
)
193-
return {
194-
"role": OpenaiMessageRole.ASSISTANT.value,
195-
"content": " ".join(content),
196-
"tool_calls": [
197-
_function_call_to_tool_call_block(function_call)
198-
for function_call in function_calls
199-
],
200-
}
201-
202176
function_schema = function_schema_for_type(type(message.content))
203177
return {
204178
"role": OpenaiMessageRole.ASSISTANT.value,
@@ -230,6 +204,31 @@ def _(message: ToolResultMessage[Any]) -> ChatCompletionMessageParam:
230204
}
231205

232206

207+
async def async_message_to_openai_message(
208+
message: Message[Any],
209+
) -> ChatCompletionMessageParam:
210+
"""Convert a Message to an OpenAI message (async version)."""
211+
if isinstance(message.content, AsyncStreamedResponse):
212+
content: list[str] = []
213+
function_calls: list[FunctionCall[Any]] = []
214+
async for item in message.content:
215+
if isinstance(item, AsyncStreamedStr):
216+
content.append(await item.to_string())
217+
elif isinstance(item, FunctionCall):
218+
function_calls.append(item)
219+
220+
return {
221+
"role": OpenaiMessageRole.ASSISTANT.value,
222+
"content": " ".join(content),
223+
"tool_calls": [
224+
_function_call_to_tool_call_block(function_call)
225+
for function_call in function_calls
226+
],
227+
}
228+
else: # noqa: RET505
229+
return message_to_openai_message(message)
230+
231+
233232
# TODO: Use ToolResultMessage to solve this at magentic level
234233
def _add_missing_tool_calls_responses(
235234
messages: list[ChatCompletionMessageParam],
@@ -556,7 +555,7 @@ async def acomplete(
556555
] = await self._async_client.chat.completions.create(
557556
model=self.model,
558557
messages=_add_missing_tool_calls_responses(
559-
[message_to_openai_message(m) for m in messages]
558+
[await async_message_to_openai_message(m) for m in messages]
560559
),
561560
max_tokens=_if_given(self.max_tokens),
562561
seed=_if_given(self.seed),

src/magentic/utilities.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

tests/chat_model/test_anthropic_chat_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from magentic._streamed_response import AsyncStreamedResponse, StreamedResponse
99
from magentic.chat_model.anthropic_chat_model import (
1010
AnthropicChatModel,
11+
async_message_to_anthropic_message,
1112
message_to_anthropic_message,
1213
)
1314
from magentic.chat_model.base import ToolSchemaParseError
@@ -125,6 +126,31 @@ def test_message_to_anthropic_message(message, expected_anthropic_message):
125126
assert message_to_anthropic_message(message) == expected_anthropic_message
126127

127128

129+
async def test_async_message_to_anthropic_message():
130+
async def generate_async_streamed_response():
131+
async def async_string_generator():
132+
yield "Hello"
133+
yield "World"
134+
135+
yield AsyncStreamedStr(async_string_generator())
136+
yield FunctionCall(plus, 1, 2)
137+
138+
async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response())
139+
message = AssistantMessage(async_streamed_response)
140+
assert await async_message_to_anthropic_message(message) == {
141+
"role": "assistant",
142+
"content": [
143+
{"type": "text", "text": "HelloWorld"},
144+
{
145+
"type": "tool_use",
146+
"id": ANY,
147+
"name": "plus",
148+
"input": {"a": 1, "b": 2},
149+
},
150+
],
151+
}
152+
153+
128154
def test_message_to_anthropic_message_user_image_document_bytes_pdf(document_bytes_pdf):
129155
image_message = UserMessage([DocumentBytes(document_bytes_pdf)])
130156
assert message_to_anthropic_message(image_message) == snapshot(

tests/chat_model/test_openai_chat_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from magentic.chat_model.openai_chat_model import (
2626
OpenaiChatModel,
27+
async_message_to_openai_message,
2728
message_to_openai_message,
2829
)
2930
from magentic.function_call import FunctionCall, ParallelFunctionCall
@@ -134,6 +135,30 @@ def test_message_to_openai_message(message, expected_openai_message):
134135
assert message_to_openai_message(message) == expected_openai_message
135136

136137

138+
async def test_async_message_to_openai_message():
139+
async def generate_async_streamed_response():
140+
async def async_string_generator():
141+
yield "Hello"
142+
yield "World"
143+
144+
yield AsyncStreamedStr(async_string_generator())
145+
yield FunctionCall(plus, 1, 2)
146+
147+
async_streamed_response = AsyncStreamedResponse(generate_async_streamed_response())
148+
message = AssistantMessage(async_streamed_response)
149+
assert await async_message_to_openai_message(message) == {
150+
"role": "assistant",
151+
"content": "HelloWorld",
152+
"tool_calls": [
153+
{
154+
"id": ANY,
155+
"type": "function",
156+
"function": {"name": "plus", "arguments": '{"a":1,"b":2}'},
157+
},
158+
],
159+
}
160+
161+
137162
def test_message_to_openai_message_user_image_message_bytes_jpg(image_bytes_jpg):
138163
image_message = UserMessage([ImageBytes(image_bytes_jpg)])
139164
assert message_to_openai_message(image_message) == snapshot(

0 commit comments

Comments
 (0)