Skip to content

Update tests to be compatible with new OpenAI, MistralAI and MCP versions #2094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
5 changes: 4 additions & 1 deletion docs/mcp/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ async def sampling_callback(
SamplingMessage(
role='user',
content=TextContent(
type='text', text='write a poem about socks', annotations=None
type='text',
text='write a poem about socks',
annotations=None,
meta=None,
),
)
]
Expand Down
32 changes: 21 additions & 11 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def call_tool(
except McpError as e:
raise exceptions.ModelRetry(e.error.message)

content = [self._map_tool_result_part(part) for part in result.content]
content = [await self._map_tool_result_part(part) for part in result.content]

if result.isError:
text = '\n'.join(str(part) for part in content)
Expand Down Expand Up @@ -207,8 +207,8 @@ async def _sampling_callback(
model=self.sampling_model.model_name,
)

def _map_tool_result_part(
self, part: mcp_types.Content
async def _map_tool_result_part(
self, part: mcp_types.ContentBlock
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values

Expand All @@ -230,18 +230,28 @@ def _map_tool_result_part(
) # pragma: no cover
elif isinstance(part, mcp_types.EmbeddedResource):
resource = part.resource
if isinstance(resource, mcp_types.TextResourceContents):
return resource.text
elif isinstance(resource, mcp_types.BlobResourceContents):
return messages.BinaryContent(
data=base64.b64decode(resource.blob),
media_type=resource.mimeType or 'application/octet-stream',
)
return self._get_content(resource)
elif isinstance(part, mcp_types.ResourceLink):
resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
if len(resource_result.contents) > 1:
return [self._get_content(resource) for resource in resource_result.contents]
else:
assert_never(resource)
return self._get_content(resource_result.contents[0])
else:
assert_never(part)

def _get_content(
self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
) -> str | messages.BinaryContent:
if isinstance(resource, mcp_types.TextResourceContents):
return resource.text
elif isinstance(resource, mcp_types.BlobResourceContents):
return messages.BinaryContent(
data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
)
else:
assert_never(resource)


@dataclass
class MCPServerStdio(MCPServer):
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
'gpt-4o-audio-preview',
'gpt-4o-audio-preview-2024-10-01',
'gpt-4o-audio-preview-2024-12-17',
'gpt-4o-audio-preview-2025-06-03',
'gpt-4o-mini',
'gpt-4o-mini-2024-07-18',
'gpt-4o-mini-audio-preview',
Expand Down Expand Up @@ -242,6 +243,7 @@
'o3-mini',
'o3-mini-2025-01-31',
'openai:chatgpt-4o-latest',
'openai:codex-mini-latest',
'openai:gpt-3.5-turbo',
'openai:gpt-3.5-turbo-0125',
'openai:gpt-3.5-turbo-0301',
Expand Down Expand Up @@ -274,6 +276,7 @@
'openai:gpt-4o-audio-preview',
'openai:gpt-4o-audio-preview-2024-10-01',
'openai:gpt-4o-audio-preview-2024-12-17',
'openai:gpt-4o-audio-preview-2025-06-03',
'openai:gpt-4o-mini',
'openai:gpt-4o-mini-2024-07-18',
'openai:gpt-4o-mini-audio-preview',
Expand Down
9 changes: 3 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
ChatMessageV2,
ChatResponse,
SystemChatMessageV2,
TextAssistantMessageContentItem,
TextAssistantMessageV2ContentItem,
ToolCallV2,
ToolCallV2Function,
ToolChatMessageV2,
Expand Down Expand Up @@ -111,7 +111,6 @@ def __init__(
*,
provider: Literal['cohere'] | Provider[AsyncClientV2] = 'cohere',
profile: ModelProfileSpec | None = None,
settings: ModelSettings | None = None,
):
"""Initialize an Cohere model.

Expand All @@ -122,15 +121,13 @@ def __init__(
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
created using the other parameters.
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
settings: Model-specific settings that will be used as defaults for this model.
"""
self._model_name = model_name

if isinstance(provider, str):
provider = infer_provider(provider)
self.client = provider.client

super().__init__(settings=settings, profile=profile or provider.model_profile)
self._profile = profile or provider.model_profile

@property
def base_url(self) -> str:
Expand Down Expand Up @@ -227,7 +224,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
assert_never(item)
message_param = AssistantChatMessageV2(role='assistant')
if texts:
message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
if tool_calls:
message_param.tool_calls = tool_calls
cohere_messages.append(message_param)
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ dependencies = [
# WARNING if you add optional groups, please update docs/install.md
logfire = ["logfire>=3.11.0"]
# Models
openai = ["openai>=1.76.0"]
cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"]
openai = ["openai>=1.86.0"]
cohere = ["cohere>=5.16.0; platform_system != 'Emscripten'"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
google = ["google-genai>=1.24.0"]
anthropic = ["anthropic>=0.52.0"]
Expand All @@ -75,7 +75,7 @@ tavily = ["tavily-python>=0.5.0"]
# CLI
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
# MCP
mcp = ["mcp>=1.9.4; python_version >= '3.10'"]
mcp = ["mcp>=1.10.0; python_version >= '3.10'"]
# Evals
evals = ["pydantic-evals=={{ version }}"]
# A2A
Expand Down
1 change: 1 addition & 0 deletions tests/assets/product_name.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PydanticAI
40 changes: 37 additions & 3 deletions tests/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from mcp.server.fastmcp import Context, FastMCP, Image
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import BlobResourceContents, EmbeddedResource, SamplingMessage, TextContent, TextResourceContents
from mcp.types import (
BlobResourceContents,
EmbeddedResource,
ResourceLink,
SamplingMessage,
TextContent,
TextResourceContents,
)
from pydantic import AnyUrl

mcp = FastMCP('PydanticAI MCP Server')
Expand Down Expand Up @@ -44,13 +51,22 @@ async def get_image_resource() -> EmbeddedResource:
return EmbeddedResource(
type='resource',
resource=BlobResourceContents(
uri='resource://kiwi.png', # type: ignore
uri=AnyUrl('resource://kiwi.png'),
blob=base64.b64encode(data).decode('utf-8'),
mimeType='image/png',
),
)


@mcp.tool()
async def get_image_resource_1() -> ResourceLink:
return ResourceLink(
type='resource_link',
uri=AnyUrl(Path(__file__).parent.joinpath('assets/kiwi.png').absolute().as_uri()),
name='kiwi.png',
)


@mcp.tool()
async def get_audio_resource() -> EmbeddedResource:
data = Path(__file__).parent.joinpath('assets/marcelo.mp3').read_bytes()
Expand All @@ -64,17 +80,35 @@ async def get_audio_resource() -> EmbeddedResource:
)


@mcp.tool()
async def get_audio_resource_1() -> ResourceLink:
return ResourceLink(
type='resource_link',
uri=AnyUrl(Path(__file__).parent.joinpath('assets/marcelo.mp3').absolute().as_uri()),
name='marcelo.mp3',
)


@mcp.tool()
async def get_product_name() -> EmbeddedResource:
return EmbeddedResource(
type='resource',
resource=TextResourceContents(
uri='resource://product_name.txt', # type: ignore
uri=AnyUrl('resource://product_name.txt'),
text='PydanticAI',
),
)


@mcp.tool()
async def get_product_name_1() -> ResourceLink:
return ResourceLink(
type='resource_link',
uri=AnyUrl(Path(__file__).parent.joinpath('assets/product_name.txt').absolute().as_uri()),
name='product_name.txt',
)


@mcp.tool()
async def get_image() -> Image:
data = Path(__file__).parent.joinpath('assets/kiwi.png').read_bytes()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def completion_message(
return MistralChatCompletionResponse(
id='123',
choices=[MistralChatCompletionChoice(finish_reason='stop', index=0, message=message)],
created=1704067200 if with_created else None, # 2024-01-01
created=1704067200 if with_created else 0, # 2024-01-01
model='mistral-large-123',
object='chat.completion',
usage=usage or MistralUsageInfo(prompt_tokens=1, completion_tokens=1, total_tokens=1),
Expand Down
Loading
Loading