Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 110 additions & 102 deletions src/openlayer/lib/integrations/litellm_tracer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module with methods used to trace LiteLLM completions."""

import contextvars
import json
import logging
import time
Expand All @@ -16,6 +17,7 @@
import litellm

from ..tracing import tracer
from ..tracing import enums as tracer_enums

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -154,113 +156,119 @@ def stream_chunks(
provider = "unknown"
latest_chunk_metadata = {}

try:
i = 0
for i, chunk in enumerate(chunks):
raw_outputs.append(chunk.model_dump() if hasattr(chunk, 'model_dump') else str(chunk))

if i == 0:
first_token_time = time.time()
# Try to detect provider from the first chunk
provider = detect_provider_from_chunk(chunk, model_name)

# Extract usage data from this chunk if available (usually in final chunks)
chunk_usage = extract_usage_from_chunk(chunk)
if any(v is not None for v in chunk_usage.values()):
latest_usage_data = chunk_usage
# Create step immediately so it's added to parent trace before parent publishes
with tracer.create_step(
name="LiteLLM Chat Completion",
step_type=tracer_enums.StepType.CHAT_COMPLETION,
inputs={"prompt": kwargs.get("messages", [])},
) as step:
try:
i = 0
for i, chunk in enumerate(chunks):
raw_outputs.append(chunk.model_dump() if hasattr(chunk, 'model_dump') else str(chunk))

# Always update metadata from latest chunk (for cost, headers, etc.)
chunk_metadata = extract_litellm_metadata(chunk, model_name)
if chunk_metadata:
latest_chunk_metadata.update(chunk_metadata)
if i == 0:
first_token_time = time.time()
# Try to detect provider from the first chunk
provider = detect_provider_from_chunk(chunk, model_name)

if i > 0:
num_of_completion_tokens = i + 1

# Handle different chunk formats based on provider
delta = get_delta_from_chunk(chunk)

if delta and hasattr(delta, 'content') and delta.content:
collected_output_data.append(delta.content)
elif delta and hasattr(delta, 'function_call') and delta.function_call:
if delta.function_call.name:
collected_function_call["name"] += delta.function_call.name
if delta.function_call.arguments:
collected_function_call["arguments"] += delta.function_call.arguments
elif delta and hasattr(delta, 'tool_calls') and delta.tool_calls:
if delta.tool_calls[0].function.name:
collected_function_call["name"] += delta.tool_calls[0].function.name
if delta.tool_calls[0].function.arguments:
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments

yield chunk

end_time = time.time()
latency = (end_time - start_time) * 1000

# pylint: disable=broad-except
except Exception as e:
logger.error("Failed to yield chunk. %s", e)
finally:
# Try to add step to the trace
try:
collected_output_data = [message for message in collected_output_data if message is not None]
if collected_output_data:
output_data = "".join(collected_output_data)
else:
if collected_function_call["arguments"]:
try:
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
except json.JSONDecodeError:
pass
output_data = collected_function_call

# Post-streaming calculations (after streaming is finished)
completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost(
chunks=raw_outputs,
messages=kwargs.get("messages", []),
output_content=output_data,
model_name=model_name,
latest_usage_data=latest_usage_data,
latest_chunk_metadata=latest_chunk_metadata
)

# Use calculated values (fall back to extracted data if calculation fails)
usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {}

final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0)
final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens)
final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", final_prompt_tokens + final_completion_tokens)
final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get('cost', None)
# Extract usage data from this chunk if available (usually in final chunks)
chunk_usage = extract_usage_from_chunk(chunk)
if any(v is not None for v in chunk_usage.values()):
latest_usage_data = chunk_usage

# Always update metadata from latest chunk (for cost, headers, etc.)
chunk_metadata = extract_litellm_metadata(chunk, model_name)
if chunk_metadata:
latest_chunk_metadata.update(chunk_metadata)

if i > 0:
num_of_completion_tokens = i + 1

# Handle different chunk formats based on provider
delta = get_delta_from_chunk(chunk)

if delta and hasattr(delta, 'content') and delta.content:
collected_output_data.append(delta.content)
elif delta and hasattr(delta, 'function_call') and delta.function_call:
if delta.function_call.name:
collected_function_call["name"] += delta.function_call.name
if delta.function_call.arguments:
collected_function_call["arguments"] += delta.function_call.arguments
elif delta and hasattr(delta, 'tool_calls') and delta.tool_calls:
if delta.tool_calls[0].function.name:
collected_function_call["name"] += delta.tool_calls[0].function.name
if delta.tool_calls[0].function.arguments:
collected_function_call["arguments"] += delta.tool_calls[0].function.arguments

yield chunk

end_time = time.time()
latency = (end_time - start_time) * 1000

trace_args = create_trace_args(
end_time=end_time,
inputs={"prompt": kwargs.get("messages", [])},
output=output_data,
latency=latency,
tokens=final_total_tokens,
prompt_tokens=final_prompt_tokens,
completion_tokens=final_completion_tokens,
model=model_name,
model_parameters=get_model_parameters(kwargs),
raw_output=raw_outputs,
id=inference_id,
cost=final_cost, # Use calculated cost
metadata={
"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None),
"provider": provider,
"litellm_model": model_name,
**latest_chunk_metadata, # Add all LiteLLM-specific metadata
},
)
add_to_trace(**trace_args)

# pylint: disable=broad-except
except Exception as e:
logger.error(
"Failed to trace the LiteLLM completion request with Openlayer. %s",
e,
)
logger.error("Failed to yield chunk. %s", e)
finally:
# Update step with final data before context manager exits
try:
collected_output_data = [message for message in collected_output_data if message is not None]
if collected_output_data:
output_data = "".join(collected_output_data)
else:
if collected_function_call["arguments"]:
try:
collected_function_call["arguments"] = json.loads(collected_function_call["arguments"])
except json.JSONDecodeError:
pass
output_data = collected_function_call

# Post-streaming calculations (after streaming is finished)
completion_tokens_calculated, prompt_tokens_calculated, total_tokens_calculated, cost_calculated = calculate_streaming_usage_and_cost(
chunks=raw_outputs,
messages=kwargs.get("messages", []),
output_content=output_data,
model_name=model_name,
latest_usage_data=latest_usage_data,
latest_chunk_metadata=latest_chunk_metadata
)

# Use calculated values (fall back to extracted data if calculation fails)
usage_data = latest_usage_data if any(v is not None for v in latest_usage_data.values()) else {}

final_prompt_tokens = prompt_tokens_calculated if prompt_tokens_calculated is not None else usage_data.get("prompt_tokens", 0)
final_completion_tokens = completion_tokens_calculated if completion_tokens_calculated is not None else usage_data.get("completion_tokens", num_of_completion_tokens)
final_total_tokens = total_tokens_calculated if total_tokens_calculated is not None else usage_data.get("total_tokens", final_prompt_tokens + final_completion_tokens)
final_cost = cost_calculated if cost_calculated is not None else latest_chunk_metadata.get('cost', None)

# Update the step with final trace data
step.log(
output=output_data,
latency=latency,
tokens=final_total_tokens,
prompt_tokens=final_prompt_tokens,
completion_tokens=final_completion_tokens,
model=model_name,
model_parameters=get_model_parameters(kwargs),
raw_output=raw_outputs,
id=inference_id,
cost=final_cost,
provider=provider,
metadata={
"timeToFirstToken": ((first_token_time - start_time) * 1000 if first_token_time else None),
"provider": provider,
"litellm_model": model_name,
**latest_chunk_metadata,
},
)

# pylint: disable=broad-except
except Exception as e:
if logger is not None:
logger.error(
"Failed to trace the LiteLLM completion request with Openlayer. %s",
e,
)


def handle_non_streaming_completion(
Expand Down
72 changes: 55 additions & 17 deletions src/openlayer/lib/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def __init__(self):
self._token = None
self._output_chunks = []
self._trace_initialized = False
self._captured_context = None # Capture context for ASGI compatibility

def __iter__(self):
return self
Expand All @@ -522,26 +523,26 @@ def __next__(self):
try:
chunk = next(self._original_gen)
self._output_chunks.append(chunk)
if self._captured_context is None:
self._captured_context = contextvars.copy_context()
return chunk
except StopIteration:
# Finalize trace when generator is exhausted
# Use captured context to ensure we have access to the trace
output = _join_output_chunks(self._output_chunks)
_finalize_sync_generator_step(
step=self._step,
token=self._token,
is_root_step=self._is_root_step,
step_name=step_name,
inputs=self._inputs,
output=output,
inference_pipeline_id=inference_pipeline_id,
on_flush_failure=on_flush_failure,
)
raise
except Exception as exc:
# Handle exceptions
if self._step:
_log_step_exception(self._step, exc)
output = _join_output_chunks(self._output_chunks)
if self._captured_context:
self._captured_context.run(
_finalize_sync_generator_step,
step=self._step,
token=self._token,
is_root_step=self._is_root_step,
step_name=step_name,
inputs=self._inputs,
output=output,
inference_pipeline_id=inference_pipeline_id,
on_flush_failure=on_flush_failure,
)
else:
_finalize_sync_generator_step(
step=self._step,
token=self._token,
Expand All @@ -553,6 +554,35 @@ def __next__(self):
on_flush_failure=on_flush_failure,
)
raise
except Exception as exc:
# Handle exceptions
if self._step:
_log_step_exception(self._step, exc)
output = _join_output_chunks(self._output_chunks)
if self._captured_context:
self._captured_context.run(
_finalize_sync_generator_step,
step=self._step,
token=self._token,
is_root_step=self._is_root_step,
step_name=step_name,
inputs=self._inputs,
output=output,
inference_pipeline_id=inference_pipeline_id,
on_flush_failure=on_flush_failure,
)
else:
_finalize_sync_generator_step(
step=self._step,
token=self._token,
is_root_step=self._is_root_step,
step_name=step_name,
inputs=self._inputs,
output=output,
inference_pipeline_id=inference_pipeline_id,
on_flush_failure=on_flush_failure,
)
raise

return TracedSyncGenerator()

Expand Down Expand Up @@ -1349,6 +1379,14 @@ def _handle_trace_completion(
logger.debug("Ending the trace...")
current_trace = get_current_trace()

if current_trace is None:
logger.warning(
"Cannot complete trace for step '%s': no active trace found. "
"This can happen when OPENLAYER_DISABLE_PUBLISH=true or trace context was lost.",
step_name,
)
return

trace_data, input_variable_names = post_process_trace(current_trace)

config = dict(
Expand Down Expand Up @@ -1644,7 +1682,7 @@ async def _invoke_with_context(


def post_process_trace(
trace_obj: traces.Trace,
trace_obj: Optional[traces.Trace],
) -> Tuple[Dict[str, Any], List[str]]:
"""Post processing of the trace data before uploading to Openlayer.

Expand Down