Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/openlayer/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def trace_openai(client):
return openai_tracer.trace_openai(client)


def trace_async_openai(client):
"""Trace OpenAI chat completions."""
# pylint: disable=import-outside-toplevel
import openai

from .integrations import async_openai_tracer

if not isinstance(client, (openai.AsyncOpenAI, openai.AsyncAzureOpenAI)):
raise ValueError("Invalid client. Please provide an OpenAI client.")
return async_openai_tracer.trace_async_openai(client)


def trace_openai_assistant_thread_run(client, run):
"""Trace OpenAI Assistant thread run."""
# pylint: disable=import-outside-toplevel
Expand Down
264 changes: 264 additions & 0 deletions src/openlayer/lib/integrations/async_openai_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""Module with methods used to trace async OpenAI / Azure OpenAI LLMs."""

import json
import logging
import time
from functools import wraps
from typing import Any, Dict, Iterator, Optional, Union

import openai

from .openai_tracer import (
get_model_parameters,
create_trace_args,
add_to_trace,
parse_non_streaming_output_data,
)

logger = logging.getLogger(__name__)


def trace_async_openai(
client: Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI],
) -> Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]:
"""Patch the AsyncOpenAI or AsyncAzureOpenAI client to trace chat completions.

The following information is collected for each chat completion:
- start_time: The time when the completion was requested.
- end_time: The time when the completion was received.
- latency: The time it took to generate the completion.
- tokens: The total number of tokens used to generate the completion.
- prompt_tokens: The number of tokens in the prompt.
- completion_tokens: The number of tokens in the completion.
- model: The model used to generate the completion.
- model_parameters: The parameters used to configure the model.
- raw_output: The raw output of the model.
- inputs: The inputs used to generate the completion.
- metadata: Additional metadata about the completion. For example, the time it
took to generate the first token, when streaming.

Parameters
----------
client : Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]
The AsyncOpenAI client to patch.

Returns
-------
Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]
The patched AsyncOpenAI client.
"""
is_azure_openai = isinstance(client, openai.AsyncAzureOpenAI)
create_func = client.chat.completions.create

@wraps(create_func)
async def traced_create_func(*args, **kwargs):
inference_id = kwargs.pop("inference_id", None)
stream = kwargs.get("stream", False)

if stream:
return await handle_async_streaming_create(
*args,
**kwargs,
create_func=create_func,
inference_id=inference_id,
is_azure_openai=is_azure_openai,
)
return await handle_async_non_streaming_create(
*args,
**kwargs,
create_func=create_func,
inference_id=inference_id,
is_azure_openai=is_azure_openai,
)

client.chat.completions.create = traced_create_func
return client


async def handle_async_streaming_create(
create_func: callable,
*args,
is_azure_openai: bool = False,
inference_id: Optional[str] = None,
**kwargs,
) -> Iterator[Any]:
"""Handles the create method when streaming is enabled.

Parameters
----------
create_func : callable
The create method to handle.
is_azure_openai : bool, optional
Whether the client is an Azure OpenAI client, by default False
inference_id : Optional[str], optional
A user-generated inference id, by default None

Returns
-------
Iterator[Any]
A generator that yields the chunks of the completion.
"""
chunks = await create_func(*args, **kwargs)
return await stream_async_chunks(
chunks=chunks,
kwargs=kwargs,
inference_id=inference_id,
is_azure_openai=is_azure_openai,
)


async def stream_async_chunks(
chunks: Iterator[Any],
kwargs: Dict[str, any],
is_azure_openai: bool = False,
inference_id: Optional[str] = None,
):
"""Streams the chunks of the completion and traces the completion."""
collected_output_data = []
collected_function_call = {
"name": "",
"arguments": "",
}
raw_outputs = []
start_time = time.time()
end_time = None
first_token_time = None
num_of_completion_tokens = None
latency = None
try:
i = 0
async for chunk in chunks:
raw_outputs.append(chunk.model_dump())
if i == 0:
first_token_time = time.time()
if i > 0:
num_of_completion_tokens = i + 1
i += 1

delta = chunk.choices[0].delta

if delta.content:
collected_output_data.append(delta.content)
elif delta.function_call:
if delta.function_call.name:
collected_function_call["name"] += delta.function_call.name
if delta.function_call.arguments:
collected_function_call["arguments"] += (
delta.function_call.arguments
)
elif delta.tool_calls:
if delta.tool_calls[0].function.name:
collected_function_call["name"] += delta.tool_calls[0].function.name
if delta.tool_calls[0].function.arguments:
collected_function_call["arguments"] += delta.tool_calls[
0
].function.arguments

yield chunk
end_time = time.time()
latency = (end_time - start_time) * 1000
# pylint: disable=broad-except
except Exception as e:
logger.error("Failed yield chunk. %s", e)
finally:
# Try to add step to the trace
try:
collected_output_data = [
message for message in collected_output_data if message is not None
]
if collected_output_data:
output_data = "".join(collected_output_data)
else:
collected_function_call["arguments"] = json.loads(
collected_function_call["arguments"]
)
output_data = collected_function_call

trace_args = create_trace_args(
end_time=end_time,
inputs={"prompt": kwargs["messages"]},
output=output_data,
latency=latency,
tokens=num_of_completion_tokens,
prompt_tokens=0,
completion_tokens=num_of_completion_tokens,
model=kwargs.get("model"),
model_parameters=get_model_parameters(kwargs),
raw_output=raw_outputs,
id=inference_id,
metadata={
"timeToFirstToken": (
(first_token_time - start_time) * 1000
if first_token_time
else None
)
},
)
add_to_trace(
**trace_args,
is_azure_openai=is_azure_openai,
)

# pylint: disable=broad-except
except Exception as e:
logger.error(
"Failed to trace the create chat completion request with Openlayer. %s",
e,
)


async def handle_async_non_streaming_create(
create_func: callable,
*args,
is_azure_openai: bool = False,
inference_id: Optional[str] = None,
**kwargs,
) -> "openai.types.chat.chat_completion.ChatCompletion":
"""Handles the create method when streaming is disabled.

Parameters
----------
create_func : callable
The create method to handle.
is_azure_openai : bool, optional
Whether the client is an Azure OpenAI client, by default False
inference_id : Optional[str], optional
A user-generated inference id, by default None

Returns
-------
openai.types.chat.chat_completion.ChatCompletion
The chat completion response.
"""
start_time = time.time()
response = await create_func(*args, **kwargs)
end_time = time.time()

# Try to add step to the trace
try:
output_data = parse_non_streaming_output_data(response)
trace_args = create_trace_args(
end_time=end_time,
inputs={"prompt": kwargs["messages"]},
output=output_data,
latency=(end_time - start_time) * 1000,
tokens=response.usage.total_tokens,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
model=response.model,
model_parameters=get_model_parameters(kwargs),
raw_output=response.model_dump(),
id=inference_id,
)

add_to_trace(
is_azure_openai=is_azure_openai,
**trace_args,
)
# pylint: disable=broad-except
except Exception as e:
logger.error(
"Failed to trace the create chat completion request with Openlayer. %s", e
)

return response
44 changes: 34 additions & 10 deletions src/openlayer/lib/integrations/openai_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,16 @@ def stream_chunks(
if delta.function_call.name:
collected_function_call["name"] += delta.function_call.name
if delta.function_call.arguments:
collected_function_call["arguments"] += delta.function_call.arguments
collected_function_call["arguments"] += (
delta.function_call.arguments
)
elif delta.tool_calls:
if delta.tool_calls[0].function.name:
collected_function_call["name"] += delta.tool_calls[0].function.name
if delta.tool_calls[0].function.arguments:
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments
collected_function_call["arguments"] += delta.tool_calls[
0
].function.arguments

yield chunk
end_time = time.time()
Expand All @@ -153,11 +157,15 @@ def stream_chunks(
finally:
# Try to add step to the trace
try:
collected_output_data = [message for message in collected_output_data if message is not None]
collected_output_data = [
message for message in collected_output_data if message is not None
]
if collected_output_data:
output_data = "".join(collected_output_data)
else:
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
collected_function_call["arguments"] = json.loads(
collected_function_call["arguments"]
)
output_data = collected_function_call

trace_args = create_trace_args(
Expand All @@ -172,7 +180,13 @@ def stream_chunks(
model_parameters=get_model_parameters(kwargs),
raw_output=raw_outputs,
id=inference_id,
metadata={"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None)},
metadata={
"timeToFirstToken": (
(first_token_time - start_time) * 1000
if first_token_time
else None
)
},
)
add_to_trace(
**trace_args,
Expand Down Expand Up @@ -240,8 +254,12 @@ def create_trace_args(
def add_to_trace(is_azure_openai: bool = False, **kwargs) -> None:
"""Add a chat completion step to the trace."""
if is_azure_openai:
tracer.add_chat_completion_step_to_trace(**kwargs, name="Azure OpenAI Chat Completion", provider="Azure")
tracer.add_chat_completion_step_to_trace(**kwargs, name="OpenAI Chat Completion", provider="OpenAI")
tracer.add_chat_completion_step_to_trace(
**kwargs, name="Azure OpenAI Chat Completion", provider="Azure"
)
tracer.add_chat_completion_step_to_trace(
**kwargs, name="OpenAI Chat Completion", provider="OpenAI"
)


def handle_non_streaming_create(
Expand Down Expand Up @@ -294,7 +312,9 @@ def handle_non_streaming_create(
)
# pylint: disable=broad-except
except Exception as e:
logger.error("Failed to trace the create chat completion request with Openlayer. %s", e)
logger.error(
"Failed to trace the create chat completion request with Openlayer. %s", e
)

return response

Expand Down Expand Up @@ -336,7 +356,9 @@ def parse_non_streaming_output_data(


# --------------------------- OpenAI Assistants API -------------------------- #
def trace_openai_assistant_thread_run(client: openai.OpenAI, run: "openai.types.beta.threads.run.Run") -> None:
def trace_openai_assistant_thread_run(
client: openai.OpenAI, run: "openai.types.beta.threads.run.Run"
) -> None:
"""Trace a run from an OpenAI assistant.

Once the run is completed, the thread data is published to Openlayer,
Expand All @@ -353,7 +375,9 @@ def trace_openai_assistant_thread_run(client: openai.OpenAI, run: "openai.types.
metadata = _extract_run_metadata(run)

# Convert thread to prompt
messages = client.beta.threads.messages.list(thread_id=run.thread_id, order="asc")
messages = client.beta.threads.messages.list(
thread_id=run.thread_id, order="asc"
)
prompt = _thread_messages_to_prompt(messages)

# Add step to the trace
Expand Down