Skip to content

Commit b4921c8

Browse files
piyushgodosvc
authored andcommitted
Merge pull request from migrate_langchain to main
2 parents 0137b5c + 1e6dc7b commit b4921c8

14 files changed

+1113
-44
lines changed

libs/oci/langchain_oci/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
1-
# Copyright (c) 2023 Oracle and/or its affiliates.
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

44
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI
5+
from langchain_oci.chat_models.oci_data_science import (
6+
ChatOCIModelDeployment,
7+
ChatOCIModelDeploymentTGI,
8+
ChatOCIModelDeploymentVLLM
9+
)
510
from langchain_oci.embeddings.oci_generative_ai import OCIGenAIEmbeddings
11+
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import OCIModelDeploymentEndpointEmbeddings
612
from langchain_oci.llms.oci_data_science_model_deployment_endpoint import (
713
BaseOCIModelDeployment,
814
OCIModelDeploymentLLM,
15+
OCIModelDeploymentTGI,
16+
OCIModelDeploymentVLLM,
917
)
1018
from langchain_oci.llms.oci_generative_ai import OCIGenAI, OCIGenAIBase
1119

1220
__all__ = [
1321
"ChatOCIGenAI",
22+
"ChatOCIModelDeployment",
23+
"ChatOCIModelDeploymentTGI",
24+
"ChatOCIModelDeploymentVLLM",
1425
"OCIGenAIEmbeddings",
26+
"OCIModelDeploymentEndpointEmbeddings",
1527
"OCIGenAIBase",
1628
"OCIGenAI",
1729
"BaseOCIModelDeployment",
1830
"OCIModelDeploymentLLM",
31+
"OCIModelDeploymentTGI",
32+
"OCIModelDeploymentVLLM",
1933
]
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
# Copyright (c) 2023 Oracle and/or its affiliates.
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

4-
from langchain_oci.chat_models.oci_data_science import ChatOCIModelDeployment
4+
from langchain_oci.chat_models.oci_data_science import (
5+
ChatOCIModelDeployment,
6+
ChatOCIModelDeploymentTGI,
7+
ChatOCIModelDeploymentVLLM
8+
)
59
from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI
610

7-
__all__ = ["ChatOCIGenAI", "ChatOCIModelDeployment"]
11+
__all__ = ["ChatOCIGenAI", "ChatOCIModelDeployment", "ChatOCIModelDeploymentTGI", "ChatOCIModelDeploymentVLLM"]

libs/oci/langchain_oci/chat_models/oci_data_science.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 Oracle and/or its affiliates.
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

44
"""Chat model for OCI data science model deployment endpoint."""
@@ -48,6 +48,7 @@
4848
)
4949

5050
logger = logging.getLogger(__name__)
51+
DEFAULT_INFERENCE_ENDPOINT_CHAT = "/v1/chat/completions"
5152

5253

5354
def _is_pydantic_class(obj: Any) -> bool:
@@ -57,6 +58,13 @@ def _is_pydantic_class(obj: Any) -> bool:
5758
class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
5859
"""OCI Data Science Model Deployment chat model integration.
5960
61+
Prerequisite
62+
The OCI Model Deployment plugins are installable only on
63+
python version 3.9 and above. If you're working inside the notebook,
64+
try installing the python 3.10 based conda pack and running the
65+
following setup.
66+
67+
6068
Setup:
6169
Install ``oracle-ads`` and ``langchain-openai``.
6270
@@ -91,22 +99,28 @@ class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment):
9199
Key init args — client params:
92100
auth: dict
93101
ADS auth dictionary for OCI authentication.
102+
default_headers: Optional[Dict]
103+
The headers to be added to the Model Deployment request.
94104
95105
Instantiate:
96106
.. code-block:: python
97107
98-
from langchain_community.chat_models import ChatOCIModelDeployment
108+
from langchain_oci.chat_models import ChatOCIModelDeployment
99109
100110
chat = ChatOCIModelDeployment(
101111
endpoint="https://modeldeployment.<region>.oci.customer-oci.com/<ocid>/predict",
102-
model="odsc-llm",
112+
model="odsc-llm", # this is the default model name if deployed with AQUA
103113
streaming=True,
104114
max_retries=3,
105115
model_kwargs={
106116
"max_token": 512,
107117
"temperature": 0.2,
108118
# other model parameters ...
109119
},
120+
default_headers={
121+
"route": "/v1/chat/completions",
122+
# other request headers ...
123+
},
110124
)
111125
112126
Invocation:
@@ -289,6 +303,25 @@ def _default_params(self) -> Dict[str, Any]:
289303
"stream": self.streaming,
290304
}
291305

306+
def _headers(
307+
self, is_async: Optional[bool] = False, body: Optional[dict] = None
308+
) -> Dict:
309+
"""Construct and return the headers for a request.
310+
311+
Args:
312+
is_async (bool, optional): Indicates if the request is asynchronous.
313+
Defaults to `False`.
314+
body (optional): The request body to be included in the headers if
315+
the request is asynchronous.
316+
317+
Returns:
318+
Dict: A dictionary containing the appropriate headers for the request.
319+
"""
320+
return {
321+
"route": DEFAULT_INFERENCE_ENDPOINT_CHAT,
322+
**super()._headers(is_async=is_async, body=body),
323+
}
324+
292325
def _generate(
293326
self,
294327
messages: List[BaseMessage],
@@ -702,7 +735,7 @@ def _process_response(self, response_json: dict) -> ChatResult:
702735

703736
for choice in choices:
704737
message = _convert_dict_to_message(choice["message"])
705-
generation_info = dict(finish_reason=choice.get("finish_reason"))
738+
generation_info = {"finish_reason": choice.get("finish_reason")}
706739
if "logprobs" in choice:
707740
generation_info["logprobs"] = choice["logprobs"]
708741

@@ -746,7 +779,7 @@ class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment):
746779
747780
.. code-block:: python
748781
749-
from langchain_community.chat_models import ChatOCIModelDeploymentVLLM
782+
from langchain_oci.chat_models import ChatOCIModelDeploymentVLLM
750783
751784
chat = ChatOCIModelDeploymentVLLM(
752785
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
@@ -913,7 +946,7 @@ class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment):
913946
914947
.. code-block:: python
915948
916-
from langchain_community.chat_models import ChatOCIModelDeploymentTGI
949+
from langchain_oci.chat_models import ChatOCIModelDeploymentTGI
917950
918951
chat = ChatOCIModelDeploymentTGI(
919952
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<ocid>/predict",
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
# Copyright (c) 2023 Oracle and/or its affiliates.
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

4+
from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import OCIModelDeploymentEndpointEmbeddings
45
from langchain_oci.embeddings.oci_generative_ai import OCIGenAIEmbeddings
56

6-
__all__ = ["OCIGenAIEmbeddings"]
7+
__all__ = ["OCIModelDeploymentEndpointEmbeddings", "OCIGenAIEmbeddings"]
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
2+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
3+
4+
from langchain_core.embeddings import Embeddings
5+
from langchain_core.language_models.llms import create_base_retry_decorator
6+
from langchain_core.utils import get_from_dict_or_env
7+
from pydantic import BaseModel, Field, model_validator
8+
import requests
9+
from typing import Any, Callable, Dict, List, Mapping, Optional
10+
11+
12+
DEFAULT_HEADER = {
13+
"Content-Type": "application/json",
14+
}
15+
16+
17+
class TokenExpiredError(Exception):
18+
pass
19+
20+
21+
def _create_retry_decorator(llm) -> Callable[[Any], Any]:
22+
"""Creates a retry decorator."""
23+
errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
24+
decorator = create_base_retry_decorator(
25+
error_types=errors, max_retries=llm.max_retries
26+
)
27+
return decorator
28+
29+
30+
class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings):
31+
"""Embedding model deployed on OCI Data Science Model Deployment.
32+
33+
Example:
34+
35+
.. code-block:: python
36+
37+
from langchain_oci.embeddings import OCIModelDeploymentEndpointEmbeddings
38+
39+
embeddings = OCIModelDeploymentEndpointEmbeddings(
40+
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
41+
)
42+
""" # noqa: E501
43+
44+
auth: dict = Field(default_factory=dict, exclude=True)
45+
"""ADS auth dictionary for OCI authentication:
46+
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
47+
This can be generated by calling `ads.common.auth.api_keys()`
48+
or `ads.common.auth.resource_principal()`. If this is not
49+
provided then the `ads.common.default_signer()` will be used."""
50+
51+
endpoint: str = ""
52+
"""The uri of the endpoint from the deployed Model Deployment model."""
53+
54+
model_kwargs: Optional[Dict] = None
55+
"""Keyword arguments to pass to the model."""
56+
57+
endpoint_kwargs: Optional[Dict] = None
58+
"""Optional attributes (except for headers) passed to the request.post
59+
function.
60+
"""
61+
62+
max_retries: int = 1
63+
"""The maximum number of retries to make when generating."""
64+
65+
@model_validator(mode="before")
66+
def validate_environment( # pylint: disable=no-self-argument
67+
cls, values: Dict
68+
) -> Dict:
69+
"""Validate that python package exists in environment."""
70+
try:
71+
import ads
72+
73+
except ImportError as ex:
74+
raise ImportError(
75+
"Could not import ads python package. "
76+
"Please install it with `pip install oracle_ads`."
77+
) from ex
78+
if not values.get("auth", None):
79+
values["auth"] = ads.common.auth.default_signer()
80+
values["endpoint"] = get_from_dict_or_env(
81+
values,
82+
"endpoint",
83+
"OCI_LLM_ENDPOINT",
84+
)
85+
return values
86+
87+
@property
88+
def _identifying_params(self) -> Mapping[str, Any]:
89+
"""Get the identifying parameters."""
90+
_model_kwargs = self.model_kwargs or {}
91+
return {
92+
**{"endpoint": self.endpoint},
93+
**{"model_kwargs": _model_kwargs},
94+
}
95+
96+
def _embed_with_retry(self, **kwargs) -> Any:
97+
"""Use tenacity to retry the call."""
98+
retry_decorator = _create_retry_decorator(self)
99+
100+
@retry_decorator
101+
def _completion_with_retry(**kwargs: Any) -> Any:
102+
try:
103+
response = requests.post(self.endpoint, **kwargs)
104+
response.raise_for_status()
105+
return response
106+
except requests.exceptions.HTTPError as http_err:
107+
if response.status_code == 401 and self._refresh_signer():
108+
raise TokenExpiredError() from http_err
109+
else:
110+
raise ValueError(
111+
f"Server error: {str(http_err)}. Message: {response.text}"
112+
) from http_err
113+
except Exception as e:
114+
raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e
115+
116+
return _completion_with_retry(**kwargs)
117+
118+
def _embedding(self, texts: List[str]) -> List[List[float]]:
119+
"""Call out to OCI Data Science Model Deployment Endpoint.
120+
121+
Args:
122+
texts: A list of texts to embed.
123+
124+
Returns:
125+
A list of list of floats representing the embeddings, or None if an
126+
error occurs.
127+
"""
128+
_model_kwargs = self.model_kwargs or {}
129+
body = self._construct_request_body(texts, _model_kwargs)
130+
request_kwargs = self._construct_request_kwargs(body)
131+
response = self._embed_with_retry(**request_kwargs)
132+
return self._proceses_response(response)
133+
134+
def _construct_request_kwargs(self, body: Any) -> dict:
135+
"""Constructs the request kwargs as a dictionary."""
136+
from ads.model.common.utils import _is_json_serializable
137+
138+
_endpoint_kwargs = self.endpoint_kwargs or {}
139+
headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER)
140+
return (
141+
dict(
142+
headers=headers,
143+
json=body,
144+
auth=self.auth.get("signer"),
145+
**_endpoint_kwargs,
146+
)
147+
if _is_json_serializable(body)
148+
else dict(
149+
headers=headers,
150+
data=body,
151+
auth=self.auth.get("signer"),
152+
**_endpoint_kwargs,
153+
)
154+
)
155+
156+
def _construct_request_body(self, texts: List[str], params: dict) -> Any:
157+
"""Constructs the request body."""
158+
return {"input": texts}
159+
160+
def _proceses_response(self, response: requests.Response) -> List[List[float]]:
161+
"""Extracts results from requests.Response."""
162+
try:
163+
res_json = response.json()
164+
embeddings = res_json["data"][0]["embedding"]
165+
except Exception as e:
166+
raise ValueError(
167+
f"Error raised by inference API: {e}.\nResponse: {response.text}"
168+
)
169+
return embeddings
170+
171+
def embed_documents(
172+
self,
173+
texts: List[str],
174+
chunk_size: Optional[int] = None,
175+
) -> List[List[float]]:
176+
"""Compute doc embeddings using OCI Data Science Model Deployment Endpoint.
177+
178+
Args:
179+
texts: The list of texts to embed.
180+
chunk_size: The chunk size defines how many input texts will
181+
be grouped together as request. If None, will use the
182+
chunk size specified by the class.
183+
184+
Returns:
185+
List of embeddings, one for each text.
186+
"""
187+
results = []
188+
_chunk_size = (
189+
len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size
190+
)
191+
for i in range(0, len(texts), _chunk_size):
192+
response = self._embedding(texts[i : i + _chunk_size])
193+
results.extend(response)
194+
return results
195+
196+
def embed_query(self, text: str) -> List[float]:
197+
"""Compute query embeddings using OCI Data Science Model Deployment Endpoint.
198+
199+
Args:
200+
text: The text to embed.
201+
202+
Returns:
203+
Embeddings for the text.
204+
"""
205+
return self._embedding([text])[0]
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
# Copyright (c) 2023 Oracle and/or its affiliates.
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
22
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
33

44
from langchain_oci.llms.oci_data_science_model_deployment_endpoint import (
55
BaseOCIModelDeployment,
66
OCIModelDeploymentLLM,
7+
OCIModelDeploymentTGI,
8+
OCIModelDeploymentVLLM,
79
)
810
from langchain_oci.llms.oci_generative_ai import OCIGenAI, OCIGenAIBase
911

@@ -12,4 +14,6 @@
1214
"OCIGenAI",
1315
"BaseOCIModelDeployment",
1416
"OCIModelDeploymentLLM",
17+
"OCIModelDeploymentTGI",
18+
"OCIModelDeploymentVLLM",
1519
]

0 commit comments

Comments
 (0)