Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,31 @@ def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
id=tool_call.id if "id" in tool_call.attribute_map else uuid.uuid4().hex[:],
)

@staticmethod
def resolve_schema_refs(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
OCI Generative AI doesn't support $ref and $defs, so we inline all references.
"""
defs = schema.get("$defs", {}) # OCI Generative AI doesn't support $defs

def resolve(obj: Any) -> Any:
if isinstance(obj, dict):
if "$ref" in obj:
ref = obj["$ref"]
if ref.startswith("#/$defs/"):
key = ref.split("/")[-1]
return resolve(defs.get(key, obj))
return obj # Cannot resolve $ref, return unchanged
return {k: resolve(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [resolve(item) for item in obj]
return obj

resolved = resolve(schema)
if isinstance(resolved, dict):
resolved.pop("$defs", None)
return resolved


class Provider(ABC):
"""Abstract base class for OCI Generative AI providers."""
Expand Down Expand Up @@ -1371,6 +1396,9 @@ def with_structured_output(
else schema # type: ignore[assignment]
)

# Resolve $ref references as OCI doesn't support $ref and $defs
json_schema_dict = OCIUtils.resolve_schema_refs(json_schema_dict)

response_json_schema = self._provider.oci_response_json_schema(
name=json_schema_dict.get("title", "response"),
description=json_schema_dict.get("description", ""),
Expand Down
31 changes: 30 additions & 1 deletion libs/oci/tests/unit_tests/chat_models/test_response_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def test_with_structured_output_json_schema():
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)

# This should not raise TypeError anymore
from pydantic import BaseModel

class TestSchema(BaseModel):
Expand All @@ -126,6 +125,36 @@ class TestSchema(BaseModel):
assert structured_llm is not None


@pytest.mark.requires("oci")
def test_with_structured_output_json_schema_nested_refs():
"""Test with_structured_output with json_schema method and nested refs."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)

from enum import Enum
from typing import List

from pydantic import BaseModel

class Color(Enum):
RED = "RED"
BLUE = "BLUE"
GREEN = "GREEN"

class Item(BaseModel):
name: str
color: Color # Creates $ref to Color

class Response(BaseModel):
message: str
items: List[Item] # Array with $ref inside

structured_llm = llm.with_structured_output(schema=Response, method="json_schema")

# The structured LLM should be created without errors
assert structured_llm is not None


@pytest.mark.requires("oci")
def test_response_format_json_schema_object():
"""Test response_format with JsonSchemaResponseFormat object."""
Expand Down