Skip to content

Commit ce13918

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
fix(openai tracer): object async_generator can't be used in 'await' expression
1 parent 0c7b25e commit ce13918

File tree

1 file changed

+9
-21
lines changed

1 file changed

+9
-21
lines changed

src/openlayer/lib/integrations/async_openai_tracer.py

+9-21
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import time
66
from functools import wraps
7-
from typing import Any, Dict, Iterator, Optional, Union
7+
from typing import Any, AsyncIterator, Optional, Union
88

99
import openai
1010

@@ -56,7 +56,7 @@ async def traced_create_func(*args, **kwargs):
5656
stream = kwargs.get("stream", False)
5757

5858
if stream:
59-
return await handle_async_streaming_create(
59+
return handle_async_streaming_create(
6060
*args,
6161
**kwargs,
6262
create_func=create_func,
@@ -81,7 +81,7 @@ async def handle_async_streaming_create(
8181
is_azure_openai: bool = False,
8282
inference_id: Optional[str] = None,
8383
**kwargs,
84-
) -> Iterator[Any]:
84+
) -> AsyncIterator[Any]:
8585
"""Handles the create method when streaming is enabled.
8686
8787
Parameters
@@ -95,25 +95,12 @@ async def handle_async_streaming_create(
9595
9696
Returns
9797
-------
98-
Iterator[Any]
98+
AsyncIterator[Any]
9999
A generator that yields the chunks of the completion.
100100
"""
101101
chunks = await create_func(*args, **kwargs)
102-
return await stream_async_chunks(
103-
chunks=chunks,
104-
kwargs=kwargs,
105-
inference_id=inference_id,
106-
is_azure_openai=is_azure_openai,
107-
)
108102

109-
110-
async def stream_async_chunks(
111-
chunks: Iterator[Any],
112-
kwargs: Dict[str, any],
113-
is_azure_openai: bool = False,
114-
inference_id: Optional[str] = None,
115-
):
116-
"""Streams the chunks of the completion and traces the completion."""
103+
# Create and return a new async generator that processes chunks
117104
collected_output_data = []
118105
collected_function_call = {
119106
"name": "",
@@ -143,9 +130,9 @@ async def stream_async_chunks(
143130
if delta.function_call.name:
144131
collected_function_call["name"] += delta.function_call.name
145132
if delta.function_call.arguments:
146-
collected_function_call["arguments"] += (
147-
delta.function_call.arguments
148-
)
133+
collected_function_call[
134+
"arguments"
135+
] += delta.function_call.arguments
149136
elif delta.tool_calls:
150137
if delta.tool_calls[0].function.name:
151138
collected_function_call["name"] += delta.tool_calls[0].function.name
@@ -155,6 +142,7 @@ async def stream_async_chunks(
155142
].function.arguments
156143

157144
yield chunk
145+
158146
end_time = time.time()
159147
latency = (end_time - start_time) * 1000
160148
# pylint: disable=broad-except

0 commit comments

Comments
 (0)