Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ general:
weave:
_type: weave
project: "nat-demo"
front_end:
_type: fastapi
endpoints:
- path: /chat_feedback
method: POST
description: Set reaction feedback for an assistant message via Weave call ID
function_name: chat_feedback

functions:
calculator_multiply:
Expand All @@ -32,6 +39,8 @@ functions:
_type: current_datetime
calculator_subtract:
_type: calculator_subtract
chat_feedback:
_type: chat_feedback

llms:
nim_llm:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Generator
from contextlib import contextmanager

from nat.builder.context import Context
from nat.data_models.intermediate_step import IntermediateStep
from nat.data_models.span import Span
from nat.observability.exporter.base_exporter import IsolatedAttribute
Expand Down Expand Up @@ -80,7 +81,18 @@ def _process_start_event(self, event: IntermediateStep):
if span is None:
logger.warning("No span found for event %s", event.UUID)
return
self._create_weave_call(event, span)
call = self._create_weave_call(event, span)

# capture the call ID for mapping reaction feedbacks to specific traces
if (event.payload.event_type == "FUNCTION_START" and event.payload.name == "<workflow>"):
try:
# Store the workflow call ID in the context for later retrieval
context = Context.get()
context._context_state.trace_id.set(call.id)
logger.info("DEBUG: Captured workflow weave call ID: %s", call.id)

except Exception as e:
logger.debug("Could not store workflow trace ID: %s", e)

def _process_end_event(self, event: IntermediateStep):
"""Process the end event for a Weave call.
Expand Down
11 changes: 8 additions & 3 deletions src/nat/agent/react_agent/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pydantic import PositiveInt

from nat.builder.builder import Builder
from nat.builder.context import Context
from nat.builder.framework_enum import LLMFrameworkEnum
from nat.builder.function_info import FunctionInfo
from nat.cli.register_workflow import register_function
Expand Down Expand Up @@ -115,6 +116,10 @@ async def react_agent_workflow(config: ReActAgentWorkflowConfig, builder: Builde
normalize_tool_input_quotes=config.normalize_tool_input_quotes).build_graph()

async def _response_fn(input_message: ChatRequest) -> ChatResponse:
# Get the trace ID for feedback tracking
context = Context.get()
trace_id = context.trace_id

try:
# initialize the starting state with the user query
messages: list[BaseMessage] = trim_messages(messages=[m.model_dump() for m in input_message.messages],
Expand All @@ -135,14 +140,14 @@ async def _response_fn(input_message: ChatRequest) -> ChatResponse:
# get and return the output from the state
state = ReActGraphState(**state)
output_message = state.messages[-1]
return ChatResponse.from_string(str(output_message.content))
return ChatResponse.from_string(str(output_message.content), trace_id=trace_id)

except Exception as ex:
logger.exception("%s ReAct Agent failed with exception: %s", AGENT_LOG_PREFIX, ex)
# here, we can implement custom error messages
if config.verbose:
return ChatResponse.from_string(str(ex))
return ChatResponse.from_string("I seem to be having a problem.")
return ChatResponse.from_string(str(ex), trace_id=trace_id)
return ChatResponse.from_string("I seem to be having a problem.", trace_id=trace_id)

if (config.use_openai_api):
yield FunctionInfo.from_fn(_response_fn, description=config.description)
Expand Down
14 changes: 14 additions & 0 deletions src/nat/builder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class ContextState(metaclass=Singleton):
def __init__(self):
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
self.trace_id: ContextVar[str | None] = ContextVar("trace_id", default=None)
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
self.metadata: ContextVar[RequestAttributes] = ContextVar("request_attributes", default=RequestAttributes())
Expand Down Expand Up @@ -174,6 +175,19 @@ def user_message_id(self) -> str | None:
"""
return self._context_state.user_message_id.get()

@property
def trace_id(self) -> str | None:
"""
Retrieves the trace ID from the current context state.

This can be used to identify traces across different tracing systems
(e.g., Weave call IDs, Phoenix Trace IDs, OpenTelemetry trace IDs, etc.).

Returns:
str | None: The trace ID if available, None otherwise.
"""
return self._context_state.trace_id.get()

@contextmanager
def push_active_function(self,
function_name: str,
Expand Down
44 changes: 35 additions & 9 deletions src/nat/data_models/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class ChatResponse(ResponseBaseModelOutput):
usage: Usage | None = None
system_fingerprint: str | None = None
service_tier: typing.Literal["scale", "default"] | None = None
trace_id: str | None = None

@field_serializer('created')
def serialize_created(self, created: datetime.datetime) -> int:
Expand All @@ -264,7 +265,8 @@ def from_string(data: str,
object_: str | None = None,
model: str | None = None,
created: datetime.datetime | None = None,
usage: Usage | None = None) -> "ChatResponse":
usage: Usage | None = None,
trace_id: str | None = None) -> "ChatResponse":

if id_ is None:
id_ = str(uuid.uuid4())
Expand All @@ -280,7 +282,8 @@ def from_string(data: str,
model=model,
created=created,
choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")],
usage=usage)
usage=usage,
trace_id=trace_id)


class ChatResponseChunk(ResponseBaseModelOutput):
Expand All @@ -300,6 +303,7 @@ class ChatResponseChunk(ResponseBaseModelOutput):
system_fingerprint: str | None = None
service_tier: typing.Literal["scale", "default"] | None = None
usage: Usage | None = None
trace_id: str | None = None

@field_serializer('created')
def serialize_created(self, created: datetime.datetime) -> int:
Expand All @@ -312,7 +316,8 @@ def from_string(data: str,
id_: str | None = None,
created: datetime.datetime | None = None,
model: str | None = None,
object_: str | None = None) -> "ChatResponseChunk":
object_: str | None = None,
trace_id: str | None = None) -> "ChatResponseChunk":

if id_ is None:
id_ = str(uuid.uuid4())
Expand All @@ -327,7 +332,8 @@ def from_string(data: str,
choices=[Choice(index=0, message=ChoiceMessage(content=data), finish_reason="stop")],
created=created,
model=model,
object=object_)
object=object_,
trace_id=trace_id)

@staticmethod
def create_streaming_chunk(content: str,
Expand All @@ -338,7 +344,8 @@ def create_streaming_chunk(content: str,
role: str | None = None,
finish_reason: str | None = None,
usage: Usage | None = None,
system_fingerprint: str | None = None) -> "ChatResponseChunk":
system_fingerprint: str | None = None,
trace_id: str | None = None) -> "ChatResponseChunk":
"""Create an OpenAI-compatible streaming chunk"""
if id_ is None:
id_ = str(uuid.uuid4())
Expand All @@ -358,7 +365,8 @@ def create_streaming_chunk(content: str,
model=model,
object="chat.completion.chunk",
usage=usage,
system_fingerprint=system_fingerprint)
system_fingerprint=system_fingerprint,
trace_id=trace_id)


class ResponseIntermediateStep(ResponseBaseModelIntermediate):
Expand Down Expand Up @@ -631,7 +639,11 @@ def _string_to_nat_chat_request(data: str) -> ChatRequest:
# ======== ChatResponse Converters ========
def _nat_chat_response_to_string(data: ChatResponse) -> str:
if data.choices and data.choices[0].message:
return data.choices[0].message.content or ""
content = data.choices[0].message.content or ""
# Include trace ID in the string if available, using a special format
if data.trace_id:
return f"{content}__TRACE_ID__:{data.trace_id}"
return content
return ""


Expand All @@ -656,7 +668,11 @@ def _string_to_nat_chat_response(data: str) -> ChatResponse:

def _chat_response_to_chat_response_chunk(data: ChatResponse) -> ChatResponseChunk:
# Preserve original message structure for backward compatibility
return ChatResponseChunk(id=data.id, choices=data.choices, created=data.created, model=data.model)
return ChatResponseChunk(id=data.id,
choices=data.choices,
created=data.created,
model=data.model,
trace_id=data.trace_id)


GlobalTypeConverter.register_converter(_chat_response_to_chat_response_chunk)
Expand All @@ -679,8 +695,18 @@ def _chat_response_chunk_to_string(data: ChatResponseChunk) -> str:
def _string_to_nat_chat_response_chunk(data: str) -> ChatResponseChunk:
'''Converts a string to an ChatResponseChunk object'''

# Check if the string contains embedded trace ID
trace_id = None
content = data

if "__TRACE_ID__:" in data:
parts = data.split("__TRACE_ID__:")
if len(parts) == 2:
content = parts[0]
trace_id = parts[1]

# Build and return the response
return ChatResponseChunk.from_string(data)
return ChatResponseChunk.from_string(content, trace_id=trace_id)


GlobalTypeConverter.register_converter(_string_to_nat_chat_response_chunk)
Expand Down
60 changes: 60 additions & 0 deletions src/nat/tool/chat_feedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nat.builder.builder import Builder
from nat.builder.function_info import FunctionInfo
from nat.cli.register_workflow import register_function
from nat.data_models.function import FunctionBaseConfig


class ChatFeedbackTool(FunctionBaseConfig, name="chat_feedback"):
"""
A tool that allows adding reactions/feedback to Weave calls. This tool retrieves a Weave call
by its ID and adds a reaction (like thumbs up/down) to provide feedback on the call's output.
The tool automatically configures the Weave project from the builder's telemetry exporters.
"""
pass


@register_function(config_type=ChatFeedbackTool)
async def chat_feedback(config: ChatFeedbackTool, builder: Builder):

async def _add_chat_feedback(weave_call_id: str, reaction_type: str) -> str:
import weave

# Get the weave project configuration from the builder's telemetry exporters
weave_project = None

# Handle both ChildBuilder and WorkflowBuilder
workflow_builder = getattr(builder, '_workflow_builder', builder)

if hasattr(workflow_builder, '_telemetry_exporters'):
for exporter_config in workflow_builder._telemetry_exporters.values():
if hasattr(exporter_config.config, 'project'):
# Construct project name in the same format as the weave exporter
entity = getattr(exporter_config.config, 'entity', None)
project = exporter_config.config.project
weave_project = f"{entity}/{project}" if entity else project
break

client = weave.init(weave_project)
call = client.get_call(weave_call_id)
call.feedback.add_reaction(reaction_type)

return f"Added reaction '{reaction_type}' to call {weave_call_id}"

yield FunctionInfo.from_fn(
_add_chat_feedback,
description="Adds a reaction/feedback to a Weave call using the provided call ID and reaction type.")
1 change: 1 addition & 0 deletions src/nat/tool/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# Import any tools which need to be automatically registered here
from . import chat_completion
from . import chat_feedback
from . import datetime_tools
from . import document_search
from . import github_tools
Expand Down