Skip to content

Commit cc1b6cb

Browse files
committed
chore(ag-ui): switch tests to FunctionModel
Switch the tests from TestModel to FunctionModel, which only needs less changes to provide the same functionality. Also use `IsStr`` and `snapshot`` to improve test readability and maintainability. This adds `ThinkingPart`` support to the `FunctionModel` via the new `DeltaThinkingCall` class.
1 parent 7c5c49d commit cc1b6cb

File tree

4 files changed

+877
-946
lines changed

4 files changed

+877
-946
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def _convert_history(messages: list[Message]) -> _History:
535535
result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)]))
536536
elif isinstance(msg, ToolMessage):
537537
tool_name = tool_calls.get(msg.tool_call_id)
538-
if tool_name is None:
538+
if tool_name is None: # pragma: no cover
539539
raise ToolCallNotFoundError(tool_call_id=msg.tool_call_id)
540540

541541
result.append(
@@ -587,7 +587,7 @@ class _RunError(Exception):
587587
message: str
588588
code: str
589589

590-
def __str__(self) -> str:
590+
def __str__(self) -> str: # pragma: no cover
591591
return self.message
592592

593593

@@ -620,7 +620,7 @@ class ToolCallNotFoundError(_RunError, ValueError):
620620

621621
def __init__(self, tool_call_id: str) -> None:
622622
"""Initialize the exception with the tool call ID."""
623-
super().__init__(
623+
super().__init__( # pragma: no cover
624624
message=f'Tool call with ID {tool_call_id} not found in the history.',
625625
code='tool_call_not_found',
626626
)

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -203,21 +203,39 @@ class DeltaToolCall:
203203
"""Incremental change to the tool call ID."""
204204

205205

206+
@dataclass
207+
class DeltaThinkingCall:
208+
"""Incremental change to a thinking part.
209+
210+
Used to describe a chunk when streaming thinking responses.
211+
"""
212+
213+
content_delta: str | None = None
214+
"""Incremental change to the thinking content."""
215+
signature_delta: str | None = None
216+
"""Incremental change to the thinking signature."""
217+
218+
206219
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
207220
"""A mapping of tool call IDs to incremental changes."""
208221

222+
DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingCall]
223+
"""A mapping of thinking call IDs to incremental changes."""
224+
209225
# TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
210226
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
211227
"""A function used to generate a non-streamed response."""
212228

213229
# TODO: Change signature as indicated above
214-
StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
230+
StreamFunctionDef: TypeAlias = Callable[
231+
[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]
232+
]
215233
"""A function used to generate a streamed response.
216234
217-
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
218-
really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`,
235+
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]`, it should
236+
really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls], AsyncIterator[DeltaThinkingCalls]]`,
219237
220-
E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
238+
E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them.
221239
"""
222240

223241

@@ -226,7 +244,7 @@ class FunctionStreamedResponse(StreamedResponse):
226244
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
227245

228246
_model_name: str
229-
_iter: AsyncIterator[str | DeltaToolCalls]
247+
_iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls]
230248
_timestamp: datetime = field(default_factory=_utils.now_utc)
231249

232250
def __post_init__(self):
@@ -238,20 +256,41 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
238256
response_tokens = _estimate_string_tokens(item)
239257
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
240258
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
241-
else:
242-
delta_tool_calls = item
243-
for dtc_index, delta_tool_call in delta_tool_calls.items():
244-
if delta_tool_call.json_args:
245-
response_tokens = _estimate_string_tokens(delta_tool_call.json_args)
246-
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
247-
maybe_event = self._parts_manager.handle_tool_call_delta(
248-
vendor_part_id=dtc_index,
249-
tool_name=delta_tool_call.name,
250-
args=delta_tool_call.json_args,
251-
tool_call_id=delta_tool_call.tool_call_id,
252-
)
253-
if maybe_event is not None:
254-
yield maybe_event
259+
elif isinstance(item, dict) and item:
260+
first_value = next(iter(item.values()))
261+
if isinstance(first_value, DeltaThinkingCall):
262+
# Handle DeltaThinkingCalls.
263+
for dtc_index, delta_call in item.items():
264+
if not isinstance(delta_call, DeltaThinkingCall): # pragma: no branch
265+
raise TypeError( # pragma: no cover
266+
f'Expected DeltaThinkingCall, got {type(delta_call).__name__} for index {dtc_index}'
267+
)
268+
if delta_call.content_delta: # pragma: no branch
269+
response_tokens = _estimate_string_tokens(delta_call.content_delta)
270+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
271+
yield self._parts_manager.handle_thinking_delta(
272+
vendor_part_id=dtc_index,
273+
content=delta_call.content_delta,
274+
signature=delta_call.signature_delta,
275+
)
276+
else:
277+
# Handle DeltaToolCalls.
278+
for dtc_index, delta_call in item.items():
279+
if not isinstance(delta_call, DeltaToolCall): # pragma: no branch
280+
raise TypeError( # pragma: no cover
281+
f'Expected DeltaToolCall, got {type(delta_call).__name__} for index {dtc_index}'
282+
)
283+
if delta_call.json_args:
284+
response_tokens = _estimate_string_tokens(delta_call.json_args)
285+
self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens)
286+
maybe_event = self._parts_manager.handle_tool_call_delta(
287+
vendor_part_id=dtc_index,
288+
tool_name=delta_call.name,
289+
args=delta_call.json_args,
290+
tool_call_id=delta_call.tool_call_id,
291+
)
292+
if maybe_event is not None:
293+
yield maybe_event
255294

256295
@property
257296
def model_name(self) -> str:
@@ -288,12 +327,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
288327
if isinstance(part, TextPart):
289328
response_tokens += _estimate_string_tokens(part.content)
290329
elif isinstance(part, ThinkingPart):
291-
# NOTE: We don't send ThinkingPart to the providers yet.
292-
# If you are unsatisfied with this, please open an issue.
293-
pass
330+
response_tokens += _estimate_string_tokens(part.content) # pragma: no cover
294331
elif isinstance(part, ToolCallPart):
295-
call = part
296-
response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str())
332+
response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str())
297333
else:
298334
assert_never(part)
299335
else:

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 17 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from contextlib import asynccontextmanager
77
from dataclasses import InitVar, dataclass, field
88
from datetime import date, datetime, timedelta
9-
from typing import Any, Literal, Union
9+
from typing import Any, Literal
1010

1111
import pydantic_core
12-
from typing_extensions import TypeAlias, assert_never
12+
from typing_extensions import assert_never
1313

1414
from .. import _utils
1515
from ..messages import (
@@ -45,54 +45,6 @@ class _WrappedToolOutput:
4545
value: Any | None
4646

4747

48-
@dataclass
49-
class TestToolCallPart:
50-
"""Represents a tool call in the test model."""
51-
52-
# NOTE: Avoid test discovery by pytest.
53-
__test__ = False
54-
55-
call_tools: list[str] | Literal['all'] = 'all'
56-
57-
58-
@dataclass
59-
class TestTextPart:
60-
"""Represents a text part in the test model."""
61-
62-
# NOTE: Avoid test discovery by pytest.
63-
__test__ = False
64-
65-
text: str
66-
67-
68-
@dataclass
69-
class TestThinkingPart:
70-
"""Represents a thinking part in the test model.
71-
72-
This is used to simulate the model thinking about the response.
73-
"""
74-
75-
# NOTE: Avoid test discovery by pytest.
76-
__test__ = False
77-
78-
content: str = 'Thinking...'
79-
80-
81-
TestPart: TypeAlias = Union[TestTextPart, TestToolCallPart, TestThinkingPart]
82-
"""A part of the test model response."""
83-
84-
85-
@dataclass
86-
class TestNode:
87-
"""A node in the test model."""
88-
89-
# NOTE: Avoid test discovery by pytest.
90-
__test__ = False
91-
92-
parts: list[TestPart]
93-
id: str = field(default_factory=_utils.generate_tool_call_id)
94-
95-
9648
@dataclass
9749
class TestModel(Model):
9850
"""A model specifically for testing purposes.
@@ -111,10 +63,6 @@ class TestModel(Model):
11163

11264
call_tools: list[str] | Literal['all'] = 'all'
11365
"""List of tools to call. If `'all'`, all tools will be called."""
114-
tool_call_deltas: set[str] = field(default_factory=set)
115-
"""A set of tool call names which should result in tool call part deltas."""
116-
custom_response_nodes: list[TestNode] | None = None
117-
"""A list of nodes which defines a custom model response."""
11866
custom_output_text: str | None = None
11967
"""If set, this text is returned as the final output."""
12068
custom_output_args: Any | None = None
@@ -154,10 +102,7 @@ async def request_stream(
154102

155103
model_response = self._request(messages, model_settings, model_request_parameters)
156104
yield TestStreamedResponse(
157-
_model_name=self._model_name,
158-
_structured_response=model_response,
159-
_messages=messages,
160-
_tool_call_deltas=self.tool_call_deltas,
105+
_model_name=self._model_name, _structured_response=model_response, _messages=messages
161106
)
162107

163108
@property
@@ -196,84 +141,32 @@ def _get_output(self, model_request_parameters: ModelRequestParameters) -> _Wrap
196141

197142
if k := output_tool.outer_typed_dict_key:
198143
return _WrappedToolOutput({k: self.custom_output_args})
199-
200-
return _WrappedToolOutput(self.custom_output_args)
144+
else:
145+
return _WrappedToolOutput(self.custom_output_args)
201146
elif model_request_parameters.allow_text_output:
202147
return _WrappedTextOutput(None)
203-
elif model_request_parameters.output_tools: # pragma: no branch
148+
elif model_request_parameters.output_tools:
204149
return _WrappedToolOutput(None)
205150
else:
206-
return _WrappedTextOutput(None) # pragma: no cover
207-
208-
def _node_response(
209-
self,
210-
messages: list[ModelMessage],
211-
model_request_parameters: ModelRequestParameters,
212-
) -> ModelResponse | None:
213-
"""Returns a ModelResponse based on configured nodes.
214-
215-
Args:
216-
messages: The messages sent to the model.
217-
model_request_parameters: The parameters for the model request.
218-
219-
Returns:
220-
The response from the model, or `None` if no nodes configured or
221-
all nodes have already been processed.
222-
"""
223-
if not self.custom_response_nodes:
224-
# No nodes configured, follow the default behaviour.
225-
return None
226-
227-
# Pick up where we left off by counting the number of ModelResponse messages in the stream.
228-
# This allows us to stream the response in chunks, simulating a real model response.
229-
node: TestNode
230-
count: int = sum(isinstance(m, ModelResponse) for m in messages)
231-
if count < len(self.custom_response_nodes):
232-
node: TestNode = self.custom_response_nodes[count]
233-
assert node.parts, 'Node parts should not be empty.'
234-
235-
parts: list[ModelResponsePart] = []
236-
part: TestPart
237-
for part in node.parts:
238-
if isinstance(part, TestTextPart): # pragma: no branch
239-
assert model_request_parameters.allow_text_output, ( # pragma: no cover
240-
'Plain response not allowed, but `part` is a `TestText`.'
241-
)
242-
parts.append(TextPart(part.text)) # pragma: no cover
243-
elif isinstance(part, TestToolCallPart): # pragma: no branch
244-
tool_calls = self._get_tool_calls(model_request_parameters)
245-
if part.call_tools == 'all': # pragma: no branch
246-
parts.extend(
247-
ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls
248-
) # pragma: no cover
249-
else:
250-
parts.extend(
251-
ToolCallPart(name, self.gen_tool_args(args))
252-
for name, args in tool_calls
253-
if name in part.call_tools
254-
)
255-
elif isinstance(part, TestThinkingPart): # pragma: no branch
256-
parts.append(ThinkingPart(content=part.content))
257-
return ModelResponse(vendor_id=node.id, parts=parts, model_name=self._model_name)
151+
return _WrappedTextOutput(None)
258152

259153
def _request(
260154
self,
261155
messages: list[ModelMessage],
262156
model_settings: ModelSettings | None,
263157
model_request_parameters: ModelRequestParameters,
264158
) -> ModelResponse:
265-
if (response := self._node_response(messages, model_request_parameters)) is not None:
266-
return response
267-
268159
tool_calls = self._get_tool_calls(model_request_parameters)
160+
output_wrapper = self._get_output(model_request_parameters)
161+
output_tools = model_request_parameters.output_tools
162+
163+
# if there are tools, the first thing we want to do is call all of them
269164
if tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
270165
return ModelResponse(
271166
parts=[ToolCallPart(name, self.gen_tool_args(args)) for name, args in tool_calls],
272167
model_name=self._model_name,
273168
)
274169

275-
output_wrapper = self._get_output(model_request_parameters)
276-
output_tools = model_request_parameters.output_tools
277170
if messages: # pragma: no branch
278171
last_message = messages[-1]
279172
assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
@@ -339,7 +232,6 @@ class TestStreamedResponse(StreamedResponse):
339232
_model_name: str
340233
_structured_response: ModelResponse
341234
_messages: InitVar[Iterable[ModelMessage]]
342-
_tool_call_deltas: set[str]
343235
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
344236

345237
def __post_init__(self, _messages: Iterable[ModelMessage]):
@@ -361,47 +253,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
361253
self._usage += _get_string_usage(word)
362254
yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
363255
elif isinstance(part, ToolCallPart):
364-
if part.tool_name in self._tool_call_deltas:
365-
# Start with empty tool call delta.
366-
event = self._parts_manager.handle_tool_call_delta(
367-
vendor_part_id=i, tool_name=part.tool_name, args='', tool_call_id=part.tool_call_id
368-
)
369-
if event is not None: # pragma: no branch
370-
yield event
371-
372-
# Stream the args as JSON string in chunks.
373-
args_json = pydantic_core.to_json(part.args).decode()
374-
*chunks, last_chunk = args_json.split(',') if ',' in args_json else [args_json]
375-
chunks = [f'{chunk},' for chunk in chunks] if chunks else []
376-
if last_chunk: # pragma: no branch
377-
chunks.append(last_chunk)
378-
379-
for chunk in chunks:
380-
event = self._parts_manager.handle_tool_call_delta(
381-
vendor_part_id=i, tool_name=None, args=chunk, tool_call_id=part.tool_call_id
382-
)
383-
if event is not None: # pragma: no branch
384-
yield event
385-
else:
386-
yield self._parts_manager.handle_tool_call_part(
387-
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
388-
)
389-
elif isinstance(part, ThinkingPart):
390-
content_json = pydantic_core.to_json(part.content).decode()
391-
*chunks, last_chunk = content_json.split(' ') if ' ' in content_json else [content_json]
392-
if len(chunks) == 0:
393-
# Single word thinking delta.
394-
yield self._parts_manager.handle_thinking_delta(vendor_part_id=i, content=content_json)
395-
else:
396-
# Start with empty thinking delta.
397-
yield self._parts_manager.handle_thinking_delta(vendor_part_id=i, content='')
398-
399-
# Stream the content as JSON string in chunks.
400-
chunks = [f'{chunk} ' for chunk in chunks] if chunks else []
401-
chunks.append(last_chunk)
402-
403-
for chunk in chunks:
404-
yield self._parts_manager.handle_thinking_delta(vendor_part_id=i, content=chunk)
256+
yield self._parts_manager.handle_tool_call_part(
257+
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
258+
)
259+
elif isinstance(part, ThinkingPart): # pragma: no cover
260+
# NOTE: There's no way to reach this part of the code, since we don't generate ThinkingPart on TestModel.
261+
assert False, "This should be unreachable — we don't generate ThinkingPart on TestModel."
405262
else:
406263
assert_never(part)
407264

0 commit comments

Comments
 (0)