6
6
from contextlib import asynccontextmanager
7
7
from dataclasses import InitVar , dataclass , field
8
8
from datetime import date , datetime , timedelta
9
- from typing import Any , Literal , Union
9
+ from typing import Any , Literal
10
10
11
11
import pydantic_core
12
- from typing_extensions import TypeAlias , assert_never
12
+ from typing_extensions import assert_never
13
13
14
14
from .. import _utils
15
15
from ..messages import (
@@ -45,54 +45,6 @@ class _WrappedToolOutput:
45
45
value : Any | None
46
46
47
47
48
- @dataclass
49
- class TestToolCallPart :
50
- """Represents a tool call in the test model."""
51
-
52
- # NOTE: Avoid test discovery by pytest.
53
- __test__ = False
54
-
55
- call_tools : list [str ] | Literal ['all' ] = 'all'
56
-
57
-
58
- @dataclass
59
- class TestTextPart :
60
- """Represents a text part in the test model."""
61
-
62
- # NOTE: Avoid test discovery by pytest.
63
- __test__ = False
64
-
65
- text : str
66
-
67
-
68
- @dataclass
69
- class TestThinkingPart :
70
- """Represents a thinking part in the test model.
71
-
72
- This is used to simulate the model thinking about the response.
73
- """
74
-
75
- # NOTE: Avoid test discovery by pytest.
76
- __test__ = False
77
-
78
- content : str = 'Thinking...'
79
-
80
-
81
- TestPart : TypeAlias = Union [TestTextPart , TestToolCallPart , TestThinkingPart ]
82
- """A part of the test model response."""
83
-
84
-
85
- @dataclass
86
- class TestNode :
87
- """A node in the test model."""
88
-
89
- # NOTE: Avoid test discovery by pytest.
90
- __test__ = False
91
-
92
- parts : list [TestPart ]
93
- id : str = field (default_factory = _utils .generate_tool_call_id )
94
-
95
-
96
48
@dataclass
97
49
class TestModel (Model ):
98
50
"""A model specifically for testing purposes.
@@ -111,10 +63,6 @@ class TestModel(Model):
111
63
112
64
call_tools : list [str ] | Literal ['all' ] = 'all'
113
65
"""List of tools to call. If `'all'`, all tools will be called."""
114
- tool_call_deltas : set [str ] = field (default_factory = set )
115
- """A set of tool call names which should result in tool call part deltas."""
116
- custom_response_nodes : list [TestNode ] | None = None
117
- """A list of nodes which defines a custom model response."""
118
66
custom_output_text : str | None = None
119
67
"""If set, this text is returned as the final output."""
120
68
custom_output_args : Any | None = None
@@ -154,10 +102,7 @@ async def request_stream(
154
102
155
103
model_response = self ._request (messages , model_settings , model_request_parameters )
156
104
yield TestStreamedResponse (
157
- _model_name = self ._model_name ,
158
- _structured_response = model_response ,
159
- _messages = messages ,
160
- _tool_call_deltas = self .tool_call_deltas ,
105
+ _model_name = self ._model_name , _structured_response = model_response , _messages = messages
161
106
)
162
107
163
108
@property
@@ -196,84 +141,32 @@ def _get_output(self, model_request_parameters: ModelRequestParameters) -> _Wrap
196
141
197
142
if k := output_tool .outer_typed_dict_key :
198
143
return _WrappedToolOutput ({k : self .custom_output_args })
199
-
200
- return _WrappedToolOutput (self .custom_output_args )
144
+ else :
145
+ return _WrappedToolOutput (self .custom_output_args )
201
146
elif model_request_parameters .allow_text_output :
202
147
return _WrappedTextOutput (None )
203
- elif model_request_parameters .output_tools : # pragma: no branch
148
+ elif model_request_parameters .output_tools :
204
149
return _WrappedToolOutput (None )
205
150
else :
206
- return _WrappedTextOutput (None ) # pragma: no cover
207
-
208
- def _node_response (
209
- self ,
210
- messages : list [ModelMessage ],
211
- model_request_parameters : ModelRequestParameters ,
212
- ) -> ModelResponse | None :
213
- """Returns a ModelResponse based on configured nodes.
214
-
215
- Args:
216
- messages: The messages sent to the model.
217
- model_request_parameters: The parameters for the model request.
218
-
219
- Returns:
220
- The response from the model, or `None` if no nodes configured or
221
- all nodes have already been processed.
222
- """
223
- if not self .custom_response_nodes :
224
- # No nodes configured, follow the default behaviour.
225
- return None
226
-
227
- # Pick up where we left off by counting the number of ModelResponse messages in the stream.
228
- # This allows us to stream the response in chunks, simulating a real model response.
229
- node : TestNode
230
- count : int = sum (isinstance (m , ModelResponse ) for m in messages )
231
- if count < len (self .custom_response_nodes ):
232
- node : TestNode = self .custom_response_nodes [count ]
233
- assert node .parts , 'Node parts should not be empty.'
234
-
235
- parts : list [ModelResponsePart ] = []
236
- part : TestPart
237
- for part in node .parts :
238
- if isinstance (part , TestTextPart ): # pragma: no branch
239
- assert model_request_parameters .allow_text_output , ( # pragma: no cover
240
- 'Plain response not allowed, but `part` is a `TestText`.'
241
- )
242
- parts .append (TextPart (part .text )) # pragma: no cover
243
- elif isinstance (part , TestToolCallPart ): # pragma: no branch
244
- tool_calls = self ._get_tool_calls (model_request_parameters )
245
- if part .call_tools == 'all' : # pragma: no branch
246
- parts .extend (
247
- ToolCallPart (name , self .gen_tool_args (args )) for name , args in tool_calls
248
- ) # pragma: no cover
249
- else :
250
- parts .extend (
251
- ToolCallPart (name , self .gen_tool_args (args ))
252
- for name , args in tool_calls
253
- if name in part .call_tools
254
- )
255
- elif isinstance (part , TestThinkingPart ): # pragma: no branch
256
- parts .append (ThinkingPart (content = part .content ))
257
- return ModelResponse (vendor_id = node .id , parts = parts , model_name = self ._model_name )
151
+ return _WrappedTextOutput (None )
258
152
259
153
def _request (
260
154
self ,
261
155
messages : list [ModelMessage ],
262
156
model_settings : ModelSettings | None ,
263
157
model_request_parameters : ModelRequestParameters ,
264
158
) -> ModelResponse :
265
- if (response := self ._node_response (messages , model_request_parameters )) is not None :
266
- return response
267
-
268
159
tool_calls = self ._get_tool_calls (model_request_parameters )
160
+ output_wrapper = self ._get_output (model_request_parameters )
161
+ output_tools = model_request_parameters .output_tools
162
+
163
+ # if there are tools, the first thing we want to do is call all of them
269
164
if tool_calls and not any (isinstance (m , ModelResponse ) for m in messages ):
270
165
return ModelResponse (
271
166
parts = [ToolCallPart (name , self .gen_tool_args (args )) for name , args in tool_calls ],
272
167
model_name = self ._model_name ,
273
168
)
274
169
275
- output_wrapper = self ._get_output (model_request_parameters )
276
- output_tools = model_request_parameters .output_tools
277
170
if messages : # pragma: no branch
278
171
last_message = messages [- 1 ]
279
172
assert isinstance (last_message , ModelRequest ), 'Expected last message to be a `ModelRequest`.'
@@ -339,7 +232,6 @@ class TestStreamedResponse(StreamedResponse):
339
232
_model_name : str
340
233
_structured_response : ModelResponse
341
234
_messages : InitVar [Iterable [ModelMessage ]]
342
- _tool_call_deltas : set [str ]
343
235
_timestamp : datetime = field (default_factory = _utils .now_utc , init = False )
344
236
345
237
def __post_init__ (self , _messages : Iterable [ModelMessage ]):
@@ -361,47 +253,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
361
253
self ._usage += _get_string_usage (word )
362
254
yield self ._parts_manager .handle_text_delta (vendor_part_id = i , content = word )
363
255
elif isinstance (part , ToolCallPart ):
364
- if part .tool_name in self ._tool_call_deltas :
365
- # Start with empty tool call delta.
366
- event = self ._parts_manager .handle_tool_call_delta (
367
- vendor_part_id = i , tool_name = part .tool_name , args = '' , tool_call_id = part .tool_call_id
368
- )
369
- if event is not None : # pragma: no branch
370
- yield event
371
-
372
- # Stream the args as JSON string in chunks.
373
- args_json = pydantic_core .to_json (part .args ).decode ()
374
- * chunks , last_chunk = args_json .split (',' ) if ',' in args_json else [args_json ]
375
- chunks = [f'{ chunk } ,' for chunk in chunks ] if chunks else []
376
- if last_chunk : # pragma: no branch
377
- chunks .append (last_chunk )
378
-
379
- for chunk in chunks :
380
- event = self ._parts_manager .handle_tool_call_delta (
381
- vendor_part_id = i , tool_name = None , args = chunk , tool_call_id = part .tool_call_id
382
- )
383
- if event is not None : # pragma: no branch
384
- yield event
385
- else :
386
- yield self ._parts_manager .handle_tool_call_part (
387
- vendor_part_id = i , tool_name = part .tool_name , args = part .args , tool_call_id = part .tool_call_id
388
- )
389
- elif isinstance (part , ThinkingPart ):
390
- content_json = pydantic_core .to_json (part .content ).decode ()
391
- * chunks , last_chunk = content_json .split (' ' ) if ' ' in content_json else [content_json ]
392
- if len (chunks ) == 0 :
393
- # Single word thinking delta.
394
- yield self ._parts_manager .handle_thinking_delta (vendor_part_id = i , content = content_json )
395
- else :
396
- # Start with empty thinking delta.
397
- yield self ._parts_manager .handle_thinking_delta (vendor_part_id = i , content = '' )
398
-
399
- # Stream the content as JSON string in chunks.
400
- chunks = [f'{ chunk } ' for chunk in chunks ] if chunks else []
401
- chunks .append (last_chunk )
402
-
403
- for chunk in chunks :
404
- yield self ._parts_manager .handle_thinking_delta (vendor_part_id = i , content = chunk )
256
+ yield self ._parts_manager .handle_tool_call_part (
257
+ vendor_part_id = i , tool_name = part .tool_name , args = part .args , tool_call_id = part .tool_call_id
258
+ )
259
+ elif isinstance (part , ThinkingPart ): # pragma: no cover
260
+ # NOTE: There's no way to reach this part of the code, since we don't generate ThinkingPart on TestModel.
261
+ assert False , "This should be unreachable — we don't generate ThinkingPart on TestModel."
405
262
else :
406
263
assert_never (part )
407
264
0 commit comments