Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions posthog/ai/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,22 @@


class RunMetadata(TypedDict, total=False):
messages: Union[List[Dict[str, Any]], List[str]]
input: Any
"""Input of the run: messages, prompt variables, etc."""
name: str
"""Name of the run: chain name, model name, etc."""
provider: str
"""Provider of the run: OpenAI, Anthropic"""
model: str
"""Model used in the run"""
model_params: Dict[str, Any]
"""Model parameters of the run: temperature, max_tokens, etc."""
base_url: str
"""Base URL of the provider's API used in the run."""
start_time: float
"""Start time of the run."""
end_time: float
"""End time of the run."""
Comment on lines +35 to +50
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was also thinking of traces as runs in my earlier PR, but then thought Run is maybe a bit too generation specific for the trace use case (e.g. traces don't have a model). So possibly we could have a GenerationMetadata dataclass, a TraceMetadata one, and later SpanMetadata for clearer separation.
But for now RunMetadata for both is indeed a better model than the simplistic one I went with before ✅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it makes sense. I'll do that in the spans PR.



RunStorage = Dict[UUID, RunMetadata]
Expand Down Expand Up @@ -119,8 +128,7 @@ def on_chain_start(
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
self._set_parent_of_run(run_id, parent_run_id)
if parent_run_id is None and self._trace_name is None:
self._trace_name = self._get_langchain_run_name(serialized, **kwargs)
self._trace_input = inputs
self._set_span_metadata(run_id, self._get_langchain_run_name(serialized, **kwargs), inputs)

def on_chat_model_start(
self,
Expand All @@ -134,7 +142,7 @@ def on_chat_model_start(
self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
self._set_parent_of_run(run_id, parent_run_id)
input = [_convert_message_to_dict(message) for row in messages for message in row]
self._set_run_metadata(serialized, run_id, input, **kwargs)
self._set_llm_metadata(serialized, run_id, input, **kwargs)

def on_llm_start(
self,
Expand All @@ -147,7 +155,7 @@ def on_llm_start(
):
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
self._set_parent_of_run(run_id, parent_run_id)
self._set_run_metadata(serialized, run_id, prompts, **kwargs)
self._set_llm_metadata(serialized, run_id, prompts, **kwargs)

def on_llm_new_token(
self,
Expand Down Expand Up @@ -204,7 +212,7 @@ def on_chain_end(
self._pop_parent_of_run(run_id)

if parent_run_id is None:
self._capture_trace(run_id, outputs=outputs)
self._pop_trace_and_capture(run_id, outputs=outputs)

def on_chain_error(
self,
Expand All @@ -218,7 +226,7 @@ def on_chain_error(
self._pop_parent_of_run(run_id)

if parent_run_id is None:
self._capture_trace(run_id, outputs=None)
self._pop_trace_and_capture(run_id, outputs=None)

def on_llm_end(
self,
Expand Down Expand Up @@ -253,7 +261,7 @@ def on_llm_end(
"$ai_provider": run.get("provider"),
"$ai_model": run.get("model"),
"$ai_model_parameters": run.get("model_params"),
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
"$ai_output_choices": with_privacy_mode(self._client, self._privacy_mode, output),
"$ai_http_status": 200,
"$ai_input_tokens": input_tokens,
Expand Down Expand Up @@ -292,7 +300,7 @@ def on_llm_error(
"$ai_provider": run.get("provider"),
"$ai_model": run.get("model"),
"$ai_model_parameters": run.get("model_params"),
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("messages")),
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
"$ai_http_status": _get_http_status(error),
"$ai_latency": latency,
"$ai_trace_id": trace_id,
Expand Down Expand Up @@ -377,7 +385,14 @@ def _find_root_run(self, run_id: UUID) -> UUID:
id = self._parent_tree[id]
return id

def _set_run_metadata(
def _set_span_metadata(self, run_id: UUID, name: str, input: Any):
self._runs[run_id] = {
"name": name,
"input": input,
"start_time": time.time(),
}

def _set_llm_metadata(
self,
serialized: Dict[str, Any],
run_id: UUID,
Expand All @@ -387,7 +402,7 @@ def _set_run_metadata(
**kwargs,
):
run: RunMetadata = {
"messages": messages,
"input": messages,
"start_time": time.time(),
}
if isinstance(invocation_params, dict):
Expand Down Expand Up @@ -450,12 +465,16 @@ def _get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs
except (KeyError, TypeError):
pass

def _capture_trace(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
def _pop_trace_and_capture(self, run_id: UUID, *, outputs: Optional[Dict[str, Any]]):
trace_id = self._get_trace_id(run_id)
run = self._pop_run_metadata(run_id)
if not run:
return
event_properties = {
"$ai_trace_name": self._trace_name,
"$ai_trace_name": run.get("name"),
"$ai_trace_id": trace_id,
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, self._trace_input),
"$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.get("input")),
"$ai_latency": run.get("end_time", 0) - run.get("start_time", 0),
**self._properties,
}
if outputs is not None:
Expand Down
54 changes: 51 additions & 3 deletions posthog/test/ai/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import math
import os
Expand Down Expand Up @@ -67,7 +68,7 @@ def test_metadata_capture(mock_client):
callbacks = CallbackHandler(mock_client)
run_id = uuid.uuid4()
with patch("time.time", return_value=1234567890):
callbacks._set_run_metadata(
callbacks._set_llm_metadata(
{"kwargs": {"openai_api_base": "https://us.posthog.com"}},
run_id,
messages=[{"role": "user", "content": "Who won the world series in 2020?"}],
Expand All @@ -76,7 +77,7 @@ def test_metadata_capture(mock_client):
)
expected = {
"model": "hog-mini",
"messages": [{"role": "user", "content": "Who won the world series in 2020?"}],
"input": [{"role": "user", "content": "Who won the world series in 2020?"}],
"start_time": 1234567890,
"model_params": {"temperature": 0.5},
"provider": "posthog",
Expand All @@ -90,6 +91,19 @@ def test_metadata_capture(mock_client):
callbacks._pop_run_metadata(uuid.uuid4()) # should not raise


def test_run_metadata_capture(mock_client):
callbacks = CallbackHandler(mock_client)
run_id = uuid.uuid4()
with patch("time.time", return_value=1234567890):
callbacks._set_span_metadata(run_id, "test", 1)
expected = {
"name": "test",
"input": 1,
"start_time": 1234567890,
}
assert callbacks._runs[run_id] == expected


@pytest.mark.parametrize("stream", [True, False])
def test_basic_chat_chain(mock_client, stream):
prompt = ChatPromptTemplate.from_messages(
Expand Down Expand Up @@ -514,7 +528,11 @@ def test_callbacks_logic(mock_client):
assert callbacks._parent_tree == {}

def assert_intermediary_run(m):
assert callbacks._runs == {}
assert len(callbacks._runs) != 0
run = next(iter(callbacks._runs.values()))
assert run["name"] == "RunnableSequence"
assert run["input"] == {}
assert run["start_time"] is not None
assert len(callbacks._parent_tree.items()) == 1
return [m]

Expand Down Expand Up @@ -981,3 +999,33 @@ def test_tool_calls(mock_client):
}
]
assert "additional_kwargs" not in generation_call["properties"]["$ai_output_choices"][0]


async def test_async_traces(mock_client):
async def sleep(x): # -> Any:
await asyncio.sleep(0.1)
return x

prompt = ChatPromptTemplate.from_messages([("user", "Foo")])
chain1 = RunnableLambda(sleep)
chain2 = prompt | FakeMessagesListChatModel(responses=[AIMessage(content="Bar")])

cb = CallbackHandler(mock_client)

start_time = time.time()
await asyncio.gather(
chain1.ainvoke({}, config={"callbacks": [cb]}),
chain2.ainvoke({}, config={"callbacks": [cb]}),
)
approximate_latency = math.floor(time.time() - start_time)
assert mock_client.capture.call_count == 3

first_call, second_call, third_call = mock_client.capture.call_args_list
assert first_call[1]["event"] == "$ai_generation"
assert second_call[1]["event"] == "$ai_trace"
assert second_call[1]["properties"]["$ai_trace_name"] == "RunnableSequence"
assert third_call[1]["event"] == "$ai_trace"
assert third_call[1]["properties"]["$ai_trace_name"] == "sleep"
assert (
min(approximate_latency - 1, 0) <= math.floor(third_call[1]["properties"]["$ai_latency"]) <= approximate_latency
)
2 changes: 1 addition & 1 deletion posthog/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION = "3.9.2"
VERSION = "3.9.3"

if __name__ == "__main__":
print(VERSION, end="") # noqa: T201
Loading