Skip to content

Commit bb7ac7f

Browse files
committed
Fix mypy type errors for LangChain 1.x compatibility
- Update bind_tools signature to match BaseChatModel (AIMessage return, tool_choice parameter) - Add isinstance checks for content type in integration tests - Remove unused type: ignore comments - Add proper type annotations for message lists - Import AIMessage in oci_data_science.py
1 parent 7cca81d commit bb7ac7f

File tree

8 files changed

+28
-15
lines changed

8 files changed

+28
-15
lines changed

libs/oci/langchain_oci/chat_models/oci_data_science.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
agenerate_from_stream,
3232
generate_from_stream,
3333
)
34-
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
34+
from langchain_core.messages import (
35+
AIMessage,
36+
AIMessageChunk,
37+
BaseMessage,
38+
BaseMessageChunk,
39+
)
3540
from langchain_core.output_parsers import (
3641
JsonOutputParser,
3742
PydanticOutputParser,
@@ -765,10 +770,14 @@ def _process_response(self, response_json: dict) -> ChatResult:
765770

766771
def bind_tools(
767772
self,
768-
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
773+
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]],
774+
*,
775+
tool_choice: Optional[str] = None,
769776
**kwargs: Any,
770-
) -> Runnable[LanguageModelInput, BaseMessage]:
777+
) -> Runnable[LanguageModelInput, AIMessage]:
771778
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
779+
if tool_choice is not None:
780+
kwargs["tool_choice"] = tool_choice
772781
return super().bind(tools=formatted_tools, **kwargs)
773782

774783

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,8 @@ def convert_to_oci_tool(
931931
"required": parameters.get("required", []),
932932
},
933933
)
934-
elif isinstance(tool, BaseTool): # type: ignore[unreachable]
935-
return self.oci_function_definition( # type: ignore[unreachable]
934+
elif isinstance(tool, BaseTool):
935+
return self.oci_function_definition(
936936
name=tool.name,
937937
description=OCIUtils.remove_signature_from_tool_description(
938938
tool.name, tool.description
@@ -1206,13 +1206,13 @@ def _prepare_request(
12061206

12071207
def bind_tools(
12081208
self,
1209-
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
1209+
tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]],
12101210
*,
12111211
tool_choice: Optional[
12121212
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
12131213
] = None,
12141214
**kwargs: Any,
1215-
) -> Runnable[LanguageModelInput, BaseMessage]:
1215+
) -> Runnable[LanguageModelInput, AIMessage]:
12161216
"""Bind tool-like objects to this chat model.
12171217
12181218
Assumes model is compatible with Meta's tool-calling API.

libs/oci/tests/integration_tests/chat_models/test_chat_features.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def test_max_tokens_limit():
308308

309309
# Response should be truncated due to max_tokens
310310
# Token count varies, but should be reasonably short
311+
assert isinstance(response.content, str)
311312
assert len(response.content.split()) <= 20 # Rough word count check
312313

313314

@@ -405,7 +406,7 @@ def test_multi_turn_context_retention(llm):
405406
def test_long_context_handling(llm):
406407
"""Test handling of longer context windows."""
407408
# Create a conversation with multiple turns
408-
messages = [
409+
messages: list[SystemMessage | HumanMessage | AIMessage] = [
409410
SystemMessage(content="You are a helpful assistant tracking a story."),
410411
]
411412

@@ -426,4 +427,5 @@ def test_long_context_handling(llm):
426427
messages.append(HumanMessage(content="What was the knight's horse named?"))
427428
final_response = llm.invoke(messages)
428429

430+
assert isinstance(final_response.content, str)
429431
assert "thunder" in final_response.content.lower()

libs/oci/tests/integration_tests/chat_models/test_langchain_compatibility.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_basic_invoke(chat_model):
7171
assert isinstance(response, AIMessage)
7272
assert response.content is not None
7373
assert len(response.content) > 0
74-
assert "hello" in response.content.lower()
74+
assert isinstance(response.content, str) and "hello" in response.content.lower()
7575

7676

7777
@pytest.mark.requires("oci")
@@ -100,7 +100,7 @@ def test_invoke_multi_turn(chat_model):
100100
response2 = chat_model.invoke(messages)
101101

102102
assert isinstance(response2, AIMessage)
103-
assert "alice" in response2.content.lower()
103+
assert isinstance(response2.content, str) and "alice" in response2.content.lower()
104104

105105

106106
# =============================================================================
@@ -269,6 +269,7 @@ def test_response_format_json_object(chat_model):
269269
)
270270

271271
assert isinstance(response, AIMessage)
272+
assert isinstance(response.content, str)
272273
# Response should contain valid JSON (may be wrapped in markdown)
273274
import json
274275
import re

libs/oci/tests/integration_tests/chat_models/test_tool_calling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,9 @@ def should_continue(state: MessagesState):
427427
comprehensive analysis."""
428428

429429
# Invoke agent with a diagnostic scenario
430+
# Langgraph invoke signature is generic; passing dict is valid at runtime
430431
result = agent.invoke(
431-
{
432+
{ # type: ignore[arg-type]
432433
"messages": [
433434
SystemMessage(content=system_prompt),
434435
HumanMessage(

libs/oci/tests/unit_tests/chat_models/test_oci_data_science.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, json_data: Dict, status_code: int = 200):
8080
def raise_for_status(self) -> None:
8181
"""Mocked raise for status."""
8282
if 400 <= self.status_code < 600:
83-
raise HTTPError(response=self) # type: ignore[arg-type]
83+
raise HTTPError(response=self)
8484

8585
def json(self) -> Dict:
8686
"""Returns mocked json data."""
@@ -152,7 +152,7 @@ def test_stream_vllm(*args: Any) -> None:
152152
if output is None:
153153
output = chunk
154154
else:
155-
output += chunk # type: ignore[assignment]
155+
output += chunk
156156
count += 1
157157
# LangChain 1.x adds a final chunk with chunk_position='last', so we get 6 chunks
158158
assert count >= 5

libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai_responses_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def call_model(state: AgentState):
385385
# ---- Act ----
386386
app = workflow.compile()
387387
input_message = HumanMessage(content="What is the capital of France?")
388-
result = app.invoke({"messages": [input_message]})
388+
result = app.invoke({"messages": [input_message]}) # type: ignore[arg-type]
389389

390390
# ---- Assert ----
391391
content = result["messages"][1].content[0]

libs/oci/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(self, json_data: Dict, status_code: int = 200) -> None:
6161
def raise_for_status(self) -> None:
6262
"""Mocked raise for status."""
6363
if 400 <= self.status_code < 600:
64-
raise HTTPError(response=self) # type: ignore[arg-type]
64+
raise HTTPError(response=self)
6565

6666
def json(self) -> Dict:
6767
"""Returns mocked json data."""

0 commit comments

Comments
 (0)