Skip to content

Commit 5e80f0a

Browse files
Refactor libs/oci to meet mypy stardard
Removed unnecessary type: ignore comments, improved type annotations, and updated function signatures for better type safety across chat models, embeddings, and tests. Also enhanced mypy configuration for stricter type checking and plugin support.
1 parent 6a2315c commit 5e80f0a

File tree

9 files changed

+69
-57
lines changed

9 files changed

+69
-57
lines changed

libs/oci/langchain_oci/chat_models/oci_data_science.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
agenerate_from_stream,
3232
generate_from_stream,
3333
)
34-
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
34+
from langchain_core.messages import (
35+
AIMessage,
36+
AIMessageChunk,
37+
BaseMessage,
38+
BaseMessageChunk,
39+
)
3540
from langchain_core.output_parsers import (
3641
JsonOutputParser,
3742
PydanticOutputParser,
@@ -598,7 +603,7 @@ def with_structured_output(
598603
if method == "json_mode":
599604
llm = self.bind(response_format={"type": "json_object"})
600605
output_parser = (
601-
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
606+
PydanticOutputParser(pydantic_object=schema)
602607
if is_pydantic_schema
603608
else JsonOutputParser()
604609
)
@@ -767,7 +772,7 @@ def bind_tools(
767772
self,
768773
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
769774
**kwargs: Any,
770-
) -> Runnable[LanguageModelInput, BaseMessage]:
775+
) -> Runnable[LanguageModelInput, AIMessage]:
771776
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
772777
return super().bind(tools=formatted_tools, **kwargs)
773778

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
AIMessage,
3232
AIMessageChunk,
3333
BaseMessage,
34-
ChatMessage,
3534
HumanMessage,
3635
SystemMessage,
3736
ToolCall,
@@ -350,7 +349,7 @@ def get_role(self, message: BaseMessage) -> str:
350349
raise ValueError(f"Unknown message type: {type(message)}")
351350

352351
def messages_to_oci_params(
353-
self, messages: Sequence[ChatMessage], **kwargs: Any
352+
self, messages: Sequence[BaseMessage], **kwargs: Any
354353
) -> Dict[str, Any]:
355354
"""
356355
Convert LangChain messages to OCI parameters for Cohere.
@@ -417,7 +416,7 @@ def messages_to_oci_params(
417416
current_turn = list(reversed(current_turn))
418417

419418
# Process tool results from the current turn
420-
oci_tool_results: List[Any] = []
419+
oci_tool_results: Optional[List[Any]] = []
421420
for message in current_turn:
422421
if isinstance(message, ToolMessage):
423422
tool_msg = message
@@ -434,7 +433,7 @@ def messages_to_oci_params(
434433
parameters=lc_tool_call["args"],
435434
)
436435
tool_result.outputs = [{"output": tool_msg.content}]
437-
oci_tool_results.append(tool_result)
436+
oci_tool_results.append(tool_result) # type: ignore[union-attr]
438437
if not oci_tool_results:
439438
oci_tool_results = None
440439

@@ -552,7 +551,7 @@ def process_stream_tool_calls(
552551
Returns:
553552
List of ToolCallChunk objects
554553
"""
555-
tool_call_chunks = []
554+
tool_call_chunks: List[ToolCallChunk] = []
556555
tool_call_response = self.chat_stream_tool_calls(event_data)
557556

558557
if not tool_call_response:
@@ -813,7 +812,7 @@ def _should_allow_more_tool_calls(
813812
return False
814813

815814
# Detect infinite loop: same tool called with same arguments in succession
816-
recent_calls = []
815+
recent_calls: list = []
817816
for msg in reversed(messages):
818817
if hasattr(msg, "tool_calls") and msg.tool_calls:
819818
for tc in msg.tool_calls:
@@ -895,7 +894,7 @@ def _process_message_content(
895894

896895
def convert_to_oci_tool(
897896
self,
898-
tool: Union[Type[BaseModel], Callable, BaseTool],
897+
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
899898
) -> Dict[str, Any]:
900899
"""Convert a BaseTool instance, TypedDict or BaseModel type
901900
to a OCI tool in Meta's format.
@@ -1016,7 +1015,7 @@ def process_stream_tool_calls(
10161015
Returns:
10171016
List of ToolCallChunk objects
10181017
"""
1019-
tool_call_chunks = []
1018+
tool_call_chunks: List[ToolCallChunk] = []
10201019
tool_call_response = self.chat_stream_tool_calls(event_data)
10211020

10221021
if not tool_call_response:
@@ -1142,7 +1141,7 @@ def _prepare_request(
11421141
stop: Optional[List[str]],
11431142
stream: bool,
11441143
**kwargs: Any,
1145-
) -> Dict[str, Any]:
1144+
) -> Any:
11461145
"""
11471146
Prepare the OCI chat request from LangChain messages.
11481147
@@ -1206,7 +1205,7 @@ def bind_tools(
12061205
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
12071206
] = None,
12081207
**kwargs: Any,
1209-
) -> Runnable[LanguageModelInput, BaseMessage]:
1208+
) -> Runnable[LanguageModelInput, AIMessage]:
12101209
"""Bind tool-like objects to this chat model.
12111210
12121211
Assumes model is compatible with Meta's tool-calling API.
@@ -1299,8 +1298,8 @@ def with_structured_output(
12991298
tool_name = getattr(self._provider.convert_to_oci_tool(schema), "name")
13001299
if is_pydantic_schema:
13011300
output_parser: OutputParserLike = PydanticToolsParser(
1302-
tools=[schema], # type: ignore[list-item]
1303-
first_tool_only=True, # type: ignore[list-item]
1301+
tools=[schema],
1302+
first_tool_only=True,
13041303
)
13051304
else:
13061305
output_parser = JsonOutputKeyToolsParser(
@@ -1309,15 +1308,15 @@ def with_structured_output(
13091308
elif method == "json_mode":
13101309
llm = self.bind(response_format={"type": "JSON_OBJECT"})
13111310
output_parser = (
1312-
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
1311+
PydanticOutputParser(pydantic_object=schema)
13131312
if is_pydantic_schema
13141313
else JsonOutputParser()
13151314
)
13161315
elif method == "json_schema":
1317-
json_schema_dict = (
1316+
json_schema_dict: Dict[str, Any] = (
13181317
schema.model_json_schema() # type: ignore[union-attr]
13191318
if is_pydantic_schema
1320-
else schema
1319+
else schema # type: ignore[assignment]
13211320
)
13221321

13231322
response_json_schema = self._provider.oci_response_json_schema(

libs/oci/langchain_oci/embeddings/oci_data_science_model_deployment_endpoint.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,9 @@ def _completion_with_retry(**kwargs: Any) -> Any:
104104
response.raise_for_status()
105105
return response
106106
except requests.exceptions.HTTPError as http_err:
107-
if response.status_code == 401 and self._refresh_signer():
108-
raise TokenExpiredError() from http_err
109-
else:
110-
raise ValueError(
111-
f"Server error: {str(http_err)}. Message: {response.text}"
112-
) from http_err
107+
raise ValueError(
108+
f"Server error: {str(http_err)}. Message: {response.text}"
109+
) from http_err
113110
except Exception as e:
114111
raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e
115112

libs/oci/langchain_oci/embeddings/oci_generative_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self-
125125
client_kwargs.pop("signer", None)
126126
elif values["auth_type"] == OCIAuthType(2).name:
127127

128-
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
128+
def make_security_token_signer(oci_config):
129129
pk = oci.signer.load_private_key_from_file(
130130
oci_config.get("key_file"), None
131131
)

libs/oci/langchain_oci/llms/oci_generative_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def validate_environment(cls, values: Dict) -> Dict:
151151
client_kwargs.pop("signer", None)
152152
elif values["auth_type"] == OCIAuthType(2).name:
153153

154-
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
154+
def make_security_token_signer(oci_config):
155155
pk = oci.signer.load_private_key_from_file(
156156
oci_config.get("key_file"), None
157157
)

libs/oci/pyproject.toml

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,26 @@ ignore = [
7777
]
7878

7979
[tool.mypy]
80-
ignore_missing_imports = "True"
81-
82-
# Disable specific error codes that are causing issues
83-
disallow_untyped_defs = "False"
84-
disable_error_code = ["attr-defined", "assignment", "var-annotated", "override", "union-attr", "arg-type"]
85-
86-
# TODO: LangChain Google settings
87-
# plugins = ["pydantic.mypy"]
88-
# strict = true
89-
# disallow_untyped_defs = true
90-
91-
# # TODO: activate for 'strict' checking
92-
# disallow_any_generics = false
93-
# warn_return_any = false
80+
plugins = ["pydantic.mypy"]
81+
check_untyped_defs = true
82+
error_summary = false
83+
pretty = true
84+
show_column_numbers = true
85+
show_error_codes = true
86+
show_error_context = true
87+
warn_redundant_casts = true
88+
warn_unreachable = true
89+
warn_unused_configs = true
90+
warn_unused_ignores = true
91+
92+
# Ignore missing imports only for specific untyped packages
93+
[[tool.mypy.overrides]]
94+
module = [
95+
"oci.*",
96+
"ads.*",
97+
"langchain_openai.*",
98+
]
99+
ignore_missing_imports = true
94100

95101
[tool.coverage.run]
96102
omit = ["tests/*"]

libs/oci/tests/integration_tests/chat_models/test_tool_calling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def should_continue(state: MessagesState):
428428

429429
# Invoke agent with a diagnostic scenario
430430
result = agent.invoke(
431-
{
431+
{ # type: ignore[arg-type]
432432
"messages": [
433433
SystemMessage(content=system_prompt),
434434
HumanMessage(

libs/oci/tests/unit_tests/chat_models/test_oci_data_science.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self, json_data: Dict, status_code: int = 200):
8080
def raise_for_status(self) -> None:
8181
"""Mocked raise for status."""
8282
if 400 <= self.status_code < 600:
83-
raise HTTPError() # type: ignore[call-arg]
83+
raise HTTPError()
8484

8585
def json(self) -> Dict:
8686
"""Returns mocked json data."""

libs/oci/tests/unit_tests/chat_models/test_oci_generative_ai.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414

1515
class MockResponseDict(dict):
16-
def __getattr__(self, val): # type: ignore[no-untyped-def]
16+
def __getattr__(self, val):
1717
return self.get(val)
1818

1919

2020
class MockToolCall(dict):
21-
def __getattr__(self, val): # type: ignore[no-untyped-def]
21+
def __getattr__(self, val):
2222
return self[val]
2323

2424

@@ -37,7 +37,7 @@ def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
3737

3838
provider = model_id.split(".")[0].lower()
3939

40-
def mocked_response(*args): # type: ignore[no-untyped-def]
40+
def mocked_response(*args):
4141
response_text = "Assistant chat reply."
4242
response = None
4343
if provider == "cohere":
@@ -166,7 +166,7 @@ def test_meta_tool_calling(monkeypatch: MonkeyPatch) -> None:
166166
oci_gen_ai_client = MagicMock()
167167
llm = ChatOCIGenAI(model_id="meta.llama-3-70b-instruct", client=oci_gen_ai_client)
168168

169-
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
169+
def mocked_response(*args, **kwargs):
170170
# Mock response with tool calls
171171
return MockResponseDict(
172172
{
@@ -231,7 +231,7 @@ def get_weather(location: str) -> str:
231231
messages = [HumanMessage(content="What's the weather like?")]
232232

233233
# Test different tool choice options
234-
tool_choices = [
234+
tool_choices: list[str | bool | dict[str, str | dict[str, str]]] = [
235235
"get_weather", # Specific tool
236236
"auto", # Auto mode
237237
"none", # No tools
@@ -254,7 +254,7 @@ def get_weather(location: str) -> str:
254254
assert tool_call["function"]["name"] == "get_weather"
255255

256256
# Test escaped JSON arguments (issue #52)
257-
def mocked_response_escaped(*args, **kwargs): # type: ignore[no-untyped-def]
257+
def mocked_response_escaped(*args, **kwargs):
258258
"""Mock response with escaped JSON arguments."""
259259
return MockResponseDict(
260260
{
@@ -310,6 +310,7 @@ def mocked_response_escaped(*args, **kwargs): # type: ignore[no-untyped-def]
310310
response_escaped = llm.bind_tools(tools=[get_weather]).invoke(messages)
311311

312312
# Verify escaped JSON was correctly parsed to a dict
313+
assert isinstance(response_escaped, AIMessage)
313314
assert len(response_escaped.tool_calls) == 1
314315
assert response_escaped.tool_calls[0]["name"] == "get_weather"
315316
assert response_escaped.tool_calls[0]["args"] == {"location": "San Francisco"}
@@ -337,7 +338,7 @@ def get_weather(location: str) -> str:
337338
).invoke(messages)
338339

339340
# Mock response for the case without tool choice
340-
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
341+
def mocked_response(*args, **kwargs):
341342
return MockResponseDict(
342343
{
343344
"status": 200,
@@ -378,7 +379,7 @@ def test_meta_tool_conversion(monkeypatch: MonkeyPatch) -> None:
378379
oci_gen_ai_client = MagicMock()
379380
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
380381

381-
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
382+
def mocked_response(*args, **kwargs):
382383
request = args[0]
383384
# Check the conversion of tools to oci generic API spec
384385
# Function tool
@@ -467,6 +468,7 @@ class PydanticTool(BaseModel):
467468

468469
# For tool calls, the response content should be empty.
469470
assert response.content == ""
471+
assert isinstance(response, AIMessage)
470472
assert len(response.tool_calls) == 1
471473
assert response.tool_calls[0]["name"] == "function_tool"
472474

@@ -483,7 +485,7 @@ class WeatherResponse(BaseModel):
483485
oci_gen_ai_client = MagicMock()
484486
llm = ChatOCIGenAI(model_id="cohere.command-r-16k", client=oci_gen_ai_client)
485487

486-
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
488+
def mocked_response(*args, **kwargs):
487489
return MockResponseDict(
488490
{
489491
"status": 200,
@@ -533,7 +535,7 @@ class WeatherResponse(BaseModel):
533535
oci_gen_ai_client = MagicMock()
534536
llm = ChatOCIGenAI(model_id="cohere.command-latest", client=oci_gen_ai_client)
535537

536-
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
538+
def mocked_response(*args, **kwargs):
537539
# Verify that response_format is a JsonSchemaResponseFormat object
538540
request = args[0]
539541
assert hasattr(request.chat_request, "response_format")
@@ -610,7 +612,7 @@ class WeatherResponse(BaseModel):
610612
oci_gen_ai_client = MagicMock()
611613
llm = ChatOCIGenAI(model_id="cohere.command-r-16k", client=oci_gen_ai_client)
612614

613-
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
615+
def mocked_response(*args, **kwargs):
614616
return MockResponseDict(
615617
{
616618
"status": 200,
@@ -663,7 +665,7 @@ def test_ai_message_tool_calls_direct_field(monkeypatch: MonkeyPatch) -> None:
663665
# Track if the tool_calls processing branch is executed
664666
tool_calls_processed = False
665667

666-
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
668+
def mocked_response(*args, **kwargs):
667669
nonlocal tool_calls_processed
668670
# Check if the request contains tool_calls in the message
669671
request = args[0]
@@ -746,7 +748,7 @@ def test_ai_message_tool_calls_additional_kwargs(monkeypatch: MonkeyPatch) -> No
746748
oci_gen_ai_client = MagicMock()
747749
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
748750

749-
def mocked_response(*args, **kwargs): # type: ignore[no-untyped-def]
751+
def mocked_response(*args, **kwargs):
750752
return MockResponseDict(
751753
{
752754
"status": 200,
@@ -884,8 +886,11 @@ def get_weather(city: str) -> str:
884886
]
885887

886888
# Prepare the request - need to pass tools from the bound model kwargs
887-
request = llm_with_tools._prepare_request(
888-
messages, stop=None, stream=False, **llm_with_tools.kwargs
889+
request = llm._prepare_request(
890+
messages,
891+
stop=None,
892+
stream=False,
893+
**llm_with_tools.kwargs, # type: ignore[attr-defined]
889894
)
890895

891896
# Verify that tool_choice is set to 'none' because limit was reached

0 commit comments

Comments
 (0)