4
4
import logging
5
5
import time
6
6
from functools import wraps
7
- from typing import Any , Dict , Iterator , Optional , Union
7
+ from typing import Any , AsyncIterator , Optional , Union
8
8
9
9
import openai
10
10
@@ -56,7 +56,7 @@ async def traced_create_func(*args, **kwargs):
56
56
stream = kwargs .get ("stream" , False )
57
57
58
58
if stream :
59
- return await handle_async_streaming_create (
59
+ return handle_async_streaming_create (
60
60
* args ,
61
61
** kwargs ,
62
62
create_func = create_func ,
@@ -81,7 +81,7 @@ async def handle_async_streaming_create(
81
81
is_azure_openai : bool = False ,
82
82
inference_id : Optional [str ] = None ,
83
83
** kwargs ,
84
- ) -> Iterator [Any ]:
84
+ ) -> AsyncIterator [Any ]:
85
85
"""Handles the create method when streaming is enabled.
86
86
87
87
Parameters
@@ -95,25 +95,12 @@ async def handle_async_streaming_create(
95
95
96
96
Returns
97
97
-------
98
- Iterator [Any]
98
+ AsyncIterator [Any]
99
99
A generator that yields the chunks of the completion.
100
100
"""
101
101
chunks = await create_func (* args , ** kwargs )
102
- return await stream_async_chunks (
103
- chunks = chunks ,
104
- kwargs = kwargs ,
105
- inference_id = inference_id ,
106
- is_azure_openai = is_azure_openai ,
107
- )
108
102
109
-
110
- async def stream_async_chunks (
111
- chunks : Iterator [Any ],
112
- kwargs : Dict [str , any ],
113
- is_azure_openai : bool = False ,
114
- inference_id : Optional [str ] = None ,
115
- ):
116
- """Streams the chunks of the completion and traces the completion."""
103
+ # Create and return a new async generator that processes chunks
117
104
collected_output_data = []
118
105
collected_function_call = {
119
106
"name" : "" ,
@@ -143,9 +130,9 @@ async def stream_async_chunks(
143
130
if delta .function_call .name :
144
131
collected_function_call ["name" ] += delta .function_call .name
145
132
if delta .function_call .arguments :
146
- collected_function_call ["arguments" ] += (
147
- delta . function_call . arguments
148
- )
133
+ collected_function_call [
134
+ " arguments"
135
+ ] += delta . function_call . arguments
149
136
elif delta .tool_calls :
150
137
if delta .tool_calls [0 ].function .name :
151
138
collected_function_call ["name" ] += delta .tool_calls [0 ].function .name
@@ -155,6 +142,7 @@ async def stream_async_chunks(
155
142
].function .arguments
156
143
157
144
yield chunk
145
+
158
146
end_time = time .time ()
159
147
latency = (end_time - start_time ) * 1000
160
148
# pylint: disable=broad-except
0 commit comments