Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,9 +712,9 @@ def messages_to_oci_params(
)
else:
oci_message = self.oci_chat_message[role](content=tool_content)
elif isinstance(message, AIMessage) and message.additional_kwargs.get(
"tool_calls"
):
elif isinstance(message, AIMessage) and (
message.tool_calls or
message.additional_kwargs.get("tool_calls")):
# Process content and tool calls for assistant messages
content = self._process_message_content(message.content)
tool_calls = []
Expand Down
161 changes: 160 additions & 1 deletion libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from unittest.mock import MagicMock

import pytest
from langchain_core.messages import HumanMessage
from pytest import MonkeyPatch

from langchain_core.messages import HumanMessage, AIMessage
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI


Expand Down Expand Up @@ -575,6 +575,165 @@ def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
assert response["parsed"].conditions == "Sunny"


@pytest.mark.requires("oci")
def test_ai_message_tool_calls_direct_field(monkeypatch: MonkeyPatch) -> None:
"""Test AIMessage with tool_calls in the direct tool_calls field."""

oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)

# Track if the tool_calls processing branch is executed
tool_calls_processed = False

def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
nonlocal tool_calls_processed
# Check if the request contains tool_calls in the message
request = args[0]
if hasattr(request, 'chat_request') and hasattr(request.chat_request, 'messages'):
for msg in request.chat_request.messages:
if hasattr(msg, 'tool_calls') and msg.tool_calls:
tool_calls_processed = True
break
return MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"api_format": "GENERIC",
"choices": [
MockResponseDict(
{
"message": MockResponseDict(
{
"role": "ASSISTANT",
"name": None,
"content": [
MockResponseDict(
{
"text": (
"I'll help you."
),
"type": "TEXT",
}
)
],
"tool_calls": [],
}
),
"finish_reason": "completed",
}
)
],
"time_created": "2025-08-14T10:00:01.100000+00:00",
}
),
"model_id": "meta.llama-3.3-70b-instruct",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict({"content-length": "123"}),
}
)

monkeypatch.setattr(llm.client, "chat", mocked_response)

# Create AIMessage with tool_calls in the direct tool_calls field
ai_message = AIMessage(
content="I need to call a function",
tool_calls=[
{
"id": "call_123",
"name": "get_weather",
"args": {"location": "San Francisco"},
}
]
)

messages = [ai_message]

# This should not raise an error and should process the tool_calls correctly
response = llm.invoke(messages)
assert response.content == "I'll help you."


@pytest.mark.requires("oci")
def test_ai_message_tool_calls_additional_kwargs(monkeypatch: MonkeyPatch) -> None:
"""Test AIMessage with tool_calls in additional_kwargs field."""

oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)

def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
return MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"api_format": "GENERIC",
"choices": [
MockResponseDict(
{
"message": MockResponseDict(
{
"role": "ASSISTANT",
"name": None,
"content": [
MockResponseDict(
{
"text": (
"I'll help you."
),
"type": "TEXT",
}
)
],
"tool_calls": [],
}
),
"finish_reason": "completed",
}
)
],
"time_created": "2025-08-14T10:00:01.100000+00:00",
}
),
"model_id": "meta.llama-3.3-70b-instruct",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict({"content-length": "123"}),
}
)

monkeypatch.setattr(llm.client, "chat", mocked_response)

# Create AIMessage with tool_calls in additional_kwargs
ai_message = AIMessage(
content="I need to call a function",
additional_kwargs={
"tool_calls": [
{
"id": "call_456",
"name": "get_weather",
"args": {"location": "New York"},
}
]
}
)

messages = [ai_message]

# This should not raise an error and should process the tool_calls correctly
response = llm.invoke(messages)
assert response.content == "I'll help you."


def test_get_provider():
"""Test determining the provider based on the model_id."""
model_provider_map = {
Expand Down